refactor: simplify the model loading logic (#933)

* remove String2GGMLType

* remove preprocess_tensor

* fix clip init

* simplify the logic for reading weights
This commit is contained in:
leejet 2025-11-03 21:21:34 +08:00 committed by GitHub
parent 6103d86e2c
commit 8f6c5c217b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 534 additions and 622 deletions

View File

@ -476,11 +476,12 @@ protected:
public:
CLIPLayer(int64_t d_model,
int64_t n_head,
int64_t intermediate_size)
int64_t intermediate_size,
bool proj_in = false)
: d_model(d_model),
n_head(n_head),
intermediate_size(intermediate_size) {
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true, true));
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true, true, proj_in));
blocks["layer_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
blocks["layer_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
@ -509,11 +510,12 @@ public:
CLIPEncoder(int64_t n_layer,
int64_t d_model,
int64_t n_head,
int64_t intermediate_size)
int64_t intermediate_size,
bool proj_in = false)
: n_layer(n_layer) {
for (int i = 0; i < n_layer; i++) {
std::string name = "layers." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new CLIPLayer(d_model, n_head, intermediate_size));
blocks[name] = std::shared_ptr<GGMLBlock>(new CLIPLayer(d_model, n_head, intermediate_size, proj_in));
}
}
@ -549,10 +551,10 @@ protected:
int64_t num_positions;
bool force_clip_f32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type token_wtype = GGML_TYPE_F32;
if (!force_clip_f32) {
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
token_wtype = get_type(prefix + "token_embedding.weight", tensor_storage_map, GGML_TYPE_F32);
if (!support_get_rows(token_wtype)) {
token_wtype = GGML_TYPE_F32;
}
@ -605,7 +607,8 @@ protected:
int64_t image_size;
int64_t num_patches;
int64_t num_positions;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type patch_wtype = GGML_TYPE_F16;
enum ggml_type class_wtype = GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32;
@ -668,7 +671,7 @@ enum CLIPVersion {
class CLIPTextModel : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
if (version == OPEN_CLIP_VIT_BIGG_14) {
enum ggml_type wtype = GGML_TYPE_F32;
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
@ -689,7 +692,8 @@ public:
CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
bool with_final_ln = true,
bool force_clip_f32 = false)
bool force_clip_f32 = false,
bool proj_in = false)
: version(version), with_final_ln(with_final_ln) {
if (version == OPEN_CLIP_VIT_H_14) {
hidden_size = 1024;
@ -704,7 +708,7 @@ public:
}
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPEmbeddings(hidden_size, vocab_size, n_token, force_clip_f32));
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in));
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
}
@ -758,7 +762,7 @@ public:
int32_t n_layer = 24;
public:
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14) {
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool proj_in = false) {
if (version == OPEN_CLIP_VIT_H_14) {
hidden_size = 1280;
intermediate_size = 5120;
@ -773,7 +777,7 @@ public:
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPVisionEmbeddings(hidden_size, num_channels, patch_size, image_size));
blocks["pre_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in));
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
}
@ -811,8 +815,8 @@ protected:
int64_t out_features;
bool transpose_weight;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (transpose_weight) {
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
} else {
@ -845,7 +849,8 @@ public:
public:
CLIPVisionModelProjection(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
bool transpose_proj_w = false) {
bool transpose_proj_w = false,
bool proj_in = false) {
if (version == OPEN_CLIP_VIT_H_14) {
hidden_size = 1280;
projection_dim = 1024;
@ -853,7 +858,7 @@ public:
hidden_size = 1664;
}
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version));
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version, proj_in));
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
}
@ -881,13 +886,24 @@ struct CLIPTextModelRunner : public GGMLRunner {
CLIPTextModelRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
bool with_final_ln = true,
bool force_clip_f32 = false)
: GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, force_clip_f32) {
model.init(params_ctx, tensor_types, prefix);
: GGMLRunner(backend, offload_params_to_cpu) {
bool proj_in = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (contains(name, "self_attn.in_proj")) {
proj_in = true;
break;
}
}
model = CLIPTextModel(version, with_final_ln, force_clip_f32, proj_in);
model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {

View File

@ -182,8 +182,8 @@ protected:
int64_t dim_in;
int64_t dim_out;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32);
enum ggml_type bias_wtype = GGML_TYPE_F32;
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
@ -408,30 +408,40 @@ protected:
int64_t d_head;
int64_t depth = 1; // 1
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
bool use_linear = false;
public:
SpatialTransformer(int64_t in_channels,
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim)
int64_t context_dim,
bool use_linear)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
depth(depth),
context_dim(context_dim) {
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
context_dim(context_dim),
use_linear(use_linear) {
// disable_self_attn is always False
int64_t inner_dim = n_head * d_head; // in_channels
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
if (use_linear) {
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, inner_dim));
} else {
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
}
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false));
}
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
if (use_linear) {
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, in_channels));
} else {
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
}
}
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -440,8 +450,8 @@ public:
// x: [N, in_channels, h, w]
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
auto proj_in = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
auto x_in = x;
int64_t n = x->ne[3];
@ -450,10 +460,15 @@ public:
int64_t inner_dim = n_head * d_head;
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
if (use_linear) {
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
} else {
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
}
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
@ -462,11 +477,19 @@ public:
x = transformer_block->forward(ctx, x, context);
}
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
if (use_linear) {
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
} else {
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
// proj_out
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
}
x = ggml_add(ctx->ggml_ctx, x, x_in);
return x;
@ -475,7 +498,7 @@ public:
class AlphaBlender : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
enum ggml_type wtype = GGML_TYPE_F32;
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);

View File

