feat: add ernie image support (#1427)

This commit is contained in:
leejet 2026-04-17 00:51:42 +08:00 committed by GitHub
parent c41c5ded7a
commit 5c243db9a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 699 additions and 20 deletions

View File

@ -57,6 +57,7 @@ API and command-line option may change frequently.***
- [Z-Image](./docs/z_image.md) - [Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md) - [Ovis-Image](./docs/ovis_image.md)
- [Anima](./docs/anima.md) - [Anima](./docs/anima.md)
- [ERNIE-Image](./docs/ernie_image.md)
- Image Edit Models - Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md) - [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit series](./docs/qwen_image_edit.md) - [Qwen Image Edit series](./docs/qwen_image_edit.md)
@ -144,6 +145,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [🔥Z-Image](./docs/z_image.md) - [🔥Z-Image](./docs/z_image.md)
- [Ovis-Image](./docs/ovis_image.md) - [Ovis-Image](./docs/ovis_image.md)
- [Anima](./docs/anima.md) - [Anima](./docs/anima.md)
- [ERNIE-Image](./docs/ernie_image.md)
- [LoRA](./docs/lora.md) - [LoRA](./docs/lora.md)
- [LCM/LCM-LoRA](./docs/lcm.md) - [LCM/LCM-LoRA](./docs/lcm.md)
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md) - [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)

Binary file not shown.

After

Width:  |  Height:  |  Size: 595 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 562 KiB

35
docs/ernie_image.md Normal file
View File

@ -0,0 +1,35 @@
# How to Use
You can run ERNIE-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less.
## Download weights
- Download ERNIE-Image-Turbo
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
- gguf: https://huggingface.co/unsloth/ERNIE-Image-Turbo-GGUF/tree/main
- Download ERNIE-Image
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/diffusion_models
- gguf: https://huggingface.co/unsloth/ERNIE-Image-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/vae
- Download ministral 3b
- safetensors: https://huggingface.co/Comfy-Org/ERNIE-Image/tree/main/text_encoders
- gguf: https://huggingface.co/unsloth/Ministral-3-3B-Instruct-2512-GGUF/tree/main
## Examples
### ERNIE-Image-Turbo
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-turbo.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 8 -v --offload-to-cpu --diffusion-fa
```
<img width="256" alt="ERNIE-Image Turbo example" src="../assets/ernie_image/turbo_example.png" />
### ERNIE-Image
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\ernie-image-UD-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\ministral-3-3b.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
```
<img width="256" alt="ERNIE-Image example" src="../assets/ernie_image/example.png" />

View File

