feat: add flux2 support (#1016)

* add flux2 support

* rename qwenvl to llm

* add Flux2FlowDenoiser

* update docs
This commit is contained in:
leejet 2025-11-30 11:32:56 +08:00 committed by GitHub
parent 20345888a3
commit 52b67c538b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 489706 additions and 576 deletions

View File

@ -37,7 +37,8 @@ API and command-line option may change frequently.***
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) - SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md) - [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md) - [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md) - [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [Chroma](./docs/chroma.md) - [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md) - [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md) - [Qwen Image](./docs/qwen_image.md)
@ -118,7 +119,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [SD1.x/SD2.x/SDXL](./docs/sd.md) - [SD1.x/SD2.x/SDXL](./docs/sd.md)
- [SD3/SD3.5](./docs/sd3.md) - [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md) - [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [FLUX.1-Kontext-dev](./docs/kontext.md) - [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Chroma](./docs/chroma.md) - [Chroma](./docs/chroma.md)
- [🔥Qwen Image](./docs/qwen_image.md) - [🔥Qwen Image](./docs/qwen_image.md)

BIN
assets/flux2/example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 556 KiB

View File

@ -2,7 +2,7 @@
#define __CONDITIONER_HPP__ #define __CONDITIONER_HPP__
#include "clip.hpp" #include "clip.hpp"
#include "qwenvl.hpp" #include "llm.hpp"
#include "t5.hpp" #include "t5.hpp"
struct SDCondition { struct SDCondition {
@ -1623,61 +1623,72 @@ struct T5CLIPEmbedder : public Conditioner {
} }
}; };
struct Qwen2_5_VLCLIPEmbedder : public Conditioner { struct LLMEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer; SDVersion version;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl; std::shared_ptr<LLM::BPETokenizer> tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend, LLMEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {}, const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_QWEN_IMAGE,
const std::string prefix = "", const std::string prefix = "",
bool enable_vision = false) { bool enable_vision = false)
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, : version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
tokenizer = std::make_shared<LLM::MistralTokenizer>();
} else {
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
}
llm = std::make_shared<LLM::LLMRunner>(arch,
backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
"text_encoders.qwen2vl", "text_encoders.llm",
enable_vision); enable_vision);
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl"); llm->get_param_tensors(tensors, "text_encoders.llm");
} }
void alloc_params_buffer() override { void alloc_params_buffer() override {
qwenvl->alloc_params_buffer(); llm->alloc_params_buffer();
} }
void free_params_buffer() override { void free_params_buffer() override {
qwenvl->free_params_buffer(); llm->free_params_buffer();
} }
size_t get_params_buffer_size() override { size_t get_params_buffer_size() override {
size_t buffer_size = 0; size_t buffer_size = 0;
buffer_size += qwenvl->get_params_buffer_size(); buffer_size += llm->get_params_buffer_size();
return buffer_size; return buffer_size;
} }
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (qwenvl) { if (llm) {
qwenvl->set_weight_adapter(adapter); llm->set_weight_adapter(adapter);
} }
} }
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
std::pair<int, int> attn_range,
size_t max_length = 0, size_t max_length = 0,
size_t system_prompt_length = 0,
bool padding = false) { bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention; std::vector<std::pair<std::string, float>> parsed_attention;
if (system_prompt_length > 0) { parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f); if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length)); auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
parsed_attention.insert(parsed_attention.end(), parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(), new_parsed_attention.begin(),
new_parsed_attention.end()); new_parsed_attention.end());
} else {
parsed_attention = parse_prompt_attention(text);
} }
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
{ {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
@ -1693,12 +1704,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
for (const auto& item : parsed_attention) { for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first; const std::string& curr_text = item.first;
float curr_weight = item.second; float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr); std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight); weights.insert(weights.end(), curr_tokens.size(), curr_weight);
} }
tokenizer.pad_tokens(tokens, weights, max_length, padding); tokenizer->pad_tokens(tokens, weights, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) { // for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
@ -1713,9 +1724,10 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
const ConditionerParams& conditioner_params) override { const ConditionerParams& conditioner_params) override {
std::string prompt; std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds; std::vector<std::pair<int, ggml_tensor*>> image_embeds;
size_t system_prompt_length = 0; std::pair<int, int> prompt_attn_range;
int prompt_template_encode_start_idx = 34; int prompt_template_encode_start_idx = 34;
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) { std::set<int> out_layers;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline"); LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64; prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6; int image_embed_idx = 64 + 6;
@ -1727,7 +1739,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
for (int i = 0; i < conditioner_params.ref_images.size(); i++) { for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size; double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
int height = image.height; int height = image.height;
int width = image.width; int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor; int h_bar = static_cast<int>(std::round(height / factor)) * factor;
@ -1757,7 +1769,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
resized_image.data = nullptr; resized_image.data = nullptr;
ggml_tensor* image_embed = nullptr; ggml_tensor* image_embed = nullptr;
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx); llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed); image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6; image_embed_idx += 1 + image_embed->ne[1] + 6;
@ -1771,17 +1783,37 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
} }
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
system_prompt_length = prompt.size();
prompt += img_prompt; prompt += img_prompt;
prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text; prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt += "[/INST]";
} else { } else {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n"; prompt_template_encode_start_idx = 34;
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} }
auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false); auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
auto& tokens = std::get<0>(tokens_and_weights); auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights);
@ -1790,9 +1822,10 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
qwenvl->compute(n_threads, llm->compute(n_threads,
input_ids, input_ids,
image_embeds, image_embeds,
out_layers,
&hidden_states, &hidden_states,
work_ctx); work_ctx);
{ {
@ -1813,14 +1846,25 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t zero_pad_len = 0;
if (sd_version_is_flux2(version)) {
int64_t min_length = 512;
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx, ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32, GGML_TYPE_F32,
hidden_states->ne[0], hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx, hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]); hidden_states->ne[2]);
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3); float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
}); });

View File

@ -356,7 +356,7 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0; virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0; virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) { virtual std::vector<float> get_sigmas(uint32_t n, int /*image_seq_len*/, scheduler_t scheduler_type, SDVersion version) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler; std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) { switch (scheduler_type) {
@ -582,10 +582,14 @@ struct FluxFlowDenoiser : public Denoiser {
set_parameters(shift); set_parameters(shift);
} }
void set_parameters(float shift = 1.15f) { void set_shift(float shift) {
this->shift = shift; this->shift = shift;
for (int i = 1; i < TIMESTEPS + 1; i++) { }
sigmas[i - 1] = t_to_sigma(i / TIMESTEPS * TIMESTEPS);
void set_parameters(float shift) {
set_shift(shift);
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(i);
} }
} }
@ -627,6 +631,38 @@ struct FluxFlowDenoiser : public Denoiser {
} }
}; };
struct Flux2FlowDenoiser : public FluxFlowDenoiser {
Flux2FlowDenoiser() = default;
float compute_empirical_mu(uint32_t n, int image_seq_len) {
const float a1 = 8.73809524e-05f;
const float b1 = 1.89833333f;
const float a2 = 0.00016927f;
const float b2 = 0.45666666f;
if (image_seq_len > 4300) {
float mu = a2 * image_seq_len + b2;
return mu;
}
float m_200 = a2 * image_seq_len + b2;
float m_10 = a1 * image_seq_len + b1;
float a = (m_200 - m_10) / 190.0f;
float b = m_200 - 200.0f * a;
float mu = a * n + b;
return mu;
}
std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version) override {
float mu = compute_empirical_mu(n, image_seq_len);
LOG_DEBUG("Flux2FlowDenoiser: set shift to %.3f", mu);
set_shift(mu);
return Denoiser::get_sigmas(n, image_seq_len, scheduler_type, version);
}
};
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t; typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t

