refactor: unify Euler, Euler Ancestral and DDIM implementations (#1474)

This commit is contained in:
Wagner Bruna 2026-05-16 05:13:28 -03:00 committed by GitHub
parent db08b84607
commit fd1a2794f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -824,45 +824,33 @@ static std::tuple<float, float, float> get_ancestral_step(float sigma_from,
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
std::shared_ptr<RNG> rng = nullptr,
bool is_flow_denoiser = false,
float eta = 0.f) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
auto denoised_opt = model(x, sigma, i + 1, nullptr);
if (denoised_opt.empty()) {
return {};
}
sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma;
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
x += d * (sigma_down - sigmas[i]);
if (sigmas[i + 1] > 0) {
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
return x;
}
static sd::Tensor<float> sample_euler_flow(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1, nullptr);
if (denoised_opt.empty()) {
return {};
}
sd::Tensor<float> denoised = std::move(denoised_opt);
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step_flow(sigma, sigmas[i + 1], eta);
float sigma_ratio = sigma_down / sigma;
x = sigma_ratio * x + (1.0f - sigma_ratio) * denoised;
if (sigma_up > 0.0f) {
x = alpha_scale * x + sd::Tensor<float>::randn_like(x, rng) * sigma_up;
sd::Tensor<float> denoised = std::move(denoised_opt);
if (sigma_to == 0.f) {
x = denoised;
} else if (eta == 0.f) {
float sigma_ratio = sigma_to / sigma;
x = sigma_ratio * x + (1.0 - sigma_ratio) * denoised;
} else {
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma, sigma_to, eta, is_flow_denoiser);
float sigma_ratio = sigma_down / sigma;
x = sigma_ratio * x + (1.0f - sigma_ratio) * denoised;
if (sigma_up > 0.f) {
if (is_flow_denoiser) {
x *= alpha_scale;
}
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
}
return x;
@ -1633,46 +1621,6 @@ static sd::Tensor<float> sample_er_sde(denoise_cb_t model,
return x;
}
static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
auto model_output_opt = model(x, sigma, i + 1, nullptr);
if (model_output_opt.empty()) {
return {};
}
sd::Tensor<float> model_output = std::move(model_output_opt);
model_output = (x - model_output) * (1.0f / sigma);
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
float beta_prod_t = 1.0f - alpha_prod_t;
sd::Tensor<float> pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) -
std::sqrt(beta_prod_t) * model_output) *
(1.0f / std::sqrt(alpha_prod_t));
float beta_prod_t_prev = 1.0f - alpha_prod_t_prev;
float variance = (beta_prod_t_prev / beta_prod_t) *
(1.0f - alpha_prod_t / alpha_prod_t_prev);
float std_dev_t = eta * std::sqrt(variance);
x = pred_original_sample +
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output;
if (eta > 0) {
x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
}
}
return x;
}
static sd::Tensor<float> sample_tcd(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
@ -1715,12 +1663,12 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
int timestep_s = (int)floor((1 - eta) * prev_timestep);
float sigma = sigmas[i];
auto model_output_opt = model(x, sigma, i + 1, nullptr);
if (model_output_opt.empty()) {
auto denoised_opt = model(x, sigma, i + 1, nullptr);
if (denoised_opt.empty()) {
return {};
}
sd::Tensor<float> model_output = std::move(model_output_opt);
model_output = (x - model_output) * (1.0f / sigma);
sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma;
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
float beta_prod_t = 1.0f - alpha_prod_t;
@ -1728,12 +1676,8 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
float alpha_prod_s = static_cast<float>(alphas_cumprod[timestep_s]);
float beta_prod_s = 1.0f - alpha_prod_s;
sd::Tensor<float> pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) -
std::sqrt(beta_prod_t) * model_output) *
(1.0f / std::sqrt(alpha_prod_t));
x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * pred_original_sample +
std::sqrt(beta_prod_s / alpha_prod_t_prev) * model_output;
x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * denoised +
std::sqrt(beta_prod_s / alpha_prod_t_prev) * d;
if (eta > 0 && sigma_to > 0.0f) {
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
@ -1804,10 +1748,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
const char* extra_sample_args) {
switch (method) {
case EULER_A_SAMPLE_METHOD:
if (is_flow_denoiser)
return sample_euler_flow(model, std::move(x), sigmas, rng, eta);
else
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case EULER_SAMPLE_METHOD:
return sample_euler(model, std::move(x), sigmas);
case HEUN_SAMPLE_METHOD:
@ -1836,7 +1777,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
case ER_SDE_SAMPLE_METHOD:
return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case DDIM_TRAILING_SAMPLE_METHOD:
return sample_ddim_trailing(model, std::move(x), sigmas, rng, eta);
// DDIM is equivalent to Euler Ancestral with the Simple scheduler
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case TCD_SAMPLE_METHOD:
return sample_tcd(model, std::move(x), sigmas, rng, eta);
case EULER_CFG_PP_SAMPLE_METHOD: