feat: ancestral sampler implementations for flow models (#1374)

* feat: add support for the eta parameter to ancestral samplers

* feat: Euler Ancestral sampler implementation for flow models

* refine flow ancestral sampling and normalize eta defaults

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Wagner Bruna 2026-04-01 14:35:29 -03:00 committed by GitHub
parent 09b12d5f6d
commit 99c1de379b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 17 deletions

View File

@ -114,7 +114,7 @@ Generation Options:
medium medium
--skip-layer-start <float> SLG enabling point (default: 0.01) --skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2) --skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0) --eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto) --flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
@ -122,7 +122,7 @@ Generation Options:
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) --high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2) --high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0) --high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
--strength <float> strength for noising/unnoising (default: 0.75) --strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float> --pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image --control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image

View File

@ -1131,7 +1131,7 @@ struct SDGenerationParams {
&sample_params.guidance.slg.layer_end}, &sample_params.guidance.slg.layer_end},
{"", {"",
"--eta", "--eta",
"eta in DDIM, only for DDIM and TCD (default: 0)", "noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)",
&sample_params.eta}, &sample_params.eta},
{"", {"",
"--flow-shift", "--flow-shift",
@ -1163,7 +1163,7 @@ struct SDGenerationParams {
&high_noise_sample_params.guidance.slg.layer_end}, &high_noise_sample_params.guidance.slg.layer_end},
{"", {"",
"--high-noise-eta", "--high-noise-eta",
"(high noise) eta in DDIM, only for DDIM and TCD (default: 0)", "(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)",
&high_noise_sample_params.eta}, &high_noise_sample_params.eta},
{"", {"",
"--strength", "--strength",

View File

@ -189,7 +189,7 @@ Default Generation Options:
medium medium
--skip-layer-start <float> SLG enabling point (default: 0.01) --skip-layer-start <float> SLG enabling point (default: 0.01)
--skip-layer-end <float> SLG disabling point (default: 0.2) --skip-layer-end <float> SLG disabling point (default: 0.2)
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0) --eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto) --flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0) --high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale) --high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
@ -197,7 +197,7 @@ Default Generation Options:
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0) --high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01) --high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2) --high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0) --high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
--strength <float> strength for noising/unnoising (default: 0.75) --strength <float> strength for noising/unnoising (default: 0.75)
--pm-style-strength <float> --pm-style-strength <float>
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image --control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image

View File

