mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add ovis image support (#1057)
This commit is contained in:
parent
bfbb929790
commit
2f0bd31a84
@ -49,6 +49,7 @@ API and command-line option may change frequently.***
|
||||
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
||||
- [Qwen Image](./docs/qwen_image.md)
|
||||
- [Z-Image](./docs/z_image.md)
|
||||
- [Ovis-Image](./docs/ovis_image.md)
|
||||
- Image Edit Models
|
||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
|
||||
@ -134,6 +135,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
|
||||
- [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
|
||||
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
|
||||
- [🔥Z-Image](./docs/z_image.md)
|
||||
- [Ovis-Image](./docs/ovis_image.md)
|
||||
- [LoRA](./docs/lora.md)
|
||||
- [LCM/LCM-LoRA](./docs/lcm.md)
|
||||
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)
|
||||
|
||||
BIN
assets/ovis_image/example.png
Normal file
BIN
assets/ovis_image/example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 401 KiB |
@ -1638,7 +1638,7 @@ struct LLMEmbedder : public Conditioner {
|
||||
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
|
||||
} else if (sd_version_is_z_image(version)) {
|
||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
|
||||
arch = LLM::LLMArch::QWEN3;
|
||||
}
|
||||
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
|
||||
@ -1728,6 +1728,7 @@ struct LLMEmbedder : public Conditioner {
|
||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||
std::pair<int, int> prompt_attn_range;
|
||||
int prompt_template_encode_start_idx = 34;
|
||||
int max_length = 0;
|
||||
std::set<int> out_layers;
|
||||
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||
LOG_INFO("QwenImageEditPlusPipeline");
|
||||
@ -1825,6 +1826,17 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt_attn_range.second = prompt.size();
|
||||
|
||||
prompt += "[/INST]";
|
||||
} else if (version == VERSION_OVIS_IMAGE) {
|
||||
prompt_template_encode_start_idx = 28;
|
||||
max_length = prompt_template_encode_start_idx + 256;
|
||||
|
||||
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
|
||||
|
||||
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||
prompt += " " + conditioner_params.text;
|
||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
|
||||
} else {
|
||||
prompt_template_encode_start_idx = 34;
|
||||
|
||||
@ -1837,7 +1849,7 @@ struct LLMEmbedder : public Conditioner {
|
||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||
}
|
||||
|
||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
|
||||
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
|
||||
auto& tokens = std::get<0>(tokens_and_weights);
|
||||
auto& weights = std::get<1>(tokens_and_weights);
|
||||
|
||||
@ -1870,9 +1882,13 @@ struct LLMEmbedder : public Conditioner {
|
||||
|
||||
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||
|
||||
int64_t zero_pad_len = 0;
|
||||
int64_t min_length = 0;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
int64_t min_length = 512;
|
||||
min_length = 512;
|
||||
}
|
||||
|
||||
int64_t zero_pad_len = 0;
|
||||
if (min_length > 0) {
|
||||
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
|
||||
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
|
||||
}
|
||||
@ -1892,6 +1908,8 @@ struct LLMEmbedder : public Conditioner {
|
||||
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
|
||||
});
|
||||
|
||||
// print_ggml_tensor(new_hidden_states);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||
return {new_hidden_states, nullptr, nullptr};
|
||||
|
||||
19
docs/ovis_image.md
Normal file
19
docs/ovis_image.md
Normal file
@ -0,0 +1,19 @@
|
||||
# How to Use
|
||||
|
||||
## Download weights
|
||||
|
||||
- Download Ovis-Image-7B
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/leejet/Ovis-Image-7B-GGUF
|
||||
- Download vae
|
||||
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
|
||||
- Download Ovis 2.5
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/text_encoders
|
||||
|
||||
## Examples
|
||||
|
||||
```
|
||||
.\bin\Release\sd.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
|
||||
```
|
||||
|
||||
<img alt="ovis image example" src="../assets/ovis_image/example.png" />
|
||||
134
flux.hpp
134
flux.hpp
@ -134,6 +134,54 @@ namespace Flux {
|
||||
}
|
||||
};
|
||||
|
||||
struct MLP : public UnaryBlock {
|
||||
bool use_mlp_silu_act;
|
||||
|
||||
public:
|
||||
MLP(int64_t hidden_size, int64_t intermediate_size, bool use_mlp_silu_act = false, bool bias = false)
|
||||
: use_mlp_silu_act(use_mlp_silu_act) {
|
||||
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
|
||||
blocks["0"] = std::make_shared<Linear>(hidden_size, intermediate_size * mlp_mult_factor, bias);
|
||||
blocks["2"] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["0"]);
|
||||
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
|
||||
|
||||
x = mlp_0->forward(ctx, x);
|
||||
if (use_mlp_silu_act) {
|
||||
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
|
||||
} else {
|
||||
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
|
||||
}
|
||||
x = mlp_2->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct YakMLP : public UnaryBlock {
|
||||
public:
|
||||
YakMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = true) {
|
||||
blocks["gate_proj"] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
|
||||
blocks["up_proj"] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
|
||||
blocks["down_proj"] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
|
||||
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
|
||||
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
|
||||
|
||||
auto gate = gate_proj->forward(ctx, x);
|
||||
gate = ggml_silu_inplace(ctx->ggml_ctx, gate);
|
||||
x = up_proj->forward(ctx, x);
|
||||
x = ggml_mul(ctx->ggml_ctx, x, gate);
|
||||
x = down_proj->forward(ctx, x);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ModulationOut {
|
||||
ggml_tensor* shift = nullptr;
|
||||
ggml_tensor* scale = nullptr;
|
||||
@ -199,7 +247,6 @@ namespace Flux {
|
||||
struct DoubleStreamBlock : public GGMLBlock {
|
||||
bool prune_mod;
|
||||
int idx = 0;
|
||||
bool use_mlp_silu_act;
|
||||
|
||||
public:
|
||||
DoubleStreamBlock(int64_t hidden_size,
|
||||
@ -210,10 +257,10 @@ namespace Flux {
|
||||
bool prune_mod = false,
|
||||
bool share_modulation = false,
|
||||
bool mlp_proj_bias = true,
|
||||
bool use_yak_mlp = false,
|
||||
bool use_mlp_silu_act = false)
|
||||
: idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
|
||||
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
|
||||
: idx(idx), prune_mod(prune_mod) {
|
||||
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||
|
||||
if (!prune_mod && !share_modulation) {
|
||||
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
@ -222,9 +269,11 @@ namespace Flux {
|
||||
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
|
||||
|
||||
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
|
||||
// img_mlp.1 is nn.GELU(approximate="tanh")
|
||||
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
|
||||
if (use_yak_mlp) {
|
||||
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias));
|
||||
} else {
|
||||
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
|
||||
}
|
||||
|
||||
if (!prune_mod && !share_modulation) {
|
||||
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
|
||||
@ -233,9 +282,11 @@ namespace Flux {
|
||||
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
|
||||
|
||||
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
|
||||
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
|
||||
// img_mlp.1 is nn.GELU(approximate="tanh")
|
||||
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
|
||||
if (use_yak_mlp) {
|
||||
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias));
|
||||
} else {
|
||||
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
|
||||
@ -272,15 +323,13 @@ namespace Flux {
|
||||
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["img_attn"]);
|
||||
|
||||
auto img_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm2"]);
|
||||
auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.0"]);
|
||||
auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.2"]);
|
||||
auto img_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks["img_mlp"]);
|
||||
|
||||
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
|
||||
auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["txt_attn"]);
|
||||
|
||||
auto txt_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm2"]);
|
||||
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
|
||||
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
|
||||
auto txt_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks["txt_mlp"]);
|
||||
|
||||
if (img_mods.empty()) {
|
||||
if (prune_mod) {
|
||||
@ -348,27 +397,15 @@ namespace Flux {
|
||||
// calculate the img bloks
|
||||
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
|
||||
|
||||
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
|
||||
if (use_mlp_silu_act) {
|
||||
img_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, img_mlp_out);
|
||||
} else {
|
||||
img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out);
|
||||
}
|
||||
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
|
||||
auto img_mlp_out = img_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
|
||||
|
||||
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate));
|
||||
|
||||
// calculate the txt bloks
|
||||
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
|
||||
|
||||
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
|
||||
if (use_mlp_silu_act) {
|
||||
txt_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, txt_mlp_out);
|
||||
} else {
|
||||
txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out);
|
||||
}
|
||||
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
|
||||
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
|
||||
auto txt_mlp_out = txt_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
|
||||
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
|
||||
|
||||
return {img, txt};
|
||||
}
|
||||
@ -381,6 +418,7 @@ namespace Flux {
|
||||
int64_t mlp_hidden_dim;
|
||||
bool prune_mod;
|
||||
int idx = 0;
|
||||
bool use_yak_mlp;
|
||||
bool use_mlp_silu_act;
|
||||
int64_t mlp_mult_factor;
|
||||
|
||||
@ -393,8 +431,9 @@ namespace Flux {
|
||||
bool prune_mod = false,
|
||||
bool share_modulation = false,
|
||||
bool mlp_proj_bias = true,
|
||||
bool use_yak_mlp = false,
|
||||
bool use_mlp_silu_act = false)
|
||||
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
|
||||
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
|
||||
int64_t head_dim = hidden_size / num_heads;
|
||||
float scale = qk_scale;
|
||||
if (scale <= 0.f) {
|
||||
@ -402,7 +441,7 @@ namespace Flux {
|
||||
}
|
||||
mlp_hidden_dim = hidden_size * mlp_ratio;
|
||||
mlp_mult_factor = 1;
|
||||
if (use_mlp_silu_act) {
|
||||
if (use_yak_mlp || use_mlp_silu_act) {
|
||||
mlp_mult_factor = 2;
|
||||
}
|
||||
|
||||
@ -481,7 +520,9 @@ namespace Flux {
|
||||
k = norm->key_norm(ctx, k);
|
||||
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
|
||||
|
||||
if (use_mlp_silu_act) {
|
||||
if (use_yak_mlp) {
|
||||
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
|
||||
} else if (use_mlp_silu_act) {
|
||||
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
|
||||
} else {
|
||||
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
|
||||
@ -726,6 +767,8 @@ namespace Flux {
|
||||
int64_t in_dim = 64;
|
||||
bool disable_bias = false;
|
||||
bool share_modulation = false;
|
||||
bool semantic_txt_norm = false;
|
||||
bool use_yak_mlp = false;
|
||||
bool use_mlp_silu_act = false;
|
||||
float ref_index_scale = 1.f;
|
||||
ChromaRadianceParams chroma_radiance_params;
|
||||
@ -759,6 +802,9 @@ namespace Flux {
|
||||
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
|
||||
}
|
||||
}
|
||||
if (params.semantic_txt_norm) {
|
||||
blocks["txt_norm"] = std::make_shared<RMSNorm>(params.context_in_dim);
|
||||
}
|
||||
blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, !params.disable_bias);
|
||||
|
||||
for (int i = 0; i < params.depth; i++) {
|
||||
@ -770,6 +816,7 @@ namespace Flux {
|
||||
params.is_chroma,
|
||||
params.share_modulation,
|
||||
!params.disable_bias,
|
||||
params.use_yak_mlp,
|
||||
params.use_mlp_silu_act);
|
||||
}
|
||||
|
||||
@ -782,6 +829,7 @@ namespace Flux {
|
||||
params.is_chroma,
|
||||
params.share_modulation,
|
||||
!params.disable_bias,
|
||||
params.use_yak_mlp,
|
||||
params.use_mlp_silu_act);
|
||||
}
|
||||
|
||||
@ -948,6 +996,12 @@ namespace Flux {
|
||||
ss_mods = single_stream_modulation->forward(ctx, vec);
|
||||
}
|
||||
|
||||
if (params.semantic_txt_norm) {
|
||||
auto semantic_txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm"]);
|
||||
|
||||
txt = semantic_txt_norm->forward(ctx, txt);
|
||||
}
|
||||
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
|
||||
for (int i = 0; i < params.depth; i++) {
|
||||
@ -1206,6 +1260,11 @@ namespace Flux {
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
flux_params.in_channels = 3;
|
||||
flux_params.patch_size = 16;
|
||||
} else if (version == VERSION_OVIS_IMAGE) {
|
||||
flux_params.semantic_txt_norm = true;
|
||||
flux_params.use_yak_mlp = true;
|
||||
flux_params.context_in_dim = 2048;
|
||||
flux_params.vec_in_dim = 0;
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
flux_params.context_in_dim = 15360;
|
||||
flux_params.in_channels = 128;
|
||||
@ -1364,13 +1423,22 @@ namespace Flux {
|
||||
ref_latents[i] = to_backend(ref_latents[i]);
|
||||
}
|
||||
|
||||
std::set<int> txt_arange_dims;
|
||||
if (sd_version_is_flux2(version)) {
|
||||
txt_arange_dims = {3};
|
||||
increase_ref_index = true;
|
||||
} else if (version == VERSION_OVIS_IMAGE) {
|
||||
txt_arange_dims = {1, 2};
|
||||
}
|
||||
|
||||
pe_vec = Rope::gen_flux_pe(x->ne[1],
|
||||
x->ne[0],
|
||||
flux_params.patch_size,
|
||||
x->ne[3],
|
||||
context->ne[1],
|
||||
txt_arange_dims,
|
||||
ref_latents,
|
||||
sd_version_is_flux2(version) ? true : increase_ref_index,
|
||||
increase_ref_index,
|
||||
flux_params.ref_index_scale,
|
||||
flux_params.theta,
|
||||
flux_params.axes_dim);
|
||||
|
||||
@ -760,17 +760,23 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
|
||||
return chunks;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x) {
|
||||
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x, bool gate_first = true) {
|
||||
// x: [ne3, ne2, ne1, ne0]
|
||||
// return: [ne3, ne2, ne1, ne0/2]
|
||||
|
||||
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0);
|
||||
auto x1 = x_vec[0]; // [ne3, ne2, ne1, ne0/2]
|
||||
auto x2 = x_vec[1]; // [ne3, ne2, ne1, ne0/2]
|
||||
ggml_tensor* gate;
|
||||
if (gate_first) {
|
||||
gate = x_vec[0];
|
||||
x = x_vec[1];
|
||||
} else {
|
||||
x = x_vec[0];
|
||||
gate = x_vec[1];
|
||||
}
|
||||
|
||||
x1 = ggml_silu_inplace(ctx, x1);
|
||||
gate = ggml_silu_inplace(ctx, gate);
|
||||
|
||||
x = ggml_mul(ctx, x1, x2); // [ne3, ne2, ne1, ne0/2]
|
||||
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
71
llm.hpp
71
llm.hpp
@ -356,6 +356,10 @@ namespace LLM {
|
||||
"<|fim_pad|>",
|
||||
"<|repo_name|>",
|
||||
"<|file_sep|>",
|
||||
"<tool_response>",
|
||||
"</tool_response>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
if (merges_utf8_str.size() > 0) {
|
||||
@ -859,11 +863,11 @@ namespace LLM {
|
||||
}
|
||||
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
} else if (arch == LLMArch::QWEN3) {
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
} else {
|
||||
int sections[4] = {16, 24, 24, 0};
|
||||
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
|
||||
@ -1073,29 +1077,22 @@ namespace LLM {
|
||||
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
|
||||
params.arch = arch;
|
||||
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
|
||||
params.num_layers = 40;
|
||||
params.hidden_size = 5120;
|
||||
params.intermediate_size = 32768;
|
||||
params.head_dim = 128;
|
||||
params.num_heads = 32;
|
||||
params.num_kv_heads = 8;
|
||||
params.qkv_bias = false;
|
||||
params.vocab_size = 131072;
|
||||
params.rms_norm_eps = 1e-5f;
|
||||
params.head_dim = 128;
|
||||
params.num_heads = 32;
|
||||
params.num_kv_heads = 8;
|
||||
params.qkv_bias = false;
|
||||
params.rms_norm_eps = 1e-5f;
|
||||
} else if (arch == LLMArch::QWEN3) {
|
||||
params.num_layers = 36;
|
||||
params.hidden_size = 2560;
|
||||
params.intermediate_size = 9728;
|
||||
params.head_dim = 128;
|
||||
params.num_heads = 32;
|
||||
params.num_kv_heads = 8;
|
||||
params.qkv_bias = false;
|
||||
params.qk_norm = true;
|
||||
params.vocab_size = 151936;
|
||||
params.rms_norm_eps = 1e-6f;
|
||||
params.head_dim = 128;
|
||||
params.num_heads = 32;
|
||||
params.num_kv_heads = 8;
|
||||
params.qkv_bias = false;
|
||||
params.qk_norm = true;
|
||||
params.rms_norm_eps = 1e-6f;
|
||||
}
|
||||
bool have_vision_weight = false;
|
||||
bool llama_cpp_style = false;
|
||||
params.num_layers = 0;
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find(prefix) == std::string::npos)
|
||||
@ -1105,10 +1102,36 @@ namespace LLM {
|
||||
have_vision_weight = true;
|
||||
if (contains(tensor_name, "attn.q_proj")) {
|
||||
llama_cpp_style = true;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
pos = tensor_name.find("layers.");
|
||||
if (pos != std::string::npos) {
|
||||
tensor_name = tensor_name.substr(pos); // remove prefix
|
||||
auto items = split_string(tensor_name, '.');
|
||||
if (items.size() > 1) {
|
||||
int block_index = atoi(items[1].c_str());
|
||||
if (block_index + 1 > params.num_layers) {
|
||||
params.num_layers = block_index + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (contains(tensor_name, "embed_tokens.weight")) {
|
||||
params.hidden_size = pair.second.ne[0];
|
||||
params.vocab_size = pair.second.ne[1];
|
||||
}
|
||||
if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) {
|
||||
params.intermediate_size = pair.second.ne[1];
|
||||
}
|
||||
}
|
||||
if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B
|
||||
params.num_heads = 16;
|
||||
}
|
||||
LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64,
|
||||
params.num_layers,
|
||||
params.vocab_size,
|
||||
params.hidden_size,
|
||||
params.intermediate_size);
|
||||
if (enable_vision && !have_vision_weight) {
|
||||
LOG_WARN("no vision weights detected, vision disabled");
|
||||
enable_vision = false;
|
||||
|
||||
@ -1056,6 +1056,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
|
||||
return VERSION_FLUX2;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
|
||||
return VERSION_OVIS_IMAGE;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) {
|
||||
return VERSION_Z_IMAGE;
|
||||
}
|
||||
|
||||
2
model.h
2
model.h
@ -45,6 +45,7 @@ enum SDVersion {
|
||||
VERSION_QWEN_IMAGE,
|
||||
VERSION_FLUX2,
|
||||
VERSION_Z_IMAGE,
|
||||
VERSION_OVIS_IMAGE,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
@ -90,6 +91,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
||||
version == VERSION_FLUX_FILL ||
|
||||
version == VERSION_FLUX_CONTROLS ||
|
||||
version == VERSION_FLEX_2 ||
|
||||
version == VERSION_OVIS_IMAGE ||
|
||||
version == VERSION_CHROMA_RADIANCE) {
|
||||
return true;
|
||||
}
|
||||
|
||||
15
rope.hpp
15
rope.hpp
@ -72,11 +72,13 @@ namespace Rope {
|
||||
}
|
||||
|
||||
// Generate IDs for image patches and text
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num) {
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num, std::set<int> arange_dims) {
|
||||
auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
|
||||
if (axes_dim_num == 4) {
|
||||
for (int i = 0; i < bs * context_len; i++) {
|
||||
txt_ids[i][3] = (i % context_len);
|
||||
for (int dim = 0; dim < axes_dim_num; dim++) {
|
||||
if (arange_dims.find(dim) != arange_dims.end()) {
|
||||
for (int i = 0; i < bs * context_len; i++) {
|
||||
txt_ids[i][dim] = (i % context_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
return txt_ids;
|
||||
@ -211,10 +213,11 @@ namespace Rope {
|
||||
int bs,
|
||||
int axes_dim_num,
|
||||
int context_len,
|
||||
std::set<int> txt_arange_dims,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index,
|
||||
float ref_index_scale) {
|
||||
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num);
|
||||
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
|
||||
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
@ -231,6 +234,7 @@ namespace Rope {
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
std::set<int> txt_arange_dims,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index,
|
||||
float ref_index_scale,
|
||||
@ -242,6 +246,7 @@ namespace Rope {
|
||||
bs,
|
||||
static_cast<int>(axes_dim.size()),
|
||||
context_len,
|
||||
txt_arange_dims,
|
||||
ref_latents,
|
||||
increase_ref_index,
|
||||
ref_index_scale);
|
||||
|
||||
@ -46,6 +46,7 @@ const char* model_version_to_str[] = {
|
||||
"Qwen Image",
|
||||
"Flux.2",
|
||||
"Z-Image",
|
||||
"Ovis Image",
|
||||
};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
@ -424,6 +425,13 @@ public:
|
||||
tensor_storage_map,
|
||||
sd_ctx_params->chroma_use_t5_mask,
|
||||
sd_ctx_params->chroma_t5_mask_pad);
|
||||
} else if (version == VERSION_OVIS_IMAGE) {
|
||||
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
version,
|
||||
"",
|
||||
false);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
@ -690,6 +698,11 @@ public:
|
||||
ignore_tensors.insert("first_stage_model.quant");
|
||||
ignore_tensors.insert("text_encoders.llm.visual.");
|
||||
}
|
||||
if (version == VERSION_OVIS_IMAGE) {
|
||||
ignore_tensors.insert("text_encoders.llm.vision_model.");
|
||||
ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
|
||||
ignore_tensors.insert("text_encoders.llm.vte.");
|
||||
}
|
||||
if (version == VERSION_SVD) {
|
||||
ignore_tensors.insert("conditioner.embedders.3");
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user