@ -533,7 +533,7 @@ public:
const std::string& prefix = "") const std::string& prefix = "")
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) { : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) { if (sd_version_is_dit(version)) {
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
dd_config.z_channels = 32; dd_config.z_channels = 32;
embed_dim = 32; embed_dim = 32;
} else { } else {
@ -578,7 +578,7 @@ public:
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) { ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
// z: [N, z_channels, h, w] // z: [N, z_channels, h, w]
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
// [N, C*p*p, h, w] -> [N, C, h*p, w*p] // [N, C*p*p, h, w] -> [N, C, h*p, w*p]
int64_t p = 2; int64_t p = 2;
@ -617,7 +617,7 @@ public:
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]); auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
} }
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
// [N, C, H, W] -> [N, C*p*p, H/p, W/p] // [N, C, H, W] -> [N, C*p*p, H/p, W/p]
@ -640,7 +640,7 @@ public:
int get_encoder_output_channels() { int get_encoder_output_channels() {
int factor = dd_config.double_z ? 2 : 1; int factor = dd_config.double_z ? 2 : 1;
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
return dd_config.z_channels * 4; return dd_config.z_channels * 4;
} }
return dd_config.z_channels * factor; return dd_config.z_channels * factor;
@ -673,7 +673,7 @@ struct AutoEncoderKL : public VAE {
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
scale_factor = 0.3611f; scale_factor = 0.3611f;
shift_factor = 0.1159f; shift_factor = 0.1159f;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_uses_flux2_vae(version)) {
scale_factor = 1.0f; scale_factor = 1.0f;
shift_factor = 0.f; shift_factor = 0.f;
} }
@ -747,7 +747,7 @@ struct AutoEncoderKL : public VAE {
} }
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override { sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
return vae_output; return vae_output;
} else if (version == VERSION_SD1_PIX2PIX) { } else if (version == VERSION_SD1_PIX2PIX) {
return sd::ops::chunk(vae_output, 2, 2)[0]; return sd::ops::chunk(vae_output, 2, 2)[0];
@ -758,7 +758,7 @@ struct AutoEncoderKL : public VAE {
std::pair<sd::Tensor<float>, sd::Tensor<float>> get_latents_mean_std(const sd::Tensor<float>& latents, int channel_dim) { std::pair<sd::Tensor<float>, sd::Tensor<float>> get_latents_mean_std(const sd::Tensor<float>& latents, int channel_dim) {
GGML_ASSERT(channel_dim >= 0 && static_cast<size_t>(channel_dim) < static_cast<size_t>(latents.dim())); GGML_ASSERT(channel_dim >= 0 && static_cast<size_t>(channel_dim) < static_cast<size_t>(latents.dim()));
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
GGML_ASSERT(latents.shape()[channel_dim] == 128); GGML_ASSERT(latents.shape()[channel_dim] == 128);
std::vector<int64_t> stats_shape(static_cast<size_t>(latents.dim()), 1); std::vector<int64_t> stats_shape(static_cast<size_t>(latents.dim()), 1);
stats_shape[static_cast<size_t>(channel_dim)] = latents.shape()[channel_dim]; stats_shape[static_cast<size_t>(channel_dim)] = latents.shape()[channel_dim];
@ -804,7 +804,7 @@ struct AutoEncoderKL : public VAE {
} }
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override { sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
int channel_dim = 2; int channel_dim = 2;
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim); auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
return (latents * std_tensor) / scale_factor + mean_tensor; return (latents * std_tensor) / scale_factor + mean_tensor;
@ -813,7 +813,7 @@ struct AutoEncoderKL : public VAE {
} }
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override { sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
int channel_dim = 2; int channel_dim = 2;
auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim); auto [mean_tensor, std_tensor] = get_latents_mean_std(latents, channel_dim);
return ((latents - mean_tensor) * scale_factor) / std_tensor; return ((latents - mean_tensor) * scale_factor) / std_tensor;

View File

@ -1621,10 +1621,12 @@ struct LLMEmbedder : public Conditioner {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (version == VERSION_FLUX2) { if (version == VERSION_FLUX2) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2; arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_ernie_image(version)) {
arch = LLM::LLMArch::MINISTRAL_3_3B;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3; arch = LLM::LLMArch::QWEN3;
} }
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2 || arch == LLM::LLMArch::MINISTRAL_3_3B) {
tokenizer = std::make_shared<MistralTokenizer>(); tokenizer = std::make_shared<MistralTokenizer>();
} else { } else {
tokenizer = std::make_shared<Qwen2Tokenizer>(); tokenizer = std::make_shared<Qwen2Tokenizer>();
@ -1867,6 +1869,13 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "[/INST]"; prompt += "[/INST]";
} else if (sd_version_is_ernie_image(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {25}; // -2
prompt_attn_range.first = 0;
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
} else if (sd_version_is_z_image(version)) { } else if (sd_version_is_z_image(version)) {
prompt_template_encode_start_idx = 0; prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2 out_layers = {35}; // -2

View File

@ -3,6 +3,7 @@
#include <optional> #include <optional>
#include "anima.hpp" #include "anima.hpp"
#include "ernie_image.hpp"
#include "flux.hpp" #include "flux.hpp"
#include "mmdit.hpp" #include "mmdit.hpp"
#include "qwen_image.hpp" #include "qwen_image.hpp"
@ -516,4 +517,66 @@ struct ZImageModel : public DiffusionModel {
} }
}; };
struct ErnieImageModel : public DiffusionModel {
std::string prefix;
ErnieImage::ErnieImageRunner ernie_image;
ErnieImageModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
}
std::string get_desc() override {
return ernie_image.get_desc();
}
void alloc_params_buffer() override {
ernie_image.alloc_params_buffer();
}
void free_params_buffer() override {
ernie_image.free_params_buffer();
}
void free_compute_buffer() override {
ernie_image.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
ernie_image.get_param_tensors(tensors, prefix);
}
size_t get_params_buffer_size() override {
return ernie_image.get_params_buffer_size();
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
ernie_image.set_weight_adapter(adapter);
}
int64_t get_adm_in_channels() override {
return 768;
}
void set_flash_attention_enabled(bool enabled) {
ernie_image.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
ernie_image.set_circular_axes(circular_x, circular_y);
}
sd::Tensor<float> compute(int n_threads,
const DiffusionParams& diffusion_params) override {
GGML_ASSERT(diffusion_params.x != nullptr);
GGML_ASSERT(diffusion_params.timesteps != nullptr);
return ernie_image.compute(n_threads,
*diffusion_params.x,
*diffusion_params.timesteps,
tensor_or_empty(diffusion_params.context));
}
};
#endif #endif

