feat: add support for f64/i64 and clip_g diffusers model (#681)

This commit is contained in:
stduhpf 2025-07-06 17:24:55 +02:00 committed by GitHub
parent 225162f270
commit dafc32d0dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -338,6 +338,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
{"to_v", "v"},
{"to_out_0", "proj_out"},
{"group_norm", "norm"},
{"key", "k"},
{"query", "q"},
{"value", "v"},
{"proj_attn", "proj_out"},
},
},
{
@ -362,6 +366,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
{"to_v", "v"},
{"to_out.0", "proj_out"},
{"group_norm", "norm"},
{"key", "k"},
{"query", "q"},
{"value", "v"},
{"proj_attn", "proj_out"},
},
},
{
@ -433,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
}
if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
}
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
std::string suffix = get_converted_suffix(m[1], m[3]);
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
@ -470,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
}
// clip-g
if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1];
}
if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) {
return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0];
}
if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) {
return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq);
}
// vae
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
@ -606,6 +631,8 @@ std::string convert_tensor_name(std::string name) {
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
if (new_key.empty()) {
new_name = name;
} else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") {
new_name = new_key;
} else {
new_name = new_key + "." + network_part;
}
@ -1029,10 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
ttype = GGML_TYPE_F32;
} else if (dtype == "F32") {
ttype = GGML_TYPE_F32;
} else if (dtype == "F64") {
ttype = GGML_TYPE_F64;
} else if (dtype == "F8_E4M3") {
ttype = GGML_TYPE_F16;
} else if (dtype == "F8_E5M2") {
ttype = GGML_TYPE_F16;
} else if (dtype == "I64") {
ttype = GGML_TYPE_I64;
}
return ttype;
}
@ -1045,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
std::ifstream file(file_path, std::ios::binary);
if (!file.is_open()) {
LOG_ERROR("failed to open '%s'", file_path.c_str());
file_paths_.pop_back();
return false;
}
@ -1056,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
// read header size
if (file_size_ <= ST_HEADER_SIZE_LEN) {
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
file_paths_.pop_back();
return false;
}
@ -1069,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
size_t header_size_ = read_u64(header_size_buf);
if (header_size_ >= file_size_) {
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
file_paths_.pop_back();
return false;
}
@ -1079,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
file.read(header_buf.data(), header_size_);
if (!file) {
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
file_paths_.pop_back();
return false;
}
@ -1134,6 +1169,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
n_dims = 1;
}
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();
@ -1166,18 +1202,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
/*================================================= DiffusersModelLoader ==================================================*/
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
if (!init_from_safetensors_file(unet_path, "unet.")) {
return false;
}
for (auto ts : tensor_storages) {
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
// probably SDXL
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
for (auto& tensor_storage : tensor_storages) {
int len = 34;
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
if (pos == std::string::npos) {
len = 44;
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
}
if (pos != std::string::npos) {
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
}
}
break;
}
}
if (!init_from_safetensors_file(vae_path, "vae.")) {
return false;
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_path, "te.")) {
return false;
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
}
return true;
}
@ -1571,7 +1634,7 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true;
if (has_multiple_encoders) {
is_xl = true;
@ -1580,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
}
}
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
has_multiple_encoders = true;
if (is_unet) {
is_xl = true;
@ -1602,7 +1665,7 @@ SDVersion ModelLoader::get_sd_version() {
token_embedding_weight = tensor_storage;
// break;
}
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (found_family) {
@ -1687,7 +1750,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
continue;
}
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) {
continue;
}