21
docs/flux2.md Normal file
View File

@ -0,0 +1,21 @@
# How to Use
## Download weights
- Download FLUX.2-dev
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
## Examples
```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
```
<img alt="flux2 example" src="../assets/flux2/example.png" />

View File

@ -14,7 +14,7 @@
## Examples ## Examples
``` ```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线 探索视觉生成基础模型的极限开创理解与生成一体化的未来。二、Qwen-Image的模型特色1、复杂文字渲染。支持中英渲染、自动布局 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3 .\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线 探索视觉生成基础模型的极限开创理解与生成一体化的未来。二、Qwen-Image的模型特色1、复杂文字渲染。支持中英渲染、自动布局 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
``` ```
<img alt="qwen example" src="../assets/qwen/example.png" /> <img alt="qwen example" src="../assets/qwen/example.png" />

View File

@ -20,7 +20,7 @@
### Qwen Image Edit ### Qwen Image Edit
``` ```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453 .\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
``` ```
<img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" /> <img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" />
@ -29,7 +29,7 @@
### Qwen Image Edit 2509 ### Qwen Image Edit 2509
``` ```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --qwen2vl_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'" .\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --llm_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
``` ```
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" /> <img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />

View File

@ -9,8 +9,10 @@ Options:
--clip_g <string> path to the clip-g text encoder --clip_g <string> path to the clip-g text encoder
--clip_vision <string> path to the clip-vision encoder --clip_vision <string> path to the clip-vision encoder
--t5xxl <string> path to the t5xxl text encoder --t5xxl <string> path to the t5xxl text encoder
--qwen2vl <string> path to the qwen2vl text encoder --llm <string> path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)
--qwen2vl_vision <string> path to the qwen2vl vit --llm_vision <string> path to the llm vit
--qwen2vl <string> alias of --llm. Deprecated.
--qwen2vl_vision <string> alias of --llm_vision. Deprecated.
--diffusion-model <string> path to the standalone diffusion model --diffusion-model <string> path to the standalone diffusion model
--high-noise-diffusion-model <string> path to the standalone high noise diffusion model --high-noise-diffusion-model <string> path to the standalone high noise diffusion model
--vae <string> path to standalone vae model --vae <string> path to standalone vae model
@ -33,7 +35,6 @@ Options:
-p, --prompt <string> the prompt to render -p, --prompt <string> the prompt to render
-n, --negative-prompt <string> the negative prompt (default: "") -n, --negative-prompt <string> the negative prompt (default: "")
--preview-path <string> path to write preview image to (default: ./preview.png) --preview-path <string> path to write preview image to (default: ./preview.png)
--easycache <string> enable EasyCache for DiT models, accepts optional "threshold,start_percent,end_percent" values (defaults to 0.2,0.15,0.95)
--upscale-model <string> path to esrgan model. --upscale-model <string> path to esrgan model.
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of -t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
CPU physical cores CPU physical cores
@ -100,20 +101,18 @@ Options:
-s, --seed RNG seed (default: 42, use random seed for < 0) -s, --seed RNG seed (default: 42, use random seed for < 0)
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, --sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow] --prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights --lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights
contain any quantized parameters, the at_runtime mode will be used; otherwise, contain any quantized parameters, the at_runtime mode will be used; otherwise,
immediately will be used.The immediately mode may have precision and immediately will be used.The immediately mode may have precision and
compatibility issues with quantized parameters, but it usually offers faster inference compatibility issues with quantized parameters, but it usually offers faster inference
speed and, in some cases, lower memory usage. The at_runtime mode, on the other speed and, in some cases, lower memory usage. The at_runtime mode, on the
hand, is exactly the opposite. other hand, is exactly the opposite.
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
default: discrete default: discrete
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--high-noise-scheduler (high noise) denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform,
simple], default: discrete
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)
-h, --help show this help message and exit -h, --help show this help message and exit
@ -121,4 +120,5 @@ Options:
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 --vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
(overrides --vae-tile-size) (overrides --vae-tile-size)
--preview preview method. must be one of the following [none, proj, tae, vae] (default is none) --preview preview method. must be one of the following [none, proj, tae, vae] (default is none)
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
``` ```

View File