438
src/ernie_image.hpp Normal file
View File

@ -0,0 +1,438 @@
#ifndef __SD_ERNIE_IMAGE_HPP__
#define __SD_ERNIE_IMAGE_HPP__
#include <memory>
#include <vector>
#include "common_dit.hpp"
#include "flux.hpp"
#include "qwen_image.hpp"
#include "rope.hpp"
namespace ErnieImage {
constexpr int ERNIE_IMAGE_GRAPH_SIZE = 40960;
__STATIC_INLINE__ ggml_tensor* timestep_embedding_sin_cos(ggml_context* ctx,
ggml_tensor* timesteps,
int dim,
int max_period = 10000) {
auto emb = ggml_ext_timestep_embedding(ctx, timesteps, dim, max_period, 1.0f);
int64_t half = dim / 2;
auto cos_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], 0);
auto sin_part = ggml_view_2d(ctx, emb, half, emb->ne[1], emb->nb[1], half * emb->nb[0]);
auto sin_first = ggml_concat(ctx, sin_part, cos_part, 0);
return sin_first;
}
__STATIC_INLINE__ ggml_tensor* apply_rotary_emb(ggml_context* ctx, ggml_tensor* x, ggml_tensor* pe) {
// x: [N, S, heads, head_dim]
// pe: [2, S, 1, head_dim], stored as ggml [head_dim, 1, S, 2].
int64_t head_dim = x->ne[0];
int64_t heads = x->ne[1];
int64_t S = x->ne[2];
int64_t N = x->ne[3];
int64_t rot_dim = pe->ne[0];
GGML_ASSERT(rot_dim <= head_dim);
GGML_ASSERT(rot_dim % 2 == 0);
GGML_ASSERT(pe->ne[1] == 1 && pe->ne[2] == S && pe->ne[3] == 2);
x = ggml_cont(ctx, x);
auto x_rot = ggml_ext_slice(ctx, x, 0, 0, rot_dim, false);
auto x_pass = rot_dim < head_dim ? ggml_ext_slice(ctx, x, 0, rot_dim, head_dim, false) : nullptr;
int64_t half = rot_dim / 2;
auto x1 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], 0);
auto x2 = ggml_view_4d(ctx, x_rot, half, heads, S, N, x_rot->nb[1], x_rot->nb[2], x_rot->nb[3], half * x_rot->nb[0]);
x1 = ggml_cont(ctx, x1);
x2 = ggml_cont(ctx, x2);
auto rotated = ggml_concat(ctx, ggml_neg(ctx, x2), x1, 0);
auto cos_emb = ggml_ext_slice(ctx, pe, 3, 0, 1, false);
auto sin_emb = ggml_ext_slice(ctx, pe, 3, 1, 2, false);
auto out = ggml_add(ctx, ggml_mul(ctx, x_rot, cos_emb), ggml_mul(ctx, rotated, sin_emb));
if (x_pass != nullptr) {
out = ggml_concat(ctx, out, x_pass, 0);
}
return out;
}
struct ErnieImageAttention : public GGMLBlock {
int64_t num_heads;
int64_t head_dim;
ErnieImageAttention(int64_t query_dim,
int64_t heads,
int64_t dim_head,
float eps = 1e-6f)
: num_heads(heads), head_dim(dim_head) {
int64_t inner_dim = heads * dim_head;
blocks["to_q"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["to_k"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["to_v"] = std::make_shared<Linear>(query_dim, inner_dim, false);
blocks["norm_q"] = std::make_shared<RMSNorm>(dim_head, eps);
blocks["norm_k"] = std::make_shared<RMSNorm>(dim_head, eps);
blocks["to_out.0"] = std::make_shared<Linear>(inner_dim, query_dim, false);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* pe,
ggml_tensor* attention_mask = nullptr) {
// x: [N, S, hidden_size]
// pe: [S, head_dim/2, 2, 2], generated in image-token-first order.
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
auto norm_q = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_q"]);
auto norm_k = std::dynamic_pointer_cast<RMSNorm>(blocks["norm_k"]);
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
int64_t S = x->ne[1];
int64_t N = x->ne[2];
auto q = to_q->forward(ctx, x);
auto k = to_k->forward(ctx, x);
auto v = to_v->forward(ctx, x);
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, S, N); // [N, S, heads, head_dim]
q = norm_q->forward(ctx, q);
k = norm_k->forward(ctx, k);
q = apply_rotary_emb(ctx->ggml_ctx, q, pe);
k = apply_rotary_emb(ctx->ggml_ctx, k, pe);
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, heads, S, head_dim]
q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]);
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]);
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size]
x = to_out_0->forward(ctx, x);
return x;
}
};
struct ErnieImageFeedForward : public GGMLBlock {
public:
ErnieImageFeedForward(int64_t hidden_size, int64_t ffn_hidden_size) {
blocks["gate_proj"] = std::make_shared<Linear>(hidden_size, ffn_hidden_size, false);
blocks["up_proj"] = std::make_shared<Linear>(hidden_size, ffn_hidden_size, false);
blocks["linear_fc2"] = std::make_shared<Linear>(ffn_hidden_size, hidden_size, false);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto gate_proj = std::dynamic_pointer_cast<Linear>(blocks["gate_proj"]);
auto up_proj = std::dynamic_pointer_cast<Linear>(blocks["up_proj"]);
auto linear_fc2 = std::dynamic_pointer_cast<Linear>(blocks["linear_fc2"]);
auto gate = gate_proj->forward(ctx, x);
gate = ggml_ext_gelu(ctx->ggml_ctx, gate);
x = up_proj->forward(ctx, x);
x = ggml_mul(ctx->ggml_ctx, x, gate);
x = linear_fc2->forward(ctx, x);
return x;
}
};
struct ErnieImageSharedAdaLNBlock : public GGMLBlock {
public:
ErnieImageSharedAdaLNBlock(int64_t hidden_size,
int64_t num_heads,
int64_t ffn_hidden_size,
float eps = 1e-6f) {
blocks["adaLN_sa_ln"] = std::make_shared<RMSNorm>(hidden_size, eps);
blocks["self_attention"] = std::make_shared<ErnieImageAttention>(hidden_size,
num_heads,
hidden_size / num_heads,
eps);
blocks["adaLN_mlp_ln"] = std::make_shared<RMSNorm>(hidden_size, eps);
blocks["mlp"] = std::make_shared<ErnieImageFeedForward>(hidden_size, ffn_hidden_size);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* pe,
const std::vector<ggml_tensor*>& temb,
ggml_tensor* attention_mask = nullptr) {
// x: [N, image_tokens + text_tokens, hidden_size]
auto adaLN_sa_ln = std::dynamic_pointer_cast<RMSNorm>(blocks["adaLN_sa_ln"]);
auto self_attention = std::dynamic_pointer_cast<ErnieImageAttention>(blocks["self_attention"]);
auto adaLN_mlp_ln = std::dynamic_pointer_cast<RMSNorm>(blocks["adaLN_mlp_ln"]);
auto mlp = std::dynamic_pointer_cast<ErnieImageFeedForward>(blocks["mlp"]);
auto shift_msa = temb[0];
auto scale_msa = temb[1];
auto gate_msa = temb[2];
auto shift_mlp = temb[3];
auto scale_mlp = temb[4];
auto gate_mlp = temb[5];
auto residual = x;
x = adaLN_sa_ln->forward(ctx, x);
x = Flux::modulate(ctx->ggml_ctx, x, shift_msa, scale_msa, true);
auto attn_out = self_attention->forward(ctx, x, pe, attention_mask);
x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa));
residual = x;
x = adaLN_mlp_ln->forward(ctx, x);
x = Flux::modulate(ctx->ggml_ctx, x, shift_mlp, scale_mlp, true);
x = ggml_add(ctx->ggml_ctx, residual, ggml_mul(ctx->ggml_ctx, mlp->forward(ctx, x), gate_mlp));
return x;
}
};
struct ErnieImageAdaLNContinuous : public GGMLBlock {
public:
ErnieImageAdaLNContinuous(int64_t hidden_size, float eps = 1e-6f) {
blocks["norm"] = std::make_shared<LayerNorm>(hidden_size, eps, false);
blocks["linear"] = std::make_shared<Linear>(hidden_size, hidden_size * 2, true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* conditioning) {
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto mods = ggml_ext_chunk(ctx->ggml_ctx, linear->forward(ctx, conditioning), 2, 0);
auto scale = mods[0];
auto shift = mods[1];
x = norm->forward(ctx, x);
x = Flux::modulate(ctx->ggml_ctx, x, shift, scale);
return x;
}
};
struct ErnieImageParams {
int64_t hidden_size = 4096;
int64_t num_heads = 32;
int64_t num_layers = 36;
int64_t ffn_hidden_size = 12288;
int64_t in_channels = 128;
int64_t out_channels = 128;
int patch_size = 1;
int64_t text_in_dim = 3072;
int theta = 256;
std::vector<int> axes_dim = {32, 48, 48};
int axes_dim_sum = 128;
float eps = 1e-6f;
};
class ErnieImageModel : public GGMLBlock {
public:
ErnieImageParams params;
ErnieImageModel() = default;
ErnieImageModel(ErnieImageParams params)
: params(params) {
blocks["x_embedder.proj"] = std::make_shared<Conv2d>(params.in_channels,
params.hidden_size,
std::pair<int, int>{params.patch_size, params.patch_size},
std::pair<int, int>{params.patch_size, params.patch_size},
std::pair<int, int>{0, 0},
std::pair<int, int>{1, 1},
true);
if (params.text_in_dim != params.hidden_size) {
blocks["text_proj"] = std::make_shared<Linear>(params.text_in_dim, params.hidden_size, false);
}
blocks["time_embedding"] = std::make_shared<Qwen::TimestepEmbedding>(params.hidden_size, params.hidden_size);
blocks["adaLN_modulation.1"] = std::make_shared<Linear>(params.hidden_size, 6 * params.hidden_size, true);
for (int i = 0; i < params.num_layers; i++) {
blocks["layers." + std::to_string(i)] = std::make_shared<ErnieImageSharedAdaLNBlock>(params.hidden_size,
params.num_heads,
params.ffn_hidden_size,
params.eps);
}
blocks["final_norm"] = std::make_shared<ErnieImageAdaLNContinuous>(params.hidden_size, params.eps);
blocks["final_linear"] = std::make_shared<Linear>(params.hidden_size,
params.patch_size * params.patch_size * params.out_channels,
true);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep,
ggml_tensor* context,
ggml_tensor* pe) {
// x: [N, C, H, W]
// context: [N, text_tokens, 3072]
// pe: [image_tokens + text_tokens, head_dim/2, 2, 2]
GGML_ASSERT(context != nullptr);
GGML_ASSERT(x->ne[1] % params.patch_size == 0 && x->ne[0] % params.patch_size == 0);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t Hp = H / params.patch_size;
int64_t Wp = W / params.patch_size;
int64_t n_img = Hp * Wp;
int64_t N = x->ne[3];
auto x_embedder_proj = std::dynamic_pointer_cast<Conv2d>(blocks["x_embedder.proj"]);
auto time_embedding = std::dynamic_pointer_cast<Qwen::TimestepEmbedding>(blocks["time_embedding"]);
auto adaLN_mod = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto final_norm = std::dynamic_pointer_cast<ErnieImageAdaLNContinuous>(blocks["final_norm"]);
auto final_linear = std::dynamic_pointer_cast<Linear>(blocks["final_linear"]);
auto img = x_embedder_proj->forward(ctx, x); // [N, hidden_size, Hp, Wp]
img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], N); // [N, hidden_size, image_tokens]
img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, image_tokens, hidden_size]
auto txt = context;
auto text_proj = std::dynamic_pointer_cast<Linear>(blocks["text_proj"]);
if (text_proj) {
txt = text_proj->forward(ctx, txt);
}
auto hidden_states = ggml_concat(ctx->ggml_ctx, img, txt, 1); // [N, image_tokens + text_tokens, hidden_size]
auto sample = timestep_embedding_sin_cos(ctx->ggml_ctx, timestep, static_cast<int>(params.hidden_size));
auto c = time_embedding->forward(ctx, sample); // [N, hidden_size]
auto mod_params = adaLN_mod->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 6 * hidden_size]
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, mod_params, 6, 0);
std::vector<ggml_tensor*> temb;
temb.reserve(6);
for (auto chunk : chunks) {
temb.push_back(ggml_reshape_3d(ctx->ggml_ctx, chunk, chunk->ne[0], 1, chunk->ne[1])); // [N, 1, hidden_size]
}
for (int i = 0; i < params.num_layers; i++) {
auto layer = std::dynamic_pointer_cast<ErnieImageSharedAdaLNBlock>(blocks["layers." + std::to_string(i)]);
hidden_states = layer->forward(ctx, hidden_states, pe, temb);
}
hidden_states = final_norm->forward(ctx, hidden_states, c);
hidden_states = final_linear->forward(ctx, hidden_states); // [N, image_tokens, p*p*out_channels]
auto patches = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, 0, n_img); // [N, image_tokens, hidden_size]
auto out = DiT::unpatchify(ctx->ggml_ctx,
patches,
Hp,
Wp,
params.patch_size,
params.patch_size,
false); // [N, out_channels, H, W]
return out;
}
};
struct ErnieImageRunner : public GGMLRunner {
ErnieImageParams ernie_params;
ErnieImageModel ernie_image;
std::vector<float> pe_vec;
ErnieImageRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu) {
ernie_params.num_layers = 0;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "x_embedder.proj.weight") && tensor_storage.n_dims == 4) {
ernie_params.patch_size = static_cast<int>(tensor_storage.ne[0]);
ernie_params.in_channels = tensor_storage.ne[2];
ernie_params.hidden_size = tensor_storage.ne[3];
} else if (ends_with(name, "text_proj.weight") && tensor_storage.n_dims == 2) {
ernie_params.text_in_dim = tensor_storage.ne[0];
} else if (ends_with(name, "layers.0.self_attention.norm_q.weight")) {
int64_t head_dim = tensor_storage.ne[0];
ernie_params.num_heads = ernie_params.hidden_size / head_dim;
} else if (ends_with(name, "layers.0.mlp.gate_proj.weight") && tensor_storage.n_dims == 2) {
ernie_params.ffn_hidden_size = tensor_storage.ne[1];
} else if (ends_with(name, "final_linear.weight") && tensor_storage.n_dims == 2) {
int64_t out_dim = tensor_storage.ne[1];
ernie_params.out_channels = out_dim / ernie_params.patch_size / ernie_params.patch_size;
}
size_t pos = name.find("layers.");
if (pos != std::string::npos) {
std::string layer_name = name.substr(pos);
auto items = split_string(layer_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > ernie_params.num_layers) {
ernie_params.num_layers = block_index + 1;
}
}
}
}
if (ernie_params.num_layers == 0) {
ernie_params.num_layers = 36;
}
ernie_params.axes_dim_sum = 0;
for (int axis_dim : ernie_params.axes_dim) {
ernie_params.axes_dim_sum += axis_dim;
}
LOG_INFO("ernie_image: layers = %" PRId64 ", hidden_size = %" PRId64 ", heads = %" PRId64
", ffn_hidden_size = %" PRId64 ", in_channels = %" PRId64 ", out_channels = %" PRId64,
ernie_params.num_layers,
ernie_params.hidden_size,
ernie_params.num_heads,
ernie_params.ffn_hidden_size,
ernie_params.in_channels,
ernie_params.out_channels);
ernie_image = ErnieImageModel(ernie_params);
ernie_image.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return "ernie_image";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
ernie_image.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
const sd::Tensor<float>& timesteps_tensor,
const sd::Tensor<float>& context_tensor) {
ggml_cgraph* gf = new_graph_custom(ERNIE_IMAGE_GRAPH_SIZE);
ggml_tensor* x = make_input(x_tensor);
ggml_tensor* timesteps = make_input(timesteps_tensor);
GGML_ASSERT(x->ne[3] == 1);
GGML_ASSERT(!context_tensor.empty());
ggml_tensor* context = make_input(context_tensor);
pe_vec = Rope::gen_ernie_image_pe(static_cast<int>(x->ne[1]),
static_cast<int>(x->ne[0]),
ernie_params.patch_size,
static_cast<int>(x->ne[3]),
static_cast<int>(context->ne[1]),
ernie_params.theta,
circular_y_enabled,
circular_x_enabled,
ernie_params.axes_dim);
int pos_len = static_cast<int>(pe_vec.size() / ernie_params.axes_dim_sum / 2);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, ernie_params.axes_dim_sum, 1, pos_len, 2);
set_backend_tensor_data(pe, pe_vec.data());
auto runner_ctx = get_context();
ggml_tensor* out = ernie_image.forward(&runner_ctx, x, timesteps, context, pe);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& x,
const sd::Tensor<float>& timesteps,
const sd::Tensor<float>& context) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
}
};
} // namespace ErnieImage
#endif // __SD_ERNIE_IMAGE_HPP__

