refactor: optimize the logic for name conversion and the processing of the LoRA model (#955)

This commit is contained in:
leejet 2025-11-10 00:12:20 +08:00 committed by GitHub
parent 8ecdf053ac
commit 694f0d9235
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1670 additions and 1415 deletions

View File

@ -111,7 +111,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) { bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) {
// the order matters // the order matters
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(embd_path)) { if (!model_loader.init_from_file_and_convert_name(embd_path)) {
LOG_ERROR("embedding '%s' failed", embd_name.c_str()); LOG_ERROR("embedding '%s' failed", embd_name.c_str());
return false; return false;
} }

View File

@ -442,7 +442,7 @@ struct ControlNet : public GGMLRunner {
std::set<std::string> ignore_tensors; std::set<std::string> ignore_tensors;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) { if (!model_loader.init_from_file_and_convert_name(file_path)) {
LOG_ERROR("init control net model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init control net model loader from file failed: '%s'", file_path.c_str());
return false; return false;
} }

View File

@ -169,7 +169,7 @@ struct ESRGAN : public GGMLRunner {
LOG_INFO("loading esrgan from '%s'", file_path.c_str()); LOG_INFO("loading esrgan from '%s'", file_path.c_str());
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) { if (!model_loader.init_from_file_and_convert_name(file_path)) {
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
return false; return false;
} }

View File

@ -1398,7 +1398,7 @@ namespace Flux {
ggml_type model_data_type = GGML_TYPE_Q8_0; ggml_type model_data_type = GGML_TYPE_Q8_0;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }

View File

@ -1568,8 +1568,10 @@ protected:
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) { struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
prepare_build_in_tensor_before(); prepare_build_in_tensor_before();
struct ggml_cgraph* gf = get_graph(); struct ggml_cgraph* gf = get_graph();
if (ggml_graph_n_nodes(gf) > 0) {
auto result = ggml_graph_node(gf, -1); auto result = ggml_graph_node(gf, -1);
ggml_set_name(result, final_result_name.c_str()); ggml_set_name(result, final_result_name.c_str());
}
prepare_build_in_tensor_after(gf); prepare_build_in_tensor_after(gf);
return gf; return gf;
} }

987
lora.hpp

File diff suppressed because it is too large Load Diff

View File

@ -961,7 +961,7 @@ struct MMDiTRunner : public GGMLRunner {
mmdit->get_param_tensors(tensors, "model.diffusion_model"); mmdit->get_param_tensors(tensors, "model.diffusion_model");
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) { if (!model_loader.init_from_file_and_convert_name(file_path)) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }

687
model.cpp
View File