@ -63,19 +63,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string& embd_dir,
SDVersion version = VERSION_SD1,
PMVersion pv = PM_VERSION_1)
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
bool force_clip_f32 = embd_dir.size() > 0;
if (sd_version_is_sd1(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
} else if (sd_version_is_sd2(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
} else if (sd_version_is_sdxl(version)) {
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32);
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32);
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32);
}
}
@ -623,9 +623,21 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
FrozenCLIPVisionEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
: vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) {
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
const String2TensorStorage& tensor_storage_map = {})
: GGMLRunner(backend, offload_params_to_cpu) {
std::string prefix = "cond_stage_model.transformer";
bool proj_in = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (contains(name, "self_attn.in_proj")) {
proj_in = true;
break;
}
}
vision_model = CLIPVisionModelProjection(OPEN_CLIP_VIT_H_14, false, proj_in);
vision_model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -673,12 +685,12 @@ struct SD3CLIPEmbedder : public Conditioner {
SD3CLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
const String2TensorStorage& tensor_storage_map = {})
: clip_g_tokenizer(0) {
bool use_clip_l = false;
bool use_clip_g = false;
bool use_t5 = false;
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
use_clip_l = true;
} else if (pair.first.find("text_encoders.clip_g") != std::string::npos) {
@ -692,13 +704,13 @@ struct SD3CLIPEmbedder : public Conditioner {
return;
}
if (use_clip_l) {
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
}
if (use_clip_g) {
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
}
if (use_t5) {
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer");
}
}
@ -1082,10 +1094,10 @@ struct FluxCLIPEmbedder : public Conditioner {
FluxCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {}) {
const String2TensorStorage& tensor_storage_map = {}) {
bool use_clip_l = false;
bool use_t5 = false;
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
use_clip_l = true;
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
@ -1099,12 +1111,12 @@ struct FluxCLIPEmbedder : public Conditioner {
}
if (use_clip_l) {
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
} else {
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
}
if (use_t5) {
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer");
} else {
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
}
@ -1342,13 +1354,13 @@ struct T5CLIPEmbedder : public Conditioner {
T5CLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
bool use_mask = false,
int mask_pad = 1,
bool is_umt5 = false)
const String2TensorStorage& tensor_storage_map = {},
bool use_mask = false,
int mask_pad = 1,
bool is_umt5 = false)
: use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) {
bool use_t5 = false;
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
use_t5 = true;
}
@ -1358,7 +1370,7 @@ struct T5CLIPEmbedder : public Conditioner {
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
return;
} else {
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5);
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer", is_umt5);
}
}
@ -1549,12 +1561,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
bool enable_vision = false) {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
bool enable_vision = false) {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
offload_params_to_cpu,
tensor_types,
tensor_storage_map,
"text_encoders.qwen2vl",
enable_vision);
}

View File

@ -27,6 +27,7 @@ protected:
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
bool use_linear_projection = false;
public:
int model_channels = 320;
@ -82,7 +83,7 @@ public:
int64_t d_head,
int64_t depth,
int64_t context_dim) -> SpatialTransformer* {
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
};
auto make_zero_conv = [&](int64_t channels) {
@ -318,10 +319,10 @@ struct ControlNet : public GGMLRunner {
ControlNet(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1)
const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, offload_params_to_cpu), control_net(version) {
control_net.init(params_ctx, tensor_types, "");
control_net.init(params_ctx, tensor_storage_map, "");
}
~ControlNet() override {

View File

@ -44,9 +44,9 @@ struct UNetModel : public DiffusionModel {
UNetModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1)
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) {
const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_SD1)
: unet(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version) {
}
std::string get_desc() override {
@ -102,8 +102,8 @@ struct MMDiTModel : public DiffusionModel {
MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
const String2TensorStorage& tensor_storage_map = {})
: mmdit(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model") {
}
std::string get_desc() override {
@ -158,10 +158,10 @@ struct FluxModel : public DiffusionModel {
FluxModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool use_mask = false)
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) {
const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_FLUX,
bool use_mask = false)
: flux(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version, use_mask) {
}
std::string get_desc() override {
@ -221,10 +221,10 @@ struct WanModel : public DiffusionModel {
WanModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_WAN2)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_WAN2)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
}
std::string get_desc() override {
@ -283,10 +283,10 @@ struct QwenImageModel : public DiffusionModel {
QwenImageModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_QWEN_IMAGE)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_QWEN_IMAGE)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
}
std::string get_desc() override {

View File

@ -156,7 +156,7 @@ struct ESRGAN : public GGMLRunner {
ESRGAN(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
const String2TensorStorage& tensor_storage_map = {})
: GGMLRunner(backend, offload_params_to_cpu) {
// rrdb_net will be created in load_from_file
}

View File

@ -37,7 +37,7 @@ namespace Flux {
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
ggml_type wtype = GGML_TYPE_F32;
params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -1115,10 +1115,10 @@ namespace Flux {
FluxRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool use_mask = false)
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool use_mask = false)
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
flux_params.version = version;
flux_params.guidance_embed = false;
@ -1134,7 +1134,7 @@ namespace Flux {
flux_params.in_channels = 3;
flux_params.patch_size = 16;
}
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix))
continue;
@ -1172,7 +1172,7 @@ namespace Flux {
}
flux = Flux(flux_params);
flux.init(params_ctx, tensor_types, prefix);
flux.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -1403,17 +1403,16 @@ namespace Flux {
return;
}
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
// item.second = model_data_type;
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type;
}
}
std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend,
false,
tensor_types,
tensor_storage_map,
"model.diffusion_model",
VERSION_CHROMA_RADIANCE,
false);

View File

