feat: implement DDIM with the "trailing" timestep spacing and TCD (#568)

This commit is contained in:
yslai 2025-02-22 05:34:22 -08:00 committed by GitHub
parent f27f2b2aa2
commit 19d876ee30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 400 additions and 3 deletions

View File

@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
ggml_context* work_ctx,
ggml_tensor* x,
std::vector<float> sigmas,
std::shared_ptr<RNG> rng) {
std::shared_ptr<RNG> rng,
float eta) {
size_t steps = sigmas.size() - 1;
// sample_euler_ancestral
switch (method) {
@ -1005,6 +1006,374 @@ static void sample_k_diffusion(sample_method_t method,
}
}
} break;
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing
{
// See J. Song et al., "Denoising Diffusion Implicit
// Models", arXiv:2010.02502 [cs.LG]
//
// DDIM itself needs alphas_cumprod (DDPM, J. Ho et al.,
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
// end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteScheduler)
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
std::vector<double> compvis_sigmas;
alphas_cumprod.reserve(TIMESTEPS);
compvis_sigmas.reserve(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)), 2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) /
alphas_cumprod[i]);
}
struct ggml_tensor* pred_original_sample =
ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* variance_noise =
ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedules and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming
//
// Diffuser naming vs. Song et al. (2010), p. 5, (12)
// and p. 16, (16) (<variable name> -> <name in
// paper>):
//
// - pred_noise_t -> epsilon_theta^(t)(x_t)
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
// - std_dev_t -> sigma_t (not the LMS sigma)
// - eta -> eta (set to 0 at the moment)
// - pred_sample_direction -> "direction pointing to
// x_t"
// - pred_prev_sample -> "x_t-1"
int timestep =
roundf(TIMESTEPS -
i * ((float)TIMESTEPS / steps)) - 1;
// 1. get previous step value (=t-1)
int prev_timestep = timestep - TIMESTEPS / steps;
// The sigma here is chosen to cause the
// CompVisDenoiser to produce t = timestep
float sigma = compvis_sigmas[timestep];
if (i == 0) {
// The function add_noise intializes x to
// Diffusers' latents * sigma (as in Diffusers'
// pipeline) or sample * sigma (Diffusers'
// scheduler), where this sigma = init_noise_sigma
// in Diffusers. For DDPM and DDIM however,
// init_noise_sigma = 1. But the k-diffusion
// model() also evaluates F_theta(c_in(sigma) x;
// ...) instead of the bare U-net F_theta, with
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
// T. Karras et al., "Elucidating the Design Space
// of Diffusion-Based Generative Models",
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
// the first call has to be prescaled as x <- x /
// (c_in * sigma) with the k-diffusion pipeline
// and CompVisDenoiser.
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
sigma;
}
}
else {
// For the subsequent steps after the first one,
// at this point x = latents or x = sample, and
// needs to be prescaled with x <- sample / c_in
// to compensate for model() applying the scale
// c_in before the U-net F_theta
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1);
}
}
// Note (also noise_pred in Diffuser's pipeline)
// model_output = model() is the D(x, sigma) as
// defined in Karras et al. (2022), p. 3, Table 1 and
// p. 8 (7), compare also p. 38 (226) therein.
struct ggml_tensor* model_output =
model(x, sigma, i + 1);
// Here model_output is still the k-diffusion denoiser
// output, not the U-net output F_theta(c_in(sigma) x;
// ...) in Karras et al. (2022), whereas Diffusers'
// model_output is F_theta(...). Recover the actual
// model_output, which is also referred to as the
// "Karras ODE derivative" d or d_cur in several
// samplers above.
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_model_output[j] =
(vec_x[j] - vec_model_output[j]) *
(1 / sigma);
}
}
// 2. compute alphas, betas
float alpha_prod_t = alphas_cumprod[timestep];
// Note final_alpha_cumprod = alphas_cumprod[0] due to
// trailing timestep spacing
float alpha_prod_t_prev = prev_timestep >= 0 ?
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
float beta_prod_t = 1 - alpha_prod_t;
// 3. compute predicted original sample from predicted
// noise also called "predicted x_0" of formula (12)
// from https://arxiv.org/pdf/2010.02502.pdf
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
// Note the substitution of latents or sample = x
// * c_in = x / sqrt(sigma^2 + 1)
for (int j = 0; j < ggml_nelements(x); j++) {
vec_pred_original_sample[j] =
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
std::sqrt(beta_prod_t) *
vec_model_output[j]) *
(1 / std::sqrt(alpha_prod_t));
}
}
// Assuming the "epsilon" prediction type, where below
// pred_epsilon = model_output is inserted, and is not
// defined/copied explicitly.
//
// 5. compute variance: "sigma_t(eta)" -> see formula
// (16)
//
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
// sqrt(1 - alpha_t/alpha_t-1)
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
float variance = (beta_prod_t_prev / beta_prod_t) *
(1 - alpha_prod_t / alpha_prod_t_prev);
float std_dev_t = eta * std::sqrt(variance);
// 6. compute "direction pointing to x_t" of formula
// (12) from https://arxiv.org/pdf/2010.02502.pdf
// 7. compute x_t without "random noise" of formula
// (12) from https://arxiv.org/pdf/2010.02502.pdf
{
float* vec_model_output = (float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Two step inner loop without an explicit
// tensor
float pred_sample_direction =
std::sqrt(1 - alpha_prod_t_prev -
std::pow(std_dev_t, 2)) *
vec_model_output[j];
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
vec_pred_original_sample[j] +
pred_sample_direction;
}
}
if (eta > 0) {
ggml_tensor_set_f32_randn(variance_noise, rng);
float* vec_variance_noise =
(float*)variance_noise->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] += std_dev_t * vec_variance_noise[j];
}
}
// See the note above: x = latents or sample here, and
// is not scaled by the c_in. For the final output
// this is correct, but for subsequent iterations, x
// needs to be prescaled again, since k-diffusion's
// model() differes from the bare U-net F_theta by the
// factor c_in.
}
} break;
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation
{
// See J. Zheng et al., "Trajectory Consistency
// Distillation: Improved Latent Consistency Distillation
// by Semi-Linear Consistency Function with Trajectory
// Mapping", arXiv:2402.19159 [cs.CV]
float beta_start = 0.00085f;
float beta_end = 0.0120f;
std::vector<double> alphas_cumprod;
std::vector<double> compvis_sigmas;
alphas_cumprod.reserve(TIMESTEPS);
compvis_sigmas.reserve(TIMESTEPS);
for (int i = 0; i < TIMESTEPS; i++) {
alphas_cumprod[i] =
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
(1.0f -
std::pow(sqrtf(beta_start) +
(sqrtf(beta_end) - sqrtf(beta_start)) *
((float)i / (TIMESTEPS - 1)), 2));
compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) /
alphas_cumprod[i]);
}
int original_steps = 50;
struct ggml_tensor* pred_original_sample =
ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* noise =
ggml_dup_tensor(work_ctx, x);
for (int i = 0; i < steps; i++) {
// Analytic form for TCD timesteps
int timestep = TIMESTEPS - 1 -
(TIMESTEPS / original_steps) *
(int)floor(i * ((float)original_steps / steps));
// 1. get previous step value
int prev_timestep = i >= steps - 1 ? 0 :
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
(int)floor((i + 1) *
((float)original_steps / steps));
// Here timestep_s is tau_n' in Algorithm 4. The _s
// notation appears to be that from C. Lu,
// "DPM-Solver: A Fast ODE Solver for Diffusion
// Probabilistic Model Sampling in Around 10 Steps",
// arXiv:2206.00927 [cs.LG], but this notation is not
// continued in Algorithm 4, where _n' is used.
int timestep_s =
(int)floor((1 - eta) * prev_timestep);
// Begin k-diffusion specific workaround for
// evaluating F_theta(x; ...) from D(x, sigma), same
// as in DDIM (and see there for detailed comments)
float sigma = compvis_sigmas[timestep];
if (i == 0) {
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
sigma;
}
}
else {
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_x[j] *= std::sqrt(sigma * sigma + 1);
}
}
struct ggml_tensor* model_output =
model(x, sigma, i + 1);
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_model_output[j] =
(vec_x[j] - vec_model_output[j]) *
(1 / sigma);
}
}
// 2. compute alphas, betas
//
// When comparing TCD with DDPM/DDIM note that Zheng
// et al. (2024) follows the DPM-Solver notation for
// alpha. One can find the following comment in the
// original DPM-Solver code
// (https://github.com/LuChengTHU/dpm-solver/):
// "**Important**: Please pay special attention for
// the args for `alphas_cumprod`: The `alphas_cumprod`
// is the \hat{alpha_n} arrays in the notations of
// DDPM. [...] Therefore, the notation \hat{alpha_n}
// is different from the notation alpha_t in
// DPM-Solver. In fact, we have alpha_{t_n} =
// \sqrt{\hat{alpha_n}}, [...]"
float alpha_prod_t = alphas_cumprod[timestep];
float beta_prod_t = 1 - alpha_prod_t;
// Note final_alpha_cumprod = alphas_cumprod[0] since
// TCD is always "trailing"
float alpha_prod_t_prev = prev_timestep >= 0 ?
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
// The subscript _s are the only portion in this
// section (2) unique to TCD
float alpha_prod_s = alphas_cumprod[timestep_s];
float beta_prod_s = 1 - alpha_prod_s;
// 3. Compute the predicted noised sample x_s based on
// the model parameterization
//
// This section is also exactly the same as DDIM
{
float* vec_x = (float*)x->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
for (int j = 0; j < ggml_nelements(x); j++) {
vec_pred_original_sample[j] =
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
std::sqrt(beta_prod_t) *
vec_model_output[j]) *
(1 / std::sqrt(alpha_prod_t));
}
}
// This consistency function step can be difficult to
// decipher from Algorithm 4, as it is simply stated
// using a consistency function. This step is the
// modified DDIM, i.e. p. 8 (32) in Zheng et
// al. (2024), with eta set to 0 (see the paragraph
// immediately thereafter that states this somewhat
// obliquely).
{
float* vec_pred_original_sample =
(float*)pred_original_sample->data;
float* vec_model_output =
(float*)model_output->data;
float* vec_x = (float*)x->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Substituting x = pred_noised_sample and
// pred_epsilon = model_output
vec_x[j] =
std::sqrt(alpha_prod_s) *
vec_pred_original_sample[j] +
std::sqrt(beta_prod_s) *
vec_model_output[j];
}
}
// 4. Sample and inject noise z ~ N(0, I) for
// MultiStep Inference Noise is not used on the final
// timestep of the timestep schedule. This also means
// that noise is not used for one-step sampling. Eta
// (referred to as "gamma" in the paper) was
// introduced to control the stochasticity in every
// step. When eta = 0, it represents deterministic
// sampling, whereas eta = 1 indicates full stochastic
// sampling.
if (eta > 0 && i != steps - 1) {
// In this case, x is still pred_noised_sample,
// continue in-place
ggml_tensor_set_f32_randn(noise, rng);
float* vec_x = (float*)x->data;
float* vec_noise = (float*)noise->data;
for (int j = 0; j < ggml_nelements(x); j++) {
// Corresponding to (35) in Zheng et
// al. (2024), substituting x =
// pred_noised_sample
vec_x[j] =
std::sqrt(alpha_prod_t_prev /
alpha_prod_s) *
vec_x[j] +
std::sqrt(1 - alpha_prod_t_prev /
alpha_prod_s) *
vec_noise[j];
}
}
}
} break;
default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