@ -25,6 +25,7 @@
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml.h" #include "ggml.h"
#include "name_conversion.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#ifdef SD_USE_METAL #ifdef SD_USE_METAL
@ -75,15 +76,6 @@ uint16_t read_short(uint8_t* buffer) {
/*================================================= Preprocess ==================================================*/ /*================================================= Preprocess ==================================================*/
std::string self_attn_names[] = {
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
"self_attn.q_proj.bias",
"self_attn.k_proj.bias",
"self_attn.v_proj.bias",
};
const char* unused_tensors[] = { const char* unused_tensors[] = {
"betas", "betas",
"alphas_cumprod_prev", "alphas_cumprod_prev",
@ -97,9 +89,9 @@ const char* unused_tensors[] = {
"posterior_mean_coef1", "posterior_mean_coef1",
"posterior_mean_coef2", "posterior_mean_coef2",
"cond_stage_model.transformer.text_model.embeddings.position_ids", "cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.1.model.text_model.embeddings.position_ids",
"cond_stage_model.transformer.vision_model.embeddings.position_ids", "cond_stage_model.transformer.vision_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale", "cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection",
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids", "conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
"conditioner.embedders.0.model.logit_scale", "conditioner.embedders.0.model.logit_scale",
"conditioner.embedders.1.model.logit_scale", "conditioner.embedders.1.model.logit_scale",
@ -110,6 +102,7 @@ const char* unused_tensors[] = {
"model_ema.diffusion_model", "model_ema.diffusion_model",
"embedding_manager", "embedding_manager",
"denoiser.sigmas", "denoiser.sigmas",
"edm_vpred.sigma_max",
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.qwen2vl.output.weight", "text_encoders.qwen2vl.output.weight",
"text_encoders.qwen2vl.lm_head.", "text_encoders.qwen2vl.lm_head.",
@ -124,622 +117,6 @@ bool is_unused_tensor(std::string name) {
return false; return false;
} }
std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
{"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"},
{"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"},
{"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"},
{"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"},
{"model.text_projection", "transformer.text_model.text_projection"},
{"model.visual.class_embedding", "transformer.vision_model.embeddings.class_embedding"},
{"model.visual.conv1.weight", "transformer.vision_model.embeddings.patch_embedding.weight"},
{"model.visual.ln_post.bias", "transformer.vision_model.post_layernorm.bias"},
{"model.visual.ln_post.weight", "transformer.vision_model.post_layernorm.weight"},
{"model.visual.ln_pre.bias", "transformer.vision_model.pre_layernorm.bias"},
{"model.visual.ln_pre.weight", "transformer.vision_model.pre_layernorm.weight"},
{"model.visual.positional_embedding", "transformer.vision_model.embeddings.position_embedding.weight"},
{"model.visual.proj", "transformer.visual_projection.weight"},
};
std::unordered_map<std::string, std::string> open_clip_to_hf_clip_resblock = {
{"attn.in_proj_bias", "self_attn.in_proj.bias"},
{"attn.in_proj_weight", "self_attn.in_proj.weight"},
{"attn.out_proj.bias", "self_attn.out_proj.bias"},
{"attn.out_proj.weight", "self_attn.out_proj.weight"},
{"ln_1.bias", "layer_norm1.bias"},
{"ln_1.weight", "layer_norm1.weight"},
{"ln_2.bias", "layer_norm2.bias"},
{"ln_2.weight", "layer_norm2.weight"},
{"mlp.c_fc.bias", "mlp.fc1.bias"},
{"mlp.c_fc.weight", "mlp.fc1.weight"},
{"mlp.c_proj.bias", "mlp.fc2.bias"},
{"mlp.c_proj.weight", "mlp.fc2.weight"},
};
std::unordered_map<std::string, std::string> cond_model_name_map = {
{"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"},
{"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"},
};
std::unordered_map<std::string, std::string> vae_decoder_name_map = {
{"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"},
{"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"},
{"first_stage_model.decoder.mid.attn_1.to_out.0.bias", "first_stage_model.decoder.mid.attn_1.proj_out.bias"},
{"first_stage_model.decoder.mid.attn_1.to_out.0.weight", "first_stage_model.decoder.mid.attn_1.proj_out.weight"},
{"first_stage_model.decoder.mid.attn_1.to_q.bias", "first_stage_model.decoder.mid.attn_1.q.bias"},
{"first_stage_model.decoder.mid.attn_1.to_q.weight", "first_stage_model.decoder.mid.attn_1.q.weight"},
{"first_stage_model.decoder.mid.attn_1.to_v.bias", "first_stage_model.decoder.mid.attn_1.v.bias"},
{"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"},
};
std::unordered_map<std::string, std::string> pmid_v2_name_map = {
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"},
{"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight",
"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"},
{"pmid.qformer_perceiver.token_proj.0.bias",
"pmid.qformer_perceiver.token_proj.fc1.bias"},
{"pmid.qformer_perceiver.token_proj.2.bias",
"pmid.qformer_perceiver.token_proj.fc2.bias"},
{"pmid.qformer_perceiver.token_proj.0.weight",
"pmid.qformer_perceiver.token_proj.fc1.weight"},
{"pmid.qformer_perceiver.token_proj.2.weight",
"pmid.qformer_perceiver.token_proj.fc2.weight"},
};
std::unordered_map<std::string, std::string> qwenvl_name_map{
{"token_embd.", "model.embed_tokens."},
{"blk.", "model.layers."},
{"attn_q.", "self_attn.q_proj."},
{"attn_k.", "self_attn.k_proj."},
{"attn_v.", "self_attn.v_proj."},
{"attn_output.", "self_attn.o_proj."},
{"attn_norm.", "input_layernorm."},
{"ffn_down.", "mlp.down_proj."},
{"ffn_gate.", "mlp.gate_proj."},
{"ffn_up.", "mlp.up_proj."},
{"ffn_norm.", "post_attention_layernorm."},
{"output_norm.", "model.norm."},
};
std::unordered_map<std::string, std::string> qwenvl_vision_name_map{
{"mm.", "merger.mlp."},
{"v.post_ln.", "merger.ln_q."},
{"v.patch_embd.weight", "patch_embed.proj.0.weight"},
{"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"},
{"v.patch_embd.weight.1", "patch_embed.proj.1.weight"},
{"v.blk.", "blocks."},
{"attn_q.", "attn.q_proj."},
{"attn_k.", "attn.k_proj."},
{"attn_v.", "attn.v_proj."},
{"attn_out.", "attn.proj."},
{"ffn_down.", "mlp.down_proj."},
{"ffn_gate.", "mlp.gate_proj."},
{"ffn_up.", "mlp.up_proj."},
{"ln1.", "norm1."},
{"ln2.", "norm2."},
};
std::string convert_cond_model_name(const std::string& name) {
std::string new_name = name;
std::string prefix;
if (contains(new_name, ".enc.")) {
// llama.cpp naming convention for T5
size_t pos = new_name.find(".enc.");
if (pos != std::string::npos) {
new_name.replace(pos, 5, ".encoder.");
}
pos = new_name.find("blk.");
if (pos != std::string::npos) {
new_name.replace(pos, 4, "block.");
}
pos = new_name.find("output_norm.");
if (pos != std::string::npos) {
new_name.replace(pos, 12, "final_layer_norm.");
}
pos = new_name.find("attn_k.");
if (pos != std::string::npos) {
new_name.replace(pos, 7, "layer.0.SelfAttention.k.");
}
pos = new_name.find("attn_v.");
if (pos != std::string::npos) {
new_name.replace(pos, 7, "layer.0.SelfAttention.v.");
}
pos = new_name.find("attn_o.");
if (pos != std::string::npos) {
new_name.replace(pos, 7, "layer.0.SelfAttention.o.");
}
pos = new_name.find("attn_q.");
if (pos != std::string::npos) {
new_name.replace(pos, 7, "layer.0.SelfAttention.q.");
}
pos = new_name.find("attn_norm.");
if (pos != std::string::npos) {
new_name.replace(pos, 10, "layer.0.layer_norm.");
}
pos = new_name.find("ffn_norm.");
if (pos != std::string::npos) {
new_name.replace(pos, 9, "layer.1.layer_norm.");
}
pos = new_name.find("ffn_up.");
if (pos != std::string::npos) {
new_name.replace(pos, 7, "layer.1.DenseReluDense.wi_1.");
}
pos = new_name.find("ffn_down.");
if (pos != std::string::npos) {
new_name.replace(pos, 9, "layer.1.DenseReluDense.wo.");
}
pos = new_name.find("ffn_gate.");
if (pos != std::string::npos) {
new_name.replace(pos, 9, "layer.1.DenseReluDense.wi_0.");
}
pos = new_name.find("attn_rel_b.");
if (pos != std::string::npos) {
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
}
} else if (contains(name, "qwen2vl")) {
if (contains(name, "qwen2vl.visual")) {
for (auto kv : qwenvl_vision_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
} else {
for (auto kv : qwenvl_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
}
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {
new_name = "text_encoders.t5xxl.transformer.shared.weight";
}
if (starts_with(new_name, "conditioner.embedders.0.open_clip.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("conditioner.embedders.0.open_clip."));
} else if (starts_with(new_name, "conditioner.embedders.0.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "conditioner.embedders.1.")) {
prefix = "cond_stage_model.1.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "cond_stage_model.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("cond_stage_model."));
} else if (ends_with(new_name, "vision_model.visual_projection.weight")) {
prefix = new_name.substr(0, new_name.size() - strlen("vision_model.visual_projection.weight"));
new_name = prefix + "visual_projection.weight";
return new_name;
} else if (ends_with(new_name, "transformer.text_projection.weight")) {
prefix = new_name.substr(0, new_name.size() - strlen("transformer.text_projection.weight"));
new_name = prefix + "transformer.text_model.text_projection";
return new_name;
} else {
return new_name;
}
if (new_name == "model.text_projection.weight") {
new_name = "transformer.text_model.text_projection";
}
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
}
if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) {
new_name = cond_model_name_map[new_name];
}
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
auto replace_suffix = [&]() {
if (new_name.find(open_clip_resblock_prefix) == 0) {
std::string remain = new_name.substr(open_clip_resblock_prefix.length());
std::string idx = remain.substr(0, remain.find("."));
std::string suffix = remain.substr(idx.length() + 1);
if (open_clip_to_hf_clip_resblock.find(suffix) != open_clip_to_hf_clip_resblock.end()) {
std::string new_suffix = open_clip_to_hf_clip_resblock[suffix];
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix;
}
}
};
replace_suffix();
open_clip_resblock_prefix = "model.visual.transformer.resblocks.";
hf_clip_resblock_prefix = "transformer.vision_model.encoder.layers.";
replace_suffix();
return prefix + new_name;
}
std::string convert_vae_decoder_name(const std::string& name) {
if (vae_decoder_name_map.find(name) != vae_decoder_name_map.end()) {
return vae_decoder_name_map[name];
}
return name;
}
std::string convert_pmid_v2_name(const std::string& name) {
if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) {
return pmid_v2_name_map[name];
}
return name;
}
/* If not a SDXL LoRA the unet" prefix will have already been replaced by this
* point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */
std::string convert_sdxl_lora_name(std::string tensor_name) {
const std::pair<std::string, std::string> sdxl_lora_name_lookup[] = {
{"unet", "model_diffusion_model"},
{"te2", "cond_stage_model_1_transformer"},
{"te1", "cond_stage_model_transformer"},
{"text_encoder_2", "cond_stage_model_1_transformer"},
{"text_encoder", "cond_stage_model_transformer"},
};
for (auto& pair_i : sdxl_lora_name_lookup) {
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) {
tensor_name = std::regex_replace(tensor_name, std::regex(pair_i.first), pair_i.second);
break;
}
}
return tensor_name;
}
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> suffix_conversion_underline = {
{
"attentions",
{
{"to_k", "k"},
{"to_q", "q"},
{"to_v", "v"},
{"to_out_0", "proj_out"},
{"group_norm", "norm"},
{"key", "k"},
{"query", "q"},
{"value", "v"},
{"proj_attn", "proj_out"},
},
},
{
"resnets",
{
{"conv1", "in_layers_2"},
{"conv2", "out_layers_3"},
{"norm1", "in_layers_0"},
{"norm2", "out_layers_0"},
{"time_emb_proj", "emb_layers_1"},
{"conv_shortcut", "skip_connection"},
},
},
};
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> suffix_conversion_dot = {
{
"attentions",
{
{"to_k", "k"},
{"to_q", "q"},
{"to_v", "v"},
{"to_out.0", "proj_out"},
{"group_norm", "norm"},
{"key", "k"},
{"query", "q"},
{"value", "v"},
{"proj_attn", "proj_out"},
},
},
{
"resnets",
{
{"conv1", "in_layers.2"},
{"conv2", "out_layers.3"},
{"norm1", "in_layers.0"},
{"norm2", "out_layers.0"},
{"time_emb_proj", "emb_layers.1"},
{"conv_shortcut", "skip_connection"},
},
},
};
std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
std::vector<std::string> m;
auto match = [](std::vector<std::string>& match_list, const std::regex& regex, const std::string& key) {
auto r = std::smatch{};
if (!std::regex_match(key, r, regex)) {
return false;
}
match_list.clear();
for (size_t i = 1; i < r.size(); ++i) {
match_list.push_back(r.str(i));
}
return true;
};
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> suffix_conversion;
if (seq == '_') {
suffix_conversion = suffix_conversion_underline;
} else {
suffix_conversion = suffix_conversion_dot;
}
auto get_converted_suffix = [&suffix_conversion](const std::string& outer_key, const std::string& inner_key) {
auto outer_iter = suffix_conversion.find(outer_key);
if (outer_iter != suffix_conversion.end()) {
auto inner_iter = outer_iter->second.find(inner_key);
if (inner_iter != outer_iter->second.end()) {
return inner_iter->second;
}
}
return inner_key;
};
// convert attn to out
if (ends_with(key, "to_out")) {
key += format("%c0", seq);
}
// unet
if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) {
return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
}
if (match(m, std::regex(format("unet%cconv%cout(.*)", seq, seq)), key)) {
return format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0];
}
if (match(m, std::regex(format("unet%cconv_norm_out(.*)", seq)), key)) {
return format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0];
}
if (match(m, std::regex(format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
}
if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
}
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[1], m[3]);
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
(m[1] == "attentions" ? "1" : "0") + seq + suffix;
}
if (match(m, std::regex(format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[0], m[2]);
return format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) +
seq + suffix;
}
if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[1], m[3]);
return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
(m[1] == "attentions" ? "1" : "0") + seq + suffix;
}
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op";
}
if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq +
(std::stoi(m[0]) > 0 ? "2" : "1") + seq + "conv";
}
// clip
if (match(m, std::regex(format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
return format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1];
}
if (match(m, std::regex(format("te%ctext_model(.*)", seq)), key)) {
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
}
// clip-g
if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1];
}
if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) {
return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0];
}
if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) {
return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq);
}
// vae
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
}
if (match(m, std::regex(format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix;
std::string block_name;
if (m[1] == "attentions") {
block_name = "attn";
suffix = get_converted_suffix(m[1], m[3]);
} else {
block_name = "block";
suffix = m[3];
}
return format("first_stage_model%c%s%cmid%c%s_%d%c%s",
seq, m[0].c_str(), seq, seq, block_name.c_str(), std::stoi(m[2]) + 1, seq, suffix.c_str());
}
if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
std::string suffix = m[3];
if (suffix == "conv_shortcut") {
suffix = "nin_shortcut";
}
return format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s",
seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str());
}
if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
return format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv",
seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq);
}
if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
std::string suffix = m[3];
if (suffix == "conv_shortcut") {
suffix = "nin_shortcut";
}
return format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s",
seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str());
}
if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
return format("first_stage_model%c%s%cup%c%d%cupsample%cconv",
seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq);
}
if (match(m, std::regex(format("vae%c(.*)", seq)), key)) {
return format("first_stage_model%c", seq) + m[0];
}
return key;
}
std::string convert_tensor_name(std::string name) {
if (starts_with(name, "diffusion_model")) {
name = "model." + name;
}
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) {
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1,
"model.diffusion_model.output_blocks.0.1.");
}
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) {
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1,
"model.diffusion_model.output_blocks.1.1.");
}
// size_t pos = name.find("lora_A");
// if (pos != std::string::npos) {
// name.replace(pos, strlen("lora_A"), "lora_up");
// }
// pos = name.find("lora_B");
// if (pos != std::string::npos) {
// name.replace(pos, strlen("lora_B"), "lora_down");
// }
std::string new_name = name;
if (starts_with(name, "cond_stage_model.") ||
starts_with(name, "conditioner.embedders.") ||
starts_with(name, "text_encoders.") ||
ends_with(name, ".vision_model.visual_projection.weight") ||
starts_with(name, "qwen2vl")) {
new_name = convert_cond_model_name(name);
} else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name);
} else if (starts_with(name, "pmid.qformer_perceiver")) {
new_name = convert_pmid_v2_name(name);
} else if (starts_with(name, "control_model.")) { // for controlnet pth models
size_t pos = name.find('.');
if (pos != std::string::npos) {
new_name = name.substr(pos + 1);
}
} else if (starts_with(name, "lora_")) { // for lora
size_t pos = name.find('.');
if (pos != std::string::npos) {
std::string name_without_network_parts = name.substr(5, pos - 5);
std::string network_part = name.substr(pos + 1);
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '_');
/* For dealing with the new SDXL LoRA tensor naming convention */
new_key = convert_sdxl_lora_name(new_key);
if (new_key.empty()) {
new_name = name;
} else {
new_name = "lora." + new_key + "." + network_part;
}
} else {
new_name = name;
}
} else if (ends_with(name, ".diff") || ends_with(name, ".diff_b")) {
new_name = "lora." + name;
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
contains(name, "lora.up") || contains(name, "lora.down") ||
contains(name, "lora_linear") || ends_with(name, ".alpha")) {
size_t pos = new_name.find(".processor");
if (pos != std::string::npos) {
new_name.replace(pos, strlen(".processor"), "");
}
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
// new_name = "model.diffusion_model." + new_name;
// }
if (ends_with(name, ".alpha")) {
pos = new_name.rfind("alpha");
} else {
pos = new_name.rfind("lora");
}
if (pos != std::string::npos) {
std::string name_without_network_parts = new_name.substr(0, pos - 1);
std::string network_part = new_name.substr(pos);
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
new_key = convert_sdxl_lora_name(new_key);
replace_all_chars(new_key, '.', '_');
size_t npos = network_part.rfind("_linear_layer");
if (npos != std::string::npos) {
network_part.replace(npos, strlen("_linear_layer"), "");
}
if (starts_with(network_part, "lora.")) {
network_part = "lora_" + network_part.substr(5);
}
if (new_key.size() > 0) {
new_name = "lora." + new_key + "." + network_part;
}
// LOG_DEBUG("new name: %s", new_name.c_str());
}
} else if (starts_with(name, "unet") || starts_with(name, "vae") || starts_with(name, "te")) { // for diffuser
size_t pos = name.find_last_of('.');
if (pos != std::string::npos) {
std::string name_without_network_parts = name.substr(0, pos);
std::string network_part = name.substr(pos + 1);
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
if (new_key.empty()) {
new_name = name;
} else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") {
new_name = new_key;
} else {
new_name = new_key + "." + network_part;
}
} else {
new_name = name;
}
} else {
new_name = name;
}
// if (new_name != name) {
// LOG_DEBUG("%s => %s", name.c_str(), new_name.c_str());
// }
return new_name;
}
float bf16_to_f32(uint16_t bfloat16) { float bf16_to_f32(uint16_t bfloat16) {
uint32_t val_bits = (static_cast<uint32_t>(bfloat16) << 16); uint32_t val_bits = (static_cast<uint32_t>(bfloat16) << 16);
return *reinterpret_cast<float*>(&val_bits); return *reinterpret_cast<float*>(&val_bits);
@ -916,9 +293,7 @@ void convert_tensor(void* src,
/*================================================= ModelLoader ==================================================*/ /*================================================= ModelLoader ==================================================*/
void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) {
TensorStorage copy = tensor_storage; tensor_storage_map[tensor_storage.name] = tensor_storage;
copy.name = convert_tensor_name(copy.name);
tensor_storage_map[copy.name] = std::move(copy);
} }
bool is_zip_file(const std::string& file_path) { bool is_zip_file(const std::string& file_path) {
@ -1012,6 +387,31 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
} }
} }
void ModelLoader::convert_tensors_name() {
SDVersion version = (version_ == VERSION_COUNT) ? get_sd_version() : version_;
String2TensorStorage new_map;
for (auto& [_, tensor_storage] : tensor_storage_map) {
auto new_name = convert_tensor_name(tensor_storage.name, version);
// LOG_DEBUG("%s -> %s", tensor_storage.name.c_str(), new_name.c_str());
tensor_storage.name = new_name;
new_map[new_name] = std::move(tensor_storage);
}
tensor_storage_map.swap(new_map);
}
bool ModelLoader::init_from_file_and_convert_name(const std::string& file_path, const std::string& prefix, SDVersion version) {
if (version_ == VERSION_COUNT && version != VERSION_COUNT) {
version_ = version;
}
if (!init_from_file(file_path, prefix)) {
return false;
}
convert_tensors_name();
return true;
}
/*================================================= GGUFModelLoader ==================================================*/ /*================================================= GGUFModelLoader ==================================================*/
bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) { bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) {
@ -1259,32 +659,6 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
if (!init_from_safetensors_file(unet_path, "unet.")) { if (!init_from_safetensors_file(unet_path, "unet.")) {
return false; return false;
} }
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (name.find("add_embedding") != std::string::npos || name.find("label_emb") != std::string::npos) {
// probably SDXL
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
String2TensorStorage new_tensor_storage_map;
for (auto& [name, tensor_storage] : tensor_storage_map) {
int len = 34;
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
if (pos == std::string::npos) {
len = 44;
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
}
if (pos != std::string::npos) {
std::string new_name = "model.diffusion_model.output_blocks.2.2.conv" + name.substr(len);
LOG_DEBUG("NEW NAME: %s", new_name.c_str());
tensor_storage.name = new_name;
new_tensor_storage_map[new_name] = tensor_storage;
} else {
new_tensor_storage_map[name] = tensor_storage;
}
}
tensor_storage_map = new_tensor_storage_map;
break;
}
}
if (!init_from_safetensors_file(vae_path, "vae.")) { if (!init_from_safetensors_file(vae_path, "vae.")) {
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
@ -1925,7 +1299,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
int64_t start_time = ggml_time_ms(); int64_t start_time = ggml_time_ms();
std::vector<TensorStorage> processed_tensor_storages; std::vector<TensorStorage> processed_tensor_storages;
for (auto& [name, tensor_storage] : tensor_storage_map) { for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (is_unused_tensor(tensor_storage.name)) { if (is_unused_tensor(tensor_storage.name)) {
continue; continue;
} }
@ -2394,6 +1768,7 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
return false; return false;
} }
} }
model_loader.convert_tensors_name();
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
return success; return success;
} }