@ -70,8 +70,8 @@ struct SDParams {
std::string clip_g_path; std::string clip_g_path;
std::string clip_vision_path; std::string clip_vision_path;
std::string t5xxl_path; std::string t5xxl_path;
std::string qwen2vl_path; std::string llm_path;
std::string qwen2vl_vision_path; std::string llm_vision_path;
std::string diffusion_model_path; std::string diffusion_model_path;
std::string high_noise_diffusion_model_path; std::string high_noise_diffusion_model_path;
std::string vae_path; std::string vae_path;
@ -174,8 +174,8 @@ void print_params(SDParams params) {
printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str()); printf(" llm_path: %s\n", params.llm_path.c_str());
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str()); printf(" llm_vision_path: %s\n", params.llm_vision_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str());
@ -532,14 +532,22 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--t5xxl", "--t5xxl",
"path to the t5xxl text encoder", "path to the t5xxl text encoder",
&params.t5xxl_path}, &params.t5xxl_path},
{"",
"--llm",
"path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)",
&params.llm_path},
{"",
"--llm_vision",
"path to the llm vit",
&params.llm_vision_path},
{"", {"",
"--qwen2vl", "--qwen2vl",
"path to the qwen2vl text encoder", "alias of --llm. Deprecated.",
&params.qwen2vl_path}, &params.llm_path},
{"", {"",
"--qwen2vl_vision", "--qwen2vl_vision",
"path to the qwen2vl vit", "alias of --llm_vision. Deprecated.",
&params.qwen2vl_vision_path}, &params.llm_vision_path},
{"", {"",
"--diffusion-model", "--diffusion-model",
"path to the standalone diffusion model", "path to the standalone diffusion model",
@ -1185,7 +1193,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
on_sample_method_arg}, on_sample_method_arg},
{"", {"",
"--prediction", "--prediction",
"prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow]", "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]",
on_prediction_arg}, on_prediction_arg},
{"", {"",
"--lora-apply-mode", "--lora-apply-mode",
@ -1230,7 +1238,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
on_relative_tile_size_arg}, on_relative_tile_size_arg},
{"", {"",
"--preview", "--preview",
std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")\n", std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")",
on_preview_arg}, on_preview_arg},
{"", {"",
"--easycache", "--easycache",
@ -1428,7 +1436,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler)); parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.llm_path, params.llm_vision_path}) {
if (!te.empty()) { if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", "; parameter_string += "TE: " + sd_basename(te) + ", ";
} }
@ -1845,8 +1853,8 @@ int main(int argc, const char* argv[]) {
params.clip_g_path.c_str(), params.clip_g_path.c_str(),
params.clip_vision_path.c_str(), params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(), params.t5xxl_path.c_str(),
params.qwen2vl_path.c_str(), params.llm_path.c_str(),
params.qwen2vl_vision_path.c_str(), params.llm_vision_path.c_str(),
params.diffusion_model_path.c_str(), params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(), params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(), params.vae_path.c_str(),

196
flux.hpp
View File

@ -14,9 +14,9 @@ namespace Flux {
struct MLPEmbedder : public UnaryBlock { struct MLPEmbedder : public UnaryBlock {
public: public:
MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { MLPEmbedder(int64_t in_dim, int64_t hidden_dim, bool bias = true) {
blocks["in_layer"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, hidden_dim, true)); blocks["in_layer"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, hidden_dim, bias));
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true)); blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, bias));
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
@ -89,12 +89,13 @@ namespace Flux {
public: public:
SelfAttention(int64_t dim, SelfAttention(int64_t dim,
int64_t num_heads = 8, int64_t num_heads = 8,
bool qkv_bias = false) bool qkv_bias = false,
bool proj_bias = true)
: num_heads(num_heads) { : num_heads(num_heads) {
int64_t head_dim = dim / num_heads; int64_t head_dim = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias)); blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim)); blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim)); blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim, proj_bias));
} }
std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { std::vector<struct ggml_tensor*> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
@ -155,10 +156,10 @@ namespace Flux {
int multiplier; int multiplier;
public: public:
Modulation(int64_t dim, bool is_double) Modulation(int64_t dim, bool is_double, bool bias = true)
: is_double(is_double) { : is_double(is_double) {
multiplier = is_double ? 6 : 3; multiplier = is_double ? 6 : 3;
blocks["lin"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * multiplier)); blocks["lin"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * multiplier, bias));
} }
std::vector<ModulationOut> forward(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { std::vector<ModulationOut> forward(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
@ -198,6 +199,7 @@ namespace Flux {
struct DoubleStreamBlock : public GGMLBlock { struct DoubleStreamBlock : public GGMLBlock {
bool prune_mod; bool prune_mod;
int idx = 0; int idx = 0;
bool use_mlp_silu_act;
public: public:
DoubleStreamBlock(int64_t hidden_size, DoubleStreamBlock(int64_t hidden_size,
@ -205,30 +207,35 @@ namespace Flux {
float mlp_ratio, float mlp_ratio,
int idx = 0, int idx = 0,
bool qkv_bias = false, bool qkv_bias = false,
bool prune_mod = false) bool prune_mod = false,
: idx(idx), prune_mod(prune_mod) { bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_mlp_silu_act = false)
: idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio; int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
if (!prune_mod) { int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
if (!prune_mod && !share_modulation) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true)); blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
} }
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias)); blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim)); blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
// img_mlp.1 is nn.GELU(approximate="tanh") // img_mlp.1 is nn.GELU(approximate="tanh")
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size)); blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
if (!prune_mod) { if (!prune_mod && !share_modulation) {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true)); blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
} }
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias)); blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim)); blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
// img_mlp.1 is nn.GELU(approximate="tanh") // img_mlp.1 is nn.GELU(approximate="tanh")
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size)); blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
} }
std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
@ -254,7 +261,9 @@ namespace Flux {
struct ggml_tensor* txt, struct ggml_tensor* txt,
struct ggml_tensor* vec, struct ggml_tensor* vec,
struct ggml_tensor* pe, struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) { struct ggml_tensor* mask = nullptr,
std::vector<ModulationOut> img_mods = {},
std::vector<ModulationOut> txt_mods = {}) {
// img: [N, n_img_token, hidden_size] // img: [N, n_img_token, hidden_size]
// txt: [N, n_txt_token, hidden_size] // txt: [N, n_txt_token, hidden_size]
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
@ -273,22 +282,24 @@ namespace Flux {
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]); auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]); auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
std::vector<ModulationOut> img_mods; if (img_mods.empty()) {
if (prune_mod) { if (prune_mod) {
img_mods = get_distil_img_mod(ctx, vec); img_mods = get_distil_img_mod(ctx, vec);
} else { } else {
auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]); auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]);
img_mods = img_mod->forward(ctx, vec); img_mods = img_mod->forward(ctx, vec);
} }
}
ModulationOut img_mod1 = img_mods[0]; ModulationOut img_mod1 = img_mods[0];
ModulationOut img_mod2 = img_mods[1]; ModulationOut img_mod2 = img_mods[1];
std::vector<ModulationOut> txt_mods; if (txt_mods.empty()) {
if (prune_mod) { if (prune_mod) {
txt_mods = get_distil_txt_mod(ctx, vec); txt_mods = get_distil_txt_mod(ctx, vec);
} else { } else {
auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]); auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]);
txt_mods = txt_mod->forward(ctx, vec); txt_mods = txt_mod->forward(ctx, vec);
} }
}
ModulationOut txt_mod1 = txt_mods[0]; ModulationOut txt_mod1 = txt_mods[0];
ModulationOut txt_mod2 = txt_mods[1]; ModulationOut txt_mod2 = txt_mods[1];
@ -338,7 +349,11 @@ namespace Flux {
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
if (use_mlp_silu_act) {
img_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, img_mlp_out);
} else {
img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out); img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out);
}
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate)); img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate));
@ -347,9 +362,12 @@ namespace Flux {
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
if (use_mlp_silu_act) {
txt_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, txt_mlp_out);
} else {
txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out); txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out);
}
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate)); txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
return {img, txt}; return {img, txt};
@ -363,6 +381,8 @@ namespace Flux {
int64_t mlp_hidden_dim; int64_t mlp_hidden_dim;
bool prune_mod; bool prune_mod;
int idx = 0; int idx = 0;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
public: public:
SingleStreamBlock(int64_t hidden_size, SingleStreamBlock(int64_t hidden_size,
@ -370,21 +390,28 @@ namespace Flux {
float mlp_ratio = 4.0f, float mlp_ratio = 4.0f,
int idx = 0, int idx = 0,
float qk_scale = 0.f, float qk_scale = 0.f,
bool prune_mod = false) bool prune_mod = false,
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod) { bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_mlp_silu_act = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads; int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale; float scale = qk_scale;
if (scale <= 0.f) { if (scale <= 0.f) {
scale = 1 / sqrt((float)head_dim); scale = 1 / sqrt((float)head_dim);
} }
mlp_hidden_dim = hidden_size * mlp_ratio; mlp_hidden_dim = hidden_size * mlp_ratio;
mlp_mult_factor = 1;
if (use_mlp_silu_act) {
mlp_mult_factor = 2;
}
blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)); blocks["linear1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size)); blocks["linear2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim)); blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false)); blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
// mlp_act is nn.GELU(approximate="tanh") // mlp_act is nn.GELU(approximate="tanh")
if (!prune_mod) { if (!prune_mod && !share_modulation) {
blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false)); blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false));
} }
} }
@ -398,7 +425,8 @@ namespace Flux {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* vec, struct ggml_tensor* vec,
struct ggml_tensor* pe, struct ggml_tensor* pe,
struct ggml_tensor* mask = nullptr) { struct ggml_tensor* mask = nullptr,
std::vector<ModulationOut> mods = {}) {
// x: [N, n_token, hidden_size] // x: [N, n_token, hidden_size]
// pe: [n_token, d_head/2, 2, 2] // pe: [n_token, d_head/2, 2, 2]
// return: [N, n_token, hidden_size] // return: [N, n_token, hidden_size]
@ -407,7 +435,11 @@ namespace Flux {
auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]); auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]);
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]); auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]); auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]);
ModulationOut mod; ModulationOut mod;
if (!mods.empty()) {
mod = mods[0];
} else {
if (prune_mod) { if (prune_mod) {
mod = get_distil_mod(ctx, vec); mod = get_distil_mod(ctx, vec);
} else { } else {
@ -415,6 +447,8 @@ namespace Flux {
mod = modulation->forward(ctx, vec)[0]; mod = modulation->forward(ctx, vec)[0];
} }
}
auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
@ -432,11 +466,11 @@ namespace Flux {
qkv_mlp, qkv_mlp,
qkv_mlp->ne[0], qkv_mlp->ne[0],
qkv_mlp->ne[1], qkv_mlp->ne[1],
mlp_hidden_dim, mlp_hidden_dim * mlp_mult_factor,
qkv_mlp->nb[1], qkv_mlp->nb[1],
qkv_mlp->nb[2], qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size] auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
int64_t head_dim = hidden_size / num_heads; int64_t head_dim = hidden_size / num_heads;
@ -447,7 +481,12 @@ namespace Flux {
k = norm->key_norm(ctx, k); k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, ggml_gelu_inplace(ctx->ggml_ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
}
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
output = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, output, mod.gate)); output = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, output, mod.gate));
@ -462,12 +501,13 @@ namespace Flux {
LastLayer(int64_t hidden_size, LastLayer(int64_t hidden_size,
int64_t patch_size, int64_t patch_size,
int64_t out_channels, int64_t out_channels,
bool prune_mod = false) bool prune_mod = false,
bool bias = true)
: prune_mod(prune_mod) { : prune_mod(prune_mod) {
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels)); blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels, bias));
if (!prune_mod) { if (!prune_mod) {
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size)); blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size, bias));
} }
} }
@ -684,6 +724,10 @@ namespace Flux {
bool qkv_bias = true; bool qkv_bias = true;
bool guidance_embed = true; bool guidance_embed = true;
int64_t in_dim = 64; int64_t in_dim = 64;
bool disable_bias = false;
bool share_modulation = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params; ChromaRadianceParams chroma_radiance_params;
}; };
@ -702,18 +746,20 @@ namespace Flux {
kernel_size, kernel_size,
stride); stride);
} else { } else {
blocks["img_in"] = std::make_shared<Linear>(params.in_channels, params.hidden_size, true); blocks["img_in"] = std::make_shared<Linear>(params.in_channels, params.hidden_size, !params.disable_bias);
} }
if (params.is_chroma) { if (params.is_chroma) {
blocks["distilled_guidance_layer"] = std::make_shared<ChromaApproximator>(params.in_dim, params.hidden_size); blocks["distilled_guidance_layer"] = std::make_shared<ChromaApproximator>(params.in_dim, params.hidden_size);
} else { } else {
blocks["time_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size); blocks["time_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
blocks["vector_in"] = std::make_shared<MLPEmbedder>(params.vec_in_dim, params.hidden_size); if (params.vec_in_dim > 0) {
blocks["vector_in"] = std::make_shared<MLPEmbedder>(params.vec_in_dim, params.hidden_size, !params.disable_bias);
}
if (params.guidance_embed) { if (params.guidance_embed) {
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size); blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
} }
} }
blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, true); blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, !params.disable_bias);
for (int i = 0; i < params.depth; i++) { for (int i = 0; i < params.depth; i++) {
blocks["double_blocks." + std::to_string(i)] = std::make_shared<DoubleStreamBlock>(params.hidden_size, blocks["double_blocks." + std::to_string(i)] = std::make_shared<DoubleStreamBlock>(params.hidden_size,
@ -721,7 +767,10 @@ namespace Flux {
params.mlp_ratio, params.mlp_ratio,
i, i,
params.qkv_bias, params.qkv_bias,
params.is_chroma); params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_mlp_silu_act);
} }
for (int i = 0; i < params.depth_single_blocks; i++) { for (int i = 0; i < params.depth_single_blocks; i++) {
@ -730,7 +779,10 @@ namespace Flux {
params.mlp_ratio, params.mlp_ratio,
i, i,
0.f, 0.f,
params.is_chroma); params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_mlp_silu_act);
} }
if (params.version == VERSION_CHROMA_RADIANCE) { if (params.version == VERSION_CHROMA_RADIANCE) {
@ -748,7 +800,13 @@ namespace Flux {
params.in_channels); params.in_channels);
} else { } else {
blocks["final_layer"] = std::make_shared<LastLayer>(params.hidden_size, 1, params.out_channels, params.is_chroma); blocks["final_layer"] = std::make_shared<LastLayer>(params.hidden_size, 1, params.out_channels, params.is_chroma, !params.disable_bias);
}
if (params.share_modulation) {
blocks["double_stream_modulation_img"] = std::make_shared<Modulation>(params.hidden_size, true, !params.disable_bias);
blocks["double_stream_modulation_txt"] = std::make_shared<Modulation>(params.hidden_size, true, !params.disable_bias);
blocks["single_stream_modulation"] = std::make_shared<Modulation>(params.hidden_size, false, !params.disable_bias);
} }
} }
@ -862,7 +920,6 @@ namespace Flux {
} }
} else { } else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]); auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f)); vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) { if (params.guidance_embed) {
GGML_ASSERT(guidance != nullptr); GGML_ASSERT(guidance != nullptr);
@ -872,8 +929,24 @@ namespace Flux {
vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in)); vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in));
} }
if (params.vec_in_dim > 0) {
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y)); vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y));
} }
}
std::vector<ModulationOut> ds_img_mods;
std::vector<ModulationOut> ds_txt_mods;
std::vector<ModulationOut> ss_mods;
if (params.share_modulation) {
auto double_stream_modulation_img = std::dynamic_pointer_cast<Modulation>(blocks["double_stream_modulation_img"]);
auto double_stream_modulation_txt = std::dynamic_pointer_cast<Modulation>(blocks["double_stream_modulation_txt"]);
auto single_stream_modulation = std::dynamic_pointer_cast<Modulation>(blocks["single_stream_modulation"]);
ds_img_mods = double_stream_modulation_img->forward(ctx, vec);
ds_txt_mods = double_stream_modulation_txt->forward(ctx, vec);
ss_mods = single_stream_modulation->forward(ctx, vec);
}
txt = txt_in->forward(ctx, txt); txt = txt_in->forward(ctx, txt);
@ -884,7 +957,7 @@ namespace Flux {
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask, ds_img_mods, ds_txt_mods);
img = img_txt.first; // [N, n_img_token, hidden_size] img = img_txt.first; // [N, n_img_token, hidden_size]
txt = img_txt.second; // [N, n_txt_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size]
} }
@ -896,7 +969,7 @@ namespace Flux {
} }
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]); auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
} }
txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
@ -1133,6 +1206,22 @@ namespace Flux {
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
flux_params.in_channels = 3; flux_params.in_channels = 3;
flux_params.patch_size = 16; flux_params.patch_size = 16;
} else if (sd_version_is_flux2(version)) {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128;
flux_params.hidden_size = 6144;
flux_params.num_heads = 48;
flux_params.patch_size = 1;
flux_params.out_channels = 128;
flux_params.mlp_ratio = 3.f;
flux_params.theta = 2000;
flux_params.axes_dim = {32, 32, 32, 32};
flux_params.vec_in_dim = 0;
flux_params.qkv_bias = false;
flux_params.disable_bias = true;
flux_params.share_modulation = true;
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
} }
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
@ -1281,7 +1370,8 @@ namespace Flux {
x->ne[3], x->ne[3],
context->ne[1], context->ne[1],
ref_latents, ref_latents,
increase_ref_index, sd_version_is_flux2(version) ? true : increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta, flux_params.theta,
flux_params.axes_dim); flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
@ -1360,9 +1450,9 @@ namespace Flux {
// cpu f16: // cpu f16:
// cuda f16: nan // cuda f16: nan
// cuda q8_0: pass // cuda q8_0: pass
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 128, 1);
// ggml_set_f32(x, 0.01f); // ggml_set_f32(x, 0.01f);
auto x = load_tensor_from_file(work_ctx, "chroma_x.bin"); // auto x = load_tensor_from_file(work_ctx, "chroma_x.bin");
// print_ggml_tensor(x); // print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 1.f); std::vector<float> timesteps_vec(1, 1.f);
@ -1371,9 +1461,9 @@ namespace Flux {
std::vector<float> guidance_vec(1, 0.f); std::vector<float> guidance_vec(1, 0.f);
auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec); auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec);
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 15360, 256, 1);
// ggml_set_f32(context, 0.01f); // ggml_set_f32(context, 0.01f);
auto context = load_tensor_from_file(work_ctx, "chroma_context.bin"); // auto context = load_tensor_from_file(work_ctx, "chroma_context.bin");
// print_ggml_tensor(context); // print_ggml_tensor(context);
// auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); // auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
@ -1395,7 +1485,7 @@ namespace Flux {
static void load_from_file_and_test(const std::string& file_path) { static void load_from_file_and_test(const std::string& file_path) {
// ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init(); ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0; ggml_type model_data_type = GGML_TYPE_COUNT;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) {
@ -1404,17 +1494,19 @@ namespace Flux {
} }
auto& tensor_storage_map = model_loader.get_tensor_storage_map(); auto& tensor_storage_map = model_loader.get_tensor_storage_map();
if (model_data_type != GGML_TYPE_COUNT) {
for (auto& [name, tensor_storage] : tensor_storage_map) { for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) { if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type; tensor_storage.expected_type = model_data_type;
} }
} }
}
std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend, std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend,
false, false,
tensor_storage_map, tensor_storage_map,
"model.diffusion_model", "model.diffusion_model",
VERSION_CHROMA_RADIANCE, VERSION_FLUX2,
false); false);
flux->alloc_params_buffer(); flux->alloc_params_buffer();