View File

@ -28,6 +28,7 @@ namespace LLM {
QWEN2_5_VL, QWEN2_5_VL,
QWEN3, QWEN3,
MISTRAL_SMALL_3_2, MISTRAL_SMALL_3_2,
MINISTRAL_3_3B,
ARCH_COUNT, ARCH_COUNT,
}; };
@ -35,6 +36,7 @@ namespace LLM {
"qwen2.5vl", "qwen2.5vl",
"qwen3", "qwen3",
"mistral_small3.2", "mistral_small3.2",
"ministral3.3b",
}; };
struct LLMVisionParams { struct LLMVisionParams {
@ -419,6 +421,9 @@ namespace LLM {
if (arch == LLMArch::MISTRAL_SMALL_3_2) { if (arch == LLMArch::MISTRAL_SMALL_3_2) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 8192, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::MINISTRAL_3_3B) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 262144, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::QWEN3) { } else if (arch == LLMArch::QWEN3) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@ -634,7 +639,7 @@ namespace LLM {
bool enable_vision_ = false) bool enable_vision_ = false)
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
params.arch = arch; params.arch = arch;
if (arch == LLMArch::MISTRAL_SMALL_3_2) { if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
params.head_dim = 128; params.head_dim = 128;
params.num_heads = 32; params.num_heads = 32;
params.num_kv_heads = 8; params.num_kv_heads = 8;
@ -746,7 +751,7 @@ namespace LLM {
} }
int64_t n_tokens = input_ids->ne[0]; int64_t n_tokens = input_ids->ne[0];
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) { if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) {
input_pos_vec.resize(n_tokens); input_pos_vec.resize(n_tokens);
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
input_pos_vec[i] = i; input_pos_vec[i] = i;
@ -982,7 +987,7 @@ namespace LLM {
const std::string prefix = "", const std::string prefix = "",
bool enable_vision = false) bool enable_vision = false)
: model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) { : model(arch, backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) {
if (arch == LLMArch::MISTRAL_SMALL_3_2) { if (arch == LLMArch::MISTRAL_SMALL_3_2 || arch == LLMArch::MINISTRAL_3_3B) {
tokenizer = std::make_shared<MistralTokenizer>(); tokenizer = std::make_shared<MistralTokenizer>();
} else { } else {
tokenizer = std::make_shared<Qwen2Tokenizer>(); tokenizer = std::make_shared<Qwen2Tokenizer>();

View File

@ -1049,6 +1049,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) {
return VERSION_Z_IMAGE; return VERSION_Z_IMAGE;
} }
if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) {
return VERSION_ERNIE_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true; is_wan = true;
} }