14
model.h
View File

@ -15,6 +15,7 @@
#include "ggml.h" #include "ggml.h"
#include "gguf.h" #include "gguf.h"
#include "json.hpp" #include "json.hpp"
#include "ordered_map.hpp"
#include "zip.h" #include "zip.h"
#define SD_MAX_DIMS 5 #define SD_MAX_DIMS 5
@ -108,7 +109,11 @@ static inline bool sd_version_is_qwen_image(SDVersion version) {
} }
static inline bool sd_version_is_inpaint(SDVersion version) { static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) { if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||
version == VERSION_SDXL_INPAINT ||
version == VERSION_FLUX_FILL ||
version == VERSION_FLEX_2) {
return true; return true;
} }
return false; return false;
@ -253,10 +258,11 @@ struct TensorStorage {
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t; typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
typedef std::map<std::string, TensorStorage> String2TensorStorage; typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
class ModelLoader { class ModelLoader {
protected: protected:
SDVersion version_ = VERSION_COUNT;
std::vector<std::string> file_paths_; std::vector<std::string> file_paths_;
String2TensorStorage tensor_storage_map; String2TensorStorage tensor_storage_map;
@ -276,6 +282,10 @@ protected:
public: public:
bool init_from_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_file(const std::string& file_path, const std::string& prefix = "");
void convert_tensors_name();
bool init_from_file_and_convert_name(const std::string& file_path,
const std::string& prefix = "",
SDVersion version = VERSION_COUNT);
SDVersion get_sd_version(); SDVersion get_sd_version();
std::map<ggml_type, uint32_t> get_wtype_stat(); std::map<ggml_type, uint32_t> get_wtype_stat();
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat(); std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();

1028
name_conversion.cpp Normal file

File diff suppressed because it is too large Load Diff

10
name_conversion.h Normal file
View File

@ -0,0 +1,10 @@
#ifndef __NAME_CONVERSTION_H__
#define __NAME_CONVERSTION_H__
#include <string>
#include "model.h"
std::string convert_tensor_name(std::string name, SDVersion version);
#endif // __NAME_CONVERSTION_H__

177
ordered_map.hpp Normal file
View File

@ -0,0 +1,177 @@
#ifndef __ORDERED_MAP_HPP__
#define __ORDERED_MAP_HPP__
#include <iostream>
#include <list>
#include <string>
#include <unordered_map>
#include <initializer_list>
#include <iterator>
#include <list>
#include <stdexcept>
#include <unordered_map>
#include <utility>
template <typename Key, typename T>
class OrderedMap {
public:
using key_type = Key;
using mapped_type = T;
using value_type = std::pair<const Key, T>;
using list_type = std::list<value_type>;
using size_type = typename list_type::size_type;
using difference_type = typename list_type::difference_type;
using iterator = typename list_type::iterator;
using const_iterator = typename list_type::const_iterator;
private:
list_type data_;
std::unordered_map<Key, iterator> index_;
public:
// --- constructors ---
OrderedMap() = default;
OrderedMap(std::initializer_list<value_type> init) {
for (const auto& kv : init)
insert(kv);
}
OrderedMap(const OrderedMap&) = default;
OrderedMap(OrderedMap&&) noexcept = default;
OrderedMap& operator=(const OrderedMap&) = default;
OrderedMap& operator=(OrderedMap&&) noexcept = default;
// --- element access ---
T& at(const Key& key) {
auto it = index_.find(key);
if (it == index_.end())
throw std::out_of_range("OrderedMap::at: key not found");
return it->second->second;
}
const T& at(const Key& key) const {
auto it = index_.find(key);
if (it == index_.end())
throw std::out_of_range("OrderedMap::at: key not found");
return it->second->second;
}
T& operator[](const Key& key) {
auto it = index_.find(key);
if (it == index_.end()) {
data_.emplace_back(key, T{});
auto iter = std::prev(data_.end());
index_[key] = iter;
return iter->second;
}
return it->second->second;
}
// --- iterators ---
iterator begin() noexcept { return data_.begin(); }
const_iterator begin() const noexcept { return data_.begin(); }
const_iterator cbegin() const noexcept { return data_.cbegin(); }
iterator end() noexcept { return data_.end(); }
const_iterator end() const noexcept { return data_.end(); }
const_iterator cend() const noexcept { return data_.cend(); }
// --- capacity ---
bool empty() const noexcept { return data_.empty(); }
size_type size() const noexcept { return data_.size(); }
// --- modifiers ---
void clear() noexcept {
data_.clear();
index_.clear();
}
std::pair<iterator, bool> insert(const value_type& value) {
auto it = index_.find(value.first);
if (it != index_.end()) {
return {it->second, false};
}
data_.push_back(value);
auto iter = std::prev(data_.end());
index_[value.first] = iter;
return {iter, true};
}
std::pair<iterator, bool> insert(value_type&& value) {
auto it = index_.find(value.first);
if (it != index_.end()) {
return {it->second, false};
}
data_.push_back(std::move(value));
auto iter = std::prev(data_.end());
index_[iter->first] = iter;
return {iter, true};
}
void erase(const Key& key) {
auto it = index_.find(key);
if (it != index_.end()) {
data_.erase(it->second);
index_.erase(it);
}
}
iterator erase(iterator pos) {
index_.erase(pos->first);
return data_.erase(pos);
}
// --- lookup ---
size_type count(const Key& key) const {
return index_.count(key);
}
iterator find(const Key& key) {
auto it = index_.find(key);
if (it == index_.end())
return data_.end();
return it->second;
}
const_iterator find(const Key& key) const {
auto it = index_.find(key);
if (it == index_.end())
return data_.end();
return it->second;
}
bool contains(const Key& key) const {
return index_.find(key) != index_.end();
}
// --- comparison ---
bool operator==(const OrderedMap& other) const {
return data_ == other.data_;
}
bool operator!=(const OrderedMap& other) const {
return !(*this == other);
}
template <typename... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
value_type value(std::forward<Args>(args)...);
auto it = index_.find(value.first);
if (it != index_.end()) {
return {it->second, false};
}
data_.push_back(std::move(value));
auto iter = std::prev(data_.end());
index_[iter->first] = iter;
return {iter, true};
}
void swap(OrderedMap& other) noexcept {
data_.swap(other.data_);
index_.swap(other.index_);
}
};
#endif // __ORDERED_MAP_HPP__

View File

@ -578,7 +578,7 @@ struct PhotoMakerIDEmbed : public GGMLRunner {
const std::string& file_path = "", const std::string& file_path = "",
const std::string& prefix = "") const std::string& prefix = "")
: file_path(file_path), GGMLRunner(backend, offload_params_to_cpu), model_loader(ml) { : file_path(file_path), GGMLRunner(backend, offload_params_to_cpu), model_loader(ml) {
if (!model_loader->init_from_file(file_path, prefix)) { if (!model_loader->init_from_file_and_convert_name(file_path, prefix)) {
load_failed = true; load_failed = true;
} }
} }

View File

@ -644,7 +644,7 @@ namespace Qwen {
ggml_type model_data_type = GGML_TYPE_Q8_0; ggml_type model_data_type = GGML_TYPE_Q8_0;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }

View File

@ -1342,7 +1342,7 @@ namespace Qwen {
ggml_type model_data_type = GGML_TYPE_F16; ggml_type model_data_type = GGML_TYPE_F16;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "qwen2vl.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "qwen2vl.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }

View File

@ -278,6 +278,8 @@ public:
} }
} }
model_loader.convert_tensors_name();
version = model_loader.get_sd_version(); version = model_loader.get_sd_version();
if (version == VERSION_COUNT) { if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path)); LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
@ -569,13 +571,13 @@ public:
version); version);
} }
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, ""); pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "", version);
if (!pmid_lora->load_from_file(true, n_threads)) { if (!pmid_lora->load_from_file(true, n_threads)) {
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
return false; return false;
} }
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path); LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path);
if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) { if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
} else { } else {
stacked_id = true; stacked_id = true;
@ -609,7 +611,7 @@ public:
ignore_tensors.insert("first_stage_model."); ignore_tensors.insert("first_stage_model.");
} }
if (stacked_id) { if (stacked_id) {
ignore_tensors.insert("lora."); ignore_tensors.insert("pmid.unet.");
} }
if (vae_decode_only) { if (vae_decode_only) {
@ -925,7 +927,7 @@ public:
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return; return;
} }
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : ""); LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "", version);
if (!lora.load_from_file(false, n_threads)) { if (!lora.load_from_file(false, n_threads)) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str()); LOG_WARN("load lora tensors from %s failed", file_path.c_str());
return; return;