View File

@ -760,6 +760,21 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
return chunks; return chunks;
} }
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x) {
// x: [ne3, ne2, ne1, ne0]
// return: [ne3, ne2, ne1, ne0/2]
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0);
auto x1 = x_vec[0]; // [ne3, ne2, ne1, ne0/2]
auto x2 = x_vec[1]; // [ne3, ne2, ne1, ne0/2]
x1 = ggml_gelu_inplace(ctx, x1);
x = ggml_mul(ctx, x1, x2); // [ne3, ne2, ne1, ne0/2]
return x;
}
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process; typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
__STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim, __STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim,

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@
#include "stable-diffusion.h" #include "stable-diffusion.h"
#include "util.h" #include "util.h"
#include "vocab.hpp" #include "vocab.hpp"
#include "vocab_mistral.hpp"
#include "vocab_qwen.hpp" #include "vocab_qwen.hpp"
#include "vocab_umt5.hpp" #include "vocab_umt5.hpp"
@ -104,8 +105,9 @@ const char* unused_tensors[] = {
"denoiser.sigmas", "denoiser.sigmas",
"edm_vpred.sigma_max", "edm_vpred.sigma_max",
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.qwen2vl.output.weight", "text_encoders.llm.output.weight",
"text_encoders.qwen2vl.lm_head.", "text_encoders.llm.lm_head.",
"first_stage_model.bn.",
}; };
bool is_unused_tensor(std::string name) { bool is_unused_tensor(std::string name) {
@ -1062,6 +1064,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) {
return VERSION_QWEN_IMAGE; return VERSION_QWEN_IMAGE;
} }
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
return VERSION_FLUX2;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true; is_wan = true;
} }
@ -1320,6 +1325,16 @@ std::string ModelLoader::load_qwen2_merges() {
return merges_utf8_str; return merges_utf8_str;
} }
std::string ModelLoader::load_mistral_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(mistral_merges_utf8_c_str), sizeof(mistral_merges_utf8_c_str));
return merges_utf8_str;
}
std::string ModelLoader::load_mistral_vocab_json() {
std::string json_str(reinterpret_cast<const char*>(mistral_vocab_json_utf8_c_str), sizeof(mistral_vocab_json_utf8_c_str));
return json_str;
}
std::string ModelLoader::load_t5_tokenizer_json() { std::string ModelLoader::load_t5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str)); std::string json_str(reinterpret_cast<const char*>(t5_tokenizer_json_str), sizeof(t5_tokenizer_json_str));
return json_str; return json_str;

