mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add support for f64/i64 and clip_g diffusers model (#681)
This commit is contained in:
parent
225162f270
commit
dafc32d0dd
81
model.cpp
81
model.cpp
@ -338,6 +338,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
|
||||
{"to_v", "v"},
|
||||
{"to_out_0", "proj_out"},
|
||||
{"group_norm", "norm"},
|
||||
{"key", "k"},
|
||||
{"query", "q"},
|
||||
{"value", "v"},
|
||||
{"proj_attn", "proj_out"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -362,6 +366,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
|
||||
{"to_v", "v"},
|
||||
{"to_out.0", "proj_out"},
|
||||
{"group_norm", "norm"},
|
||||
{"key", "k"},
|
||||
{"query", "q"},
|
||||
{"value", "v"},
|
||||
{"proj_attn", "proj_out"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -433,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
|
||||
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
|
||||
}
|
||||
|
||||
if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
|
||||
return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
|
||||
}
|
||||
|
||||
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
|
||||
std::string suffix = get_converted_suffix(m[1], m[3]);
|
||||
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
|
||||
@ -470,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
|
||||
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
|
||||
}
|
||||
|
||||
// clip-g
|
||||
if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
|
||||
return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1];
|
||||
}
|
||||
|
||||
if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) {
|
||||
return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0];
|
||||
}
|
||||
|
||||
if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) {
|
||||
return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq);
|
||||
}
|
||||
|
||||
// vae
|
||||
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
|
||||
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
|
||||
@ -606,6 +631,8 @@ std::string convert_tensor_name(std::string name) {
|
||||
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
|
||||
if (new_key.empty()) {
|
||||
new_name = name;
|
||||
} else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") {
|
||||
new_name = new_key;
|
||||
} else {
|
||||
new_name = new_key + "." + network_part;
|
||||
}
|
||||
@ -1029,10 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
||||
ttype = GGML_TYPE_F32;
|
||||
} else if (dtype == "F32") {
|
||||
ttype = GGML_TYPE_F32;
|
||||
} else if (dtype == "F64") {
|
||||
ttype = GGML_TYPE_F64;
|
||||
} else if (dtype == "F8_E4M3") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
} else if (dtype == "F8_E5M2") {
|
||||
ttype = GGML_TYPE_F16;
|
||||
} else if (dtype == "I64") {
|
||||
ttype = GGML_TYPE_I64;
|
||||
}
|
||||
return ttype;
|
||||
}
|
||||
@ -1045,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
std::ifstream file(file_path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||
file_paths_.pop_back();
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1056,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
// read header size
|
||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
||||
file_paths_.pop_back();
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1069,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
size_t header_size_ = read_u64(header_size_buf);
|
||||
if (header_size_ >= file_size_) {
|
||||
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
||||
file_paths_.pop_back();
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1079,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
file.read(header_buf.data(), header_size_);
|
||||
if (!file) {
|
||||
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
|
||||
file_paths_.pop_back();
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1134,6 +1169,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
n_dims = 1;
|
||||
}
|
||||
|
||||
|
||||
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||
tensor_storage.reverse_ne();
|
||||
|
||||
@ -1166,18 +1202,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
/*================================================= DiffusersModelLoader ==================================================*/
|
||||
|
||||
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
|
||||
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
|
||||
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
|
||||
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
|
||||
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
|
||||
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
|
||||
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
|
||||
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
|
||||
|
||||
if (!init_from_safetensors_file(unet_path, "unet.")) {
|
||||
return false;
|
||||
}
|
||||
for (auto ts : tensor_storages) {
|
||||
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
|
||||
// probably SDXL
|
||||
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
int len = 34;
|
||||
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
|
||||
if (pos == std::string::npos) {
|
||||
len = 44;
|
||||
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
|
||||
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
|
||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!init_from_safetensors_file(vae_path, "vae.")) {
|
||||
return false;
|
||||
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
|
||||
// return false;
|
||||
}
|
||||
if (!init_from_safetensors_file(clip_path, "te.")) {
|
||||
return false;
|
||||
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
|
||||
// return false;
|
||||
}
|
||||
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
|
||||
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -1571,7 +1634,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||
return VERSION_SD3;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
||||
is_unet = true;
|
||||
if (has_multiple_encoders) {
|
||||
is_xl = true;
|
||||
@ -1580,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
|
||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
|
||||
has_multiple_encoders = true;
|
||||
if (is_unet) {
|
||||
is_xl = true;
|
||||
@ -1602,7 +1665,7 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
token_embedding_weight = tensor_storage;
|
||||
// break;
|
||||
}
|
||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
|
||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
||||
input_block_weight = tensor_storage;
|
||||
input_block_checked = true;
|
||||
if (found_family) {
|
||||
@ -1687,7 +1750,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user