feat: add ovis image support (#1057)

This commit is contained in:
leejet 2025-12-07 12:32:56 +08:00 committed by GitHub
parent bfbb929790
commit 2f0bd31a84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 230 additions and 71 deletions

View File

@ -49,6 +49,7 @@ API and command-line option may change frequently.***
- [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
- [Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
@ -134,6 +135,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
- [🔥Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md)
- [LoRA](./docs/lora.md)
- [LCM/LCM-LoRA](./docs/lcm.md)
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)

Binary file not shown.

After

Width:  |  Height:  |  Size: 401 KiB

View File

@ -1638,7 +1638,7 @@ struct LLMEmbedder : public Conditioner {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version)) {
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
@ -1728,6 +1728,7 @@ struct LLMEmbedder : public Conditioner {
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
std::pair<int, int> prompt_attn_range;
int prompt_template_encode_start_idx = 34;
int max_length = 0;
std::set<int> out_layers;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
@ -1825,6 +1826,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = prompt.size();
prompt += "[/INST]";
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += " " + conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else {
prompt_template_encode_start_idx = 34;
@ -1837,7 +1849,7 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
@ -1870,9 +1882,13 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t zero_pad_len = 0;
int64_t min_length = 0;
if (sd_version_is_flux2(version)) {
int64_t min_length = 512;
min_length = 512;
}
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
@ -1892,6 +1908,8 @@ struct LLMEmbedder : public Conditioner {
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
// print_ggml_tensor(new_hidden_states);
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {new_hidden_states, nullptr, nullptr};

19
docs/ovis_image.md Normal file
View File

@ -0,0 +1,19 @@
# How to Use
## Download weights
- Download Ovis-Image-7B
- safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/leejet/Ovis-Image-7B-GGUF
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
- Download Ovis 2.5
- safetensors: https://huggingface.co/Comfy-Org/Ovis-Image/tree/main/split_files/text_encoders
## Examples
```
.\bin\Release\sd.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
```
<img alt="ovis image example" src="../assets/ovis_image/example.png" />

130
flux.hpp
View File

@ -134,6 +134,54 @@ namespace Flux {
}
};
struct MLP : public UnaryBlock {
bool use_mlp_silu_act;
public:
MLP(int64_t hidden_size, int64_t intermediate_size, bool use_mlp_silu_act = false, bool bias = false)
: use_mlp_silu_act(use_mlp_silu_act) {
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
blocks["0"] = std::make_shared<Linear>(hidden_size, intermediate_size * mlp_mult_factor, bias);
blocks["2"] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
x = mlp_0->forward(ctx, x);
if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else {
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
}
x = mlp_2->forward(ctx, x);
return x;
}
};
struct YakMLP : public UnaryBlock {
public:
YakMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = true) {
blocks["gate_proj"] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
blocks["up_proj"] = std::make_shared<Linear>(hidden_size, intermediate_size, bias);
blocks["down_proj"] = std::make_shared<Linear>(intermediate_size, hidden_size, bias);
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
auto gate = gate_proj->forward(ctx, x);
gate = ggml_silu_inplace(ctx->ggml_ctx, gate);
x = up_proj->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, gate);
x = down_proj->forward(ctx, x);
return x;
}
};
struct ModulationOut {
ggml_tensor* shift = nullptr;
ggml_tensor* scale = nullptr;
@ -199,7 +247,6 @@ namespace Flux {
struct DoubleStreamBlock : public GGMLBlock {
bool prune_mod;
int idx = 0;
bool use_mlp_silu_act;
public:
DoubleStreamBlock(int64_t hidden_size,
@ -210,10 +257,10 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
: idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
: idx(idx), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
int64_t mlp_mult_factor = use_mlp_silu_act ? 2 : 1;
if (!prune_mod && !share_modulation) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
@ -222,9 +269,11 @@ namespace Flux {
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
if (use_yak_mlp) {
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias));
} else {
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
}
if (!prune_mod && !share_modulation) {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
@ -233,9 +282,11 @@ namespace Flux {
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias));
blocks["txt_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size, mlp_proj_bias));
if (use_yak_mlp) {
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new YakMLP(hidden_size, mlp_hidden_dim, mlp_proj_bias));
} else {
blocks["txt_mlp"] = std::shared_ptr<GGMLBlock>(new MLP(hidden_size, mlp_hidden_dim, use_mlp_silu_act, mlp_proj_bias));
}
}
std::vector<ModulationOut> get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) {
@ -272,15 +323,13 @@ namespace Flux {
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["img_attn"]);
auto img_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm2"]);
auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.0"]);
auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.2"]);
auto img_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks["img_mlp"]);
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["txt_attn"]);
auto txt_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm2"]);
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
auto txt_mlp = std::dynamic_pointer_cast<UnaryBlock>(blocks["txt_mlp"]);
if (img_mods.empty()) {
if (prune_mod) {
@ -348,26 +397,14 @@ namespace Flux {
// calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
if (use_mlp_silu_act) {
img_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, img_mlp_out);
} else {
img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out);
}
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
auto img_mlp_out = img_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate));
// calculate the txt bloks
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
if (use_mlp_silu_act) {
txt_mlp_out = ggml_ext_silu_act(ctx->ggml_ctx, txt_mlp_out);
} else {
txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out);
}
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
auto txt_mlp_out = txt_mlp->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate));
return {img, txt};
@ -381,6 +418,7 @@ namespace Flux {
int64_t mlp_hidden_dim;
bool prune_mod;
int idx = 0;
bool use_yak_mlp;
bool use_mlp_silu_act;
int64_t mlp_mult_factor;
@ -393,8 +431,9 @@ namespace Flux {
bool prune_mod = false,
bool share_modulation = false,
bool mlp_proj_bias = true,
bool use_yak_mlp = false,
bool use_mlp_silu_act = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_mlp_silu_act(use_mlp_silu_act) {
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
if (scale <= 0.f) {
@ -402,7 +441,7 @@ namespace Flux {
}
mlp_hidden_dim = hidden_size * mlp_ratio;
mlp_mult_factor = 1;
if (use_mlp_silu_act) {
if (use_yak_mlp || use_mlp_silu_act) {
mlp_mult_factor = 2;
}
@ -481,7 +520,9 @@ namespace Flux {
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]
if (use_mlp_silu_act) {
if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
@ -726,6 +767,8 @@ namespace Flux {
int64_t in_dim = 64;
bool disable_bias = false;
bool share_modulation = false;
bool semantic_txt_norm = false;
bool use_yak_mlp = false;
bool use_mlp_silu_act = false;
float ref_index_scale = 1.f;
ChromaRadianceParams chroma_radiance_params;
@ -759,6 +802,9 @@ namespace Flux {
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size, !params.disable_bias);
}
}
if (params.semantic_txt_norm) {
blocks["txt_norm"] = std::make_shared<RMSNorm>(params.context_in_dim);
}
blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, !params.disable_bias);
for (int i = 0; i < params.depth; i++) {
@ -770,6 +816,7 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
}
@ -782,6 +829,7 @@ namespace Flux {
params.is_chroma,
params.share_modulation,
!params.disable_bias,
params.use_yak_mlp,
params.use_mlp_silu_act);
}
@ -948,6 +996,12 @@ namespace Flux {
ss_mods = single_stream_modulation->forward(ctx, vec);
}
if (params.semantic_txt_norm) {
auto semantic_txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm"]);
txt = semantic_txt_norm->forward(ctx, txt);
}
txt = txt_in->forward(ctx, txt);
for (int i = 0; i < params.depth; i++) {
@ -1206,6 +1260,11 @@ namespace Flux {
} else if (version == VERSION_CHROMA_RADIANCE) {
flux_params.in_channels = 3;
flux_params.patch_size = 16;
} else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true;
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128;
@ -1364,13 +1423,22 @@ namespace Flux {
ref_latents[i] = to_backend(ref_latents[i]);
}
std::set<int> txt_arange_dims;
if (sd_version_is_flux2(version)) {
txt_arange_dims = {3};
increase_ref_index = true;
} else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2};
}
pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
x->ne[3],
context->ne[1],
txt_arange_dims,
ref_latents,
sd_version_is_flux2(version) ? true : increase_ref_index,
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);

View File

@ -760,17 +760,23 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
return chunks;
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x) {
__STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x, bool gate_first = true) {
// x: [ne3, ne2, ne1, ne0]
// return: [ne3, ne2, ne1, ne0/2]
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0);
auto x1 = x_vec[0]; // [ne3, ne2, ne1, ne0/2]
auto x2 = x_vec[1]; // [ne3, ne2, ne1, ne0/2]
ggml_tensor* gate;
if (gate_first) {
gate = x_vec[0];
x = x_vec[1];
} else {
x = x_vec[0];
gate = x_vec[1];
}
x1 = ggml_silu_inplace(ctx, x1);
gate = ggml_silu_inplace(ctx, gate);
x = ggml_mul(ctx, x1, x2); // [ne3, ne2, ne1, ne0/2]
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]
return x;
}

