fix: correct image to image DDIM and TCD (#1410)

This commit is contained in:
Wagner Bruna 2026-04-19 06:51:28 -03:00 committed by GitHub
parent e77e4c46bf
commit 7023fc4cfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 47 deletions

View File

@ -1525,32 +1525,12 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod(TIMESTEPS);
std::vector<double> compvis_sigmas(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)),
2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
}
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
int timestep = static_cast<int>(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1;
int prev_timestep = timestep - TIMESTEPS / steps;
float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) {
x *= std::sqrt(sigma * sigma + 1) / sigma;
} else {
x *= std::sqrt(sigma * sigma + 1);
}
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
auto model_output_opt = model(x, sigma, i + 1);
if (model_output_opt.empty()) {
@ -1559,8 +1539,8 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
sd::Tensor<float> model_output = std::move(model_output_opt);
model_output = (x - model_output) * (1.0f / sigma);
float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
float beta_prod_t = 1.0f - alpha_prod_t;
sd::Tensor<float> pred_original_sample = ((x / std::sqrt(sigma * sigma + 1)) -
@ -1572,12 +1552,13 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
(1.0f - alpha_prod_t / alpha_prod_t_prev);
float std_dev_t = eta * std::sqrt(variance);
x = std::sqrt(alpha_prod_t_prev) * pred_original_sample +
std::sqrt(1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) * model_output;
x = pred_original_sample +
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output;
if (eta > 0) {
x += std_dev_t * sd::Tensor<float>::randn_like(x, rng);
x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
}
}
return x;
}
@ -1603,19 +1584,25 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
}
int original_steps = 50;
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps));
int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps));
int timestep_s = (int)floor((1 - eta) * prev_timestep);
float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) {
x *= std::sqrt(sigma * sigma + 1) / sigma;
} else {
x *= std::sqrt(sigma * sigma + 1);
auto get_timestep_from_sigma = [&](float s) -> int {
auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
if (it == compvis_sigmas.begin()) return 0;
if (it == compvis_sigmas.end()) return TIMESTEPS - 1;
int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
int idx_low = idx_high - 1;
if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
return idx_high;
}
return idx_low;
};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
float sigma_to = sigmas[i + 1];
int prev_timestep = get_timestep_from_sigma(sigma_to);
int timestep_s = (int)floor((1 - eta) * prev_timestep);
float sigma = sigmas[i];
auto model_output_opt = model(x, sigma, i + 1);
if (model_output_opt.empty()) {
@ -1624,9 +1611,9 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
sd::Tensor<float> model_output = std::move(model_output_opt);
model_output = (x - model_output) * (1.0f / sigma);
float alpha_prod_t = static_cast<float>(alphas_cumprod[timestep]);
float alpha_prod_t = 1.0f / (sigma * sigma + 1.0f);
float beta_prod_t = 1.0f - alpha_prod_t;
float alpha_prod_t_prev = static_cast<float>(prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]);
float alpha_prod_t_prev = 1.0f / (sigma_to * sigma_to + 1.0f);
float alpha_prod_s = static_cast<float>(alphas_cumprod[timestep_s]);
float beta_prod_s = 1.0f - alpha_prod_s;
@ -1634,13 +1621,14 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
std::sqrt(beta_prod_t) * model_output) *
(1.0f / std::sqrt(alpha_prod_t));
x = std::sqrt(alpha_prod_s) * pred_original_sample +
std::sqrt(beta_prod_s) * model_output;
x = std::sqrt(alpha_prod_s / alpha_prod_t_prev) * pred_original_sample +
std::sqrt(beta_prod_s / alpha_prod_t_prev) * model_output;
if (eta > 0 && i != steps - 1) {
if (eta > 0 && sigma_to > 0.0f) {
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
}
}
return x;
}

View File

@ -2457,8 +2457,10 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
return EXPONENTIAL_SCHEDULER;
}
}
if (sample_method == LCM_SAMPLE_METHOD) {
if (sample_method == LCM_SAMPLE_METHOD || sample_method == TCD_SAMPLE_METHOD) {
return LCM_SCHEDULER;
} else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) {
return SIMPLE_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}