diff --git a/README.md b/README.md
index 2e70b690..0d54c05c 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,7 @@ API and command-line option may change frequently.***
## ๐ฅImportant News
+* **2026/06/25** ๐ stable-diffusion.cpp now supports **Krea2**
* **2026/06/04** ๐ stable-diffusion.cpp now supports **Ideogram4**
* **2026/05/31** ๐ stable-diffusion.cpp now supports **PiD**
* **2026/05/27** ๐ stable-diffusion.cpp now supports **Lens**
@@ -51,6 +52,7 @@ API and command-line option may change frequently.***
- [Anima](./docs/anima.md)
- [ERNIE-Image](./docs/ernie_image.md)
- [Boogu Image](./docs/boogu_image.md)
+ - [Krea2](./docs/krea2.md)
- [HiDream-O1-Image](./docs/hidream_o1_image.md)
- [Ideogram4](./docs/ideogram4.md)
- Image Edit Models
diff --git a/assets/krea2/example.png b/assets/krea2/example.png
new file mode 100644
index 00000000..c665e1e7
Binary files /dev/null and b/assets/krea2/example.png differ
diff --git a/docs/krea2.md b/docs/krea2.md
new file mode 100644
index 00000000..b47a0354
--- /dev/null
+++ b/docs/krea2.md
@@ -0,0 +1,27 @@
+# How to Use
+
+Krea2 uses a Krea2 diffusion transformer, the Wan2.1 VAE, and Qwen3-VL 4B as the LLM text encoder.
+
+## Download weights
+
+- Download Krea2 Raw
+ - safetensors: https://huggingface.co/krea/Krea-2-Raw/tree/main
+ - gguf: https://huggingface.co/realrebelai/KREA-2_GGUFs/tree/main/BASE
+- Download Krea2 Turbo
+ - safetensors: https://huggingface.co/krea/Krea-2-Turbo/tree/main
+ - gguf: https://huggingface.co/realrebelai/KREA-2_GGUFs/tree/main/TURBO
+- Download vae
+ - safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors
+- Download Qwen3-VL 4B
+ - safetensors: https://huggingface.co/Comfy-Org/Krea-2/tree/main/text_encoders
+ - gguf: https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct-GGUF/tree/main
+
+## Examples
+
+### Krea2
+
+```
+.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Krea-2-Raw-Q8_0.gguf --llm ..\..\ComfyUI\models\text_encoders\Qwen3-VL-4B-Instruct-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors -p "a lovely cat holding a sign says 'krea2.cpp'" --diffusion-fa -v --offload-to-cpu
+```
+
+
diff --git a/src/conditioning/conditioner.hpp b/src/conditioning/conditioner.hpp
index ae1a5b5b..e037fe76 100644
--- a/src/conditioning/conditioner.hpp
+++ b/src/conditioning/conditioner.hpp
@@ -1518,7 +1518,7 @@ struct LLMEmbedder : public Conditioner {
arch = LLM::LLMArch::GPT_OSS_20B;
} else if (sd_version_is_pid(version)) {
arch = LLM::LLMArch::GEMMA2_2B;
- } else if (sd_version_is_ideogram4(version) || sd_version_is_boogu_image(version)) {
+ } else if (sd_version_is_ideogram4(version) || sd_version_is_boogu_image(version) || sd_version_is_krea2(version)) {
arch = LLM::LLMArch::QWEN3_VL;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3;
@@ -1837,6 +1837,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast(prompt.size());
prompt += "<|im_end|>\n";
}
+ } else if (sd_version_is_krea2(version)) {
+ prompt_template_encode_start_idx = 34;
+ out_layers = {2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35};
+
+ prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
+
+ prompt_attn_range.first = static_cast(prompt.size());
+ prompt += conditioner_params.text;
+ prompt_attn_range.second = static_cast(prompt.size());
+
+ prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_longcat(version)) {
spell_quotes = true;
diff --git a/src/core/ggml_extend.hpp b/src/core/ggml_extend.hpp
index a3dda16b..65196813 100644
--- a/src/core/ggml_extend.hpp
+++ b/src/core/ggml_extend.hpp
@@ -1382,7 +1382,16 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
if (!ggml_backend_supports_op(backend, kqv)) {
kqv = nullptr;
} else {
- kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
+ kqv = ggml_view_4d(ctx,
+ kqv,
+ d_head,
+ n_head,
+ L_q,
+ N,
+ kqv->nb[1],
+ kqv->nb[2],
+ kqv->nb[1] * n_head,
+ 0);
}
}
}
diff --git a/src/model.h b/src/model.h
index d02ed65b..cce30913 100644
--- a/src/model.h
+++ b/src/model.h
@@ -49,6 +49,7 @@ enum SDVersion {
VERSION_LONGCAT,
VERSION_PID,
VERSION_IDEOGRAM4,
+ VERSION_KREA2,
VERSION_ESRGAN,
VERSION_COUNT,
};
@@ -186,6 +187,13 @@ static inline bool sd_version_is_ideogram4(SDVersion version) {
return false;
}
+static inline bool sd_version_is_krea2(SDVersion version) {
+ if (version == VERSION_KREA2) {
+ return true;
+ }
+ return false;
+}
+
static inline bool sd_version_uses_flux_vae(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) {
return true;
@@ -226,7 +234,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
sd_version_is_lens(version) ||
sd_version_is_longcat(version) ||
sd_version_is_pid(version) ||
- sd_version_is_ideogram4(version)) {
+ sd_version_is_ideogram4(version) ||
+ sd_version_is_krea2(version)) {
return true;
}
return false;
diff --git a/src/model/diffusion/krea2.hpp b/src/model/diffusion/krea2.hpp
new file mode 100644
index 00000000..02e65559
--- /dev/null
+++ b/src/model/diffusion/krea2.hpp
@@ -0,0 +1,683 @@
+#ifndef __SD_MODEL_DIFFUSION_KREA2_HPP__
+#define __SD_MODEL_DIFFUSION_KREA2_HPP__
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "core/ggml_extend.hpp"
+#include "core/ggml_graph_cut.h"
+#include "model/common/rope.hpp"
+#include "model/diffusion/dit.hpp"
+#include "model/diffusion/flux.hpp"
+#include "model/diffusion/model.hpp"
+#include "model_loader.h"
+
+namespace Krea2 {
+ constexpr int KREA2_GRAPH_SIZE = 65536;
+
+ struct Krea2Config {
+ int patch_size = 2;
+ int64_t in_channels = 16;
+ int64_t out_channels = 16;
+ int64_t features = 6144;
+ int64_t timestep_dim = 256;
+ int64_t text_dim = 2560;
+ int64_t text_layers = 12;
+ int64_t layers = 28;
+ int64_t heads = 48;
+ int64_t kv_heads = 12;
+ int64_t text_heads = 20;
+ int64_t text_kv_heads = 20;
+ int64_t mlp_multiplier = 4;
+ float theta = 1000.f;
+ float norm_eps = 1e-5f;
+ std::vector axes_dim = {32, 48, 48};
+ int axes_dim_sum = 128;
+
+ int64_t head_dim() const {
+ return features / heads;
+ }
+
+ static int64_t count_blocks(const String2TensorStorage& tensor_storage_map,
+ const std::string& prefix,
+ const std::string& block_prefix) {
+ int64_t count = 0;
+ std::string full_prefix = prefix.empty() ? block_prefix : prefix + "." + block_prefix;
+ for (const auto& [name, _] : tensor_storage_map) {
+ if (!starts_with(name, full_prefix)) {
+ continue;
+ }
+ std::string tail = name.substr(full_prefix.size());
+ size_t dot = tail.find('.');
+ if (dot == std::string::npos) {
+ continue;
+ }
+ int block_index = std::atoi(tail.substr(0, dot).c_str());
+ count = std::max(count, block_index + 1);
+ }
+ return count;
+ }
+
+ void update_axes_dim() {
+ int64_t dim_head = head_dim();
+ int64_t unit = dim_head / 16;
+ axes_dim = {
+ static_cast(dim_head - 12 * unit),
+ static_cast(6 * unit),
+ static_cast(6 * unit),
+ };
+ axes_dim_sum = axes_dim[0] + axes_dim[1] + axes_dim[2];
+ }
+
+ static Krea2Config detect_from_weights(const String2TensorStorage& tensor_storage_map,
+ const std::string& prefix) {
+ Krea2Config config;
+ int64_t detected_head_dim = 0;
+ int64_t detected_text_head_dim = 0;
+
+ for (const auto& [name, tensor_storage] : tensor_storage_map) {
+ if (!starts_with(name, prefix)) {
+ continue;
+ }
+ if (ends_with(name, "first.weight") && tensor_storage.n_dims == 2) {
+ config.in_channels = tensor_storage.ne[0] / (config.patch_size * config.patch_size);
+ config.out_channels = config.in_channels;
+ config.features = tensor_storage.ne[1];
+ } else if (ends_with(name, "blocks.0.attn.qknorm.qnorm.scale") && tensor_storage.n_dims == 1) {
+ detected_head_dim = tensor_storage.ne[0];
+ } else if (ends_with(name, "blocks.0.attn.wq.weight") && tensor_storage.n_dims == 2) {
+ if (detected_head_dim > 0) {
+ config.heads = tensor_storage.ne[1] / detected_head_dim;
+ }
+ } else if (ends_with(name, "blocks.0.attn.wk.weight") && tensor_storage.n_dims == 2) {
+ if (detected_head_dim > 0) {
+ config.kv_heads = tensor_storage.ne[1] / detected_head_dim;
+ }
+ } else if (ends_with(name, "txtfusion.projector.weight") && tensor_storage.n_dims == 2) {
+ config.text_layers = tensor_storage.ne[0];
+ } else if (ends_with(name, "txtfusion.layerwise_blocks.0.prenorm.scale") && tensor_storage.n_dims == 1) {
+ config.text_dim = tensor_storage.ne[0];
+ } else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.qknorm.qnorm.scale") && tensor_storage.n_dims == 1) {
+ detected_text_head_dim = tensor_storage.ne[0];
+ } else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.wq.weight") && tensor_storage.n_dims == 2) {
+ if (detected_text_head_dim > 0) {
+ config.text_heads = tensor_storage.ne[1] / detected_text_head_dim;
+ }
+ } else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.wk.weight") && tensor_storage.n_dims == 2) {
+ if (detected_text_head_dim > 0) {
+ config.text_kv_heads = tensor_storage.ne[1] / detected_text_head_dim;
+ }
+ } else if (ends_with(name, "last.linear.weight") && tensor_storage.n_dims == 2) {
+ config.out_channels = tensor_storage.ne[1] / (config.patch_size * config.patch_size);
+ }
+ }
+
+ config.layers = std::max(1, count_blocks(tensor_storage_map, prefix, "blocks."));
+ if (detected_head_dim > 0 && config.features > 0) {
+ config.heads = config.features / detected_head_dim;
+ }
+ if (detected_head_dim > 0) {
+ std::string wk_name = prefix.empty() ? "blocks.0.attn.wk.weight" : prefix + ".blocks.0.attn.wk.weight";
+ auto it = tensor_storage_map.find(wk_name);
+ if (it != tensor_storage_map.end() && it->second.n_dims == 2) {
+ config.kv_heads = it->second.ne[1] / detected_head_dim;
+ }
+ }
+ if (detected_text_head_dim > 0 && config.text_dim > 0) {
+ config.text_heads = config.text_dim / detected_text_head_dim;
+ }
+ if (detected_text_head_dim > 0) {
+ std::string wk_name = prefix.empty() ? "txtfusion.layerwise_blocks.0.attn.wk.weight" : prefix + ".txtfusion.layerwise_blocks.0.attn.wk.weight";
+ auto it = tensor_storage_map.find(wk_name);
+ if (it != tensor_storage_map.end() && it->second.n_dims == 2) {
+ config.text_kv_heads = it->second.ne[1] / detected_text_head_dim;
+ }
+ }
+ config.update_axes_dim();
+
+ LOG_DEBUG("krea2: layers=%" PRId64 ", features=%" PRId64 ", heads=%" PRId64 ", kv_heads=%" PRId64 ", text_dim=%" PRId64 ", text_layers=%" PRId64 ", text_heads=%" PRId64 ", text_kv_heads=%" PRId64 ", channels=%" PRId64,
+ config.layers,
+ config.features,
+ config.heads,
+ config.kv_heads,
+ config.text_dim,
+ config.text_layers,
+ config.text_heads,
+ config.text_kv_heads,
+ config.in_channels);
+ return config;
+ }
+ };
+
+ __STATIC_INLINE__ int64_t ceil_to_multiple(int64_t value, int64_t multiple) {
+ return ((value + multiple - 1) / multiple) * multiple;
+ }
+
+ class KreaRMSNorm : public UnaryBlock {
+ protected:
+ int64_t hidden_size;
+ float eps;
+ std::string prefix;
+
+ void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
+ GGML_UNUSED(tensor_storage_map);
+ this->prefix = prefix;
+ params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
+ }
+
+ public:
+ KreaRMSNorm(int64_t hidden_size, float eps = 1e-5f)
+ : hidden_size(hidden_size),
+ eps(eps) {}
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ ggml_tensor* scale = params["scale"];
+ scale = ggml_add(ctx->ggml_ctx, scale, ggml_ext_ones(ctx->ggml_ctx, scale->ne[0], 1, 1, 1));
+ x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
+ x = ggml_mul_inplace(ctx->ggml_ctx, x, scale);
+ return x;
+ }
+ };
+
+ class KreaSwiGLU : public UnaryBlock {
+ public:
+ KreaSwiGLU(int64_t features, int64_t multiplier) {
+ int64_t mlp_dim = ceil_to_multiple(((2 * features) / 3) * multiplier, 128);
+ blocks["gate"] = std::make_shared(features, mlp_dim, false);
+ blocks["up"] = std::make_shared(features, mlp_dim, false);
+ blocks["down"] = std::make_shared(mlp_dim, features, false);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ auto gate = std::dynamic_pointer_cast(blocks["gate"]);
+ auto up = std::dynamic_pointer_cast(blocks["up"]);
+ auto down = std::dynamic_pointer_cast(blocks["down"]);
+
+ auto gated = ggml_silu(ctx->ggml_ctx, gate->forward(ctx, x));
+ auto up_x = up->forward(ctx, x);
+ x = ggml_mul(ctx->ggml_ctx, gated, up_x);
+ return down->forward(ctx, x);
+ }
+ };
+
+ class KreaAttention : public GGMLBlock {
+ protected:
+ int64_t features;
+ int64_t heads;
+ int64_t kv_heads;
+ int64_t head_dim_;
+
+ ggml_tensor* attention_no_rope(GGMLRunnerContext* ctx,
+ ggml_tensor* q,
+ ggml_tensor* k,
+ ggml_tensor* v,
+ ggml_tensor* mask) {
+ int64_t Lq = q->ne[2];
+ int64_t Lk = k->ne[2];
+ int64_t N = q->ne[3];
+ q = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim_ * heads, Lq, N);
+ k = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim_ * kv_heads, Lk, N);
+ v = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim_ * kv_heads, Lk, N);
+ return ggml_ext_attention_ext(ctx->ggml_ctx,
+ ctx->backend,
+ q,
+ k,
+ v,
+ heads,
+ mask,
+ false,
+ ctx->flash_attn_enabled);
+ }
+
+ public:
+ KreaAttention(int64_t features,
+ int64_t heads,
+ int64_t kv_heads,
+ float eps = 1e-5f)
+ : features(features),
+ heads(heads),
+ kv_heads(kv_heads),
+ head_dim_(features / heads) {
+ blocks["wq"] = std::make_shared(features, heads * head_dim_, false);
+ blocks["wk"] = std::make_shared(features, kv_heads * head_dim_, false);
+ blocks["wv"] = std::make_shared(features, kv_heads * head_dim_, false);
+ blocks["gate"] = std::make_shared(features, features, false);
+ blocks["qknorm.qnorm"] = std::make_shared(head_dim_, eps);
+ blocks["qknorm.knorm"] = std::make_shared(head_dim_, eps);
+ blocks["wo"] = std::make_shared(features, features, false);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx,
+ ggml_tensor* x,
+ ggml_tensor* pe = nullptr,
+ ggml_tensor* mask = nullptr) {
+ auto wq = std::dynamic_pointer_cast(blocks["wq"]);
+ auto wk = std::dynamic_pointer_cast(blocks["wk"]);
+ auto wv = std::dynamic_pointer_cast(blocks["wv"]);
+ auto gate = std::dynamic_pointer_cast(blocks["gate"]);
+ auto qnorm = std::dynamic_pointer_cast(blocks["qknorm.qnorm"]);
+ auto knorm = std::dynamic_pointer_cast(blocks["qknorm.knorm"]);
+ auto wo = std::dynamic_pointer_cast(blocks["wo"]);
+
+ if (sd_backend_is(ctx->backend, "Vulkan")) {
+ wo->set_force_prec_f32(true);
+ }
+
+ int64_t L = x->ne[1];
+ int64_t N = x->ne[2];
+
+ auto q = wq->forward(ctx, x);
+ q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim_, heads, L, N);
+ auto k = wk->forward(ctx, x);
+ k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim_, kv_heads, L, N);
+ auto v = wv->forward(ctx, x);
+ v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim_, kv_heads, L, N);
+
+ q = qnorm->forward(ctx, q);
+ k = knorm->forward(ctx, k);
+
+ auto out = pe != nullptr ? Rope::attention(ctx, q, k, v, pe, mask)
+ : attention_no_rope(ctx, q, k, v, mask);
+ out = ggml_mul(ctx->ggml_ctx, out, ggml_sigmoid(ctx->ggml_ctx, gate->forward(ctx, x)));
+ out = wo->forward(ctx, out);
+ return out;
+ }
+ };
+
+ class KreaDoubleSharedModulation : public GGMLBlock {
+ protected:
+ int64_t dim;
+
+ void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
+ GGML_UNUSED(tensor_storage_map);
+ GGML_UNUSED(prefix);
+ params["lin"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim * 6);
+ }
+
+ public:
+ KreaDoubleSharedModulation(int64_t dim)
+ : dim(dim) {}
+
+ std::vector forward(GGMLRunnerContext* ctx, ggml_tensor* vec) {
+ auto lin = ggml_repeat(ctx->ggml_ctx, params["lin"], vec);
+ auto out = ggml_add(ctx->ggml_ctx, vec, lin);
+ return ggml_ext_chunk(ctx->ggml_ctx, out, 6, 0);
+ }
+ };
+
+ class KreaFinalModulation : public GGMLBlock {
+ protected:
+ int64_t dim;
+
+ void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
+ GGML_UNUSED(tensor_storage_map);
+ GGML_UNUSED(prefix);
+ params["lin"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 2);
+ }
+
+ public:
+ KreaFinalModulation(int64_t dim)
+ : dim(dim) {}
+
+ std::vector forward(GGMLRunnerContext* ctx, ggml_tensor* vec) {
+ auto out = ggml_add(ctx->ggml_ctx, params["lin"], vec);
+ return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1);
+ }
+ };
+
+ class KreaTextFusionBlock : public UnaryBlock {
+ public:
+ KreaTextFusionBlock(int64_t dim,
+ int64_t heads,
+ int64_t kv_heads,
+ int64_t multiplier,
+ float eps) {
+ blocks["prenorm"] = std::make_shared(dim, eps);
+ blocks["postnorm"] = std::make_shared(dim, eps);
+ blocks["attn"] = std::make_shared(dim, heads, kv_heads, eps);
+ blocks["mlp"] = std::make_shared(dim, multiplier);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ auto prenorm = std::dynamic_pointer_cast(blocks["prenorm"]);
+ auto postnorm = std::dynamic_pointer_cast(blocks["postnorm"]);
+ auto attn = std::dynamic_pointer_cast(blocks["attn"]);
+ auto mlp = std::dynamic_pointer_cast(blocks["mlp"]);
+
+ x = ggml_add(ctx->ggml_ctx, x, attn->forward(ctx, prenorm->forward(ctx, x)));
+ x = ggml_add(ctx->ggml_ctx, x, mlp->forward(ctx, postnorm->forward(ctx, x)));
+ return x;
+ }
+ };
+
+ class KreaTextFusionTransformer : public UnaryBlock {
+ protected:
+ Krea2Config config;
+
+ public:
+ explicit KreaTextFusionTransformer(Krea2Config config)
+ : config(std::move(config)) {
+ for (int i = 0; i < 2; ++i) {
+ blocks["layerwise_blocks." + std::to_string(i)] = std::make_shared(this->config.text_dim,
+ this->config.text_heads,
+ this->config.text_kv_heads,
+ this->config.mlp_multiplier,
+ this->config.norm_eps);
+ blocks["refiner_blocks." + std::to_string(i)] = std::make_shared(this->config.text_dim,
+ this->config.text_heads,
+ this->config.text_kv_heads,
+ this->config.mlp_multiplier,
+ this->config.norm_eps);
+ }
+ blocks["projector"] = std::make_shared(this->config.text_layers, 1, false);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* context) override {
+ int64_t text_tokens = context->ne[1];
+ int64_t batch = context->ne[2];
+
+ context = ggml_reshape_3d(ctx->ggml_ctx,
+ context,
+ config.text_dim,
+ config.text_layers,
+ text_tokens * batch);
+
+ for (int i = 0; i < 2; ++i) {
+ auto block = std::dynamic_pointer_cast(blocks["layerwise_blocks." + std::to_string(i)]);
+ context = block->forward(ctx, context);
+ }
+
+ context = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context, 1, 0, 2, 3));
+ auto projector = std::dynamic_pointer_cast(blocks["projector"]);
+ context = projector->forward(ctx, context);
+ context = ggml_reshape_3d(ctx->ggml_ctx, context, config.text_dim, text_tokens, batch);
+
+ for (int i = 0; i < 2; ++i) {
+ auto block = std::dynamic_pointer_cast(blocks["refiner_blocks." + std::to_string(i)]);
+ context = block->forward(ctx, context);
+ }
+ return context;
+ }
+ };
+
+ class KreaSingleStreamBlock : public UnaryBlock {
+ public:
+ explicit KreaSingleStreamBlock(Krea2Config config) {
+ blocks["mod"] = std::make_shared(config.features);
+ blocks["prenorm"] = std::make_shared(config.features, config.norm_eps);
+ blocks["postnorm"] = std::make_shared(config.features, config.norm_eps);
+ blocks["attn"] = std::make_shared(config.features, config.heads, config.kv_heads, config.norm_eps);
+ blocks["mlp"] = std::make_shared(config.features, config.mlp_multiplier);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx,
+ ggml_tensor* x,
+ ggml_tensor* vec,
+ ggml_tensor* pe) {
+ auto mod = std::dynamic_pointer_cast(blocks["mod"]);
+ auto prenorm = std::dynamic_pointer_cast(blocks["prenorm"]);
+ auto postnorm = std::dynamic_pointer_cast(blocks["postnorm"]);
+ auto attn = std::dynamic_pointer_cast(blocks["attn"]);
+ auto mlp = std::dynamic_pointer_cast(blocks["mlp"]);
+
+ auto mods = mod->forward(ctx, vec);
+ auto attn_input = Flux::modulate(ctx->ggml_ctx,
+ prenorm->forward(ctx, x),
+ mods[1],
+ mods[0],
+ true);
+ auto attn_out = attn->forward(ctx, attn_input, pe);
+ x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, mods[2]));
+
+ auto mlp_input = Flux::modulate(ctx->ggml_ctx,
+ postnorm->forward(ctx, x),
+ mods[4],
+ mods[3],
+ true);
+ auto mlp_out = mlp->forward(ctx, mlp_input);
+ x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, mods[5]));
+ return x;
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ GGML_UNUSED(ctx);
+ GGML_UNUSED(x);
+ GGML_ABORT("KreaSingleStreamBlock requires conditioning");
+ return nullptr;
+ }
+ };
+
+ class KreaTimeMLP : public UnaryBlock {
+ public:
+ explicit KreaTimeMLP(Krea2Config config) {
+ blocks["0"] = std::make_shared(config.timestep_dim, config.features, true);
+ blocks["2"] = std::make_shared(config.features, config.features, true);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ auto linear_0 = std::dynamic_pointer_cast(blocks["0"]);
+ auto linear_2 = std::dynamic_pointer_cast(blocks["2"]);
+ x = linear_0->forward(ctx, x);
+ x = ggml_ext_gelu(ctx->ggml_ctx, x, false);
+ x = linear_2->forward(ctx, x);
+ return x;
+ }
+ };
+
+ class KreaTProj : public UnaryBlock {
+ public:
+ explicit KreaTProj(Krea2Config config) {
+ blocks["1"] = std::make_shared(config.features, config.features * 6, true);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ auto linear_1 = std::dynamic_pointer_cast(blocks["1"]);
+ x = ggml_ext_gelu(ctx->ggml_ctx, x, false);
+ x = linear_1->forward(ctx, x);
+ return x;
+ }
+ };
+
+ class KreaTextMLP : public UnaryBlock {
+ public:
+ explicit KreaTextMLP(Krea2Config config) {
+ blocks["0"] = std::make_shared(config.text_dim, config.norm_eps);
+ blocks["1"] = std::make_shared(config.text_dim, config.features, true);
+ blocks["3"] = std::make_shared(config.features, config.features, true);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
+ auto norm = std::dynamic_pointer_cast(blocks["0"]);
+ auto linear_1 = std::dynamic_pointer_cast(blocks["1"]);
+ auto linear_3 = std::dynamic_pointer_cast(blocks["3"]);
+ x = norm->forward(ctx, x);
+ x = linear_1->forward(ctx, x);
+ x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
+ x = linear_3->forward(ctx, x);
+ return x;
+ }
+ };
+
+ class KreaLastLayer : public GGMLBlock {
+ public:
+ explicit KreaLastLayer(Krea2Config config) {
+ blocks["norm"] = std::make_shared(config.features, config.norm_eps);
+ blocks["linear"] = std::make_shared(config.features, config.patch_size * config.patch_size * config.out_channels, true);
+ blocks["modulation"] = std::make_shared(config.features);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* vec) {
+ auto norm = std::dynamic_pointer_cast(blocks["norm"]);
+ auto linear = std::dynamic_pointer_cast(blocks["linear"]);
+ auto modulation = std::dynamic_pointer_cast(blocks["modulation"]);
+
+ auto mods = modulation->forward(ctx, vec);
+ x = Flux::modulate(ctx->ggml_ctx,
+ norm->forward(ctx, x),
+ mods[1],
+ mods[0],
+ true);
+ x = linear->forward(ctx, x);
+ return x;
+ }
+ };
+
+ class Krea2Model : public GGMLBlock {
+ protected:
+ Krea2Config config;
+
+ public:
+ Krea2Model() = default;
+ explicit Krea2Model(Krea2Config config)
+ : config(std::move(config)) {
+ blocks["first"] = std::make_shared(this->config.patch_size * this->config.patch_size * this->config.in_channels,
+ this->config.features,
+ true);
+ blocks["tmlp"] = std::make_shared(this->config);
+ blocks["txtfusion"] = std::make_shared(this->config);
+ blocks["txtmlp"] = std::make_shared(this->config);
+ blocks["tproj"] = std::make_shared(this->config);
+ for (int i = 0; i < this->config.layers; ++i) {
+ blocks["blocks." + std::to_string(i)] = std::make_shared(this->config);
+ }
+ blocks["last"] = std::make_shared(this->config);
+ }
+
+ ggml_tensor* forward(GGMLRunnerContext* ctx,
+ ggml_tensor* x,
+ ggml_tensor* timestep,
+ ggml_tensor* context,
+ ggml_tensor* pe) {
+ int64_t W = x->ne[0];
+ int64_t H = x->ne[1];
+ int64_t N = x->ne[3];
+ GGML_ASSERT(N == 1);
+
+ auto first = std::dynamic_pointer_cast(blocks["first"]);
+ auto tmlp = std::dynamic_pointer_cast(blocks["tmlp"]);
+ auto txtfusion = std::dynamic_pointer_cast(blocks["txtfusion"]);
+ auto txtmlp = std::dynamic_pointer_cast(blocks["txtmlp"]);
+ auto tproj = std::dynamic_pointer_cast(blocks["tproj"]);
+ auto last = std::dynamic_pointer_cast(blocks["last"]);
+
+ auto img = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size, true);
+ int64_t img_len = img->ne[1];
+ img = first->forward(ctx, img);
+
+ auto t = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast(config.timestep_dim), 10000, 1000.f);
+ t = tmlp->forward(ctx, t);
+ t = ggml_reshape_3d(ctx->ggml_ctx, t, t->ne[0], 1, t->ne[1]);
+ auto tvec = tproj->forward(ctx, t);
+
+ auto txt = txtfusion->forward(ctx, context);
+ txt = txtmlp->forward(ctx, txt);
+ int64_t txt_len = txt->ne[1];
+
+ auto hidden_states = ggml_concat(ctx->ggml_ctx, txt, img, 1);
+ for (int i = 0; i < config.layers; ++i) {
+ auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]);
+ hidden_states = block->forward(ctx, hidden_states, tvec, pe);
+ sd::ggml_graph_cut::mark_graph_cut(hidden_states, "krea2.blocks." + std::to_string(i), "hidden_states");
+ }
+
+ hidden_states = last->forward(ctx, hidden_states, t);
+ hidden_states = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, txt_len, txt_len + img_len);
+ hidden_states = DiT::unpatchify_and_crop(ctx->ggml_ctx, hidden_states, H, W, config.patch_size, config.patch_size, true);
+ return hidden_states;
+ }
+ };
+
+ __STATIC_INLINE__ std::vector gen_krea2_pe(int h,
+ int w,
+ int patch_size,
+ int bs,
+ int context_len,
+ float theta,
+ const std::vector& axes_dim) {
+ auto txt_ids = Rope::gen_flux_txt_ids(bs, context_len, 3, {});
+ auto img_ids = Rope::gen_flux_img_ids(h, w, patch_size, bs, 3, 0, 0, 0, false);
+ auto ids = Rope::concat_ids(txt_ids, img_ids, bs);
+ return Rope::embed_nd(ids, bs, theta, axes_dim);
+ }
+
+ struct Krea2Runner : public DiffusionModelRunner {
+ Krea2Config config;
+ Krea2Model model;
+ std::vector pe_vec;
+
+ Krea2Runner(ggml_backend_t backend,
+ const String2TensorStorage& tensor_storage_map = {},
+ const std::string prefix = "",
+ std::shared_ptr weight_manager = nullptr)
+ : DiffusionModelRunner(backend, prefix, weight_manager),
+ config(Krea2Config::detect_from_weights(tensor_storage_map, prefix)) {
+ model = Krea2Model(config);
+ model.init(params_ctx, tensor_storage_map, prefix);
+ }
+
+ std::string get_desc() override {
+ return "krea2";
+ }
+
+ void get_param_tensors(std::map& tensors, const std::string& prefix) override {
+ model.get_param_tensors(tensors, prefix);
+ }
+
+ ggml_cgraph* build_graph(const sd::Tensor& x_tensor,
+ const sd::Tensor& timesteps_tensor,
+ const sd::Tensor& context_tensor) {
+ ggml_cgraph* gf = new_graph_custom(KREA2_GRAPH_SIZE);
+ ggml_tensor* x = make_input(x_tensor);
+ ggml_tensor* timesteps = make_input(timesteps_tensor);
+ GGML_ASSERT(x->ne[3] == 1);
+ GGML_ASSERT(!context_tensor.empty());
+ ggml_tensor* context = make_input(context_tensor);
+
+ pe_vec = gen_krea2_pe(static_cast(x->ne[1]),
+ static_cast(x->ne[0]),
+ config.patch_size,
+ static_cast(x->ne[3]),
+ static_cast(context->ne[1]),
+ config.theta,
+ config.axes_dim);
+ int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2);
+ auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
+ set_backend_tensor_data(pe, pe_vec.data());
+
+ auto runner_ctx = get_context();
+ ggml_tensor* out = model.forward(&runner_ctx, x, timesteps, context, pe);
+ ggml_build_forward_expand(gf, out);
+ return gf;
+ }
+
+ sd::Tensor compute(int n_threads,
+ const sd::Tensor& x,
+ const sd::Tensor& timesteps,
+ const sd::Tensor& context) {
+ auto get_graph = [&]() -> ggml_cgraph* {
+ return build_graph(x, timesteps, context);
+ };
+ return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false, false, false), x.dim());
+ }
+
+ sd::Tensor compute(int n_threads,
+ const DiffusionParams& diffusion_params) override {
+ GGML_ASSERT(diffusion_params.x != nullptr);
+ GGML_ASSERT(diffusion_params.timesteps != nullptr);
+ return compute(n_threads,
+ *diffusion_params.x,
+ *diffusion_params.timesteps,
+ tensor_or_empty(diffusion_params.context));
+ }
+ };
+} // namespace Krea2
+
+#endif // __SD_MODEL_DIFFUSION_KREA2_HPP__
diff --git a/src/model_loader.cpp b/src/model_loader.cpp
index 2fd854a8..33c056b3 100644
--- a/src/model_loader.cpp
+++ b/src/model_loader.cpp
@@ -453,6 +453,10 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("embed_image_indicator.weight") != std::string::npos) {
return VERSION_IDEOGRAM4;
}
+ if (tensor_storage.name.find("model.diffusion_model.txtfusion.projector.weight") != std::string::npos ||
+ tensor_storage.name.find("model.diffusion_model.text_fusion.projector.weight") != std::string::npos) {
+ return VERSION_KREA2;
+ }
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
return VERSION_CHROMA_RADIANCE;
}
diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp
index da2a8d5e..ccc8347b 100644
--- a/src/name_conversion.cpp
+++ b/src/name_conversion.cpp
@@ -704,6 +704,38 @@ std::string convert_other_dit_to_original_anima(std::string name) {
return name;
}
+std::string convert_diffusers_dit_to_original_krea2(std::string name) {
+ static const std::vector> prefix_map = {
+ {"img_in.", "first."},
+ {"time_embed.linear_1.", "tmlp.0."},
+ {"time_embed.linear_2.", "tmlp.2."},
+ {"time_mod_proj.", "tproj.1."},
+ {"txt_in.linear_1.", "txtmlp.1."},
+ {"txt_in.linear_2.", "txtmlp.3."},
+ {"text_fusion.", "txtfusion."},
+ {"transformer_blocks.", "blocks."},
+ {"final_layer.", "last."},
+ };
+ static const std::vector> name_map = {
+ {"attn.to_out.0.", "attn.wo."},
+ {"attn.to_out.", "attn.wo."},
+ {"attn.to_gate.", "attn.gate."},
+ {"attn.to_q.", "attn.wq."},
+ {"attn.to_k.", "attn.wk."},
+ {"attn.to_v.", "attn.wv."},
+ {"ff.gate.", "mlp.gate."},
+ {"ff.up.", "mlp.up."},
+ {"ff.down.", "mlp.down."},
+ {"txt_in.norm.", "txtmlp.0."},
+ {"last.norm.weight", "last.norm.scale"},
+ {"last.modulation.weight", "last.modulation.lin"},
+ };
+
+ replace_with_prefix_map(name, prefix_map);
+ replace_with_name_map(name, name_map);
+ return name;
+}
+
std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) {
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
name = convert_diffusers_unet_to_original_sd1(name);
@@ -717,6 +749,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
name = convert_diffusers_dit_to_original_lumina2(name);
} else if (sd_version_is_anima(version)) {
name = convert_other_dit_to_original_anima(name);
+ } else if (sd_version_is_krea2(version)) {
+ name = convert_diffusers_dit_to_original_krea2(name);
}
return name;
}
@@ -1175,7 +1209,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
replace_with_prefix_map(name, prefix_map);
- if (sd_version_is_boogu_image(version) && starts_with(name, "text_encoders.llm.visual.")) {
+ if ((sd_version_is_boogu_image(version) || sd_version_is_krea2(version)) && starts_with(name, "text_encoders.llm.visual.")) {
name = convert_qwen3_vl_vision_name(std::move(name));
}
diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp
index 08a3ed23..311c7511 100644
--- a/src/stable-diffusion.cpp
+++ b/src/stable-diffusion.cpp
@@ -26,6 +26,7 @@
#include "model/diffusion/flux.hpp"
#include "model/diffusion/hidream_o1.hpp"
#include "model/diffusion/ideogram4.hpp"
+#include "model/diffusion/krea2.hpp"
#include "model/diffusion/lens.hpp"
#include "model/diffusion/ltxv.hpp"
#include "model/diffusion/mmdit.hpp"
@@ -95,6 +96,7 @@ const char* model_version_to_str[] = {
"Longcat-Image",
"PiD",
"Ideogram 4",
+ "Krea2",
"ESRGAN",
};
@@ -645,6 +647,17 @@ public:
tensor_storage_map,
"model.diffusion_model",
model_manager);
+ } else if (sd_version_is_krea2(version)) {
+ cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE),
+ tensor_storage_map,
+ version,
+ "",
+ false,
+ model_manager);
+ diffusion_model = std::make_shared(backend_for(SDBackendModule::DIFFUSION),
+ tensor_storage_map,
+ "model.diffusion_model",
+ model_manager);
} else if (sd_version_is_flux(version)) {
bool is_chroma = false;
for (auto pair : tensor_storage_map) {
@@ -881,6 +894,7 @@ public:
auto create_tae = [&](bool decode_only) -> std::shared_ptr {
if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
+ sd_version_is_krea2(version) ||
sd_version_is_anima(version) ||
sd_version_is_ltxav(version)) {
return std::make_shared(backend_for(SDBackendModule::VAE),
@@ -921,6 +935,7 @@ public:
model_manager);
} else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
+ sd_version_is_krea2(version) ||
sd_version_is_anima(version)) {
return std::make_shared(backend_for(SDBackendModule::VAE),
tensor_storage_map,
@@ -1267,7 +1282,8 @@ public:
} else if (sd_version_is_flux(version) ||
sd_version_is_longcat(version) ||
sd_version_is_lens(version) ||
- sd_version_is_ltxav(version)) {
+ sd_version_is_ltxav(version) ||
+ sd_version_is_krea2(version)) {
pred_type = FLUX_FLOW_PRED;
default_flow_shift = 1.0f; // TODO: validate
@@ -1283,6 +1299,8 @@ public:
default_flow_shift = 1.83f;
} else if (sd_version_is_ltxav(version)) {
default_flow_shift = 2.37f;
+ } else if (sd_version_is_krea2(version)) {
+ default_flow_shift = 1.15f;
}
} else if (sd_version_is_flux2(version)) {
pred_type = FLUX2_FLOW_PRED;
@@ -1724,7 +1742,7 @@ public:
} else if (sd_version_uses_flux_vae(version)) {
latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias;
- } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
+ } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) {
latent_rgb_proj = wan_21_latent_rgb_proj;
latent_rgb_bias = wan_21_latent_rgb_bias;
} else {