49
llm.hpp
View File

@ -356,6 +356,10 @@ namespace LLM {
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<tool_response>",
"</tool_response>",
"<think>",
"</think>",
};
if (merges_utf8_str.size() > 0) {
@ -859,11 +863,11 @@ namespace LLM {
}
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::QWEN3) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else {
int sections[4] = {16, 24, 24, 0};
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@ -1073,29 +1077,22 @@ namespace LLM {
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
params.arch = arch;
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
params.num_layers = 40;
params.hidden_size = 5120;
params.intermediate_size = 32768;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.vocab_size = 131072;
params.rms_norm_eps = 1e-5f;
} else if (arch == LLMArch::QWEN3) {
params.num_layers = 36;
params.hidden_size = 2560;
params.intermediate_size = 9728;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.vocab_size = 151936;
params.rms_norm_eps = 1e-6f;
}
bool have_vision_weight = false;
bool llama_cpp_style = false;
params.num_layers = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
@ -1105,10 +1102,36 @@ namespace LLM {
have_vision_weight = true;
if (contains(tensor_name, "attn.q_proj")) {
llama_cpp_style = true;
break;
}
continue;
}
pos = tensor_name.find("layers.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > params.num_layers) {
params.num_layers = block_index + 1;
}
}
}
if (contains(tensor_name, "embed_tokens.weight")) {
params.hidden_size = pair.second.ne[0];
params.vocab_size = pair.second.ne[1];
}
if (contains(tensor_name, "layers.0.mlp.gate_proj.weight")) {
params.intermediate_size = pair.second.ne[1];
}
}
if (arch == LLMArch::QWEN3 && params.num_layers == 28) { // Qwen3 2B
params.num_heads = 16;
}
LOG_DEBUG("llm: num_layers = %" PRId64 ", vocab_size = %" PRId64 ", hidden_size = %" PRId64 ", intermediate_size = %" PRId64,
params.num_layers,
params.vocab_size,
params.hidden_size,
params.intermediate_size);
if (enable_vision && !have_vision_weight) {
LOG_WARN("no vision weights detected, vision disabled");
enable_vision = false;

View File

@ -1056,6 +1056,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
return VERSION_FLUX2;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
return VERSION_OVIS_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) {
return VERSION_Z_IMAGE;
}