11
model.h
View File

@ -43,6 +43,7 @@ enum SDVersion {
VERSION_WAN2_2_I2V, VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V, VERSION_WAN2_2_TI2V,
VERSION_QWEN_IMAGE, VERSION_QWEN_IMAGE,
VERSION_FLUX2,
VERSION_COUNT, VERSION_COUNT,
}; };
@ -94,6 +95,13 @@ static inline bool sd_version_is_flux(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_flux2(SDVersion version) {
if (version == VERSION_FLUX2) {
return true;
}
return false;
}
static inline bool sd_version_is_wan(SDVersion version) { static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) {
return true; return true;
@ -121,6 +129,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
static inline bool sd_version_is_dit(SDVersion version) { static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || if (sd_version_is_flux(version) ||
sd_version_is_flux2(version) ||
sd_version_is_sd3(version) || sd_version_is_sd3(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_qwen_image(version)) { sd_version_is_qwen_image(version)) {
@ -313,6 +322,8 @@ public:
static std::string load_merges(); static std::string load_merges();
static std::string load_qwen2_merges(); static std::string load_qwen2_merges();
static std::string load_mistral_merges();
static std::string load_mistral_vocab_json();
static std::string load_t5_tokenizer_json(); static std::string load_t5_tokenizer_json();
static std::string load_umt5_tokenizer_json(); static std::string load_umt5_tokenizer_json();
}; };

View File

@ -127,7 +127,7 @@ std::string convert_cond_stage_model_name(std::string name, std::string prefix)
{"token_embd.", "shared."}, {"token_embd.", "shared."},
}; };
static const std::vector<std::pair<std::string, std::string>> qwenvl_name_map{ static const std::vector<std::pair<std::string, std::string>> llm_name_map{
{"token_embd.", "model.embed_tokens."}, {"token_embd.", "model.embed_tokens."},
{"blk.", "model.layers."}, {"blk.", "model.layers."},
{"attn_q.", "self_attn.q_proj."}, {"attn_q.", "self_attn.q_proj."},
@ -142,7 +142,7 @@ std::string convert_cond_stage_model_name(std::string name, std::string prefix)
{"output_norm.", "model.norm."}, {"output_norm.", "model.norm."},
}; };
static const std::vector<std::pair<std::string, std::string>> qwenvl_vision_name_map{ static const std::vector<std::pair<std::string, std::string>> llm_vision_name_map{
{"mm.", "merger.mlp."}, {"mm.", "merger.mlp."},
{"v.post_ln.", "merger.ln_q."}, {"v.post_ln.", "merger.ln_q."},
{"v.patch_embd.weight", "patch_embed.proj.0.weight"}, {"v.patch_embd.weight", "patch_embed.proj.0.weight"},
@ -161,11 +161,11 @@ std::string convert_cond_stage_model_name(std::string name, std::string prefix)
}; };
if (contains(name, "t5xxl")) { if (contains(name, "t5xxl")) {
replace_with_name_map(name, t5_name_map); replace_with_name_map(name, t5_name_map);
} else if (contains(name, "qwen2vl")) { } else if (contains(name, "llm")) {
if (contains(name, "qwen2vl.visual")) { if (contains(name, "llm.visual")) {
replace_with_name_map(name, qwenvl_vision_name_map); replace_with_name_map(name, llm_vision_name_map);
} else { } else {
replace_with_name_map(name, qwenvl_name_map); replace_with_name_map(name, llm_name_map);
} }
} else { } else {
name = convert_open_clip_to_hf_clip_name(name); name = convert_open_clip_to_hf_clip_name(name);
@ -620,7 +620,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
name = convert_diffusers_unet_to_original_sdxl(name); name = convert_diffusers_unet_to_original_sdxl(name);
} else if (sd_version_is_sd3(version)) { } else if (sd_version_is_sd3(version)) {
name = convert_diffusers_dit_to_original_sd3(name); name = convert_diffusers_dit_to_original_sd3(name);
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) {
name = convert_diffusers_dit_to_original_flux(name); name = convert_diffusers_dit_to_original_flux(name);
} }
return name; return name;
@ -722,6 +722,11 @@ std::string convert_diffusers_vae_to_original_sd1(std::string name) {
} }
std::string convert_first_stage_model_name(std::string name, std::string prefix) { std::string convert_first_stage_model_name(std::string name, std::string prefix) {
static std::unordered_map<std::string, std::string> vae_name_map = {
{"decoder.post_quant_conv.", "post_quant_conv."},
{"encoder.quant_conv.", "quant_conv."},
};
replace_with_prefix_map(name, vae_name_map);
name = convert_diffusers_vae_to_original_sd1(name); name = convert_diffusers_vae_to_original_sd1(name);
return name; return name;
} }

