diff --git a/model.cpp b/model.cpp index 7a57253..eeee6d3 100644 --- a/model.cpp +++ b/model.cpp @@ -1539,6 +1539,15 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s return true; } +bool ModelLoader::model_is_unet() { + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + return true; + } + } + return false; +} + SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; bool input_block_checked = false; diff --git a/model.h b/model.h index 79c2533..82885dd 100644 --- a/model.h +++ b/model.h @@ -210,6 +210,7 @@ public: std::map tensor_storages_types; bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + bool model_is_unet(); SDVersion get_sd_version(); ggml_type get_sd_wtype(); ggml_type get_conditioner_wtype(); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b5860cf..9c82657 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -213,16 +213,25 @@ public: } } + if (diffusion_model_path.size() > 0) { + LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + } + } + + bool is_unet = model_loader.model_is_unet(); + if (clip_l_path.size() > 0) { LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { + if (!model_loader.init_from_file(clip_l_path, is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.")) { LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); } } if (clip_g_path.size() > 0) { LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); - if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { + if (!model_loader.init_from_file(clip_g_path, is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.")) { LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); } } @@ -234,13 +243,6 @@ public: } } - if (diffusion_model_path.size() > 0) { - LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); - if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { - LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); - } - } - if (vae_path.size() > 0) { LOG_INFO("loading vae from '%s'", vae_path.c_str()); if (!model_loader.init_from_file(vae_path, "vae.")) {