View File

@ -45,6 +45,7 @@ enum SDVersion {
VERSION_QWEN_IMAGE,
VERSION_FLUX2,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_COUNT,
};
@ -90,6 +91,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
version == VERSION_FLUX_FILL ||
version == VERSION_FLUX_CONTROLS ||
version == VERSION_FLEX_2 ||
version == VERSION_OVIS_IMAGE ||
version == VERSION_CHROMA_RADIANCE) {
return true;
}

View File

@ -72,11 +72,13 @@ namespace Rope {
}
// Generate IDs for image patches and text
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num) {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_txt_ids(int bs, int context_len, int axes_dim_num, std::set<int> arange_dims) {
auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
if (axes_dim_num == 4) {
for (int dim = 0; dim < axes_dim_num; dim++) {
if (arange_dims.find(dim) != arange_dims.end()) {
for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][3] = (i % context_len);
txt_ids[i][dim] = (i % context_len);
}
}
}
return txt_ids;
@ -211,10 +213,11 @@ namespace Rope {
int bs,
int axes_dim_num,
int context_len,
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num);
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids, img_ids, bs);
@ -231,6 +234,7 @@ namespace Rope {
int patch_size,
int bs,
int context_len,
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale,
@ -242,6 +246,7 @@ namespace Rope {
bs,
static_cast<int>(axes_dim.size()),
context_len,
txt_arange_dims,
ref_latents,
increase_ref_index,
ref_index_scale);

View File

@ -46,6 +46,7 @@ const char* model_version_to_str[] = {
"Qwen Image",
"Flux.2",
"Z-Image",
"Ovis Image",
};
const char* sampling_methods_str[] = {
@ -424,6 +425,13 @@ public:
tensor_storage_map,
sd_ctx_params->chroma_use_t5_mask,
sd_ctx_params->chroma_t5_mask_pad);
} else if (version == VERSION_OVIS_IMAGE) {
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version,
"",
false);
} else {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
@ -690,6 +698,11 @@ public:
ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("text_encoders.llm.visual.");
}
if (version == VERSION_OVIS_IMAGE) {
ignore_tensors.insert("text_encoders.llm.vision_model.");
ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
ignore_tensors.insert("text_encoders.llm.vte.");
}
if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3");
}