mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 15:16:38 +00:00
feat: add krea2 support
This commit is contained in:
parent
f440ad9c29
commit
49e7882137
@ -1518,7 +1518,7 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
arch = LLM::LLMArch::GPT_OSS_20B;
|
arch = LLM::LLMArch::GPT_OSS_20B;
|
||||||
} else if (sd_version_is_pid(version)) {
|
} else if (sd_version_is_pid(version)) {
|
||||||
arch = LLM::LLMArch::GEMMA2_2B;
|
arch = LLM::LLMArch::GEMMA2_2B;
|
||||||
} else if (sd_version_is_ideogram4(version) || sd_version_is_boogu_image(version)) {
|
} else if (sd_version_is_ideogram4(version) || sd_version_is_boogu_image(version) || sd_version_is_krea2(version)) {
|
||||||
arch = LLM::LLMArch::QWEN3_VL;
|
arch = LLM::LLMArch::QWEN3_VL;
|
||||||
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
|
||||||
arch = LLM::LLMArch::QWEN3;
|
arch = LLM::LLMArch::QWEN3;
|
||||||
@ -1837,6 +1837,17 @@ struct LLMEmbedder : public Conditioner {
|
|||||||
prompt_attn_range.second = static_cast<int>(prompt.size());
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
prompt += "<|im_end|>\n";
|
prompt += "<|im_end|>\n";
|
||||||
}
|
}
|
||||||
|
} else if (sd_version_is_krea2(version)) {
|
||||||
|
prompt_template_encode_start_idx = 34;
|
||||||
|
out_layers = {2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35};
|
||||||
|
|
||||||
|
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
|
||||||
|
|
||||||
|
prompt_attn_range.first = static_cast<int>(prompt.size());
|
||||||
|
prompt += conditioner_params.text;
|
||||||
|
prompt_attn_range.second = static_cast<int>(prompt.size());
|
||||||
|
|
||||||
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
} else if (sd_version_is_longcat(version)) {
|
} else if (sd_version_is_longcat(version)) {
|
||||||
spell_quotes = true;
|
spell_quotes = true;
|
||||||
|
|
||||||
|
|||||||
@ -1382,7 +1382,16 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
|
|||||||
if (!ggml_backend_supports_op(backend, kqv)) {
|
if (!ggml_backend_supports_op(backend, kqv)) {
|
||||||
kqv = nullptr;
|
kqv = nullptr;
|
||||||
} else {
|
} else {
|
||||||
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
kqv = ggml_view_4d(ctx,
|
||||||
|
kqv,
|
||||||
|
d_head,
|
||||||
|
n_head,
|
||||||
|
L_q,
|
||||||
|
N,
|
||||||
|
kqv->nb[1],
|
||||||
|
kqv->nb[2],
|
||||||
|
kqv->nb[1] * n_head,
|
||||||
|
0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
11
src/model.h
11
src/model.h
@ -49,6 +49,7 @@ enum SDVersion {
|
|||||||
VERSION_LONGCAT,
|
VERSION_LONGCAT,
|
||||||
VERSION_PID,
|
VERSION_PID,
|
||||||
VERSION_IDEOGRAM4,
|
VERSION_IDEOGRAM4,
|
||||||
|
VERSION_KREA2,
|
||||||
VERSION_ESRGAN,
|
VERSION_ESRGAN,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
@ -186,6 +187,13 @@ static inline bool sd_version_is_ideogram4(SDVersion version) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_krea2(SDVersion version) {
|
||||||
|
if (version == VERSION_KREA2) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_uses_flux_vae(SDVersion version) {
|
static inline bool sd_version_uses_flux_vae(SDVersion version) {
|
||||||
if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) {
|
if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) {
|
||||||
return true;
|
return true;
|
||||||
@ -226,7 +234,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
|
|||||||
sd_version_is_lens(version) ||
|
sd_version_is_lens(version) ||
|
||||||
sd_version_is_longcat(version) ||
|
sd_version_is_longcat(version) ||
|
||||||
sd_version_is_pid(version) ||
|
sd_version_is_pid(version) ||
|
||||||
sd_version_is_ideogram4(version)) {
|
sd_version_is_ideogram4(version) ||
|
||||||
|
sd_version_is_krea2(version)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
683
src/model/diffusion/krea2.hpp
Normal file
683
src/model/diffusion/krea2.hpp
Normal file
@ -0,0 +1,683 @@
|
|||||||
|
#ifndef __SD_MODEL_DIFFUSION_KREA2_HPP__
|
||||||
|
#define __SD_MODEL_DIFFUSION_KREA2_HPP__
|
||||||
|
|
||||||
|
#include <inttypes.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "core/ggml_extend.hpp"
|
||||||
|
#include "core/ggml_graph_cut.h"
|
||||||
|
#include "model/common/rope.hpp"
|
||||||
|
#include "model/diffusion/dit.hpp"
|
||||||
|
#include "model/diffusion/flux.hpp"
|
||||||
|
#include "model/diffusion/model.hpp"
|
||||||
|
#include "model_loader.h"
|
||||||
|
|
||||||
|
namespace Krea2 {
|
||||||
|
constexpr int KREA2_GRAPH_SIZE = 65536;
|
||||||
|
|
||||||
|
struct Krea2Config {
|
||||||
|
int patch_size = 2;
|
||||||
|
int64_t in_channels = 16;
|
||||||
|
int64_t out_channels = 16;
|
||||||
|
int64_t features = 6144;
|
||||||
|
int64_t timestep_dim = 256;
|
||||||
|
int64_t text_dim = 2560;
|
||||||
|
int64_t text_layers = 12;
|
||||||
|
int64_t layers = 28;
|
||||||
|
int64_t heads = 48;
|
||||||
|
int64_t kv_heads = 12;
|
||||||
|
int64_t text_heads = 20;
|
||||||
|
int64_t text_kv_heads = 20;
|
||||||
|
int64_t mlp_multiplier = 4;
|
||||||
|
float theta = 1000.f;
|
||||||
|
float norm_eps = 1e-5f;
|
||||||
|
std::vector<int> axes_dim = {32, 48, 48};
|
||||||
|
int axes_dim_sum = 128;
|
||||||
|
|
||||||
|
int64_t head_dim() const {
|
||||||
|
return features / heads;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t count_blocks(const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string& prefix,
|
||||||
|
const std::string& block_prefix) {
|
||||||
|
int64_t count = 0;
|
||||||
|
std::string full_prefix = prefix.empty() ? block_prefix : prefix + "." + block_prefix;
|
||||||
|
for (const auto& [name, _] : tensor_storage_map) {
|
||||||
|
if (!starts_with(name, full_prefix)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string tail = name.substr(full_prefix.size());
|
||||||
|
size_t dot = tail.find('.');
|
||||||
|
if (dot == std::string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int block_index = std::atoi(tail.substr(0, dot).c_str());
|
||||||
|
count = std::max<int64_t>(count, block_index + 1);
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
void update_axes_dim() {
|
||||||
|
int64_t dim_head = head_dim();
|
||||||
|
int64_t unit = dim_head / 16;
|
||||||
|
axes_dim = {
|
||||||
|
static_cast<int>(dim_head - 12 * unit),
|
||||||
|
static_cast<int>(6 * unit),
|
||||||
|
static_cast<int>(6 * unit),
|
||||||
|
};
|
||||||
|
axes_dim_sum = axes_dim[0] + axes_dim[1] + axes_dim[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
static Krea2Config detect_from_weights(const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string& prefix) {
|
||||||
|
Krea2Config config;
|
||||||
|
int64_t detected_head_dim = 0;
|
||||||
|
int64_t detected_text_head_dim = 0;
|
||||||
|
|
||||||
|
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||||
|
if (!starts_with(name, prefix)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ends_with(name, "first.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
config.in_channels = tensor_storage.ne[0] / (config.patch_size * config.patch_size);
|
||||||
|
config.out_channels = config.in_channels;
|
||||||
|
config.features = tensor_storage.ne[1];
|
||||||
|
} else if (ends_with(name, "blocks.0.attn.qknorm.qnorm.scale") && tensor_storage.n_dims == 1) {
|
||||||
|
detected_head_dim = tensor_storage.ne[0];
|
||||||
|
} else if (ends_with(name, "blocks.0.attn.wq.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
if (detected_head_dim > 0) {
|
||||||
|
config.heads = tensor_storage.ne[1] / detected_head_dim;
|
||||||
|
}
|
||||||
|
} else if (ends_with(name, "blocks.0.attn.wk.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
if (detected_head_dim > 0) {
|
||||||
|
config.kv_heads = tensor_storage.ne[1] / detected_head_dim;
|
||||||
|
}
|
||||||
|
} else if (ends_with(name, "txtfusion.projector.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
config.text_layers = tensor_storage.ne[0];
|
||||||
|
} else if (ends_with(name, "txtfusion.layerwise_blocks.0.prenorm.scale") && tensor_storage.n_dims == 1) {
|
||||||
|
config.text_dim = tensor_storage.ne[0];
|
||||||
|
} else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.qknorm.qnorm.scale") && tensor_storage.n_dims == 1) {
|
||||||
|
detected_text_head_dim = tensor_storage.ne[0];
|
||||||
|
} else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.wq.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
if (detected_text_head_dim > 0) {
|
||||||
|
config.text_heads = tensor_storage.ne[1] / detected_text_head_dim;
|
||||||
|
}
|
||||||
|
} else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.wk.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
if (detected_text_head_dim > 0) {
|
||||||
|
config.text_kv_heads = tensor_storage.ne[1] / detected_text_head_dim;
|
||||||
|
}
|
||||||
|
} else if (ends_with(name, "last.linear.weight") && tensor_storage.n_dims == 2) {
|
||||||
|
config.out_channels = tensor_storage.ne[1] / (config.patch_size * config.patch_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.layers = std::max<int64_t>(1, count_blocks(tensor_storage_map, prefix, "blocks."));
|
||||||
|
if (detected_head_dim > 0 && config.features > 0) {
|
||||||
|
config.heads = config.features / detected_head_dim;
|
||||||
|
}
|
||||||
|
if (detected_head_dim > 0) {
|
||||||
|
std::string wk_name = prefix.empty() ? "blocks.0.attn.wk.weight" : prefix + ".blocks.0.attn.wk.weight";
|
||||||
|
auto it = tensor_storage_map.find(wk_name);
|
||||||
|
if (it != tensor_storage_map.end() && it->second.n_dims == 2) {
|
||||||
|
config.kv_heads = it->second.ne[1] / detected_head_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (detected_text_head_dim > 0 && config.text_dim > 0) {
|
||||||
|
config.text_heads = config.text_dim / detected_text_head_dim;
|
||||||
|
}
|
||||||
|
if (detected_text_head_dim > 0) {
|
||||||
|
std::string wk_name = prefix.empty() ? "txtfusion.layerwise_blocks.0.attn.wk.weight" : prefix + ".txtfusion.layerwise_blocks.0.attn.wk.weight";
|
||||||
|
auto it = tensor_storage_map.find(wk_name);
|
||||||
|
if (it != tensor_storage_map.end() && it->second.n_dims == 2) {
|
||||||
|
config.text_kv_heads = it->second.ne[1] / detected_text_head_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.update_axes_dim();
|
||||||
|
|
||||||
|
LOG_DEBUG("krea2: layers=%" PRId64 ", features=%" PRId64 ", heads=%" PRId64 ", kv_heads=%" PRId64 ", text_dim=%" PRId64 ", text_layers=%" PRId64 ", text_heads=%" PRId64 ", text_kv_heads=%" PRId64 ", channels=%" PRId64,
|
||||||
|
config.layers,
|
||||||
|
config.features,
|
||||||
|
config.heads,
|
||||||
|
config.kv_heads,
|
||||||
|
config.text_dim,
|
||||||
|
config.text_layers,
|
||||||
|
config.text_heads,
|
||||||
|
config.text_kv_heads,
|
||||||
|
config.in_channels);
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
__STATIC_INLINE__ int64_t ceil_to_multiple(int64_t value, int64_t multiple) {
|
||||||
|
return ((value + multiple - 1) / multiple) * multiple;
|
||||||
|
}
|
||||||
|
|
||||||
|
class KreaRMSNorm : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t hidden_size;
|
||||||
|
float eps;
|
||||||
|
std::string prefix;
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||||
|
GGML_UNUSED(tensor_storage_map);
|
||||||
|
this->prefix = prefix;
|
||||||
|
params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
KreaRMSNorm(int64_t hidden_size, float eps = 1e-5f)
|
||||||
|
: hidden_size(hidden_size),
|
||||||
|
eps(eps) {}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
ggml_tensor* scale = params["scale"];
|
||||||
|
scale = ggml_add(ctx->ggml_ctx, scale, ggml_ext_ones(ctx->ggml_ctx, scale->ne[0], 1, 1, 1));
|
||||||
|
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
|
||||||
|
x = ggml_mul_inplace(ctx->ggml_ctx, x, scale);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaSwiGLU : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
KreaSwiGLU(int64_t features, int64_t multiplier) {
|
||||||
|
int64_t mlp_dim = ceil_to_multiple(((2 * features) / 3) * multiplier, 128);
|
||||||
|
blocks["gate"] = std::make_shared<Linear>(features, mlp_dim, false);
|
||||||
|
blocks["up"] = std::make_shared<Linear>(features, mlp_dim, false);
|
||||||
|
blocks["down"] = std::make_shared<Linear>(mlp_dim, features, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
auto gate = std::dynamic_pointer_cast<Linear>(blocks["gate"]);
|
||||||
|
auto up = std::dynamic_pointer_cast<Linear>(blocks["up"]);
|
||||||
|
auto down = std::dynamic_pointer_cast<Linear>(blocks["down"]);
|
||||||
|
|
||||||
|
auto gated = ggml_silu(ctx->ggml_ctx, gate->forward(ctx, x));
|
||||||
|
auto up_x = up->forward(ctx, x);
|
||||||
|
x = ggml_mul(ctx->ggml_ctx, gated, up_x);
|
||||||
|
return down->forward(ctx, x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaAttention : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t features;
|
||||||
|
int64_t heads;
|
||||||
|
int64_t kv_heads;
|
||||||
|
int64_t head_dim_;
|
||||||
|
|
||||||
|
ggml_tensor* attention_no_rope(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* q,
|
||||||
|
ggml_tensor* k,
|
||||||
|
ggml_tensor* v,
|
||||||
|
ggml_tensor* mask) {
|
||||||
|
int64_t Lq = q->ne[2];
|
||||||
|
int64_t Lk = k->ne[2];
|
||||||
|
int64_t N = q->ne[3];
|
||||||
|
q = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim_ * heads, Lq, N);
|
||||||
|
k = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim_ * kv_heads, Lk, N);
|
||||||
|
v = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim_ * kv_heads, Lk, N);
|
||||||
|
return ggml_ext_attention_ext(ctx->ggml_ctx,
|
||||||
|
ctx->backend,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
mask,
|
||||||
|
false,
|
||||||
|
ctx->flash_attn_enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
KreaAttention(int64_t features,
|
||||||
|
int64_t heads,
|
||||||
|
int64_t kv_heads,
|
||||||
|
float eps = 1e-5f)
|
||||||
|
: features(features),
|
||||||
|
heads(heads),
|
||||||
|
kv_heads(kv_heads),
|
||||||
|
head_dim_(features / heads) {
|
||||||
|
blocks["wq"] = std::make_shared<Linear>(features, heads * head_dim_, false);
|
||||||
|
blocks["wk"] = std::make_shared<Linear>(features, kv_heads * head_dim_, false);
|
||||||
|
blocks["wv"] = std::make_shared<Linear>(features, kv_heads * head_dim_, false);
|
||||||
|
blocks["gate"] = std::make_shared<Linear>(features, features, false);
|
||||||
|
blocks["qknorm.qnorm"] = std::make_shared<KreaRMSNorm>(head_dim_, eps);
|
||||||
|
blocks["qknorm.knorm"] = std::make_shared<KreaRMSNorm>(head_dim_, eps);
|
||||||
|
blocks["wo"] = std::make_shared<Linear>(features, features, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* pe = nullptr,
|
||||||
|
ggml_tensor* mask = nullptr) {
|
||||||
|
auto wq = std::dynamic_pointer_cast<Linear>(blocks["wq"]);
|
||||||
|
auto wk = std::dynamic_pointer_cast<Linear>(blocks["wk"]);
|
||||||
|
auto wv = std::dynamic_pointer_cast<Linear>(blocks["wv"]);
|
||||||
|
auto gate = std::dynamic_pointer_cast<Linear>(blocks["gate"]);
|
||||||
|
auto qnorm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["qknorm.qnorm"]);
|
||||||
|
auto knorm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["qknorm.knorm"]);
|
||||||
|
auto wo = std::dynamic_pointer_cast<Linear>(blocks["wo"]);
|
||||||
|
|
||||||
|
if (sd_backend_is(ctx->backend, "Vulkan")) {
|
||||||
|
wo->set_force_prec_f32(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t L = x->ne[1];
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
|
||||||
|
auto q = wq->forward(ctx, x);
|
||||||
|
q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim_, heads, L, N);
|
||||||
|
auto k = wk->forward(ctx, x);
|
||||||
|
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim_, kv_heads, L, N);
|
||||||
|
auto v = wv->forward(ctx, x);
|
||||||
|
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim_, kv_heads, L, N);
|
||||||
|
|
||||||
|
q = qnorm->forward(ctx, q);
|
||||||
|
k = knorm->forward(ctx, k);
|
||||||
|
|
||||||
|
auto out = pe != nullptr ? Rope::attention(ctx, q, k, v, pe, mask)
|
||||||
|
: attention_no_rope(ctx, q, k, v, mask);
|
||||||
|
out = ggml_mul(ctx->ggml_ctx, out, ggml_sigmoid(ctx->ggml_ctx, gate->forward(ctx, x)));
|
||||||
|
out = wo->forward(ctx, out);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaDoubleSharedModulation : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||||
|
GGML_UNUSED(tensor_storage_map);
|
||||||
|
GGML_UNUSED(prefix);
|
||||||
|
params["lin"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim * 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
KreaDoubleSharedModulation(int64_t dim)
|
||||||
|
: dim(dim) {}
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> forward(GGMLRunnerContext* ctx, ggml_tensor* vec) {
|
||||||
|
auto lin = ggml_repeat(ctx->ggml_ctx, params["lin"], vec);
|
||||||
|
auto out = ggml_add(ctx->ggml_ctx, vec, lin);
|
||||||
|
return ggml_ext_chunk(ctx->ggml_ctx, out, 6, 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaFinalModulation : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||||
|
GGML_UNUSED(tensor_storage_map);
|
||||||
|
GGML_UNUSED(prefix);
|
||||||
|
params["lin"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
KreaFinalModulation(int64_t dim)
|
||||||
|
: dim(dim) {}
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> forward(GGMLRunnerContext* ctx, ggml_tensor* vec) {
|
||||||
|
auto out = ggml_add(ctx->ggml_ctx, params["lin"], vec);
|
||||||
|
return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaTextFusionBlock : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
KreaTextFusionBlock(int64_t dim,
|
||||||
|
int64_t heads,
|
||||||
|
int64_t kv_heads,
|
||||||
|
int64_t multiplier,
|
||||||
|
float eps) {
|
||||||
|
blocks["prenorm"] = std::make_shared<KreaRMSNorm>(dim, eps);
|
||||||
|
blocks["postnorm"] = std::make_shared<KreaRMSNorm>(dim, eps);
|
||||||
|
blocks["attn"] = std::make_shared<KreaAttention>(dim, heads, kv_heads, eps);
|
||||||
|
blocks["mlp"] = std::make_shared<KreaSwiGLU>(dim, multiplier);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
auto prenorm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["prenorm"]);
|
||||||
|
auto postnorm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["postnorm"]);
|
||||||
|
auto attn = std::dynamic_pointer_cast<KreaAttention>(blocks["attn"]);
|
||||||
|
auto mlp = std::dynamic_pointer_cast<KreaSwiGLU>(blocks["mlp"]);
|
||||||
|
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, attn->forward(ctx, prenorm->forward(ctx, x)));
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, mlp->forward(ctx, postnorm->forward(ctx, x)));
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaTextFusionTransformer : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
Krea2Config config;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit KreaTextFusionTransformer(Krea2Config config)
|
||||||
|
: config(std::move(config)) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
blocks["layerwise_blocks." + std::to_string(i)] = std::make_shared<KreaTextFusionBlock>(this->config.text_dim,
|
||||||
|
this->config.text_heads,
|
||||||
|
this->config.text_kv_heads,
|
||||||
|
this->config.mlp_multiplier,
|
||||||
|
this->config.norm_eps);
|
||||||
|
blocks["refiner_blocks." + std::to_string(i)] = std::make_shared<KreaTextFusionBlock>(this->config.text_dim,
|
||||||
|
this->config.text_heads,
|
||||||
|
this->config.text_kv_heads,
|
||||||
|
this->config.mlp_multiplier,
|
||||||
|
this->config.norm_eps);
|
||||||
|
}
|
||||||
|
blocks["projector"] = std::make_shared<Linear>(this->config.text_layers, 1, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* context) override {
|
||||||
|
int64_t text_tokens = context->ne[1];
|
||||||
|
int64_t batch = context->ne[2];
|
||||||
|
|
||||||
|
context = ggml_reshape_3d(ctx->ggml_ctx,
|
||||||
|
context,
|
||||||
|
config.text_dim,
|
||||||
|
config.text_layers,
|
||||||
|
text_tokens * batch);
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
auto block = std::dynamic_pointer_cast<KreaTextFusionBlock>(blocks["layerwise_blocks." + std::to_string(i)]);
|
||||||
|
context = block->forward(ctx, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
context = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context, 1, 0, 2, 3));
|
||||||
|
auto projector = std::dynamic_pointer_cast<Linear>(blocks["projector"]);
|
||||||
|
context = projector->forward(ctx, context);
|
||||||
|
context = ggml_reshape_3d(ctx->ggml_ctx, context, config.text_dim, text_tokens, batch);
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
auto block = std::dynamic_pointer_cast<KreaTextFusionBlock>(blocks["refiner_blocks." + std::to_string(i)]);
|
||||||
|
context = block->forward(ctx, context);
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaSingleStreamBlock : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
explicit KreaSingleStreamBlock(Krea2Config config) {
|
||||||
|
blocks["mod"] = std::make_shared<KreaDoubleSharedModulation>(config.features);
|
||||||
|
blocks["prenorm"] = std::make_shared<KreaRMSNorm>(config.features, config.norm_eps);
|
||||||
|
blocks["postnorm"] = std::make_shared<KreaRMSNorm>(config.features, config.norm_eps);
|
||||||
|
blocks["attn"] = std::make_shared<KreaAttention>(config.features, config.heads, config.kv_heads, config.norm_eps);
|
||||||
|
blocks["mlp"] = std::make_shared<KreaSwiGLU>(config.features, config.mlp_multiplier);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* vec,
|
||||||
|
ggml_tensor* pe) {
|
||||||
|
auto mod = std::dynamic_pointer_cast<KreaDoubleSharedModulation>(blocks["mod"]);
|
||||||
|
auto prenorm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["prenorm"]);
|
||||||
|
auto postnorm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["postnorm"]);
|
||||||
|
auto attn = std::dynamic_pointer_cast<KreaAttention>(blocks["attn"]);
|
||||||
|
auto mlp = std::dynamic_pointer_cast<KreaSwiGLU>(blocks["mlp"]);
|
||||||
|
|
||||||
|
auto mods = mod->forward(ctx, vec);
|
||||||
|
auto attn_input = Flux::modulate(ctx->ggml_ctx,
|
||||||
|
prenorm->forward(ctx, x),
|
||||||
|
mods[1],
|
||||||
|
mods[0],
|
||||||
|
true);
|
||||||
|
auto attn_out = attn->forward(ctx, attn_input, pe);
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, mods[2]));
|
||||||
|
|
||||||
|
auto mlp_input = Flux::modulate(ctx->ggml_ctx,
|
||||||
|
postnorm->forward(ctx, x),
|
||||||
|
mods[4],
|
||||||
|
mods[3],
|
||||||
|
true);
|
||||||
|
auto mlp_out = mlp->forward(ctx, mlp_input);
|
||||||
|
x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, mods[5]));
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
GGML_UNUSED(ctx);
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
GGML_ABORT("KreaSingleStreamBlock requires conditioning");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaTimeMLP : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
explicit KreaTimeMLP(Krea2Config config) {
|
||||||
|
blocks["0"] = std::make_shared<Linear>(config.timestep_dim, config.features, true);
|
||||||
|
blocks["2"] = std::make_shared<Linear>(config.features, config.features, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
auto linear_0 = std::dynamic_pointer_cast<Linear>(blocks["0"]);
|
||||||
|
auto linear_2 = std::dynamic_pointer_cast<Linear>(blocks["2"]);
|
||||||
|
x = linear_0->forward(ctx, x);
|
||||||
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, false);
|
||||||
|
x = linear_2->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaTProj : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
explicit KreaTProj(Krea2Config config) {
|
||||||
|
blocks["1"] = std::make_shared<Linear>(config.features, config.features * 6, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["1"]);
|
||||||
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, false);
|
||||||
|
x = linear_1->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaTextMLP : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
explicit KreaTextMLP(Krea2Config config) {
|
||||||
|
blocks["0"] = std::make_shared<KreaRMSNorm>(config.text_dim, config.norm_eps);
|
||||||
|
blocks["1"] = std::make_shared<Linear>(config.text_dim, config.features, true);
|
||||||
|
blocks["3"] = std::make_shared<Linear>(config.features, config.features, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
|
||||||
|
auto norm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["0"]);
|
||||||
|
auto linear_1 = std::dynamic_pointer_cast<Linear>(blocks["1"]);
|
||||||
|
auto linear_3 = std::dynamic_pointer_cast<Linear>(blocks["3"]);
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
x = linear_1->forward(ctx, x);
|
||||||
|
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
|
||||||
|
x = linear_3->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KreaLastLayer : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
explicit KreaLastLayer(Krea2Config config) {
|
||||||
|
blocks["norm"] = std::make_shared<KreaRMSNorm>(config.features, config.norm_eps);
|
||||||
|
blocks["linear"] = std::make_shared<Linear>(config.features, config.patch_size * config.patch_size * config.out_channels, true);
|
||||||
|
blocks["modulation"] = std::make_shared<KreaFinalModulation>(config.features);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* vec) {
|
||||||
|
auto norm = std::dynamic_pointer_cast<KreaRMSNorm>(blocks["norm"]);
|
||||||
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
|
auto modulation = std::dynamic_pointer_cast<KreaFinalModulation>(blocks["modulation"]);
|
||||||
|
|
||||||
|
auto mods = modulation->forward(ctx, vec);
|
||||||
|
x = Flux::modulate(ctx->ggml_ctx,
|
||||||
|
norm->forward(ctx, x),
|
||||||
|
mods[1],
|
||||||
|
mods[0],
|
||||||
|
true);
|
||||||
|
x = linear->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Krea2Model : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
Krea2Config config;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Krea2Model() = default;
|
||||||
|
explicit Krea2Model(Krea2Config config)
|
||||||
|
: config(std::move(config)) {
|
||||||
|
blocks["first"] = std::make_shared<Linear>(this->config.patch_size * this->config.patch_size * this->config.in_channels,
|
||||||
|
this->config.features,
|
||||||
|
true);
|
||||||
|
blocks["tmlp"] = std::make_shared<KreaTimeMLP>(this->config);
|
||||||
|
blocks["txtfusion"] = std::make_shared<KreaTextFusionTransformer>(this->config);
|
||||||
|
blocks["txtmlp"] = std::make_shared<KreaTextMLP>(this->config);
|
||||||
|
blocks["tproj"] = std::make_shared<KreaTProj>(this->config);
|
||||||
|
for (int i = 0; i < this->config.layers; ++i) {
|
||||||
|
blocks["blocks." + std::to_string(i)] = std::make_shared<KreaSingleStreamBlock>(this->config);
|
||||||
|
}
|
||||||
|
blocks["last"] = std::make_shared<KreaLastLayer>(this->config);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
|
ggml_tensor* x,
|
||||||
|
ggml_tensor* timestep,
|
||||||
|
ggml_tensor* context,
|
||||||
|
ggml_tensor* pe) {
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t N = x->ne[3];
|
||||||
|
GGML_ASSERT(N == 1);
|
||||||
|
|
||||||
|
auto first = std::dynamic_pointer_cast<Linear>(blocks["first"]);
|
||||||
|
auto tmlp = std::dynamic_pointer_cast<KreaTimeMLP>(blocks["tmlp"]);
|
||||||
|
auto txtfusion = std::dynamic_pointer_cast<KreaTextFusionTransformer>(blocks["txtfusion"]);
|
||||||
|
auto txtmlp = std::dynamic_pointer_cast<KreaTextMLP>(blocks["txtmlp"]);
|
||||||
|
auto tproj = std::dynamic_pointer_cast<KreaTProj>(blocks["tproj"]);
|
||||||
|
auto last = std::dynamic_pointer_cast<KreaLastLayer>(blocks["last"]);
|
||||||
|
|
||||||
|
auto img = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size, true);
|
||||||
|
int64_t img_len = img->ne[1];
|
||||||
|
img = first->forward(ctx, img);
|
||||||
|
|
||||||
|
auto t = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast<int>(config.timestep_dim), 10000, 1000.f);
|
||||||
|
t = tmlp->forward(ctx, t);
|
||||||
|
t = ggml_reshape_3d(ctx->ggml_ctx, t, t->ne[0], 1, t->ne[1]);
|
||||||
|
auto tvec = tproj->forward(ctx, t);
|
||||||
|
|
||||||
|
auto txt = txtfusion->forward(ctx, context);
|
||||||
|
txt = txtmlp->forward(ctx, txt);
|
||||||
|
int64_t txt_len = txt->ne[1];
|
||||||
|
|
||||||
|
auto hidden_states = ggml_concat(ctx->ggml_ctx, txt, img, 1);
|
||||||
|
for (int i = 0; i < config.layers; ++i) {
|
||||||
|
auto block = std::dynamic_pointer_cast<KreaSingleStreamBlock>(blocks["blocks." + std::to_string(i)]);
|
||||||
|
hidden_states = block->forward(ctx, hidden_states, tvec, pe);
|
||||||
|
sd::ggml_graph_cut::mark_graph_cut(hidden_states, "krea2.blocks." + std::to_string(i), "hidden_states");
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_states = last->forward(ctx, hidden_states, t);
|
||||||
|
hidden_states = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, txt_len, txt_len + img_len);
|
||||||
|
hidden_states = DiT::unpatchify_and_crop(ctx->ggml_ctx, hidden_states, H, W, config.patch_size, config.patch_size, true);
|
||||||
|
return hidden_states;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
__STATIC_INLINE__ std::vector<float> gen_krea2_pe(int h,
|
||||||
|
int w,
|
||||||
|
int patch_size,
|
||||||
|
int bs,
|
||||||
|
int context_len,
|
||||||
|
float theta,
|
||||||
|
const std::vector<int>& axes_dim) {
|
||||||
|
auto txt_ids = Rope::gen_flux_txt_ids(bs, context_len, 3, {});
|
||||||
|
auto img_ids = Rope::gen_flux_img_ids(h, w, patch_size, bs, 3, 0, 0, 0, false);
|
||||||
|
auto ids = Rope::concat_ids(txt_ids, img_ids, bs);
|
||||||
|
return Rope::embed_nd(ids, bs, theta, axes_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Krea2Runner : public DiffusionModelRunner {
|
||||||
|
Krea2Config config;
|
||||||
|
Krea2Model model;
|
||||||
|
std::vector<float> pe_vec;
|
||||||
|
|
||||||
|
Krea2Runner(ggml_backend_t backend,
|
||||||
|
const String2TensorStorage& tensor_storage_map = {},
|
||||||
|
const std::string prefix = "",
|
||||||
|
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
|
||||||
|
: DiffusionModelRunner(backend, prefix, weight_manager),
|
||||||
|
config(Krea2Config::detect_from_weights(tensor_storage_map, prefix)) {
|
||||||
|
model = Krea2Model(config);
|
||||||
|
model.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return "krea2";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) override {
|
||||||
|
model.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
|
||||||
|
const sd::Tensor<float>& timesteps_tensor,
|
||||||
|
const sd::Tensor<float>& context_tensor) {
|
||||||
|
ggml_cgraph* gf = new_graph_custom(KREA2_GRAPH_SIZE);
|
||||||
|
ggml_tensor* x = make_input(x_tensor);
|
||||||
|
ggml_tensor* timesteps = make_input(timesteps_tensor);
|
||||||
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
|
GGML_ASSERT(!context_tensor.empty());
|
||||||
|
ggml_tensor* context = make_input(context_tensor);
|
||||||
|
|
||||||
|
pe_vec = gen_krea2_pe(static_cast<int>(x->ne[1]),
|
||||||
|
static_cast<int>(x->ne[0]),
|
||||||
|
config.patch_size,
|
||||||
|
static_cast<int>(x->ne[3]),
|
||||||
|
static_cast<int>(context->ne[1]),
|
||||||
|
config.theta,
|
||||||
|
config.axes_dim);
|
||||||
|
int pos_len = static_cast<int>(pe_vec.size() / config.axes_dim_sum / 2);
|
||||||
|
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len);
|
||||||
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
|
auto runner_ctx = get_context();
|
||||||
|
ggml_tensor* out = model.forward(&runner_ctx, x, timesteps, context, pe);
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> compute(int n_threads,
|
||||||
|
const sd::Tensor<float>& x,
|
||||||
|
const sd::Tensor<float>& timesteps,
|
||||||
|
const sd::Tensor<float>& context) {
|
||||||
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
|
return build_graph(x, timesteps, context);
|
||||||
|
};
|
||||||
|
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false, false, false), x.dim());
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::Tensor<float> compute(int n_threads,
|
||||||
|
const DiffusionParams& diffusion_params) override {
|
||||||
|
GGML_ASSERT(diffusion_params.x != nullptr);
|
||||||
|
GGML_ASSERT(diffusion_params.timesteps != nullptr);
|
||||||
|
return compute(n_threads,
|
||||||
|
*diffusion_params.x,
|
||||||
|
*diffusion_params.timesteps,
|
||||||
|
tensor_or_empty(diffusion_params.context));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace Krea2
|
||||||
|
|
||||||
|
#endif // __SD_MODEL_DIFFUSION_KREA2_HPP__
|
||||||
@ -453,6 +453,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name.find("embed_image_indicator.weight") != std::string::npos) {
|
if (tensor_storage.name.find("embed_image_indicator.weight") != std::string::npos) {
|
||||||
return VERSION_IDEOGRAM4;
|
return VERSION_IDEOGRAM4;
|
||||||
}
|
}
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.txtfusion.projector.weight") != std::string::npos ||
|
||||||
|
tensor_storage.name.find("model.diffusion_model.text_fusion.projector.weight") != std::string::npos) {
|
||||||
|
return VERSION_KREA2;
|
||||||
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
|
||||||
return VERSION_CHROMA_RADIANCE;
|
return VERSION_CHROMA_RADIANCE;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -704,6 +704,38 @@ std::string convert_other_dit_to_original_anima(std::string name) {
|
|||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string convert_diffusers_dit_to_original_krea2(std::string name) {
|
||||||
|
static const std::vector<std::pair<std::string, std::string>> prefix_map = {
|
||||||
|
{"img_in.", "first."},
|
||||||
|
{"time_embed.linear_1.", "tmlp.0."},
|
||||||
|
{"time_embed.linear_2.", "tmlp.2."},
|
||||||
|
{"time_mod_proj.", "tproj.1."},
|
||||||
|
{"txt_in.linear_1.", "txtmlp.1."},
|
||||||
|
{"txt_in.linear_2.", "txtmlp.3."},
|
||||||
|
{"text_fusion.", "txtfusion."},
|
||||||
|
{"transformer_blocks.", "blocks."},
|
||||||
|
{"final_layer.", "last."},
|
||||||
|
};
|
||||||
|
static const std::vector<std::pair<std::string, std::string>> name_map = {
|
||||||
|
{"attn.to_out.0.", "attn.wo."},
|
||||||
|
{"attn.to_out.", "attn.wo."},
|
||||||
|
{"attn.to_gate.", "attn.gate."},
|
||||||
|
{"attn.to_q.", "attn.wq."},
|
||||||
|
{"attn.to_k.", "attn.wk."},
|
||||||
|
{"attn.to_v.", "attn.wv."},
|
||||||
|
{"ff.gate.", "mlp.gate."},
|
||||||
|
{"ff.up.", "mlp.up."},
|
||||||
|
{"ff.down.", "mlp.down."},
|
||||||
|
{"txt_in.norm.", "txtmlp.0."},
|
||||||
|
{"last.norm.weight", "last.norm.scale"},
|
||||||
|
{"last.modulation.weight", "last.modulation.lin"},
|
||||||
|
};
|
||||||
|
|
||||||
|
replace_with_prefix_map(name, prefix_map);
|
||||||
|
replace_with_name_map(name, name_map);
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) {
|
std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) {
|
||||||
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
|
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
|
||||||
name = convert_diffusers_unet_to_original_sd1(name);
|
name = convert_diffusers_unet_to_original_sd1(name);
|
||||||
@ -717,6 +749,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S
|
|||||||
name = convert_diffusers_dit_to_original_lumina2(name);
|
name = convert_diffusers_dit_to_original_lumina2(name);
|
||||||
} else if (sd_version_is_anima(version)) {
|
} else if (sd_version_is_anima(version)) {
|
||||||
name = convert_other_dit_to_original_anima(name);
|
name = convert_other_dit_to_original_anima(name);
|
||||||
|
} else if (sd_version_is_krea2(version)) {
|
||||||
|
name = convert_diffusers_dit_to_original_krea2(name);
|
||||||
}
|
}
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
@ -1175,7 +1209,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
|||||||
|
|
||||||
replace_with_prefix_map(name, prefix_map);
|
replace_with_prefix_map(name, prefix_map);
|
||||||
|
|
||||||
if (sd_version_is_boogu_image(version) && starts_with(name, "text_encoders.llm.visual.")) {
|
if ((sd_version_is_boogu_image(version) || sd_version_is_krea2(version)) && starts_with(name, "text_encoders.llm.visual.")) {
|
||||||
name = convert_qwen3_vl_vision_name(std::move(name));
|
name = convert_qwen3_vl_vision_name(std::move(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@
|
|||||||
#include "model/diffusion/flux.hpp"
|
#include "model/diffusion/flux.hpp"
|
||||||
#include "model/diffusion/hidream_o1.hpp"
|
#include "model/diffusion/hidream_o1.hpp"
|
||||||
#include "model/diffusion/ideogram4.hpp"
|
#include "model/diffusion/ideogram4.hpp"
|
||||||
|
#include "model/diffusion/krea2.hpp"
|
||||||
#include "model/diffusion/lens.hpp"
|
#include "model/diffusion/lens.hpp"
|
||||||
#include "model/diffusion/ltxv.hpp"
|
#include "model/diffusion/ltxv.hpp"
|
||||||
#include "model/diffusion/mmdit.hpp"
|
#include "model/diffusion/mmdit.hpp"
|
||||||
@ -95,6 +96,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Longcat-Image",
|
"Longcat-Image",
|
||||||
"PiD",
|
"PiD",
|
||||||
"Ideogram 4",
|
"Ideogram 4",
|
||||||
|
"Krea2",
|
||||||
"ESRGAN",
|
"ESRGAN",
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -645,6 +647,17 @@ public:
|
|||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
"model.diffusion_model",
|
"model.diffusion_model",
|
||||||
model_manager);
|
model_manager);
|
||||||
|
} else if (sd_version_is_krea2(version)) {
|
||||||
|
cond_stage_model = std::make_shared<LLMEmbedder>(backend_for(SDBackendModule::TE),
|
||||||
|
tensor_storage_map,
|
||||||
|
version,
|
||||||
|
"",
|
||||||
|
false,
|
||||||
|
model_manager);
|
||||||
|
diffusion_model = std::make_shared<Krea2::Krea2Runner>(backend_for(SDBackendModule::DIFFUSION),
|
||||||
|
tensor_storage_map,
|
||||||
|
"model.diffusion_model",
|
||||||
|
model_manager);
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
bool is_chroma = false;
|
bool is_chroma = false;
|
||||||
for (auto pair : tensor_storage_map) {
|
for (auto pair : tensor_storage_map) {
|
||||||
@ -881,6 +894,7 @@ public:
|
|||||||
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
|
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
|
||||||
if (sd_version_is_wan(version) ||
|
if (sd_version_is_wan(version) ||
|
||||||
sd_version_is_qwen_image(version) ||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
sd_version_is_krea2(version) ||
|
||||||
sd_version_is_anima(version) ||
|
sd_version_is_anima(version) ||
|
||||||
sd_version_is_ltxav(version)) {
|
sd_version_is_ltxav(version)) {
|
||||||
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
||||||
@ -921,6 +935,7 @@ public:
|
|||||||
model_manager);
|
model_manager);
|
||||||
} else if (sd_version_is_wan(version) ||
|
} else if (sd_version_is_wan(version) ||
|
||||||
sd_version_is_qwen_image(version) ||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
sd_version_is_krea2(version) ||
|
||||||
sd_version_is_anima(version)) {
|
sd_version_is_anima(version)) {
|
||||||
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
@ -1267,7 +1282,8 @@ public:
|
|||||||
} else if (sd_version_is_flux(version) ||
|
} else if (sd_version_is_flux(version) ||
|
||||||
sd_version_is_longcat(version) ||
|
sd_version_is_longcat(version) ||
|
||||||
sd_version_is_lens(version) ||
|
sd_version_is_lens(version) ||
|
||||||
sd_version_is_ltxav(version)) {
|
sd_version_is_ltxav(version) ||
|
||||||
|
sd_version_is_krea2(version)) {
|
||||||
pred_type = FLUX_FLOW_PRED;
|
pred_type = FLUX_FLOW_PRED;
|
||||||
|
|
||||||
default_flow_shift = 1.0f; // TODO: validate
|
default_flow_shift = 1.0f; // TODO: validate
|
||||||
@ -1283,6 +1299,8 @@ public:
|
|||||||
default_flow_shift = 1.83f;
|
default_flow_shift = 1.83f;
|
||||||
} else if (sd_version_is_ltxav(version)) {
|
} else if (sd_version_is_ltxav(version)) {
|
||||||
default_flow_shift = 2.37f;
|
default_flow_shift = 2.37f;
|
||||||
|
} else if (sd_version_is_krea2(version)) {
|
||||||
|
default_flow_shift = 1.15f;
|
||||||
}
|
}
|
||||||
} else if (sd_version_is_flux2(version)) {
|
} else if (sd_version_is_flux2(version)) {
|
||||||
pred_type = FLUX2_FLOW_PRED;
|
pred_type = FLUX2_FLOW_PRED;
|
||||||
@ -1724,7 +1742,7 @@ public:
|
|||||||
} else if (sd_version_uses_flux_vae(version)) {
|
} else if (sd_version_uses_flux_vae(version)) {
|
||||||
latent_rgb_proj = flux_latent_rgb_proj;
|
latent_rgb_proj = flux_latent_rgb_proj;
|
||||||
latent_rgb_bias = flux_latent_rgb_bias;
|
latent_rgb_bias = flux_latent_rgb_bias;
|
||||||
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) {
|
||||||
latent_rgb_proj = wan_21_latent_rgb_proj;
|
latent_rgb_proj = wan_21_latent_rgb_proj;
|
||||||
latent_rgb_bias = wan_21_latent_rgb_bias;
|
latent_rgb_bias = wan_21_latent_rgb_bias;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user