@ -1460,8 +1460,6 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
#define MAX_PARAMS_TENSOR_NUM 32768
#define MAX_GRAPH_SIZE 327680
typedef std::map<std::string, enum ggml_type> String2GGMLType;
struct GGMLRunnerContext {
ggml_backend_t backend = nullptr;
ggml_context* ggml_ctx = nullptr;
@ -1900,30 +1898,36 @@ protected:
GGMLBlockMap blocks;
ParameterMap params;
ggml_type get_type(const std::string& name, const String2GGMLType& tensor_types, ggml_type default_type) {
auto iter = tensor_types.find(name);
if (iter != tensor_types.end()) {
return iter->second;
ggml_type get_type(const std::string& name, const String2TensorStorage& tensor_storage_map, ggml_type default_type) {
ggml_type wtype = default_type;
auto iter = tensor_storage_map.find(name);
if (iter != tensor_storage_map.end()) {
const TensorStorage& tensor_storage = iter->second;
if (tensor_storage.expected_type != GGML_TYPE_COUNT) {
wtype = tensor_storage.expected_type;
} else {
wtype = tensor_storage.type;
}
}
return default_type;
return wtype;
}
void init_blocks(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
void init_blocks(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
for (auto& pair : blocks) {
auto& block = pair.second;
block->init(ctx, tensor_types, prefix + pair.first);
block->init(ctx, tensor_storage_map, prefix + pair.first);
}
}
virtual void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {}
virtual void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {}
public:
void init(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
void init(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") {
if (prefix.size() > 0) {
prefix = prefix + ".";
}
init_blocks(ctx, tensor_types, prefix);
init_params(ctx, tensor_types, prefix);
init_blocks(ctx, tensor_storage_map, prefix);
init_params(ctx, tensor_storage_map, prefix);
}
size_t get_params_num() {
@ -2001,8 +2005,8 @@ protected:
bool force_prec_f32;
float scale;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
wtype = GGML_TYPE_F32;
}
@ -2049,8 +2053,8 @@ class Embedding : public UnaryBlock {
protected:
int64_t embedding_dim;
int64_t num_embeddings;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
if (!support_get_rows(wtype)) {
wtype = GGML_TYPE_F32;
}
@ -2093,7 +2097,7 @@ protected:
bool bias;
float scale = 1.f;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
if (bias) {
@ -2157,7 +2161,7 @@ protected:
int64_t dilation;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
if (bias) {
@ -2204,7 +2208,7 @@ protected:
std::tuple<int, int, int> dilation;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx,
wtype,
@ -2253,7 +2257,7 @@ protected:
bool elementwise_affine;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
if (elementwise_affine) {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
@ -2295,7 +2299,7 @@ protected:
float eps;
bool affine;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
if (affine) {
enum ggml_type wtype = GGML_TYPE_F32;
enum ggml_type bias_wtype = GGML_TYPE_F32;
@ -2336,7 +2340,7 @@ protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -2359,9 +2363,11 @@ class MultiheadAttention : public GGMLBlock {
protected:
int64_t embed_dim;
int64_t n_head;
bool proj_in;
std::string q_proj_name;
std::string k_proj_name;
std::string v_proj_name;
std::string in_proj_name;
std::string out_proj_name;
public:
@ -2369,19 +2375,27 @@ public:
int64_t n_head,
bool qkv_proj_bias = true,
bool out_proj_bias = true,
bool proj_in = false,
std::string q_proj_name = "q_proj",
std::string k_proj_name = "k_proj",
std::string v_proj_name = "v_proj",
std::string in_proj_name = "in_proj",
std::string out_proj_name = "out_proj")
: embed_dim(embed_dim),
n_head(n_head),
proj_in(proj_in),
q_proj_name(q_proj_name),
k_proj_name(k_proj_name),
v_proj_name(v_proj_name),
in_proj_name(in_proj_name),
out_proj_name(out_proj_name) {
blocks[q_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
blocks[k_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
blocks[v_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
if (proj_in) {
blocks[in_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim * 3, qkv_proj_bias));
} else {
blocks[q_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
blocks[k_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
blocks[v_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
}
blocks[out_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, out_proj_bias));
}
@ -2389,14 +2403,27 @@ public:
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
bool mask = false) {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks[k_proj_name]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks[v_proj_name]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
struct ggml_tensor* q = q_proj->forward(ctx, x);
struct ggml_tensor* k = k_proj->forward(ctx, x);
struct ggml_tensor* v = v_proj->forward(ctx, x);
ggml_tensor* q;
ggml_tensor* k;
ggml_tensor* v;
if (proj_in) {
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks[in_proj_name]);
auto qkv = in_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
q = qkv_vec[0];
k = qkv_vec[1];
v = qkv_vec[2];
} else {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks[k_proj_name]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks[v_proj_name]);
q = q_proj->forward(ctx, x);
k = k_proj->forward(ctx, x);
v = v_proj->forward(ctx, x);
}
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]

View File

@ -633,13 +633,13 @@ protected:
int64_t hidden_size;
std::string qk_norm;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F32;
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
}
public:
MMDiT(const String2GGMLType& tensor_types = {}) {
MMDiT(const String2TensorStorage& tensor_storage_map = {}) {
// input_size is always None
// learn_sigma is always False
// register_length is alwalys 0
@ -652,8 +652,7 @@ public:
// pos_embed_offset is not used
// context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
// read tensors from tensor_types
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
continue;
@ -852,10 +851,10 @@ struct MMDiTRunner : public GGMLRunner {
MMDiTRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix);
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_storage_map) {
mmdit.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {

415
model.cpp
View File

@ -140,7 +140,9 @@ std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
{"model.visual.proj", "transformer.visual_projection.weight"},
};
std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
std::unordered_map<std::string, std::string> open_clip_to_hf_clip_resblock = {
{"attn.in_proj_bias", "self_attn.in_proj.bias"},
{"attn.in_proj_weight", "self_attn.in_proj.weight"},
{"attn.out_proj.bias", "self_attn.out_proj.bias"},
{"attn.out_proj.weight", "self_attn.out_proj.weight"},
{"ln_1.bias", "layer_norm1.bias"},
@ -351,10 +353,8 @@ std::string convert_cond_model_name(const std::string& name) {
std::string idx = remain.substr(0, remain.find("."));
std::string suffix = remain.substr(idx.length() + 1);
if (suffix == "attn.in_proj_weight" || suffix == "attn.in_proj_bias") {
new_name = hf_clip_resblock_prefix + idx + "." + suffix;
} else if (open_clip_to_hk_clip_resblock.find(suffix) != open_clip_to_hk_clip_resblock.end()) {
std::string new_suffix = open_clip_to_hk_clip_resblock[suffix];
if (open_clip_to_hf_clip_resblock.find(suffix) != open_clip_to_hf_clip_resblock.end()) {
std::string new_suffix = open_clip_to_hf_clip_resblock[suffix];
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix;
}
}
@ -740,80 +740,6 @@ std::string convert_tensor_name(std::string name) {
return new_name;
}
void add_preprocess_tensor_storage_types(String2GGMLType& tensor_storages_types, std::string name, enum ggml_type type) {
std::string new_name = convert_tensor_name(name);
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {
size_t prefix_size = new_name.find("attn.in_proj_weight");
std::string prefix = new_name.substr(0, prefix_size);
tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type;
tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type;
tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type;
} else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) {
size_t prefix_size = new_name.find("attn.in_proj_bias");
std::string prefix = new_name.substr(0, prefix_size);
tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type;
tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type;
tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type;
} else {
tensor_storages_types[new_name] = type;
}
}
void preprocess_tensor(TensorStorage tensor_storage,
std::vector<TensorStorage>& processed_tensor_storages) {
std::vector<TensorStorage> result;
std::string new_name = convert_tensor_name(tensor_storage.name);
// convert unet transformer linear to conv2d 1x1
if (starts_with(new_name, "model.diffusion_model.") &&
!starts_with(new_name, "model.diffusion_model.proj_out.") &&
(ends_with(new_name, "proj_in.weight") || ends_with(new_name, "proj_out.weight"))) {
tensor_storage.unsqueeze();
}
// convert vae attn block linear to conv2d 1x1
if (starts_with(new_name, "first_stage_model.") && new_name.find("attn_1") != std::string::npos) {
tensor_storage.unsqueeze();
}
// wan vae
if (ends_with(new_name, "gamma")) {
tensor_storage.reverse_ne();
tensor_storage.n_dims = 1;
tensor_storage.reverse_ne();
}
tensor_storage.name = new_name;
if (new_name.find("cond_stage_model") != std::string::npos &&
ends_with(new_name, "attn.in_proj_weight")) {
size_t prefix_size = new_name.find("attn.in_proj_weight");
std::string prefix = new_name.substr(0, prefix_size);
std::vector<TensorStorage> chunks = tensor_storage.chunk(3);
chunks[0].name = prefix + "self_attn.q_proj.weight";
chunks[1].name = prefix + "self_attn.k_proj.weight";
chunks[2].name = prefix + "self_attn.v_proj.weight";
processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());
} else if (new_name.find("cond_stage_model") != std::string::npos &&
ends_with(new_name, "attn.in_proj_bias")) {
size_t prefix_size = new_name.find("attn.in_proj_bias");
std::string prefix = new_name.substr(0, prefix_size);
std::vector<TensorStorage> chunks = tensor_storage.chunk(3);
chunks[0].name = prefix + "self_attn.q_proj.bias";
chunks[1].name = prefix + "self_attn.k_proj.bias";
chunks[2].name = prefix + "self_attn.v_proj.bias";
processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());
} else {
processed_tensor_storages.push_back(tensor_storage);
}
}
float bf16_to_f32(uint16_t bfloat16) {
uint32_t val_bits = (static_cast<uint32_t>(bfloat16) << 16);
return *reinterpret_cast<float*>(&val_bits);
@ -989,44 +915,10 @@ void convert_tensor(void* src,
/*================================================= ModelLoader ==================================================*/
// ported from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py#L16
std::map<char, int> unicode_to_byte() {
std::map<int, char> byte_to_unicode;
// List of utf-8 byte ranges
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
byte_to_unicode[b] = static_cast<char>(b);
}
for (int b = 49825; b <= 49836; ++b) {
byte_to_unicode[b] = static_cast<char>(b);
}
for (int b = 49838; b <= 50111; ++b) {
byte_to_unicode[b] = static_cast<char>(b);
}
// printf("%d %d %d %d\n", static_cast<int>('¡'), static_cast<int>('¬'), static_cast<int>('®'), static_cast<int>('ÿ'));
// exit(1);
int n = 0;
for (int b = 0; b < 256; ++b) {
if (byte_to_unicode.find(b) == byte_to_unicode.end()) {
byte_to_unicode[b] = static_cast<char>(256 + n);
n++;
}
}
// byte_encoder = bytes_to_unicode()
// byte_decoder = {v: k for k, v in byte_encoder.items()}
std::map<char, int> byte_decoder;
for (const auto& entry : byte_to_unicode) {
byte_decoder[entry.second] = entry.first;
}
byte_to_unicode.clear();
return byte_decoder;
void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) {
TensorStorage copy = tensor_storage;
copy.name = convert_tensor_name(copy.name);
tensor_storage_map[copy.name] = std::move(copy);
}
bool is_zip_file(const std::string& file_path) {
@ -1156,8 +1048,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
// LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str());
tensor_storages.push_back(tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
add_tensor_storage(tensor_storage);
}
return true;
@ -1182,8 +1073,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
tensor_storages.push_back(tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
add_tensor_storage(tensor_storage);
}
gguf_free(ctx_gguf_);
@ -1350,8 +1240,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
}
tensor_storages.push_back(tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
add_tensor_storage(tensor_storage);
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
}
@ -1370,11 +1259,13 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
if (!init_from_safetensors_file(unet_path, "unet.")) {
return false;
}
for (auto ts : tensor_storages) {
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (name.find("add_embedding") != std::string::npos || name.find("label_emb") != std::string::npos) {
// probably SDXL
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
for (auto& tensor_storage : tensor_storages) {
String2TensorStorage new_tensor_storage_map;
for (auto& [name, tensor_storage] : tensor_storage_map) {
int len = 34;
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
if (pos == std::string::npos) {
@ -1382,11 +1273,15 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
}
if (pos != std::string::npos) {
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
std::string new_name = "model.diffusion_model.output_blocks.2.2.conv" + name.substr(len);
LOG_DEBUG("NEW NAME: %s", new_name.c_str());
tensor_storage.name = new_name;
new_tensor_storage_map[new_name] = tensor_storage;
} else {
new_tensor_storage_map[name] = tensor_storage;
}
}
tensor_storage_map = new_tensor_storage_map;
break;
}
}
@ -1712,8 +1607,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
name = prefix + name;
}
reader.tensor_storage.name = name;
tensor_storages.push_back(reader.tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
add_tensor_storage(reader.tensor_storage);
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
// reset
@ -1767,15 +1661,6 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
return true;
}
bool ModelLoader::model_is_unet() {
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
return true;
}
}
return false;
}
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
@ -1789,7 +1674,7 @@ SDVersion ModelLoader::get_sd_version() {
bool has_img_emb = false;
bool has_middle_block_1 = false;
for (auto& tensor_storage : tensor_storages) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
@ -1910,7 +1795,7 @@ SDVersion ModelLoader::get_sd_version() {
std::map<ggml_type, uint32_t> ModelLoader::get_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
@ -1927,7 +1812,7 @@ std::map<ggml_type, uint32_t> ModelLoader::get_wtype_stat() {
std::map<ggml_type, uint32_t> ModelLoader::get_conditioner_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
@ -1951,7 +1836,7 @@ std::map<ggml_type, uint32_t> ModelLoader::get_conditioner_wtype_stat() {
std::map<ggml_type, uint32_t> ModelLoader::get_diffusion_model_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
@ -1972,7 +1857,7 @@ std::map<ggml_type, uint32_t> ModelLoader::get_diffusion_model_wtype_stat() {
std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
@ -1993,26 +1878,14 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
}
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
for (auto& pair : tensor_storages_types) {
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
bool found = false;
for (auto& tensor_storage : tensor_storages) {
std::map<std::string, ggml_type> temp;
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
for (auto& preprocessed_name : temp) {
if (preprocessed_name.first == pair.first) {
if (tensor_should_be_converted(tensor_storage, wtype)) {
pair.second = wtype;
}
found = true;
break;
}
}
if (found) {
break;
}
}
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (!tensor_should_be_converted(tensor_storage, wtype)) {
continue;
}
tensor_storage.expected_type = wtype;
}
}
@ -2047,74 +1920,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
LOG_DEBUG("using %d threads for model loading", num_threads_to_use);
int64_t start_time = ggml_time_ms();
std::vector<TensorStorage> processed_tensor_storages;
{
struct IndexedStorage {
size_t index;
TensorStorage ts;
};
std::mutex vec_mutex;
std::vector<IndexedStorage> all_results;
int n_threads = std::min(num_threads_to_use, (int)tensor_storages.size());
if (n_threads < 1) {
n_threads = 1;
}
std::vector<std::thread> workers;
for (int i = 0; i < n_threads; ++i) {
workers.emplace_back([&, thread_id = i]() {
std::vector<IndexedStorage> local_results;
std::vector<TensorStorage> temp_storages;
for (size_t j = thread_id; j < tensor_storages.size(); j += n_threads) {
const auto& tensor_storage = tensor_storages[j];
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
temp_storages.clear();
preprocess_tensor(tensor_storage, temp_storages);
for (const auto& ts : temp_storages) {
local_results.push_back({j, ts});
}
}
if (!local_results.empty()) {
std::lock_guard<std::mutex> lock(vec_mutex);
all_results.insert(all_results.end(),
local_results.begin(), local_results.end());
}
});
}
for (auto& w : workers) {
w.join();
}
std::vector<IndexedStorage> deduplicated;
deduplicated.reserve(all_results.size());
std::unordered_map<std::string, size_t> name_to_pos;
for (auto& entry : all_results) {
auto it = name_to_pos.find(entry.ts.name);
if (it == name_to_pos.end()) {
name_to_pos.emplace(entry.ts.name, deduplicated.size());
deduplicated.push_back(entry);
} else if (deduplicated[it->second].index < entry.index) {
deduplicated[it->second] = entry;
}
}
std::sort(deduplicated.begin(), deduplicated.end(), [](const IndexedStorage& a, const IndexedStorage& b) {
return a.index < b.index;
});
processed_tensor_storages.reserve(deduplicated.size());
for (auto& entry : deduplicated) {
processed_tensor_storages.push_back(entry.ts);
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
processed_tensor_storages.push_back(tensor_storage);
}
process_time_ms = ggml_time_ms() - start_time;
@ -2231,107 +2043,72 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
}
};
char* read_buf = nullptr;
char* target_buf = nullptr;
char* convert_buf = nullptr;
if (dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) {
if (tensor_storage.type == dst_tensor->type) {
GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes());
if (tensor_storage.is_f64 || tensor_storage.is_i64) {
read_buffer.resize(tensor_storage.nbytes_to_read());
read_data((char*)read_buffer.data(), nbytes_to_read);
read_buf = (char*)read_buffer.data();
} else {
read_data((char*)dst_tensor->data, nbytes_to_read);
read_buf = (char*)dst_tensor->data;
}
t1 = ggml_time_ms();
read_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f64) {
f64_to_f32_vec((double*)read_buffer.data(), (float*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_i64) {
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements());
}
t1 = ggml_time_ms();
convert_time_ms.fetch_add(t1 - t0);
target_buf = (char*)dst_tensor->data;
} else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_data((char*)read_buffer.data(), nbytes_to_read);
t1 = ggml_time_ms();
read_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f64) {
// inplace op
f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_i64) {
// inplace op
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements());
}
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
t1 = ggml_time_ms();
convert_time_ms.fetch_add(t1 - t0);
read_buf = (char*)read_buffer.data();
target_buf = read_buf;
convert_buf = (char*)dst_tensor->data;
}
} else {
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
read_data((char*)read_buffer.data(), nbytes_to_read);
t1 = ggml_time_ms();
read_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f64) {
// inplace op
f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_i64) {
// inplace op
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements());
}
if (tensor_storage.type == dst_tensor->type) {
// copy to device memory
t1 = ggml_time_ms();
convert_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
t1 = ggml_time_ms();
copy_to_backend_time_ms.fetch_add(t1 - t0);
} else {
// convert first, then copy to device memory
read_buf = (char*)read_buffer.data();
target_buf = read_buf;
if (tensor_storage.type != dst_tensor->type) {
convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_tensor((void*)read_buffer.data(), tensor_storage.type, (void*)convert_buffer.data(), dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
t1 = ggml_time_ms();
convert_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
t1 = ggml_time_ms();
copy_to_backend_time_ms.fetch_add(t1 - t0);
convert_buf = (char*)convert_buffer.data();
}
}
t0 = ggml_time_ms();
read_data(read_buf, nbytes_to_read);
t1 = ggml_time_ms();
read_time_ms.fetch_add(t1 - t0);
t0 = ggml_time_ms();
if (tensor_storage.is_bf16) {
bf16_to_f32_vec((uint16_t*)read_buf, (float*)target_buf, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
f8_e4m3_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
f8_e5m2_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements());
} else if (tensor_storage.is_f64) {
f64_to_f32_vec((double*)read_buf, (float*)target_buf, tensor_storage.nelements());
} else if (tensor_storage.is_i64) {
i64_to_i32_vec((int64_t*)read_buf, (int32_t*)target_buf, tensor_storage.nelements());
}
if (tensor_storage.type != dst_tensor->type) {
convert_tensor((void*)target_buf,
tensor_storage.type,
convert_buf,
dst_tensor->type,
(int)tensor_storage.nelements() / (int)tensor_storage.ne[0],
(int)tensor_storage.ne[0]);
} else {
convert_buf = read_buf;
}
t1 = ggml_time_ms();
convert_time_ms.fetch_add(t1 - t0);
if (dst_tensor->buffer != nullptr && !ggml_backend_buffer_is_host(dst_tensor->buffer)) {
t0 = ggml_time_ms();
ggml_backend_tensor_set(dst_tensor, convert_buf, 0, ggml_nbytes(dst_tensor));
t1 = ggml_time_ms();
copy_to_backend_time_ms.fetch_add(t1 - t0);
}
}
if (zip != nullptr) {
zip_close(zip);
@ -2520,7 +2297,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
mem_size += tensor_storages.size() * ggml_tensor_overhead();
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
mem_size += get_params_mem_size(backend, type);
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
@ -2587,14 +2364,10 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
}
int64_t mem_size = 0;
std::vector<TensorStorage> processed_tensor_storages;
for (auto& tensor_storage : tensor_storages) {
for (auto [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
preprocess_tensor(tensor_storage, processed_tensor_storages);
}
for (auto& tensor_storage : processed_tensor_storages) {
if (tensor_should_be_converted(tensor_storage, type)) {
tensor_storage.type = type;
}

24
model.h
View File

@ -65,6 +65,15 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
return false;
}
static inline bool sd_version_is_unet(SDVersion version) {
if (sd_version_is_sd1(version) ||
sd_version_is_sd2(version) ||
sd_version_is_sdxl(version)) {
return true;
}
return false;
}
static inline bool sd_version_is_sd3(SDVersion version) {
if (version == VERSION_SD3) {
return true;
@ -134,6 +143,7 @@ enum PMVersion {
struct TensorStorage {
std::string name;
ggml_type type = GGML_TYPE_F32;
ggml_type expected_type = GGML_TYPE_COUNT;
bool is_bf16 = false;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
@ -242,12 +252,14 @@ struct TensorStorage {
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
typedef std::map<std::string, enum ggml_type> String2GGMLType;
typedef std::map<std::string, TensorStorage> String2TensorStorage;
class ModelLoader {
protected:
std::vector<std::string> file_paths_;
std::vector<TensorStorage> tensor_storages;
String2TensorStorage tensor_storage_map;
void add_tensor_storage(const TensorStorage& tensor_storage);
bool parse_data_pkl(uint8_t* buffer,
size_t buffer_size,
@ -262,15 +274,13 @@ protected:
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
public:
String2GGMLType tensor_storages_types;
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
bool model_is_unet();
SDVersion get_sd_version();
std::map<ggml_type, uint32_t> get_wtype_stat();
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string prefix = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
@ -279,8 +289,8 @@ public:
std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names;
for (const auto& ts : tensor_storages) {
names.push_back(ts.name);
for (const auto& [name, tensor_storage] : tensor_storage_map) {
names.push_back(name);
}
return names;
}

View File

@ -412,7 +412,7 @@ public:
public:
PhotoMakerIDEncoder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
SDVersion version = VERSION_SDXL,
PMVersion pm_v = PM_VERSION_1,
@ -422,9 +422,9 @@ public:
pm_version(pm_v),
style_strength(sty) {
if (pm_version == PM_VERSION_1) {
id_encoder.init(params_ctx, tensor_types, prefix);
id_encoder.init(params_ctx, tensor_storage_map, prefix);
} else if (pm_version == PM_VERSION_2) {
id_encoder2.init(params_ctx, tensor_types, prefix);
id_encoder2.init(params_ctx, tensor_storage_map, prefix);
}
}

View File

@ -502,12 +502,12 @@ namespace Qwen {
QwenImageRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_QWEN_IMAGE)
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_QWEN_IMAGE)
: GGMLRunner(backend, offload_params_to_cpu) {
qwen_image_params.num_layers = 0;
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
@ -526,7 +526,7 @@ namespace Qwen {
}
LOG_INFO("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers);
qwen_image = QwenImageModel(qwen_image_params);
qwen_image.init(params_ctx, tensor_types, prefix);
qwen_image.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -649,17 +649,16 @@ namespace Qwen {
return;
}
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
item.second = model_data_type;
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type;
}
}
std::shared_ptr<QwenImageRunner> qwen_image = std::make_shared<QwenImageRunner>(backend,
false,
tensor_types,
tensor_storage_map,
"model.diffusion_model",
VERSION_QWEN_IMAGE);

View File

@ -910,13 +910,13 @@ namespace Qwen {
Qwen2_5_VLRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool enable_vision_ = false)
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
bool have_vision_weight = false;
bool llama_cpp_style = false;
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
@ -940,7 +940,7 @@ namespace Qwen {
}
}
model = Qwen2_5_VL(params, enable_vision, llama_cpp_style);
model.init(params_ctx, tensor_types, prefix);
model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -1188,10 +1188,10 @@ namespace Qwen {
Qwen2_5_VLEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
bool enable_vision = false)
: model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
bool enable_vision = false)
: model(backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@ -1347,17 +1347,16 @@ namespace Qwen {
return;
}
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
item.second = model_data_type;
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type;
}
}
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::make_shared<Qwen2_5_VLEmbedder>(backend,
false,
tensor_types,
tensor_storage_map,
"qwen2vl",
true);

View File

@ -213,7 +213,7 @@ public:
}
}
bool is_unet = model_loader.model_is_unet();
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
@ -273,12 +273,12 @@ public:
return false;
}
auto& tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
item.second = GGML_TYPE_F16;
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (contains(name, "qwen2vl") &&
ends_with(name, "weight") &&
(tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) {
tensor_storage.expected_type = GGML_TYPE_F16;
}
}
@ -344,13 +344,13 @@ public:
if (sd_version_is_sd3(version)) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
} else if (sd_version_is_flux(version)) {
bool is_chroma = false;
for (auto pair : model_loader.tensor_storages_types) {
for (auto pair : tensor_storage_map) {
if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
is_chroma = true;
break;
@ -368,42 +368,42 @@ public:
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
sd_ctx_params->chroma_use_t5_mask,
sd_ctx_params->chroma_t5_mask_pad);
} else {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
}
diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
true,
1,
true);
diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"model.diffusion_model",
version);
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"model.high_noise_diffusion_model",
version);
}
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors);
}
@ -414,32 +414,32 @@ public:
}
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"",
enable_vision);
diffusion_model = std::make_shared<QwenImageModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"model.diffusion_model",
version);
} else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
SAFE_STR(sd_ctx_params->embedding_dir),
version,
PM_VERSION_2);
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
SAFE_STR(sd_ctx_params->embedding_dir),
version);
}
diffusion_model = std::make_shared<UNetModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the diffusion model");
@ -477,7 +477,7 @@ public:
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
version);
@ -489,7 +489,7 @@ public:
} else if (!use_tiny_autoencoder) {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
@ -512,7 +512,7 @@ public:
} else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
@ -533,7 +533,7 @@ public:
}
control_net = std::make_shared<ControlNet>(controlnet_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the control net");
@ -544,7 +544,7 @@ public:
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"pmid",
version,
PM_VERSION_2);
@ -552,7 +552,7 @@ public:
} else {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"pmid",
version);
}
@ -733,12 +733,12 @@ public:
is_using_v_parameterization = true;
}
} else if (sd_version_is_sdxl(version)) {
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
@ -758,10 +758,9 @@ public:
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 1.0f; // TODO: validate
for (auto pair : model_loader.tensor_storages_types) {
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
shift = 1.15f;
break;
}
}
}

