mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
add support for flux2 klein (#1193)
* add support for flux2 klein 4b * add support for flux2 klein 8b * use attention_mask in Flux.2 klein LLMEmbedder * update docs
This commit is contained in:
parent
fbce16e02d
commit
9565c7f6bd
@ -43,8 +43,8 @@ API and command-line option may change frequently.***
|
|||||||
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
||||||
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
|
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
|
||||||
- [FLUX.2-dev](./docs/flux2.md)
|
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
||||||
- [Qwen Image](./docs/qwen_image.md)
|
- [Qwen Image](./docs/qwen_image.md)
|
||||||
@ -127,8 +127,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
|
|||||||
|
|
||||||
- [SD1.x/SD2.x/SDXL](./docs/sd.md)
|
- [SD1.x/SD2.x/SDXL](./docs/sd.md)
|
||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
|
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
|
||||||
- [FLUX.2-dev](./docs/flux2.md)
|
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
|
||||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
- [🔥Qwen Image](./docs/qwen_image.md)
|
- [🔥Qwen Image](./docs/qwen_image.md)
|
||||||
|
|||||||
BIN
assets/flux2/flux2-klein-4b-edit.png
Normal file
BIN
assets/flux2/flux2-klein-4b-edit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 510 KiB |
BIN
assets/flux2/flux2-klein-4b.png
Normal file
BIN
assets/flux2/flux2-klein-4b.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 455 KiB |
BIN
assets/flux2/flux2-klein-9b-edit.png
Normal file
BIN
assets/flux2/flux2-klein-9b-edit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 511 KiB |
BIN
assets/flux2/flux2-klein-9b.png
Normal file
BIN
assets/flux2/flux2-klein-9b.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 491 KiB |
BIN
assets/flux2/flux2-klein-base-4b.png
Normal file
BIN
assets/flux2/flux2-klein-base-4b.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 464 KiB |
BIN
assets/flux2/flux2-klein-base-9b.png
Normal file
BIN
assets/flux2/flux2-klein-base-9b.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 552 KiB |
@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
bool enable_vision = false)
|
bool enable_vision = false)
|
||||||
: version(version) {
|
: version(version) {
|
||||||
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
||||||
if (sd_version_is_flux2(version)) {
|
if (version == VERSION_FLUX2) {
|
||||||
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
||||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
|
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
||||||
arch = LLM::LLMArch::QWEN3;
|
arch = LLM::LLMArch::QWEN3;
|
||||||
}
|
}
|
||||||
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
||||||
@ -1708,6 +1708,9 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
int prompt_template_encode_start_idx = 34;
|
int prompt_template_encode_start_idx = 34;
|
||||||
int max_length = 0;
|
int max_length = 0;
|
||||||
std::set<int> out_layers;
|
std::set<int> out_layers;
|
||||||
|
std::vector<int> tokens;
|
||||||
|
std::vector<float> weights;
|
||||||
|
std::vector<float> mask;
|
||||||
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||||
LOG_INFO("QwenImageEditPlusPipeline");
|
LOG_INFO("QwenImageEditPlusPipeline");
|
||||||
prompt_template_encode_start_idx = 64;
|
prompt_template_encode_start_idx = 64;
|
||||||
@ -1771,7 +1774,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (version == VERSION_FLUX2) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
out_layers = {10, 20, 30};
|
out_layers = {10, 20, 30};
|
||||||
|
|
||||||
@ -1793,17 +1796,28 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (version == VERSION_FLUX2_KLEIN) {
|
||||||
prompt_template_encode_start_idx = 0;
|
prompt_template_encode_start_idx = 0;
|
||||||
out_layers = {10, 20, 30};
|
max_length = 512;
|
||||||
|
out_layers = {9, 18, 27};
|
||||||
|
|
||||||
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
|
prompt = "<|im_start|>user\n";
|
||||||
|
|
||||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||||
prompt += conditioner_params.text;
|
prompt += conditioner_params.text;
|
||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
prompt += "[/INST]";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||||
|
|
||||||
|
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
|
||||||
|
tokens = std::get<0>(tokens_and_weights);
|
||||||
|
weights = std::get<1>(tokens_and_weights);
|
||||||
|
|
||||||
|
mask.insert(mask.end(), tokens.size(), 1.f);
|
||||||
|
if (tokens.size() < max_length) {
|
||||||
|
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
|
||||||
|
tokenizer->pad_tokens(tokens, weights, max_length, true);
|
||||||
|
}
|
||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
prompt_template_encode_start_idx = 28;
|
prompt_template_encode_start_idx = 28;
|
||||||
max_length = prompt_template_encode_start_idx + 256;
|
max_length = prompt_template_encode_start_idx + 256;
|
||||||
@ -1827,17 +1841,34 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (tokens.empty()) {
|
||||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
||||||
auto& tokens = std::get<0>(tokens_and_weights);
|
tokens = std::get<0>(tokens_and_weights);
|
||||||
auto& weights = std::get<1>(tokens_and_weights);
|
weights = std::get<1>(tokens_and_weights);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
|
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
|
||||||
|
|
||||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||||
|
|
||||||
|
ggml_tensor* attention_mask = nullptr;
|
||||||
|
if (!mask.empty()) {
|
||||||
|
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
|
||||||
|
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
|
float value = 0.f;
|
||||||
|
if (mask[i0] == 0.f) {
|
||||||
|
value = -INFINITY;
|
||||||
|
} else if (i0 > i1) {
|
||||||
|
value = -INFINITY;
|
||||||
|
}
|
||||||
|
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
llm->compute(n_threads,
|
llm->compute(n_threads,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
image_embeds,
|
image_embeds,
|
||||||
out_layers,
|
out_layers,
|
||||||
&hidden_states,
|
&hidden_states,
|
||||||
@ -1861,7 +1892,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||||
|
|
||||||
int64_t min_length = 0;
|
int64_t min_length = 0;
|
||||||
if (sd_version_is_flux2(version)) {
|
if (version == VERSION_FLUX2) {
|
||||||
min_length = 512;
|
min_length = 512;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# How to Use
|
# How to Use
|
||||||
|
|
||||||
## Download weights
|
## Flux.2-dev
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
|
||||||
- Download FLUX.2-dev
|
- Download FLUX.2-dev
|
||||||
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
|
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
|
||||||
@ -9,7 +11,7 @@
|
|||||||
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
|
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
|
||||||
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
|
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
|
||||||
|
|
||||||
## Examples
|
### Examples
|
||||||
|
|
||||||
```
|
```
|
||||||
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
|
||||||
@ -17,5 +19,74 @@
|
|||||||
|
|
||||||
<img alt="flux2 example" src="../assets/flux2/example.png" />
|
<img alt="flux2 example" src="../assets/flux2/example.png" />
|
||||||
|
|
||||||
|
## Flux.2 klein 4B / Flux.2 klein base 4B
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
|
||||||
|
- Download FLUX.2-klein-4B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-4B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-4B-GGUF/tree/main
|
||||||
|
- Download FLUX.2-klein-base-4B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
|
||||||
|
- Download vae
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
|
||||||
|
- Download Qwen3 4b
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
|
||||||
|
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-4b" src="../assets/flux2/flux2-klein-4b.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-4b-edit" src="../assets/flux2/flux2-klein-4b-edit.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-base-4b" src="../assets/flux2/flux2-klein-base-4b.png" />
|
||||||
|
|
||||||
|
## Flux.2 klein 9B / Flux.2 klein base 9B
|
||||||
|
|
||||||
|
### Download weights
|
||||||
|
|
||||||
|
- Download FLUX.2-klein-9B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-9B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-9B-GGUF/tree/main
|
||||||
|
- Download FLUX.2-klein-base-9B
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B
|
||||||
|
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-9B-GGUF/tree/main
|
||||||
|
- Download vae
|
||||||
|
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
|
||||||
|
- Download Qwen3 8B
|
||||||
|
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-9B/tree/main/split_files/text_encoders
|
||||||
|
- gguf: https://huggingface.co/unsloth/Qwen3-8B-GGUF/tree/main
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-9b" src="../assets/flux2/flux2-klein-9b.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-9b-edit" src="../assets/flux2/flux2-klein-9b-edit.png" />
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="flux2-klein-base-9b" src="../assets/flux2/flux2-klein-base-9b.png" />
|
||||||
29
flux.hpp
29
flux.hpp
@ -1288,13 +1288,9 @@ namespace Flux {
|
|||||||
} else if (version == VERSION_OVIS_IMAGE) {
|
} else if (version == VERSION_OVIS_IMAGE) {
|
||||||
flux_params.semantic_txt_norm = true;
|
flux_params.semantic_txt_norm = true;
|
||||||
flux_params.use_yak_mlp = true;
|
flux_params.use_yak_mlp = true;
|
||||||
flux_params.context_in_dim = 2048;
|
|
||||||
flux_params.vec_in_dim = 0;
|
flux_params.vec_in_dim = 0;
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (sd_version_is_flux2(version)) {
|
||||||
flux_params.context_in_dim = 15360;
|
|
||||||
flux_params.in_channels = 128;
|
flux_params.in_channels = 128;
|
||||||
flux_params.hidden_size = 6144;
|
|
||||||
flux_params.num_heads = 48;
|
|
||||||
flux_params.patch_size = 1;
|
flux_params.patch_size = 1;
|
||||||
flux_params.out_channels = 128;
|
flux_params.out_channels = 128;
|
||||||
flux_params.mlp_ratio = 3.f;
|
flux_params.mlp_ratio = 3.f;
|
||||||
@ -1307,12 +1303,12 @@ namespace Flux {
|
|||||||
flux_params.ref_index_scale = 10.f;
|
flux_params.ref_index_scale = 10.f;
|
||||||
flux_params.use_mlp_silu_act = true;
|
flux_params.use_mlp_silu_act = true;
|
||||||
}
|
}
|
||||||
|
int64_t head_dim = 0;
|
||||||
for (auto pair : tensor_storage_map) {
|
for (auto pair : tensor_storage_map) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
if (!starts_with(tensor_name, prefix))
|
if (!starts_with(tensor_name, prefix))
|
||||||
continue;
|
continue;
|
||||||
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
|
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
|
||||||
// not schnell
|
|
||||||
flux_params.guidance_embed = true;
|
flux_params.guidance_embed = true;
|
||||||
}
|
}
|
||||||
if (tensor_name.find("__x0__") != std::string::npos) {
|
if (tensor_name.find("__x0__") != std::string::npos) {
|
||||||
@ -1344,13 +1340,30 @@ namespace Flux {
|
|||||||
flux_params.depth_single_blocks = block_depth + 1;
|
flux_params.depth_single_blocks = block_depth + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (ends_with(tensor_name, "txt_in.weight")) {
|
||||||
|
flux_params.context_in_dim = pair.second.ne[0];
|
||||||
|
flux_params.hidden_size = pair.second.ne[1];
|
||||||
|
}
|
||||||
|
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
|
||||||
|
head_dim = pair.second.ne[0];
|
||||||
|
}
|
||||||
|
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
|
||||||
|
head_dim = pair.second.ne[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
|
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
||||||
|
|
||||||
|
LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
|
||||||
|
", hidden_size = %" PRId64 ", num_heads = %d",
|
||||||
|
flux_params.depth,
|
||||||
|
flux_params.depth_single_blocks,
|
||||||
|
flux_params.guidance_embed ? "true" : "false",
|
||||||
|
flux_params.context_in_dim,
|
||||||
|
flux_params.hidden_size,
|
||||||
|
flux_params.num_heads);
|
||||||
if (flux_params.is_chroma) {
|
if (flux_params.is_chroma) {
|
||||||
LOG_INFO("Using pruned modulation (Chroma)");
|
LOG_INFO("Using pruned modulation (Chroma)");
|
||||||
} else if (!flux_params.guidance_embed) {
|
|
||||||
LOG_INFO("Flux guidance is disabled (Schnell mode)");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
flux = Flux(flux_params);
|
flux = Flux(flux_params);
|
||||||
|
|||||||
@ -1348,6 +1348,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
|
|||||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
|
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
|
||||||
|
|
||||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||||
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
kq = ggml_scale_inplace(ctx, kq, scale);
|
kq = ggml_scale_inplace(ctx, kq, scale);
|
||||||
if (mask) {
|
if (mask) {
|
||||||
kq = ggml_add_inplace(ctx, kq, mask);
|
kq = ggml_add_inplace(ctx, kq, mask);
|
||||||
|
|||||||
51
llm.hpp
51
llm.hpp
@ -837,7 +837,8 @@ namespace LLM {
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask = nullptr) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
int64_t n_token = x->ne[1];
|
int64_t n_token = x->ne[1];
|
||||||
int64_t N = x->ne[2];
|
int64_t N = x->ne[2];
|
||||||
@ -880,7 +881,7 @@ namespace LLM {
|
|||||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
|
||||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
|
||||||
|
|
||||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
|
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size]
|
||||||
|
|
||||||
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
|
||||||
return x;
|
return x;
|
||||||
@ -898,7 +899,8 @@ namespace LLM {
|
|||||||
|
|
||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* input_pos) {
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask = nullptr) {
|
||||||
// x: [N, n_token, hidden_size]
|
// x: [N, n_token, hidden_size]
|
||||||
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
|
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
|
||||||
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
|
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
|
||||||
@ -907,7 +909,7 @@ namespace LLM {
|
|||||||
|
|
||||||
auto residual = x;
|
auto residual = x;
|
||||||
x = input_layernorm->forward(ctx, x);
|
x = input_layernorm->forward(ctx, x);
|
||||||
x = self_attn->forward(ctx, x, input_pos);
|
x = self_attn->forward(ctx, x, input_pos, attention_mask);
|
||||||
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
|
||||||
|
|
||||||
residual = x;
|
residual = x;
|
||||||
@ -936,6 +938,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
@ -990,7 +993,7 @@ namespace LLM {
|
|||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
|
||||||
|
|
||||||
x = block->forward(ctx, x, input_pos);
|
x = block->forward(ctx, x, input_pos, attention_mask);
|
||||||
if (out_layers.find(i + 1) != out_layers.end()) {
|
if (out_layers.find(i + 1) != out_layers.end()) {
|
||||||
intermediate_outputs.push_back(x);
|
intermediate_outputs.push_back(x);
|
||||||
}
|
}
|
||||||
@ -1036,12 +1039,13 @@ namespace LLM {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
// input_ids: [N, n_token]
|
// input_ids: [N, n_token]
|
||||||
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
|
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
|
||||||
|
|
||||||
auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers);
|
auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1063,6 +1067,7 @@ namespace LLM {
|
|||||||
LLM model;
|
LLM model;
|
||||||
|
|
||||||
std::vector<int> input_pos_vec;
|
std::vector<int> input_pos_vec;
|
||||||
|
std::vector<float> attention_mask_vec;
|
||||||
std::vector<float> window_mask_vec;
|
std::vector<float> window_mask_vec;
|
||||||
std::vector<int> window_index_vec;
|
std::vector<int> window_index_vec;
|
||||||
std::vector<int> window_inverse_index_vec;
|
std::vector<int> window_inverse_index_vec;
|
||||||
@ -1157,9 +1162,10 @@ namespace LLM {
|
|||||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
struct ggml_tensor* input_pos,
|
struct ggml_tensor* input_pos,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size]
|
auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
|
||||||
return hidden_states;
|
return hidden_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1174,6 +1180,7 @@ namespace LLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers) {
|
std::set<int> out_layers) {
|
||||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||||
@ -1205,9 +1212,26 @@ namespace LLM {
|
|||||||
input_pos_vec.size());
|
input_pos_vec.size());
|
||||||
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
set_backend_tensor_data(input_pos, input_pos_vec.data());
|
||||||
|
|
||||||
|
if (attention_mask != nullptr) {
|
||||||
|
attention_mask = to_backend(attention_mask);
|
||||||
|
} else {
|
||||||
|
attention_mask_vec.resize(n_tokens * n_tokens);
|
||||||
|
for (int i0 = 0; i0 < n_tokens; i0++) {
|
||||||
|
for (int i1 = 0; i1 < n_tokens; i1++) {
|
||||||
|
float value = 0.f;
|
||||||
|
if (i0 > i1) {
|
||||||
|
value = -INFINITY;
|
||||||
|
}
|
||||||
|
attention_mask_vec[i1 * n_tokens + i0] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||||
|
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
|
||||||
|
}
|
||||||
|
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
|
|
||||||
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers);
|
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, hidden_states);
|
ggml_build_forward_expand(gf, hidden_states);
|
||||||
|
|
||||||
@ -1216,12 +1240,13 @@ namespace LLM {
|
|||||||
|
|
||||||
bool compute(const int n_threads,
|
bool compute(const int n_threads,
|
||||||
struct ggml_tensor* input_ids,
|
struct ggml_tensor* input_ids,
|
||||||
|
struct ggml_tensor* attention_mask,
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
|
||||||
std::set<int> out_layers,
|
std::set<int> out_layers,
|
||||||
ggml_tensor** output,
|
ggml_tensor** output,
|
||||||
ggml_context* output_ctx = nullptr) {
|
ggml_context* output_ctx = nullptr) {
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_graph(input_ids, image_embeds, out_layers);
|
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
|
||||||
};
|
};
|
||||||
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||||
}
|
}
|
||||||
@ -1525,7 +1550,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, image_embeds, {}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
@ -1565,7 +1590,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
@ -1588,7 +1613,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, {}, {35}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
@ -1611,7 +1636,7 @@ namespace LLM {
|
|||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
model.compute(8, input_ids, {}, {}, &out, work_ctx);
|
model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx);
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
|
|||||||
16
model.cpp
16
model.cpp
@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
|
|
||||||
bool is_xl = false;
|
bool is_xl = false;
|
||||||
bool is_flux = false;
|
bool is_flux = false;
|
||||||
|
bool is_flux2 = false;
|
||||||
|
bool has_single_block_47 = false;
|
||||||
bool is_wan = false;
|
bool is_wan = false;
|
||||||
int64_t patch_embedding_channels = 0;
|
int64_t patch_embedding_channels = 0;
|
||||||
bool has_img_emb = false;
|
bool has_img_emb = false;
|
||||||
@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_QWEN_IMAGE;
|
return VERSION_QWEN_IMAGE;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
||||||
return VERSION_FLUX2;
|
is_flux2 = true;
|
||||||
|
}
|
||||||
|
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
|
||||||
|
has_single_block_47 = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
||||||
return VERSION_OVIS_IMAGE;
|
return VERSION_OVIS_IMAGE;
|
||||||
@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_SDXL;
|
return VERSION_SDXL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_flux) {
|
if (is_flux && !is_flux2) {
|
||||||
if (input_block_weight.ne[0] == 384) {
|
if (input_block_weight.ne[0] == 384) {
|
||||||
return VERSION_FLUX_FILL;
|
return VERSION_FLUX_FILL;
|
||||||
}
|
}
|
||||||
@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
return VERSION_FLUX;
|
return VERSION_FLUX;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (is_flux2) {
|
||||||
|
if (has_single_block_47) {
|
||||||
|
return VERSION_FLUX2;
|
||||||
|
}
|
||||||
|
return VERSION_FLUX2_KLEIN;
|
||||||
|
}
|
||||||
|
|
||||||
if (token_embedding_weight.ne[0] == 768) {
|
if (token_embedding_weight.ne[0] == 768) {
|
||||||
if (is_inpaint) {
|
if (is_inpaint) {
|
||||||
return VERSION_SD1_INPAINT;
|
return VERSION_SD1_INPAINT;
|
||||||
|
|||||||
3
model.h
3
model.h
@ -45,6 +45,7 @@ enum SDVersion {
|
|||||||
VERSION_WAN2_2_TI2V,
|
VERSION_WAN2_2_TI2V,
|
||||||
VERSION_QWEN_IMAGE,
|
VERSION_QWEN_IMAGE,
|
||||||
VERSION_FLUX2,
|
VERSION_FLUX2,
|
||||||
|
VERSION_FLUX2_KLEIN,
|
||||||
VERSION_Z_IMAGE,
|
VERSION_Z_IMAGE,
|
||||||
VERSION_OVIS_IMAGE,
|
VERSION_OVIS_IMAGE,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_flux2(SDVersion version) {
|
static inline bool sd_version_is_flux2(SDVersion version) {
|
||||||
if (version == VERSION_FLUX2) {
|
if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -48,6 +48,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Wan 2.2 TI2V",
|
"Wan 2.2 TI2V",
|
||||||
"Qwen Image",
|
"Qwen Image",
|
||||||
"Flux.2",
|
"Flux.2",
|
||||||
|
"Flux.2 klein",
|
||||||
"Z-Image",
|
"Z-Image",
|
||||||
"Ovis Image",
|
"Ovis Image",
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user