mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
feat: override text encoders for unet models (#682)
This commit is contained in:
parent
76c72628b1
commit
19fbfd8639
@ -1539,6 +1539,15 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ModelLoader::model_is_unet() {
|
||||||
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
SDVersion ModelLoader::get_sd_version() {
|
SDVersion ModelLoader::get_sd_version() {
|
||||||
TensorStorage token_embedding_weight, input_block_weight;
|
TensorStorage token_embedding_weight, input_block_weight;
|
||||||
bool input_block_checked = false;
|
bool input_block_checked = false;
|
||||||
|
|||||||
1
model.h
1
model.h
@ -210,6 +210,7 @@ public:
|
|||||||
std::map<std::string, enum ggml_type> tensor_storages_types;
|
std::map<std::string, enum ggml_type> tensor_storages_types;
|
||||||
|
|
||||||
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
|
||||||
|
bool model_is_unet();
|
||||||
SDVersion get_sd_version();
|
SDVersion get_sd_version();
|
||||||
ggml_type get_sd_wtype();
|
ggml_type get_sd_wtype();
|
||||||
ggml_type get_conditioner_wtype();
|
ggml_type get_conditioner_wtype();
|
||||||
|
|||||||
@ -213,16 +213,25 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (diffusion_model_path.size() > 0) {
|
||||||
|
LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str());
|
||||||
|
if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) {
|
||||||
|
LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_unet = model_loader.model_is_unet();
|
||||||
|
|
||||||
if (clip_l_path.size() > 0) {
|
if (clip_l_path.size() > 0) {
|
||||||
LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str());
|
LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str());
|
||||||
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) {
|
if (!model_loader.init_from_file(clip_l_path, is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.")) {
|
||||||
LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str());
|
LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (clip_g_path.size() > 0) {
|
if (clip_g_path.size() > 0) {
|
||||||
LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str());
|
LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str());
|
||||||
if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) {
|
if (!model_loader.init_from_file(clip_g_path, is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.")) {
|
||||||
LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str());
|
LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -234,13 +243,6 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (diffusion_model_path.size() > 0) {
|
|
||||||
LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str());
|
|
||||||
if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) {
|
|
||||||
LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (vae_path.size() > 0) {
|
if (vae_path.size() > 0) {
|
||||||
LOG_INFO("loading vae from '%s'", vae_path.c_str());
|
LOG_INFO("loading vae from '%s'", vae_path.c_str());
|
||||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user