#include "extensions/generation_extension.h" #include #include #include #include #include "core/tensor_ggml.hpp" #include "core/util.h" #include "model/adapter/lora.hpp" #include "model/adapter/pmid.hpp" static std::tuple, std::vector, std::vector> tokenize_photomaker_trigger(FrozenCLIPEmbedderWithCustomWords& clip_conditioner, const std::string& text, int trigger_token_count, int32_t image_token) { auto tokens_and_weights = clip_conditioner.tokenize(text); std::vector source_tokens = std::move(tokens_and_weights.first); std::vector source_weights = std::move(tokens_and_weights.second); if (!source_tokens.empty() && source_tokens.front() == clip_conditioner.tokenizer.BOS_TOKEN_ID) { source_tokens.erase(source_tokens.begin()); source_weights.erase(source_weights.begin()); } if (!source_tokens.empty() && source_tokens.back() == clip_conditioner.tokenizer.EOS_TOKEN_ID) { source_tokens.pop_back(); source_weights.pop_back(); } std::vector tokens; std::vector weights; int32_t class_idx = -1; for (size_t i = 0; i < source_tokens.size(); i++) { int token = source_tokens[i]; if (token == image_token) { if (!tokens.empty()) { class_idx = static_cast(tokens.size()) - 1; int class_token = tokens.back(); float class_weight = weights.back(); for (int j = 1; j < trigger_token_count; j++) { tokens.push_back(class_token); weights.push_back(class_weight); } } continue; } tokens.push_back(token); weights.push_back(source_weights[i]); } clip_conditioner.tokenizer.pad_tokens(tokens, &weights, nullptr, clip_conditioner.text_model->model.n_token, clip_conditioner.text_model->model.n_token, true); std::vector class_token_mask; for (int i = 0; i < tokens.size(); i++) { class_token_mask.push_back(class_idx + 1 <= i && i < class_idx + 1 + trigger_token_count); } return std::make_tuple(tokens, weights, class_token_mask); } static std::tuple> get_photomaker_condition_with_trigger(FrozenCLIPEmbedderWithCustomWords& clip_conditioner, int n_threads, const ConditionerParams& conditioner_params, const std::string& trigger_word, int trigger_token_count) { auto image_tokens = clip_conditioner.convert_token_to_id(trigger_word); GGML_ASSERT(image_tokens.size() == 1); auto tokens_and_weights = tokenize_photomaker_trigger(clip_conditioner, conditioner_params.text, trigger_token_count, image_tokens[0]); std::vector& tokens = std::get<0>(tokens_and_weights); std::vector& weights = std::get<1>(tokens_and_weights); std::vector& trigger_mask = std::get<2>(tokens_and_weights); auto cond = clip_conditioner.get_learned_condition_common(n_threads, tokens, weights, conditioner_params.clip_skip, conditioner_params.width, conditioner_params.height, conditioner_params.zero_out_masked); return std::make_tuple(std::move(cond), trigger_mask); } static std::string remove_photomaker_trigger_from_prompt(FrozenCLIPEmbedderWithCustomWords& clip_conditioner, const std::string& prompt, const std::string& trigger_word) { auto image_tokens = clip_conditioner.convert_token_to_id(trigger_word); GGML_ASSERT(image_tokens.size() == 1); auto tokens_and_weights = clip_conditioner.tokenize(prompt); std::vector& tokens = tokens_and_weights.first; auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]); GGML_ASSERT(it != tokens.end()); tokens.erase(it); return clip_conditioner.decode(tokens); } struct PhotoMakerExtension : public GenerationExtension { std::shared_ptr pmid_model; std::shared_ptr pmid_lora; bool enabled = false; std::string model_path; std::string trigger_word = "img"; SDCondition id_condition; int start_merge_step = -1; const char* name() const override { return "photomaker"; } bool is_enabled() const override { return enabled; } bool init(const GenerationExtensionInitContext& ctx) override { model_path = SAFE_STR(ctx.params->photo_maker_path); if (model_path.empty()) { return true; } if (!ctx.ensure_backend_pair(SDBackendModule::PHOTOMAKER)) { return false; } PMVersion pm_version = std::strstr(model_path.c_str(), "v2") != nullptr ? PM_VERSION_2 : PM_VERSION_1; pmid_model = std::make_shared(ctx.backend_for(SDBackendModule::PHOTOMAKER), ctx.params_backend_for(SDBackendModule::PHOTOMAKER), ctx.tensor_storage_map, "pmid", ctx.version, pm_version); if (pm_version == PM_VERSION_2) { LOG_INFO("using PhotoMaker Version 2"); } pmid_lora = std::make_shared("pmid", ctx.backend_for(SDBackendModule::PHOTOMAKER), ctx.params_backend_for(SDBackendModule::PHOTOMAKER), model_path, "", ctx.version); auto lora_tensor_filter = [&](const std::string& tensor_name) { return starts_with(tensor_name, "lora.model"); }; if (!pmid_lora->load_from_file(ctx.n_threads, lora_tensor_filter)) { LOG_WARN("load photomaker lora tensors from %s failed", model_path.c_str()); return false; } LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", model_path.c_str()); if (!ctx.model_loader.init_from_file_and_convert_name(model_path, "pmid.")) { LOG_WARN("loading stacked ID embedding from '%s' failed", model_path.c_str()); return true; } enabled = true; return true; } void collect_param_tensors(GenerationExtensionTensorContext& ctx) override { if (!enabled || pmid_model == nullptr) { return; } std::map temp; pmid_model->get_param_tensors(temp, "pmid"); bool do_mmap = ctx.module_can_mmap(SDBackendModule::PHOTOMAKER); for (const auto& [key, tensor] : temp) { ctx.tensors[key] = tensor; if (do_mmap) { ctx.mmap_able_tensors[key] = tensor; } } } void add_ignore_tensors(std::set& ignore_tensors) const override { if (!enabled) { return; } ignore_tensors.insert("pmid.unet."); } bool alloc_params_buffer() override { if (!enabled || pmid_model == nullptr) { return true; } return pmid_model->alloc_params_buffer(); } size_t get_params_buffer_size() const override { if (!enabled || pmid_model == nullptr) { return 0; } return pmid_model->get_params_buffer_size(); } void reset_runtime_condition() override { id_condition = {}; start_merge_step = -1; } bool prepare_condition(GenerationExtensionConditionContext& ctx) override { reset_runtime_condition(); if (!enabled || pmid_model == nullptr || pmid_lora == nullptr) { return false; } if (!pmid_lora->applied) { int64_t t0 = ggml_time_ms(); pmid_lora->apply(ctx.tensors, ctx.version, ctx.n_threads); int64_t t1 = ggml_time_ms(); pmid_lora->applied = true; LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); if (ctx.free_params_immediately) { pmid_lora->free_params_buffer(); } } bool pmv2 = pmid_model->get_version() == PM_VERSION_2; if (ctx.pm_params.id_images_count <= 0 || ctx.pm_params.id_images == nullptr) { LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); LOG_WARN("Turn off PhotoMaker for this request"); return false; } auto* clip_conditioner = dynamic_cast(ctx.conditioner); if (clip_conditioner == nullptr) { LOG_WARN("PhotoMaker requires FrozenCLIPEmbedderWithCustomWords conditioner"); LOG_WARN("Turn off PhotoMaker for this request"); return false; } int clip_image_size = 224; pmid_model->style_strength = ctx.pm_params.style_strength; sd::Tensor id_image_tensor; for (int i = 0; i < ctx.pm_params.id_images_count; i++) { auto id_image = sd_image_to_tensor(ctx.pm_params.id_images[i]); auto processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size); if (id_image_tensor.empty()) { id_image_tensor = processed_id_image; } else { id_image_tensor = sd::ops::concat(id_image_tensor, processed_id_image, 3); } } int64_t t0 = ggml_time_ms(); int trigger_token_count = pmv2 ? 2 * ctx.pm_params.id_images_count : ctx.pm_params.id_images_count; auto cond_tup = get_photomaker_condition_with_trigger(*clip_conditioner, ctx.n_threads, ctx.condition_params, trigger_word, trigger_token_count); SDCondition prepared_id_condition = std::get<0>(cond_tup); auto class_tokens_mask = std::get<1>(cond_tup); if (std::find(class_tokens_mask.begin(), class_tokens_mask.end(), true) == class_tokens_mask.end()) { LOG_WARN("PhotoMaker trigger word '%s' was not found in prompt", trigger_word.c_str()); LOG_WARN("Turn off PhotoMaker for this request"); return false; } sd::Tensor id_embeds; if (pmv2 && ctx.pm_params.id_embed_path != nullptr) { try { id_embeds = sd::load_tensor_from_file_as_tensor(ctx.pm_params.id_embed_path); } catch (const std::exception&) { id_embeds = {}; } } if (pmv2 && id_embeds.empty()) { LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2"); LOG_WARN("Turn off PhotoMaker for this request"); return false; } if (pmv2 && ctx.pm_params.id_images_count != id_embeds.shape()[1]) { LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", ctx.pm_params.id_images_count, static_cast(id_embeds.shape()[1])); LOG_WARN("Turn off PhotoMaker for this request"); return false; } auto res = pmid_model->compute(ctx.n_threads, id_image_tensor, prepared_id_condition.c_crossattn, id_embeds, class_tokens_mask); if (res.empty()) { LOG_ERROR("Photomaker ID Stacking failed"); LOG_WARN("Turn off PhotoMaker for this request"); return false; } prepared_id_condition.c_crossattn = std::move(res); int64_t t1 = ggml_time_ms(); id_condition = std::move(prepared_id_condition); start_merge_step = int(ctx.pm_params.style_strength / 100.f * ctx.total_steps); ctx.condition_params.text = remove_photomaker_trigger_from_prompt(*clip_conditioner, ctx.condition_params.text, trigger_word); LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); if (ctx.free_params_immediately) { pmid_model->free_params_buffer(); } return true; } const SDCondition& before_condition(int step, const SDCondition& condition) const override { if (!id_condition.empty() && start_merge_step != -1 && step > start_merge_step) { return id_condition; } return condition; } }; std::shared_ptr create_photomaker_extension() { return std::make_shared(); }