diff --git a/clip.hpp b/clip.hpp index d359f61..7ca565d 100644 --- a/clip.hpp +++ b/clip.hpp @@ -545,9 +545,9 @@ protected: int64_t vocab_size; int64_t num_positions; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type token_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32; - enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type token_wtype = GGML_TYPE_F32; + enum ggml_type position_wtype = GGML_TYPE_F32; params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size); params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions); @@ -594,10 +594,10 @@ protected: int64_t image_size; int64_t num_patches; int64_t num_positions; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16; - enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32; - enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type patch_wtype = GGML_TYPE_F16; + enum ggml_type class_wtype = GGML_TYPE_F32; + enum ggml_type position_wtype = GGML_TYPE_F32; params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim); params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim); @@ -657,9 +657,9 @@ enum CLIPVersion { class CLIPTextModel : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { if (version == OPEN_CLIP_VIT_BIGG_14) { - enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size); } } @@ -805,8 +805,8 @@ protected: int64_t out_features; bool transpose_weight; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); if (transpose_weight) { params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features); } else { @@ -868,7 +868,7 @@ struct CLIPTextModelRunner : public GGMLRunner { CLIPTextModel model; CLIPTextModelRunner(ggml_backend_t backend, - std::map& tensor_types, + const String2GGMLType& tensor_types, const std::string prefix, CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool with_final_ln = true, diff --git a/common.hpp b/common.hpp index 9b5cc53..3a13077 100644 --- a/common.hpp +++ b/common.hpp @@ -182,9 +182,9 @@ protected: int64_t dim_in; int64_t dim_out; - void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { - enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32; - enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32); + enum ggml_type bias_wtype = GGML_TYPE_F32; params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); } @@ -440,9 +440,9 @@ public: class AlphaBlender : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { // Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } diff --git a/conditioner.hpp b/conditioner.hpp index 3f89d52..6a51dce 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -57,7 +57,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector readed_embeddings; FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend, - std::map& tensor_types, + const String2GGMLType& tensor_types, const std::string& embd_dir, SDVersion version = VERSION_SD1, PMVersion pv = PM_VERSION_1, @@ -618,7 +618,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { struct FrozenCLIPVisionEmbedder : public GGMLRunner { CLIPVisionModelProjection vision_model; - FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map& tensor_types) + FrozenCLIPVisionEmbedder(ggml_backend_t backend, const String2GGMLType& tensor_types = {}) : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) { vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer"); } @@ -663,8 +663,8 @@ struct SD3CLIPEmbedder : public Conditioner { std::shared_ptr t5; SD3CLIPEmbedder(ggml_backend_t backend, - std::map& tensor_types, - int clip_skip = -1) + const String2GGMLType& tensor_types = {}, + int clip_skip = -1) : clip_g_tokenizer(0) { clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); clip_g = std::make_shared(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); @@ -1010,8 +1010,8 @@ struct FluxCLIPEmbedder : public Conditioner { size_t chunk_len = 256; FluxCLIPEmbedder(ggml_backend_t backend, - std::map& tensor_types, - int clip_skip = -1) { + const String2GGMLType& tensor_types = {}, + int clip_skip = -1) { clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true); t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); set_clip_skip(clip_skip); @@ -1231,10 +1231,10 @@ struct PixArtCLIPEmbedder : public Conditioner { int mask_pad = 1; PixArtCLIPEmbedder(ggml_backend_t backend, - std::map& tensor_types, - int clip_skip = -1, - bool use_mask = false, - int mask_pad = 1) + const String2GGMLType& tensor_types = {}, + int clip_skip = -1, + bool use_mask = false, + int mask_pad = 1) : use_mask(use_mask), mask_pad(mask_pad) { t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); } diff --git a/control.hpp b/control.hpp index 23b75fe..d8f81fc 100644 --- a/control.hpp +++ b/control.hpp @@ -317,8 +317,8 @@ struct ControlNet : public GGMLRunner { bool guided_hint_cached = false; ControlNet(ggml_backend_t backend, - std::map& tensor_types, - SDVersion version = VERSION_SD1) + const String2GGMLType& tensor_types = {}, + SDVersion version = VERSION_SD1) : GGMLRunner(backend), control_net(version) { control_net.init(params_ctx, tensor_types, ""); } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 5c34943..787a4fa 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -32,9 +32,9 @@ struct UNetModel : public DiffusionModel { UNetModelRunner unet; UNetModel(ggml_backend_t backend, - std::map& tensor_types, - SDVersion version = VERSION_SD1, - bool flash_attn = false) + const String2GGMLType& tensor_types = {}, + SDVersion version = VERSION_SD1, + bool flash_attn = false) : unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) { } @@ -85,7 +85,7 @@ struct MMDiTModel : public DiffusionModel { MMDiTRunner mmdit; MMDiTModel(ggml_backend_t backend, - std::map& tensor_types) + const String2GGMLType& tensor_types = {}) : mmdit(backend, tensor_types, "model.diffusion_model") { } @@ -135,10 +135,10 @@ struct FluxModel : public DiffusionModel { Flux::FluxRunner flux; FluxModel(ggml_backend_t backend, - std::map& tensor_types, - SDVersion version = VERSION_FLUX, - bool flash_attn = false, - bool use_mask = false) + const String2GGMLType& tensor_types = {}, + SDVersion version = VERSION_FLUX, + bool flash_attn = false, + bool use_mask = false) : flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) { } diff --git a/esrgan.hpp b/esrgan.hpp index 5cbb4ad..4215db1 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -142,7 +142,7 @@ struct ESRGAN : public GGMLRunner { int scale = 4; int tile_size = 128; // avoid cuda OOM for 4gb VRAM - ESRGAN(ggml_backend_t backend, std::map& tensor_types) + ESRGAN(ggml_backend_t backend, const String2GGMLType& tensor_types = {}) : GGMLRunner(backend) { rrdb_net.init(params_ctx, tensor_types, ""); } diff --git a/flux.hpp b/flux.hpp index 1104591..40838f2 100644 --- a/flux.hpp +++ b/flux.hpp @@ -35,8 +35,8 @@ namespace Flux { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "scale") != tensor_types.end()) ? tensor_types[prefix + "scale"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + ggml_type wtype = GGML_TYPE_F32; params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -1039,8 +1039,6 @@ namespace Flux { }; struct FluxRunner : public GGMLRunner { - static std::map empty_tensor_types; - public: FluxParams flux_params; Flux flux; @@ -1050,11 +1048,11 @@ namespace Flux { bool use_mask = false; FluxRunner(ggml_backend_t backend, - std::map& tensor_types = empty_tensor_types, - const std::string prefix = "", - SDVersion version = VERSION_FLUX, - bool flash_attn = false, - bool use_mask = false) + const String2GGMLType& tensor_types = {}, + const std::string prefix = "", + SDVersion version = VERSION_FLUX, + bool flash_attn = false, + bool use_mask = false) : GGMLRunner(backend), use_mask(use_mask) { flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index eb33f02..d4e4278 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -841,21 +841,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); int kv_pad = 0; - //if (flash_attn) { - // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - //} - // is there anything oddly shaped?? ping Green-Sky if you can trip this assert + // if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + // } + // is there anything oddly shaped?? ping Green-Sky if you can trip this assert GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); bool can_use_flash_attn = true; - can_use_flash_attn = can_use_flash_attn && ( - d_head == 64 || - d_head == 80 || - d_head == 96 || - d_head == 112 || - d_head == 128 || - d_head == 256 - ); + can_use_flash_attn = can_use_flash_attn && (d_head == 64 || + d_head == 80 || + d_head == 96 || + d_head == 112 || + d_head == 128 || + d_head == 256); #if 0 can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; #else @@ -880,9 +878,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ggml_tensor* kqv = nullptr; // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); if (can_use_flash_attn && flash_attn) { - //LOG_DEBUG(" uses flash attention"); + // LOG_DEBUG(" uses flash attention"); if (kv_pad != 0) { - //LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); + // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); k = ggml_pad(ctx, k, 0, kv_pad, 0, 0); } k = ggml_cast(ctx, k, GGML_TYPE_F16); @@ -1099,6 +1097,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { #define MAX_PARAMS_TENSOR_NUM 32768 #define MAX_GRAPH_SIZE 32768 +typedef std::map String2GGMLType; + struct GGMLRunner { protected: typedef std::function get_graph_cb_t; @@ -1310,17 +1310,25 @@ protected: GGMLBlockMap blocks; ParameterMap params; - void init_blocks(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + ggml_type get_type(const std::string& name, const String2GGMLType& tensor_types, ggml_type default_type) { + auto iter = tensor_types.find(name); + if (iter != tensor_types.end()) { + return iter->second; + } + return default_type; + } + + void init_blocks(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { for (auto& pair : blocks) { auto& block = pair.second; block->init(ctx, tensor_types, prefix + pair.first); } } - virtual void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") {} + virtual void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {} public: - void init(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + void init(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { if (prefix.size() > 0) { prefix = prefix + "."; } @@ -1381,8 +1389,8 @@ protected: bool bias; bool force_f32; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; } @@ -1417,8 +1425,8 @@ class Embedding : public UnaryBlock { protected: int64_t embedding_dim; int64_t num_embeddings; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); } @@ -1457,11 +1465,11 @@ protected: std::pair dilation; bool bias; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels); if (bias) { - enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); } } @@ -1502,11 +1510,11 @@ protected: int64_t dilation; bool bias; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d if (bias) { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); } } @@ -1546,12 +1554,12 @@ protected: bool elementwise_affine; bool bias; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { if (elementwise_affine) { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); if (bias) { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; params["bias"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); } } @@ -1588,10 +1596,10 @@ protected: float eps; bool affine; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { if (affine) { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; - enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + enum ggml_type wtype = GGML_TYPE_F32; + enum ggml_type bias_wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, num_channels); params["bias"] = ggml_new_tensor_1d(ctx, bias_wtype, num_channels); } diff --git a/mmdit.hpp b/mmdit.hpp index dee7b1c..a93a35d 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -147,8 +147,8 @@ protected: int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -652,13 +652,13 @@ protected: int64_t hidden_size; std::string qk_norm; - void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "pos_embed") != tensor_types.end()) ? tensor_types[prefix + "pos_embed"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); } public: - MMDiT(std::map& tensor_types) { + MMDiT(const String2GGMLType& tensor_types = {}) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 @@ -869,11 +869,9 @@ public: struct MMDiTRunner : public GGMLRunner { MMDiT mmdit; - static std::map empty_tensor_types; - MMDiTRunner(ggml_backend_t backend, - std::map& tensor_types = empty_tensor_types, - const std::string prefix = "") + const String2GGMLType& tensor_types = {}, + const std::string prefix = "") : GGMLRunner(backend), mmdit(tensor_types) { mmdit.init(params_ctx, tensor_types, prefix); } diff --git a/model.cpp b/model.cpp index 9529cc5..df1c863 100644 --- a/model.cpp +++ b/model.cpp @@ -648,7 +648,7 @@ std::string convert_tensor_name(std::string name) { return new_name; } -void add_preprocess_tensor_storage_types(std::map& tensor_storages_types, std::string name, enum ggml_type type) { +void add_preprocess_tensor_storage_types(String2GGMLType& tensor_storages_types, std::string name, enum ggml_type type) { std::string new_name = convert_tensor_name(name); if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) { diff --git a/model.h b/model.h index ea71610..869c24c 100644 --- a/model.h +++ b/model.h @@ -207,6 +207,8 @@ struct TensorStorage { typedef std::function on_new_tensor_cb_t; +typedef std::map String2GGMLType; + class ModelLoader { protected: std::vector file_paths_; @@ -225,7 +227,7 @@ protected: bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: - std::map tensor_storages_types; + String2GGMLType tensor_storages_types; bool init_from_file(const std::string& file_path, const std::string& prefix = ""); bool model_is_unet(); diff --git a/pmid.hpp b/pmid.hpp index ea9f02e..e2a0f62 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -623,7 +623,12 @@ public: std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, std::map& tensor_types, const std::string prefix, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f) + PhotoMakerIDEncoder(ggml_backend_t backend, + const String2GGMLType& tensor_types, + const std::string prefix, + SDVersion version = VERSION_SDXL, + PMVersion pm_v = PM_VERSION_1, + float sty = 20.f) : GGMLRunner(backend), version(version), pm_version(pm_v), diff --git a/t5.hpp b/t5.hpp index d511ef2..f00dc96 100644 --- a/t5.hpp +++ b/t5.hpp @@ -457,8 +457,8 @@ protected: int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -735,7 +735,7 @@ struct T5Runner : public GGMLRunner { std::vector relative_position_bucket_vec; T5Runner(ggml_backend_t backend, - std::map& tensor_types, + const String2GGMLType& tensor_types, const std::string prefix, int64_t num_layers = 24, int64_t model_dim = 4096, @@ -876,16 +876,14 @@ struct T5Embedder { T5UniGramTokenizer tokenizer; T5Runner model; - static std::map empty_tensor_types; - T5Embedder(ggml_backend_t backend, - std::map& tensor_types = empty_tensor_types, - const std::string prefix = "", - int64_t num_layers = 24, - int64_t model_dim = 4096, - int64_t ff_dim = 10240, - int64_t num_heads = 64, - int64_t vocab_size = 32128) + const String2GGMLType& tensor_types = {}, + const std::string prefix = "", + int64_t num_layers = 24, + int64_t model_dim = 4096, + int64_t ff_dim = 10240, + int64_t num_heads = 64, + int64_t vocab_size = 32128) : model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) { } diff --git a/tae.hpp b/tae.hpp index 678c44c..51fb94f 100644 --- a/tae.hpp +++ b/tae.hpp @@ -196,7 +196,7 @@ struct TinyAutoEncoder : public GGMLRunner { bool decode_only = false; TinyAutoEncoder(ggml_backend_t backend, - std::map& tensor_types, + const String2GGMLType& tensor_types, const std::string prefix, bool decoder_only = true, SDVersion version = VERSION_SD1) diff --git a/unet.hpp b/unet.hpp index 9193dcd..7ab4934 100644 --- a/unet.hpp +++ b/unet.hpp @@ -166,7 +166,6 @@ public: // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { protected: - static std::map empty_tensor_types; SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; @@ -184,7 +183,7 @@ public: int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1, std::map& tensor_types = empty_tensor_types, bool flash_attn = false) + UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}, bool flash_attn = false) : version(version) { if (sd_version_is_sd2(version)) { context_dim = 1024; @@ -539,7 +538,7 @@ struct UNetModelRunner : public GGMLRunner { UnetModelBlock unet; UNetModelRunner(ggml_backend_t backend, - std::map& tensor_types, + const String2GGMLType& tensor_types, const std::string prefix, SDVersion version = VERSION_SD1, bool flash_attn = false) diff --git a/vae.hpp b/vae.hpp index 4add881..41f53ee 100644 --- a/vae.hpp +++ b/vae.hpp @@ -163,8 +163,8 @@ public: class VideoResnetBlock : public ResnetBlock { protected: - void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { - enum ggml_type wtype = (tensor_types.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_types, GGML_TYPE_F32); params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } @@ -525,7 +525,7 @@ struct AutoEncoderKL : public GGMLRunner { AutoencodingEngine ae; AutoEncoderKL(ggml_backend_t backend, - std::map& tensor_types, + const String2GGMLType& tensor_types, const std::string prefix, bool decode_only = false, bool use_video_decoder = false,