@ -786,10 +786,43 @@ static std::pair<float, float> get_ancestral_step(float sigma_from,
return {sigma_down, sigma_up}; return {sigma_down, sigma_up};
} }
static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from,
float sigma_to,
float eta = 1.0f) {
float sigma_down = sigma_to;
float sigma_up = 0.0f;
float alpha_scale = 1.0f;
if (eta <= 0.0f || sigma_from <= 0.0f || sigma_to <= 0.0f) {
return {sigma_down, sigma_up, alpha_scale};
}
// Flow Euler ancestral sampling becomes numerically unstable for eta > 1, so
// clamp to the valid maximum-noise regime instead of letting NaNs propagate.
eta = std::min(eta, 1.0f);
float sigma_ratio = sigma_to / sigma_from;
sigma_down = sigma_to * (1.0f + (sigma_ratio - 1.0f) * eta);
sigma_down = std::max(0.0f, std::min(sigma_to, sigma_down));
float denom = 1.0f - sigma_down;
if (denom <= 0.0f) {
return {sigma_to, sigma_up, alpha_scale};
}
alpha_scale = (1.0f - sigma_to) / denom;
float term = (sigma_down / sigma_to) * alpha_scale;
term = std::max(-1.0f, std::min(1.0f, term));
sigma_up = sigma_to * std::sqrt(std::max(1.0f - term * term, 0.0f));
return {sigma_down, sigma_up, alpha_scale};
}
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model, static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng) { std::shared_ptr<RNG> rng,
float eta) {
int steps = static_cast<int>(sigmas.size()) - 1; int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
float sigma = sigmas[i]; float sigma = sigmas[i];
@ -799,7 +832,7 @@ static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma; sd::Tensor<float> d = (x - denoised) / sigma;
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]); auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
x += d * (sigma_down - sigmas[i]); x += d * (sigma_down - sigmas[i]);
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up; x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
@ -808,6 +841,30 @@ static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
return x; return x;
} }
static sd::Tensor<float> sample_euler_flow(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1);
if (denoised_opt.empty()) {
return {};
}
sd::Tensor<float> denoised = std::move(denoised_opt);
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigmas[i + 1], eta);
float sigma_ratio = sigma_down / sigma;
x = sigma_ratio * x + (1.0f - sigma_ratio) * denoised;
if (sigma_up > 0.0f) {
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
return x;
}
static sd::Tensor<float> sample_euler(denoise_cb_t model, static sd::Tensor<float> sample_euler(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
const std::vector<float>& sigmas) { const std::vector<float>& sigmas) {
@ -885,7 +942,8 @@ static sd::Tensor<float> sample_dpm2(denoise_cb_t model,
static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model, static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng) { std::shared_ptr<RNG> rng,
float eta) {
auto t_fn = [](float sigma) -> float { return -log(sigma); }; auto t_fn = [](float sigma) -> float { return -log(sigma); };
auto sigma_fn = [](float t) -> float { return exp(-t); }; auto sigma_fn = [](float t) -> float { return exp(-t); };
@ -896,7 +954,7 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
return {}; return {};
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]); auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
if (sigma_down == 0) { if (sigma_down == 0) {
x = denoised; x = denoised;
@ -1368,10 +1426,14 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
sd::Tensor<float> x, sd::Tensor<float> x,
std::vector<float> sigmas, std::vector<float> sigmas,
std::shared_ptr<RNG> rng, std::shared_ptr<RNG> rng,
float eta) { float eta,
bool is_flow_denoiser) {
switch (method) { switch (method) {
case EULER_A_SAMPLE_METHOD: case EULER_A_SAMPLE_METHOD:
return sample_euler_ancestral(model, std::move(x), sigmas, rng); if (is_flow_denoiser)
return sample_euler_flow(model, std::move(x), sigmas, rng, eta);
else
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
case EULER_SAMPLE_METHOD: case EULER_SAMPLE_METHOD:
return sample_euler(model, std::move(x), sigmas); return sample_euler(model, std::move(x), sigmas);
case HEUN_SAMPLE_METHOD: case HEUN_SAMPLE_METHOD:
@ -1379,7 +1441,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
case DPM2_SAMPLE_METHOD: case DPM2_SAMPLE_METHOD:
return sample_dpm2(model, std::move(x), sigmas); return sample_dpm2(model, std::move(x), sigmas);
case DPMPP2S_A_SAMPLE_METHOD: case DPMPP2S_A_SAMPLE_METHOD:
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng); return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta);
case DPMPP2M_SAMPLE_METHOD: case DPMPP2M_SAMPLE_METHOD:
return sample_dpmpp_2m(model, std::move(x), sigmas); return sample_dpmpp_2m(model, std::move(x), sigmas);
case DPMPP2Mv2_SAMPLE_METHOD: case DPMPP2Mv2_SAMPLE_METHOD:

View File

@ -1593,6 +1593,7 @@ public:
float eta, float eta,
int shifted_timestep, int shifted_timestep,
sample_method_t method, sample_method_t method,
bool is_flow_denoiser,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
int start_merge_step, int start_merge_step,
const std::vector<sd::Tensor<float>>& ref_latents, const std::vector<sd::Tensor<float>>& ref_latents,
@ -1791,7 +1792,7 @@ public:
return denoised; return denoised;
}; };
auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta); auto x0_opt = sample_k_diffusion(method, denoise, x_t, sigmas, sampler_rng, eta, is_flow_denoiser);
if (x0_opt.empty()) { if (x0_opt.empty()) {
LOG_ERROR("Diffusion model sampling failed"); LOG_ERROR("Diffusion model sampling failed");
if (control_net) { if (control_net) {
@ -1909,6 +1910,11 @@ public:
flow_denoiser->set_shift(flow_shift); flow_denoiser->set_shift(flow_shift);
} }
} }
bool is_flow_denoiser() {
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
return !!flow_denoiser;
}
}; };
/*================================================= SD API ==================================================*/ /*================================================= SD API ==================================================*/
@ -2225,6 +2231,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->scheduler = SCHEDULER_COUNT; sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_COUNT; sample_params->sample_method = SAMPLE_METHOD_COUNT;
sample_params->sample_steps = 20; sample_params->sample_steps = 20;
sample_params->eta = INFINITY;
sample_params->custom_sigmas = nullptr; sample_params->custom_sigmas = nullptr;
sample_params->custom_sigmas_count = 0; sample_params->custom_sigmas_count = 0;
sample_params->flow_shift = INFINITY; sample_params->flow_shift = INFINITY;
@ -2438,6 +2445,26 @@ static scheduler_t resolve_scheduler(sd_ctx_t* sd_ctx,
return scheduler; return scheduler;
} }
static float resolve_eta(sd_ctx_t* sd_ctx,
float eta,
enum sample_method_t sample_method) {
if (eta == INFINITY) {
switch (sample_method) {
case DDIM_TRAILING_SAMPLE_METHOD:
case TCD_SAMPLE_METHOD:
case RES_MULTISTEP_SAMPLE_METHOD:
case RES_2S_SAMPLE_METHOD:
return 0.0f;
case EULER_A_SAMPLE_METHOD:
case DPMPP2S_A_SAMPLE_METHOD:
return 1.0f;
default:;
}
return 0.0f;
}
return eta;
}
struct GenerationRequest { struct GenerationRequest {
std::string prompt; std::string prompt;
std::string negative_prompt; std::string negative_prompt;
@ -2586,6 +2613,8 @@ struct GenerationRequest {
struct SamplePlan { struct SamplePlan {
enum sample_method_t sample_method = SAMPLE_METHOD_COUNT; enum sample_method_t sample_method = SAMPLE_METHOD_COUNT;
enum sample_method_t high_noise_sample_method = SAMPLE_METHOD_COUNT; enum sample_method_t high_noise_sample_method = SAMPLE_METHOD_COUNT;
float eta = 0.f;
float high_noise_eta = 0.f;
int sample_steps = 0; int sample_steps = 0;
int high_noise_sample_steps = 0; int high_noise_sample_steps = 0;
int total_steps = 0; int total_steps = 0;
@ -2597,6 +2626,7 @@ struct SamplePlan {
const sd_img_gen_params_t* sd_img_gen_params, const sd_img_gen_params_t* sd_img_gen_params,
const GenerationRequest& request) { const GenerationRequest& request) {
sample_method = sd_img_gen_params->sample_params.sample_method; sample_method = sd_img_gen_params->sample_params.sample_method;
eta = sd_img_gen_params->sample_params.eta;
sample_steps = sd_img_gen_params->sample_params.sample_steps; sample_steps = sd_img_gen_params->sample_params.sample_steps;
resolve(sd_ctx, &request, &sd_img_gen_params->sample_params); resolve(sd_ctx, &request, &sd_img_gen_params->sample_params);
} }
@ -2605,10 +2635,12 @@ struct SamplePlan {
const sd_vid_gen_params_t* sd_vid_gen_params, const sd_vid_gen_params_t* sd_vid_gen_params,
const GenerationRequest& request) { const GenerationRequest& request) {
sample_method = sd_vid_gen_params->sample_params.sample_method; sample_method = sd_vid_gen_params->sample_params.sample_method;
eta = sd_vid_gen_params->sample_params.eta;
sample_steps = sd_vid_gen_params->sample_params.sample_steps; sample_steps = sd_vid_gen_params->sample_params.sample_steps;
if (sd_ctx->sd->high_noise_diffusion_model) { if (sd_ctx->sd->high_noise_diffusion_model) {
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps; high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method; high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
high_noise_eta = sd_vid_gen_params->high_noise_sample_params.eta;
} }
moe_boundary = sd_vid_gen_params->moe_boundary; moe_boundary = sd_vid_gen_params->moe_boundary;
resolve(sd_ctx, &request, &sd_vid_gen_params->sample_params); resolve(sd_ctx, &request, &sd_vid_gen_params->sample_params);
@ -2644,6 +2676,8 @@ struct SamplePlan {
sd_ctx->sd->version); sd_ctx->sd->version);
} }
eta = resolve_eta(sd_ctx, eta, sample_method);
if (high_noise_sample_steps < 0) { if (high_noise_sample_steps < 0) {
for (size_t i = 0; i < sigmas.size(); ++i) { for (size_t i = 0; i < sigmas.size(); ++i) {
if (sigmas[i] < moe_boundary) { if (sigmas[i] < moe_boundary) {
@ -2658,6 +2692,7 @@ struct SamplePlan {
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
high_noise_sample_method = resolve_sample_method(sd_ctx, high_noise_sample_method = resolve_sample_method(sd_ctx,
high_noise_sample_method); high_noise_sample_method);
high_noise_eta = resolve_eta(sd_ctx, high_noise_eta, high_noise_sample_method);
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]); LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
} }
@ -3123,9 +3158,10 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
latents.control_image, latents.control_image,
request.control_strength, request.control_strength,
request.guidance, request.guidance,
request.eta, plan.eta,
request.shifted_timestep, request.shifted_timestep,
plan.sample_method, plan.sample_method,
sd_ctx->sd->is_flow_denoiser(),
plan.sigmas, plan.sigmas,
plan.start_merge_step, plan.start_merge_step,
latents.ref_latents, latents.ref_latents,
@ -3482,9 +3518,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd::Tensor<float>(), sd::Tensor<float>(),
0.f, 0.f,
request.high_noise_guidance, request.high_noise_guidance,
sd_vid_gen_params->high_noise_sample_params.eta, plan.high_noise_eta,
request.shifted_timestep, request.shifted_timestep,
plan.high_noise_sample_method, plan.high_noise_sample_method,
sd_ctx->sd->is_flow_denoiser(),
high_noise_sigmas, high_noise_sigmas,
-1, -1,
std::vector<sd::Tensor<float>>{}, std::vector<sd::Tensor<float>>{},
@ -3523,9 +3560,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd::Tensor<float>(), sd::Tensor<float>(),
0.f, 0.f,
sd_vid_gen_params->sample_params.guidance, sd_vid_gen_params->sample_params.guidance,
sd_vid_gen_params->sample_params.eta, plan.eta,
sd_vid_gen_params->sample_params.shifted_timestep, sd_vid_gen_params->sample_params.shifted_timestep,
plan.sample_method, plan.sample_method,
sd_ctx->sd->is_flow_denoiser(),
plan.sigmas, plan.sigmas,
-1, -1,
std::vector<sd::Tensor<float>>{}, std::vector<sd::Tensor<float>>{},