feat: support img-cfg for edit models (#929)

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2026-06-01 16:54:25 +02:00 committed by GitHub
parent be65ac7511
commit f8935d6f25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 19 deletions

View File

@ -129,8 +129,8 @@ Generation Options:
--hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default: --hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default:
128) 128)
--cfg-scale <float> unconditional guidance scale: (default: 7.0) --cfg-scale <float> unconditional guidance scale: (default: 7.0)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same --img-cfg-scale <float> image guidance scale for inpaint or image edit models: (default: same as
as --cfg-scale) --cfg-scale)
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5) --guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means --slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means
disabled, a value of 2.5 is nice for sd3.5 medium disabled, a value of 2.5 is nice for sd3.5 medium
@ -140,8 +140,8 @@ Generation Options:
res_2s; 1 for euler_a, er_sde and dpm++2s_a) res_2s; 1 for euler_a, er_sde and dpm++2s_a)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto) --flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or image edit models (default:
(default: same as --cfg-scale) same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input --high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input
(default: 3.5) (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: --high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default:

View File

@ -940,7 +940,7 @@ ArgOptions SDGenerationParams::get_options() {
&sample_params.guidance.txt_cfg}, &sample_params.guidance.txt_cfg},
{"", {"",
"--img-cfg-scale", "--img-cfg-scale",
"image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)", "image guidance scale for inpaint or image edit models: (default: same as --cfg-scale)",
&sample_params.guidance.img_cfg}, &sample_params.guidance.img_cfg},
{"", {"",
"--guidance", "--guidance",
@ -972,7 +972,7 @@ ArgOptions SDGenerationParams::get_options() {
&high_noise_sample_params.guidance.txt_cfg}, &high_noise_sample_params.guidance.txt_cfg},
{"", {"",
"--high-noise-img-cfg-scale", "--high-noise-img-cfg-scale",
"(high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)", "(high noise) image guidance scale for inpaint or image edit models (default: same as --cfg-scale)",
&high_noise_sample_params.guidance.img_cfg}, &high_noise_sample_params.guidance.img_cfg},
{"", {"",
"--high-noise-guidance", "--high-noise-guidance",

View File

@ -231,8 +231,8 @@ Default Generation Options:
--hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default: --hires-upscale-tile-size <int> highres fix upscaler tile size, reserved for model-backed upscalers (default:
128) 128)
--cfg-scale <float> unconditional guidance scale: (default: 7.0) --cfg-scale <float> unconditional guidance scale: (default: 7.0)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same --img-cfg-scale <float> image guidance scale for inpaint or image edit models: (default: same as
as --cfg-scale) --cfg-scale)
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5) --guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
--slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means --slg-scale <float> skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means
disabled, a value of 2.5 is nice for sd3.5 medium disabled, a value of 2.5 is nice for sd3.5 medium
@ -242,8 +242,8 @@ Default Generation Options:
res_2s; 1 for euler_a, er_sde and dpm++2s_a) res_2s; 1 for euler_a, er_sde and dpm++2s_a)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto) --flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or image edit models (default:
(default: same as --cfg-scale) same as --cfg-scale)
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input --high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input
(default: 3.5) (default: 3.5)
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: --high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default:

View File

