mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add PiD support (#1585)
This commit is contained in:
parent
d2797b8667
commit
0982807139
@ -15,6 +15,7 @@ API and command-line option may change frequently.***
|
|||||||
|
|
||||||
## 🔥Important News
|
## 🔥Important News
|
||||||
|
|
||||||
|
* **2026/05/31** 🚀 stable-diffusion.cpp now supports **PiD**
|
||||||
* **2026/05/27** 🚀 stable-diffusion.cpp now supports **Lens**
|
* **2026/05/27** 🚀 stable-diffusion.cpp now supports **Lens**
|
||||||
* **2026/05/17** 🚀 stable-diffusion.cpp now supports **LTX-2.3**
|
* **2026/05/17** 🚀 stable-diffusion.cpp now supports **LTX-2.3**
|
||||||
* **2026/04/11** 🚀 stable-diffusion.cpp now uses a brand-new embedded web UI.
|
* **2026/04/11** 🚀 stable-diffusion.cpp now uses a brand-new embedded web UI.
|
||||||
@ -42,6 +43,7 @@ API and command-line option may change frequently.***
|
|||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
||||||
- [Qwen Image](./docs/qwen_image.md)
|
- [Qwen Image](./docs/qwen_image.md)
|
||||||
|
- [PiD](./docs/pid.md)
|
||||||
- [LongCat Image](./docs/longcat_image.md)
|
- [LongCat Image](./docs/longcat_image.md)
|
||||||
- [Z-Image](./docs/z_image.md)
|
- [Z-Image](./docs/z_image.md)
|
||||||
- [Ovis-Image](./docs/ovis_image.md)
|
- [Ovis-Image](./docs/ovis_image.md)
|
||||||
|
|||||||
BIN
assets/pid/example.png
Normal file
BIN
assets/pid/example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.0 MiB |
39
docs/pid.md
Normal file
39
docs/pid.md
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# How to Use
|
||||||
|
|
||||||
|
PiD is NVIDIA's Pixel Diffusion Decoder. It replaces the usual VAE decode or decode-then-upscale path with a pixel-space diffusion decoder conditioned on a
|
||||||
|
source latent and text prompt.
|
||||||
|
|
||||||
|
In stable-diffusion.cpp, PiD currently runs as an image edit pipeline: provide a reference image with `-r`/`--ref-image`, encode that image with a matching VAE, then let the PiD diffusion model decode/upscale directly to RGB.
|
||||||
|
|
||||||
|
## Download weights
|
||||||
|
|
||||||
|
- Download PiD
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/PixelDiT/tree/main/diffusion_models
|
||||||
|
- Download Gemma 2 2B
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/PixelDiT/tree/main/text_encoders
|
||||||
|
- Download the VAE that matches the PiD checkpoint backbone
|
||||||
|
- safetensors: https://huggingface.co/nvidia/PiD/tree/main/checkpoints
|
||||||
|
- Flux / Z-Image PiD: use the Flux VAE and pass `--vae-format flux`
|
||||||
|
- SD3 PiD: use the SD3 VAE and pass `--vae-format sd3`
|
||||||
|
- Flux.2 PiD: use the Flux.2 VAE and pass `--vae-format flux2`
|
||||||
|
|
||||||
|
The official PiD model card should be checked before use. At the time of the initial PiD release, the official weights are under the NSCLv1 non-commercial license.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\pid_flux1_512_to_2048_4step_bf16.safetensors --llm "..\..\ComfyUI\models\text_encoders\gemma_2_2b_it_elm_bf16.safetensors" --vae ..\..\ComfyUI\models\vae\ae.sft --vae-format flux --cfg-scale 1.0 -p "a lovely cat" -r ..\assets\ernie_image\turbo_example.png --diffusion-fa -v --steps 4 -H 2048 -W 2048 --rng cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
Before:
|
||||||
|
|
||||||
|
<img width="256" alt="ERNIE-Image Turbo example" src="../assets/ernie_image/turbo_example.png" />
|
||||||
|
|
||||||
|
After:
|
||||||
|
<img width="1024" alt="PiD example" src="../assets/pid/example.png" />
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- `-r`/`--ref-image` is required. PiD uses the first reference image as the source latent condition.
|
||||||
|
- `--vae-format` should match the VAE latent layout used by the PiD checkpoint. This is important when using standalone VAE files because the PiD diffusion
|
||||||
|
checkpoint alone does not identify the VAE format.
|
||||||
@ -35,6 +35,22 @@ const char* const modes_str[] = {
|
|||||||
"metadata",
|
"metadata",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static sd_vae_format_t str_to_vae_format(const std::string& value) {
|
||||||
|
if (value == "auto") {
|
||||||
|
return SD_VAE_FORMAT_AUTO;
|
||||||
|
}
|
||||||
|
if (value == "flux") {
|
||||||
|
return SD_VAE_FORMAT_FLUX;
|
||||||
|
}
|
||||||
|
if (value == "sd3") {
|
||||||
|
return SD_VAE_FORMAT_SD3;
|
||||||
|
}
|
||||||
|
if (value == "flux2") {
|
||||||
|
return SD_VAE_FORMAT_FLUX2;
|
||||||
|
}
|
||||||
|
return SD_VAE_FORMAT_COUNT;
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
static std::string utf16_to_utf8(const std::wstring& wstr) {
|
static std::string utf16_to_utf8(const std::wstring& wstr) {
|
||||||
if (wstr.empty())
|
if (wstr.empty())
|
||||||
@ -348,6 +364,10 @@ ArgOptions SDContextParams::get_options() {
|
|||||||
"--vae",
|
"--vae",
|
||||||
"path to standalone vae model",
|
"path to standalone vae model",
|
||||||
&vae_path},
|
&vae_path},
|
||||||
|
{"",
|
||||||
|
"--vae-format",
|
||||||
|
"VAE latent format override: auto, flux, sd3, or flux2 (default: auto)",
|
||||||
|
&vae_format},
|
||||||
{"",
|
{"",
|
||||||
"--audio-vae",
|
"--audio-vae",
|
||||||
"path to standalone LTX audio vae model",
|
"path to standalone LTX audio vae model",
|
||||||
@ -639,6 +659,11 @@ bool SDContextParams::validate(SDMode mode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (str_to_vae_format(vae_format) == SD_VAE_FORMAT_COUNT) {
|
||||||
|
LOG_ERROR("error: vae_format must be 'auto', 'flux', 'sd3', or 'flux2'");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -679,6 +704,7 @@ std::string SDContextParams::to_string() const {
|
|||||||
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
|
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
|
||||||
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
|
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
|
||||||
<< " vae_path: \"" << vae_path << "\",\n"
|
<< " vae_path: \"" << vae_path << "\",\n"
|
||||||
|
<< " vae_format: \"" << vae_format << "\",\n"
|
||||||
<< " audio_vae_path: \"" << audio_vae_path << "\",\n"
|
<< " audio_vae_path: \"" << audio_vae_path << "\",\n"
|
||||||
<< " taesd_path: \"" << taesd_path << "\",\n"
|
<< " taesd_path: \"" << taesd_path << "\",\n"
|
||||||
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
||||||
@ -772,6 +798,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
|
|||||||
chroma_use_t5_mask,
|
chroma_use_t5_mask,
|
||||||
chroma_t5_mask_pad,
|
chroma_t5_mask_pad,
|
||||||
qwen_image_zero_cond_t,
|
qwen_image_zero_cond_t,
|
||||||
|
str_to_vae_format(vae_format),
|
||||||
max_vram,
|
max_vram,
|
||||||
backend.c_str(),
|
backend.c_str(),
|
||||||
params_backend.c_str(),
|
params_backend.c_str(),
|
||||||
|
|||||||
@ -94,6 +94,7 @@ struct SDContextParams {
|
|||||||
std::string high_noise_diffusion_model_path;
|
std::string high_noise_diffusion_model_path;
|
||||||
std::string embeddings_connectors_path;
|
std::string embeddings_connectors_path;
|
||||||
std::string vae_path;
|
std::string vae_path;
|
||||||
|
std::string vae_format = "auto";
|
||||||
std::string audio_vae_path;
|
std::string audio_vae_path;
|
||||||
std::string taesd_path;
|
std::string taesd_path;
|
||||||
std::string esrgan_path;
|
std::string esrgan_path;
|
||||||
|
|||||||
@ -168,6 +168,14 @@ typedef struct {
|
|||||||
const char* path;
|
const char* path;
|
||||||
} sd_embedding_t;
|
} sd_embedding_t;
|
||||||
|
|
||||||
|
enum sd_vae_format_t {
|
||||||
|
SD_VAE_FORMAT_AUTO = -1,
|
||||||
|
SD_VAE_FORMAT_FLUX,
|
||||||
|
SD_VAE_FORMAT_SD3,
|
||||||
|
SD_VAE_FORMAT_FLUX2,
|
||||||
|
SD_VAE_FORMAT_COUNT,
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
const char* model_path;
|
const char* model_path;
|
||||||
const char* clip_l_path;
|
const char* clip_l_path;
|
||||||
@ -212,6 +220,7 @@ typedef struct {
|
|||||||
bool chroma_use_t5_mask;
|
bool chroma_use_t5_mask;
|
||||||
int chroma_t5_mask_pad;
|
int chroma_t5_mask_pad;
|
||||||
bool qwen_image_zero_cond_t;
|
bool qwen_image_zero_cond_t;
|
||||||
|
enum sd_vae_format_t vae_format;
|
||||||
float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB)
|
float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB)
|
||||||
const char* backend;
|
const char* backend;
|
||||||
const char* params_backend;
|
const char* params_backend;
|
||||||
|
|||||||
@ -1171,7 +1171,6 @@ struct FluxCLIPEmbedder : public Conditioner {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void free_params_buffer() override {
|
void free_params_buffer() override {
|
||||||
if (clip_l) {
|
if (clip_l) {
|
||||||
clip_l->free_params_buffer();
|
clip_l->free_params_buffer();
|
||||||
@ -1601,8 +1600,8 @@ struct AnimaConditioner : public Conditioner {
|
|||||||
|
|
||||||
bool alloc_params_buffer() override {
|
bool alloc_params_buffer() override {
|
||||||
if (!llm->alloc_params_buffer()) {
|
if (!llm->alloc_params_buffer()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1719,6 +1718,8 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
arch = LLM::LLMArch::MINISTRAL_3_3B;
|
arch = LLM::LLMArch::MINISTRAL_3_3B;
|
||||||
} else if (sd_version_is_lens(version)) {
|
} else if (sd_version_is_lens(version)) {
|
||||||
arch = LLM::LLMArch::GPT_OSS_20B;
|
arch = LLM::LLMArch::GPT_OSS_20B;
|
||||||
|
} else if (sd_version_is_pid(version)) {
|
||||||
|
arch = LLM::LLMArch::GEMMA2_2B;
|
||||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
||||||
arch = LLM::LLMArch::QWEN3;
|
arch = LLM::LLMArch::QWEN3;
|
||||||
}
|
}
|
||||||
@ -1726,6 +1727,8 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
tokenizer = std::make_shared<MistralTokenizer>();
|
tokenizer = std::make_shared<MistralTokenizer>();
|
||||||
} else if (arch == LLM::LLMArch::GPT_OSS_20B) {
|
} else if (arch == LLM::LLMArch::GPT_OSS_20B) {
|
||||||
tokenizer = std::make_shared<GPTOSSTokenizer>();
|
tokenizer = std::make_shared<GPTOSSTokenizer>();
|
||||||
|
} else if (arch == LLM::LLMArch::GEMMA2_2B) {
|
||||||
|
tokenizer = std::make_shared<Gemma2Tokenizer>();
|
||||||
} else {
|
} else {
|
||||||
tokenizer = std::make_shared<Qwen2Tokenizer>();
|
tokenizer = std::make_shared<Qwen2Tokenizer>();
|
||||||
}
|
}
|
||||||
@ -1743,7 +1746,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
|
|
||||||
bool alloc_params_buffer() override {
|
bool alloc_params_buffer() override {
|
||||||
if (!llm->alloc_params_buffer()) {
|
if (!llm->alloc_params_buffer()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1847,12 +1850,16 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
|
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
|
||||||
sd::Tensor<float> attention_mask;
|
sd::Tensor<float> attention_mask;
|
||||||
if (!mask.empty()) {
|
if (!mask.empty()) {
|
||||||
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
|
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
|
||||||
|
const float masked_attention_value = -std::numeric_limits<float>::max() / 4.0f;
|
||||||
for (size_t i1 = 0; i1 < mask.size(); ++i1) {
|
for (size_t i1 = 0; i1 < mask.size(); ++i1) {
|
||||||
for (size_t i0 = 0; i0 < mask.size(); ++i0) {
|
for (size_t i0 = 0; i0 < mask.size(); ++i0) {
|
||||||
float value = 0.0f;
|
float value = 0.0f;
|
||||||
if (mask[i0] == 0.0f || i0 > i1) {
|
if (mask[i0] == 0.0f) {
|
||||||
value = -INFINITY;
|
value += masked_attention_value;
|
||||||
|
}
|
||||||
|
if (i0 > i1) {
|
||||||
|
value += masked_attention_value;
|
||||||
}
|
}
|
||||||
attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value;
|
attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value;
|
||||||
}
|
}
|
||||||
@ -2126,6 +2133,53 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||||
|
} else if (sd_version_is_pid(version)) {
|
||||||
|
constexpr int pixeldit_max_length = 300;
|
||||||
|
const std::string chi_prompt =
|
||||||
|
"Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
||||||
|
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\n"
|
||||||
|
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n"
|
||||||
|
"Here are examples of how to transform or refine prompts:\n"
|
||||||
|
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
||||||
|
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n"
|
||||||
|
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\n"
|
||||||
|
"User Prompt: ";
|
||||||
|
auto chi_tokens = std::get<0>(tokenize(chi_prompt, {0, 0}));
|
||||||
|
size_t num_chi_tokens = chi_tokens.size();
|
||||||
|
max_length = (int)num_chi_tokens + pixeldit_max_length - 2;
|
||||||
|
min_length = max_length;
|
||||||
|
|
||||||
|
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||||
|
prompt += " " + conditioner_params.text;
|
||||||
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
|
auto hidden_states = encode_prompt(n_threads,
|
||||||
|
prompt,
|
||||||
|
prompt_attn_range,
|
||||||
|
min_length,
|
||||||
|
0,
|
||||||
|
image_embeds,
|
||||||
|
out_layers,
|
||||||
|
0,
|
||||||
|
false,
|
||||||
|
max_length);
|
||||||
|
GGML_ASSERT(!hidden_states.empty());
|
||||||
|
|
||||||
|
if (hidden_states.shape()[1] > pixeldit_max_length) {
|
||||||
|
auto bos = sd::ops::slice(hidden_states, 1, 0, 1);
|
||||||
|
auto tail = sd::ops::slice(hidden_states,
|
||||||
|
1,
|
||||||
|
hidden_states.shape()[1] - (pixeldit_max_length - 1),
|
||||||
|
hidden_states.shape()[1]);
|
||||||
|
hidden_states = sd::ops::concat(bos, tail, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
|
|
||||||
|
SDCondition result;
|
||||||
|
result.c_crossattn = std::move(hidden_states);
|
||||||
|
return result;
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("unknown version %d", version);
|
GGML_ABORT("unknown version %d", version);
|
||||||
}
|
}
|
||||||
@ -2268,10 +2322,10 @@ struct LTXAVEmbedder : public Conditioner {
|
|||||||
|
|
||||||
bool alloc_params_buffer() override {
|
bool alloc_params_buffer() override {
|
||||||
if (!llm->alloc_params_buffer()) {
|
if (!llm->alloc_params_buffer()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!projector->alloc_params_buffer()) {
|
if (!projector->alloc_params_buffer()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
59
src/llm.hpp
59
src/llm.hpp
@ -37,6 +37,7 @@ namespace LLM {
|
|||||||
MISTRAL_SMALL_3_2,
|
MISTRAL_SMALL_3_2,
|
||||||
MINISTRAL_3_3B,
|
MINISTRAL_3_3B,
|
||||||
GEMMA3_12B,
|
GEMMA3_12B,
|
||||||
|
GEMMA2_2B,
|
||||||
GPT_OSS_20B,
|
GPT_OSS_20B,
|
||||||
ARCH_COUNT,
|
ARCH_COUNT,
|
||||||
};
|
};
|
||||||
@ -48,6 +49,7 @@ namespace LLM {
|
|||||||
"mistral_small3.2",
|
"mistral_small3.2",
|
||||||
"ministral3.3b",
|
"ministral3.3b",
|
||||||
"gemma3_12b",
|
"gemma3_12b",
|
||||||
|
"gemma2_2b",
|
||||||
"gpt_oss_20b",
|
"gpt_oss_20b",
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -900,6 +902,33 @@ namespace LLM {
|
|||||||
1.f,
|
1.f,
|
||||||
32.f,
|
32.f,
|
||||||
1.f);
|
1.f);
|
||||||
|
} else if (arch == LLMArch::GEMMA2_2B) {
|
||||||
|
q = ggml_rope_ext(ctx->ggml_ctx,
|
||||||
|
q,
|
||||||
|
input_pos,
|
||||||
|
nullptr,
|
||||||
|
head_dim,
|
||||||
|
GGML_ROPE_TYPE_NEOX,
|
||||||
|
8192,
|
||||||
|
10000.f,
|
||||||
|
1.f,
|
||||||
|
0.f,
|
||||||
|
1.f,
|
||||||
|
32.f,
|
||||||
|
1.f);
|
||||||
|
k = ggml_rope_ext(ctx->ggml_ctx,
|
||||||
|
k,
|
||||||
|
input_pos,
|
||||||
|
nullptr,
|
||||||
|
head_dim,
|
||||||
|
GGML_ROPE_TYPE_NEOX,
|
||||||
|
8192,
|
||||||
|
10000.f,
|
||||||
|
1.f,
|
||||||
|
0.f,
|
||||||
|
1.f,
|
||||||
|
32.f,
|
||||||
|
1.f);
|
||||||
} else if (arch == LLMArch::QWEN3_VL) {
|
} else if (arch == LLMArch::QWEN3_VL) {
|
||||||
int sections[4] = {24, 20, 20, 0};
|
int sections[4] = {24, 20, 20, 0};
|
||||||
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||||
@ -957,10 +986,18 @@ namespace LLM {
|
|||||||
: arch(params.arch),
|
: arch(params.arch),
|
||||||
sliding_attention(0) {
|
sliding_attention(0) {
|
||||||
if (params.arch == LLMArch::GEMMA3_12B) {
|
if (params.arch == LLMArch::GEMMA3_12B) {
|
||||||
post_attention_norm_name = "post_attention_norm";
|
post_attention_norm_name = "post_attention_norm"; // attn_post_norm
|
||||||
post_ffw_norm_name = "post_ffw_norm";
|
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
|
||||||
|
post_ffw_norm_name = "post_ffw_norm"; // ffn_post_norm
|
||||||
|
} else if (params.arch == LLMArch::GEMMA2_2B) {
|
||||||
|
post_attention_norm_name = "post_attention_layernorm"; // ffn_norm
|
||||||
|
pre_ffw_norm_name = "pre_feedforward_layernorm";
|
||||||
|
post_ffw_norm_name = "post_feedforward_layernorm";
|
||||||
|
} else if (params.arch == LLMArch::GPT_OSS_20B) {
|
||||||
|
pre_ffw_norm_name = "post_attention_norm"; // attn_post_norm
|
||||||
|
} else {
|
||||||
|
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
|
||||||
}
|
}
|
||||||
pre_ffw_norm_name = params.arch == LLMArch::GPT_OSS_20B ? "post_attention_norm" : "post_attention_layernorm";
|
|
||||||
|
|
||||||
blocks["self_attn"] = std::make_shared<Attention>(params);
|
blocks["self_attn"] = std::make_shared<Attention>(params);
|
||||||
if (params.arch == LLMArch::GPT_OSS_20B) {
|
if (params.arch == LLMArch::GPT_OSS_20B) {
|
||||||
@ -1447,6 +1484,21 @@ namespace LLM {
|
|||||||
params.rope_thetas = {1000000.f, 10000.f};
|
params.rope_thetas = {1000000.f, 10000.f};
|
||||||
params.rope_scales = {8.f, 1.f};
|
params.rope_scales = {8.f, 1.f};
|
||||||
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
|
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
|
||||||
|
} else if (arch == LLMArch::GEMMA2_2B) {
|
||||||
|
params.head_dim = 256;
|
||||||
|
params.num_heads = 8;
|
||||||
|
params.num_kv_heads = 4;
|
||||||
|
params.qkv_bias = false;
|
||||||
|
params.qk_norm = false;
|
||||||
|
params.rms_norm_eps = 1e-6f;
|
||||||
|
params.rms_norm_add = true;
|
||||||
|
params.normalize_input = true;
|
||||||
|
params.max_position_embeddings = 8192;
|
||||||
|
params.mlp_activation = MLPActivation::GELU_TANH;
|
||||||
|
params.hidden_size = 2304;
|
||||||
|
params.intermediate_size = 9216;
|
||||||
|
params.num_layers = 26;
|
||||||
|
params.vocab_size = 256000;
|
||||||
} else if (arch == LLMArch::GPT_OSS_20B) {
|
} else if (arch == LLMArch::GPT_OSS_20B) {
|
||||||
params.head_dim = 64;
|
params.head_dim = 64;
|
||||||
params.num_heads = 64;
|
params.num_heads = 64;
|
||||||
@ -1585,6 +1637,7 @@ namespace LLM {
|
|||||||
params.arch == LLMArch::MINISTRAL_3_3B ||
|
params.arch == LLMArch::MINISTRAL_3_3B ||
|
||||||
params.arch == LLMArch::QWEN3 ||
|
params.arch == LLMArch::QWEN3 ||
|
||||||
params.arch == LLMArch::GEMMA3_12B ||
|
params.arch == LLMArch::GEMMA3_12B ||
|
||||||
|
params.arch == LLMArch::GEMMA2_2B ||
|
||||||
params.arch == LLMArch::GPT_OSS_20B) {
|
params.arch == LLMArch::GPT_OSS_20B) {
|
||||||
input_pos_vec.resize(n_tokens);
|
input_pos_vec.resize(n_tokens);
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
|||||||
@ -91,7 +91,6 @@ struct LoraModel : public GGMLRunner {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
dry_run = false;
|
dry_run = false;
|
||||||
model_loader.load_tensors(on_new_tensor_cb, n_threads);
|
model_loader.load_tensors(on_new_tensor_cb, n_threads);
|
||||||
|
|
||||||
|
|||||||
@ -1069,8 +1069,8 @@ namespace LTXV {
|
|||||||
prefix);
|
prefix);
|
||||||
|
|
||||||
if (!ltx_audio_vae->alloc_params_buffer()) {
|
if (!ltx_audio_vae->alloc_params_buffer()) {
|
||||||
LOG_ERROR("ltx audio vae buffer allocation failed");
|
LOG_ERROR("ltx audio vae buffer allocation failed");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, ggml_tensor*> tensors;
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
|||||||
@ -432,6 +432,9 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
|
tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
|
||||||
is_flux = true;
|
is_flux = true;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.net.lq_proj.latent_proj.0.weight") != std::string::npos) {
|
||||||
|
return VERSION_PID;
|
||||||
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
|
||||||
return VERSION_CHROMA_RADIANCE;
|
return VERSION_CHROMA_RADIANCE;
|
||||||
}
|
}
|
||||||
|
|||||||
11
src/model.h
11
src/model.h
@ -49,6 +49,7 @@ enum SDVersion {
|
|||||||
VERSION_ERNIE_IMAGE,
|
VERSION_ERNIE_IMAGE,
|
||||||
VERSION_LENS,
|
VERSION_LENS,
|
||||||
VERSION_LONGCAT,
|
VERSION_LONGCAT,
|
||||||
|
VERSION_PID,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -164,6 +165,13 @@ static inline bool sd_version_is_lens(SDVersion version) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_pid(SDVersion version) {
|
||||||
|
if (version == VERSION_PID) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
|
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
|
||||||
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version) || sd_version_is_lens(version)) {
|
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version) || sd_version_is_lens(version)) {
|
||||||
return true;
|
return true;
|
||||||
@ -194,7 +202,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
|
|||||||
sd_version_is_z_image(version) ||
|
sd_version_is_z_image(version) ||
|
||||||
sd_version_is_ernie_image(version) ||
|
sd_version_is_ernie_image(version) ||
|
||||||
sd_version_is_lens(version) ||
|
sd_version_is_lens(version) ||
|
||||||
sd_version_is_longcat(version)) {
|
sd_version_is_longcat(version) ||
|
||||||
|
sd_version_is_pid(version)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
842
src/pid.hpp
Normal file
842
src/pid.hpp
Normal file
@ -0,0 +1,842 @@
|
|||||||
|
#ifndef __SD_PID_HPP__
|
||||||
|
#define __SD_PID_HPP__
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common_dit.hpp"
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
#include "mmdit.hpp"
|
||||||
|
#include "rope.hpp"
|
||||||
|
|
||||||
|
namespace Pid {
|
||||||
|
constexpr int PID_GRAPH_SIZE = 196608;
|
||||||
|
constexpr float PID_PI = 3.14159265358979323846f;
|
||||||
|
|
||||||
|
struct PixelDiTParams {
|
||||||
|
int64_t in_channels = 3;
|
||||||
|
int64_t hidden_size = 1536;
|
||||||
|
int64_t num_groups = 24;
|
||||||
|
int64_t patch_mlp_hidden_dim = 4096;
|
||||||
|
int64_t pixel_hidden_size = 16;
|
||||||
|
int64_t pixel_attn_hidden_size = 1152;
|
||||||
|
int64_t pixel_num_groups = 16;
|
||||||
|
int64_t patch_depth = 14;
|
||||||
|
int64_t pixel_depth = 2;
|
||||||
|
int64_t patch_size = 16;
|
||||||
|
int64_t txt_embed_dim = 2304;
|
||||||
|
int64_t txt_max_length = 300;
|
||||||
|
float text_rope_theta = 10000.f;
|
||||||
|
int64_t lq_latent_channels = 16;
|
||||||
|
int64_t lq_hidden_dim = 512;
|
||||||
|
int64_t lq_num_res_blocks = 4;
|
||||||
|
int64_t lq_interval = 2;
|
||||||
|
int64_t lq_sr_scale = 4;
|
||||||
|
int64_t lq_latent_down_factor = 8;
|
||||||
|
int64_t rope_ref_grid_h = 64;
|
||||||
|
int64_t rope_ref_grid_w = 64;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::vector<float> make_rope_1d(int length,
|
||||||
|
int dim,
|
||||||
|
float theta) {
|
||||||
|
GGML_ASSERT(dim % 2 == 0);
|
||||||
|
return Rope::flatten(Rope::rope(Rope::linspace(0.f, static_cast<float>(length - 1), length), dim, theta));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<float> make_rope_2d(int height,
|
||||||
|
int width,
|
||||||
|
int dim,
|
||||||
|
float theta = 10000.f,
|
||||||
|
float scale = 16.f,
|
||||||
|
int ref_grid_h = 0,
|
||||||
|
int ref_grid_w = 0) {
|
||||||
|
GGML_ASSERT(dim % 4 == 0);
|
||||||
|
return Rope::embed_2d_interleaved(height, width, dim, theta, scale, ref_grid_h, ref_grid_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<float> make_pixel_abs_pos(int height,
|
||||||
|
int width,
|
||||||
|
int dim) {
|
||||||
|
GGML_ASSERT(dim % 4 == 0);
|
||||||
|
int half_dim = dim / 2;
|
||||||
|
std::vector<float> x_pos;
|
||||||
|
std::vector<float> y_pos;
|
||||||
|
x_pos.reserve(static_cast<size_t>(height) * width);
|
||||||
|
y_pos.reserve(static_cast<size_t>(height) * width);
|
||||||
|
for (int iy = 0; iy < height; ++iy) {
|
||||||
|
for (int ix = 0; ix < width; ++ix) {
|
||||||
|
x_pos.push_back(static_cast<float>(ix));
|
||||||
|
y_pos.push_back(static_cast<float>(iy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_emb = timestep_embedding(x_pos, half_dim, 10000, false);
|
||||||
|
auto y_emb = timestep_embedding(y_pos, half_dim, 10000, false);
|
||||||
|
|
||||||
|
std::vector<float> out(static_cast<size_t>(dim) * height * width);
|
||||||
|
for (int pos = 0; pos < height * width; ++pos) {
|
||||||
|
size_t out_base = static_cast<size_t>(pos) * dim;
|
||||||
|
size_t emb_base = static_cast<size_t>(pos) * half_dim;
|
||||||
|
for (int i = 0; i < half_dim; ++i) {
|
||||||
|
out[out_base + i] = x_emb[emb_base + i];
|
||||||
|
out[out_base + half_dim + i] = y_emb[emb_base + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ggml_tensor* apply_adaln(ggml_context* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* shift,
|
||||||
|
ggml_tensor* scale) {
|
||||||
|
return ggml_add(ctx, ggml_add(ctx, x, ggml_mul(ctx, x, scale)), shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PatchTokenEmbedder : public GGMLBlock {
|
||||||
|
bool use_rms_norm;
|
||||||
|
|
||||||
|
PatchTokenEmbedder(int64_t in_chans,
|
||||||
|
int64_t embed_dim,
|
||||||
|
bool use_rms_norm = false,
|
||||||
|
bool bias = true)
|
||||||
|
: use_rms_norm(use_rms_norm) {
|
||||||
|
blocks["proj"] = std::make_shared<Linear>(in_chans, embed_dim, bias);
|
||||||
|
if (use_rms_norm) {
|
||||||
|
blocks["norm"] = std::make_shared<RMSNorm>(embed_dim, 1e-6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
x = proj->forward(ctx, x);
|
||||||
|
if (use_rms_norm) {
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PixelDiTTimestepEmbedder : public GGMLBlock {
|
||||||
|
int frequency_embedding_size;
|
||||||
|
|
||||||
|
PixelDiTTimestepEmbedder(int64_t hidden_size,
|
||||||
|
int frequency_embedding_size = 256)
|
||||||
|
: frequency_embedding_size(frequency_embedding_size) {
|
||||||
|
blocks["mlp.0"] = std::make_shared<Linear>(frequency_embedding_size, hidden_size, true, true);
|
||||||
|
blocks["mlp.2"] = std::make_shared<Linear>(hidden_size, hidden_size, true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* t) {
|
||||||
|
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
|
||||||
|
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
|
||||||
|
auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size, 10);
|
||||||
|
t_emb = mlp_0->forward(ctx, t_emb);
|
||||||
|
t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb);
|
||||||
|
return mlp_2->forward(ctx, t_emb);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FeedForward : public GGMLBlock {
|
||||||
|
FeedForward(int64_t dim, int64_t hidden_dim) {
|
||||||
|
blocks["w1"] = std::make_shared<Linear>(dim, hidden_dim, false);
|
||||||
|
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false);
|
||||||
|
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
auto w1 = std::dynamic_pointer_cast<Linear>(blocks["w1"]);
|
||||||
|
auto w2 = std::dynamic_pointer_cast<Linear>(blocks["w2"]);
|
||||||
|
auto w3 = std::dynamic_pointer_cast<Linear>(blocks["w3"]);
|
||||||
|
auto h = ggml_silu_inplace(ctx->ggml_ctx, w1->forward(ctx, x));
|
||||||
|
h = ggml_mul_inplace(ctx->ggml_ctx, h, w3->forward(ctx, x));
|
||||||
|
return w2->forward(ctx, h);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FinalLayer : public GGMLBlock {
|
||||||
|
FinalLayer(int64_t hidden_size, int64_t out_channels) {
|
||||||
|
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
|
||||||
|
blocks["linear"] = std::make_shared<Linear>(hidden_size, out_channels, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
|
return linear->forward(ctx, norm->forward(ctx, x));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RotaryAttention : public GGMLBlock {
|
||||||
|
int64_t dim;
|
||||||
|
int64_t num_heads;
|
||||||
|
|
||||||
|
RotaryAttention(int64_t dim, int64_t num_heads)
|
||||||
|
: dim(dim), num_heads(num_heads) {
|
||||||
|
int64_t head_dim = dim / num_heads;
|
||||||
|
blocks["qkv"] = std::make_shared<Linear>(dim, dim * 3, false);
|
||||||
|
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
|
||||||
|
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
|
||||||
|
blocks["proj"] = std::make_shared<Linear>(dim, dim, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* pos) {
|
||||||
|
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
|
||||||
|
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
||||||
|
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
|
||||||
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
|
auto qkv = qkv_proj->forward(ctx, x);
|
||||||
|
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
||||||
|
int64_t L = x->ne[1];
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t head_dim = dim / num_heads;
|
||||||
|
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, L, N);
|
||||||
|
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, L, N);
|
||||||
|
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, L, N);
|
||||||
|
q = q_norm->forward(ctx, q);
|
||||||
|
k = k_norm->forward(ctx, k);
|
||||||
|
x = Rope::attention(ctx, q, k, v, pos, nullptr, 1.0f / 128.f, true);
|
||||||
|
return proj->forward(ctx, x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MMDiTJointAttention : public GGMLBlock {
|
||||||
|
int64_t dim;
|
||||||
|
int64_t num_heads;
|
||||||
|
|
||||||
|
MMDiTJointAttention(int64_t dim, int64_t num_heads)
|
||||||
|
: dim(dim), num_heads(num_heads) {
|
||||||
|
int64_t head_dim = dim / num_heads;
|
||||||
|
blocks["qkv_x"] = std::make_shared<Linear>(dim, dim * 3, false);
|
||||||
|
blocks["qkv_y"] = std::make_shared<Linear>(dim, dim * 3, false);
|
||||||
|
blocks["q_norm_x"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
|
||||||
|
blocks["k_norm_x"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
|
||||||
|
blocks["q_norm_y"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
|
||||||
|
blocks["k_norm_y"] = std::make_shared<RMSNorm>(head_dim, 1e-6f);
|
||||||
|
blocks["proj_x"] = std::make_shared<Linear>(dim, dim, true);
|
||||||
|
blocks["proj_y"] = std::make_shared<Linear>(dim, dim, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* y,
|
||||||
|
ggml_tensor* pos_img,
|
||||||
|
ggml_tensor* pos_txt) {
|
||||||
|
auto qkv_x_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv_x"]);
|
||||||
|
auto qkv_y_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv_y"]);
|
||||||
|
auto q_norm_x = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm_x"]);
|
||||||
|
auto k_norm_x = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm_x"]);
|
||||||
|
auto q_norm_y = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm_y"]);
|
||||||
|
auto k_norm_y = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm_y"]);
|
||||||
|
auto proj_x = std::dynamic_pointer_cast<Linear>(blocks["proj_x"]);
|
||||||
|
auto proj_y = std::dynamic_pointer_cast<Linear>(blocks["proj_y"]);
|
||||||
|
|
||||||
|
int64_t Nx = x->ne[1];
|
||||||
|
int64_t Ny = y->ne[1];
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t head_dim = dim / num_heads;
|
||||||
|
|
||||||
|
auto qkv_x = split_qkv(ctx->ggml_ctx, qkv_x_proj->forward(ctx, x));
|
||||||
|
auto qx = ggml_reshape_4d(ctx->ggml_ctx, qkv_x[0], head_dim, num_heads, Nx, N);
|
||||||
|
auto kx = ggml_reshape_4d(ctx->ggml_ctx, qkv_x[1], head_dim, num_heads, Nx, N);
|
||||||
|
auto vx = ggml_reshape_4d(ctx->ggml_ctx, qkv_x[2], head_dim, num_heads, Nx, N);
|
||||||
|
qx = q_norm_x->forward(ctx, qx);
|
||||||
|
kx = k_norm_x->forward(ctx, kx);
|
||||||
|
|
||||||
|
auto qkv_y = split_qkv(ctx->ggml_ctx, qkv_y_proj->forward(ctx, y));
|
||||||
|
auto qy = ggml_reshape_4d(ctx->ggml_ctx, qkv_y[0], head_dim, num_heads, Ny, N);
|
||||||
|
auto ky = ggml_reshape_4d(ctx->ggml_ctx, qkv_y[1], head_dim, num_heads, Ny, N);
|
||||||
|
auto vy = ggml_reshape_4d(ctx->ggml_ctx, qkv_y[2], head_dim, num_heads, Ny, N);
|
||||||
|
qy = q_norm_y->forward(ctx, qy);
|
||||||
|
ky = k_norm_y->forward(ctx, ky);
|
||||||
|
|
||||||
|
auto q_joint = ggml_concat(ctx->ggml_ctx, qy, qx, 2);
|
||||||
|
auto k_joint = ggml_concat(ctx->ggml_ctx, ky, kx, 2);
|
||||||
|
auto v_joint = ggml_concat(ctx->ggml_ctx, vy, vx, 2);
|
||||||
|
auto pos_joint = ggml_concat(ctx->ggml_ctx, pos_txt, pos_img, 3);
|
||||||
|
auto out = Rope::attention(ctx, q_joint, k_joint, v_joint, pos_joint, nullptr, 1.0f, true);
|
||||||
|
|
||||||
|
auto out_y = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, Ny);
|
||||||
|
auto out_x = ggml_ext_slice(ctx->ggml_ctx, out, 1, Ny, Ny + Nx);
|
||||||
|
return {proj_x->forward(ctx, out_x), proj_y->forward(ctx, out_y)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MMDiTBlockT2I : public GGMLBlock {
|
||||||
|
int64_t hidden_size;
|
||||||
|
|
||||||
|
MMDiTBlockT2I(int64_t hidden_size, int64_t groups, int64_t mlp_hidden_dim)
|
||||||
|
: hidden_size(hidden_size) {
|
||||||
|
blocks["norm_x1"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
|
||||||
|
blocks["norm_y1"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
|
||||||
|
blocks["attn"] = std::make_shared<MMDiTJointAttention>(hidden_size, groups);
|
||||||
|
blocks["norm_x2"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
|
||||||
|
blocks["norm_y2"] = std::make_shared<RMSNorm>(hidden_size, 1e-6f);
|
||||||
|
blocks["mlp_x"] = std::make_shared<FeedForward>(hidden_size, mlp_hidden_dim);
|
||||||
|
blocks["mlp_y"] = std::make_shared<FeedForward>(hidden_size, mlp_hidden_dim);
|
||||||
|
blocks["adaLN_modulation_img.0"] = std::make_shared<Linear>(hidden_size, 6 * hidden_size, true);
|
||||||
|
blocks["adaLN_modulation_txt.0"] = std::make_shared<Linear>(hidden_size, 6 * hidden_size, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* y,
|
||||||
|
ggml_tensor* c,
|
||||||
|
ggml_tensor* pos_img,
|
||||||
|
ggml_tensor* pos_txt) {
|
||||||
|
auto norm_x1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_x1"]);
|
||||||
|
auto norm_y1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_y1"]);
|
||||||
|
auto attn = std::dynamic_pointer_cast<MMDiTJointAttention>(blocks["attn"]);
|
||||||
|
auto norm_x2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_x2"]);
|
||||||
|
auto norm_y2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_y2"]);
|
||||||
|
auto mlp_x = std::dynamic_pointer_cast<FeedForward>(blocks["mlp_x"]);
|
||||||
|
auto mlp_y = std::dynamic_pointer_cast<FeedForward>(blocks["mlp_y"]);
|
||||||
|
auto ada_img = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation_img.0"]);
|
||||||
|
auto ada_txt = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation_txt.0"]);
|
||||||
|
|
||||||
|
auto mx = ggml_ext_chunk(ctx->ggml_ctx, ada_img->forward(ctx, c), 6, 0);
|
||||||
|
auto my = ggml_ext_chunk(ctx->ggml_ctx, ada_txt->forward(ctx, c), 6, 0);
|
||||||
|
|
||||||
|
auto x_norm = apply_adaln(ctx->ggml_ctx, norm_x1->forward(ctx, x), mx[0], mx[1]);
|
||||||
|
auto y_norm = apply_adaln(ctx->ggml_ctx, norm_y1->forward(ctx, y), my[0], my[1]);
|
||||||
|
auto attn_out = attn->forward(ctx, x_norm, y_norm, pos_img, pos_txt);
|
||||||
|
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out.first, mx[2]));
|
||||||
|
y = ggml_add(ctx->ggml_ctx, y, ggml_mul(ctx->ggml_ctx, attn_out.second, my[2]));
|
||||||
|
|
||||||
|
auto x_mlp = mlp_x->forward(ctx, apply_adaln(ctx->ggml_ctx, norm_x2->forward(ctx, x), mx[3], mx[4]));
|
||||||
|
auto y_mlp = mlp_y->forward(ctx, apply_adaln(ctx->ggml_ctx, norm_y2->forward(ctx, y), my[3], my[4]));
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x_mlp, mx[5]));
|
||||||
|
y = ggml_add(ctx->ggml_ctx, y, ggml_mul(ctx->ggml_ctx, y_mlp, my[5]));
|
||||||
|
return {x, y};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PixelTokenEmbedder : public GGMLBlock {
|
||||||
|
int64_t in_channels;
|
||||||
|
int64_t hidden_size_output;
|
||||||
|
|
||||||
|
PixelTokenEmbedder(int64_t in_channels, int64_t hidden_size_output)
|
||||||
|
: in_channels(in_channels), hidden_size_output(hidden_size_output) {
|
||||||
|
blocks["proj"] = std::make_shared<Linear>(in_channels, hidden_size_output, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* inputs,
|
||||||
|
int64_t patch_size,
|
||||||
|
ggml_tensor* pos_full) {
|
||||||
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
int64_t W = inputs->ne[0];
|
||||||
|
int64_t H = inputs->ne[1];
|
||||||
|
int64_t B = inputs->ne[3];
|
||||||
|
int64_t L = (W / patch_size) * (H / patch_size);
|
||||||
|
int64_t P2 = patch_size * patch_size;
|
||||||
|
|
||||||
|
auto x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, inputs, 2, 0, 1, 3));
|
||||||
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, in_channels, W * H, B);
|
||||||
|
x = proj->forward(ctx, x);
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, pos_full);
|
||||||
|
x = ggml_reshape_4d(ctx->ggml_ctx, x, hidden_size_output, W, H, B);
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3));
|
||||||
|
x = DiT::patchify(ctx->ggml_ctx, x, static_cast<int>(patch_size), static_cast<int>(patch_size), false);
|
||||||
|
x = ggml_reshape_3d(ctx->ggml_ctx, x, hidden_size_output, P2, L * B);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PiTBlock : public GGMLBlock {
|
||||||
|
int64_t pixel_dim;
|
||||||
|
int64_t context_dim;
|
||||||
|
int64_t attn_dim;
|
||||||
|
int64_t num_heads;
|
||||||
|
int64_t patch_size;
|
||||||
|
|
||||||
|
PiTBlock(int64_t pixel_dim,
|
||||||
|
int64_t context_dim,
|
||||||
|
int64_t patch_size,
|
||||||
|
int64_t attn_dim,
|
||||||
|
int64_t num_heads)
|
||||||
|
: pixel_dim(pixel_dim),
|
||||||
|
context_dim(context_dim),
|
||||||
|
attn_dim(attn_dim),
|
||||||
|
num_heads(num_heads),
|
||||||
|
patch_size(patch_size) {
|
||||||
|
int64_t p2 = patch_size * patch_size;
|
||||||
|
blocks["compress_to_attn"] = std::make_shared<Linear>(p2 * pixel_dim, attn_dim, true);
|
||||||
|
blocks["expand_from_attn"] = std::make_shared<Linear>(attn_dim, p2 * pixel_dim, true);
|
||||||
|
blocks["norm1"] = std::make_shared<RMSNorm>(pixel_dim, 1e-6f);
|
||||||
|
blocks["attn"] = std::make_shared<RotaryAttention>(attn_dim, num_heads);
|
||||||
|
blocks["norm2"] = std::make_shared<RMSNorm>(pixel_dim, 1e-6f);
|
||||||
|
blocks["mlp"] = std::make_shared<Mlp>(pixel_dim, pixel_dim * 4);
|
||||||
|
blocks["adaLN_modulation.0"] = std::make_shared<Linear>(context_dim, 6 * pixel_dim * p2, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* s_cond,
|
||||||
|
int64_t image_height,
|
||||||
|
int64_t image_width,
|
||||||
|
ggml_tensor* pos_comp) {
|
||||||
|
auto compress = std::dynamic_pointer_cast<Linear>(blocks["compress_to_attn"]);
|
||||||
|
auto expand = std::dynamic_pointer_cast<Linear>(blocks["expand_from_attn"]);
|
||||||
|
auto norm1 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm1"]);
|
||||||
|
auto attn = std::dynamic_pointer_cast<RotaryAttention>(blocks["attn"]);
|
||||||
|
auto norm2 = std::dynamic_pointer_cast<RMSNorm>(blocks["norm2"]);
|
||||||
|
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
|
||||||
|
auto ada = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.0"]);
|
||||||
|
|
||||||
|
int64_t Hs = image_height / patch_size;
|
||||||
|
int64_t Ws = image_width / patch_size;
|
||||||
|
int64_t L = Hs * Ws;
|
||||||
|
int64_t BL = x->ne[2];
|
||||||
|
int64_t B = BL / L;
|
||||||
|
int64_t P2 = patch_size * patch_size;
|
||||||
|
|
||||||
|
auto ada_params = ada->forward(ctx, s_cond);
|
||||||
|
ada_params = ggml_reshape_3d(ctx->ggml_ctx, ada_params, 6 * pixel_dim, P2, BL);
|
||||||
|
auto mod = ggml_ext_chunk(ctx->ggml_ctx, ada_params, 6, 0);
|
||||||
|
|
||||||
|
auto x_norm = apply_adaln(ctx->ggml_ctx, norm1->forward(ctx, x), mod[0], mod[1]);
|
||||||
|
auto x_flat = ggml_reshape_2d(ctx->ggml_ctx, x_norm, P2 * pixel_dim, BL);
|
||||||
|
auto x_comp = compress->forward(ctx, x_flat);
|
||||||
|
x_comp = ggml_reshape_3d(ctx->ggml_ctx, x_comp, attn_dim, L, B);
|
||||||
|
auto attn_out = attn->forward(ctx, x_comp, pos_comp);
|
||||||
|
auto attn_flat = expand->forward(ctx, ggml_reshape_2d(ctx->ggml_ctx, attn_out, attn_dim, BL));
|
||||||
|
auto attn_exp = ggml_reshape_3d(ctx->ggml_ctx, attn_flat, pixel_dim, P2, BL);
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_exp, mod[2]));
|
||||||
|
|
||||||
|
auto mlp_out = mlp->forward(ctx, apply_adaln(ctx->ggml_ctx, norm2->forward(ctx, x), mod[3], mod[4]));
|
||||||
|
return ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, mod[5]));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SigmaAwareGate : public GGMLBlock {
|
||||||
|
int64_t dim;
|
||||||
|
|
||||||
|
SigmaAwareGate(int64_t dim)
|
||||||
|
: dim(dim) {
|
||||||
|
blocks["content_proj"] = std::make_shared<Linear>(dim * 2, dim, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
std::string prefix = "") override {
|
||||||
|
params["log_alpha"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* lq,
|
||||||
|
ggml_tensor* sigma) {
|
||||||
|
auto content_proj = std::dynamic_pointer_cast<Linear>(blocks["content_proj"]);
|
||||||
|
|
||||||
|
auto content_logit = content_proj->forward(ctx, ggml_concat(ctx->ggml_ctx, x, lq, 0));
|
||||||
|
sigma = ggml_reshape_3d(ctx->ggml_ctx, sigma, 1, 1, sigma->ne[0]);
|
||||||
|
auto alpha = ggml_exp(ctx->ggml_ctx, params["log_alpha"]);
|
||||||
|
auto offset = ggml_neg(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, alpha, sigma));
|
||||||
|
auto gate = ggml_sigmoid(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, content_logit, offset));
|
||||||
|
return ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, gate, lq));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PiDResBlock : public GGMLBlock {
|
||||||
|
PiDResBlock(int64_t channels) {
|
||||||
|
blocks["block.0"] = std::make_shared<GroupNorm>(4, channels, 1e-5f);
|
||||||
|
blocks["block.2"] = std::make_shared<Conv2d>(channels, channels, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
|
||||||
|
blocks["block.3"] = std::make_shared<GroupNorm>(4, channels, 1e-5f);
|
||||||
|
blocks["block.5"] = std::make_shared<Conv2d>(channels, channels, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
auto norm1 = std::dynamic_pointer_cast<GroupNorm>(blocks["block.0"]);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["block.2"]);
|
||||||
|
auto norm2 = std::dynamic_pointer_cast<GroupNorm>(blocks["block.3"]);
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["block.5"]);
|
||||||
|
auto h = ggml_silu_inplace(ctx->ggml_ctx, norm1->forward(ctx, x));
|
||||||
|
h = conv1->forward(ctx, h);
|
||||||
|
h = ggml_silu_inplace(ctx->ggml_ctx, norm2->forward(ctx, h));
|
||||||
|
h = conv2->forward(ctx, h);
|
||||||
|
return ggml_add(ctx->ggml_ctx, x, h);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LQProjection2D : public GGMLBlock {
|
||||||
|
PixelDiTParams params_cfg;
|
||||||
|
|
||||||
|
LQProjection2D(const PixelDiTParams& params_cfg)
|
||||||
|
: params_cfg(params_cfg) {
|
||||||
|
blocks["latent_proj.0"] = std::make_shared<Conv2d>(params_cfg.lq_latent_channels, params_cfg.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
|
||||||
|
blocks["latent_proj.2"] = std::make_shared<Conv2d>(params_cfg.lq_hidden_dim, params_cfg.lq_hidden_dim, std::pair<int, int>{3, 3}, std::pair<int, int>{1, 1}, std::pair<int, int>{1, 1});
|
||||||
|
for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) {
|
||||||
|
blocks["latent_proj." + std::to_string(3 + i)] = std::make_shared<PiDResBlock>(params_cfg.lq_hidden_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_outputs = static_cast<int>((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval);
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
blocks["output_heads." + std::to_string(i)] = std::make_shared<Linear>(params_cfg.lq_hidden_dim, params_cfg.hidden_size, true);
|
||||||
|
blocks["gate_modules." + std::to_string(i)] = std::make_shared<SigmaAwareGate>(params_cfg.hidden_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_gate_active(int block_idx) const {
|
||||||
|
return block_idx % params_cfg.lq_interval == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_output_index(int block_idx) const {
|
||||||
|
return block_idx / static_cast<int>(params_cfg.lq_interval);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* gate(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* lq,
|
||||||
|
ggml_tensor* sigma,
|
||||||
|
int out_idx) {
|
||||||
|
auto gate_module = std::dynamic_pointer_cast<SigmaAwareGate>(blocks["gate_modules." + std::to_string(out_idx)]);
|
||||||
|
return gate_module->forward(ctx, x, lq, sigma);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* lq_latent,
|
||||||
|
int64_t target_pH,
|
||||||
|
int64_t target_pW) {
|
||||||
|
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["latent_proj.0"]);
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["latent_proj.2"]);
|
||||||
|
float z_to_patch_ratio = static_cast<float>(params_cfg.lq_sr_scale * params_cfg.lq_latent_down_factor) /
|
||||||
|
static_cast<float>(params_cfg.patch_size);
|
||||||
|
GGML_ASSERT(z_to_patch_ratio >= 1.0f);
|
||||||
|
if (lq_latent->ne[0] != target_pW || lq_latent->ne[1] != target_pH) {
|
||||||
|
lq_latent = ggml_interpolate(ctx->ggml_ctx,
|
||||||
|
lq_latent,
|
||||||
|
target_pW,
|
||||||
|
target_pH,
|
||||||
|
lq_latent->ne[2],
|
||||||
|
lq_latent->ne[3],
|
||||||
|
GGML_SCALE_MODE_NEAREST);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto feat = conv0->forward(ctx, lq_latent);
|
||||||
|
feat = ggml_silu_inplace(ctx->ggml_ctx, feat);
|
||||||
|
feat = conv2->forward(ctx, feat);
|
||||||
|
for (int i = 0; i < params_cfg.lq_num_res_blocks; ++i) {
|
||||||
|
auto block = std::dynamic_pointer_cast<PiDResBlock>(blocks["latent_proj." + std::to_string(3 + i)]);
|
||||||
|
feat = block->forward(ctx, feat);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t B = feat->ne[3];
|
||||||
|
int64_t C = feat->ne[2];
|
||||||
|
int64_t L = target_pH * target_pW;
|
||||||
|
auto tokens = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, feat, 2, 0, 1, 3));
|
||||||
|
tokens = ggml_reshape_3d(ctx->ggml_ctx, tokens, C, L, B);
|
||||||
|
|
||||||
|
int num_outputs = static_cast<int>((params_cfg.patch_depth + params_cfg.lq_interval - 1) / params_cfg.lq_interval);
|
||||||
|
std::vector<ggml_tensor*> outputs;
|
||||||
|
outputs.reserve(num_outputs);
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
auto head = std::dynamic_pointer_cast<Linear>(blocks["output_heads." + std::to_string(i)]);
|
||||||
|
outputs.push_back(head->forward(ctx, tokens));
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PixelDiT : public GGMLBlock {
|
||||||
|
PixelDiTParams params_cfg;
|
||||||
|
|
||||||
|
PixelDiT() = default;
|
||||||
|
|
||||||
|
PixelDiT(const PixelDiTParams& params_cfg)
|
||||||
|
: params_cfg(params_cfg) {
|
||||||
|
blocks["pixel_embedder"] = std::make_shared<PixelTokenEmbedder>(params_cfg.in_channels, params_cfg.pixel_hidden_size);
|
||||||
|
blocks["s_embedder"] = std::make_shared<PatchTokenEmbedder>(params_cfg.in_channels * params_cfg.patch_size * params_cfg.patch_size, params_cfg.hidden_size, false, true);
|
||||||
|
blocks["t_embedder"] = std::make_shared<PixelDiTTimestepEmbedder>(params_cfg.hidden_size);
|
||||||
|
blocks["y_embedder"] = std::make_shared<PatchTokenEmbedder>(params_cfg.txt_embed_dim, params_cfg.hidden_size, true, true);
|
||||||
|
for (int i = 0; i < params_cfg.patch_depth; ++i) {
|
||||||
|
blocks["patch_blocks." + std::to_string(i)] = std::make_shared<MMDiTBlockT2I>(params_cfg.hidden_size, params_cfg.num_groups, params_cfg.patch_mlp_hidden_dim);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < params_cfg.pixel_depth; ++i) {
|
||||||
|
blocks["pixel_blocks." + std::to_string(i)] = std::make_shared<PiTBlock>(params_cfg.pixel_hidden_size,
|
||||||
|
params_cfg.hidden_size,
|
||||||
|
params_cfg.patch_size,
|
||||||
|
params_cfg.pixel_attn_hidden_size,
|
||||||
|
params_cfg.pixel_num_groups);
|
||||||
|
}
|
||||||
|
blocks["final_layer"] = std::make_shared<FinalLayer>(params_cfg.pixel_hidden_size, params_cfg.in_channels);
|
||||||
|
blocks["lq_proj"] = std::make_shared<LQProjection2D>(params_cfg);
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
std::string prefix = "") override {
|
||||||
|
params["y_pos_embedding"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, params_cfg.hidden_size, params_cfg.txt_max_length, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* timesteps,
|
||||||
|
ggml_tensor* context,
|
||||||
|
ggml_tensor* lq_latent,
|
||||||
|
ggml_tensor* degrade_sigma,
|
||||||
|
ggml_tensor* pos_img,
|
||||||
|
ggml_tensor* pos_txt,
|
||||||
|
ggml_tensor* pixel_pos_full,
|
||||||
|
ggml_tensor* pixel_pos_comp) {
|
||||||
|
auto pixel_embedder = std::dynamic_pointer_cast<PixelTokenEmbedder>(blocks["pixel_embedder"]);
|
||||||
|
auto s_embedder = std::dynamic_pointer_cast<PatchTokenEmbedder>(blocks["s_embedder"]);
|
||||||
|
auto t_embedder = std::dynamic_pointer_cast<PixelDiTTimestepEmbedder>(blocks["t_embedder"]);
|
||||||
|
auto y_embedder = std::dynamic_pointer_cast<PatchTokenEmbedder>(blocks["y_embedder"]);
|
||||||
|
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
|
||||||
|
auto lq_proj = std::dynamic_pointer_cast<LQProjection2D>(blocks["lq_proj"]);
|
||||||
|
|
||||||
|
int64_t W_orig = x->ne[0];
|
||||||
|
int64_t H_orig = x->ne[1];
|
||||||
|
x = DiT::pad_to_patch_size(ctx, x, static_cast<int>(params_cfg.patch_size), static_cast<int>(params_cfg.patch_size));
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t B = x->ne[3];
|
||||||
|
int64_t Hs = H / params_cfg.patch_size;
|
||||||
|
int64_t Ws = W / params_cfg.patch_size;
|
||||||
|
int64_t L = Hs * Ws;
|
||||||
|
int64_t P2 = params_cfg.patch_size * params_cfg.patch_size;
|
||||||
|
|
||||||
|
auto x_patches = DiT::patchify(ctx->ggml_ctx, x, static_cast<int>(params_cfg.patch_size), static_cast<int>(params_cfg.patch_size), true);
|
||||||
|
auto t_emb = t_embedder->forward(ctx, timesteps);
|
||||||
|
auto condition = ggml_silu(ctx->ggml_ctx, t_emb);
|
||||||
|
|
||||||
|
GGML_ASSERT(context != nullptr);
|
||||||
|
int64_t Ltxt = std::min<int64_t>(context->ne[1], params_cfg.txt_max_length);
|
||||||
|
auto y = ggml_ext_slice(ctx->ggml_ctx, context, 1, 0, Ltxt);
|
||||||
|
auto y_emb = y_embedder->forward(ctx, y);
|
||||||
|
auto y_pos = ggml_ext_slice(ctx->ggml_ctx, params["y_pos_embedding"], 1, 0, Ltxt);
|
||||||
|
y_emb = ggml_add(ctx->ggml_ctx, y_emb, y_pos);
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> lq_features = lq_proj->forward(ctx, lq_latent, Hs, Ws);
|
||||||
|
|
||||||
|
auto s = s_embedder->forward(ctx, x_patches);
|
||||||
|
|
||||||
|
for (int i = 0; i < params_cfg.patch_depth; ++i) {
|
||||||
|
if (lq_proj->is_gate_active(i)) {
|
||||||
|
int out_idx = lq_proj->get_output_index(i);
|
||||||
|
if (out_idx < static_cast<int>(lq_features.size())) {
|
||||||
|
s = lq_proj->gate(ctx, s, lq_features[out_idx], degrade_sigma, out_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto block = std::dynamic_pointer_cast<MMDiTBlockT2I>(blocks["patch_blocks." + std::to_string(i)]);
|
||||||
|
auto out = block->forward(ctx,
|
||||||
|
s,
|
||||||
|
y_emb,
|
||||||
|
condition,
|
||||||
|
pos_img,
|
||||||
|
pos_txt);
|
||||||
|
s = out.first;
|
||||||
|
y_emb = out.second;
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(s, "pid.patch_blocks." + std::to_string(i), "s");
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(y_emb, "pid.patch_blocks." + std::to_string(i), "y");
|
||||||
|
}
|
||||||
|
s = ggml_silu(ctx->ggml_ctx, ggml_add(ctx->ggml_ctx, s, t_emb));
|
||||||
|
|
||||||
|
auto s_cond = ggml_reshape_2d(ctx->ggml_ctx, s, params_cfg.hidden_size, L * B);
|
||||||
|
auto pixels = pixel_embedder->forward(ctx, x, params_cfg.patch_size, pixel_pos_full);
|
||||||
|
for (int i = 0; i < params_cfg.pixel_depth; ++i) {
|
||||||
|
auto block = std::dynamic_pointer_cast<PiTBlock>(blocks["pixel_blocks." + std::to_string(i)]);
|
||||||
|
pixels = block->forward(ctx, pixels, s_cond, H, W, pixel_pos_comp);
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(pixels, "pid.pixel_blocks." + std::to_string(i), "pixels");
|
||||||
|
}
|
||||||
|
|
||||||
|
pixels = final_layer->forward(ctx, pixels);
|
||||||
|
pixels = ggml_reshape_3d(ctx->ggml_ctx, pixels, params_cfg.in_channels * P2, L, B);
|
||||||
|
auto out = DiT::unpatchify(ctx->ggml_ctx,
|
||||||
|
pixels,
|
||||||
|
Hs,
|
||||||
|
Ws,
|
||||||
|
static_cast<int>(params_cfg.patch_size),
|
||||||
|
static_cast<int>(params_cfg.patch_size),
|
||||||
|
false);
|
||||||
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H_orig);
|
||||||
|
out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W_orig);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PiDRunner : public DiffusionModelRunner {
|
||||||
|
PixelDiTParams params_cfg;
|
||||||
|
PixelDiT model;
|
||||||
|
std::vector<float> pos_img_vec;
|
||||||
|
std::vector<float> pos_txt_vec;
|
||||||
|
std::vector<float> pixel_pos_vec;
|
||||||
|
std::vector<float> pixel_pos_comp_vec;
|
||||||
|
|
||||||
|
PiDRunner(ggml_backend_t backend,
|
||||||
|
ggml_backend_t params_backend,
|
||||||
|
const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string prefix = "model.diffusion_model")
|
||||||
|
: DiffusionModelRunner(backend, params_backend, prefix) {
|
||||||
|
for (const auto& pair : tensor_storage_map) {
|
||||||
|
const std::string& tensor_name = pair.first;
|
||||||
|
if (tensor_name.find(prefix) == std::string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_t pos = tensor_name.find("patch_blocks.");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
auto items = split_string(tensor_name.substr(pos), '.');
|
||||||
|
if (items.size() > 1) {
|
||||||
|
int block_index = atoi(items[1].c_str());
|
||||||
|
params_cfg.patch_depth = std::max<int64_t>(params_cfg.patch_depth, block_index + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos = tensor_name.find("pixel_blocks.");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
auto items = split_string(tensor_name.substr(pos), '.');
|
||||||
|
if (items.size() > 1) {
|
||||||
|
int block_index = atoi(items[1].c_str());
|
||||||
|
params_cfg.pixel_depth = std::max<int64_t>(params_cfg.pixel_depth, block_index + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tensor_name.find("lq_proj.latent_proj.0.weight") != std::string::npos) {
|
||||||
|
params_cfg.lq_latent_channels = pair.second.ne[2];
|
||||||
|
params_cfg.lq_latent_down_factor = params_cfg.lq_latent_channels >= 64 ? 16 : 8;
|
||||||
|
}
|
||||||
|
if (tensor_name.find("patch_blocks.0.mlp_x.w1.weight") != std::string::npos) {
|
||||||
|
params_cfg.patch_mlp_hidden_dim = pair.second.ne[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_INFO("PiD params: patch_depth=%" PRId64 ", pixel_depth=%" PRId64 ", patch_mlp_hidden_dim=%" PRId64 ", lq_latent_channels=%" PRId64 ", lq_latent_down_factor=%" PRId64,
|
||||||
|
params_cfg.patch_depth,
|
||||||
|
params_cfg.pixel_depth,
|
||||||
|
params_cfg.patch_mlp_hidden_dim,
|
||||||
|
params_cfg.lq_latent_channels,
|
||||||
|
params_cfg.lq_latent_down_factor);
|
||||||
|
model = PixelDiT(params_cfg);
|
||||||
|
model.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return "PiD";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) override {
|
||||||
|
model.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||||
|
const sd::Tensor<float>& timesteps_tensor,
|
||||||
|
const sd::Tensor<float>& context_tensor,
|
||||||
|
const sd::Tensor<float>& lq_latent_tensor,
|
||||||
|
const sd::Tensor<float>& degrade_sigma_tensor) {
|
||||||
|
ggml_cgraph* gf = new_graph_custom(PID_GRAPH_SIZE);
|
||||||
|
ggml_tensor* x = make_input(x_tensor);
|
||||||
|
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||||
|
ggml_tensor* context = make_input(context_tensor);
|
||||||
|
ggml_tensor* lq_latent = make_input(lq_latent_tensor);
|
||||||
|
ggml_tensor* degrade_sigma = make_input(degrade_sigma_tensor);
|
||||||
|
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t B = x->ne[3];
|
||||||
|
int64_t Wp = align_up(static_cast<int>(W), static_cast<int>(params_cfg.patch_size));
|
||||||
|
int64_t Hp = align_up(static_cast<int>(H), static_cast<int>(params_cfg.patch_size));
|
||||||
|
int64_t Hs = Hp / params_cfg.patch_size;
|
||||||
|
int64_t Ws = Wp / params_cfg.patch_size;
|
||||||
|
|
||||||
|
pos_img_vec = make_rope_2d(static_cast<int>(Hs),
|
||||||
|
static_cast<int>(Ws),
|
||||||
|
static_cast<int>(params_cfg.hidden_size / params_cfg.num_groups),
|
||||||
|
10000.f,
|
||||||
|
16.f,
|
||||||
|
static_cast<int>(params_cfg.rope_ref_grid_h),
|
||||||
|
static_cast<int>(params_cfg.rope_ref_grid_w));
|
||||||
|
auto pos_img = ggml_new_tensor_4d(compute_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
params_cfg.hidden_size / params_cfg.num_groups / 2,
|
||||||
|
Hs * Ws);
|
||||||
|
set_backend_tensor_data(pos_img, pos_img_vec.data());
|
||||||
|
|
||||||
|
int64_t Ltxt = std::min<int64_t>(context->ne[1], params_cfg.txt_max_length);
|
||||||
|
pos_txt_vec = make_rope_1d(static_cast<int>(Ltxt),
|
||||||
|
static_cast<int>(params_cfg.hidden_size / params_cfg.num_groups),
|
||||||
|
params_cfg.text_rope_theta);
|
||||||
|
auto pos_txt = ggml_new_tensor_4d(compute_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
params_cfg.hidden_size / params_cfg.num_groups / 2,
|
||||||
|
Ltxt);
|
||||||
|
set_backend_tensor_data(pos_txt, pos_txt_vec.data());
|
||||||
|
|
||||||
|
pixel_pos_vec = make_pixel_abs_pos(static_cast<int>(Hp),
|
||||||
|
static_cast<int>(Wp),
|
||||||
|
static_cast<int>(params_cfg.pixel_hidden_size));
|
||||||
|
auto pixel_pos = ggml_new_tensor_3d(compute_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
params_cfg.pixel_hidden_size,
|
||||||
|
Wp * Hp,
|
||||||
|
1);
|
||||||
|
set_backend_tensor_data(pixel_pos, pixel_pos_vec.data());
|
||||||
|
|
||||||
|
pixel_pos_comp_vec = make_rope_2d(static_cast<int>(Hs),
|
||||||
|
static_cast<int>(Ws),
|
||||||
|
static_cast<int>(params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups),
|
||||||
|
10000.f,
|
||||||
|
16.f,
|
||||||
|
static_cast<int>(params_cfg.rope_ref_grid_h),
|
||||||
|
static_cast<int>(params_cfg.rope_ref_grid_w));
|
||||||
|
auto pixel_pos_comp = ggml_new_tensor_4d(compute_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
params_cfg.pixel_attn_hidden_size / params_cfg.pixel_num_groups / 2,
|
||||||
|
Hs * Ws);
|
||||||
|
set_backend_tensor_data(pixel_pos_comp, pixel_pos_comp_vec.data());
|
||||||
|
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
auto out = model.forward(&runner_ctx,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
lq_latent,
|
||||||
|
degrade_sigma,
|
||||||
|
pos_img,
|
||||||
|
pos_txt,
|
||||||
|
pixel_pos,
|
||||||
|
pixel_pos_comp);
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> compute(int n_threads,
|
||||||
|
const sd::Tensor<float>& x,
|
||||||
|
const sd::Tensor<float>& timesteps,
|
||||||
|
const sd::Tensor<float>& context,
|
||||||
|
const sd::Tensor<float>& lq_latent,
|
||||||
|
const sd::Tensor<float>& degrade_sigma) {
|
||||||
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
|
return build_graph(x, timesteps, context, lq_latent, degrade_sigma);
|
||||||
|
};
|
||||||
|
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> compute(int n_threads,
|
||||||
|
const DiffusionParams& diffusion_params) override {
|
||||||
|
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||||
|
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||||
|
GGML_ASSERT(diffusion_params.context != nullptr);
|
||||||
|
GGML_ASSERT(diffusion_params.ref_latents != nullptr);
|
||||||
|
GGML_ASSERT(!diffusion_params.ref_latents->empty());
|
||||||
|
auto degrade_sigma = sd::Tensor<float>::from_vector({0.0f});
|
||||||
|
return compute(n_threads,
|
||||||
|
*diffusion_params.x,
|
||||||
|
*diffusion_params.timesteps,
|
||||||
|
*diffusion_params.context,
|
||||||
|
diffusion_params.ref_latents->front(),
|
||||||
|
degrade_sigma);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace Pid
|
||||||
|
|
||||||
|
#endif // __SD_PID_HPP__
|
||||||
53
src/rope.hpp
53
src/rope.hpp
@ -249,6 +249,59 @@ namespace Rope {
|
|||||||
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout);
|
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ std::vector<float> embed_2d_interleaved(int height,
|
||||||
|
int width,
|
||||||
|
int dim,
|
||||||
|
float theta = 10000.f,
|
||||||
|
float scale = 16.f,
|
||||||
|
int ref_grid_h = 0,
|
||||||
|
int ref_grid_w = 0) {
|
||||||
|
assert(dim % 4 == 0);
|
||||||
|
int half_dim = dim / 2;
|
||||||
|
int dim_axis = dim / 2;
|
||||||
|
int axis_half_dim = dim_axis / 2;
|
||||||
|
|
||||||
|
float h_ntk = 1.f;
|
||||||
|
float w_ntk = 1.f;
|
||||||
|
if (ref_grid_h > 0 && ref_grid_w > 0 && dim_axis > 2) {
|
||||||
|
float power = static_cast<float>(dim_axis) / static_cast<float>(dim_axis - 2);
|
||||||
|
h_ntk = std::pow(static_cast<float>(height) / static_cast<float>(ref_grid_h), power);
|
||||||
|
w_ntk = std::pow(static_cast<float>(width) / static_cast<float>(ref_grid_w), power);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> x_pos;
|
||||||
|
std::vector<float> y_pos;
|
||||||
|
x_pos.reserve(static_cast<size_t>(height) * width);
|
||||||
|
y_pos.reserve(static_cast<size_t>(height) * width);
|
||||||
|
for (int iy = 0; iy < height; ++iy) {
|
||||||
|
float y = height == 1 ? 0.f : scale * static_cast<float>(iy) / static_cast<float>(height - 1);
|
||||||
|
for (int ix = 0; ix < width; ++ix) {
|
||||||
|
float x = width == 1 ? 0.f : scale * static_cast<float>(ix) / static_cast<float>(width - 1);
|
||||||
|
x_pos.push_back(x);
|
||||||
|
y_pos.push_back(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_emb = rope(x_pos, dim_axis, theta * w_ntk);
|
||||||
|
auto y_emb = rope(y_pos, dim_axis, theta * h_ntk);
|
||||||
|
|
||||||
|
std::vector<float> out(static_cast<size_t>(height) * width * half_dim * 4);
|
||||||
|
for (int pos = 0; pos < height * width; ++pos) {
|
||||||
|
for (int i = 0; i < axis_half_dim; ++i) {
|
||||||
|
int jx = 2 * i;
|
||||||
|
int jy = 2 * i + 1;
|
||||||
|
size_t base_x = static_cast<size_t>(pos) * half_dim * 4 + static_cast<size_t>(jx) * 4;
|
||||||
|
size_t base_y = static_cast<size_t>(pos) * half_dim * 4 + static_cast<size_t>(jy) * 4;
|
||||||
|
size_t axis = static_cast<size_t>(i) * 4;
|
||||||
|
for (int k = 0; k < 4; ++k) {
|
||||||
|
out[base_x + k] = x_emb[pos][axis + k];
|
||||||
|
out[base_y + k] = y_emb[pos][axis + k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
|
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
|
||||||
int bs,
|
int bs,
|
||||||
int axes_dim_num,
|
int axes_dim_num,
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "ggml_graph_cut.h"
|
#include "ggml_graph_cut.h"
|
||||||
|
|
||||||
@ -26,6 +30,7 @@
|
|||||||
#include "ltx_vae.hpp"
|
#include "ltx_vae.hpp"
|
||||||
#include "ltxv.hpp"
|
#include "ltxv.hpp"
|
||||||
#include "mmdit.hpp"
|
#include "mmdit.hpp"
|
||||||
|
#include "pid.hpp"
|
||||||
#include "pmid.hpp"
|
#include "pmid.hpp"
|
||||||
#include "qwen_image.hpp"
|
#include "qwen_image.hpp"
|
||||||
#include "sample-cache.h"
|
#include "sample-cache.h"
|
||||||
@ -39,6 +44,9 @@
|
|||||||
#include "latent-preview.h"
|
#include "latent-preview.h"
|
||||||
#include "name_conversion.h"
|
#include "name_conversion.h"
|
||||||
|
|
||||||
|
const char* sd_vae_format_name(enum sd_vae_format_t format);
|
||||||
|
static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback);
|
||||||
|
|
||||||
const char* model_version_to_str[] = {
|
const char* model_version_to_str[] = {
|
||||||
"SD 1.x",
|
"SD 1.x",
|
||||||
"SD 1.x Inpaint",
|
"SD 1.x Inpaint",
|
||||||
@ -75,6 +83,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Ernie Image",
|
"Ernie Image",
|
||||||
"Lens",
|
"Lens",
|
||||||
"Longcat-Image",
|
"Longcat-Image",
|
||||||
|
"PiD",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sampling_methods_str[] = {
|
const char* sampling_methods_str[] = {
|
||||||
@ -501,6 +510,16 @@ public:
|
|||||||
params_backend_for(SDBackendModule::DIFFUSION),
|
params_backend_for(SDBackendModule::DIFFUSION),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
"model.diffusion_model");
|
"model.diffusion_model");
|
||||||
|
} else if (sd_version_is_pid(version)) {
|
||||||
|
vae_decode_only = false;
|
||||||
|
cond_stage_model = std::make_shared<LLMEmbedder>(backend_for(SDBackendModule::TE),
|
||||||
|
params_backend_for(SDBackendModule::TE),
|
||||||
|
tensor_storage_map,
|
||||||
|
version);
|
||||||
|
diffusion_model = std::make_shared<Pid::PiDRunner>(backend_for(SDBackendModule::DIFFUSION),
|
||||||
|
params_backend_for(SDBackendModule::DIFFUSION),
|
||||||
|
tensor_storage_map,
|
||||||
|
"model.diffusion_model.net");
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
bool is_chroma = false;
|
bool is_chroma = false;
|
||||||
for (auto pair : tensor_storage_map) {
|
for (auto pair : tensor_storage_map) {
|
||||||
@ -743,6 +762,16 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
sd_vae_format_t vae_format = sd_ctx_params->vae_format;
|
||||||
|
if (vae_format < SD_VAE_FORMAT_AUTO || vae_format >= SD_VAE_FORMAT_COUNT) {
|
||||||
|
LOG_WARN("invalid VAE format override, using auto");
|
||||||
|
vae_format = SD_VAE_FORMAT_AUTO;
|
||||||
|
}
|
||||||
|
SDVersion vae_version = version;
|
||||||
|
if (sd_version_is_pid(version) && vae_format != SD_VAE_FORMAT_AUTO) {
|
||||||
|
vae_version = sd_vae_format_to_version(vae_format, vae_version);
|
||||||
|
}
|
||||||
|
|
||||||
auto create_vae = [&]() -> std::shared_ptr<VAE> {
|
auto create_vae = [&]() -> std::shared_ptr<VAE> {
|
||||||
if (sd_version_is_ltxav(version)) {
|
if (sd_version_is_ltxav(version)) {
|
||||||
return std::make_shared<LTXVideoVAE>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<LTXVideoVAE>(backend_for(SDBackendModule::VAE),
|
||||||
@ -767,7 +796,7 @@ public:
|
|||||||
"first_stage_model",
|
"first_stage_model",
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
false,
|
false,
|
||||||
version);
|
vae_version);
|
||||||
if (sd_version_is_sdxl(version) &&
|
if (sd_version_is_sdxl(version) &&
|
||||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
|
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
|
||||||
float vae_conv_2d_scale = 1.f / 32.f;
|
float vae_conv_2d_scale = 1.f / 32.f;
|
||||||
@ -1140,12 +1169,15 @@ public:
|
|||||||
version == VERSION_HIDREAM_O1 ||
|
version == VERSION_HIDREAM_O1 ||
|
||||||
sd_version_is_anima(version) ||
|
sd_version_is_anima(version) ||
|
||||||
sd_version_is_ernie_image(version) ||
|
sd_version_is_ernie_image(version) ||
|
||||||
sd_version_is_z_image(version)) {
|
sd_version_is_z_image(version) ||
|
||||||
|
sd_version_is_pid(version)) {
|
||||||
pred_type = FLOW_PRED;
|
pred_type = FLOW_PRED;
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version)) {
|
||||||
default_flow_shift = 5.f;
|
default_flow_shift = 5.f;
|
||||||
} else if (sd_version_is_ernie_image(version)) {
|
} else if (sd_version_is_ernie_image(version)) {
|
||||||
default_flow_shift = 4.f;
|
default_flow_shift = 4.f;
|
||||||
|
} else if (sd_version_is_pid(version)) {
|
||||||
|
default_flow_shift = 1.5f;
|
||||||
} else {
|
} else {
|
||||||
default_flow_shift = 3.f;
|
default_flow_shift = 3.f;
|
||||||
}
|
}
|
||||||
@ -2180,6 +2212,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
int get_vae_scale_factor() {
|
int get_vae_scale_factor() {
|
||||||
|
if (sd_version_is_pid(version)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
return first_stage_model->get_scale_factor();
|
return first_stage_model->get_scale_factor();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2206,6 +2241,8 @@ public:
|
|||||||
latent_channel = 3;
|
latent_channel = 3;
|
||||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
latent_channel = 3;
|
latent_channel = 3;
|
||||||
|
} else if (sd_version_is_pid(version)) {
|
||||||
|
latent_channel = 3;
|
||||||
} else if (sd_version_uses_flux2_vae(version)) {
|
} else if (sd_version_uses_flux2_vae(version)) {
|
||||||
latent_channel = 128;
|
latent_channel = 128;
|
||||||
} else {
|
} else {
|
||||||
@ -2283,6 +2320,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) {
|
sd::Tensor<float> decode_first_stage(const sd::Tensor<float>& x, bool decode_video = false) {
|
||||||
|
if (sd_version_is_pid(version)) {
|
||||||
|
return sd::ops::clamp((x + 1.f) * 0.5f, 0.0f, 1.0f);
|
||||||
|
}
|
||||||
auto latents = first_stage_model->diffusion_to_vae_latents(x);
|
auto latents = first_stage_model->diffusion_to_vae_latents(x);
|
||||||
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
|
||||||
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
|
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
|
||||||
@ -2543,6 +2583,35 @@ enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str) {
|
|||||||
return SD_HIRES_UPSCALER_COUNT;
|
return SD_HIRES_UPSCALER_COUNT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* sd_vae_format_name(enum sd_vae_format_t format) {
|
||||||
|
switch (format) {
|
||||||
|
case SD_VAE_FORMAT_AUTO:
|
||||||
|
return "auto";
|
||||||
|
case SD_VAE_FORMAT_FLUX:
|
||||||
|
return "flux";
|
||||||
|
case SD_VAE_FORMAT_SD3:
|
||||||
|
return "sd3";
|
||||||
|
case SD_VAE_FORMAT_FLUX2:
|
||||||
|
return "flux2";
|
||||||
|
default:
|
||||||
|
return NONE_STR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback) {
|
||||||
|
switch (format) {
|
||||||
|
case SD_VAE_FORMAT_FLUX:
|
||||||
|
return VERSION_FLUX;
|
||||||
|
case SD_VAE_FORMAT_SD3:
|
||||||
|
return VERSION_SD3;
|
||||||
|
case SD_VAE_FORMAT_FLUX2:
|
||||||
|
return VERSION_FLUX2;
|
||||||
|
case SD_VAE_FORMAT_AUTO:
|
||||||
|
default:
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void sd_cache_params_init(sd_cache_params_t* cache_params) {
|
void sd_cache_params_init(sd_cache_params_t* cache_params) {
|
||||||
*cache_params = {};
|
*cache_params = {};
|
||||||
cache_params->mode = SD_CACHE_DISABLED;
|
cache_params->mode = SD_CACHE_DISABLED;
|
||||||
@ -2608,6 +2677,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
|||||||
sd_ctx_params->chroma_use_dit_mask = true;
|
sd_ctx_params->chroma_use_dit_mask = true;
|
||||||
sd_ctx_params->chroma_use_t5_mask = false;
|
sd_ctx_params->chroma_use_t5_mask = false;
|
||||||
sd_ctx_params->chroma_t5_mask_pad = 1;
|
sd_ctx_params->chroma_t5_mask_pad = 1;
|
||||||
|
sd_ctx_params->vae_format = SD_VAE_FORMAT_AUTO;
|
||||||
sd_ctx_params->backend = nullptr;
|
sd_ctx_params->backend = nullptr;
|
||||||
sd_ctx_params->params_backend = nullptr;
|
sd_ctx_params->params_backend = nullptr;
|
||||||
}
|
}
|
||||||
@ -2655,7 +2725,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
"circular_y: %s\n"
|
"circular_y: %s\n"
|
||||||
"chroma_use_dit_mask: %s\n"
|
"chroma_use_dit_mask: %s\n"
|
||||||
"chroma_use_t5_mask: %s\n"
|
"chroma_use_t5_mask: %s\n"
|
||||||
"chroma_t5_mask_pad: %d\n",
|
"chroma_t5_mask_pad: %d\n"
|
||||||
|
"vae_format: %s\n",
|
||||||
SAFE_STR(sd_ctx_params->model_path),
|
SAFE_STR(sd_ctx_params->model_path),
|
||||||
SAFE_STR(sd_ctx_params->clip_l_path),
|
SAFE_STR(sd_ctx_params->clip_l_path),
|
||||||
SAFE_STR(sd_ctx_params->clip_g_path),
|
SAFE_STR(sd_ctx_params->clip_g_path),
|
||||||
@ -2692,7 +2763,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
|||||||
BOOL_STR(sd_ctx_params->circular_y),
|
BOOL_STR(sd_ctx_params->circular_y),
|
||||||
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
|
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),
|
||||||
BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
|
BOOL_STR(sd_ctx_params->chroma_use_t5_mask),
|
||||||
sd_ctx_params->chroma_t5_mask_pad);
|
sd_ctx_params->chroma_t5_mask_pad,
|
||||||
|
sd_vae_format_name(sd_ctx_params->vae_format));
|
||||||
|
|
||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
@ -2969,6 +3041,9 @@ SD_API bool sd_ctx_supports_video_generation(const sd_ctx_t* sd_ctx) {
|
|||||||
|
|
||||||
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
||||||
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
||||||
|
if (sd_version_is_pid(sd_ctx->sd->version)) {
|
||||||
|
return LCM_SAMPLE_METHOD;
|
||||||
|
}
|
||||||
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
||||||
return EULER_SAMPLE_METHOD;
|
return EULER_SAMPLE_METHOD;
|
||||||
}
|
}
|
||||||
@ -3867,7 +3942,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
sd::Tensor<float> ref_latent;
|
sd::Tensor<float> ref_latent;
|
||||||
if (request->auto_resize_ref_image) {
|
if (request->auto_resize_ref_image && !sd_version_is_pid(sd_ctx->sd->version)) {
|
||||||
LOG_DEBUG("auto resize ref images");
|
LOG_DEBUG("auto resize ref images");
|
||||||
int vae_image_size = std::min(1024 * 1024, request->width * request->height);
|
int vae_image_size = std::min(1024 * 1024, request->width * request->height);
|
||||||
double vae_width = sqrt(vae_image_size * ref_images[i].shape()[0] / ref_images[i].shape()[1]);
|
double vae_width = sqrt(vae_image_size * ref_images[i].shape()[0] / ref_images[i].shape()[1]);
|
||||||
@ -3899,6 +3974,13 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
|
|||||||
ref_latents.push_back(std::move(ref_latent));
|
ref_latents.push_back(std::move(ref_latent));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_version_is_pid(sd_ctx->sd->version)) {
|
||||||
|
if (ref_latents.empty()) {
|
||||||
|
LOG_ERROR("PiD requires a reference image");
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sd::Tensor<float> concat_latent;
|
sd::Tensor<float> concat_latent;
|
||||||
sd::Tensor<float> uncond_concat_latent;
|
sd::Tensor<float> uncond_concat_latent;
|
||||||
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
|
|||||||
@ -478,7 +478,7 @@ struct T5Embedder {
|
|||||||
bool alloc_params_buffer() {
|
bool alloc_params_buffer() {
|
||||||
if (!model.alloc_params_buffer()) {
|
if (!model.alloc_params_buffer()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -182,7 +182,8 @@ std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t
|
|||||||
unsigned char b = utf8_token_str[i];
|
unsigned char b = utf8_token_str[i];
|
||||||
char hex_buf[16];
|
char hex_buf[16];
|
||||||
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b);
|
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", b);
|
||||||
iter = encoder.find(utf8_to_utf32(hex_buf));
|
iter = encoder.find(utf8_to_utf32(hex_buf));
|
||||||
|
token_id = iter != encoder.end() ? iter->second : UNK_TOKEN_ID;
|
||||||
bpe_tokens.push_back(token_id);
|
bpe_tokens.push_back(token_id);
|
||||||
token_strs.push_back(hex_buf);
|
token_strs.push_back(hex_buf);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -189,3 +189,164 @@ GemmaTokenizer::GemmaTokenizer(const std::string& merges_utf8_str, const std::st
|
|||||||
load_from_merges(load_gemma_merges(), load_gemma_vocab_json());
|
load_from_merges(load_gemma_merges(), load_gemma_vocab_json());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Gemma2Tokenizer::normalize(const std::string& text) const {
|
||||||
|
std::string normalized = text;
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = normalized.find(' ', pos)) != std::string::npos) {
|
||||||
|
normalized.replace(pos, 1, "\xE2\x96\x81");
|
||||||
|
pos += 3;
|
||||||
|
}
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Gemma2Tokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
nlohmann::json vocab;
|
||||||
|
try {
|
||||||
|
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||||
|
} catch (const nlohmann::json::parse_error&) {
|
||||||
|
GGML_ABORT("invalid vocab json str");
|
||||||
|
}
|
||||||
|
for (const auto& [key, value] : vocab.items()) {
|
||||||
|
std::u32string token = utf8_to_utf32(key);
|
||||||
|
int i = value;
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
}
|
||||||
|
encoder_len = static_cast<int>(vocab.size());
|
||||||
|
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||||
|
|
||||||
|
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
}
|
||||||
|
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
Gemma2Tokenizer::Gemma2Tokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
byte_level_bpe = false;
|
||||||
|
byte_fallback = true;
|
||||||
|
add_bos_token = true;
|
||||||
|
PAD_TOKEN = "<pad>";
|
||||||
|
EOS_TOKEN = "<eos>";
|
||||||
|
BOS_TOKEN = "<bos>";
|
||||||
|
UNK_TOKEN = "<unk>";
|
||||||
|
|
||||||
|
PAD_TOKEN_ID = 0;
|
||||||
|
EOS_TOKEN_ID = 1;
|
||||||
|
BOS_TOKEN_ID = 2;
|
||||||
|
UNK_TOKEN_ID = 3;
|
||||||
|
|
||||||
|
std::vector<std::string> special_tokens_before_merge = {
|
||||||
|
PAD_TOKEN,
|
||||||
|
EOS_TOKEN,
|
||||||
|
BOS_TOKEN,
|
||||||
|
UNK_TOKEN,
|
||||||
|
"<mask>",
|
||||||
|
"<2mass>",
|
||||||
|
"[@BOS@]",
|
||||||
|
};
|
||||||
|
for (int i = 0; i <= 98; i++) {
|
||||||
|
special_tokens_before_merge.push_back("<unused" + std::to_string(i) + ">");
|
||||||
|
}
|
||||||
|
special_tokens_before_merge.push_back("<start_of_turn>");
|
||||||
|
special_tokens_before_merge.push_back("<end_of_turn>");
|
||||||
|
for (int i = 1; i <= 31; i++) {
|
||||||
|
special_tokens_before_merge.push_back(std::string(i, '\n'));
|
||||||
|
}
|
||||||
|
for (int i = 2; i <= 31; i++) {
|
||||||
|
std::string whitespace_token;
|
||||||
|
for (int j = 0; j < i; j++) {
|
||||||
|
whitespace_token += "\xE2\x96\x81";
|
||||||
|
}
|
||||||
|
special_tokens_before_merge.push_back(whitespace_token);
|
||||||
|
}
|
||||||
|
std::vector<std::string> html_tokens = {
|
||||||
|
"<table>",
|
||||||
|
"<caption>",
|
||||||
|
"<thead>",
|
||||||
|
"<tbody>",
|
||||||
|
"<tfoot>",
|
||||||
|
"<tr>",
|
||||||
|
"<th>",
|
||||||
|
"<td>",
|
||||||
|
"</table>",
|
||||||
|
"</caption>",
|
||||||
|
"</thead>",
|
||||||
|
"</tbody>",
|
||||||
|
"</tfoot>",
|
||||||
|
"</tr>",
|
||||||
|
"</th>",
|
||||||
|
"</td>",
|
||||||
|
"<h1>",
|
||||||
|
"<h2>",
|
||||||
|
"<h3>",
|
||||||
|
"<h4>",
|
||||||
|
"<h5>",
|
||||||
|
"<h6>",
|
||||||
|
"<blockquote>",
|
||||||
|
"</h1>",
|
||||||
|
"</h2>",
|
||||||
|
"</h3>",
|
||||||
|
"</h4>",
|
||||||
|
"</h5>",
|
||||||
|
"</h6>",
|
||||||
|
"</blockquote>",
|
||||||
|
"<strong>",
|
||||||
|
"<em>",
|
||||||
|
"<b>",
|
||||||
|
"<i>",
|
||||||
|
"<u>",
|
||||||
|
"<s>",
|
||||||
|
"<sub>",
|
||||||
|
"<sup>",
|
||||||
|
"<code>",
|
||||||
|
"</strong>",
|
||||||
|
"</em>",
|
||||||
|
"</b>",
|
||||||
|
"</i>",
|
||||||
|
"</u>",
|
||||||
|
"</s>",
|
||||||
|
"</sub>",
|
||||||
|
"</sup>",
|
||||||
|
"</code>",
|
||||||
|
};
|
||||||
|
special_tokens_before_merge.insert(special_tokens_before_merge.end(),
|
||||||
|
html_tokens.begin(),
|
||||||
|
html_tokens.end());
|
||||||
|
for (int i = 0; i <= 0xFF; i++) {
|
||||||
|
char hex_buf[16];
|
||||||
|
snprintf(hex_buf, sizeof(hex_buf), "<0x%02X>", i);
|
||||||
|
special_tokens_before_merge.push_back(hex_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> special_tokens_after_merge = {
|
||||||
|
"[toxicity=0]",
|
||||||
|
};
|
||||||
|
for (int i = 1; i <= 31; i++) {
|
||||||
|
special_tokens_after_merge.insert(special_tokens_after_merge.begin() + i - 1,
|
||||||
|
std::string(i, '\t'));
|
||||||
|
}
|
||||||
|
for (int i = 99; i <= 99; i++) {
|
||||||
|
special_tokens_after_merge.push_back("<unused" + std::to_string(i) + ">");
|
||||||
|
}
|
||||||
|
|
||||||
|
special_tokens = special_tokens_before_merge;
|
||||||
|
special_tokens.insert(special_tokens.end(),
|
||||||
|
special_tokens_after_merge.begin(),
|
||||||
|
special_tokens_after_merge.end());
|
||||||
|
|
||||||
|
if (merges_utf8_str.size() > 0 && vocab_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(load_gemma2_merges(), load_gemma2_vocab_json());
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -14,4 +14,13 @@ public:
|
|||||||
explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
explicit GemmaTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Gemma2Tokenizer : public BPETokenizer {
|
||||||
|
protected:
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
|
||||||
|
std::string normalize(const std::string& text) const override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Gemma2Tokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||||
|
};
|
||||||
|
|
||||||
#endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__
|
#endif // __SD_TOKENIZERS_GEMMA_TOKENIZER_H__
|
||||||
|
|||||||
3
src/tokenizers/vocab/gemma2_merges.hpp
Normal file
3
src/tokenizers/vocab/gemma2_merges.hpp
Normal file
File diff suppressed because one or more lines are too long
3
src/tokenizers/vocab/gemma2_vocab.hpp
Normal file
3
src/tokenizers/vocab/gemma2_vocab.hpp
Normal file
File diff suppressed because one or more lines are too long
@ -1,5 +1,7 @@
|
|||||||
#include "vocab.h"
|
#include "vocab.h"
|
||||||
#include "clip_merges.hpp"
|
#include "clip_merges.hpp"
|
||||||
|
#include "gemma2_merges.hpp"
|
||||||
|
#include "gemma2_vocab.hpp"
|
||||||
#include "gemma_merges.hpp"
|
#include "gemma_merges.hpp"
|
||||||
#include "gemma_vocab.hpp"
|
#include "gemma_vocab.hpp"
|
||||||
#include "gpt_oss_merges.hpp"
|
#include "gpt_oss_merges.hpp"
|
||||||
@ -50,6 +52,16 @@ std::string load_gemma_vocab_json() {
|
|||||||
return json_str;
|
return json_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string load_gemma2_merges() {
|
||||||
|
std::string merges_utf8_str(reinterpret_cast<const char*>(gemma2_merges_utf8_c_str), sizeof(gemma2_merges_utf8_c_str));
|
||||||
|
return merges_utf8_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string load_gemma2_vocab_json() {
|
||||||
|
std::string json_str(reinterpret_cast<const char*>(gemma2_vocab_json_utf8_c_str), sizeof(gemma2_vocab_json_utf8_c_str));
|
||||||
|
return json_str;
|
||||||
|
}
|
||||||
|
|
||||||
std::string load_gpt_oss_merges() {
|
std::string load_gpt_oss_merges() {
|
||||||
std::string merges_utf8_str(reinterpret_cast<const char*>(gpt_oss_merges_utf8_c_str), sizeof(gpt_oss_merges_utf8_c_str));
|
std::string merges_utf8_str(reinterpret_cast<const char*>(gpt_oss_merges_utf8_c_str), sizeof(gpt_oss_merges_utf8_c_str));
|
||||||
return merges_utf8_str;
|
return merges_utf8_str;
|
||||||
|
|||||||
@ -11,6 +11,8 @@ std::string load_t5_tokenizer_json();
|
|||||||
std::string load_umt5_tokenizer_json();
|
std::string load_umt5_tokenizer_json();
|
||||||
std::string load_gemma_merges();
|
std::string load_gemma_merges();
|
||||||
std::string load_gemma_vocab_json();
|
std::string load_gemma_vocab_json();
|
||||||
|
std::string load_gemma2_merges();
|
||||||
|
std::string load_gemma2_vocab_json();
|
||||||
std::string load_gpt_oss_merges();
|
std::string load_gpt_oss_merges();
|
||||||
std::string load_gpt_oss_vocab_json();
|
std::string load_gpt_oss_vocab_json();
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user