2
t5.hpp
View File

@ -1004,7 +1004,7 @@ struct T5Embedder {
ggml_type model_data_type = GGML_TYPE_F16; ggml_type model_data_type = GGML_TYPE_F16;
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) { if (!model_loader.init_from_file_and_convert_name(file_path)) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }

View File

@ -222,7 +222,7 @@ struct TinyAutoEncoder : public GGMLRunner {
} }
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) { if (!model_loader.init_from_file_and_convert_name(file_path)) {
LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str());
return false; return false;
} }

View File

@ -42,7 +42,7 @@ struct UpscalerGGML {
backend = ggml_backend_sycl_init(0); backend = ggml_backend_sycl_init(0);
#endif #endif
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(esrgan_path)) { if (!model_loader.init_from_file_and_convert_name(esrgan_path)) {
LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str());
} }
model_loader.set_wtype_override(model_data_type); model_loader.set_wtype_override(model_data_type);

View File

@ -1271,7 +1271,7 @@ namespace WAN {
vae->get_param_tensors(tensors, "first_stage_model"); vae->get_param_tensors(tensors, "first_stage_model");
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "vae.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "vae.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }
@ -2255,7 +2255,7 @@ namespace WAN {
LOG_INFO("loading from '%s'", file_path.c_str()); LOG_INFO("loading from '%s'", file_path.c_str());
ModelLoader model_loader; ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return; return;
} }