From a2d83dd0c804dc39d7ceb2e599cfbb02488da8e9 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 27 Dec 2025 16:48:15 +0800 Subject: [PATCH] refactor: move pmid condition logic into get_pmid_condition (#1148) --- stable-diffusion.cpp | 201 ++++++++++++++++++++----------------------- 1 file changed, 92 insertions(+), 109 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 58d4204..9a97d4c 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -129,7 +129,7 @@ public: bool use_tiny_autoencoder = false; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; bool offload_params_to_cpu = false; - bool stacked_id = false; + bool use_pmid = false; bool is_using_v_parameterization = false; bool is_using_edm_v_parameterization = false; @@ -701,10 +701,10 @@ public: if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) { LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); } else { - stacked_id = true; + use_pmid = true; } } - if (stacked_id) { + if (use_pmid) { if (!pmid_model->alloc_params_buffer()) { LOG_ERROR(" pmid model params buffer allocation failed"); return false; @@ -745,7 +745,7 @@ public: if (use_tiny_autoencoder) { ignore_tensors.insert("first_stage_model."); } - if (stacked_id) { + if (use_pmid) { ignore_tensors.insert("pmid.unet."); } ignore_tensors.insert("model.diffusion_model.__x0__"); @@ -799,7 +799,7 @@ public: control_net_params_mem_size = control_net->get_params_buffer_size(); } size_t pmid_params_mem_size = 0; - if (stacked_id) { + if (use_pmid) { pmid_params_mem_size = pmid_model->get_params_buffer_size(); } @@ -1211,14 +1211,89 @@ public: } } - ggml_tensor* id_encoder(ggml_context* work_ctx, - ggml_tensor* init_img, - ggml_tensor* prompts_embeds, - ggml_tensor* id_embeds, - std::vector& class_tokens_mask) { - ggml_tensor* res = nullptr; - pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx); - return res; + SDCondition get_pmid_conditon(ggml_context* work_ctx, + sd_pm_params_t pm_params, + ConditionerParams& condition_params) { + SDCondition id_cond; + if (use_pmid) { + if (!pmid_lora->applied) { + int64_t t0 = ggml_time_ms(); + pmid_lora->apply(tensors, version, n_threads); + int64_t t1 = ggml_time_ms(); + pmid_lora->applied = true; + LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + if (free_params_immediately) { + pmid_lora->free_params_buffer(); + } + } + // preprocess input id images + bool pmv2 = pmid_model->get_version() == PM_VERSION_2; + if (pm_params.id_images_count > 0) { + int clip_image_size = 224; + pmid_model->style_strength = pm_params.style_strength; + + auto id_image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count); + + std::vector processed_id_images; + for (int i = 0; i < pm_params.id_images_count; i++) { + sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]); + sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size); + free(id_image.data); + id_image.data = nullptr; + processed_id_images.push_back(processed_id_image); + } + + ggml_ext_tensor_iter(id_image_tensor, [&](ggml_tensor* id_image_tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false); + ggml_ext_tensor_set_f32(id_image_tensor, value, i0, i1, i2, i3); + }); + + for (auto& image : processed_id_images) { + free(image.data); + image.data = nullptr; + } + processed_id_images.clear(); + + int64_t t0 = ggml_time_ms(); + condition_params.num_input_imgs = pm_params.id_images_count; + auto cond_tup = cond_stage_model->get_learned_condition_with_trigger(work_ctx, + n_threads, + condition_params); + id_cond = std::get<0>(cond_tup); + auto class_tokens_mask = std::get<1>(cond_tup); + struct ggml_tensor* id_embeds = nullptr; + if (pmv2 && pm_params.id_embed_path != nullptr) { + id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); + } + if (pmv2 && id_embeds == nullptr) { + LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2"); + LOG_WARN("Turn off PhotoMaker"); + use_pmid = false; + } else { + if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) { + LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]); + LOG_WARN("Turn off PhotoMaker"); + use_pmid = false; + } else { + ggml_tensor* res = nullptr; + pmid_model->compute(n_threads, id_image_tensor, id_cond.c_crossattn, id_embeds, class_tokens_mask, &res, work_ctx); + id_cond.c_crossattn = res; + int64_t t1 = ggml_time_ms(); + LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); + if (free_params_immediately) { + pmid_model->free_params_buffer(); + } + // Encode input prompt without the trigger word for delayed conditioning + condition_params.text = cond_stage_model->remove_trigger_from_prompt(work_ctx, condition_params.text); + } + } + } else { + LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); + LOG_WARN("Turn off PhotoMaker"); + use_pmid = false; + } + } + return id_cond; } ggml_tensor* get_clip_vision_output(ggml_context* work_ctx, @@ -3117,114 +3192,22 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, guidance.img_cfg = guidance.txt_cfg; } - // for (auto v : sigmas) { - // std::cout << v << " "; - // } - // std::cout << std::endl; - int sample_steps = sigmas.size() - 1; int64_t t0 = ggml_time_ms(); - // Photo Maker - std::string prompt_text_only; - ggml_tensor* init_img = nullptr; - SDCondition id_cond; - std::vector class_tokens_mask; - ConditionerParams condition_params; + condition_params.text = prompt; condition_params.clip_skip = clip_skip; condition_params.width = width; condition_params.height = height; condition_params.ref_images = ref_images; condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels(); - if (sd_ctx->sd->stacked_id) { - if (!sd_ctx->sd->pmid_lora->applied) { - int64_t t0 = ggml_time_ms(); - sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads); - int64_t t1 = ggml_time_ms(); - sd_ctx->sd->pmid_lora->applied = true; - LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately) { - sd_ctx->sd->pmid_lora->free_params_buffer(); - } - } - // preprocess input id images - bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2; - if (pm_params.id_images_count > 0) { - int clip_image_size = 224; - sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength; - - init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count); - - std::vector processed_id_images; - for (int i = 0; i < pm_params.id_images_count; i++) { - sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]); - sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size); - free(id_image.data); - id_image.data = nullptr; - processed_id_images.push_back(processed_id_image); - } - - ggml_ext_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false); - ggml_ext_tensor_set_f32(init_img, value, i0, i1, i2, i3); - }); - - for (auto& image : processed_id_images) { - free(image.data); - image.data = nullptr; - } - processed_id_images.clear(); - - int64_t t0 = ggml_time_ms(); - condition_params.text = prompt; - condition_params.num_input_imgs = pm_params.id_images_count; - auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, - sd_ctx->sd->n_threads, - condition_params); - id_cond = std::get<0>(cond_tup); - class_tokens_mask = std::get<1>(cond_tup); // - struct ggml_tensor* id_embeds = nullptr; - if (pmv2 && pm_params.id_embed_path != nullptr) { - id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); - // print_ggml_tensor(id_embeds, true, "id_embeds:"); - } - if (pmv2 && id_embeds == nullptr) { - LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2"); - LOG_WARN("Turn off PhotoMaker"); - sd_ctx->sd->stacked_id = false; - } else { - if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) { - LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]); - LOG_WARN("Turn off PhotoMaker"); - sd_ctx->sd->stacked_id = false; - } else { - id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); - int64_t t1 = ggml_time_ms(); - LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); - if (sd_ctx->sd->free_params_immediately) { - sd_ctx->sd->pmid_model->free_params_buffer(); - } - // Encode input prompt without the trigger word for delayed conditioning - prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt); - // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); - prompt = prompt_text_only; // - if (sample_steps < 50) { - LOG_WARN("It's recommended to use >= 50 steps for photo maker!"); - } - } - } - } else { - LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); - LOG_WARN("Turn off PhotoMaker"); - sd_ctx->sd->stacked_id = false; - } - } + // Photo Maker + SDCondition id_cond = sd_ctx->sd->get_pmid_conditon(work_ctx, pm_params, condition_params); // Get learned condition - condition_params.text = prompt; condition_params.zero_out_masked = false; SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, sd_ctx->sd->n_threads, @@ -3364,7 +3347,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng); int start_merge_step = -1; - if (sd_ctx->sd->stacked_id) { + if (sd_ctx->sd->use_pmid) { start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps); // if (start_merge_step > 30) // start_merge_step = 30;