@ -109,6 +109,19 @@ const char* sampling_methods_str[] = {
/*================================================== Helper Functions ================================================*/ /*================================================== Helper Functions ================================================*/
static bool sd_version_supports_ref_latent_img_cfg(SDVersion version) {
return version == VERSION_FLUX ||
sd_version_is_flux2(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_longcat(version) ||
sd_version_is_z_image(version);
}
static bool sd_version_supports_img_cfg(SDVersion version, bool has_ref_images) {
return sd_version_is_inpaint_or_unet_edit(version) ||
(has_ref_images && sd_version_supports_ref_latent_img_cfg(version));
}
void calculate_alphas_cumprod(float* alphas_cumprod, void calculate_alphas_cumprod(float* alphas_cumprod,
float linear_start = 0.00085f, float linear_start = 0.00085f,
float linear_end = 0.0120f, float linear_end = 0.0120f,
@ -2059,13 +2072,19 @@ public:
cond, cond,
&controls); &controls);
static const std::vector<sd::Tensor<float>> empty_ref_latents;
bool uncond_without_ref_latents = !img_cond.empty() &&
!ref_latents.empty() &&
sd_version_supports_ref_latent_img_cfg(version);
auto run_condition = [&](const SDCondition& condition, auto run_condition = [&](const SDCondition& condition,
const sd::Tensor<float>* c_concat_override = nullptr, const sd::Tensor<float>* c_concat_override = nullptr,
const std::vector<int>* local_skip_layers = nullptr) -> sd::Tensor<float> { const std::vector<int>* local_skip_layers = nullptr,
const std::vector<sd::Tensor<float>>* ref_latents_override = nullptr) -> sd::Tensor<float> {
diffusion_params.context = condition.c_crossattn.empty() ? nullptr : &condition.c_crossattn; diffusion_params.context = condition.c_crossattn.empty() ? nullptr : &condition.c_crossattn;
diffusion_params.c_concat = c_concat_override != nullptr ? c_concat_override : (condition.c_concat.empty() ? nullptr : &condition.c_concat); diffusion_params.c_concat = c_concat_override != nullptr ? c_concat_override : (condition.c_concat.empty() ? nullptr : &condition.c_concat);
diffusion_params.y = condition.c_vector.empty() ? nullptr : &condition.c_vector; diffusion_params.y = condition.c_vector.empty() ? nullptr : &condition.c_vector;
diffusion_params.ref_latents = condition.c_ref_images.empty() ? &ref_latents : &condition.c_ref_images; diffusion_params.ref_latents = ref_latents_override != nullptr ? ref_latents_override : (condition.c_ref_images.empty() ? &ref_latents : &condition.c_ref_images);
if (sd_version_is_unet(version)) { if (sd_version_is_unet(version)) {
diffusion_params.extra = UNetDiffusionExtra{-1, &controls, control_strength}; diffusion_params.extra = UNetDiffusionExtra{-1, &controls, control_strength};
@ -2140,7 +2159,10 @@ public:
LOG_DEBUG("Skipping layers at uncond step %d\n", step); LOG_DEBUG("Skipping layers at uncond step %d\n", step);
uncond_skip_layers = &skip_layer_guidance.layers(); uncond_skip_layers = &skip_layer_guidance.layers();
} }
uncond_out = run_condition(uncond, nullptr, uncond_skip_layers); uncond_out = run_condition(uncond,
nullptr,
uncond_skip_layers,
uncond_without_ref_latents ? &empty_ref_latents : nullptr);
if (uncond_out.empty()) { if (uncond_out.empty()) {
return {}; return {};
} }
@ -3149,6 +3171,7 @@ struct GenerationRequest {
bool use_img_cond = false; bool use_img_cond = false;
bool use_high_noise_uncond = false; bool use_high_noise_uncond = false;
bool use_high_noise_img_cond = false; bool use_high_noise_img_cond = false;
bool has_ref_images = false;
const sd_cache_params_t* cache_params = nullptr; const sd_cache_params_t* cache_params = nullptr;
int batch_count = 1; int batch_count = 1;
int shifted_timestep = 0; int shifted_timestep = 0;
@ -3182,6 +3205,7 @@ struct GenerationRequest {
eta = sd_img_gen_params->sample_params.eta; eta = sd_img_gen_params->sample_params.eta;
increase_ref_index = sd_img_gen_params->increase_ref_index; increase_ref_index = sd_img_gen_params->increase_ref_index;
auto_resize_ref_image = sd_img_gen_params->auto_resize_ref_image; auto_resize_ref_image = sd_img_gen_params->auto_resize_ref_image;
has_ref_images = sd_img_gen_params->ref_images_count > 0;
guidance = sd_img_gen_params->sample_params.guidance; guidance = sd_img_gen_params->sample_params.guidance;
pm_params = sd_img_gen_params->pm_params; pm_params = sd_img_gen_params->pm_params;
hires = sd_img_gen_params->hires; hires = sd_img_gen_params->hires;
@ -3305,17 +3329,22 @@ struct GenerationRequest {
sd_guidance_params_t* guidance, sd_guidance_params_t* guidance,
bool* use_uncond, bool* use_uncond,
bool* use_img_cond, bool* use_img_cond,
bool has_ref_images,
const char* stage_name = nullptr) { const char* stage_name = nullptr) {
GGML_ASSERT(guidance != nullptr); GGML_ASSERT(guidance != nullptr);
GGML_ASSERT(use_uncond != nullptr); GGML_ASSERT(use_uncond != nullptr);
GGML_ASSERT(use_img_cond != nullptr); GGML_ASSERT(use_img_cond != nullptr);
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) // out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
// img_cfg == txt_cfg means that img_cfg is not used // img_cfg == txt_cfg means that img_cfg is not used
bool img_cfg_was_unset = !std::isfinite(guidance->img_cfg);
if (!std::isfinite(guidance->img_cfg)) { if (!std::isfinite(guidance->img_cfg)) {
guidance->img_cfg = guidance->txt_cfg; guidance->img_cfg = guidance->txt_cfg;
} }
if (!sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version)) { if (!sd_version_supports_img_cfg(sd_ctx->sd->version, has_ref_images)) {
if (!img_cfg_was_unset && guidance->img_cfg != guidance->txt_cfg) {
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance");
}
guidance->img_cfg = guidance->txt_cfg; guidance->img_cfg = guidance->txt_cfg;
} }
@ -3344,12 +3373,13 @@ struct GenerationRequest {
resolve_hires(); resolve_hires();
seed = resolve_seed(seed); seed = resolve_seed(seed);
resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_cond); resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_cond, has_ref_images);
if (sd_ctx->sd->high_noise_diffusion_model) { if (sd_ctx->sd->high_noise_diffusion_model) {
resolve_guidance(sd_ctx, resolve_guidance(sd_ctx,
&high_noise_guidance, &high_noise_guidance,
&use_high_noise_uncond, &use_high_noise_uncond,
&use_high_noise_img_cond, &use_high_noise_img_cond,
has_ref_images,
"high noise: "); "high noise: ");
} }
@ -3949,6 +3979,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
LOG_WARN("This model needs at least one reference image; using an empty reference"); LOG_WARN("This model needs at least one reference image; using an empty reference");
ref_images.push_back(sd::zeros<float>({request->width, request->height, 3, 1})); ref_images.push_back(sd::zeros<float>({request->width, request->height, 3, 1}));
request->guidance.img_cfg = request->guidance.txt_cfg; request->guidance.img_cfg = request->guidance.txt_cfg;
request->use_img_cond = false;
} }
if (!ref_images.empty()) { if (!ref_images.empty()) {
@ -4104,6 +4135,10 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
cond.c_concat = latents->concat_latent; // TODO: optimize cond.c_concat = latents->concat_latent; // TODO: optimize
} }
bool use_ref_latent_img_cfg = request->use_img_cond &&
!latents->ref_images.empty() &&
sd_version_supports_ref_latent_img_cfg(sd_ctx->sd->version);
SDCondition uncond; SDCondition uncond;
if (request->use_uncond || request->use_high_noise_uncond) { if (request->use_uncond || request->use_high_noise_uncond) {
bool zero_out_masked = false; bool zero_out_masked = false;
@ -4121,6 +4156,23 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
} }
} }
SDCondition img_cond;
if (request->use_img_cond) {
if (use_ref_latent_img_cfg) {
img_cond = uncond;
std::vector<sd::Tensor<float>> empty_ref_images;
condition_params.ref_images = &empty_ref_images;
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(sd_ctx->sd->n_threads,
condition_params);
if (uncond.c_concat.empty()) {
uncond.c_concat = latents->uncond_concat_latent; // TODO: optimize
}
} else {
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
}
}
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %.2fs", (t1 - prepare_start_ms) * 1.0f / 1000); LOG_INFO("get_learned_condition completed, taking %.2fs", (t1 - prepare_start_ms) * 1.0f / 1000);
@ -4129,9 +4181,7 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
} }
ImageGenerationEmbeds embeds; ImageGenerationEmbeds embeds;
if (request->use_img_cond) { embeds.img_cond = std::move(img_cond);
embeds.img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
}
embeds.cond = std::move(cond); embeds.cond = std::move(cond);
embeds.uncond = std::move(uncond); embeds.uncond = std::move(uncond);
embeds.id_cond = std::move(id_cond); embeds.id_cond = std::move(id_cond);