25
t5.hpp
View File

@ -461,7 +461,7 @@ protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
@ -759,7 +759,7 @@ struct T5Runner : public GGMLRunner {
T5Runner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool is_umt5 = false)
: GGMLRunner(backend, offload_params_to_cpu) {
@ -768,7 +768,7 @@ struct T5Runner : public GGMLRunner {
params.relative_attention = false;
}
model = T5(params);
model.init(params_ctx, tensor_types, prefix);
model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -905,10 +905,10 @@ struct T5Embedder {
T5Embedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
bool is_umt5 = false)
: model(backend, offload_params_to_cpu, tensor_types, prefix, is_umt5), tokenizer(is_umt5) {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
bool is_umt5 = false)
: model(backend, offload_params_to_cpu, tensor_storage_map, prefix, is_umt5), tokenizer(is_umt5) {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@ -1009,15 +1009,14 @@ struct T5Embedder {
return;
}
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
item.second = model_data_type;
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type;
}
}
std::shared_ptr<T5Embedder> t5 = std::make_shared<T5Embedder>(backend, false, tensor_types, "", true);
std::shared_ptr<T5Embedder> t5 = std::make_shared<T5Embedder>(backend, false, tensor_storage_map, "", true);
t5->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;

View File

@ -197,14 +197,14 @@ struct TinyAutoEncoder : public GGMLRunner {
TinyAutoEncoder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool decoder_only = true,
SDVersion version = VERSION_SD1)
: decode_only(decoder_only),
taesd(decoder_only, version),
GGMLRunner(backend, offload_params_to_cpu) {
taesd.init(params_ctx, tensor_types, prefix);
taesd.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {

View File

@ -20,9 +20,10 @@ public:
int64_t d_head,
int64_t depth,
int64_t context_dim,
bool use_linear,
int64_t time_depth = 1,
int64_t max_time_embed_period = 10000)
: SpatialTransformer(in_channels, n_head, d_head, depth, context_dim),
: SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear),
max_time_embed_period(max_time_embed_period) {
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
// use_spatial_context is always True
@ -178,17 +179,19 @@ protected:
int num_heads = 8;
int num_head_channels = -1; // channels // num_heads
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
bool use_linear_projection = false;
public:
int model_channels = 320;
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {})
UnetModelBlock(SDVersion version = VERSION_SD1, const String2TensorStorage& tensor_storage_map = {})
: version(version) {
if (sd_version_is_sd2(version)) {
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
context_dim = 1024;
num_head_channels = 64;
num_heads = -1;
use_linear_projection = true;
} else if (sd_version_is_sdxl(version)) {
context_dim = 2048;
attention_resolutions = {4, 2};
@ -196,13 +199,15 @@ public:
transformer_depth = {1, 2, 10};
num_head_channels = 64;
num_heads = -1;
use_linear_projection = true;
} else if (version == VERSION_SVD) {
in_channels = 8;
out_channels = 4;
context_dim = 1024;
adm_in_channels = 768;
num_head_channels = 64;
num_heads = -1;
in_channels = 8;
out_channels = 4;
context_dim = 1024;
adm_in_channels = 768;
num_head_channels = 64;
num_heads = -1;
use_linear_projection = true;
} else if (version == VERSION_SD1_TINY_UNET) {
num_res_blocks = 1;
channel_mult = {1, 2, 4};
@ -249,9 +254,9 @@ public:
int64_t depth,
int64_t context_dim) -> SpatialTransformer* {
if (version == VERSION_SVD) {
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
} else {
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
}
};
@ -581,11 +586,11 @@ struct UNetModelRunner : public GGMLRunner {
UNetModelRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
SDVersion version = VERSION_SD1)
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types) {
unet.init(params_ctx, tensor_types, prefix);
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_storage_map) {
unet.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {

View File

@ -51,7 +51,7 @@ struct UpscalerGGML {
backend = ggml_backend_cpu_init();
}
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.get_tensor_storage_map());
if (direct) {
esrgan_upscaler->set_conv2d_direct_enabled(true);
}

118
vae.hpp
View File

@ -64,25 +64,32 @@ public:
class AttnBlock : public UnaryBlock {
protected:
int64_t in_channels;
bool use_linear;
public:
AttnBlock(int64_t in_channels)
: in_channels(in_channels) {
AttnBlock(int64_t in_channels, bool use_linear)
: in_channels(in_channels), use_linear(use_linear) {
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
if (use_linear) {
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
} else {
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<Conv2d>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<Conv2d>(blocks["v"]);
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
auto h_ = norm->forward(ctx, x);
@ -91,23 +98,44 @@ public:
const int64_t h = h_->ne[1];
const int64_t w = h_->ne[0];
auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
ggml_tensor* q;
ggml_tensor* k;
ggml_tensor* v;
if (use_linear) {
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
} else {
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
}
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
} else {
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
}
h_ = ggml_add(ctx->ggml_ctx, h_, x);
return h_;
@ -163,8 +191,8 @@ public:
class VideoResnetBlock : public ResnetBlock {
protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}
@ -233,7 +261,8 @@ public:
int num_res_blocks,
int in_channels,
int z_channels,
bool double_z = true)
bool double_z = true,
bool use_linear_projection = false)
: ch(ch),
ch_mult(ch_mult),
num_res_blocks(num_res_blocks),
@ -264,7 +293,7 @@ public:
}
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
@ -351,8 +380,9 @@ public:
std::vector<int> ch_mult,
int num_res_blocks,
int z_channels,
bool video_decoder = false,
int video_kernel_size = 3)
bool use_linear_projection = false,
bool video_decoder = false,
int video_kernel_size = 3)
: ch(ch),
out_ch(out_ch),
ch_mult(ch_mult),
@ -366,7 +396,7 @@ public:
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
for (int i = num_resolutions - 1; i >= 0; i--) {
@ -454,9 +484,10 @@ protected:
} dd_config;
public:
AutoencodingEngine(bool decode_only = true,
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
AutoencodingEngine(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) {
dd_config.z_channels = 16;
@ -470,6 +501,7 @@ public:
dd_config.ch_mult,
dd_config.num_res_blocks,
dd_config.z_channels,
use_linear_projection,
use_video_decoder));
if (use_quant) {
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
@ -482,7 +514,8 @@ public:
dd_config.num_res_blocks,
dd_config.in_channels,
dd_config.z_channels,
dd_config.double_z));
dd_config.double_z,
use_linear_projection));
if (use_quant) {
int factor = dd_config.double_z ? 2 : 1;
@ -562,13 +595,26 @@ struct AutoEncoderKL : public VAE {
AutoEncoderKL(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool decode_only = false,
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), VAE(backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_types, prefix);
: decode_only(decode_only), VAE(backend, offload_params_to_cpu) {
bool use_linear_projection = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "attn_1.proj_out.weight")) {
if (tensor_storage.n_dims == 2) {
use_linear_projection = true;
}
break;
}
}
ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder);
ae.init(params_ctx, tensor_storage_map, prefix);
}
void set_conv2d_scale(float scale) override {

59
wan.hpp
View File

@ -26,7 +26,7 @@ namespace WAN {
std::tuple<int, int, int> dilation;
bool bias;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
params["weight"] = ggml_new_tensor_4d(ctx,
GGML_TYPE_F16,
std::get<2>(kernel_size),
@ -87,9 +87,14 @@ namespace WAN {
protected:
int64_t dim;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
ggml_type wtype = GGML_TYPE_F32;
params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim);
auto iter = tensor_storage_map.find(prefix + "gamma");
if (iter != tensor_storage_map.end()) {
params["gamma"] = ggml_new_tensor(ctx, wtype, iter->second.n_dims, &iter->second.ne[0]);
} else {
params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim);
}
}
public:
@ -101,6 +106,7 @@ namespace WAN {
// assert N == 1
struct ggml_tensor* w = params["gamma"];
w = ggml_reshape_1d(ctx->ggml_ctx, w, ggml_nelements(w));
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12);
h = ggml_mul(ctx->ggml_ctx, h, w);
@ -1110,12 +1116,12 @@ namespace WAN {
WanVAERunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
bool decode_only = false,
SDVersion version = VERSION_WAN2)
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
bool decode_only = false,
SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_types, prefix);
ae.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -1256,7 +1262,7 @@ namespace WAN {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<WanVAERunner> vae = std::make_shared<WanVAERunner>(backend, false, String2GGMLType{}, "", false, VERSION_WAN2_2_TI2V);
std::shared_ptr<WanVAERunner> vae = std::make_shared<WanVAERunner>(backend, false, String2TensorStorage{}, "", false, VERSION_WAN2_2_TI2V);
{
LOG_INFO("loading from '%s'", file_path.c_str());
@ -1494,8 +1500,8 @@ namespace WAN {
protected:
int dim;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1);
}
@ -1582,8 +1588,8 @@ namespace WAN {
class VaceWanAttentionBlock : public WanAttentionBlock {
protected:
int block_id;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1);
}
@ -1634,8 +1640,8 @@ namespace WAN {
protected:
int dim;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1);
}
@ -1681,7 +1687,7 @@ namespace WAN {
int in_dim;
int flf_pos_embed_token_number;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
if (flf_pos_embed_token_number > 0) {
params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1);
}
@ -2015,12 +2021,12 @@ namespace WAN {
WanRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_WAN2)
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_WAN2)
: GGMLRunner(backend, offload_params_to_cpu) {
wan_params.num_layers = 0;
for (auto pair : tensor_types) {
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (tensor_name.find(prefix) == std::string::npos)
continue;
@ -2117,7 +2123,7 @@ namespace WAN {
LOG_INFO("%s", desc.c_str());
wan = Wan(wan_params);
wan.init(params_ctx, tensor_types, prefix);
wan.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -2254,17 +2260,16 @@ namespace WAN {
return;
}
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
item.second = model_data_type;
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "weight")) {
tensor_storage.expected_type = model_data_type;
}
}
std::shared_ptr<WanRunner> wan = std::make_shared<WanRunner>(backend,
false,
tensor_types,
tensor_storage_map,
"model.diffusion_model",
VERSION_WAN2_2_TI2V);