View File

@ -39,6 +39,8 @@ const char* sample_method_str[] = {
"ipndm",
"ipndm_v",
"lcm",
"ddim_trailing",
"tcd",
};
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
@ -93,6 +95,7 @@ struct SDParams {
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float guidance = 3.5f;
float eta = 0.f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
@ -162,6 +165,7 @@ void print_params(SDParams params) {
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" eta: %.2f\n", params.eta);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
printf(" height: %d\n", params.height);
@ -211,6 +215,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
@ -221,7 +226,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" sampling method (default: \"euler_a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
@ -440,6 +445,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.guidance = std::stof(argv[i]);
} else if (arg == "--eta") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.eta = std::stof(argv[i]);
} else if (arg == "--strength") {
if (++i >= argc) {
invalid_arg = true;
@ -719,6 +730,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
parameter_string += "Eta: " + std::to_string(params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
@ -939,6 +951,7 @@ int main(int argc, const char* argv[]) {
params.clip_skip,
params.cfg_scale,
params.guidance,
params.eta,
params.width,
params.height,
params.sample_method,
@ -1006,6 +1019,7 @@ int main(int argc, const char* argv[]) {
params.clip_skip,
params.cfg_scale,
params.guidance,
params.eta,
params.width,
params.height,
params.sample_method,

View File

@ -47,6 +47,8 @@ const char* sampling_methods_str[] = {
"iPNDM",
"iPNDM_v",
"LCM",
"DDIM \"trailing\"",
"TCD"
};
/*================================================== Helper Functions ================================================*/
@ -793,6 +795,7 @@ public:
float min_cfg,
float cfg_scale,
float guidance,
float eta,
sample_method_t method,
const std::vector<float>& sigmas,
int start_merge_step,
@ -988,7 +991,7 @@ public:
return denoised;
};
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng);
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
@ -1194,6 +1197,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,
@ -1457,6 +1461,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
cfg_scale,
cfg_scale,
guidance,
eta,
sample_method,
sigmas,
start_merge_step,
@ -1522,6 +1527,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,
@ -1600,6 +1606,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
clip_skip,
cfg_scale,
guidance,
eta,
width,
height,
sample_method,
@ -1631,6 +1638,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
sample_method_t sample_method,
@ -1778,6 +1786,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
clip_skip,
cfg_scale,
guidance,
eta,
width,
height,
sample_method,
@ -1891,6 +1900,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
min_cfg,
cfg_scale,
0.f,
0.f,
sample_method,
sigmas,
-1,

View File

@ -44,6 +44,8 @@ enum sample_method_t {
IPNDM,
IPNDM_V,
LCM,
DDIM_TRAILING,
TCD,
N_SAMPLE_METHODS
};
@ -155,6 +157,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,
@ -180,6 +183,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
int clip_skip,
float cfg_scale,
float guidance,
float eta,
int width,
int height,
enum sample_method_t sample_method,