mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
Compare commits
No commits in common. "dafc32d0dd0922272f970c93f961834af460f013" and "b9e4718facf1794f47e4259dd536b2f5f3c39cd2" have entirely different histories.
dafc32d0dd
...
b9e4718fac
82
model.cpp
82
model.cpp
@ -100,7 +100,6 @@ const char* unused_tensors[] = {
|
|||||||
"model_ema.diffusion_model",
|
"model_ema.diffusion_model",
|
||||||
"embedding_manager",
|
"embedding_manager",
|
||||||
"denoiser.sigmas",
|
"denoiser.sigmas",
|
||||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_unused_tensor(std::string name) {
|
bool is_unused_tensor(std::string name) {
|
||||||
@ -338,10 +337,6 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
|
|||||||
{"to_v", "v"},
|
{"to_v", "v"},
|
||||||
{"to_out_0", "proj_out"},
|
{"to_out_0", "proj_out"},
|
||||||
{"group_norm", "norm"},
|
{"group_norm", "norm"},
|
||||||
{"key", "k"},
|
|
||||||
{"query", "q"},
|
|
||||||
{"value", "v"},
|
|
||||||
{"proj_attn", "proj_out"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -366,10 +361,6 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
|
|||||||
{"to_v", "v"},
|
{"to_v", "v"},
|
||||||
{"to_out.0", "proj_out"},
|
{"to_out.0", "proj_out"},
|
||||||
{"group_norm", "norm"},
|
{"group_norm", "norm"},
|
||||||
{"key", "k"},
|
|
||||||
{"query", "q"},
|
|
||||||
{"value", "v"},
|
|
||||||
{"proj_attn", "proj_out"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -441,10 +432,6 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
|
|||||||
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
|
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
|
|
||||||
return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
|
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
|
||||||
std::string suffix = get_converted_suffix(m[1], m[3]);
|
std::string suffix = get_converted_suffix(m[1], m[3]);
|
||||||
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
|
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
|
||||||
@ -482,19 +469,6 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
|
|||||||
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
|
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// clip-g
|
|
||||||
if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
|
|
||||||
return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) {
|
|
||||||
return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) {
|
|
||||||
return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq);
|
|
||||||
}
|
|
||||||
|
|
||||||
// vae
|
// vae
|
||||||
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
|
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
|
||||||
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
|
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
|
||||||
@ -631,8 +605,6 @@ std::string convert_tensor_name(std::string name) {
|
|||||||
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
|
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
|
||||||
if (new_key.empty()) {
|
if (new_key.empty()) {
|
||||||
new_name = name;
|
new_name = name;
|
||||||
} else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") {
|
|
||||||
new_name = new_key;
|
|
||||||
} else {
|
} else {
|
||||||
new_name = new_key + "." + network_part;
|
new_name = new_key + "." + network_part;
|
||||||
}
|
}
|
||||||
@ -1056,14 +1028,10 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
|||||||
ttype = GGML_TYPE_F32;
|
ttype = GGML_TYPE_F32;
|
||||||
} else if (dtype == "F32") {
|
} else if (dtype == "F32") {
|
||||||
ttype = GGML_TYPE_F32;
|
ttype = GGML_TYPE_F32;
|
||||||
} else if (dtype == "F64") {
|
|
||||||
ttype = GGML_TYPE_F64;
|
|
||||||
} else if (dtype == "F8_E4M3") {
|
} else if (dtype == "F8_E4M3") {
|
||||||
ttype = GGML_TYPE_F16;
|
ttype = GGML_TYPE_F16;
|
||||||
} else if (dtype == "F8_E5M2") {
|
} else if (dtype == "F8_E5M2") {
|
||||||
ttype = GGML_TYPE_F16;
|
ttype = GGML_TYPE_F16;
|
||||||
} else if (dtype == "I64") {
|
|
||||||
ttype = GGML_TYPE_I64;
|
|
||||||
}
|
}
|
||||||
return ttype;
|
return ttype;
|
||||||
}
|
}
|
||||||
@ -1076,7 +1044,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
std::ifstream file(file_path, std::ios::binary);
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
if (!file.is_open()) {
|
if (!file.is_open()) {
|
||||||
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
LOG_ERROR("failed to open '%s'", file_path.c_str());
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1088,7 +1055,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
// read header size
|
// read header size
|
||||||
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
if (file_size_ <= ST_HEADER_SIZE_LEN) {
|
||||||
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1102,7 +1068,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
size_t header_size_ = read_u64(header_size_buf);
|
size_t header_size_ = read_u64(header_size_buf);
|
||||||
if (header_size_ >= file_size_) {
|
if (header_size_ >= file_size_) {
|
||||||
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1113,7 +1078,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
file.read(header_buf.data(), header_size_);
|
file.read(header_buf.data(), header_size_);
|
||||||
if (!file) {
|
if (!file) {
|
||||||
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
|
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
|
||||||
file_paths_.pop_back();
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1169,7 +1133,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
n_dims = 1;
|
n_dims = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||||
tensor_storage.reverse_ne();
|
tensor_storage.reverse_ne();
|
||||||
|
|
||||||
@ -1202,45 +1165,18 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
/*================================================= DiffusersModelLoader ==================================================*/
|
/*================================================= DiffusersModelLoader ==================================================*/
|
||||||
|
|
||||||
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
|
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
|
||||||
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
|
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
|
||||||
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
|
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
|
||||||
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
|
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
|
||||||
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
|
|
||||||
|
|
||||||
if (!init_from_safetensors_file(unet_path, "unet.")) {
|
if (!init_from_safetensors_file(unet_path, "unet.")) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (auto ts : tensor_storages) {
|
|
||||||
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
|
|
||||||
// probably SDXL
|
|
||||||
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
|
|
||||||
for (auto& tensor_storage : tensor_storages) {
|
|
||||||
int len = 34;
|
|
||||||
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
|
|
||||||
if (pos == std::string::npos) {
|
|
||||||
len = 44;
|
|
||||||
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
|
|
||||||
}
|
|
||||||
if (pos != std::string::npos) {
|
|
||||||
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
|
|
||||||
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
|
|
||||||
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!init_from_safetensors_file(vae_path, "vae.")) {
|
if (!init_from_safetensors_file(vae_path, "vae.")) {
|
||||||
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
|
return false;
|
||||||
// return false;
|
|
||||||
}
|
}
|
||||||
if (!init_from_safetensors_file(clip_path, "te.")) {
|
if (!init_from_safetensors_file(clip_path, "te.")) {
|
||||||
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
|
return false;
|
||||||
// return false;
|
|
||||||
}
|
|
||||||
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
|
|
||||||
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1634,7 +1570,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||||
return VERSION_SD3;
|
return VERSION_SD3;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
|
||||||
is_unet = true;
|
is_unet = true;
|
||||||
if (has_multiple_encoders) {
|
if (has_multiple_encoders) {
|
||||||
is_xl = true;
|
is_xl = true;
|
||||||
@ -1643,7 +1579,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
|
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
|
||||||
has_multiple_encoders = true;
|
has_multiple_encoders = true;
|
||||||
if (is_unet) {
|
if (is_unet) {
|
||||||
is_xl = true;
|
is_xl = true;
|
||||||
@ -1665,7 +1601,7 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
token_embedding_weight = tensor_storage;
|
token_embedding_weight = tensor_storage;
|
||||||
// break;
|
// break;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
|
||||||
input_block_weight = tensor_storage;
|
input_block_weight = tensor_storage;
|
||||||
input_block_checked = true;
|
input_block_checked = true;
|
||||||
if (found_family) {
|
if (found_family) {
|
||||||
@ -1750,7 +1686,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user