View File

@ -50,6 +50,7 @@ enum SDVersion {
VERSION_FLUX2_KLEIN, VERSION_FLUX2_KLEIN,
VERSION_Z_IMAGE, VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE, VERSION_OVIS_IMAGE,
VERSION_ERNIE_IMAGE,
VERSION_COUNT, VERSION_COUNT,
}; };
@ -137,6 +138,20 @@ static inline bool sd_version_is_z_image(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_ernie_image(SDVersion version) {
if (version == VERSION_ERNIE_IMAGE) {
return true;
}
return false;
}
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version)) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) { static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT || version == VERSION_SD2_INPAINT ||
@ -155,7 +170,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) || sd_version_is_anima(version) ||
sd_version_is_z_image(version)) { sd_version_is_z_image(version) ||
sd_version_is_ernie_image(version)) {
return true; return true;
} }
return false; return false;

View File

@ -7,6 +7,11 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
namespace Rope { namespace Rope {
enum class EmbedNDLayout {
Matrix,
ErnieImage,
};
template <class T> template <class T>
__STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) { __STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num); std::vector<T> result(num);
@ -169,7 +174,8 @@ namespace Rope {
int bs, int bs,
const std::vector<float>& axis_thetas, const std::vector<float>& axis_thetas,
const std::vector<int>& axes_dim, const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) { const std::vector<std::vector<int>>& wrap_dims = {},
EmbedNDLayout layout = EmbedNDLayout::Matrix) {
std::vector<std::vector<float>> trans_ids = transpose(ids); std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs; size_t pos_len = ids.size() / bs;
size_t num_axes = axes_dim.size(); size_t num_axes = axes_dim.size();
@ -204,6 +210,24 @@ namespace Rope {
offset += rope_emb[0].size(); offset += rope_emb[0].size();
} }
if (layout == EmbedNDLayout::ErnieImage) {
int head_dim = emb_dim * 2;
std::vector<float> ernie_emb(bs * pos_len * head_dim * 2, 0.0f);
for (size_t pos_idx = 0; pos_idx < bs * pos_len; ++pos_idx) {
for (int i = 0; i < emb_dim; ++i) {
float cos_val = emb[pos_idx][4 * i];
float sin_val = emb[pos_idx][4 * i + 2];
size_t cos_offset = pos_idx * head_dim + 2 * i;
size_t sin_offset = bs * pos_len * head_dim + cos_offset;
ernie_emb[cos_offset] = cos_val;
ernie_emb[cos_offset + 1] = cos_val;
ernie_emb[sin_offset] = sin_val;
ernie_emb[sin_offset + 1] = sin_val;
}
}
return ernie_emb;
}
return flatten(emb); return flatten(emb);
} }
@ -211,9 +235,10 @@ namespace Rope {
int bs, int bs,
float theta, float theta,
const std::vector<int>& axes_dim, const std::vector<int>& axes_dim,
const std::vector<std::vector<int>>& wrap_dims = {}) { const std::vector<std::vector<int>>& wrap_dims = {},
EmbedNDLayout layout = EmbedNDLayout::Matrix) {
std::vector<float> axis_thetas(axes_dim.size(), theta); std::vector<float> axis_thetas(axes_dim.size(), theta);
return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims); return embed_nd(ids, bs, axis_thetas, axes_dim, wrap_dims, layout);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
@ -437,6 +462,74 @@ namespace Rope {
return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims); return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_ernie_image_ids(int h,
int w,
int patch_size,
int bs,
int context_len) {
int h_len = h / patch_size;
int w_len = w / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0f));
std::vector<float> h_ids = linspace<float>(0.f, static_cast<float>(h_len - 1), h_len);
std::vector<float> w_ids = linspace<float>(0.f, static_cast<float>(w_len - 1), w_len);
for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
img_ids[i * w_len + j][0] = static_cast<float>(context_len);
img_ids[i * w_len + j][1] = h_ids[i];
img_ids[i * w_len + j][2] = w_ids[j];
}
}
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3, 0.0f));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < static_cast<int>(img_ids.size()); ++j) {
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
}
}
std::vector<std::vector<float>> txt_ids(bs * context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < context_len; ++j) {
txt_ids[i * context_len + j][0] = static_cast<float>(j);
}
}
return concat_ids(img_ids_repeated, txt_ids, bs);
}
__STATIC_INLINE__ std::vector<float> gen_ernie_image_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
int theta,
bool circular_h,
bool circular_w,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_ernie_image_ids(h, w, patch_size, bs, context_len);
std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int h_len = h / patch_size;
int w_len = w / patch_size;
if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
const size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) {
wrap_dims[1][token_i] = h_len;
}
if (circular_w) {
wrap_dims[2][token_i] = w_len;
}
}
}
}
return embed_nd(ids, bs, static_cast<float>(theta), axes_dim, wrap_dims, EmbedNDLayout::ErnieImage);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t, __STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h, int h,
int w, int w,