View File

@ -72,15 +72,28 @@ namespace Rope {
} }
// Generate IDs for image patches and text // Generate IDs for image patches and text
__STATIC_INLINE__ std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) { __STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0)); auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
if (axes_dim_num == 4) {
for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][3] = (i % context_len);
}
}
return txt_ids;
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { __STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
int w,
int patch_size,
int bs,
int axes_dim_num,
int index = 0,
int h_offset = 0,
int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0)); std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len); std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len); std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
@ -153,8 +166,10 @@ namespace Rope {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs, int bs,
int axes_dim_num,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { bool increase_ref_index,
float ref_index_scale) {
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
@ -170,7 +185,14 @@ namespace Rope {
} }
} }
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, index, h_offset, w_offset); auto ref_ids = gen_flux_img_ids(ref->ne[1],
ref->ne[0],
patch_size,
bs,
axes_dim_num,
static_cast<int>(index * ref_index_scale),
h_offset,
w_offset);
ids = concat_ids(ids, ref_ids, bs); ids = concat_ids(ids, ref_ids, bs);
if (increase_ref_index) { if (increase_ref_index) {
@ -187,15 +209,17 @@ namespace Rope {
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int axes_dim_num,
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { bool increase_ref_index,
auto txt_ids = gen_txt_ids(bs, context_len); float ref_index_scale) {
auto img_ids = gen_img_ids(h, w, patch_size, bs); auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;
@ -209,9 +233,18 @@ namespace Rope {
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_flux_ids(h,
w,
patch_size,
bs,
static_cast<int>(axes_dim.size()),
context_len,
ref_latents,
increase_ref_index,
ref_index_scale);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
@ -232,10 +265,11 @@ namespace Rope {
txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]}; txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]};
} }
} }
auto img_ids = gen_img_ids(h, w, patch_size, bs); int axes_dim_num = 3;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;

