diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fc7c7c6..150ecde1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -216,6 +216,9 @@ file(GLOB SD_LIB_SOURCES CONFIGURE_DEPENDS "src/core/*.h" "src/core/*.cpp" "src/core/*.hpp" + "src/extensions/*.h" + "src/extensions/*.cpp" + "src/extensions/*.hpp" "src/model/*/*.h" "src/model/*/*.cpp" "src/model/*/*.hpp" diff --git a/format-code.sh b/format-code.sh index fb619371..1ed77a5b 100644 --- a/format-code.sh +++ b/format-code.sh @@ -1,6 +1,7 @@ for f in src/*.cpp src/*.h src/*.hpp \ src/conditioning/*.cpp src/conditioning/*.h src/conditioning/*.hpp \ src/core/*.cpp src/core/*.h src/core/*.hpp \ + src/extensions/*.cpp src/extensions/*.h src/extensions/*.hpp \ src/runtime/*.cpp src/runtime/*.h src/runtime/*.hpp \ src/model/*/*.cpp src/model/*/*.h src/model/*/*.hpp \ src/tokenizers/*.h src/tokenizers/*.cpp src/tokenizers/vocab/*.h src/tokenizers/vocab/*.cpp \ diff --git a/src/conditioning/conditioner.hpp b/src/conditioning/conditioner.hpp index 78c3e8fe..217658bf 100644 --- a/src/conditioning/conditioner.hpp +++ b/src/conditioning/conditioner.hpp @@ -103,7 +103,6 @@ struct ConditionerParams { int width = -1; int height = -1; bool zero_out_masked = false; - int num_input_imgs = 0; // for photomaker const std::vector>* ref_images = nullptr; // for qwen image edit }; @@ -121,25 +120,16 @@ public: virtual void set_stream_layers_enabled(bool enabled) {} virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter) {} - virtual std::tuple> get_learned_condition_with_trigger(int n_threads, - const ConditionerParams& conditioner_params) { - GGML_ABORT("Not implemented yet!"); - } - virtual std::string remove_trigger_from_prompt(const std::string& prompt) { - GGML_ABORT("Not implemented yet!"); - } }; // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { - SDVersion version = VERSION_SD1; - PMVersion pm_version = PM_VERSION_1; + SDVersion version = VERSION_SD1; CLIPTokenizer tokenizer; std::shared_ptr text_model; std::shared_ptr text_model2; - std::string trigger_word = "img"; // should be user settable std::map embedding_map; int32_t num_custom_embeddings = 0; int32_t num_custom_embeddings_2 = 0; @@ -150,9 +140,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_backend_t params_backend, const String2TensorStorage& tensor_storage_map, const std::map& orig_embedding_map, - SDVersion version = VERSION_SD1, - PMVersion pv = PM_VERSION_1) - : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) { + SDVersion version = VERSION_SD1) + : version(version), tokenizer(sd_version_is_sd2(version) ? 0 : 49407) { for (const auto& kv : orig_embedding_map) { std::string name = kv.first; std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); }); @@ -329,121 +318,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return tokenizer.decode(tokens); } - std::tuple, std::vector, std::vector> - tokenize_with_trigger_token(std::string text, - int num_input_imgs, - int32_t image_token) { - auto parsed_attention = parse_prompt_attention(text); - - { - std::stringstream ss; - ss << "["; - for (const auto& item : parsed_attention) { - ss << "['" << item.first << "', " << item.second << "], "; - } - ss << "]"; - LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - } - - auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { - auto iter = embedding_map.find(str); - if (iter == embedding_map.end()) { - return false; - } - std::string embedding_path = iter->second; - if (load_embedding(str, embedding_path, bpe_tokens)) { - return true; - } - return false; - }; - - std::vector tokens; - std::vector weights; - std::vector class_token_mask; - int32_t class_idx = -1, tokens_acc = 0; - for (const auto& item : parsed_attention) { - std::vector class_token_index; - std::vector clean_input_ids; - const std::string& curr_text = item.first; - float curr_weight = item.second; - // printf(" %s: %f \n", curr_text.c_str(), curr_weight); - int32_t clean_index = 0; - if (curr_text == "BREAK" && curr_weight == -1.0f) { - // Pad token array up to chunk size at this point. - // TODO: This is a hardcoded chunk_len, like in stable-diffusion.cpp, make it a parameter for the future? - // Also, this is 75 instead of 77 to leave room for BOS and EOS tokens. - int padding_size = 75 - (tokens_acc % 75); - for (int j = 0; j < padding_size; j++) { - clean_input_ids.push_back(tokenizer.EOS_TOKEN_ID); - clean_index++; - } - - // After padding, continue to the next iteration to process the following text as a new segment - tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); - weights.insert(weights.end(), padding_size, curr_weight); - continue; - } - - // Regular token, process normally - std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); - for (uint32_t i = 0; i < curr_tokens.size(); i++) { - int token_id = curr_tokens[i]; - if (token_id == image_token) { - class_token_index.push_back(clean_index - 1); - } else { - clean_input_ids.push_back(token_id); - clean_index++; - } - } - // GGML_ASSERT(class_token_index.size() == 1); // PhotoMaker currently does not support multiple - // trigger words in a single prompt. - if (class_token_index.size() == 1) { - // Expand the class word token and corresponding mask - int class_token = clean_input_ids[class_token_index[0]]; - class_idx = tokens_acc + class_token_index[0]; - std::vector clean_input_ids_tmp; - for (int i = 0; i < class_token_index[0]; i++) - clean_input_ids_tmp.push_back(clean_input_ids[i]); - for (int i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) - clean_input_ids_tmp.push_back(class_token); - for (int i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) - clean_input_ids_tmp.push_back(clean_input_ids[i]); - clean_input_ids.clear(); - clean_input_ids = clean_input_ids_tmp; - } - tokens_acc += clean_index; - tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); - weights.insert(weights.end(), clean_input_ids.size(), curr_weight); - } - // BUG!! double couting, pad_tokens will add BOS at the beginning - // tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); - // weights.insert(weights.begin(), 1.0); - - tokenizer.pad_tokens(tokens, &weights, nullptr, text_model->model.n_token, text_model->model.n_token, true); - int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; - for (int i = 0; i < tokens.size(); i++) { - // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs - if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs - // hardcode for now - class_token_mask.push_back(true); - else - class_token_mask.push_back(false); - } - - // printf("["); - // for (int i = 0; i < tokens.size(); i++) { - // printf("%d, ", class_token_mask[i] ? 1 : 0); - // } - // printf("]\n"); - - // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", "; - // } - // std::cout << std::endl; - - return std::make_tuple(tokens, weights, class_token_mask); - } - std::pair, std::vector> tokenize(std::string text, size_t min_length = 0, size_t max_length = 0, @@ -631,49 +505,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return result; } - std::tuple> - get_learned_condition_with_trigger(int n_threads, - const ConditionerParams& conditioner_params) override { - auto image_tokens = convert_token_to_id(trigger_word); - // if(image_tokens.size() == 1){ - // printf(" image token id is: %d \n", image_tokens[0]); - // } - GGML_ASSERT(image_tokens.size() == 1); - auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text, - conditioner_params.num_input_imgs, - image_tokens[0]); - std::vector& tokens = std::get<0>(tokens_and_weights); - std::vector& weights = std::get<1>(tokens_and_weights); - std::vector& clsm = std::get<2>(tokens_and_weights); - // printf("tokens: \n"); - // for(int i = 0; i < tokens.size(); ++i) - // printf("%d ", tokens[i]); - // printf("\n"); - // printf("clsm: \n"); - // for(int i = 0; i < clsm.size(); ++i) - // printf("%d ", clsm[i]?1:0); - // printf("\n"); - auto cond = get_learned_condition_common(n_threads, - tokens, - weights, - conditioner_params.clip_skip, - conditioner_params.width, - conditioner_params.height, - conditioner_params.zero_out_masked); - return std::make_tuple(cond, clsm); - } - - std::string remove_trigger_from_prompt(const std::string& prompt) override { - auto image_tokens = convert_token_to_id(trigger_word); - GGML_ASSERT(image_tokens.size() == 1); - auto tokens_and_weights = tokenize(prompt); - std::vector& tokens = tokens_and_weights.first; - auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]); - GGML_ASSERT(it != tokens.end()); // prompt must have trigger word - tokens.erase(it); - return decode(tokens); - } - SDCondition get_learned_condition(int n_threads, const ConditionerParams& conditioner_params) override { auto tokens_and_weights = tokenize(conditioner_params.text, text_model->model.n_token, text_model->model.n_token, true); diff --git a/src/extensions/generation_extension.h b/src/extensions/generation_extension.h new file mode 100644 index 00000000..0c895b87 --- /dev/null +++ b/src/extensions/generation_extension.h @@ -0,0 +1,73 @@ +#ifndef __SD_EXTENSIONS_GENERATION_EXTENSION_H__ +#define __SD_EXTENSIONS_GENERATION_EXTENSION_H__ + +#include +#include +#include +#include +#include + +#include "conditioning/conditioner.hpp" +#include "core/ggml_extend_backend.h" +#include "model.h" +#include "stable-diffusion.h" + +struct GenerationExtensionInitContext { + const sd_ctx_params_t* params; + SDVersion version; + const String2TensorStorage& tensor_storage_map; + ModelLoader& model_loader; + int n_threads; + std::function ensure_backend_pair; + std::function backend_for; + std::function params_backend_for; +}; + +struct GenerationExtensionTensorContext { + std::map& tensors; + std::map& mmap_able_tensors; + std::function module_can_mmap; +}; + +struct GenerationExtensionConditionContext { + Conditioner* conditioner; + ConditionerParams& condition_params; + const sd_pm_params_t& pm_params; + std::map& tensors; + SDVersion version; + int n_threads; + int total_steps; + bool free_params_immediately; +}; + +struct GenerationExtension { + virtual ~GenerationExtension() = default; + + virtual const char* name() const = 0; + virtual bool is_enabled() const { + return false; + } + virtual bool init(const GenerationExtensionInitContext&) { + return true; + } + virtual void collect_param_tensors(GenerationExtensionTensorContext&) {} + virtual void add_ignore_tensors(std::set&) const {} + virtual bool alloc_params_buffer() { + return true; + } + virtual size_t get_params_buffer_size() const { + return 0; + } + virtual void reset_runtime_condition() {} + virtual bool prepare_condition(GenerationExtensionConditionContext&) { + return false; + } + virtual const SDCondition& before_condition(int step, + const SDCondition& condition) const { + return condition; + } +}; + +std::shared_ptr create_photomaker_extension(); + +#endif diff --git a/src/extensions/photomaker_extension.cpp b/src/extensions/photomaker_extension.cpp new file mode 100644 index 00000000..ac3949a1 --- /dev/null +++ b/src/extensions/photomaker_extension.cpp @@ -0,0 +1,325 @@ +#include "extensions/generation_extension.h" + +#include +#include +#include +#include + +#include "core/tensor_ggml.hpp" +#include "core/util.h" +#include "model/adapter/lora.hpp" +#include "model/adapter/pmid.hpp" + +static std::tuple, std::vector, std::vector> +tokenize_photomaker_trigger(FrozenCLIPEmbedderWithCustomWords& clip_conditioner, + const std::string& text, + int trigger_token_count, + int32_t image_token) { + auto tokens_and_weights = clip_conditioner.tokenize(text); + std::vector source_tokens = std::move(tokens_and_weights.first); + std::vector source_weights = std::move(tokens_and_weights.second); + + if (!source_tokens.empty() && source_tokens.front() == clip_conditioner.tokenizer.BOS_TOKEN_ID) { + source_tokens.erase(source_tokens.begin()); + source_weights.erase(source_weights.begin()); + } + if (!source_tokens.empty() && source_tokens.back() == clip_conditioner.tokenizer.EOS_TOKEN_ID) { + source_tokens.pop_back(); + source_weights.pop_back(); + } + + std::vector tokens; + std::vector weights; + int32_t class_idx = -1; + for (size_t i = 0; i < source_tokens.size(); i++) { + int token = source_tokens[i]; + if (token == image_token) { + if (!tokens.empty()) { + class_idx = static_cast(tokens.size()) - 1; + int class_token = tokens.back(); + float class_weight = weights.back(); + for (int j = 1; j < trigger_token_count; j++) { + tokens.push_back(class_token); + weights.push_back(class_weight); + } + } + continue; + } + tokens.push_back(token); + weights.push_back(source_weights[i]); + } + + clip_conditioner.tokenizer.pad_tokens(tokens, + &weights, + nullptr, + clip_conditioner.text_model->model.n_token, + clip_conditioner.text_model->model.n_token, + true); + std::vector class_token_mask; + for (int i = 0; i < tokens.size(); i++) { + class_token_mask.push_back(class_idx + 1 <= i && i < class_idx + 1 + trigger_token_count); + } + + return std::make_tuple(tokens, weights, class_token_mask); +} + +static std::tuple> +get_photomaker_condition_with_trigger(FrozenCLIPEmbedderWithCustomWords& clip_conditioner, + int n_threads, + const ConditionerParams& conditioner_params, + const std::string& trigger_word, + int trigger_token_count) { + auto image_tokens = clip_conditioner.convert_token_to_id(trigger_word); + GGML_ASSERT(image_tokens.size() == 1); + auto tokens_and_weights = tokenize_photomaker_trigger(clip_conditioner, + conditioner_params.text, + trigger_token_count, + image_tokens[0]); + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); + std::vector& trigger_mask = std::get<2>(tokens_and_weights); + auto cond = clip_conditioner.get_learned_condition_common(n_threads, + tokens, + weights, + conditioner_params.clip_skip, + conditioner_params.width, + conditioner_params.height, + conditioner_params.zero_out_masked); + return std::make_tuple(std::move(cond), trigger_mask); +} + +static std::string remove_photomaker_trigger_from_prompt(FrozenCLIPEmbedderWithCustomWords& clip_conditioner, + const std::string& prompt, + const std::string& trigger_word) { + auto image_tokens = clip_conditioner.convert_token_to_id(trigger_word); + GGML_ASSERT(image_tokens.size() == 1); + auto tokens_and_weights = clip_conditioner.tokenize(prompt); + std::vector& tokens = tokens_and_weights.first; + auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]); + GGML_ASSERT(it != tokens.end()); + tokens.erase(it); + return clip_conditioner.decode(tokens); +} + +struct PhotoMakerExtension : public GenerationExtension { + std::shared_ptr pmid_model; + std::shared_ptr pmid_lora; + bool enabled = false; + std::string model_path; + std::string trigger_word = "img"; + SDCondition id_condition; + int start_merge_step = -1; + + const char* name() const override { + return "photomaker"; + } + + bool is_enabled() const override { + return enabled; + } + + bool init(const GenerationExtensionInitContext& ctx) override { + model_path = SAFE_STR(ctx.params->photo_maker_path); + if (model_path.empty()) { + return true; + } + + if (!ctx.ensure_backend_pair(SDBackendModule::PHOTOMAKER)) { + return false; + } + + PMVersion pm_version = std::strstr(model_path.c_str(), "v2") != nullptr ? PM_VERSION_2 : PM_VERSION_1; + pmid_model = std::make_shared(ctx.backend_for(SDBackendModule::PHOTOMAKER), + ctx.params_backend_for(SDBackendModule::PHOTOMAKER), + ctx.tensor_storage_map, + "pmid", + ctx.version, + pm_version); + if (pm_version == PM_VERSION_2) { + LOG_INFO("using PhotoMaker Version 2"); + } + + pmid_lora = std::make_shared("pmid", + ctx.backend_for(SDBackendModule::PHOTOMAKER), + ctx.params_backend_for(SDBackendModule::PHOTOMAKER), + model_path, + "", + ctx.version); + auto lora_tensor_filter = [&](const std::string& tensor_name) { + return starts_with(tensor_name, "lora.model"); + }; + if (!pmid_lora->load_from_file(ctx.n_threads, lora_tensor_filter)) { + LOG_WARN("load photomaker lora tensors from %s failed", model_path.c_str()); + return false; + } + + LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", model_path.c_str()); + if (!ctx.model_loader.init_from_file_and_convert_name(model_path, "pmid.")) { + LOG_WARN("loading stacked ID embedding from '%s' failed", model_path.c_str()); + return true; + } + + enabled = true; + return true; + } + + void collect_param_tensors(GenerationExtensionTensorContext& ctx) override { + if (!enabled || pmid_model == nullptr) { + return; + } + + std::map temp; + pmid_model->get_param_tensors(temp, "pmid"); + bool do_mmap = ctx.module_can_mmap(SDBackendModule::PHOTOMAKER); + for (const auto& [key, tensor] : temp) { + ctx.tensors[key] = tensor; + if (do_mmap) { + ctx.mmap_able_tensors[key] = tensor; + } + } + } + + void add_ignore_tensors(std::set& ignore_tensors) const override { + if (!enabled) { + return; + } + ignore_tensors.insert("pmid.unet."); + } + + bool alloc_params_buffer() override { + if (!enabled || pmid_model == nullptr) { + return true; + } + return pmid_model->alloc_params_buffer(); + } + + size_t get_params_buffer_size() const override { + if (!enabled || pmid_model == nullptr) { + return 0; + } + return pmid_model->get_params_buffer_size(); + } + + void reset_runtime_condition() override { + id_condition = {}; + start_merge_step = -1; + } + + bool prepare_condition(GenerationExtensionConditionContext& ctx) override { + reset_runtime_condition(); + if (!enabled || pmid_model == nullptr || pmid_lora == nullptr) { + return false; + } + + if (!pmid_lora->applied) { + int64_t t0 = ggml_time_ms(); + pmid_lora->apply(ctx.tensors, ctx.version, ctx.n_threads); + int64_t t1 = ggml_time_ms(); + pmid_lora->applied = true; + LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + if (ctx.free_params_immediately) { + pmid_lora->free_params_buffer(); + } + } + + bool pmv2 = pmid_model->get_version() == PM_VERSION_2; + if (ctx.pm_params.id_images_count <= 0 || ctx.pm_params.id_images == nullptr) { + LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); + LOG_WARN("Turn off PhotoMaker for this request"); + return false; + } + auto* clip_conditioner = dynamic_cast(ctx.conditioner); + if (clip_conditioner == nullptr) { + LOG_WARN("PhotoMaker requires FrozenCLIPEmbedderWithCustomWords conditioner"); + LOG_WARN("Turn off PhotoMaker for this request"); + return false; + } + + int clip_image_size = 224; + pmid_model->style_strength = ctx.pm_params.style_strength; + sd::Tensor id_image_tensor; + for (int i = 0; i < ctx.pm_params.id_images_count; i++) { + auto id_image = sd_image_to_tensor(ctx.pm_params.id_images[i]); + auto processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size); + if (id_image_tensor.empty()) { + id_image_tensor = processed_id_image; + } else { + id_image_tensor = sd::ops::concat(id_image_tensor, processed_id_image, 3); + } + } + + int64_t t0 = ggml_time_ms(); + int trigger_token_count = pmv2 ? 2 * ctx.pm_params.id_images_count : ctx.pm_params.id_images_count; + auto cond_tup = get_photomaker_condition_with_trigger(*clip_conditioner, + ctx.n_threads, + ctx.condition_params, + trigger_word, + trigger_token_count); + SDCondition prepared_id_condition = std::get<0>(cond_tup); + auto class_tokens_mask = std::get<1>(cond_tup); + if (std::find(class_tokens_mask.begin(), class_tokens_mask.end(), true) == class_tokens_mask.end()) { + LOG_WARN("PhotoMaker trigger word '%s' was not found in prompt", trigger_word.c_str()); + LOG_WARN("Turn off PhotoMaker for this request"); + return false; + } + + sd::Tensor id_embeds; + if (pmv2 && ctx.pm_params.id_embed_path != nullptr) { + try { + id_embeds = sd::load_tensor_from_file_as_tensor(ctx.pm_params.id_embed_path); + } catch (const std::exception&) { + id_embeds = {}; + } + } + if (pmv2 && id_embeds.empty()) { + LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2"); + LOG_WARN("Turn off PhotoMaker for this request"); + return false; + } + if (pmv2 && ctx.pm_params.id_images_count != id_embeds.shape()[1]) { + LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", + ctx.pm_params.id_images_count, + static_cast(id_embeds.shape()[1])); + LOG_WARN("Turn off PhotoMaker for this request"); + return false; + } + + auto res = pmid_model->compute(ctx.n_threads, + id_image_tensor, + prepared_id_condition.c_crossattn, + id_embeds, + class_tokens_mask); + if (res.empty()) { + LOG_ERROR("Photomaker ID Stacking failed"); + LOG_WARN("Turn off PhotoMaker for this request"); + return false; + } + + prepared_id_condition.c_crossattn = std::move(res); + int64_t t1 = ggml_time_ms(); + id_condition = std::move(prepared_id_condition); + start_merge_step = int(ctx.pm_params.style_strength / 100.f * ctx.total_steps); + ctx.condition_params.text = remove_photomaker_trigger_from_prompt(*clip_conditioner, + ctx.condition_params.text, + trigger_word); + LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); + LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); + + if (ctx.free_params_immediately) { + pmid_model->free_params_buffer(); + } + return true; + } + + const SDCondition& before_condition(int step, + const SDCondition& condition) const override { + if (!id_condition.empty() && start_merge_step != -1 && step > start_merge_step) { + return id_condition; + } + return condition; + } +}; + +std::shared_ptr create_photomaker_extension() { + return std::make_shared(); +} diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 07ba0d0f..266618c4 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "core/ggml_extend.hpp" #include "core/ggml_graph_cut.h" @@ -13,8 +14,8 @@ #include "stable-diffusion.h" #include "conditioning/conditioner.hpp" +#include "extensions/generation_extension.h" #include "model/adapter/lora.hpp" -#include "model/adapter/pmid.hpp" #include "model/diffusion/anima.hpp" #include "model/diffusion/control.hpp" #include "model/diffusion/ernie_image.hpp" @@ -180,9 +181,7 @@ public: std::shared_ptr preview_vae; std::shared_ptr audio_vae_model; std::shared_ptr control_net; - std::shared_ptr pmid_model; - std::shared_ptr pmid_lora; - std::shared_ptr pmid_id_embeds; + std::vector> generation_extensions; std::vector> cond_stage_lora_models; std::vector> diffusion_lora_models; std::vector> first_stage_lora_models; @@ -193,7 +192,6 @@ public: bool offload_params_to_cpu = false; float max_vram = 0.f; bool stream_layers = false; - bool use_pmid = false; std::string backend_spec; std::string params_backend_spec; @@ -743,21 +741,12 @@ public: for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { embbeding_map.emplace(SAFE_STR(sd_ctx_params->embeddings[i].name), SAFE_STR(sd_ctx_params->embeddings[i].path)); } - if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { - cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE), - params_backend_for(SDBackendModule::TE), - tensor_storage_map, - embbeding_map, - version, - PM_VERSION_2); - } else { - cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE), - params_backend_for(SDBackendModule::TE), - tensor_storage_map, - embbeding_map, - version); - } - diffusion_model = std::make_shared(backend_for(SDBackendModule::DIFFUSION), + cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE), + params_backend_for(SDBackendModule::TE), + tensor_storage_map, + embbeding_map, + version); + diffusion_model = std::make_shared(backend_for(SDBackendModule::DIFFUSION), params_backend_for(SDBackendModule::DIFFUSION), tensor_storage_map, "model.diffusion_model", @@ -914,50 +903,35 @@ public: } } - if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { - if (!ensure_backend_pair(SDBackendModule::PHOTOMAKER)) { - return false; - } - if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { - pmid_model = std::make_shared(backend_for(SDBackendModule::PHOTOMAKER), - params_backend_for(SDBackendModule::PHOTOMAKER), - tensor_storage_map, - "pmid", - version, - PM_VERSION_2); - LOG_INFO("using PhotoMaker Version 2"); - } else { - pmid_model = std::make_shared(backend_for(SDBackendModule::PHOTOMAKER), - params_backend_for(SDBackendModule::PHOTOMAKER), - tensor_storage_map, - "pmid", - version); - } - pmid_lora = std::make_shared("pmid", - backend_for(SDBackendModule::PHOTOMAKER), - params_backend_for(SDBackendModule::PHOTOMAKER), - sd_ctx_params->photo_maker_path, - "", - version); - auto lora_tensor_filter = [&](const std::string& tensor_name) { - if (starts_with(tensor_name, "lora.model")) { - return true; - } - return false; + { + generation_extensions.clear(); + auto photomaker_extension = create_photomaker_extension(); + GenerationExtensionInitContext extension_ctx{ + sd_ctx_params, + version, + tensor_storage_map, + model_loader, + n_threads, + [this](SDBackendModule module) { return ensure_backend_pair(module); }, + [this](SDBackendModule module) { return backend_for(module); }, + [this](SDBackendModule module) { return params_backend_for(module); }, }; - if (!pmid_lora->load_from_file(n_threads, lora_tensor_filter)) { - LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); + if (!photomaker_extension->init(extension_ctx)) { return false; } - LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path); - if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) { - LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); - } else { - use_pmid = true; + if (photomaker_extension->is_enabled()) { + generation_extensions.push_back(photomaker_extension); } } - if (use_pmid) { - get_param_tensors_p(pmid_model, module_can_mmap(SDBackendModule::PHOTOMAKER), "pmid"); + { + GenerationExtensionTensorContext extension_tensor_ctx{ + tensors, + mmap_able_tensors, + module_can_mmap, + }; + for (auto& extension : generation_extensions) { + extension->collect_param_tensors(extension_tensor_ctx); + } } if (sd_ctx_params->flash_attn) { @@ -1011,8 +985,8 @@ public: if (use_tae && !tae_preview_only) { ignore_tensors.insert("first_stage_model."); } - if (use_pmid) { - ignore_tensors.insert("pmid.unet."); + for (auto& extension : generation_extensions) { + extension->add_ignore_tensors(ignore_tensors); } ignore_tensors.insert("model.diffusion_model.__x0__"); ignore_tensors.insert("model.diffusion_model.__32x32__"); @@ -1099,10 +1073,12 @@ public: ggml_free(ctx); return false; } - if (use_pmid && pmid_model && !pmid_model->alloc_params_buffer()) { - LOG_ERROR("PhotoMaker params buffer allocation failed"); - ggml_free(ctx); - return false; + for (auto& extension : generation_extensions) { + if (!extension->alloc_params_buffer()) { + LOG_ERROR("%s params buffer allocation failed", extension->name()); + ggml_free(ctx); + return false; + } } bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads, sd_ctx_params->enable_mmap); @@ -1136,9 +1112,9 @@ public: } control_net_params_mem_size = control_net->get_params_buffer_size(); } - size_t pmid_params_mem_size = 0; - if (use_pmid) { - pmid_params_mem_size = pmid_model->get_params_buffer_size(); + size_t extension_params_mem_size = 0; + for (auto& extension : generation_extensions) { + extension_params_mem_size += extension->get_params_buffer_size(); } size_t total_params_ram_size = 0; @@ -1170,7 +1146,7 @@ public: }; if (!add_params_memory(clip_params_mem_size, SDBackendModule::TE) || - !add_params_memory(pmid_params_mem_size, SDBackendModule::PHOTOMAKER) || + !add_params_memory(extension_params_mem_size, SDBackendModule::PHOTOMAKER) || !add_params_memory(unet_params_mem_size, SDBackendModule::DIFFUSION) || !add_params_memory(vae_params_mem_size, SDBackendModule::VAE) || !add_params_memory(control_net_params_mem_size, SDBackendModule::CONTROL_NET)) { @@ -1181,7 +1157,7 @@ public: size_t total_params_size = total_params_ram_size + total_params_vram_size; LOG_INFO( "total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): " - "text_encoders %.2fMB(%s), diffusion_model %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)", + "text_encoders %.2fMB(%s), diffusion_model %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), extensions %.2fMB(%s)", total_params_size / 1024.0 / 1024.0, total_params_vram_size / 1024.0 / 1024.0, total_params_ram_size / 1024.0 / 1024.0, @@ -1193,8 +1169,8 @@ public: params_memory_location(vae_params_mem_size, SDBackendModule::VAE), control_net_params_mem_size / 1024.0 / 1024.0, params_memory_location(control_net_params_mem_size, SDBackendModule::CONTROL_NET), - pmid_params_mem_size / 1024.0 / 1024.0, - params_memory_location(pmid_params_mem_size, SDBackendModule::PHOTOMAKER)); + extension_params_mem_size / 1024.0 / 1024.0, + params_memory_location(extension_params_mem_size, SDBackendModule::PHOTOMAKER)); } // init denoiser @@ -1599,88 +1575,30 @@ public: } } - SDCondition get_pmid_conditon(sd_pm_params_t pm_params, - ConditionerParams& condition_params) { - SDCondition id_cond; - if (use_pmid) { - if (!pmid_lora->applied) { - int64_t t0 = ggml_time_ms(); - pmid_lora->apply(tensors, version, n_threads); - int64_t t1 = ggml_time_ms(); - pmid_lora->applied = true; - LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - if (free_params_immediately) { - pmid_lora->free_params_buffer(); - } - } - // preprocess input id images - bool pmv2 = pmid_model->get_version() == PM_VERSION_2; - if (pm_params.id_images_count > 0) { - int clip_image_size = 224; - pmid_model->style_strength = pm_params.style_strength; - sd::Tensor id_image_tensor; - for (int i = 0; i < pm_params.id_images_count; i++) { - auto id_image = sd_image_to_tensor(pm_params.id_images[i]); - auto processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size); - if (id_image_tensor.empty()) { - id_image_tensor = processed_id_image; - } else { - id_image_tensor = sd::ops::concat(id_image_tensor, processed_id_image, 3); - } - } - - int64_t t0 = ggml_time_ms(); - condition_params.num_input_imgs = pm_params.id_images_count; - auto cond_tup = cond_stage_model->get_learned_condition_with_trigger(n_threads, - condition_params); - id_cond = std::get<0>(cond_tup); - auto class_tokens_mask = std::get<1>(cond_tup); - sd::Tensor id_embeds; - if (pmv2 && pm_params.id_embed_path != nullptr) { - try { - id_embeds = sd::load_tensor_from_file_as_tensor(pm_params.id_embed_path); - } catch (const std::exception&) { - id_embeds = {}; - } - } - if (pmv2 && id_embeds.empty()) { - LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2"); - LOG_WARN("Turn off PhotoMaker"); - use_pmid = false; - } else { - if (pmv2 && pm_params.id_images_count != id_embeds.shape()[1]) { - LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, static_cast(id_embeds.shape()[1])); - LOG_WARN("Turn off PhotoMaker"); - use_pmid = false; - } else { - auto res = pmid_model->compute(n_threads, - id_image_tensor, - id_cond.c_crossattn, - id_embeds, - class_tokens_mask); - if (res.empty()) { - LOG_ERROR("Photomaker ID Stacking failed"); - LOG_WARN("Turn off PhotoMaker"); - use_pmid = false; - } else { - id_cond.c_crossattn = std::move(res); - int64_t t1 = ggml_time_ms(); - LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); - // Encode input prompt without the trigger word for delayed conditioning - condition_params.text = cond_stage_model->remove_trigger_from_prompt(condition_params.text); - } - if (free_params_immediately) { - pmid_model->free_params_buffer(); - } - } - } - } else { - LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); - LOG_WARN("Turn off PhotoMaker"); - use_pmid = false; - } + void reset_generation_extensions() { + for (auto& extension : generation_extensions) { + extension->reset_runtime_condition(); + } + } + + void prepare_generation_extensions(const sd_pm_params_t& pm_params, + ConditionerParams& condition_params, + int total_steps) { + reset_generation_extensions(); + GenerationExtensionConditionContext ctx{ + cond_stage_model.get(), + condition_params, + pm_params, + tensors, + version, + n_threads, + total_steps, + free_params_immediately, + }; + + for (auto& extension : generation_extensions) { + extension->prepare_condition(ctx); } - return id_cond; } sd::Tensor get_clip_vision_output(const sd::Tensor& image, @@ -1979,7 +1897,6 @@ public: const SDCondition& cond, const SDCondition& uncond, const SDCondition& img_uncond, - const SDCondition& id_cond, const sd::Tensor& control_image, float control_strength, const sd_guidance_params_t& guidance, @@ -1989,7 +1906,6 @@ public: bool is_flow_denoiser, const char* extra_sample_args, const std::vector& sigmas, - int start_merge_step, const std::vector>& ref_latents, bool increase_ref_index, const sd::Tensor& denoise_mask, @@ -2181,20 +2097,24 @@ public: return output_opt; }; - if (start_merge_step == -1 || step <= start_merge_step) { - cond_out = run_condition(cond); - if (cond_out.empty()) { - return {}; - } - } else { - GGML_ASSERT(!id_cond.empty()); - cond_out = run_condition(id_cond, - cond.c_concat.empty() ? nullptr : &cond.c_concat); - if (cond_out.empty()) { - return {}; + const SDCondition* positive_condition = &cond; + const sd::Tensor* c_concat_override = nullptr; + for (const auto& extension : generation_extensions) { + const SDCondition& next_condition = extension->before_condition(step, *positive_condition); + if (&next_condition != positive_condition) { + positive_condition = &next_condition; + if (positive_condition != &cond) { + c_concat_override = cond.c_concat.empty() ? nullptr : &cond.c_concat; + } + break; } } + cond_out = run_condition(*positive_condition, c_concat_override); + if (cond_out.empty()) { + return {}; + } + if (!uncond.empty()) { if (!step_cache.is_step_skipped()) { compute_sample_controls(control_image, @@ -3470,7 +3390,6 @@ struct SamplePlan { int high_noise_sample_steps = 0; int total_steps = 0; float moe_boundary = 0.f; - int start_merge_step = -1; std::vector sigmas; SamplePlan(sd_ctx_t* sd_ctx, @@ -3555,11 +3474,6 @@ struct SamplePlan { high_noise_eta = resolve_eta(sd_ctx, high_noise_eta, high_noise_sample_method); LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]); } - - if (sd_ctx->sd->use_pmid) { - start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * total_steps); - LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); - } } }; @@ -3890,7 +3804,6 @@ struct ImageGenerationEmbeds { SDCondition cond; SDCondition uncond; SDCondition img_uncond; - SDCondition id_cond; }; struct CircularAxesState { @@ -4195,7 +4108,9 @@ static std::optional prepare_image_generation_embeds(sd_c condition_params.height = request->height; condition_params.ref_images = &latents->ref_images; - auto id_cond = sd_ctx->sd->get_pmid_conditon(request->pm_params, condition_params); + sd_ctx->sd->prepare_generation_extensions(request->pm_params, + condition_params, + plan->total_steps); int64_t prepare_start_ms = ggml_time_ms(); condition_params.zero_out_masked = false; auto cond = sd_ctx->sd->cond_stage_model->get_learned_condition(sd_ctx->sd->n_threads, @@ -4265,7 +4180,6 @@ static std::optional prepare_image_generation_embeds(sd_c embeds.img_uncond = std::move(img_uncond); embeds.cond = std::move(cond); embeds.uncond = std::move(uncond); - embeds.id_cond = std::move(id_cond); return embeds; } @@ -4546,7 +4460,6 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s embeds.cond, embeds.uncond, embeds.img_uncond, - embeds.id_cond, latents.control_image, request.control_strength, request.guidance, @@ -4556,7 +4469,6 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s sd_ctx->sd->is_flow_denoiser(), plan.extra_sample_args, plan.sigmas, - plan.start_merge_step, latents.ref_latents, request.increase_ref_index, latents.denoise_mask, @@ -4666,7 +4578,6 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s embeds.cond, embeds.uncond, embeds.img_uncond, - embeds.id_cond, latents.control_image, request.control_strength, request.guidance, @@ -4676,7 +4587,6 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s sd_ctx->sd->is_flow_denoiser(), plan.extra_sample_args, hires_sigma_sched, - plan.start_merge_step, latents.ref_latents, request.increase_ref_index, hires_denoise_mask, @@ -5335,6 +5245,7 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_ctx->sd->sampler_rng->manual_seed(request.seed); sd_ctx->sd->set_flow_shift(sd_vid_gen_params->sample_params.flow_shift); sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count); + sd_ctx->sd->reset_generation_extensions(); SamplePlan plan(sd_ctx, sd_vid_gen_params, request); auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request); @@ -5381,7 +5292,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, embeds.cond, request.use_high_noise_uncond ? embeds.uncond : SDCondition(), embeds.img_uncond, - embeds.id_cond, sd::Tensor(), 0.f, request.high_noise_guidance, @@ -5391,7 +5301,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_ctx->sd->is_flow_denoiser(), plan.high_noise_extra_sample_args, high_noise_sigmas, - -1, std::vector>{}, false, latents.denoise_mask, @@ -5427,7 +5336,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, embeds.cond, request.use_uncond ? embeds.uncond : SDCondition(), embeds.img_uncond, - embeds.id_cond, sd::Tensor(), 0.f, sd_vid_gen_params->sample_params.guidance, @@ -5437,7 +5345,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_ctx->sd->is_flow_denoiser(), plan.extra_sample_args, plan.sigmas, - -1, std::vector>{}, false, latents.denoise_mask, @@ -5571,7 +5478,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, embeds.cond, hires_request.use_uncond ? embeds.uncond : SDCondition(), embeds.img_uncond, - embeds.id_cond, sd::Tensor(), 0.f, sd_vid_gen_params->sample_params.guidance, @@ -5581,7 +5487,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd_ctx->sd->is_flow_denoiser(), plan.extra_sample_args, hires_sigma_sched, - -1, std::vector>{}, false, hires_denoise_mask,