View File

@ -52,6 +52,7 @@ const char* model_version_to_str[] = {
"Flux.2 klein", "Flux.2 klein",
"Z-Image", "Z-Image",
"Ovis Image", "Ovis Image",
"Ernie Image",
}; };
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
@ -551,6 +552,15 @@ public:
tensor_storage_map, tensor_storage_map,
"model.diffusion_model", "model.diffusion_model",
version); version);
} else if (sd_version_is_ernie_image(version)) {
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
version);
diffusion_model = std::make_shared<ErnieImageModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
"model.diffusion_model");
} else { // SD1.x SD2.x SDXL } else { // SD1.x SD2.x SDXL
std::map<std::string, std::string> embbeding_map; std::map<std::string, std::string> embbeding_map;
for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) {
@ -819,6 +829,10 @@ public:
if (version == VERSION_SVD) { if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3"); ignore_tensors.insert("conditioner.embedders.3");
} }
if (sd_version_is_ernie_image(version)) {
ignore_tensors.insert("text_encoders.llm.vision_tower.");
ignore_tensors.insert("text_encoders.llm.multi_modal_projector.");
}
bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap); bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap);
if (!success) { if (!success) {
LOG_ERROR("load tensors from model loader failed"); LOG_ERROR("load tensors from model loader failed");
@ -922,6 +936,7 @@ public:
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) || sd_version_is_anima(version) ||
sd_version_is_ernie_image(version) ||
sd_version_is_z_image(version)) { sd_version_is_z_image(version)) {
pred_type = FLOW_PRED; pred_type = FLOW_PRED;
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version)) {
@ -1395,7 +1410,7 @@ public:
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]); uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
if (dim == 128) { if (dim == 128) {
if (sd_version_is_flux2(version)) { if (sd_version_uses_flux2_vae(version)) {
latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_proj = flux2_latent_rgb_proj;
latent_rgb_bias = flux2_latent_rgb_bias; latent_rgb_bias = flux2_latent_rgb_bias;
patch_sz = 2; patch_sz = 2;
@ -1844,7 +1859,7 @@ public:
latent_channel = 48; latent_channel = 48;
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
latent_channel = 3; latent_channel = 3;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_uses_flux2_vae(version)) {
latent_channel = 128; latent_channel = 128;
} else { } else {
latent_channel = 16; latent_channel = 16;

View File

@ -69,7 +69,7 @@ public:
int scale_factor = 8; int scale_factor = 8;
if (version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2_2_TI2V) {
scale_factor = 16; scale_factor = 16;
} else if (sd_version_is_flux2(version)) { } else if (sd_version_uses_flux2_vae(version)) {
scale_factor = 16; scale_factor = 16;
} else if (version == VERSION_CHROMA_RADIANCE) { } else if (version == VERSION_CHROMA_RADIANCE) {
scale_factor = 1; scale_factor = 1;