refactor: img-cond->img_uncond (#1594)

* refactor: img-cond->img_uncond

* align APG and CFG++ with img-uncond CFG

* set default img_cfg to 1.f

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2026-06-03 16:57:42 +02:00 committed by GitHub
parent 2d40a8b2ad
commit 4513e3fda9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 113 deletions

View File

@ -82,17 +82,18 @@ namespace sd::guidance {
output.pred = pred_cond; output.pred = pred_cond;
if (has_tensor(input.pred_uncond)) { if (has_tensor(input.pred_uncond)) {
const sd::Tensor<float>& pred_uncond = *input.pred_uncond; const sd::Tensor<float>& pred_uncond = *input.pred_uncond;
if (has_tensor(input.pred_img_cond)) { if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_cond = *input.pred_img_cond; const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_uncond + output.pred = pred_img_uncond +
image_guidance_scale_ * (pred_img_cond - pred_uncond) + image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
guidance_scale_ * (pred_cond - pred_img_cond); guidance_scale_ * (pred_cond - pred_uncond);
} else { } else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond); output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
} }
} else if (has_tensor(input.pred_img_cond)) { } else if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_cond = *input.pred_img_cond; const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_cond + guidance_scale_ * (pred_cond - pred_img_cond); output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
} }
return output; return output;
@ -108,20 +109,20 @@ namespace sd::guidance {
static sd::Tensor<float> calculate_guidance_delta(const sd::Tensor<float>& pred_cond, static sd::Tensor<float> calculate_guidance_delta(const sd::Tensor<float>& pred_cond,
const sd::Tensor<float>* pred_uncond, const sd::Tensor<float>* pred_uncond,
const sd::Tensor<float>* pred_img_cond, const sd::Tensor<float>* pred_img_uncond,
float guidance_scale, float guidance_scale,
float image_guidance_scale) { float image_guidance_scale) {
if (pred_img_cond != nullptr) { if (pred_img_uncond != nullptr) {
if (pred_uncond != nullptr && guidance_scale == 1.0f) { if (pred_uncond != nullptr && guidance_scale == 1.0f) {
return *pred_img_cond - *pred_uncond; return *pred_uncond - *pred_img_uncond;
} }
if (pred_uncond != nullptr) { if (pred_uncond != nullptr) {
return pred_cond + return pred_cond +
(*pred_uncond * (1.0f - image_guidance_scale) + (*pred_uncond * (image_guidance_scale - guidance_scale) +
*pred_img_cond * (image_guidance_scale - guidance_scale)) / *pred_img_uncond * (1.0f - image_guidance_scale)) /
(guidance_scale - 1.0f); (guidance_scale - 1.0f);
} }
return pred_cond - *pred_img_cond; return pred_cond - *pred_img_uncond;
} }
return pred_cond - *pred_uncond; return pred_cond - *pred_uncond;
} }
@ -139,28 +140,28 @@ namespace sd::guidance {
output.pred = pred_cond; output.pred = pred_cond;
if (has_tensor(input.pred_uncond)) { if (has_tensor(input.pred_uncond)) {
const sd::Tensor<float>& pred_uncond = *input.pred_uncond; const sd::Tensor<float>& pred_uncond = *input.pred_uncond;
if (has_tensor(input.pred_img_cond)) { if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_cond = *input.pred_img_cond; const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_uncond + output.pred = pred_img_uncond +
image_guidance_scale_ * (pred_img_cond - pred_uncond) + image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
guidance_scale_ * (pred_cond - pred_img_cond); guidance_scale_ * (pred_cond - pred_uncond);
} else { } else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond); output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
} }
} else if (has_tensor(input.pred_img_cond)) { } else if (has_tensor(input.pred_img_uncond)) {
const sd::Tensor<float>& pred_img_cond = *input.pred_img_cond; const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
output.pred = pred_img_cond + guidance_scale_ * (pred_cond - pred_img_cond); output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
} }
if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_cond)) { if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_uncond)) {
return output; return output;
} }
const sd::Tensor<float>* pred_uncond = input.pred_uncond; const sd::Tensor<float>* pred_uncond = input.pred_uncond;
const sd::Tensor<float>* pred_img_cond = input.pred_img_cond; const sd::Tensor<float>* pred_img_uncond = input.pred_img_uncond;
sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond, sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond,
pred_uncond, pred_uncond,
pred_img_cond, pred_img_uncond,
guidance_scale_, guidance_scale_,
image_guidance_scale_); image_guidance_scale_);
if (params_.momentum != 0.0f) { if (params_.momentum != 0.0f) {
@ -202,11 +203,11 @@ namespace sd::guidance {
if (pred_uncond != nullptr) { if (pred_uncond != nullptr) {
if (guidance_scale_ != 1.0f) { if (guidance_scale_ != 1.0f) {
output.pred = pred_cond + (guidance_scale_ - 1.0f) * deltas; output.pred = pred_cond + (guidance_scale_ - 1.0f) * deltas;
} else if (pred_img_cond != nullptr) { } else if (pred_img_uncond != nullptr) {
output.pred = pred_cond + (image_guidance_scale_ - 1.0f) * deltas; output.pred = pred_cond + (image_guidance_scale_ - 1.0f) * deltas;
} }
} else if (pred_img_cond != nullptr) { } else if (pred_img_uncond != nullptr) {
output.pred = *pred_img_cond + guidance_scale_ * deltas; output.pred = *pred_img_uncond + guidance_scale_ * deltas;
} }
return output; return output;

View File

@ -33,7 +33,7 @@ namespace sd::guidance {
size_t schedule_size = 0; size_t schedule_size = 0;
const sd::Tensor<float>* pred_cond = nullptr; const sd::Tensor<float>* pred_cond = nullptr;
const sd::Tensor<float>* pred_uncond = nullptr; const sd::Tensor<float>* pred_uncond = nullptr;
const sd::Tensor<float>* pred_img_cond = nullptr; const sd::Tensor<float>* pred_img_uncond = nullptr;
std::function<sd::Tensor<float>()> predict_skip_layer; std::function<sd::Tensor<float>()> predict_skip_layer;
}; };

View File

@ -1945,7 +1945,7 @@ public:
sd::Tensor<float> noise, sd::Tensor<float> noise,
const SDCondition& cond, const SDCondition& cond,
const SDCondition& uncond, const SDCondition& uncond,
const SDCondition& img_cond, const SDCondition& img_uncond,
const SDCondition& id_cond, const SDCondition& id_cond,
const sd::Tensor<float>& control_image, const sd::Tensor<float>& control_image,
float control_strength, float control_strength,
@ -2070,7 +2070,7 @@ public:
sd::Tensor<float> cond_out; sd::Tensor<float> cond_out;
sd::Tensor<float> uncond_out; sd::Tensor<float> uncond_out;
sd::Tensor<float> img_cond_out; sd::Tensor<float> img_uncond_out;
sd_sample::SampleStepCacheDispatcher step_cache(cache_runtime, step, sigma); sd_sample::SampleStepCacheDispatcher step_cache(cache_runtime, step, sigma);
std::vector<sd::Tensor<float>> controls; std::vector<sd::Tensor<float>> controls;
DiffusionParams diffusion_params; DiffusionParams diffusion_params;
@ -2089,7 +2089,7 @@ public:
&controls); &controls);
static const std::vector<sd::Tensor<float>> empty_ref_latents; static const std::vector<sd::Tensor<float>> empty_ref_latents;
bool uncond_without_ref_latents = !img_cond.empty() && bool uncond_without_ref_latents = !img_uncond.empty() &&
!ref_latents.empty() && !ref_latents.empty() &&
sd_version_supports_ref_latent_img_cfg(version); sd_version_supports_ref_latent_img_cfg(version);
@ -2176,17 +2176,18 @@ public:
uncond_skip_layers = &skip_layer_guidance.layers(); uncond_skip_layers = &skip_layer_guidance.layers();
} }
uncond_out = run_condition(uncond, uncond_out = run_condition(uncond,
nullptr, uncond.c_concat.empty() ? nullptr : &uncond.c_concat,
uncond_skip_layers, uncond_skip_layers);
uncond_without_ref_latents ? &empty_ref_latents : nullptr);
if (uncond_out.empty()) { if (uncond_out.empty()) {
return {}; return {};
} }
} }
if (!img_cond.empty()) { if (!img_uncond.empty()) {
img_cond_out = run_condition(img_cond, img_uncond_out = run_condition(img_uncond,
cond.c_concat.empty() ? nullptr : &cond.c_concat); img_uncond.c_concat.empty() ? nullptr : &img_uncond.c_concat,
if (img_cond_out.empty()) { nullptr,
uncond_without_ref_latents ? &empty_ref_latents : nullptr);
if (img_uncond_out.empty()) {
return {}; return {};
} }
} }
@ -2195,7 +2196,7 @@ public:
guidance_input.schedule_size = sigmas.size(); guidance_input.schedule_size = sigmas.size();
guidance_input.pred_cond = &cond_out; guidance_input.pred_cond = &cond_out;
guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out; guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out;
guidance_input.pred_img_cond = img_cond_out.empty() ? nullptr : &img_cond_out; guidance_input.pred_img_uncond = img_uncond_out.empty() ? nullptr : &img_uncond_out;
sd::guidance::GuiderOutput guided = primary_guidance.forward(guidance_input, {}); sd::guidance::GuiderOutput guided = primary_guidance.forward(guidance_input, {});
if (guided.pred.empty()) { if (guided.pred.empty()) {
@ -2222,7 +2223,9 @@ public:
sd::guidance::GuiderOutput output; sd::guidance::GuiderOutput output;
output.pred = denoised; output.pred = denoised;
if (needs_uncond_denoised) { if (needs_uncond_denoised) {
const sd::Tensor<float>& base_uncond = !uncond_out.empty() ? uncond_out : cond_out; const sd::Tensor<float>& base_uncond = !img_uncond_out.empty()
? img_uncond_out
: (!uncond_out.empty() ? uncond_out : cond_out);
output.pred_uncond = base_uncond * c_out + x * c_skip; output.pred_uncond = base_uncond * c_out + x * c_skip;
} }
if (cache_runtime.spectrum_enabled) { if (cache_runtime.spectrum_enabled) {
@ -3196,9 +3199,9 @@ struct GenerationRequest {
int diffusion_model_down_factor = -1; int diffusion_model_down_factor = -1;
int64_t seed = -1; int64_t seed = -1;
bool use_uncond = false; bool use_uncond = false;
bool use_img_cond = false; bool use_img_uncond = false;
bool use_high_noise_uncond = false; bool use_high_noise_uncond = false;
bool use_high_noise_img_cond = false; bool use_high_noise_img_uncond = false;
bool has_ref_images = false; bool has_ref_images = false;
const sd_cache_params_t* cache_params = nullptr; const sd_cache_params_t* cache_params = nullptr;
int batch_count = 1; int batch_count = 1;
@ -3356,35 +3359,38 @@ struct GenerationRequest {
static void resolve_guidance(sd_ctx_t* sd_ctx, static void resolve_guidance(sd_ctx_t* sd_ctx,
sd_guidance_params_t* guidance, sd_guidance_params_t* guidance,
bool* use_uncond, bool* use_uncond,
bool* use_img_cond, bool* use_img_uncond,
bool has_ref_images, bool has_ref_images,
const char* stage_name = nullptr) { const char* stage_name = nullptr) {
GGML_ASSERT(guidance != nullptr); GGML_ASSERT(guidance != nullptr);
GGML_ASSERT(use_uncond != nullptr); GGML_ASSERT(use_uncond != nullptr);
GGML_ASSERT(use_img_cond != nullptr); GGML_ASSERT(use_img_uncond != nullptr);
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) // out_img_uncond + text_cfg_scale * (out_cond - out_uncond) + image_cfg_scale * (out_uncond - out_img_uncond)
// img_cfg == txt_cfg means that img_cfg is not used // -> text_cfg_scale * out_cond + (image_cfg_scale - text_cfg_scale) * out_uncond + (1 - image_cfg_scale) * out_img_uncond
bool img_cfg_was_unset = !std::isfinite(guidance->img_cfg); // out_cond : prompt, image latent
if (!std::isfinite(guidance->img_cfg)) { // out_uncond : negative prompt, image latent
guidance->img_cfg = guidance->txt_cfg; // out_img_uncond : negative prompt, zero image latent
// image_cfg_scale == 1 reduces 3-cond CFG to 2-cond CFG.
bool img_cfg_was_set = std::isfinite(guidance->img_cfg);
if (!img_cfg_was_set) {
guidance->img_cfg = 1.f;
} }
if (!sd_version_supports_img_cfg(sd_ctx->sd->version, has_ref_images)) { if (!sd_version_supports_img_cfg(sd_ctx->sd->version, has_ref_images)) {
if (!img_cfg_was_unset && guidance->img_cfg != guidance->txt_cfg) { if (img_cfg_was_set && guidance->img_cfg != 1.f) {
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance"); LOG_WARN("3-conditioning CFG is not supported with this model, disabling it for better performance");
} }
guidance->img_cfg = guidance->txt_cfg; guidance->img_cfg = 1.f;
}
if (guidance->txt_cfg != 1.f) {
*use_uncond = true;
} }
if (guidance->img_cfg != guidance->txt_cfg) { if (guidance->img_cfg != guidance->txt_cfg) {
*use_img_cond = true;
*use_uncond = true; *use_uncond = true;
} }
if (guidance->img_cfg != 1.f) {
*use_img_uncond = true;
}
if (guidance->txt_cfg < 1.f) { if (guidance->txt_cfg < 1.f) {
const char* prefix = stage_name == nullptr ? "" : stage_name; const char* prefix = stage_name == nullptr ? "" : stage_name;
if (guidance->txt_cfg == 0.f) { if (guidance->txt_cfg == 0.f) {
@ -3401,12 +3407,12 @@ struct GenerationRequest {
resolve_hires(); resolve_hires();
seed = resolve_seed(seed); seed = resolve_seed(seed);
resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_cond, has_ref_images); resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_uncond, has_ref_images);
if (sd_ctx->sd->high_noise_diffusion_model) { if (sd_ctx->sd->high_noise_diffusion_model) {
resolve_guidance(sd_ctx, resolve_guidance(sd_ctx,
&high_noise_guidance, &high_noise_guidance,
&use_high_noise_uncond, &use_high_noise_uncond,
&use_high_noise_img_cond, &use_high_noise_img_uncond,
has_ref_images, has_ref_images,
"high noise: "); "high noise: ");
} }
@ -3525,7 +3531,7 @@ struct SamplePlan {
struct ImageGenerationLatents { struct ImageGenerationLatents {
sd::Tensor<float> init_latent; sd::Tensor<float> init_latent;
sd::Tensor<float> concat_latent; sd::Tensor<float> concat_latent;
sd::Tensor<float> uncond_concat_latent; sd::Tensor<float> img_uncond_concat_latent;
sd::Tensor<float> audio_latent; sd::Tensor<float> audio_latent;
sd::Tensor<float> video_positions; sd::Tensor<float> video_positions;
sd::Tensor<float> control_image; sd::Tensor<float> control_image;
@ -3848,7 +3854,7 @@ static int get_ltxav_num_audio_latents(int frames, int fps) {
struct ImageGenerationEmbeds { struct ImageGenerationEmbeds {
SDCondition cond; SDCondition cond;
SDCondition uncond; SDCondition uncond;
SDCondition img_cond; SDCondition img_uncond;
SDCondition id_cond; SDCondition id_cond;
}; };
@ -4007,7 +4013,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
LOG_WARN("This model needs at least one reference image; using an empty reference"); LOG_WARN("This model needs at least one reference image; using an empty reference");
ref_images.push_back(sd::zeros<float>({request->width, request->height, 3, 1})); ref_images.push_back(sd::zeros<float>({request->width, request->height, 3, 1}));
request->guidance.img_cfg = request->guidance.txt_cfg; request->guidance.img_cfg = request->guidance.txt_cfg;
request->use_img_cond = false; request->use_img_uncond = false;
} }
if (!ref_images.empty()) { if (!ref_images.empty()) {
@ -4060,7 +4066,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
} }
sd::Tensor<float> concat_latent; sd::Tensor<float> concat_latent;
sd::Tensor<float> uncond_concat_latent; sd::Tensor<float> img_uncond_concat_latent;
if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (sd_version_is_inpaint(sd_ctx->sd->version)) {
sd::Tensor<float> masked_init_latent; sd::Tensor<float> masked_init_latent;
@ -4089,7 +4095,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
mask = mask.permute({1, 3, 0, 2}).reshape({request->width / request->vae_scale_factor, request->height / request->vae_scale_factor, request->vae_scale_factor * request->vae_scale_factor, 1}); mask = mask.permute({1, 3, 0, 2}).reshape({request->width / request->vae_scale_factor, request->height / request->vae_scale_factor, request->vae_scale_factor * request->vae_scale_factor, 1});
concat_latent = sd::ops::concat(masked_init_latent, mask, 2); concat_latent = sd::ops::concat(masked_init_latent, mask, 2);
uncond_concat_latent = sd::ops::concat(uncond_masked_init_latent, mask, 2); img_uncond_concat_latent = sd::ops::concat(uncond_masked_init_latent, mask, 2);
} else if (sd_ctx->sd->version == VERSION_FLEX_2) { } else if (sd_ctx->sd->version == VERSION_FLEX_2) {
concat_latent = sd::ops::concat(masked_init_latent, latent_mask, 2); concat_latent = sd::ops::concat(masked_init_latent, latent_mask, 2);
if (!control_latent.empty()) { if (!control_latent.empty()) {
@ -4098,16 +4104,16 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
concat_latent = sd::ops::concat(concat_latent, sd::Tensor<float>::zeros_like(masked_init_latent), 2); concat_latent = sd::ops::concat(concat_latent, sd::Tensor<float>::zeros_like(masked_init_latent), 2);
} }
uncond_concat_latent = sd::ops::concat(uncond_masked_init_latent, latent_mask, 2); img_uncond_concat_latent = sd::ops::concat(uncond_masked_init_latent, latent_mask, 2);
uncond_concat_latent = sd::ops::concat(uncond_concat_latent, sd::Tensor<float>::zeros_like(masked_init_latent), 2); img_uncond_concat_latent = sd::ops::concat(img_uncond_concat_latent, sd::Tensor<float>::zeros_like(masked_init_latent), 2);
} else { // SD1.x SD2.x SDXL inpaint } else { // SD1.x SD2.x SDXL inpaint
concat_latent = sd::ops::concat(latent_mask, masked_init_latent, 2); concat_latent = sd::ops::concat(latent_mask, masked_init_latent, 2);
uncond_concat_latent = sd::ops::concat(latent_mask, uncond_masked_init_latent, 2); img_uncond_concat_latent = sd::ops::concat(latent_mask, uncond_masked_init_latent, 2);
} }
} }
if (sd_version_is_unet_edit(sd_ctx->sd->version)) { if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
concat_latent = sd::ops::interpolate<float>(ref_latents[0], init_latent.shape()); concat_latent = sd::ops::interpolate<float>(ref_latents[0], init_latent.shape());
uncond_concat_latent = sd::Tensor<float>::zeros_like(concat_latent); img_uncond_concat_latent = sd::Tensor<float>::zeros_like(concat_latent);
} }
if (sd_ctx->sd->version == VERSION_FLUX_CONTROLS) { if (sd_ctx->sd->version == VERSION_FLUX_CONTROLS) {
if (!control_latent.empty()) { if (!control_latent.empty()) {
@ -4115,7 +4121,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
} else { } else {
concat_latent = sd::Tensor<float>::zeros_like(init_latent); concat_latent = sd::Tensor<float>::zeros_like(init_latent);
} }
uncond_concat_latent = sd::Tensor<float>::zeros_like(concat_latent); img_uncond_concat_latent = sd::Tensor<float>::zeros_like(concat_latent);
} }
if (sd_img_gen_params->init_image.data != nullptr || sd_img_gen_params->ref_images_count > 0) { if (sd_img_gen_params->init_image.data != nullptr || sd_img_gen_params->ref_images_count > 0) {
@ -4126,7 +4132,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
ImageGenerationLatents latents; ImageGenerationLatents latents;
latents.init_latent = std::move(init_latent); latents.init_latent = std::move(init_latent);
latents.concat_latent = std::move(concat_latent); latents.concat_latent = std::move(concat_latent);
latents.uncond_concat_latent = std::move(uncond_concat_latent); latents.img_uncond_concat_latent = std::move(img_uncond_concat_latent);
latents.control_image = std::move(control_image_tensor); latents.control_image = std::move(control_image_tensor);
latents.ref_images = std::move(ref_images); latents.ref_images = std::move(ref_images);
latents.ref_latents = std::move(ref_latents); latents.ref_latents = std::move(ref_latents);
@ -4163,7 +4169,7 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
cond.c_concat = latents->concat_latent; // TODO: optimize cond.c_concat = latents->concat_latent; // TODO: optimize
} }
bool use_ref_latent_img_cfg = request->use_img_cond && bool use_ref_latent_img_cfg = request->use_img_uncond &&
!latents->ref_images.empty() && !latents->ref_images.empty() &&
sd_version_supports_ref_latent_img_cfg(sd_ctx->sd->version); sd_version_supports_ref_latent_img_cfg(sd_ctx->sd->version);
@ -4180,24 +4186,32 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(sd_ctx->sd->n_threads, uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(sd_ctx->sd->n_threads,
condition_params); condition_params);
if (uncond.c_concat.empty()) { if (uncond.c_concat.empty()) {
uncond.c_concat = latents->uncond_concat_latent; // TODO: optimize uncond.c_concat = latents->concat_latent; // TODO: optimize
} }
} }
SDCondition img_cond; SDCondition img_uncond;
if (request->use_img_cond) { if (request->use_img_uncond) {
if ((request->use_uncond || request->use_high_noise_uncond) && (latents->ref_images.empty() || !use_ref_latent_img_cfg)) {
img_uncond = SDCondition(uncond.c_crossattn, uncond.c_vector, latents->img_uncond_concat_latent);
} else {
bool zero_out_masked = false;
if (sd_version_is_sdxl(sd_ctx->sd->version) &&
request->negative_prompt.empty() &&
!sd_ctx->sd->is_using_edm_v_parameterization) {
zero_out_masked = true;
}
condition_params.text = request->negative_prompt;
condition_params.zero_out_masked = zero_out_masked;
if (use_ref_latent_img_cfg) { if (use_ref_latent_img_cfg) {
img_cond = uncond;
std::vector<sd::Tensor<float>> empty_ref_images; std::vector<sd::Tensor<float>> empty_ref_images;
condition_params.ref_images = &empty_ref_images; condition_params.ref_images = &empty_ref_images;
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(sd_ctx->sd->n_threads,
condition_params);
if (uncond.c_concat.empty()) {
uncond.c_concat = latents->uncond_concat_latent; // TODO: optimize
} }
} else { img_uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(sd_ctx->sd->n_threads,
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat); condition_params);
if (img_uncond.c_concat.empty()) {
img_uncond.c_concat = latents->img_uncond_concat_latent; // TODO: optimize
}
} }
} }
@ -4209,7 +4223,7 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
} }
ImageGenerationEmbeds embeds; ImageGenerationEmbeds embeds;
embeds.img_cond = std::move(img_cond); embeds.img_uncond = std::move(img_uncond);
embeds.cond = std::move(cond); embeds.cond = std::move(cond);
embeds.uncond = std::move(uncond); embeds.uncond = std::move(uncond);
embeds.id_cond = std::move(id_cond); embeds.id_cond = std::move(id_cond);
@ -4492,7 +4506,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::move(noise), std::move(noise),
embeds.cond, embeds.cond,
embeds.uncond, embeds.uncond,
embeds.img_cond, embeds.img_uncond,
embeds.id_cond, embeds.id_cond,
latents.control_image, latents.control_image,
request.control_strength, request.control_strength,
@ -4612,7 +4626,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::move(noise), std::move(noise),
embeds.cond, embeds.cond,
embeds.uncond, embeds.uncond,
embeds.img_cond, embeds.img_uncond,
embeds.id_cond, embeds.id_cond,
latents.control_image, latents.control_image,
request.control_strength, request.control_strength,
@ -5327,7 +5341,7 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
std::move(noise), std::move(noise),
embeds.cond, embeds.cond,
request.use_high_noise_uncond ? embeds.uncond : SDCondition(), request.use_high_noise_uncond ? embeds.uncond : SDCondition(),
embeds.img_cond, embeds.img_uncond,
embeds.id_cond, embeds.id_cond,
sd::Tensor<float>(), sd::Tensor<float>(),
0.f, 0.f,
@ -5373,7 +5387,7 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
std::move(noise), std::move(noise),
embeds.cond, embeds.cond,
request.use_uncond ? embeds.uncond : SDCondition(), request.use_uncond ? embeds.uncond : SDCondition(),
embeds.img_cond, embeds.img_uncond,
embeds.id_cond, embeds.id_cond,
sd::Tensor<float>(), sd::Tensor<float>(),
0.f, 0.f,
@ -5517,7 +5531,7 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
std::move(noise), std::move(noise),
embeds.cond, embeds.cond,
hires_request.use_uncond ? embeds.uncond : SDCondition(), hires_request.use_uncond ? embeds.uncond : SDCondition(),
embeds.img_cond, embeds.img_uncond,
embeds.id_cond, embeds.id_cond,
sd::Tensor<float>(), sd::Tensor<float>(),
0.f, 0.f,