View File

@ -44,6 +44,7 @@ const char* model_version_to_str[] = {
"Wan 2.2 I2V", "Wan 2.2 I2V",
"Wan 2.2 TI2V", "Wan 2.2 TI2V",
"Qwen Image", "Qwen Image",
"Flux.2",
}; };
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
@ -275,17 +276,17 @@ public:
} }
} }
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->llm_path)) > 0) {
LOG_INFO("loading qwen2vl from '%s'", sd_ctx_params->qwen2vl_path); LOG_INFO("loading llm from '%s'", sd_ctx_params->llm_path);
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_path, "text_encoders.qwen2vl.")) { if (!model_loader.init_from_file(sd_ctx_params->llm_path, "text_encoders.llm.")) {
LOG_WARN("loading qwen2vl from '%s' failed", sd_ctx_params->qwen2vl_path); LOG_WARN("loading llm from '%s' failed", sd_ctx_params->llm_path);
} }
} }
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_vision_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->llm_vision_path)) > 0) {
LOG_INFO("loading qwen2vl vision from '%s'", sd_ctx_params->qwen2vl_vision_path); LOG_INFO("loading llm vision from '%s'", sd_ctx_params->llm_vision_path);
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_vision_path, "text_encoders.qwen2vl.visual.")) { if (!model_loader.init_from_file(sd_ctx_params->llm_vision_path, "text_encoders.llm.visual.")) {
LOG_WARN("loading qwen2vl vision from '%s' failed", sd_ctx_params->qwen2vl_vision_path); LOG_WARN("loading llm vision from '%s' failed", sd_ctx_params->llm_vision_path);
} }
} }
@ -306,7 +307,7 @@ public:
auto& tensor_storage_map = model_loader.get_tensor_storage_map(); auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) { for (auto& [name, tensor_storage] : tensor_storage_map) {
if (contains(name, "qwen2vl") && if (contains(name, "llm") &&
ends_with(name, "weight") && ends_with(name, "weight") &&
(tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) { (tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) {
tensor_storage.expected_type = GGML_TYPE_F16; tensor_storage.expected_type = GGML_TYPE_F16;
@ -379,8 +380,11 @@ public:
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
scale_factor = 0.3611f; scale_factor = 0.3611f;
shift_factor = 0.1159f; shift_factor = 0.1159f;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { } else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_flux2(version)) {
scale_factor = 1.0f; scale_factor = 1.0f;
shift_factor = 0.f;
} }
if (sd_version_is_control(version)) { if (sd_version_is_control(version)) {
@ -436,6 +440,17 @@ public:
tensor_storage_map, tensor_storage_map,
version, version,
sd_ctx_params->chroma_use_dit_mask); sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_flux2(version)) {
bool is_chroma = false;
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version);
diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_wan(version)) { } else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend, cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
@ -469,9 +484,10 @@ public:
if (!vae_decode_only) { if (!vae_decode_only) {
enable_vision = true; enable_vision = true;
} }
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend, cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
version,
"", "",
enable_vision); enable_vision);
diffusion_model = std::make_shared<QwenImageModel>(backend, diffusion_model = std::make_shared<QwenImageModel>(backend,
@ -668,7 +684,7 @@ public:
ignore_tensors.insert("first_stage_model.encoder"); ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.conv1"); ignore_tensors.insert("first_stage_model.conv1");
ignore_tensors.insert("first_stage_model.quant"); ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("text_encoders.qwen2vl.visual."); ignore_tensors.insert("text_encoders.llm.visual.");
} }
if (version == VERSION_SVD) { if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3"); ignore_tensors.insert("conditioner.embedders.3");
@ -786,6 +802,11 @@ public:
denoiser = std::make_shared<FluxFlowDenoiser>(shift); denoiser = std::make_shared<FluxFlowDenoiser>(shift);
break; break;
} }
case FLUX2_FLOW_PRED: {
LOG_INFO("running in Flux2 FLOW mode");
denoiser = std::make_shared<Flux2FlowDenoiser>();
break;
}
default: { default: {
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction); LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
return false; return false;
@ -830,6 +851,9 @@ public:
} }
} }
denoiser = std::make_shared<FluxFlowDenoiser>(shift); denoiser = std::make_shared<FluxFlowDenoiser>(shift);
} else if (sd_version_is_flux2(version)) {
LOG_INFO("running in Flux2 FLOW mode");
denoiser = std::make_shared<Flux2FlowDenoiser>();
} else if (sd_version_is_wan(version)) { } else if (sd_version_is_wan(version)) {
LOG_INFO("running in FLOW mode"); LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift; float shift = sd_ctx_params->flow_shift;
@ -1826,6 +1850,8 @@ public:
int vae_scale_factor = 8; int vae_scale_factor = 8;
if (version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2_2_TI2V) {
vae_scale_factor = 16; vae_scale_factor = 16;
} else if (sd_version_is_flux2(version)) {
vae_scale_factor = 16;
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
vae_scale_factor = 1; vae_scale_factor = 1;
} }
@ -1839,6 +1865,8 @@ public:
latent_channel = 48; latent_channel = 48;
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
latent_channel = 3; latent_channel = 3;
} else if (sd_version_is_flux2(version)) {
latent_channel = 128;
} else { } else {
latent_channel = 16; latent_channel = 16;
} }
@ -1846,6 +1874,11 @@ public:
return latent_channel; return latent_channel;
} }
int get_image_seq_len(int h, int w) {
int vae_scale_factor = get_vae_scale_factor();
return (h / vae_scale_factor) * (w / vae_scale_factor);
}
ggml_tensor* generate_init_latent(ggml_context* work_ctx, ggml_tensor* generate_init_latent(ggml_context* work_ctx,
int width, int width,
int height, int height,
@ -1869,14 +1902,14 @@ public:
return init_latent; return init_latent;
} }
void process_latent_in(ggml_tensor* latent) { void get_latents_mean_std_vec(ggml_tensor* latent, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { GGML_ASSERT(latent->ne[channel_dim] == 16 || latent->ne[channel_dim] == 48 || latent->ne[channel_dim] == 128);
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48); if (latent->ne[channel_dim] == 16) {
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; 0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f}; 3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
if (latent->ne[3] == 48) { } else if (latent->ne[channel_dim] == 48) {
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f, latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f, -0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f, -0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
@ -1890,11 +1923,63 @@ public:
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
} else if (latent->ne[channel_dim] == 128) {
// flux2
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
latents_std_vec = {
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
} }
}
void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_flux2(version)) {
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latent->ne[3]; i++) { for (int i = 0; i < latent->ne[3]; i++) {
float mean = latents_mean_vec[i]; if (channel_dim == 3) {
float std_ = latents_std_vec[i]; mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latent->ne[2]; j++) { for (int j = 0; j < latent->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int k = 0; k < latent->ne[1]; k++) { for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) { for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i); float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
@ -1916,31 +2001,24 @@ public:
} }
void process_latent_out(ggml_tensor* latent) { void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_flux2(version)) {
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48); int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, std::vector<float> latents_mean_vec;
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; std::vector<float> latents_std_vec;
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
if (latent->ne[3] == 48) { float mean;
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f, float std_;
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
latents_std_vec = {
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
}
for (int i = 0; i < latent->ne[3]; i++) { for (int i = 0; i < latent->ne[3]; i++) {
float mean = latents_mean_vec[i]; if (channel_dim == 3) {
float std_ = latents_std_vec[i]; mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latent->ne[2]; j++) { for (int j = 0; j < latent->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int k = 0; k < latent->ne[1]; k++) { for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) { for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i); float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
@ -2087,6 +2165,7 @@ public:
if (use_tiny_autoencoder || if (use_tiny_autoencoder ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
version == VERSION_CHROMA_RADIANCE) { version == VERSION_CHROMA_RADIANCE) {
latent = vae_output; latent = vae_output;
} else if (version == VERSION_SD1_PIX2PIX) { } else if (version == VERSION_SD1_PIX2PIX) {
@ -2292,6 +2371,7 @@ const char* prediction_to_str[] = {
"edm_v", "edm_v",
"sd3_flow", "sd3_flow",
"flux_flow", "flux_flow",
"flux2_flow",
}; };
const char* sd_prediction_name(enum prediction_t prediction) { const char* sd_prediction_name(enum prediction_t prediction) {
@ -2396,8 +2476,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"clip_g_path: %s\n" "clip_g_path: %s\n"
"clip_vision_path: %s\n" "clip_vision_path: %s\n"
"t5xxl_path: %s\n" "t5xxl_path: %s\n"
"qwen2vl_path: %s\n" "llm_path: %s\n"
"qwen2vl_vision_path: %s\n" "llm_vision_path: %s\n"
"diffusion_model_path: %s\n" "diffusion_model_path: %s\n"
"high_noise_diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n"
"vae_path: %s\n" "vae_path: %s\n"
@ -2427,8 +2507,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->clip_g_path), SAFE_STR(sd_ctx_params->clip_g_path),
SAFE_STR(sd_ctx_params->clip_vision_path), SAFE_STR(sd_ctx_params->clip_vision_path),
SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->t5xxl_path),
SAFE_STR(sd_ctx_params->qwen2vl_path), SAFE_STR(sd_ctx_params->llm_path),
SAFE_STR(sd_ctx_params->qwen2vl_vision_path), SAFE_STR(sd_ctx_params->llm_vision_path),
SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->vae_path),
@ -3062,7 +3142,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
int sample_steps = sd_img_gen_params->sample_params.sample_steps; int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
ggml_tensor* init_latent = nullptr; ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr; ggml_tensor* concat_latent = nullptr;
@ -3315,7 +3398,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps; total_steps += high_noise_sample_steps;
} }
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, 0, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);
if (high_noise_sample_steps < 0) { if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b) // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)

View File

@ -71,6 +71,7 @@ enum prediction_t {
EDM_V_PRED, EDM_V_PRED,
SD3_FLOW_PRED, SD3_FLOW_PRED,
FLUX_FLOW_PRED, FLUX_FLOW_PRED,
FLUX2_FLOW_PRED,
PREDICTION_COUNT PREDICTION_COUNT
}; };
@ -156,8 +157,8 @@ typedef struct {
const char* clip_g_path; const char* clip_g_path;
const char* clip_vision_path; const char* clip_vision_path;
const char* t5xxl_path; const char* t5xxl_path;
const char* qwen2vl_path; const char* llm_path;
const char* qwen2vl_vision_path; const char* llm_vision_path;
const char* diffusion_model_path; const char* diffusion_model_path;
const char* high_noise_diffusion_model_path; const char* high_noise_diffusion_model_path;
const char* vae_path; const char* vae_path;

View File

@ -811,6 +811,8 @@ bool starts_with(const std::vector<char32_t>& text,
return std::equal(prefix.begin(), prefix.end(), text.begin() + index); return std::equal(prefix.begin(), prefix.end(), text.begin() + index);
} }
// mistral: [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
// qwen2: (?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
std::vector<std::string> token_split(const std::string& text) { std::vector<std::string> token_split(const std::string& text) {
std::vector<std::string> tokens; std::vector<std::string> tokens;
auto cps = utf8_to_codepoints(text); auto cps = utf8_to_codepoints(text);

52
vae.hpp
View File

@ -487,6 +487,7 @@ public:
// ldm.models.autoencoder.AutoencoderKL // ldm.models.autoencoder.AutoencoderKL
class AutoencodingEngine : public GGMLBlock { class AutoencodingEngine : public GGMLBlock {
protected: protected:
SDVersion version;
bool decode_only = true; bool decode_only = true;
bool use_video_decoder = false; bool use_video_decoder = false;
bool use_quant = true; bool use_quant = true;
@ -507,10 +508,15 @@ public:
bool decode_only = true, bool decode_only = true,
bool use_linear_projection = false, bool use_linear_projection = false,
bool use_video_decoder = false) bool use_video_decoder = false)
: decode_only(decode_only), use_video_decoder(use_video_decoder) { : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) { if (sd_version_is_dit(version)) {
dd_config.z_channels = 16; if (sd_version_is_flux2(version)) {
dd_config.z_channels = 32;
embed_dim = 32;
} else {
use_quant = false; use_quant = false;
dd_config.z_channels = 16;
}
} }
if (use_video_decoder) { if (use_video_decoder) {
use_quant = false; use_quant = false;
@ -547,6 +553,24 @@ public:
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w] // z: [N, z_channels, h, w]
if (sd_version_is_flux2(version)) {
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
int64_t p = 2;
int64_t N = z->ne[3];
int64_t C = z->ne[2] / p / p;
int64_t h = z->ne[1];
int64_t w = z->ne[0];
int64_t H = h * p;
int64_t W = w * p;
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
}
if (use_quant) { if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]); auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w] z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
@ -563,12 +587,30 @@ public:
// x: [N, in_channels, h, w] // x: [N, in_channels, h, w]
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]); auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
auto h = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8] auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
if (use_quant) { if (use_quant) {
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]); auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
h = quant_conv->forward(ctx, h); // [N, 2*embed_dim, h/8, w/8] z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
} }
return h; if (sd_version_is_flux2(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
int64_t p = 2;
int64_t N = z->ne[3];
int64_t C = z->ne[2];
int64_t H = z->ne[1];
int64_t W = z->ne[0];
int64_t h = H / p;
int64_t w = W / p;
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
}
return z;
} }
}; };

488508
vocab_mistral.hpp Normal file

File diff suppressed because it is too large Load Diff