mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
refactor: simplify the model loading logic (#933)
* remove String2GGMLType * remove preprocess_tensor * fix clip init * simplify the logic for reading weights
This commit is contained in:
parent
6103d86e2c
commit
8f6c5c217b
54
clip.hpp
54
clip.hpp
@ -476,11 +476,12 @@ protected:
|
||||
public:
|
||||
CLIPLayer(int64_t d_model,
|
||||
int64_t n_head,
|
||||
int64_t intermediate_size)
|
||||
int64_t intermediate_size,
|
||||
bool proj_in = false)
|
||||
: d_model(d_model),
|
||||
n_head(n_head),
|
||||
intermediate_size(intermediate_size) {
|
||||
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true, true));
|
||||
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true, true, proj_in));
|
||||
|
||||
blocks["layer_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
|
||||
blocks["layer_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
|
||||
@ -509,11 +510,12 @@ public:
|
||||
CLIPEncoder(int64_t n_layer,
|
||||
int64_t d_model,
|
||||
int64_t n_head,
|
||||
int64_t intermediate_size)
|
||||
int64_t intermediate_size,
|
||||
bool proj_in = false)
|
||||
: n_layer(n_layer) {
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
std::string name = "layers." + std::to_string(i);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new CLIPLayer(d_model, n_head, intermediate_size));
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new CLIPLayer(d_model, n_head, intermediate_size, proj_in));
|
||||
}
|
||||
}
|
||||
|
||||
@ -549,10 +551,10 @@ protected:
|
||||
int64_t num_positions;
|
||||
bool force_clip_f32;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type token_wtype = GGML_TYPE_F32;
|
||||
if (!force_clip_f32) {
|
||||
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
|
||||
token_wtype = get_type(prefix + "token_embedding.weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
if (!support_get_rows(token_wtype)) {
|
||||
token_wtype = GGML_TYPE_F32;
|
||||
}
|
||||
@ -605,7 +607,8 @@ protected:
|
||||
int64_t image_size;
|
||||
int64_t num_patches;
|
||||
int64_t num_positions;
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type patch_wtype = GGML_TYPE_F16;
|
||||
enum ggml_type class_wtype = GGML_TYPE_F32;
|
||||
enum ggml_type position_wtype = GGML_TYPE_F32;
|
||||
@ -668,7 +671,7 @@ enum CLIPVersion {
|
||||
|
||||
class CLIPTextModel : public GGMLBlock {
|
||||
protected:
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
if (version == OPEN_CLIP_VIT_BIGG_14) {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
|
||||
@ -689,7 +692,8 @@ public:
|
||||
|
||||
CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
|
||||
bool with_final_ln = true,
|
||||
bool force_clip_f32 = false)
|
||||
bool force_clip_f32 = false,
|
||||
bool proj_in = false)
|
||||
: version(version), with_final_ln(with_final_ln) {
|
||||
if (version == OPEN_CLIP_VIT_H_14) {
|
||||
hidden_size = 1024;
|
||||
@ -704,7 +708,7 @@ public:
|
||||
}
|
||||
|
||||
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPEmbeddings(hidden_size, vocab_size, n_token, force_clip_f32));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in));
|
||||
blocks["final_layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
||||
}
|
||||
|
||||
@ -758,7 +762,7 @@ public:
|
||||
int32_t n_layer = 24;
|
||||
|
||||
public:
|
||||
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14) {
|
||||
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool proj_in = false) {
|
||||
if (version == OPEN_CLIP_VIT_H_14) {
|
||||
hidden_size = 1280;
|
||||
intermediate_size = 5120;
|
||||
@ -773,7 +777,7 @@ public:
|
||||
|
||||
blocks["embeddings"] = std::shared_ptr<GGMLBlock>(new CLIPVisionEmbeddings(hidden_size, num_channels, patch_size, image_size));
|
||||
blocks["pre_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size, proj_in));
|
||||
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
|
||||
}
|
||||
|
||||
@ -811,8 +815,8 @@ protected:
|
||||
int64_t out_features;
|
||||
bool transpose_weight;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
if (transpose_weight) {
|
||||
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
|
||||
} else {
|
||||
@ -845,7 +849,8 @@ public:
|
||||
|
||||
public:
|
||||
CLIPVisionModelProjection(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
|
||||
bool transpose_proj_w = false) {
|
||||
bool transpose_proj_w = false,
|
||||
bool proj_in = false) {
|
||||
if (version == OPEN_CLIP_VIT_H_14) {
|
||||
hidden_size = 1280;
|
||||
projection_dim = 1024;
|
||||
@ -853,7 +858,7 @@ public:
|
||||
hidden_size = 1664;
|
||||
}
|
||||
|
||||
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version));
|
||||
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version, proj_in));
|
||||
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
|
||||
}
|
||||
|
||||
@ -881,13 +886,24 @@ struct CLIPTextModelRunner : public GGMLRunner {
|
||||
|
||||
CLIPTextModelRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
|
||||
bool with_final_ln = true,
|
||||
bool force_clip_f32 = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, force_clip_f32) {
|
||||
model.init(params_ctx, tensor_types, prefix);
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
bool proj_in = false;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (contains(name, "self_attn.in_proj")) {
|
||||
proj_in = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
model = CLIPTextModel(version, with_final_ln, force_clip_f32, proj_in);
|
||||
model.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
|
||||
59
common.hpp
59
common.hpp
@ -182,8 +182,8 @@ protected:
|
||||
int64_t dim_in;
|
||||
int64_t dim_out;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
enum ggml_type bias_wtype = GGML_TYPE_F32;
|
||||
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
|
||||
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
|
||||
@ -408,30 +408,40 @@ protected:
|
||||
int64_t d_head;
|
||||
int64_t depth = 1; // 1
|
||||
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
|
||||
bool use_linear = false;
|
||||
|
||||
public:
|
||||
SpatialTransformer(int64_t in_channels,
|
||||
int64_t n_head,
|
||||
int64_t d_head,
|
||||
int64_t depth,
|
||||
int64_t context_dim)
|
||||
int64_t context_dim,
|
||||
bool use_linear)
|
||||
: in_channels(in_channels),
|
||||
n_head(n_head),
|
||||
d_head(d_head),
|
||||
depth(depth),
|
||||
context_dim(context_dim) {
|
||||
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
|
||||
context_dim(context_dim),
|
||||
use_linear(use_linear) {
|
||||
// disable_self_attn is always False
|
||||
int64_t inner_dim = n_head * d_head; // in_channels
|
||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
|
||||
if (use_linear) {
|
||||
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, inner_dim));
|
||||
} else {
|
||||
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
|
||||
}
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "transformer_blocks." + std::to_string(i);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false));
|
||||
}
|
||||
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
|
||||
if (use_linear) {
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, in_channels));
|
||||
} else {
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
@ -440,8 +450,8 @@ public:
|
||||
// x: [N, in_channels, h, w]
|
||||
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
|
||||
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
||||
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
|
||||
auto proj_in = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_in"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
|
||||
|
||||
auto x_in = x;
|
||||
int64_t n = x->ne[3];
|
||||
@ -450,10 +460,15 @@ public:
|
||||
int64_t inner_dim = n_head * d_head;
|
||||
|
||||
x = norm->forward(ctx, x);
|
||||
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
||||
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
||||
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
||||
if (use_linear) {
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
||||
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
||||
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
||||
} else {
|
||||
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
|
||||
x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
|
||||
}
|
||||
|
||||
for (int i = 0; i < depth; i++) {
|
||||
std::string name = "transformer_blocks." + std::to_string(i);
|
||||
@ -462,11 +477,19 @@ public:
|
||||
x = transformer_block->forward(ctx, x, context);
|
||||
}
|
||||
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
||||
if (use_linear) {
|
||||
// proj_out
|
||||
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
||||
|
||||
// proj_out
|
||||
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
||||
} else {
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
|
||||
|
||||
// proj_out
|
||||
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
|
||||
}
|
||||
|
||||
x = ggml_add(ctx->ggml_ctx, x, x_in);
|
||||
return x;
|
||||
@ -475,7 +498,7 @@ public:
|
||||
|
||||
class AlphaBlender : public GGMLBlock {
|
||||
protected:
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
|
||||
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
|
||||
|
||||
@ -63,19 +63,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
|
||||
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string& embd_dir,
|
||||
SDVersion version = VERSION_SD1,
|
||||
PMVersion pv = PM_VERSION_1)
|
||||
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
|
||||
bool force_clip_f32 = embd_dir.size() > 0;
|
||||
if (sd_version_is_sd1(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32);
|
||||
} else if (sd_version_is_sd2(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32);
|
||||
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32);
|
||||
text_model = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32);
|
||||
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32);
|
||||
}
|
||||
}
|
||||
|
||||
@ -623,9 +623,21 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
|
||||
FrozenCLIPVisionEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) {
|
||||
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
|
||||
const String2TensorStorage& tensor_storage_map = {})
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
std::string prefix = "cond_stage_model.transformer";
|
||||
bool proj_in = false;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (contains(name, "self_attn.in_proj")) {
|
||||
proj_in = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
vision_model = CLIPVisionModelProjection(OPEN_CLIP_VIT_H_14, false, proj_in);
|
||||
vision_model.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -673,12 +685,12 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
|
||||
SD3CLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
const String2TensorStorage& tensor_storage_map = {})
|
||||
: clip_g_tokenizer(0) {
|
||||
bool use_clip_l = false;
|
||||
bool use_clip_g = false;
|
||||
bool use_t5 = false;
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
|
||||
use_clip_l = true;
|
||||
} else if (pair.first.find("text_encoders.clip_g") != std::string::npos) {
|
||||
@ -692,13 +704,13 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
return;
|
||||
}
|
||||
if (use_clip_l) {
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
|
||||
}
|
||||
if (use_clip_g) {
|
||||
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
|
||||
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
|
||||
}
|
||||
if (use_t5) {
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1082,10 +1094,10 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
|
||||
FluxCLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {}) {
|
||||
const String2TensorStorage& tensor_storage_map = {}) {
|
||||
bool use_clip_l = false;
|
||||
bool use_t5 = false;
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
|
||||
use_clip_l = true;
|
||||
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
|
||||
@ -1099,12 +1111,12 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
}
|
||||
|
||||
if (use_clip_l) {
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
|
||||
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
|
||||
} else {
|
||||
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
|
||||
}
|
||||
if (use_t5) {
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer");
|
||||
} else {
|
||||
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
|
||||
}
|
||||
@ -1342,13 +1354,13 @@ struct T5CLIPEmbedder : public Conditioner {
|
||||
|
||||
T5CLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
bool use_mask = false,
|
||||
int mask_pad = 1,
|
||||
bool is_umt5 = false)
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
bool use_mask = false,
|
||||
int mask_pad = 1,
|
||||
bool is_umt5 = false)
|
||||
: use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) {
|
||||
bool use_t5 = false;
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
|
||||
use_t5 = true;
|
||||
}
|
||||
@ -1358,7 +1370,7 @@ struct T5CLIPEmbedder : public Conditioner {
|
||||
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
|
||||
return;
|
||||
} else {
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5);
|
||||
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_storage_map, "text_encoders.t5xxl.transformer", is_umt5);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1549,12 +1561,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
||||
|
||||
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
bool enable_vision = false) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
bool enable_vision = false) {
|
||||
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_types,
|
||||
tensor_storage_map,
|
||||
"text_encoders.qwen2vl",
|
||||
enable_vision);
|
||||
}
|
||||
|
||||
@ -27,6 +27,7 @@ protected:
|
||||
int num_heads = 8;
|
||||
int num_head_channels = -1; // channels // num_heads
|
||||
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
|
||||
bool use_linear_projection = false;
|
||||
|
||||
public:
|
||||
int model_channels = 320;
|
||||
@ -82,7 +83,7 @@ public:
|
||||
int64_t d_head,
|
||||
int64_t depth,
|
||||
int64_t context_dim) -> SpatialTransformer* {
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
|
||||
};
|
||||
|
||||
auto make_zero_conv = [&](int64_t channels) {
|
||||
@ -318,10 +319,10 @@ struct ControlNet : public GGMLRunner {
|
||||
|
||||
ControlNet(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_SD1)
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
SDVersion version = VERSION_SD1)
|
||||
: GGMLRunner(backend, offload_params_to_cpu), control_net(version) {
|
||||
control_net.init(params_ctx, tensor_types, "");
|
||||
control_net.init(params_ctx, tensor_storage_map, "");
|
||||
}
|
||||
|
||||
~ControlNet() override {
|
||||
|
||||
@ -44,9 +44,9 @@ struct UNetModel : public DiffusionModel {
|
||||
|
||||
UNetModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_SD1)
|
||||
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
SDVersion version = VERSION_SD1)
|
||||
: unet(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -102,8 +102,8 @@ struct MMDiTModel : public DiffusionModel {
|
||||
|
||||
MMDiTModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
|
||||
const String2TensorStorage& tensor_storage_map = {})
|
||||
: mmdit(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model") {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -158,10 +158,10 @@ struct FluxModel : public DiffusionModel {
|
||||
|
||||
FluxModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool use_mask = false)
|
||||
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool use_mask = false)
|
||||
: flux(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version, use_mask) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -221,10 +221,10 @@ struct WanModel : public DiffusionModel {
|
||||
|
||||
WanModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -283,10 +283,10 @@ struct QwenImageModel : public DiffusionModel {
|
||||
|
||||
QwenImageModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
|
||||
@ -156,7 +156,7 @@ struct ESRGAN : public GGMLRunner {
|
||||
|
||||
ESRGAN(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
const String2TensorStorage& tensor_storage_map = {})
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
// rrdb_net will be created in load_from_file
|
||||
}
|
||||
|
||||
25
flux.hpp
25
flux.hpp
@ -37,7 +37,7 @@ namespace Flux {
|
||||
int64_t hidden_size;
|
||||
float eps;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
ggml_type wtype = GGML_TYPE_F32;
|
||||
params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||
}
|
||||
@ -1115,10 +1115,10 @@ namespace Flux {
|
||||
|
||||
FluxRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool use_mask = false)
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool use_mask = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
|
||||
flux_params.version = version;
|
||||
flux_params.guidance_embed = false;
|
||||
@ -1134,7 +1134,7 @@ namespace Flux {
|
||||
flux_params.in_channels = 3;
|
||||
flux_params.patch_size = 16;
|
||||
}
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (!starts_with(tensor_name, prefix))
|
||||
continue;
|
||||
@ -1172,7 +1172,7 @@ namespace Flux {
|
||||
}
|
||||
|
||||
flux = Flux(flux_params);
|
||||
flux.init(params_ctx, tensor_types, prefix);
|
||||
flux.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -1403,17 +1403,16 @@ namespace Flux {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (ends_with(item.first, "weight")) {
|
||||
// item.second = model_data_type;
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (ends_with(name, "weight")) {
|
||||
tensor_storage.expected_type = model_data_type;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend,
|
||||
false,
|
||||
tensor_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
VERSION_CHROMA_RADIANCE,
|
||||
false);
|
||||
|
||||
@ -1460,8 +1460,6 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
||||
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||
#define MAX_GRAPH_SIZE 327680
|
||||
|
||||
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
||||
|
||||
struct GGMLRunnerContext {
|
||||
ggml_backend_t backend = nullptr;
|
||||
ggml_context* ggml_ctx = nullptr;
|
||||
@ -1900,30 +1898,36 @@ protected:
|
||||
GGMLBlockMap blocks;
|
||||
ParameterMap params;
|
||||
|
||||
ggml_type get_type(const std::string& name, const String2GGMLType& tensor_types, ggml_type default_type) {
|
||||
auto iter = tensor_types.find(name);
|
||||
if (iter != tensor_types.end()) {
|
||||
return iter->second;
|
||||
ggml_type get_type(const std::string& name, const String2TensorStorage& tensor_storage_map, ggml_type default_type) {
|
||||
ggml_type wtype = default_type;
|
||||
auto iter = tensor_storage_map.find(name);
|
||||
if (iter != tensor_storage_map.end()) {
|
||||
const TensorStorage& tensor_storage = iter->second;
|
||||
if (tensor_storage.expected_type != GGML_TYPE_COUNT) {
|
||||
wtype = tensor_storage.expected_type;
|
||||
} else {
|
||||
wtype = tensor_storage.type;
|
||||
}
|
||||
}
|
||||
return default_type;
|
||||
return wtype;
|
||||
}
|
||||
|
||||
void init_blocks(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||
void init_blocks(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
|
||||
for (auto& pair : blocks) {
|
||||
auto& block = pair.second;
|
||||
block->init(ctx, tensor_types, prefix + pair.first);
|
||||
block->init(ctx, tensor_storage_map, prefix + pair.first);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {}
|
||||
virtual void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {}
|
||||
|
||||
public:
|
||||
void init(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
||||
void init(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") {
|
||||
if (prefix.size() > 0) {
|
||||
prefix = prefix + ".";
|
||||
}
|
||||
init_blocks(ctx, tensor_types, prefix);
|
||||
init_params(ctx, tensor_types, prefix);
|
||||
init_blocks(ctx, tensor_storage_map, prefix);
|
||||
init_params(ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
size_t get_params_num() {
|
||||
@ -2001,8 +2005,8 @@ protected:
|
||||
bool force_prec_f32;
|
||||
float scale;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
|
||||
wtype = GGML_TYPE_F32;
|
||||
}
|
||||
@ -2049,8 +2053,8 @@ class Embedding : public UnaryBlock {
|
||||
protected:
|
||||
int64_t embedding_dim;
|
||||
int64_t num_embeddings;
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
if (!support_get_rows(wtype)) {
|
||||
wtype = GGML_TYPE_F32;
|
||||
}
|
||||
@ -2093,7 +2097,7 @@ protected:
|
||||
bool bias;
|
||||
float scale = 1.f;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = GGML_TYPE_F16;
|
||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
|
||||
if (bias) {
|
||||
@ -2157,7 +2161,7 @@ protected:
|
||||
int64_t dilation;
|
||||
bool bias;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = GGML_TYPE_F16;
|
||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d
|
||||
if (bias) {
|
||||
@ -2204,7 +2208,7 @@ protected:
|
||||
std::tuple<int, int, int> dilation;
|
||||
bool bias;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = GGML_TYPE_F16;
|
||||
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||
wtype,
|
||||
@ -2253,7 +2257,7 @@ protected:
|
||||
bool elementwise_affine;
|
||||
bool bias;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
if (elementwise_affine) {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape);
|
||||
@ -2295,7 +2299,7 @@ protected:
|
||||
float eps;
|
||||
bool affine;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
if (affine) {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
enum ggml_type bias_wtype = GGML_TYPE_F32;
|
||||
@ -2336,7 +2340,7 @@ protected:
|
||||
int64_t hidden_size;
|
||||
float eps;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||
}
|
||||
@ -2359,9 +2363,11 @@ class MultiheadAttention : public GGMLBlock {
|
||||
protected:
|
||||
int64_t embed_dim;
|
||||
int64_t n_head;
|
||||
bool proj_in;
|
||||
std::string q_proj_name;
|
||||
std::string k_proj_name;
|
||||
std::string v_proj_name;
|
||||
std::string in_proj_name;
|
||||
std::string out_proj_name;
|
||||
|
||||
public:
|
||||
@ -2369,19 +2375,27 @@ public:
|
||||
int64_t n_head,
|
||||
bool qkv_proj_bias = true,
|
||||
bool out_proj_bias = true,
|
||||
bool proj_in = false,
|
||||
std::string q_proj_name = "q_proj",
|
||||
std::string k_proj_name = "k_proj",
|
||||
std::string v_proj_name = "v_proj",
|
||||
std::string in_proj_name = "in_proj",
|
||||
std::string out_proj_name = "out_proj")
|
||||
: embed_dim(embed_dim),
|
||||
n_head(n_head),
|
||||
proj_in(proj_in),
|
||||
q_proj_name(q_proj_name),
|
||||
k_proj_name(k_proj_name),
|
||||
v_proj_name(v_proj_name),
|
||||
in_proj_name(in_proj_name),
|
||||
out_proj_name(out_proj_name) {
|
||||
blocks[q_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
|
||||
blocks[k_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
|
||||
blocks[v_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
|
||||
if (proj_in) {
|
||||
blocks[in_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim * 3, qkv_proj_bias));
|
||||
} else {
|
||||
blocks[q_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
|
||||
blocks[k_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
|
||||
blocks[v_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, qkv_proj_bias));
|
||||
}
|
||||
blocks[out_proj_name] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, out_proj_bias));
|
||||
}
|
||||
|
||||
@ -2389,14 +2403,27 @@ public:
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x,
|
||||
bool mask = false) {
|
||||
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
|
||||
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks[k_proj_name]);
|
||||
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks[v_proj_name]);
|
||||
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks[out_proj_name]);
|
||||
|
||||
struct ggml_tensor* q = q_proj->forward(ctx, x);
|
||||
struct ggml_tensor* k = k_proj->forward(ctx, x);
|
||||
struct ggml_tensor* v = v_proj->forward(ctx, x);
|
||||
ggml_tensor* q;
|
||||
ggml_tensor* k;
|
||||
ggml_tensor* v;
|
||||
if (proj_in) {
|
||||
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks[in_proj_name]);
|
||||
auto qkv = in_proj->forward(ctx, x);
|
||||
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
|
||||
q = qkv_vec[0];
|
||||
k = qkv_vec[1];
|
||||
v = qkv_vec[2];
|
||||
} else {
|
||||
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks[q_proj_name]);
|
||||
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks[k_proj_name]);
|
||||
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks[v_proj_name]);
|
||||
|
||||
q = q_proj->forward(ctx, x);
|
||||
k = k_proj->forward(ctx, x);
|
||||
v = v_proj->forward(ctx, x);
|
||||
}
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
|
||||
|
||||
|
||||
15
mmdit.hpp
15
mmdit.hpp
@ -633,13 +633,13 @@ protected:
|
||||
int64_t hidden_size;
|
||||
std::string qk_norm;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
|
||||
}
|
||||
|
||||
public:
|
||||
MMDiT(const String2GGMLType& tensor_types = {}) {
|
||||
MMDiT(const String2TensorStorage& tensor_storage_map = {}) {
|
||||
// input_size is always None
|
||||
// learn_sigma is always False
|
||||
// register_length is alwalys 0
|
||||
@ -652,8 +652,7 @@ public:
|
||||
// pos_embed_offset is not used
|
||||
// context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
|
||||
|
||||
// read tensors from tensor_types
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
|
||||
continue;
|
||||
@ -852,10 +851,10 @@ struct MMDiTRunner : public GGMLRunner {
|
||||
|
||||
MMDiTRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "")
|
||||
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
|
||||
mmdit.init(params_ctx, tensor_types, prefix);
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "")
|
||||
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_storage_map) {
|
||||
mmdit.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
|
||||
415
model.cpp
415
model.cpp
@ -140,7 +140,9 @@ std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
|
||||
{"model.visual.proj", "transformer.visual_projection.weight"},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
|
||||
std::unordered_map<std::string, std::string> open_clip_to_hf_clip_resblock = {
|
||||
{"attn.in_proj_bias", "self_attn.in_proj.bias"},
|
||||
{"attn.in_proj_weight", "self_attn.in_proj.weight"},
|
||||
{"attn.out_proj.bias", "self_attn.out_proj.bias"},
|
||||
{"attn.out_proj.weight", "self_attn.out_proj.weight"},
|
||||
{"ln_1.bias", "layer_norm1.bias"},
|
||||
@ -351,10 +353,8 @@ std::string convert_cond_model_name(const std::string& name) {
|
||||
std::string idx = remain.substr(0, remain.find("."));
|
||||
std::string suffix = remain.substr(idx.length() + 1);
|
||||
|
||||
if (suffix == "attn.in_proj_weight" || suffix == "attn.in_proj_bias") {
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + suffix;
|
||||
} else if (open_clip_to_hk_clip_resblock.find(suffix) != open_clip_to_hk_clip_resblock.end()) {
|
||||
std::string new_suffix = open_clip_to_hk_clip_resblock[suffix];
|
||||
if (open_clip_to_hf_clip_resblock.find(suffix) != open_clip_to_hf_clip_resblock.end()) {
|
||||
std::string new_suffix = open_clip_to_hf_clip_resblock[suffix];
|
||||
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix;
|
||||
}
|
||||
}
|
||||
@ -740,80 +740,6 @@ std::string convert_tensor_name(std::string name) {
|
||||
return new_name;
|
||||
}
|
||||
|
||||
void add_preprocess_tensor_storage_types(String2GGMLType& tensor_storages_types, std::string name, enum ggml_type type) {
|
||||
std::string new_name = convert_tensor_name(name);
|
||||
|
||||
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {
|
||||
size_t prefix_size = new_name.find("attn.in_proj_weight");
|
||||
std::string prefix = new_name.substr(0, prefix_size);
|
||||
tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type;
|
||||
tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type;
|
||||
tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type;
|
||||
} else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) {
|
||||
size_t prefix_size = new_name.find("attn.in_proj_bias");
|
||||
std::string prefix = new_name.substr(0, prefix_size);
|
||||
tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type;
|
||||
tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type;
|
||||
tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type;
|
||||
} else {
|
||||
tensor_storages_types[new_name] = type;
|
||||
}
|
||||
}
|
||||
|
||||
void preprocess_tensor(TensorStorage tensor_storage,
|
||||
std::vector<TensorStorage>& processed_tensor_storages) {
|
||||
std::vector<TensorStorage> result;
|
||||
std::string new_name = convert_tensor_name(tensor_storage.name);
|
||||
|
||||
// convert unet transformer linear to conv2d 1x1
|
||||
if (starts_with(new_name, "model.diffusion_model.") &&
|
||||
!starts_with(new_name, "model.diffusion_model.proj_out.") &&
|
||||
(ends_with(new_name, "proj_in.weight") || ends_with(new_name, "proj_out.weight"))) {
|
||||
tensor_storage.unsqueeze();
|
||||
}
|
||||
|
||||
// convert vae attn block linear to conv2d 1x1
|
||||
if (starts_with(new_name, "first_stage_model.") && new_name.find("attn_1") != std::string::npos) {
|
||||
tensor_storage.unsqueeze();
|
||||
}
|
||||
|
||||
// wan vae
|
||||
if (ends_with(new_name, "gamma")) {
|
||||
tensor_storage.reverse_ne();
|
||||
tensor_storage.n_dims = 1;
|
||||
tensor_storage.reverse_ne();
|
||||
}
|
||||
|
||||
tensor_storage.name = new_name;
|
||||
|
||||
if (new_name.find("cond_stage_model") != std::string::npos &&
|
||||
ends_with(new_name, "attn.in_proj_weight")) {
|
||||
size_t prefix_size = new_name.find("attn.in_proj_weight");
|
||||
std::string prefix = new_name.substr(0, prefix_size);
|
||||
|
||||
std::vector<TensorStorage> chunks = tensor_storage.chunk(3);
|
||||
chunks[0].name = prefix + "self_attn.q_proj.weight";
|
||||
chunks[1].name = prefix + "self_attn.k_proj.weight";
|
||||
chunks[2].name = prefix + "self_attn.v_proj.weight";
|
||||
|
||||
processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());
|
||||
|
||||
} else if (new_name.find("cond_stage_model") != std::string::npos &&
|
||||
ends_with(new_name, "attn.in_proj_bias")) {
|
||||
size_t prefix_size = new_name.find("attn.in_proj_bias");
|
||||
std::string prefix = new_name.substr(0, prefix_size);
|
||||
|
||||
std::vector<TensorStorage> chunks = tensor_storage.chunk(3);
|
||||
chunks[0].name = prefix + "self_attn.q_proj.bias";
|
||||
chunks[1].name = prefix + "self_attn.k_proj.bias";
|
||||
chunks[2].name = prefix + "self_attn.v_proj.bias";
|
||||
|
||||
processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());
|
||||
} else {
|
||||
processed_tensor_storages.push_back(tensor_storage);
|
||||
}
|
||||
}
|
||||
|
||||
float bf16_to_f32(uint16_t bfloat16) {
|
||||
uint32_t val_bits = (static_cast<uint32_t>(bfloat16) << 16);
|
||||
return *reinterpret_cast<float*>(&val_bits);
|
||||
@ -989,44 +915,10 @@ void convert_tensor(void* src,
|
||||
|
||||
/*================================================= ModelLoader ==================================================*/
|
||||
|
||||
// ported from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py#L16
|
||||
std::map<char, int> unicode_to_byte() {
|
||||
std::map<int, char> byte_to_unicode;
|
||||
|
||||
// List of utf-8 byte ranges
|
||||
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
||||
byte_to_unicode[b] = static_cast<char>(b);
|
||||
}
|
||||
|
||||
for (int b = 49825; b <= 49836; ++b) {
|
||||
byte_to_unicode[b] = static_cast<char>(b);
|
||||
}
|
||||
|
||||
for (int b = 49838; b <= 50111; ++b) {
|
||||
byte_to_unicode[b] = static_cast<char>(b);
|
||||
}
|
||||
// printf("%d %d %d %d\n", static_cast<int>('¡'), static_cast<int>('¬'), static_cast<int>('®'), static_cast<int>('ÿ'));
|
||||
// exit(1);
|
||||
|
||||
int n = 0;
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
if (byte_to_unicode.find(b) == byte_to_unicode.end()) {
|
||||
byte_to_unicode[b] = static_cast<char>(256 + n);
|
||||
n++;
|
||||
}
|
||||
}
|
||||
|
||||
// byte_encoder = bytes_to_unicode()
|
||||
// byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
std::map<char, int> byte_decoder;
|
||||
|
||||
for (const auto& entry : byte_to_unicode) {
|
||||
byte_decoder[entry.second] = entry.first;
|
||||
}
|
||||
|
||||
byte_to_unicode.clear();
|
||||
|
||||
return byte_decoder;
|
||||
void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) {
|
||||
TensorStorage copy = tensor_storage;
|
||||
copy.name = convert_tensor_name(copy.name);
|
||||
tensor_storage_map[copy.name] = std::move(copy);
|
||||
}
|
||||
|
||||
bool is_zip_file(const std::string& file_path) {
|
||||
@ -1156,8 +1048,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||
|
||||
// LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str());
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
add_tensor_storage(tensor_storage);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -1182,8 +1073,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
|
||||
|
||||
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
add_tensor_storage(tensor_storage);
|
||||
}
|
||||
|
||||
gguf_free(ctx_gguf_);
|
||||
@ -1350,8 +1240,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
|
||||
}
|
||||
|
||||
tensor_storages.push_back(tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
add_tensor_storage(tensor_storage);
|
||||
|
||||
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
|
||||
}
|
||||
@ -1370,11 +1259,13 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
|
||||
if (!init_from_safetensors_file(unet_path, "unet.")) {
|
||||
return false;
|
||||
}
|
||||
for (auto ts : tensor_storages) {
|
||||
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (name.find("add_embedding") != std::string::npos || name.find("label_emb") != std::string::npos) {
|
||||
// probably SDXL
|
||||
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
String2TensorStorage new_tensor_storage_map;
|
||||
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
int len = 34;
|
||||
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
|
||||
if (pos == std::string::npos) {
|
||||
@ -1382,11 +1273,15 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
|
||||
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
|
||||
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
std::string new_name = "model.diffusion_model.output_blocks.2.2.conv" + name.substr(len);
|
||||
LOG_DEBUG("NEW NAME: %s", new_name.c_str());
|
||||
tensor_storage.name = new_name;
|
||||
new_tensor_storage_map[new_name] = tensor_storage;
|
||||
} else {
|
||||
new_tensor_storage_map[name] = tensor_storage;
|
||||
}
|
||||
}
|
||||
tensor_storage_map = new_tensor_storage_map;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -1712,8 +1607,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
|
||||
name = prefix + name;
|
||||
}
|
||||
reader.tensor_storage.name = name;
|
||||
tensor_storages.push_back(reader.tensor_storage);
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
|
||||
add_tensor_storage(reader.tensor_storage);
|
||||
|
||||
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
|
||||
// reset
|
||||
@ -1767,15 +1661,6 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelLoader::model_is_unet() {
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight, input_block_weight;
|
||||
|
||||
@ -1789,7 +1674,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
bool has_img_emb = false;
|
||||
bool has_middle_block_1 = false;
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!(is_xl)) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||
is_flux = true;
|
||||
@ -1910,7 +1795,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
|
||||
std::map<ggml_type, uint32_t> ModelLoader::get_wtype_stat() {
|
||||
std::map<ggml_type, uint32_t> wtype_stat;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
@ -1927,7 +1812,7 @@ std::map<ggml_type, uint32_t> ModelLoader::get_wtype_stat() {
|
||||
|
||||
std::map<ggml_type, uint32_t> ModelLoader::get_conditioner_wtype_stat() {
|
||||
std::map<ggml_type, uint32_t> wtype_stat;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
@ -1951,7 +1836,7 @@ std::map<ggml_type, uint32_t> ModelLoader::get_conditioner_wtype_stat() {
|
||||
|
||||
std::map<ggml_type, uint32_t> ModelLoader::get_diffusion_model_wtype_stat() {
|
||||
std::map<ggml_type, uint32_t> wtype_stat;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
@ -1972,7 +1857,7 @@ std::map<ggml_type, uint32_t> ModelLoader::get_diffusion_model_wtype_stat() {
|
||||
|
||||
std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
|
||||
std::map<ggml_type, uint32_t> wtype_stat;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
@ -1993,26 +1878,14 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
|
||||
}
|
||||
|
||||
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
|
||||
for (auto& pair : tensor_storages_types) {
|
||||
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
|
||||
bool found = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
std::map<std::string, ggml_type> temp;
|
||||
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
|
||||
for (auto& preprocessed_name : temp) {
|
||||
if (preprocessed_name.first == pair.first) {
|
||||
if (tensor_should_be_converted(tensor_storage, wtype)) {
|
||||
pair.second = wtype;
|
||||
}
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (!tensor_should_be_converted(tensor_storage, wtype)) {
|
||||
continue;
|
||||
}
|
||||
tensor_storage.expected_type = wtype;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2047,74 +1920,13 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
|
||||
LOG_DEBUG("using %d threads for model loading", num_threads_to_use);
|
||||
|
||||
int64_t start_time = ggml_time_ms();
|
||||
|
||||
std::vector<TensorStorage> processed_tensor_storages;
|
||||
|
||||
{
|
||||
struct IndexedStorage {
|
||||
size_t index;
|
||||
TensorStorage ts;
|
||||
};
|
||||
|
||||
std::mutex vec_mutex;
|
||||
std::vector<IndexedStorage> all_results;
|
||||
|
||||
int n_threads = std::min(num_threads_to_use, (int)tensor_storages.size());
|
||||
if (n_threads < 1) {
|
||||
n_threads = 1;
|
||||
}
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
for (int i = 0; i < n_threads; ++i) {
|
||||
workers.emplace_back([&, thread_id = i]() {
|
||||
std::vector<IndexedStorage> local_results;
|
||||
std::vector<TensorStorage> temp_storages;
|
||||
|
||||
for (size_t j = thread_id; j < tensor_storages.size(); j += n_threads) {
|
||||
const auto& tensor_storage = tensor_storages[j];
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
temp_storages.clear();
|
||||
preprocess_tensor(tensor_storage, temp_storages);
|
||||
|
||||
for (const auto& ts : temp_storages) {
|
||||
local_results.push_back({j, ts});
|
||||
}
|
||||
}
|
||||
|
||||
if (!local_results.empty()) {
|
||||
std::lock_guard<std::mutex> lock(vec_mutex);
|
||||
all_results.insert(all_results.end(),
|
||||
local_results.begin(), local_results.end());
|
||||
}
|
||||
});
|
||||
}
|
||||
for (auto& w : workers) {
|
||||
w.join();
|
||||
}
|
||||
|
||||
std::vector<IndexedStorage> deduplicated;
|
||||
deduplicated.reserve(all_results.size());
|
||||
std::unordered_map<std::string, size_t> name_to_pos;
|
||||
for (auto& entry : all_results) {
|
||||
auto it = name_to_pos.find(entry.ts.name);
|
||||
if (it == name_to_pos.end()) {
|
||||
name_to_pos.emplace(entry.ts.name, deduplicated.size());
|
||||
deduplicated.push_back(entry);
|
||||
} else if (deduplicated[it->second].index < entry.index) {
|
||||
deduplicated[it->second] = entry;
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(deduplicated.begin(), deduplicated.end(), [](const IndexedStorage& a, const IndexedStorage& b) {
|
||||
return a.index < b.index;
|
||||
});
|
||||
|
||||
processed_tensor_storages.reserve(deduplicated.size());
|
||||
for (auto& entry : deduplicated) {
|
||||
processed_tensor_storages.push_back(entry.ts);
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
processed_tensor_storages.push_back(tensor_storage);
|
||||
}
|
||||
|
||||
process_time_ms = ggml_time_ms() - start_time;
|
||||
@ -2231,107 +2043,72 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
|
||||
}
|
||||
};
|
||||
|
||||
char* read_buf = nullptr;
|
||||
char* target_buf = nullptr;
|
||||
char* convert_buf = nullptr;
|
||||
if (dst_tensor->buffer == nullptr || ggml_backend_buffer_is_host(dst_tensor->buffer)) {
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes());
|
||||
if (tensor_storage.is_f64 || tensor_storage.is_i64) {
|
||||
read_buffer.resize(tensor_storage.nbytes_to_read());
|
||||
read_data((char*)read_buffer.data(), nbytes_to_read);
|
||||
read_buf = (char*)read_buffer.data();
|
||||
} else {
|
||||
read_data((char*)dst_tensor->data, nbytes_to_read);
|
||||
read_buf = (char*)dst_tensor->data;
|
||||
}
|
||||
t1 = ggml_time_ms();
|
||||
read_time_ms.fetch_add(t1 - t0);
|
||||
|
||||
t0 = ggml_time_ms();
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
// inplace op
|
||||
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
// inplace op
|
||||
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f64) {
|
||||
f64_to_f32_vec((double*)read_buffer.data(), (float*)dst_tensor->data, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_i64) {
|
||||
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements());
|
||||
}
|
||||
t1 = ggml_time_ms();
|
||||
convert_time_ms.fetch_add(t1 - t0);
|
||||
target_buf = (char*)dst_tensor->data;
|
||||
} else {
|
||||
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
|
||||
read_data((char*)read_buffer.data(), nbytes_to_read);
|
||||
t1 = ggml_time_ms();
|
||||
read_time_ms.fetch_add(t1 - t0);
|
||||
|
||||
t0 = ggml_time_ms();
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
// inplace op
|
||||
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
// inplace op
|
||||
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f64) {
|
||||
// inplace op
|
||||
f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_i64) {
|
||||
// inplace op
|
||||
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
|
||||
t1 = ggml_time_ms();
|
||||
convert_time_ms.fetch_add(t1 - t0);
|
||||
read_buf = (char*)read_buffer.data();
|
||||
target_buf = read_buf;
|
||||
convert_buf = (char*)dst_tensor->data;
|
||||
}
|
||||
} else {
|
||||
read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read()));
|
||||
read_data((char*)read_buffer.data(), nbytes_to_read);
|
||||
t1 = ggml_time_ms();
|
||||
read_time_ms.fetch_add(t1 - t0);
|
||||
|
||||
t0 = ggml_time_ms();
|
||||
if (tensor_storage.is_bf16) {
|
||||
// inplace op
|
||||
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
// inplace op
|
||||
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
// inplace op
|
||||
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f64) {
|
||||
// inplace op
|
||||
f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_i64) {
|
||||
// inplace op
|
||||
i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements());
|
||||
}
|
||||
|
||||
if (tensor_storage.type == dst_tensor->type) {
|
||||
// copy to device memory
|
||||
t1 = ggml_time_ms();
|
||||
convert_time_ms.fetch_add(t1 - t0);
|
||||
t0 = ggml_time_ms();
|
||||
ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||
t1 = ggml_time_ms();
|
||||
copy_to_backend_time_ms.fetch_add(t1 - t0);
|
||||
} else {
|
||||
// convert first, then copy to device memory
|
||||
read_buf = (char*)read_buffer.data();
|
||||
target_buf = read_buf;
|
||||
|
||||
if (tensor_storage.type != dst_tensor->type) {
|
||||
convert_buffer.resize(ggml_nbytes(dst_tensor));
|
||||
convert_tensor((void*)read_buffer.data(), tensor_storage.type, (void*)convert_buffer.data(), dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
|
||||
t1 = ggml_time_ms();
|
||||
convert_time_ms.fetch_add(t1 - t0);
|
||||
t0 = ggml_time_ms();
|
||||
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
|
||||
t1 = ggml_time_ms();
|
||||
copy_to_backend_time_ms.fetch_add(t1 - t0);
|
||||
convert_buf = (char*)convert_buffer.data();
|
||||
}
|
||||
}
|
||||
|
||||
t0 = ggml_time_ms();
|
||||
read_data(read_buf, nbytes_to_read);
|
||||
t1 = ggml_time_ms();
|
||||
read_time_ms.fetch_add(t1 - t0);
|
||||
|
||||
t0 = ggml_time_ms();
|
||||
if (tensor_storage.is_bf16) {
|
||||
bf16_to_f32_vec((uint16_t*)read_buf, (float*)target_buf, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e4m3) {
|
||||
f8_e4m3_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f8_e5m2) {
|
||||
f8_e5m2_to_f16_vec((uint8_t*)read_buf, (uint16_t*)target_buf, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_f64) {
|
||||
f64_to_f32_vec((double*)read_buf, (float*)target_buf, tensor_storage.nelements());
|
||||
} else if (tensor_storage.is_i64) {
|
||||
i64_to_i32_vec((int64_t*)read_buf, (int32_t*)target_buf, tensor_storage.nelements());
|
||||
}
|
||||
if (tensor_storage.type != dst_tensor->type) {
|
||||
convert_tensor((void*)target_buf,
|
||||
tensor_storage.type,
|
||||
convert_buf,
|
||||
dst_tensor->type,
|
||||
(int)tensor_storage.nelements() / (int)tensor_storage.ne[0],
|
||||
(int)tensor_storage.ne[0]);
|
||||
} else {
|
||||
convert_buf = read_buf;
|
||||
}
|
||||
t1 = ggml_time_ms();
|
||||
convert_time_ms.fetch_add(t1 - t0);
|
||||
|
||||
if (dst_tensor->buffer != nullptr && !ggml_backend_buffer_is_host(dst_tensor->buffer)) {
|
||||
t0 = ggml_time_ms();
|
||||
ggml_backend_tensor_set(dst_tensor, convert_buf, 0, ggml_nbytes(dst_tensor));
|
||||
t1 = ggml_time_ms();
|
||||
copy_to_backend_time_ms.fetch_add(t1 - t0);
|
||||
}
|
||||
}
|
||||
if (zip != nullptr) {
|
||||
zip_close(zip);
|
||||
@ -2520,7 +2297,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
|
||||
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
|
||||
auto backend = ggml_backend_cpu_init();
|
||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||
mem_size += tensor_storages.size() * ggml_tensor_overhead();
|
||||
mem_size += tensor_storage_map.size() * ggml_tensor_overhead();
|
||||
mem_size += get_params_mem_size(backend, type);
|
||||
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
|
||||
ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false});
|
||||
@ -2587,14 +2364,10 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
|
||||
}
|
||||
int64_t mem_size = 0;
|
||||
std::vector<TensorStorage> processed_tensor_storages;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
for (auto [name, tensor_storage] : tensor_storage_map) {
|
||||
if (is_unused_tensor(tensor_storage.name)) {
|
||||
continue;
|
||||
}
|
||||
preprocess_tensor(tensor_storage, processed_tensor_storages);
|
||||
}
|
||||
|
||||
for (auto& tensor_storage : processed_tensor_storages) {
|
||||
if (tensor_should_be_converted(tensor_storage, type)) {
|
||||
tensor_storage.type = type;
|
||||
}
|
||||
|
||||
24
model.h
24
model.h
@ -65,6 +65,15 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_unet(SDVersion version) {
|
||||
if (sd_version_is_sd1(version) ||
|
||||
sd_version_is_sd2(version) ||
|
||||
sd_version_is_sdxl(version)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_sd3(SDVersion version) {
|
||||
if (version == VERSION_SD3) {
|
||||
return true;
|
||||
@ -134,6 +143,7 @@ enum PMVersion {
|
||||
struct TensorStorage {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
ggml_type expected_type = GGML_TYPE_COUNT;
|
||||
bool is_bf16 = false;
|
||||
bool is_f8_e4m3 = false;
|
||||
bool is_f8_e5m2 = false;
|
||||
@ -242,12 +252,14 @@ struct TensorStorage {
|
||||
|
||||
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
|
||||
|
||||
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
||||
typedef std::map<std::string, TensorStorage> String2TensorStorage;
|
||||
|
||||
class ModelLoader {
|
||||
protected:
|
||||
std::vector<std::string> file_paths_;
|
||||
std::vector<TensorStorage> tensor_storages;
|
||||
String2TensorStorage tensor_storage_map;
|
||||
|
||||
void add_tensor_storage(const TensorStorage& tensor_storage);
|
||||
|
||||
bool parse_data_pkl(uint8_t* buffer,
|
||||
size_t buffer_size,
|
||||
@ -262,15 +274,13 @@ protected:
|
||||
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
|
||||
|
||||
public:
|
||||
String2GGMLType tensor_storages_types;
|
||||
|
||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||
bool model_is_unet();
|
||||
SDVersion get_sd_version();
|
||||
std::map<ggml_type, uint32_t> get_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
|
||||
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
|
||||
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
|
||||
void set_wtype_override(ggml_type wtype, std::string prefix = "");
|
||||
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
|
||||
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
|
||||
@ -279,8 +289,8 @@ public:
|
||||
|
||||
std::vector<std::string> get_tensor_names() const {
|
||||
std::vector<std::string> names;
|
||||
for (const auto& ts : tensor_storages) {
|
||||
names.push_back(ts.name);
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
names.push_back(name);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
6
pmid.hpp
6
pmid.hpp
@ -412,7 +412,7 @@ public:
|
||||
public:
|
||||
PhotoMakerIDEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
SDVersion version = VERSION_SDXL,
|
||||
PMVersion pm_v = PM_VERSION_1,
|
||||
@ -422,9 +422,9 @@ public:
|
||||
pm_version(pm_v),
|
||||
style_strength(sty) {
|
||||
if (pm_version == PM_VERSION_1) {
|
||||
id_encoder.init(params_ctx, tensor_types, prefix);
|
||||
id_encoder.init(params_ctx, tensor_storage_map, prefix);
|
||||
} else if (pm_version == PM_VERSION_2) {
|
||||
id_encoder2.init(params_ctx, tensor_types, prefix);
|
||||
id_encoder2.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -502,12 +502,12 @@ namespace Qwen {
|
||||
|
||||
QwenImageRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
qwen_image_params.num_layers = 0;
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find(prefix) == std::string::npos)
|
||||
continue;
|
||||
@ -526,7 +526,7 @@ namespace Qwen {
|
||||
}
|
||||
LOG_INFO("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers);
|
||||
qwen_image = QwenImageModel(qwen_image_params);
|
||||
qwen_image.init(params_ctx, tensor_types, prefix);
|
||||
qwen_image.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -649,17 +649,16 @@ namespace Qwen {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (ends_with(item.first, "weight")) {
|
||||
item.second = model_data_type;
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (ends_with(name, "weight")) {
|
||||
tensor_storage.expected_type = model_data_type;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<QwenImageRunner> qwen_image = std::make_shared<QwenImageRunner>(backend,
|
||||
false,
|
||||
tensor_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
VERSION_QWEN_IMAGE);
|
||||
|
||||
|
||||
25
qwenvl.hpp
25
qwenvl.hpp
@ -910,13 +910,13 @@ namespace Qwen {
|
||||
|
||||
Qwen2_5_VLRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool enable_vision_ = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) {
|
||||
bool have_vision_weight = false;
|
||||
bool llama_cpp_style = false;
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find(prefix) == std::string::npos)
|
||||
continue;
|
||||
@ -940,7 +940,7 @@ namespace Qwen {
|
||||
}
|
||||
}
|
||||
model = Qwen2_5_VL(params, enable_vision, llama_cpp_style);
|
||||
model.init(params_ctx, tensor_types, prefix);
|
||||
model.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -1188,10 +1188,10 @@ namespace Qwen {
|
||||
|
||||
Qwen2_5_VLEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
bool enable_vision = false)
|
||||
: model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
bool enable_vision = false)
|
||||
: model(backend, offload_params_to_cpu, tensor_storage_map, prefix, enable_vision) {
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
@ -1347,17 +1347,16 @@ namespace Qwen {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (ends_with(item.first, "weight")) {
|
||||
item.second = model_data_type;
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (ends_with(name, "weight")) {
|
||||
tensor_storage.expected_type = model_data_type;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Qwen2_5_VLEmbedder> qwenvl = std::make_shared<Qwen2_5_VLEmbedder>(backend,
|
||||
false,
|
||||
tensor_types,
|
||||
tensor_storage_map,
|
||||
"qwen2vl",
|
||||
true);
|
||||
|
||||
|
||||
@ -213,7 +213,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
bool is_unet = model_loader.model_is_unet();
|
||||
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
|
||||
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
|
||||
@ -273,12 +273,12 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
|
||||
item.second = GGML_TYPE_F16;
|
||||
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (contains(name, "qwen2vl") &&
|
||||
ends_with(name, "weight") &&
|
||||
(tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) {
|
||||
tensor_storage.expected_type = GGML_TYPE_F16;
|
||||
}
|
||||
}
|
||||
|
||||
@ -344,13 +344,13 @@ public:
|
||||
if (sd_version_is_sd3(version)) {
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
bool is_chroma = false;
|
||||
for (auto pair : model_loader.tensor_storages_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||
is_chroma = true;
|
||||
break;
|
||||
@ -368,42 +368,42 @@ public:
|
||||
|
||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
sd_ctx_params->chroma_use_t5_mask,
|
||||
sd_ctx_params->chroma_t5_mask_pad);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
}
|
||||
diffusion_model = std::make_shared<FluxModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
version,
|
||||
sd_ctx_params->chroma_use_dit_mask);
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
true,
|
||||
1,
|
||||
true);
|
||||
diffusion_model = std::make_shared<WanModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
version);
|
||||
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
|
||||
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"model.high_noise_diffusion_model",
|
||||
version);
|
||||
}
|
||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
clip_vision->alloc_params_buffer();
|
||||
clip_vision->get_param_tensors(tensors);
|
||||
}
|
||||
@ -414,32 +414,32 @@ public:
|
||||
}
|
||||
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"",
|
||||
enable_vision);
|
||||
diffusion_model = std::make_shared<QwenImageModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
version);
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
version,
|
||||
PM_VERSION_2);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
version);
|
||||
}
|
||||
diffusion_model = std::make_shared<UNetModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
version);
|
||||
if (sd_ctx_params->diffusion_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the diffusion model");
|
||||
@ -477,7 +477,7 @@ public:
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
version);
|
||||
@ -489,7 +489,7 @@ public:
|
||||
} else if (!use_tiny_autoencoder) {
|
||||
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
false,
|
||||
@ -512,7 +512,7 @@ public:
|
||||
} else {
|
||||
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"decoder.layers",
|
||||
vae_decode_only,
|
||||
version);
|
||||
@ -533,7 +533,7 @@ public:
|
||||
}
|
||||
control_net = std::make_shared<ControlNet>(controlnet_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
version);
|
||||
if (sd_ctx_params->diffusion_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the control net");
|
||||
@ -544,7 +544,7 @@ public:
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"pmid",
|
||||
version,
|
||||
PM_VERSION_2);
|
||||
@ -552,7 +552,7 @@ public:
|
||||
} else {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"pmid",
|
||||
version);
|
||||
}
|
||||
@ -733,12 +733,12 @@ public:
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
|
||||
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
|
||||
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
@ -758,10 +758,9 @@ public:
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 1.0f; // TODO: validate
|
||||
for (auto pair : model_loader.tensor_storages_types) {
|
||||
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
|
||||
shift = 1.15f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
25
t5.hpp
25
t5.hpp
@ -461,7 +461,7 @@ protected:
|
||||
int64_t hidden_size;
|
||||
float eps;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||
}
|
||||
@ -759,7 +759,7 @@ struct T5Runner : public GGMLRunner {
|
||||
|
||||
T5Runner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool is_umt5 = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
@ -768,7 +768,7 @@ struct T5Runner : public GGMLRunner {
|
||||
params.relative_attention = false;
|
||||
}
|
||||
model = T5(params);
|
||||
model.init(params_ctx, tensor_types, prefix);
|
||||
model.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -905,10 +905,10 @@ struct T5Embedder {
|
||||
|
||||
T5Embedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
bool is_umt5 = false)
|
||||
: model(backend, offload_params_to_cpu, tensor_types, prefix, is_umt5), tokenizer(is_umt5) {
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
bool is_umt5 = false)
|
||||
: model(backend, offload_params_to_cpu, tensor_storage_map, prefix, is_umt5), tokenizer(is_umt5) {
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
@ -1009,15 +1009,14 @@ struct T5Embedder {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (ends_with(item.first, "weight")) {
|
||||
item.second = model_data_type;
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (ends_with(name, "weight")) {
|
||||
tensor_storage.expected_type = model_data_type;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<T5Embedder> t5 = std::make_shared<T5Embedder>(backend, false, tensor_types, "", true);
|
||||
std::shared_ptr<T5Embedder> t5 = std::make_shared<T5Embedder>(backend, false, tensor_storage_map, "", true);
|
||||
|
||||
t5->alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> tensors;
|
||||
|
||||
4
tae.hpp
4
tae.hpp
@ -197,14 +197,14 @@ struct TinyAutoEncoder : public GGMLRunner {
|
||||
|
||||
TinyAutoEncoder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decoder_only = true,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decoder_only),
|
||||
taesd(decoder_only, version),
|
||||
GGMLRunner(backend, offload_params_to_cpu) {
|
||||
taesd.init(params_ctx, tensor_types, prefix);
|
||||
taesd.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
|
||||
37
unet.hpp
37
unet.hpp
@ -20,9 +20,10 @@ public:
|
||||
int64_t d_head,
|
||||
int64_t depth,
|
||||
int64_t context_dim,
|
||||
bool use_linear,
|
||||
int64_t time_depth = 1,
|
||||
int64_t max_time_embed_period = 10000)
|
||||
: SpatialTransformer(in_channels, n_head, d_head, depth, context_dim),
|
||||
: SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear),
|
||||
max_time_embed_period(max_time_embed_period) {
|
||||
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
|
||||
// use_spatial_context is always True
|
||||
@ -178,17 +179,19 @@ protected:
|
||||
int num_heads = 8;
|
||||
int num_head_channels = -1; // channels // num_heads
|
||||
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
|
||||
bool use_linear_projection = false;
|
||||
|
||||
public:
|
||||
int model_channels = 320;
|
||||
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
|
||||
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {})
|
||||
UnetModelBlock(SDVersion version = VERSION_SD1, const String2TensorStorage& tensor_storage_map = {})
|
||||
: version(version) {
|
||||
if (sd_version_is_sd2(version)) {
|
||||
context_dim = 1024;
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
context_dim = 1024;
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
use_linear_projection = true;
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
context_dim = 2048;
|
||||
attention_resolutions = {4, 2};
|
||||
@ -196,13 +199,15 @@ public:
|
||||
transformer_depth = {1, 2, 10};
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
use_linear_projection = true;
|
||||
} else if (version == VERSION_SVD) {
|
||||
in_channels = 8;
|
||||
out_channels = 4;
|
||||
context_dim = 1024;
|
||||
adm_in_channels = 768;
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
in_channels = 8;
|
||||
out_channels = 4;
|
||||
context_dim = 1024;
|
||||
adm_in_channels = 768;
|
||||
num_head_channels = 64;
|
||||
num_heads = -1;
|
||||
use_linear_projection = true;
|
||||
} else if (version == VERSION_SD1_TINY_UNET) {
|
||||
num_res_blocks = 1;
|
||||
channel_mult = {1, 2, 4};
|
||||
@ -249,9 +254,9 @@ public:
|
||||
int64_t depth,
|
||||
int64_t context_dim) -> SpatialTransformer* {
|
||||
if (version == VERSION_SVD) {
|
||||
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
|
||||
} else {
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim);
|
||||
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
|
||||
}
|
||||
};
|
||||
|
||||
@ -581,11 +586,11 @@ struct UNetModelRunner : public GGMLRunner {
|
||||
|
||||
UNetModelRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types) {
|
||||
unet.init(params_ctx, tensor_types, prefix);
|
||||
: GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_storage_map) {
|
||||
unet.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
|
||||
@ -51,7 +51,7 @@ struct UpscalerGGML {
|
||||
backend = ggml_backend_cpu_init();
|
||||
}
|
||||
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
|
||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.get_tensor_storage_map());
|
||||
if (direct) {
|
||||
esrgan_upscaler->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
|
||||
118
vae.hpp
118
vae.hpp
@ -64,25 +64,32 @@ public:
|
||||
class AttnBlock : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
bool use_linear;
|
||||
|
||||
public:
|
||||
AttnBlock(int64_t in_channels)
|
||||
: in_channels(in_channels) {
|
||||
AttnBlock(int64_t in_channels, bool use_linear)
|
||||
: in_channels(in_channels), use_linear(use_linear) {
|
||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
if (use_linear) {
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
} else {
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w]
|
||||
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
||||
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
|
||||
auto k_proj = std::dynamic_pointer_cast<Conv2d>(blocks["k"]);
|
||||
auto v_proj = std::dynamic_pointer_cast<Conv2d>(blocks["v"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
|
||||
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
|
||||
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
|
||||
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
|
||||
|
||||
auto h_ = norm->forward(ctx, x);
|
||||
|
||||
@ -91,23 +98,44 @@ public:
|
||||
const int64_t h = h_->ne[1];
|
||||
const int64_t w = h_->ne[0];
|
||||
|
||||
auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
||||
ggml_tensor* q;
|
||||
ggml_tensor* k;
|
||||
ggml_tensor* v;
|
||||
if (use_linear) {
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
|
||||
auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
} else {
|
||||
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
|
||||
}
|
||||
|
||||
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
|
||||
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
if (use_linear) {
|
||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
|
||||
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
} else {
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
|
||||
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
}
|
||||
|
||||
h_ = ggml_add(ctx->ggml_ctx, h_, x);
|
||||
return h_;
|
||||
@ -163,8 +191,8 @@ public:
|
||||
|
||||
class VideoResnetBlock : public ResnetBlock {
|
||||
protected:
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
|
||||
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
|
||||
}
|
||||
|
||||
@ -233,7 +261,8 @@ public:
|
||||
int num_res_blocks,
|
||||
int in_channels,
|
||||
int z_channels,
|
||||
bool double_z = true)
|
||||
bool double_z = true,
|
||||
bool use_linear_projection = false)
|
||||
: ch(ch),
|
||||
ch_mult(ch_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
@ -264,7 +293,7 @@ public:
|
||||
}
|
||||
|
||||
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||
@ -351,8 +380,9 @@ public:
|
||||
std::vector<int> ch_mult,
|
||||
int num_res_blocks,
|
||||
int z_channels,
|
||||
bool video_decoder = false,
|
||||
int video_kernel_size = 3)
|
||||
bool use_linear_projection = false,
|
||||
bool video_decoder = false,
|
||||
int video_kernel_size = 3)
|
||||
: ch(ch),
|
||||
out_ch(out_ch),
|
||||
ch_mult(ch_mult),
|
||||
@ -366,7 +396,7 @@ public:
|
||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
|
||||
|
||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||
@ -454,9 +484,10 @@ protected:
|
||||
} dd_config;
|
||||
|
||||
public:
|
||||
AutoencodingEngine(bool decode_only = true,
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
AutoencodingEngine(SDVersion version = VERSION_SD1,
|
||||
bool decode_only = true,
|
||||
bool use_linear_projection = false,
|
||||
bool use_video_decoder = false)
|
||||
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
dd_config.z_channels = 16;
|
||||
@ -470,6 +501,7 @@ public:
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.z_channels,
|
||||
use_linear_projection,
|
||||
use_video_decoder));
|
||||
if (use_quant) {
|
||||
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
|
||||
@ -482,7 +514,8 @@ public:
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.in_channels,
|
||||
dd_config.z_channels,
|
||||
dd_config.double_z));
|
||||
dd_config.double_z,
|
||||
use_linear_projection));
|
||||
if (use_quant) {
|
||||
int factor = dd_config.double_z ? 2 : 1;
|
||||
|
||||
@ -562,13 +595,26 @@ struct AutoEncoderKL : public VAE {
|
||||
|
||||
AutoEncoderKL(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decode_only = false,
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), ae(decode_only, use_video_decoder, version), VAE(backend, offload_params_to_cpu) {
|
||||
ae.init(params_ctx, tensor_types, prefix);
|
||||
: decode_only(decode_only), VAE(backend, offload_params_to_cpu) {
|
||||
bool use_linear_projection = false;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (ends_with(name, "attn_1.proj_out.weight")) {
|
||||
if (tensor_storage.n_dims == 2) {
|
||||
use_linear_projection = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder);
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
void set_conv2d_scale(float scale) override {
|
||||
|
||||
59
wan.hpp
59
wan.hpp
@ -26,7 +26,7 @@ namespace WAN {
|
||||
std::tuple<int, int, int> dilation;
|
||||
bool bias;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||
GGML_TYPE_F16,
|
||||
std::get<2>(kernel_size),
|
||||
@ -87,9 +87,14 @@ namespace WAN {
|
||||
protected:
|
||||
int64_t dim;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
ggml_type wtype = GGML_TYPE_F32;
|
||||
params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim);
|
||||
auto iter = tensor_storage_map.find(prefix + "gamma");
|
||||
if (iter != tensor_storage_map.end()) {
|
||||
params["gamma"] = ggml_new_tensor(ctx, wtype, iter->second.n_dims, &iter->second.ne[0]);
|
||||
} else {
|
||||
params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
@ -101,6 +106,7 @@ namespace WAN {
|
||||
// assert N == 1
|
||||
|
||||
struct ggml_tensor* w = params["gamma"];
|
||||
w = ggml_reshape_1d(ctx->ggml_ctx, w, ggml_nelements(w));
|
||||
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||
h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12);
|
||||
h = ggml_mul(ctx->ggml_ctx, h, w);
|
||||
@ -1110,12 +1116,12 @@ namespace WAN {
|
||||
|
||||
WanVAERunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
bool decode_only = false,
|
||||
SDVersion version = VERSION_WAN2)
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
bool decode_only = false,
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
||||
ae.init(params_ctx, tensor_types, prefix);
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -1256,7 +1262,7 @@ namespace WAN {
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_F16;
|
||||
std::shared_ptr<WanVAERunner> vae = std::make_shared<WanVAERunner>(backend, false, String2GGMLType{}, "", false, VERSION_WAN2_2_TI2V);
|
||||
std::shared_ptr<WanVAERunner> vae = std::make_shared<WanVAERunner>(backend, false, String2TensorStorage{}, "", false, VERSION_WAN2_2_TI2V);
|
||||
{
|
||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||
|
||||
@ -1494,8 +1500,8 @@ namespace WAN {
|
||||
protected:
|
||||
int dim;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1);
|
||||
}
|
||||
|
||||
@ -1582,8 +1588,8 @@ namespace WAN {
|
||||
class VaceWanAttentionBlock : public WanAttentionBlock {
|
||||
protected:
|
||||
int block_id;
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1);
|
||||
}
|
||||
|
||||
@ -1634,8 +1640,8 @@ namespace WAN {
|
||||
protected:
|
||||
int dim;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1);
|
||||
}
|
||||
|
||||
@ -1681,7 +1687,7 @@ namespace WAN {
|
||||
int in_dim;
|
||||
int flf_pos_embed_token_number;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override {
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
if (flf_pos_embed_token_number > 0) {
|
||||
params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1);
|
||||
}
|
||||
@ -2015,12 +2021,12 @@ namespace WAN {
|
||||
|
||||
WanRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_WAN2)
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
wan_params.num_layers = 0;
|
||||
for (auto pair : tensor_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find(prefix) == std::string::npos)
|
||||
continue;
|
||||
@ -2117,7 +2123,7 @@ namespace WAN {
|
||||
LOG_INFO("%s", desc.c_str());
|
||||
|
||||
wan = Wan(wan_params);
|
||||
wan.init(params_ctx, tensor_types, prefix);
|
||||
wan.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -2254,17 +2260,16 @@ namespace WAN {
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (ends_with(item.first, "weight")) {
|
||||
item.second = model_data_type;
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (ends_with(name, "weight")) {
|
||||
tensor_storage.expected_type = model_data_type;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<WanRunner> wan = std::make_shared<WanRunner>(backend,
|
||||
false,
|
||||
tensor_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
VERSION_WAN2_2_TI2V);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user