feat: add PiD support (#1585)

This commit is contained in:
leejet 2026-05-31 22:38:39 +08:00 committed by GitHub
parent d2797b8667
commit 0982807139
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1387 additions and 23 deletions

View File

@ -15,6 +15,7 @@ API and command-line option may change frequently.***
## 🔥Important News ## 🔥Important News
* **2026/05/31** 🚀 stable-diffusion.cpp now supports **PiD**
* **2026/05/27** 🚀 stable-diffusion.cpp now supports **Lens** * **2026/05/27** 🚀 stable-diffusion.cpp now supports **Lens**
* **2026/05/17** 🚀 stable-diffusion.cpp now supports **LTX-2.3** * **2026/05/17** 🚀 stable-diffusion.cpp now supports **LTX-2.3**
* **2026/04/11** 🚀 stable-diffusion.cpp now uses a brand-new embedded web UI. * **2026/04/11** 🚀 stable-diffusion.cpp now uses a brand-new embedded web UI.
@ -42,6 +43,7 @@ API and command-line option may change frequently.***
- [Chroma](./docs/chroma.md) - [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md) - [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md) - [Qwen Image](./docs/qwen_image.md)
- [PiD](./docs/pid.md)
- [LongCat Image](./docs/longcat_image.md) - [LongCat Image](./docs/longcat_image.md)
- [Z-Image](./docs/z_image.md) - [Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md) - [Ovis-Image](./docs/ovis_image.md)

BIN
assets/pid/example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.0 MiB

39
docs/pid.md Normal file
View File

@ -0,0 +1,39 @@
# How to Use
PiD is NVIDIA's Pixel Diffusion Decoder. It replaces the usual VAE decode or decode-then-upscale path with a pixel-space diffusion decoder conditioned on a
source latent and text prompt.
In stable-diffusion.cpp, PiD currently runs as an image edit pipeline: provide a reference image with `-r`/`--ref-image`, encode that image with a matching VAE, then let the PiD diffusion model decode/upscale directly to RGB.
## Download weights
- Download PiD
- safetensors: https://huggingface.co/Comfy-Org/PixelDiT/tree/main/diffusion_models
- Download Gemma 2 2B
- safetensors: https://huggingface.co/Comfy-Org/PixelDiT/tree/main/text_encoders
- Download the VAE that matches the PiD checkpoint backbone
- safetensors: https://huggingface.co/nvidia/PiD/tree/main/checkpoints
- Flux / Z-Image PiD: use the Flux VAE and pass `--vae-format flux`
- SD3 PiD: use the SD3 VAE and pass `--vae-format sd3`
- Flux.2 PiD: use the Flux.2 VAE and pass `--vae-format flux2`
The official PiD model card should be checked before use. At the time of the initial PiD release, the official weights are under the NSCLv1 non-commercial license.
## Examples
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\pid_flux1_512_to_2048_4step_bf16.safetensors --llm "..\..\ComfyUI\models\text_encoders\gemma_2_2b_it_elm_bf16.safetensors" --vae ..\..\ComfyUI\models\vae\ae.sft --vae-format flux --cfg-scale 1.0 -p "a lovely cat" -r ..\assets\ernie_image\turbo_example.png --diffusion-fa -v --steps 4 -H 2048 -W 2048 --rng cpu
```
Before:
<img width="256" alt="ERNIE-Image Turbo example" src="../assets/ernie_image/turbo_example.png" />
After:
<img width="1024" alt="PiD example" src="../assets/pid/example.png" />
## Notes
- `-r`/`--ref-image` is required. PiD uses the first reference image as the source latent condition.
- `--vae-format` should match the VAE latent layout used by the PiD checkpoint. This is important when using standalone VAE files because the PiD diffusion
checkpoint alone does not identify the VAE format.

View File

@ -35,6 +35,22 @@ const char* const modes_str[] = {
"metadata", "metadata",
}; };
static sd_vae_format_t str_to_vae_format(const std::string& value) {
if (value == "auto") {
return SD_VAE_FORMAT_AUTO;
}
if (value == "flux") {
return SD_VAE_FORMAT_FLUX;
}
if (value == "sd3") {
return SD_VAE_FORMAT_SD3;
}
if (value == "flux2") {
return SD_VAE_FORMAT_FLUX2;
}
return SD_VAE_FORMAT_COUNT;
}
#if defined(_WIN32) #if defined(_WIN32)
static std::string utf16_to_utf8(const std::wstring& wstr) { static std::string utf16_to_utf8(const std::wstring& wstr) {
if (wstr.empty()) if (wstr.empty())
@ -348,6 +364,10 @@ ArgOptions SDContextParams::get_options() {
"--vae", "--vae",
"path to standalone vae model", "path to standalone vae model",
&vae_path}, &vae_path},
{"",
"--vae-format",
"VAE latent format override: auto, flux, sd3, or flux2 (default: auto)",
&vae_format},
{"", {"",
"--audio-vae", "--audio-vae",
"path to standalone LTX audio vae model", "path to standalone LTX audio vae model",
@ -639,6 +659,11 @@ bool SDContextParams::validate(SDMode mode) {
} }
} }
if (str_to_vae_format(vae_format) == SD_VAE_FORMAT_COUNT) {
LOG_ERROR("error: vae_format must be 'auto', 'flux', 'sd3', or 'flux2'");
return false;
}
return true; return true;
} }
@ -679,6 +704,7 @@ std::string SDContextParams::to_string() const {
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n" << " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
<< " vae_path: \"" << vae_path << "\",\n" << " vae_path: \"" << vae_path << "\",\n"
<< " vae_format: \"" << vae_format << "\",\n"
<< " audio_vae_path: \"" << audio_vae_path << "\",\n" << " audio_vae_path: \"" << audio_vae_path << "\",\n"
<< " taesd_path: \"" << taesd_path << "\",\n" << " taesd_path: \"" << taesd_path << "\",\n"
<< " esrgan_path: \"" << esrgan_path << "\",\n" << " esrgan_path: \"" << esrgan_path << "\",\n"
@ -772,6 +798,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
chroma_use_t5_mask, chroma_use_t5_mask,
chroma_t5_mask_pad, chroma_t5_mask_pad,
qwen_image_zero_cond_t, qwen_image_zero_cond_t,
str_to_vae_format(vae_format),
max_vram, max_vram,
backend.c_str(), backend.c_str(),
params_backend.c_str(), params_backend.c_str(),

View File

@ -94,6 +94,7 @@ struct SDContextParams {
std::string high_noise_diffusion_model_path; std::string high_noise_diffusion_model_path;
std::string embeddings_connectors_path; std::string embeddings_connectors_path;
std::string vae_path; std::string vae_path;
std::string vae_format = "auto";
std::string audio_vae_path; std::string audio_vae_path;
std::string taesd_path; std::string taesd_path;
std::string esrgan_path; std::string esrgan_path;

View File

@ -168,6 +168,14 @@ typedef struct {
const char* path; const char* path;
} sd_embedding_t; } sd_embedding_t;
enum sd_vae_format_t {
SD_VAE_FORMAT_AUTO = -1,
SD_VAE_FORMAT_FLUX,
SD_VAE_FORMAT_SD3,
SD_VAE_FORMAT_FLUX2,
SD_VAE_FORMAT_COUNT,
};
typedef struct { typedef struct {
const char* model_path; const char* model_path;
const char* clip_l_path; const char* clip_l_path;
@ -212,6 +220,7 @@ typedef struct {
bool chroma_use_t5_mask; bool chroma_use_t5_mask;
int chroma_t5_mask_pad; int chroma_t5_mask_pad;
bool qwen_image_zero_cond_t; bool qwen_image_zero_cond_t;
enum sd_vae_format_t vae_format;
float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB) float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB)
const char* backend; const char* backend;
const char* params_backend; const char* params_backend;

View File

@ -1171,7 +1171,6 @@ struct FluxCLIPEmbedder : public Conditioner {
return true; return true;
} }
void free_params_buffer() override { void free_params_buffer() override {
if (clip_l) { if (clip_l) {
clip_l->free_params_buffer(); clip_l->free_params_buffer();
@ -1601,8 +1600,8 @@ struct AnimaConditioner : public Conditioner {
bool alloc_params_buffer() override { bool alloc_params_buffer() override {
if (!llm->alloc_params_buffer()) { if (!llm->alloc_params_buffer()) {
return false; return false;
} }
return true; return true;
} }
@ -1719,6 +1718,8 @@ struct LLMEmbedder : public Conditioner {
arch = LLM::LLMArch::MINISTRAL_3_3B; arch = LLM::LLMArch::MINISTRAL_3_3B;
} else if (sd_version_is_lens(version)) { } else if (sd_version_is_lens(version)) {
arch = LLM::LLMArch::GPT_OSS_20B; arch = LLM::LLMArch::GPT_OSS_20B;
} else if (sd_version_is_pid(version)) {
arch = LLM::LLMArch::GEMMA2_2B;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3; arch = LLM::LLMArch::QWEN3;
} }
@ -1726,6 +1727,8 @@ struct LLMEmbedder : public Conditioner {
tokenizer = std::make_shared<MistralTokenizer>(); tokenizer = std::make_shared<MistralTokenizer>();
} else if (arch == LLM::LLMArch::GPT_OSS_20B) { } else if (arch == LLM::LLMArch::GPT_OSS_20B) {
tokenizer = std::make_shared<GPTOSSTokenizer>(); tokenizer = std::make_shared<GPTOSSTokenizer>();
} else if (arch == LLM::LLMArch::GEMMA2_2B) {
tokenizer = std::make_shared<Gemma2Tokenizer>();
} else { } else {
tokenizer = std::make_shared<Qwen2Tokenizer>(); tokenizer = std::make_shared<Qwen2Tokenizer>();
} }
@ -1743,7 +1746,7 @@ struct LLMEmbedder : public Conditioner {
bool alloc_params_buffer() override { bool alloc_params_buffer() override {
if (!llm->alloc_params_buffer()) { if (!llm->alloc_params_buffer()) {
return false; return false;
} }
return true; return true;
} }
@ -1847,12 +1850,16 @@ struct LLMEmbedder : public Conditioner {
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens); sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
sd::Tensor<float> attention_mask; sd::Tensor<float> attention_mask;
if (!mask.empty()) { if (!mask.empty()) {
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())}); attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
const float masked_attention_value = -std::numeric_limits<float>::max() / 4.0f;
for (size_t i1 = 0; i1 < mask.size(); ++i1) { for (size_t i1 = 0; i1 < mask.size(); ++i1) {
for (size_t i0 = 0; i0 < mask.size(); ++i0) { for (size_t i0 = 0; i0 < mask.size(); ++i0) {
float value = 0.0f; float value = 0.0f;
if (mask[i0] == 0.0f || i0 > i1) { if (mask[i0] == 0.0f) {
value = -INFINITY; value += masked_attention_value;
}
if (i0 > i1) {
value += masked_attention_value;
} }
attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value; attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value;
} }
@ -2126,6 +2133,53 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else if (sd_version_is_pid(version)) {
constexpr int pixeldit_max_length = 300;
const std::string chi_prompt =
"Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\n"
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\n"
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n"
"Here are examples of how to transform or refine prompts:\n"
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n"
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\n"
"User Prompt: ";
auto chi_tokens = std::get<0>(tokenize(chi_prompt, {0, 0}));
size_t num_chi_tokens = chi_tokens.size();
max_length = (int)num_chi_tokens + pixeldit_max_length - 2;
min_length = max_length;
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += " " + conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
auto hidden_states = encode_prompt(n_threads,
prompt,
prompt_attn_range,
min_length,
0,
image_embeds,
out_layers,
0,
false,
max_length);
GGML_ASSERT(!hidden_states.empty());
if (hidden_states.shape()[1] > pixeldit_max_length) {
auto bos = sd::ops::slice(hidden_states, 1, 0, 1);
auto tail = sd::ops::slice(hidden_states,
1,
hidden_states.shape()[1] - (pixeldit_max_length - 1),
hidden_states.shape()[1]);
hidden_states = sd::ops::concat(bos, tail, 1);
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
SDCondition result;
result.c_crossattn = std::move(hidden_states);
return result;
} else { } else {
GGML_ABORT("unknown version %d", version); GGML_ABORT("unknown version %d", version);
} }
@ -2268,10 +2322,10 @@ struct LTXAVEmbedder : public Conditioner {
bool alloc_params_buffer() override { bool alloc_params_buffer() override {
if (!llm->alloc_params_buffer()) { if (!llm->alloc_params_buffer()) {
return false; return false;
} }
if (!projector->alloc_params_buffer()) { if (!projector->alloc_params_buffer()) {
return false; return false;
} }
return true; return true;
} }

View File

@ -37,6 +37,7 @@ namespace LLM {
MISTRAL_SMALL_3_2, MISTRAL_SMALL_3_2,
MINISTRAL_3_3B, MINISTRAL_3_3B,
GEMMA3_12B, GEMMA3_12B,
GEMMA2_2B,
GPT_OSS_20B, GPT_OSS_20B,
ARCH_COUNT, ARCH_COUNT,
}; };
@ -48,6 +49,7 @@ namespace LLM {
"mistral_small3.2", "mistral_small3.2",
"ministral3.3b", "ministral3.3b",
"gemma3_12b", "gemma3_12b",
"gemma2_2b",
"gpt_oss_20b", "gpt_oss_20b",
}; };
@ -900,6 +902,33 @@ namespace LLM {
1.f, 1.f,
32.f, 32.f,
1.f); 1.f);
} else if (arch == LLMArch::GEMMA2_2B) {
q = ggml_rope_ext(ctx->ggml_ctx,
q,
input_pos,
nullptr,
head_dim,
GGML_ROPE_TYPE_NEOX,
8192,
10000.f,
1.f,
0.f,
1.f,
32.f,
1.f);
k = ggml_rope_ext(ctx->ggml_ctx,
k,
input_pos,
nullptr,
head_dim,
GGML_ROPE_TYPE_NEOX,
8192,
10000.f,
1.f,
0.f,
1.f,
32.f,
1.f);
} else if (arch == LLMArch::QWEN3_VL) { } else if (arch == LLMArch::QWEN3_VL) {
int sections[4] = {24, 20, 20, 0}; int sections[4] = {24, 20, 20, 0};
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@ -957,10 +986,18 @@ namespace LLM {
: arch(params.arch), : arch(params.arch),
sliding_attention(0) { sliding_attention(0) {
if (params.arch == LLMArch::GEMMA3_12B) { if (params.arch == LLMArch::GEMMA3_12B) {
post_attention_norm_name = "post_attention_norm"; post_attention_norm_name = "post_attention_norm"; // attn_post_norm
post_ffw_norm_name = "post_ffw_norm"; pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
post_ffw_norm_name = "post_ffw_norm"; // ffn_post_norm
} else if (params.arch == LLMArch::GEMMA2_2B) {
post_attention_norm_name = "post_attention_layernorm"; // ffn_norm
pre_ffw_norm_name = "pre_feedforward_layernorm";
post_ffw_norm_name = "post_feedforward_layernorm";
} else if (params.arch == LLMArch::GPT_OSS_20B) {
pre_ffw_norm_name = "post_attention_norm"; // attn_post_norm
} else {
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
} }
pre_ffw_norm_name = params.arch == LLMArch::GPT_OSS_20B ? "post_attention_norm" : "post_attention_layernorm";
blocks["self_attn"] = std::make_shared<Attention>(params); blocks["self_attn"] = std::make_shared<Attention>(params);
if (params.arch == LLMArch::GPT_OSS_20B) { if (params.arch == LLMArch::GPT_OSS_20B) {
@ -1447,6 +1484,21 @@ namespace LLM {
params.rope_thetas = {1000000.f, 10000.f}; params.rope_thetas = {1000000.f, 10000.f};
params.rope_scales = {8.f, 1.f}; params.rope_scales = {8.f, 1.f};
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0}; params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
} else if (arch == LLMArch::GEMMA2_2B) {
params.head_dim = 256;
params.num_heads = 8;
params.num_kv_heads = 4;
params.qkv_bias = false;
params.qk_norm = false;
params.rms_norm_eps = 1e-6f;
params.rms_norm_add = true;
params.normalize_input = true;
params.max_position_embeddings = 8192;
params.mlp_activation = MLPActivation::GELU_TANH;
params.hidden_size = 2304;
params.intermediate_size = 9216;
params.num_layers = 26;
params.vocab_size = 256000;
} else if (arch == LLMArch::GPT_OSS_20B) { } else if (arch == LLMArch::GPT_OSS_20B) {
params.head_dim = 64; params.head_dim = 64;
params.num_heads = 64; params.num_heads = 64;
@ -1585,6 +1637,7 @@ namespace LLM {
params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::MINISTRAL_3_3B ||
params.arch == LLMArch::QWEN3 || params.arch == LLMArch::QWEN3 ||
params.arch == LLMArch::GEMMA3_12B || params.arch == LLMArch::GEMMA3_12B ||
params.arch == LLMArch::GEMMA2_2B ||
params.arch == LLMArch::GPT_OSS_20B) { params.arch == LLMArch::GPT_OSS_20B) {
input_pos_vec.resize(n_tokens); input_pos_vec.resize(n_tokens);
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {

View File

@ -91,7 +91,6 @@ struct LoraModel : public GGMLRunner {
return false; return false;
} }
dry_run = false; dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, n_threads); model_loader.load_tensors(on_new_tensor_cb, n_threads);

View File

@ -1069,8 +1069,8 @@ namespace LTXV {
prefix); prefix);
if (!ltx_audio_vae->alloc_params_buffer()) { if (!ltx_audio_vae->alloc_params_buffer()) {
LOG_ERROR("ltx audio vae buffer allocation failed"); LOG_ERROR("ltx audio vae buffer allocation failed");
return; return;
} }
std::map<std::string, ggml_tensor*> tensors; std::map<std::string, ggml_tensor*> tensors;

View File

@ -432,6 +432,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
is_flux = true; is_flux = true;
} }
if (tensor_storage.name.find("model.diffusion_model.net.lq_proj.latent_proj.0.weight") != std::string::npos) {
return VERSION_PID;
}
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
return VERSION_CHROMA_RADIANCE; return VERSION_CHROMA_RADIANCE;
} }

View File

@ -49,6 +49,7 @@ enum SDVersion {
VERSION_ERNIE_IMAGE, VERSION_ERNIE_IMAGE,
VERSION_LENS, VERSION_LENS,
VERSION_LONGCAT, VERSION_LONGCAT,
VERSION_PID,
VERSION_COUNT, VERSION_COUNT,
}; };
@ -164,6 +165,13 @@ static inline bool sd_version_is_lens(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_pid(SDVersion version) {
if (version == VERSION_PID) {
return true;
}
return false;
}
static inline bool sd_version_uses_flux2_vae(SDVersion version) { static inline bool sd_version_uses_flux2_vae(SDVersion version) {
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version) || sd_version_is_lens(version)) { if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version) || sd_version_is_lens(version)) {
return true; return true;
@ -194,7 +202,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_z_image(version) || sd_version_is_z_image(version) ||
sd_version_is_ernie_image(version) || sd_version_is_ernie_image(version) ||
sd_version_is_lens(version) || sd_version_is_lens(version) ||
sd_version_is_longcat(version)) { sd_version_is_longcat(version) ||
sd_version_is_pid(version)) {
return true; return true;
} }
return false; return false;

842
src/pid.hpp Normal file
View File

@ -0,0 +1,842 @@
#ifndef __SD_PID_HPP__
#define __SD_PID_HPP__
#include <cmath>
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>
#include "common_dit.hpp"
#include "ggml_extend.hpp"
#include "mmdit.hpp"
#include "rope.hpp"
namespace Pid {
constexpr int PID_GRAPH_SIZE = 196608;
constexpr float PID_PI = 3.14159265358979323846f;
struct PixelDiTParams {
int64_t in_channels = 3;
int64_t hidden_size = 1536;
int64_t num_groups = 24;
int64_t patch_mlp_hidden_dim = 4096;
int64_t pixel_hidden_size = 16;
int64_t pixel_attn_hidden_size = 1152;
int64_t pixel_num_groups = 16;
int64_t patch_depth = 14;
int64_t pixel_depth = 2;
int64_t patch_size = 16;
int64_t txt_embed_dim = 2304;
int64_t txt_max_length = 300;
float text_rope_theta = 10000.f;
int64_t lq_latent_channels = 16;
int64_t lq_hidden_dim = 512;
int64_t lq_num_res_blocks = 4;
int64_t lq_interval = 2;
int64_t lq_sr_scale = 4;
int64_t lq_latent_down_factor = 8;
int64_t rope_ref_grid_h = 64;
int64_t rope_ref_grid_w = 64;
};
inline std::vector<float> make_rope_1d(int length,
int dim,
float theta) {
GGML_ASSERT(dim % 2 == 0);
return Rope::flatten(Rope::rope(Rope::linspace(0.f, static_cast<float>(length - 1), length), dim, theta));
}
inline std::vector<float> make_rope_2d(int height,
int width,
int dim,
float theta = 10000.f,
float scale = 16.f,
int ref_grid_h = 0,
int ref_grid_w = 0) {
GGML_ASSERT(dim % 4 == 0);
return Rope::embed_2d_interleaved(height, width, dim, theta, scale, ref_grid_h, ref_grid_w);
}
inline std::vector<float> make_pixel_abs_pos(int height,
int width,
int dim) {
GGML_ASSERT(dim % 4 == 0);
int half_dim = dim / 2;
std::vector<float> x_pos;
std::vector<float> y_pos;
x_pos.reserve(static_cast<size_t>(height) * width);
y_pos.reserve(static_cast<size_t>(height) * width);
for (int iy = 0; iy < height; ++iy) {
for (int ix = 0; ix < width; ++ix) {
x_pos.push_back(static_cast<float>(ix));
y_pos.push_back(static_cast<float>(iy));
}
}
auto x_emb = timestep_embedding(x_pos, half_dim, 10000, false);
auto y_emb = timestep_embedding(y_pos, half_dim, 10000, false);
std::vector<float> out(static_cast<size_t>(dim) * height * width);
for (int pos = 0; pos < height * width; ++pos) {
size_t out_base = static_cast<size_t>(pos) * dim;
size_t emb_base = static_cast<size_t>(pos) * half_dim;
for (int i = 0; i < half_dim; ++i) {
out[out_base + i] = x_emb[emb_base + i];
out[out_base + half_dim + i] = y_emb[emb_base + i];
}
}
return out;
}
inline ggml_tensor* apply_adaln(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* shift,
ggml_tensor* scale) {
return ggml_add(ctx, ggml_add(ctx, x, ggml_mul(ctx, x, scale)), shift);
}
struct PatchTokenEmbedder : public GGMLBlock {
bool use_rms_norm;
PatchTokenEmbedder(int64_t in_chans,
int64_t embed_dim,
bool use_rms_norm = false,
bool bias = true)
: use_rms_norm(use_rms_norm) {
blocks["proj"] = std::make_shared<Linear>(in_chans, embed_dim, bias);
if (use_rms_norm) {
blocks["norm"] = std::make_shared<RMSNorm>(embed_dim, 1e-6f);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x);
if (use_rms_norm) {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
x = norm->forward(ctx, x);
}
return x;
}
};
struct PixelDiTTimestepEmbedder : public GGMLBlock {
int frequency_embedding_size;
PixelDiTTimestepEmbedder(int64_t hidden_size,
int frequency_embedding_size = 256)
: frequency_embedding_size(frequency_embedding_size) {
blocks["mlp.0"] = std::make_shared<Linear>(frequency_embedding_size, hidden_size, true, true);
blocks["mlp.2"] = std::make_shared<Linear>(hidden_size, hidden_size, true, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* t) {
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size, 10);
t_emb = mlp_0->forward(ctx, t_emb);
t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb);
return mlp_2->forward(ctx, t_emb);
}
};
struct FeedForward : public GGMLBlock {
FeedForward(int64_t dim, int64_t hidden_dim) {
blocks["w1"] = std::make_shared<Linear>(dim, hidden_dim, false);
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false);
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto w1 = std::dynamic_pointer_cast<Linear>(blocks["w1"]);
auto w2 = std::dynamic_pointer_cast<Linear>(blocks["w2"]);
auto w3 = std::dynamic_pointer_cast<Linear>(blocks["w3"]);
auto h = ggml_silu_inplace(ctx->ggml_ctx, w1->forward(ctx, x));
h = ggml_mul_inplace(ctx->ggml_ctx, h, w3->forward(ctx, x));
return w2->forward(ctx, h);
}
};
struct FinalLayer : public GGMLBlock {
FinalLayer(int64_t hidden_size, int64_t out_channels) {
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
blocks["linear"] = std::make_shared<Linear>(hidden_size, out_channels, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
return linear->forward(ctx, norm->forward(ctx, x));
}
};
struct RotaryAttention : public GGMLBlock {
int64_t dim;
int64_t num_heads;
RotaryAttention(int64_t dim, int64_t num_heads)
: dim(dim), num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::make_shared<Linear>(dim, dim * 3, false);
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["proj"] = std::make_shared<Linear>(dim, dim, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* pos) {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
int64_t L = x->ne[1];
int64_t N = x->ne[2];
int64_t head_dim = dim / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, L, N);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, L, N);
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, L, N);
q = q_norm->forward(ctx, q);
k = k_norm->forward(ctx, k);
x = Rope::attention(ctx, q, k, v, pos, nullptr, 1.0f / 128.f, true);
return proj->forward(ctx, x);
}
};
struct MMDiTJointAttention : public GGMLBlock {
int64_t dim;
int64_t num_heads;
MMDiTJointAttention(int64_t dim, int64_t num_heads)
: dim(dim), num_heads(num_heads) {
int64_t head_dim = dim / num_heads;
blocks["qkv_x"] = std::make_shared<Linear>(dim, dim * 3, false);
blocks["qkv_y"] = std::make_shared<Linear>(dim, dim * 3, false);
blocks["q_norm_x"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["k_norm_x"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["q_norm_y"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["k_norm_y"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
blocks["proj_x"] = std::make_shared<Linear>(dim, dim, true);
blocks["proj_y"] = std::make_shared<Linear>(dim, dim, true);
}
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* y,
ggml_tensor* pos_img,
ggml_tensor* pos_txt) {
auto qkv_x_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv_x"]);
auto qkv_y_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv_y"]);
auto q_norm_x = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm_x"]);
auto k_norm_x = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm_x"]);
auto q_norm_y = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm_y"]);
auto k_norm_y = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm_y"]);
auto proj_x = std::dynamic_pointer_cast<Linear>(blocks["proj_x"]);
auto proj_y = std::dynamic_pointer_cast<Linear>(blocks["proj_y"]);
int64_t Nx = x->ne[1];
int64_t Ny = y->ne[1];
int64_t N = x->ne[2];
int64_t head_dim = dim / num_heads;
auto qkv_x = split_qkv(ctx->ggml_ctx, qkv_x_proj->forward(ctx, x));
auto qx = ggml_reshape_4d(ctx->ggml_ctx, qkv_x[0], head_dim, num_heads, Nx, N);
auto kx = ggml_reshape_4d(ctx->ggml_ctx, qkv_x[1], head_dim, num_heads, Nx, N);
auto vx = ggml_reshape_4d(ctx->ggml_ctx, qkv_x[2], head_dim, num_heads, Nx, N);
qx = q_norm_x->forward(ctx, qx);
kx = k_norm_x->forward(ctx, kx);
auto qkv_y = split_qkv(ctx->ggml_ctx, qkv_y_proj->forward(ctx, y));
auto qy = ggml_reshape_4d(ctx->ggml_ctx, qkv_y[0], head_dim, num_heads, Ny, N);
auto ky = ggml_reshape_4d(ctx->ggml_ctx, qkv_y[1], head_dim, num_heads, Ny, N);
auto vy = ggml_reshape_4d(ctx->ggml_ctx, qkv_y[2], head_dim, num_heads, Ny, N);
qy = q_norm_y->forward(ctx, qy);
ky = k_norm_y->forward(ctx, ky);
auto q_joint = ggml_concat(ctx->ggml_ctx, qy, qx, 2);
auto k_joint = ggml_concat(ctx->ggml_ctx, ky, kx, 2);
auto v_joint = ggml_concat(ctx->ggml_ctx, vy, vx, 2);
auto pos_joint = ggml_concat(ctx->ggml_ctx, pos_txt, pos_img, 3);
auto out = Rope::attention(ctx, q_joint, k_joint, v_joint, pos_joint, nullptr, 1.0f, true);
auto out_y = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, Ny);
auto out_x = ggml_ext_slice(ctx->ggml_ctx, out, 1, Ny, Ny + Nx);
return {proj_x->forward(ctx, out_x), proj_y->forward(ctx, out_y)};
}
};
struct MMDiTBlockT2I : public GGMLBlock {
int64_t hidden_size;
MMDiTBlockT2I(int64_t hidden_size, int64_t groups, int64_t mlp_hidden_dim)
: hidden_size(hidden_size) {
blocks["norm_x1"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
blocks["norm_y1"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
blocks["attn"] = std::make_shared<MMDiTJointAttention>(hidden_size, groups);
blocks["norm_x2"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
blocks["norm_y2"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
blocks["mlp_x"] = std::make_shared<FeedForward>(hidden_size, mlp_hidden_dim);
blocks["mlp_y"] = std::make_shared<FeedForward>(hidden_size, mlp_hidden_dim);
blocks["adaLN_modulation_img.0"] = std::make_shared<Linear>(hidden_size, 6 * hidden_size, true);
blocks["adaLN_modulation_txt.0"] = std::make_shared<Linear>(hidden_size, 6 * hidden_size, true);
}
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* y,
ggml_tensor* c,
ggml_tensor* pos_img,
ggml_tensor* pos_txt) {
auto norm_x1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_x1"]);
auto norm_y1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_y1"]);
auto attn = std::dynamic_pointer_cast<MMDiTJointAttention>(blocks["attn"]);
auto norm_x2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_x2"]);
auto norm_y2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_y2"]);
auto mlp_x = std::dynamic_pointer_cast<FeedForward>(blocks["mlp_x"]);
auto mlp_y = std::dynamic_pointer_cast<FeedForward>(blocks["mlp_y"]);
auto ada_img = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation_img.0"]);
auto ada_txt = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation_txt.0"]);
auto mx = ggml_ext_chunk(ctx->ggml_ctx, ada_img->forward(ctx, c), 6, 0);
auto my = ggml_ext_chunk(ctx->ggml_ctx, ada_txt->forward(ctx, c), 6, 0);
auto x_norm = apply_adaln(ctx->ggml_ctx, norm_x1->forward(ctx, x), mx[0], mx[1]);
auto y_norm = apply_adaln(ctx->ggml_ctx, norm_y1->forward(ctx, y), my[0], my[1]);
auto attn_out = attn->forward(ctx, x_norm, y_norm, pos_img, pos_txt);
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out.first, mx[2]));
y = ggml_add(ctx->ggml_ctx, y, ggml_mul(ctx->ggml_ctx, attn_out.second, my[2]));
auto x_mlp = mlp_x->forward(ctx, apply_adaln(ctx->ggml_ctx, norm_x2->forward(ctx, x), mx[3], mx[4]));
auto y_mlp = mlp_y->forward(ctx, apply_adaln(ctx->ggml_ctx, norm_y2->forward(ctx, y), my[3], my[4]));
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x_mlp, mx[5]));
y = ggml_add(ctx->ggml_ctx, y, ggml_mul(ctx->ggml_ctx, y_mlp, my[5]));
return {x, y};
}
};
struct PixelTokenEmbedder : public GGMLBlock {
int64_t in_channels;
int64_t hidden_size_output;
PixelTokenEmbedder(int64_t in_channels, int64_t hidden_size_output)
: in_channels(in_channels), hidden_size_output(hidden_size_output) {
blocks["proj"] = std::make_shared<Linear>(in_channels, hidden_size_output, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* inputs,
int64_t patch_size,
ggml_tensor* pos_full) {
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
int64_t W = inputs->ne[0];
int64_t H = inputs->ne[1];
int64_t B = inputs->ne[3];
int64_t L = (W / patch_size) * (H / patch_size);
int64_t P2 = patch_size * patch_size;
auto x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, inputs, 2, 0, 1, 3));
x = ggml_reshape_3d(ctx->ggml_ctx, x, in_channels, W * H, B);
x = proj->forward(ctx, x);
x = ggml_add(ctx->ggml_ctx, x, pos_full);
x = ggml_reshape_4d(ctx->ggml_ctx, x, hidden_size_output, W, H, B);
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3));
x = DiT::patchify(ctx->ggml_ctx, x, static_cast<int>(patch_size), static_cast<int>(patch_size), false);
x = ggml_reshape_3d(ctx->ggml_ctx, x, hidden_size_output, P2, L * B);
return x;
}
};
struct PiTBlock : public GGMLBlock {
int64_t pixel_dim;
int64_t context_dim;
int64_t attn_dim;
int64_t num_heads;
int64_t patch_size;
PiTBlock(int64_t pixel_dim,
int64_t context_dim,
int64_t patch_size,
int64_t attn_dim,
int64_t num_heads)
: pixel_dim(pixel_dim),
context_dim(context_dim),
attn_dim(attn_dim),
num_heads(num_heads),
patch_size(patch_size) {
int64_t p2 = patch_size * patch_size;
blocks["compress_to_attn"] = std::make_shared<Linear>(p2 * pixel_dim, attn_dim, true);
blocks["expand_from_attn"] = std::make_shared<Linear>(attn_dim, p2 * pixel_dim, true);
blocks["norm1"] = std::make_shared<RMSNorm>(pixel_dim, 1e-6f);
blocks["attn"] = std::make_shared<RotaryAttention>(attn_dim, num_heads);
blocks["norm2"] = std::make_shared<RMSNorm>(pixel_dim, 1e-6f);
blocks["mlp"] = std::make_shared<Mlp>(pixel_dim, pixel_dim * 4);
blocks["adaLN_modulation.0"] = std::make_shared<Linear>(context_dim, 6 * pixel_dim * p2, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* s_cond,
int64_t image_height,
int64_t image_width,
ggml_tensor* pos_comp) {
auto compress = std::dynamic_pointer_cast<Linear>(blocks["compress_to_attn"]);
auto expand = std::dynamic_pointer_cast<Linear>(blocks["expand_from_attn"]);
auto norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm1"]);
auto attn = std::dynamic_pointer_cast<RotaryAttention>(blocks["attn"]);
auto norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
auto ada = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
int64_t Hs = image_height / patch_size;
int64_t Ws = image_width / patch_size;
int64_t L = Hs * Ws;
int64_t BL = x->ne[2];
int64_t B = BL / L;
int64_t P2 = patch_size * patch_size;
auto ada_params = ada->forward(ctx, s_cond);
ada_params = ggml_reshape_3d(ctx->ggml_ctx, ada_params, 6 * pixel_dim, P2, BL);
auto mod = ggml_ext_chunk(ctx->ggml_ctx, ada_params, 6, 0);
auto x_norm = apply_adaln(ctx->ggml_ctx, norm1->forward(ctx, x), mod[0], mod[1]);
auto x_flat = ggml_reshape_2d(ctx->ggml_ctx, x_norm, P2 * pixel_dim, BL);
auto x_comp = compress->forward(ctx, x_flat);
x_comp = ggml_reshape_3d(ctx->ggml_ctx, x_comp, attn_dim, L, B);
auto attn_out = attn->forward(ctx, x_comp, pos_comp);
auto attn_flat = expand->forward(ctx, ggml_reshape_2d(ctx->ggml_ctx, attn_out, attn_dim, BL));
auto attn_exp = ggml_reshape_3d(ctx->ggml_ctx, attn_flat, pixel_dim, P2, BL);
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_exp, mod[2]));
auto mlp_out = mlp->forward(ctx, apply_adaln(ctx->ggml_ctx, norm2->forward(ctx, x), mod[3], mod[4]));
return ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, mod[5]));
}
};
struct SigmaAwareGate : public GGMLBlock {
int64_t dim;
SigmaAwareGate(int64_t dim)
: dim(dim) {
blocks["content_proj"] = std::make_shared<Linear>(dim * 2, dim, true);
}
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
std::string prefix = "") override {
params["log_alpha"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* lq,
ggml_tensor* sigma) {
auto content_proj = std::dynamic_pointer_cast<Linear>(blocks["content_proj"]);
auto content_logit = content_proj->forward(ctx, ggml_concat(ctx->ggml_ctx, x, lq, 0));
sigma = ggml_reshape_3d(ctx->ggml_ctx, sigma, 1, 1, sigma->ne[0]);
auto alpha = ggml_exp(ctx->ggml_ctx, params["log_alpha"]);
auto offset = ggml_neg(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, alpha, sigma));
auto gate = ggml_sigmoid(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, content_logit, offset));
return ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, gate, lq));
}
};
struct PiDResBlock : public GGMLBlock {
PiDResBlock(int64_t channels) {
blocks["block.0"] = std::make_shared<GroupNorm>(4, channels, 1e-5f);
blocks["block.2"] = std::make_shared<Conv2d>(channels, channels, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
blocks["block.3"] = std::make_shared<GroupNorm>(4, channels, 1e-5f);
blocks["block.5"] = std::make_shared<Conv2d>(channels, channels, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto norm1 = std::dynamic_pointer_cast<GroupNorm>(blocks["block.0"]);
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["block.2"]);
auto norm2 = std::dynamic_pointer_cast<GroupNorm>(blocks["block.3"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["block.5"]);
auto h = ggml_silu_inplace(ctx->ggml_ctx, norm1->forward(ctx, x));
h = conv1->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, norm2->forward(ctx, h));
h = conv2->forward(ctx, h);
return ggml_add(ctx->ggml_ctx, x, h);
}
};
struct LQProjection2D : public GGMLBlock {
PixelDiTParams params_cfg;
LQProjection2D(const PixelDiTParams& params_cfg)
: params_cfg(params_cfg) {
blocks["latent_proj.0"] = std::make_shared<Conv2d>(params_cfg.lq_latent_channels, params_cfg.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
blocks["latent_proj.2"] = std::make_shared<Conv2d>(params_cfg.lq_hidden_dim, params_cfg.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) {
blocks["latent_proj." + std::to_string(3 + i)] = std::make_shared<PiDResBlock>(params_cfg.lq_hidden_dim);
}
int num_outputs = static_cast<int>((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval);
for (int i = 0; i < num_outputs; ++i) {
blocks["output_heads." + std::to_string(i)] = std::make_shared<Linear>(params_cfg.lq_hidden_dim, params_cfg.hidden_size, true);
blocks["gate_modules." + std::to_string(i)] = std::make_shared<SigmaAwareGate>(params_cfg.hidden_size);
}
}
bool is_gate_active(int block_idx) const {
return block_idx % params_cfg.lq_interval == 0;
}
int get_output_index(int block_idx) const {
return block_idx / static_cast<int>(params_cfg.lq_interval);
}
ggml_tensor* gate(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* lq,
ggml_tensor* sigma,
int out_idx) {
auto gate_module = std::dynamic_pointer_cast<SigmaAwareGate>(blocks["gate_modules." + std::to_string(out_idx)]);
return gate_module->forward(ctx, x, lq, sigma);
}
std::vector<ggml_tensor*> forward(GGMLRunnerContext* ctx,
ggml_tensor* lq_latent,
int64_t target_pH,
int64_t target_pW) {
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["latent_proj.0"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["latent_proj.2"]);
float z_to_patch_ratio = static_cast<float>(params_cfg.lq_sr_scale * params_cfg.lq_latent_down_factor) /
static_cast<float>(params_cfg.patch_size);
GGML_ASSERT(z_to_patch_ratio >= 1.0f);
if (lq_latent->ne[0] != target_pW || lq_latent->ne[1] != target_pH) {
lq_latent = ggml_interpolate(ctx->ggml_ctx,
lq_latent,
target_pW,
target_pH,
lq_latent->ne[2],
lq_latent->ne[3],
GGML_SCALE_MODE_NEAREST);
}
auto feat = conv0->forward(ctx, lq_latent);
feat = ggml_silu_inplace(ctx->ggml_ctx, feat);
feat = conv2->forward(ctx, feat);
for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) {
auto block = std::dynamic_pointer_cast<PiDResBlock>(blocks["latent_proj." + std::to_string(3 + i)]);
feat = block->forward(ctx, feat);
}
int64_t B = feat->ne[3];
int64_t C = feat->ne[2];
int64_t L = target_pH * target_pW;
auto tokens = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, feat, 2, 0, 1, 3));
tokens = ggml_reshape_3d(ctx->ggml_ctx, tokens, C, L, B);
int num_outputs = static_cast<int>((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval);
std::vector<ggml_tensor*> outputs;
outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto head = std::dynamic_pointer_cast<Linear>(blocks["output_heads." + std::to_string(i)]);
outputs.push_back(head->forward(ctx, tokens));
}
return outputs;
}
};
struct PixelDiT : public GGMLBlock {
PixelDiTParams params_cfg;
PixelDiT() = default;
PixelDiT(const PixelDiTParams& params_cfg)
: params_cfg(params_cfg) {
blocks["pixel_embedder"] = std::make_shared<PixelTokenEmbedder>(params_cfg.in_channels, params_cfg.pixel_hidden_size);
blocks["s_embedder"] = std::make_shared<PatchTokenEmbedder>(params_cfg.in_channels * params_cfg.patch_size * params_cfg.patch_size, params_cfg.hidden_size, false, true);
blocks["t_embedder"] = std::make_shared<PixelDiTTimestepEmbedder>(params_cfg.hidden_size);
blocks["y_embedder"] = std::make_shared<PatchTokenEmbedder>(params_cfg.txt_embed_dim, params_cfg.hidden_size, true, true);
for (int i = 0; i < params_cfg.patch_depth; ++i) {
blocks["patch_blocks." + std::to_string(i)] = std::make_shared<MMDiTBlockT2I>(params_cfg.hidden_size, params_cfg.num_groups, params_cfg.patch_mlp_hidden_dim);
}
for (int i = 0; i < params_cfg.pixel_depth; ++i) {
blocks["pixel_blocks." + std::to_string(i)] = std::make_shared<PiTBlock>(params_cfg.pixel_hidden_size,
params_cfg.hidden_size,
params_cfg.patch_size,
params_cfg.pixel_attn_hidden_size,
params_cfg.pixel_num_groups);
}
blocks["final_layer"] = std::make_shared<FinalLayer>(params_cfg.pixel_hidden_size, params_cfg.in_channels);
blocks["lq_proj"] = std::make_shared<LQProjection2D>(params_cfg);
}
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
std::string prefix = "") override {
params["y_pos_embedding"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, params_cfg.hidden_size, params_cfg.txt_max_length, 1);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timesteps,
ggml_tensor* context,
ggml_tensor* lq_latent,
ggml_tensor* degrade_sigma,
ggml_tensor* pos_img,
ggml_tensor* pos_txt,
ggml_tensor* pixel_pos_full,
ggml_tensor* pixel_pos_comp) {
auto pixel_embedder = std::dynamic_pointer_cast<PixelTokenEmbedder>(blocks["pixel_embedder"]);
auto s_embedder = std::dynamic_pointer_cast<PatchTokenEmbedder>(blocks["s_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<PixelDiTTimestepEmbedder>(blocks["t_embedder"]);
auto y_embedder = std::dynamic_pointer_cast<PatchTokenEmbedder>(blocks["y_embedder"]);
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
auto lq_proj = std::dynamic_pointer_cast<LQProjection2D>(blocks["lq_proj"]);
int64_t W_orig = x->ne[0];
int64_t H_orig = x->ne[1];
x = DiT::pad_to_patch_size(ctx, x, static_cast<int>(params_cfg.patch_size), static_cast<int>(params_cfg.patch_size));
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t B = x->ne[3];
int64_t Hs = H / params_cfg.patch_size;
int64_t Ws = W / params_cfg.patch_size;
int64_t L = Hs * Ws;
int64_t P2 = params_cfg.patch_size * params_cfg.patch_size;
auto x_patches = DiT::patchify(ctx->ggml_ctx, x, static_cast<int>(params_cfg.patch_size), static_cast<int>(params_cfg.patch_size), true);
auto t_emb = t_embedder->forward(ctx, timesteps);
auto condition = ggml_silu(ctx->ggml_ctx, t_emb);
GGML_ASSERT(context != nullptr);
int64_t Ltxt = std::min<int64_t>(context->ne[1], params_cfg.txt_max_length);
auto y = ggml_ext_slice(ctx->ggml_ctx, context, 1, 0, Ltxt);
auto y_emb = y_embedder->forward(ctx, y);
auto y_pos = ggml_ext_slice(ctx->ggml_ctx, params["y_pos_embedding"], 1, 0, Ltxt);
y_emb = ggml_add(ctx->ggml_ctx, y_emb, y_pos);
std::vector<ggml_tensor*> lq_features = lq_proj->forward(ctx, lq_latent, Hs, Ws);
auto s = s_embedder->forward(ctx, x_patches);
for (int i = 0; i < params_cfg.patch_depth; ++i) {
if (lq_proj->is_gate_active(i)) {
int out_idx = lq_proj->get_output_index(i);
if (out_idx < static_cast<int>(lq_features.size())) {
s = lq_proj->gate(ctx, s, lq_features[out_idx], degrade_sigma, out_idx);
}
}
auto block = std::dynamic_pointer_cast<MMDiTBlockT2I>(blocks["patch_blocks." + std::to_string(i)]);
auto out = block->forward(ctx,
s,
y_emb,
condition,
pos_img,
pos_txt);
s = out.first;
y_emb = out.second;
sd::ggml_graph_cut::mark_graph_cut(s, "pid.patch_blocks." + std::to_string(i), "s");
sd::ggml_graph_cut::mark_graph_cut(y_emb, "pid.patch_blocks." + std::to_string(i), "y");
}
s = ggml_silu(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, s, t_emb));
auto s_cond = ggml_reshape_2d(ctx->ggml_ctx, s, params_cfg.hidden_size, L * B);
auto pixels = pixel_embedder->forward(ctx, x, params_cfg.patch_size, pixel_pos_full);
for (int i = 0; i < params_cfg.pixel_depth; ++i) {
auto block = std::dynamic_pointer_cast<PiTBlock>(blocks["pixel_blocks." + std::to_string(i)]);
pixels = block->forward(ctx, pixels, s_cond, H, W, pixel_pos_comp);
sd::ggml_graph_cut::mark_graph_cut(pixels, "pid.pixel_blocks." + std::to_string(i), "pixels");
}
pixels = final_layer->forward(ctx, pixels);
pixels = ggml_reshape_3d(ctx->ggml_ctx, pixels, params_cfg.in_channels * P2, L, B);
auto out = DiT::unpatchify(ctx->ggml_ctx,
pixels,
Hs,
Ws,
static_cast<int>(params_cfg.patch_size),
static_cast<int>(params_cfg.patch_size),
false);
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H_orig);
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W_orig);
return out;
}
};
struct PiDRunner : public DiffusionModelRunner {
PixelDiTParams params_cfg;
PixelDiT model;
std::vector<float> pos_img_vec;
std::vector<float> pos_txt_vec;
std::vector<float> pixel_pos_vec;
std::vector<float> pixel_pos_comp_vec;
PiDRunner(ggml_backend_t backend,
ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map,
const std::string prefix = "model.diffusion_model")
: DiffusionModelRunner(backend, params_backend, prefix) {
for (const auto& pair : tensor_storage_map) {
const std::string& tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos) {
continue;
}
size_t pos = tensor_name.find("patch_blocks.");
if (pos != std::string::npos) {
auto items = split_string(tensor_name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
params_cfg.patch_depth = std::max<int64_t>(params_cfg.patch_depth, block_index + 1);
}
}
pos = tensor_name.find("pixel_blocks.");
if (pos != std::string::npos) {
auto items = split_string(tensor_name.substr(pos), '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
params_cfg.pixel_depth = std::max<int64_t>(params_cfg.pixel_depth, block_index + 1);
}
}
if (tensor_name.find("lq_proj.latent_proj.0.weight") != std::string::npos) {
params_cfg.lq_latent_channels = pair.second.ne[2];
params_cfg.lq_latent_down_factor = params_cfg.lq_latent_channels >= 64 ? 16 : 8;
}
if (tensor_name.find("patch_blocks.0.mlp_x.w1.weight") != std::string::npos) {
params_cfg.patch_mlp_hidden_dim = pair.second.ne[1];
}
}
LOG_INFO("PiD params: patch_depth=%" PRId64 ", pixel_depth=%" PRId64 ", patch_mlp_hidden_dim=%" PRId64 ", lq_latent_channels=%" PRId64 ", lq_latent_down_factor=%" PRId64,
params_cfg.patch_depth,
params_cfg.pixel_depth,
params_cfg.patch_mlp_hidden_dim,
params_cfg.lq_latent_channels,
params_cfg.lq_latent_down_factor);
model = PixelDiT(params_cfg);
model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return "PiD";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) override {
model.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
const sd::Tensor<float>& timesteps_tensor,
const sd::Tensor<float>& context_tensor,
const sd::Tensor<float>& lq_latent_tensor,
const sd::Tensor<float>& degrade_sigma_tensor) {
ggml_cgraph* gf = new_graph_custom(PID_GRAPH_SIZE);
ggml_tensor* x = make_input(x_tensor);
ggml_tensor* timesteps = make_input(timesteps_tensor);
ggml_tensor* context = make_input(context_tensor);
ggml_tensor* lq_latent = make_input(lq_latent_tensor);
ggml_tensor* degrade_sigma = make_input(degrade_sigma_tensor);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t B = x->ne[3];
int64_t Wp = align_up(static_cast<int>(W), static_cast<int>(params_cfg.patch_size));
int64_t Hp = align_up(static_cast<int>(H), static_cast<int>(params_cfg.patch_size));
int64_t Hs = Hp / params_cfg.patch_size;
int64_t Ws = Wp / params_cfg.patch_size;
pos_img_vec = make_rope_2d(static_cast<int>(Hs),
static_cast<int>(Ws),
static_cast<int>(params_cfg.hidden_size / params_cfg.num_groups),
10000.f,
16.f,
static_cast<int>(params_cfg.rope_ref_grid_h),
static_cast<int>(params_cfg.rope_ref_grid_w));
auto pos_img = ggml_new_tensor_4d(compute_ctx,
GGML_TYPE_F32,
2,
2,
params_cfg.hidden_size / params_cfg.num_groups / 2,
Hs * Ws);
set_backend_tensor_data(pos_img, pos_img_vec.data());
int64_t Ltxt = std::min<int64_t>(context->ne[1], params_cfg.txt_max_length);
pos_txt_vec = make_rope_1d(static_cast<int>(Ltxt),
static_cast<int>(params_cfg.hidden_size / params_cfg.num_groups),
params_cfg.text_rope_theta);
auto pos_txt = ggml_new_tensor_4d(compute_ctx,
GGML_TYPE_F32,
2,
2,
params_cfg.hidden_size / params_cfg.num_groups / 2,
Ltxt);
set_backend_tensor_data(pos_txt, pos_txt_vec.data());
pixel_pos_vec = make_pixel_abs_pos(static_cast<int>(Hp),
static_cast<int>(Wp),
static_cast<int>(params_cfg.pixel_hidden_size));
auto pixel_pos = ggml_new_tensor_3d(compute_ctx,
GGML_TYPE_F32,
params_cfg.pixel_hidden_size,
Wp * Hp,
1);
set_backend_tensor_data(pixel_pos, pixel_pos_vec.data());
pixel_pos_comp_vec = make_rope_2d(static_cast<int>(Hs),
static_cast<int>(Ws),
static_cast<int>(params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups),
10000.f,
16.f,
static_cast<int>(params_cfg.rope_ref_grid_h),
static_cast<int>(params_cfg.rope_ref_grid_w));
auto pixel_pos_comp = ggml_new_tensor_4d(compute_ctx,
GGML_TYPE_F32,
2,
2,
params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups / 2,
Hs * Ws);
set_backend_tensor_data(pixel_pos_comp, pixel_pos_comp_vec.data());
auto runner_ctx = get_context();
auto out = model.forward(&runner_ctx,
x,
timesteps,
context,
lq_latent,
degrade_sigma,
pos_img,
pos_txt,
pixel_pos,
pixel_pos_comp);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& x,
const sd::Tensor<float>& timesteps,
const sd::Tensor<float>& context,
const sd::Tensor<float>& lq_latent,
const sd::Tensor<float>& degrade_sigma) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, lq_latent, degrade_sigma);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
}
sd::Tensor<float> compute(int n_threads,
const DiffusionParams& diffusion_params) override {
GGML_ASSERT(diffusion_params.x != nullptr);
GGML_ASSERT(diffusion_params.timesteps != nullptr);
GGML_ASSERT(diffusion_params.context != nullptr);
GGML_ASSERT(diffusion_params.ref_latents != nullptr);
GGML_ASSERT(!diffusion_params.ref_latents->empty());
auto degrade_sigma = sd::Tensor<float>::from_vector({0.0f});
return compute(n_threads,
*diffusion_params.x,
*diffusion_params.timesteps,
*diffusion_params.context,
diffusion_params.ref_latents->front(),
degrade_sigma);
}
};
} // namespace Pid
#endif // __SD_PID_HPP__

View File

@ -249,6 +249,59 @@ namespace Rope {
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout); return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout);
} }
__STATIC_INLINE__ std::vector<float> embed_2d_interleaved(int height,
int width,
int dim,
float theta = 10000.f,
float scale = 16.f,
int ref_grid_h = 0,
int ref_grid_w = 0) {
assert(dim % 4 == 0);
int half_dim = dim / 2;
int dim_axis = dim / 2;
int axis_half_dim = dim_axis / 2;
float h_ntk = 1.f;
float w_ntk = 1.f;
if (ref_grid_h > 0 && ref_grid_w > 0 && dim_axis > 2) {
float power = static_cast<float>(dim_axis) / static_cast<float>(dim_axis - 2);
h_ntk = std::pow(static_cast<float>(height) / static_cast<float>(ref_grid_h), power);
w_ntk = std::pow(static_cast<float>(width) / static_cast<float>(ref_grid_w), power);
}
std::vector<float> x_pos;
std::vector<float> y_pos;
x_pos.reserve(static_cast<size_t>(height) * width);
y_pos.reserve(static_cast<size_t>(height) * width);
for (int iy = 0; iy < height; ++iy) {
float y = height == 1 ? 0.f : scale * static_cast<float>(iy) / static_cast<float>(height - 1);
for (int ix = 0; ix < width; ++ix) {
float x = width == 1 ? 0.f : scale * static_cast<float>(ix) / static_cast<float>(width - 1);
x_pos.push_back(x);
y_pos.push_back(y);
}
}
auto x_emb = rope(x_pos, dim_axis, theta * w_ntk);
auto y_emb = rope(y_pos, dim_axis, theta * h_ntk);
std::vector<float> out(static_cast<size_t>(height) * width * half_dim * 4);
for (int pos = 0; pos < height * width; ++pos) {
for (int i = 0; i < axis_half_dim; ++i) {
int jx = 2 * i;
int jy = 2 * i + 1;
size_t base_x = static_cast<size_t>(pos) * half_dim * 4 + static_cast<size_t>(jx) * 4;
size_t base_y = static_cast<size_t>(pos) * half_dim * 4 + static_cast<size_t>(jy) * 4;
size_t axis = static_cast<size_t>(i) * 4;
for (int k = 0; k < 4; ++k) {
out[base_x + k] = x_emb[pos][axis + k];
out[base_y + k] = y_emb[pos][axis + k];
}
}
}
return out;
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs, int bs,
int axes_dim_num, int axes_dim_num,

View File

@ -1,3 +1,7 @@
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "ggml_graph_cut.h" #include "ggml_graph_cut.h"
@ -26,6 +30,7 @@
#include "ltx_vae.hpp" #include "ltx_vae.hpp"
#include "ltxv.hpp" #include "ltxv.hpp"
#include "mmdit.hpp" #include "mmdit.hpp"
#include "pid.hpp"
#include "pmid.hpp" #include "pmid.hpp"
#include "qwen_image.hpp" #include "qwen_image.hpp"
#include "sample-cache.h" #include "sample-cache.h"
@ -39,6 +44,9 @@
#include "latent-preview.h" #include "latent-preview.h"
#include "name_conversion.h" #include "name_conversion.h"
const char* sd_vae_format_name(enum sd_vae_format_t format);
static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback);
const char* model_version_to_str[] = { const char* model_version_to_str[] = {
"SD 1.x", "SD 1.x",
"SD 1.x Inpaint", "SD 1.x Inpaint",
@ -75,6 +83,7 @@ const char* model_version_to_str[] = {
"Ernie Image", "Ernie Image",
"Lens", "Lens",
"Longcat-Image", "Longcat-Image",
"PiD",
}; };
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
@ -501,6 +510,16 @@ public:
params_backend_for(SDBackendModule::DIFFUSION), params_backend_for(SDBackendModule::DIFFUSION),
tensor_storage_map, tensor_storage_map,
"model.diffusion_model"); "model.diffusion_model");
} else if (sd_version_is_pid(version)) {
vae_decode_only = false;
cond_stage_model = std::make_shared<LLMEmbedder>(backend_for(SDBackendModule::TE),
params_backend_for(SDBackendModule::TE),
tensor_storage_map,
version);
diffusion_model = std::make_shared<Pid::PiDRunner>(backend_for(SDBackendModule::DIFFUSION),
params_backend_for(SDBackendModule::DIFFUSION),
tensor_storage_map,
"model.diffusion_model.net");
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
bool is_chroma = false; bool is_chroma = false;
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
@ -743,6 +762,16 @@ public:
} }
}; };
sd_vae_format_t vae_format = sd_ctx_params->vae_format;
if (vae_format < SD_VAE_FORMAT_AUTO || vae_format >= SD_VAE_FORMAT_COUNT) {
LOG_WARN("invalid VAE format override, using auto");
vae_format = SD_VAE_FORMAT_AUTO;
}
SDVersion vae_version = version;
if (sd_version_is_pid(version) && vae_format != SD_VAE_FORMAT_AUTO) {
vae_version = sd_vae_format_to_version(vae_format, vae_version);
}
auto create_vae = [&]() -> std::shared_ptr<VAE> { auto create_vae = [&]() -> std::shared_ptr<VAE> {
if (sd_version_is_ltxav(version)) { if (sd_version_is_ltxav(version)) {
return std::make_shared<LTXVideoVAE>(backend_for(SDBackendModule::VAE), return std::make_shared<LTXVideoVAE>(backend_for(SDBackendModule::VAE),
@ -767,7 +796,7 @@ public:
"first_stage_model", "first_stage_model",
vae_decode_only, vae_decode_only,
false, false,
version); vae_version);
if (sd_version_is_sdxl(version) && if (sd_version_is_sdxl(version) &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) { (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
float vae_conv_2d_scale = 1.f / 32.f; float vae_conv_2d_scale = 1.f / 32.f;
@ -1140,12 +1169,15 @@ public:
version == VERSION_HIDREAM_O1 || version == VERSION_HIDREAM_O1 ||
sd_version_is_anima(version) || sd_version_is_anima(version) ||
sd_version_is_ernie_image(version) || sd_version_is_ernie_image(version) ||
sd_version_is_z_image(version)) { sd_version_is_z_image(version) ||
sd_version_is_pid(version)) {
pred_type = FLOW_PRED; pred_type = FLOW_PRED;
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
default_flow_shift = 5.f; default_flow_shift = 5.f;
} else if (sd_version_is_ernie_image(version)) { } else if (sd_version_is_ernie_image(version)) {
default_flow_shift = 4.f; default_flow_shift = 4.f;
} else if (sd_version_is_pid(version)) {
default_flow_shift = 1.5f;
} else { } else {
default_flow_shift = 3.f; default_flow_shift = 3.f;
} }
@ -2180,6 +2212,9 @@ public:
} }
int get_vae_scale_factor() { int get_vae_scale_factor() {
if (sd_version_is_pid(version)) {
return 1;
}
return first_stage_model->get_scale_factor(); return first_stage_model->get_scale_factor();
} }
@ -2206,6 +2241,8 @@ public:
latent_channel = 3; latent_channel = 3;
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
latent_channel = 3; latent_channel = 3;
} else if (sd_version_is_pid(version)) {
latent_channel = 3;
} else if (sd_version_uses_flux2_vae(version)) { } else if (sd_version_uses_flux2_vae(version)) {
latent_channel = 128; latent_channel = 128;
} else { } else {
@ -2283,6 +2320,9 @@ public:
} }
sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) { sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) {
if (sd_version_is_pid(version)) {
return sd::ops::clamp((x + 1.f) * 0.5f, 0.0f, 1.0f);
}
auto latents = first_stage_model->diffusion_to_vae_latents(x); auto latents = first_stage_model->diffusion_to_vae_latents(x);
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling); first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y); return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
@ -2543,6 +2583,35 @@ enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str) {
return SD_HIRES_UPSCALER_COUNT; return SD_HIRES_UPSCALER_COUNT;
} }
const char* sd_vae_format_name(enum sd_vae_format_t format) {
switch (format) {
case SD_VAE_FORMAT_AUTO:
return "auto";
case SD_VAE_FORMAT_FLUX:
return "flux";
case SD_VAE_FORMAT_SD3:
return "sd3";
case SD_VAE_FORMAT_FLUX2:
return "flux2";
default:
return NONE_STR;
}
}
static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback) {
switch (format) {
case SD_VAE_FORMAT_FLUX:
return VERSION_FLUX;
case SD_VAE_FORMAT_SD3:
return VERSION_SD3;
case SD_VAE_FORMAT_FLUX2:
return VERSION_FLUX2;
case SD_VAE_FORMAT_AUTO:
default:
return fallback;
}
}
void sd_cache_params_init(sd_cache_params_t* cache_params) { void sd_cache_params_init(sd_cache_params_t* cache_params) {
*cache_params = {}; *cache_params = {};
cache_params->mode = SD_CACHE_DISABLED; cache_params->mode = SD_CACHE_DISABLED;
@ -2608,6 +2677,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->chroma_use_dit_mask = true; sd_ctx_params->chroma_use_dit_mask = true;
sd_ctx_params->chroma_use_t5_mask = false; sd_ctx_params->chroma_use_t5_mask = false;
sd_ctx_params->chroma_t5_mask_pad = 1; sd_ctx_params->chroma_t5_mask_pad = 1;
sd_ctx_params->vae_format = SD_VAE_FORMAT_AUTO;
sd_ctx_params->backend = nullptr; sd_ctx_params->backend = nullptr;
sd_ctx_params->params_backend = nullptr; sd_ctx_params->params_backend = nullptr;
} }
@ -2655,7 +2725,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"circular_y: %s\n" "circular_y: %s\n"
"chroma_use_dit_mask: %s\n" "chroma_use_dit_mask: %s\n"
"chroma_use_t5_mask: %s\n" "chroma_use_t5_mask: %s\n"
"chroma_t5_mask_pad: %d\n", "chroma_t5_mask_pad: %d\n"
"vae_format: %s\n",
SAFE_STR(sd_ctx_params->model_path), SAFE_STR(sd_ctx_params->model_path),
SAFE_STR(sd_ctx_params->clip_l_path), SAFE_STR(sd_ctx_params->clip_l_path),
SAFE_STR(sd_ctx_params->clip_g_path), SAFE_STR(sd_ctx_params->clip_g_path),
@ -2692,7 +2763,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->circular_y),
BOOL_STR(sd_ctx_params->chroma_use_dit_mask), BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
BOOL_STR(sd_ctx_params->chroma_use_t5_mask), BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
sd_ctx_params->chroma_t5_mask_pad); sd_ctx_params->chroma_t5_mask_pad,
sd_vae_format_name(sd_ctx_params->vae_format));
return buf; return buf;
} }
@ -2969,6 +3041,9 @@ SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx) {
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
if (sd_version_is_pid(sd_ctx->sd->version)) {
return LCM_SAMPLE_METHOD;
}
if (sd_version_is_dit(sd_ctx->sd->version)) { if (sd_version_is_dit(sd_ctx->sd->version)) {
return EULER_SAMPLE_METHOD; return EULER_SAMPLE_METHOD;
} }
@ -3867,7 +3942,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
continue; continue;
} }
sd::Tensor<float> ref_latent; sd::Tensor<float> ref_latent;
if (request->auto_resize_ref_image) { if (request->auto_resize_ref_image && !sd_version_is_pid(sd_ctx->sd->version)) {
LOG_DEBUG("auto resize ref images"); LOG_DEBUG("auto resize ref images");
int vae_image_size = std::min(1024 * 1024, request->width * request->height); int vae_image_size = std::min(1024 * 1024, request->width * request->height);
double vae_width = sqrt(vae_image_size * ref_images[i].shape()[0] / ref_images[i].shape()[1]); double vae_width = sqrt(vae_image_size * ref_images[i].shape()[0] / ref_images[i].shape()[1]);
@ -3899,6 +3974,13 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
ref_latents.push_back(std::move(ref_latent)); ref_latents.push_back(std::move(ref_latent));
} }
if (sd_version_is_pid(sd_ctx->sd->version)) {
if (ref_latents.empty()) {
LOG_ERROR("PiD requires a reference image");
return std::nullopt;
}
}
sd::Tensor<float> concat_latent; sd::Tensor<float> concat_latent;
sd::Tensor<float> uncond_concat_latent; sd::Tensor<float> uncond_concat_latent;
if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (sd_version_is_inpaint(sd_ctx->sd->version)) {

View File

@ -478,7 +478,7 @@ struct T5Embedder {
bool alloc_params_buffer() { bool alloc_params_buffer() {
if (!model.alloc_params_buffer()) { if (!model.alloc_params_buffer()) {
return false; return false;
} }
return true; return true;
} }

View File

@ -182,7 +182,8 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
unsigned char b = utf8_token_str[i]; unsigned char b = utf8_token_str[i];
char hex_buf[16]; char hex_buf[16];
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b); snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b);
iter = encoder.find(utf8_to_utf32(hex_buf)); iter = encoder.find(utf8_to_utf32(hex_buf));
token_id = iter != encoder.end() ? iter->second : UNK_TOKEN_ID;
bpe_tokens.push_back(token_id); bpe_tokens.push_back(token_id);
token_strs.push_back(hex_buf); token_strs.push_back(hex_buf);
} }

View File

@ -189,3 +189,164 @@ GemmaTokenizer::GemmaTokenizer(const std::string& merges_utf8_str, const std::st
load_from_merges(load_gemma_merges(), load_gemma_vocab_json()); load_from_merges(load_gemma_merges(), load_gemma_vocab_json());
} }
} }
std::string Gemma2Tokenizer::normalize(const std::string& text) const {
std::string normalized = text;
size_t pos = 0;
while ((pos = normalized.find(' ', pos)) != std::string::npos) {
normalized.replace(pos, 1, "\xE2\x96\x81");
pos += 3;
}
return normalized;
}
void Gemma2Tokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
nlohmann::json vocab;
try {
vocab = nlohmann::json::parse(vocab_utf8_str);
} catch (const nlohmann::json::parse_error&) {
GGML_ABORT("invalid vocab json str");
}
for (const auto& [key, value] : vocab.items()) {
std::u32string token = utf8_to_utf32(key);
int i = value;
encoder[token] = i;
decoder[i] = token;
}
encoder_len = static_cast<int>(vocab.size());
LOG_DEBUG("vocab size: %d", encoder_len);
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
}
LOG_DEBUG("merges size %zu", merge_pairs.size());
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
}
Gemma2Tokenizer::Gemma2Tokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
byte_level_bpe = false;
byte_fallback = true;
add_bos_token = true;
PAD_TOKEN = "<pad>";
EOS_TOKEN = "<eos>";
BOS_TOKEN = "<bos>";
UNK_TOKEN = "<unk>";
PAD_TOKEN_ID = 0;
EOS_TOKEN_ID = 1;
BOS_TOKEN_ID = 2;
UNK_TOKEN_ID = 3;
std::vector<std::string> special_tokens_before_merge = {
PAD_TOKEN,
EOS_TOKEN,
BOS_TOKEN,
UNK_TOKEN,
"<mask>",
"<2mass>",
"[@BOS@]",
};
for (int i = 0; i <= 98; i++) {
special_tokens_before_merge.push_back("<unused" + std::to_string(i) + ">");
}
special_tokens_before_merge.push_back("<start_of_turn>");
special_tokens_before_merge.push_back("<end_of_turn>");
for (int i = 1; i <= 31; i++) {
special_tokens_before_merge.push_back(std::string(i, '\n'));
}
for (int i = 2; i <= 31; i++) {
std::string whitespace_token;
for (int j = 0; j < i; j++) {
whitespace_token += "\xE2\x96\x81";
}
special_tokens_before_merge.push_back(whitespace_token);
}
std::vector<std::string> html_tokens = {
"<table>",
"<caption>",
"<thead>",
"<tbody>",
"<tfoot>",
"<tr>",
"<th>",
"<td>",
"</table>",
"</caption>",
"</thead>",
"</tbody>",
"</tfoot>",
"</tr>",
"</th>",
"</td>",
"<h1>",
"<h2>",
"<h3>",
"<h4>",
"<h5>",
"<h6>",
"<blockquote>",
"</h1>",
"</h2>",
"</h3>",
"</h4>",
"</h5>",
"</h6>",
"</blockquote>",
"<strong>",
"<em>",
"<b>",
"<i>",
"<u>",
"<s>",
"<sub>",
"<sup>",
"<code>",
"</strong>",
"</em>",
"</b>",
"</i>",
"</u>",
"</s>",
"</sub>",
"</sup>",
"</code>",
};
special_tokens_before_merge.insert(special_tokens_before_merge.end(),
html_tokens.begin(),
html_tokens.end());
for (int i = 0; i <= 0xFF; i++) {
char hex_buf[16];
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", i);
special_tokens_before_merge.push_back(hex_buf);
}
std::vector<std::string> special_tokens_after_merge = {
"[toxicity=0]",
};
for (int i = 1; i <= 31; i++) {
special_tokens_after_merge.insert(special_tokens_after_merge.begin() + i - 1,
std::string(i, '\t'));
}
for (int i = 99; i <= 99; i++) {
special_tokens_after_merge.push_back("<unused" + std::to_string(i) + ">");
}
special_tokens = special_tokens_before_merge;
special_tokens.insert(special_tokens.end(),
special_tokens_after_merge.begin(),
special_tokens_after_merge.end());
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
load_from_merges(merges_utf8_str, vocab_utf8_str);
} else {
load_from_merges(load_gemma2_merges(), load_gemma2_vocab_json());
}
}

View File

@ -14,4 +14,13 @@ public:
explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = ""); explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
}; };
class Gemma2Tokenizer : public BPETokenizer {
protected:
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
std::string normalize(const std::string& text) const override;
public:
explicit Gemma2Tokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
};
#endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__ #endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,7 @@
#include "vocab.h" #include "vocab.h"
#include "clip_merges.hpp" #include "clip_merges.hpp"
#include "gemma2_merges.hpp"
#include "gemma2_vocab.hpp"
#include "gemma_merges.hpp" #include "gemma_merges.hpp"
#include "gemma_vocab.hpp" #include "gemma_vocab.hpp"
#include "gpt_oss_merges.hpp" #include "gpt_oss_merges.hpp"
@ -50,6 +52,16 @@ std::string load_gemma_vocab_json() {
return json_str; return json_str;
} }
std::string load_gemma2_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(gemma2_merges_utf8_c_str), sizeof(gemma2_merges_utf8_c_str));
return merges_utf8_str;
}
std::string load_gemma2_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(gemma2_vocab_json_utf8_c_str), sizeof(gemma2_vocab_json_utf8_c_str));
return json_str;
}
std::string load_gpt_oss_merges() { std::string load_gpt_oss_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(gpt_oss_merges_utf8_c_str), sizeof(gpt_oss_merges_utf8_c_str)); std::string merges_utf8_str(reinterpret_cast<const char*>(gpt_oss_merges_utf8_c_str), sizeof(gpt_oss_merges_utf8_c_str));
return merges_utf8_str; return merges_utf8_str;

View File

@ -11,6 +11,8 @@ std::string load_t5_tokenizer_json();
std::string load_umt5_tokenizer_json(); std::string load_umt5_tokenizer_json();
std::string load_gemma_merges(); std::string load_gemma_merges();
std::string load_gemma_vocab_json(); std::string load_gemma_vocab_json();
std::string load_gemma2_merges();
std::string load_gemma2_vocab_json();
std::string load_gpt_oss_merges(); std::string load_gpt_oss_merges();
std::string load_gpt_oss_vocab_json(); std::string load_gpt_oss_vocab_json();