mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 21:38:58 +00:00
feat: implement DDIM with the "trailing" timestep spacing and TCD (#568)
This commit is contained in:
parent
f27f2b2aa2
commit
19d876ee30
371
denoiser.hpp
371
denoiser.hpp
@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
ggml_context* work_ctx,
|
ggml_context* work_ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
std::vector<float> sigmas,
|
std::vector<float> sigmas,
|
||||||
std::shared_ptr<RNG> rng) {
|
std::shared_ptr<RNG> rng,
|
||||||
|
float eta) {
|
||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
// sample_euler_ancestral
|
// sample_euler_ancestral
|
||||||
switch (method) {
|
switch (method) {
|
||||||
@ -1005,6 +1006,374 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
|
||||||
|
// with the "trailing" timestep spacing
|
||||||
|
{
|
||||||
|
// See J. Song et al., "Denoising Diffusion Implicit
|
||||||
|
// Models", arXiv:2010.02502 [cs.LG]
|
||||||
|
//
|
||||||
|
// DDIM itself needs alphas_cumprod (DDPM, J. Ho et al.,
|
||||||
|
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
|
||||||
|
// end beta) (which unfortunately k-diffusion's data
|
||||||
|
// structure hides from the denoiser), and the sigmas are
|
||||||
|
// also needed to invert the behavior of CompVisDenoiser
|
||||||
|
// (k-diffusion's LMSDiscreteScheduler)
|
||||||
|
float beta_start = 0.00085f;
|
||||||
|
float beta_end = 0.0120f;
|
||||||
|
std::vector<double> alphas_cumprod;
|
||||||
|
std::vector<double> compvis_sigmas;
|
||||||
|
|
||||||
|
alphas_cumprod.reserve(TIMESTEPS);
|
||||||
|
compvis_sigmas.reserve(TIMESTEPS);
|
||||||
|
for (int i = 0; i < TIMESTEPS; i++) {
|
||||||
|
alphas_cumprod[i] =
|
||||||
|
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
|
||||||
|
(1.0f -
|
||||||
|
std::pow(sqrtf(beta_start) +
|
||||||
|
(sqrtf(beta_end) - sqrtf(beta_start)) *
|
||||||
|
((float)i / (TIMESTEPS - 1)), 2));
|
||||||
|
compvis_sigmas[i] =
|
||||||
|
std::sqrt((1 - alphas_cumprod[i]) /
|
||||||
|
alphas_cumprod[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* pred_original_sample =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* variance_noise =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
// The "trailing" DDIM timestep, see S. Lin et al.,
|
||||||
|
// "Common Diffusion Noise Schedules and Sample Steps
|
||||||
|
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
|
||||||
|
// 2. Most variables below follow Diffusers naming
|
||||||
|
//
|
||||||
|
// Diffuser naming vs. Song et al. (2010), p. 5, (12)
|
||||||
|
// and p. 16, (16) (<variable name> -> <name in
|
||||||
|
// paper>):
|
||||||
|
//
|
||||||
|
// - pred_noise_t -> epsilon_theta^(t)(x_t)
|
||||||
|
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
|
||||||
|
// - std_dev_t -> sigma_t (not the LMS sigma)
|
||||||
|
// - eta -> eta (set to 0 at the moment)
|
||||||
|
// - pred_sample_direction -> "direction pointing to
|
||||||
|
// x_t"
|
||||||
|
// - pred_prev_sample -> "x_t-1"
|
||||||
|
int timestep =
|
||||||
|
roundf(TIMESTEPS -
|
||||||
|
i * ((float)TIMESTEPS / steps)) - 1;
|
||||||
|
// 1. get previous step value (=t-1)
|
||||||
|
int prev_timestep = timestep - TIMESTEPS / steps;
|
||||||
|
// The sigma here is chosen to cause the
|
||||||
|
// CompVisDenoiser to produce t = timestep
|
||||||
|
float sigma = compvis_sigmas[timestep];
|
||||||
|
if (i == 0) {
|
||||||
|
// The function add_noise intializes x to
|
||||||
|
// Diffusers' latents * sigma (as in Diffusers'
|
||||||
|
// pipeline) or sample * sigma (Diffusers'
|
||||||
|
// scheduler), where this sigma = init_noise_sigma
|
||||||
|
// in Diffusers. For DDPM and DDIM however,
|
||||||
|
// init_noise_sigma = 1. But the k-diffusion
|
||||||
|
// model() also evaluates F_theta(c_in(sigma) x;
|
||||||
|
// ...) instead of the bare U-net F_theta, with
|
||||||
|
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
|
||||||
|
// T. Karras et al., "Elucidating the Design Space
|
||||||
|
// of Diffusion-Based Generative Models",
|
||||||
|
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
|
||||||
|
// the first call has to be prescaled as x <- x /
|
||||||
|
// (c_in * sigma) with the k-diffusion pipeline
|
||||||
|
// and CompVisDenoiser.
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
|
||||||
|
sigma;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// For the subsequent steps after the first one,
|
||||||
|
// at this point x = latents or x = sample, and
|
||||||
|
// needs to be prescaled with x <- sample / c_in
|
||||||
|
// to compensate for model() applying the scale
|
||||||
|
// c_in before the U-net F_theta
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Note (also noise_pred in Diffuser's pipeline)
|
||||||
|
// model_output = model() is the D(x, sigma) as
|
||||||
|
// defined in Karras et al. (2022), p. 3, Table 1 and
|
||||||
|
// p. 8 (7), compare also p. 38 (226) therein.
|
||||||
|
struct ggml_tensor* model_output =
|
||||||
|
model(x, sigma, i + 1);
|
||||||
|
// Here model_output is still the k-diffusion denoiser
|
||||||
|
// output, not the U-net output F_theta(c_in(sigma) x;
|
||||||
|
// ...) in Karras et al. (2022), whereas Diffusers'
|
||||||
|
// model_output is F_theta(...). Recover the actual
|
||||||
|
// model_output, which is also referred to as the
|
||||||
|
// "Karras ODE derivative" d or d_cur in several
|
||||||
|
// samplers above.
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_model_output[j] =
|
||||||
|
(vec_x[j] - vec_model_output[j]) *
|
||||||
|
(1 / sigma);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. compute alphas, betas
|
||||||
|
float alpha_prod_t = alphas_cumprod[timestep];
|
||||||
|
// Note final_alpha_cumprod = alphas_cumprod[0] due to
|
||||||
|
// trailing timestep spacing
|
||||||
|
float alpha_prod_t_prev = prev_timestep >= 0 ?
|
||||||
|
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
|
||||||
|
float beta_prod_t = 1 - alpha_prod_t;
|
||||||
|
// 3. compute predicted original sample from predicted
|
||||||
|
// noise also called "predicted x_0" of formula (12)
|
||||||
|
// from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
// Note the substitution of latents or sample = x
|
||||||
|
// * c_in = x / sqrt(sigma^2 + 1)
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_pred_original_sample[j] =
|
||||||
|
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
|
||||||
|
std::sqrt(beta_prod_t) *
|
||||||
|
vec_model_output[j]) *
|
||||||
|
(1 / std::sqrt(alpha_prod_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Assuming the "epsilon" prediction type, where below
|
||||||
|
// pred_epsilon = model_output is inserted, and is not
|
||||||
|
// defined/copied explicitly.
|
||||||
|
//
|
||||||
|
// 5. compute variance: "sigma_t(eta)" -> see formula
|
||||||
|
// (16)
|
||||||
|
//
|
||||||
|
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
|
||||||
|
// sqrt(1 - alpha_t/alpha_t-1)
|
||||||
|
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
|
||||||
|
float variance = (beta_prod_t_prev / beta_prod_t) *
|
||||||
|
(1 - alpha_prod_t / alpha_prod_t_prev);
|
||||||
|
float std_dev_t = eta * std::sqrt(variance);
|
||||||
|
// 6. compute "direction pointing to x_t" of formula
|
||||||
|
// (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
// 7. compute x_t without "random noise" of formula
|
||||||
|
// (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
{
|
||||||
|
float* vec_model_output = (float*)model_output->data;
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
// Two step inner loop without an explicit
|
||||||
|
// tensor
|
||||||
|
float pred_sample_direction =
|
||||||
|
std::sqrt(1 - alpha_prod_t_prev -
|
||||||
|
std::pow(std_dev_t, 2)) *
|
||||||
|
vec_model_output[j];
|
||||||
|
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
|
||||||
|
vec_pred_original_sample[j] +
|
||||||
|
pred_sample_direction;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (eta > 0) {
|
||||||
|
ggml_tensor_set_f32_randn(variance_noise, rng);
|
||||||
|
float* vec_variance_noise =
|
||||||
|
(float*)variance_noise->data;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] += std_dev_t * vec_variance_noise[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// See the note above: x = latents or sample here, and
|
||||||
|
// is not scaled by the c_in. For the final output
|
||||||
|
// this is correct, but for subsequent iterations, x
|
||||||
|
// needs to be prescaled again, since k-diffusion's
|
||||||
|
// model() differes from the bare U-net F_theta by the
|
||||||
|
// factor c_in.
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
|
||||||
|
// Trajectory Consistency Distillation
|
||||||
|
{
|
||||||
|
// See J. Zheng et al., "Trajectory Consistency
|
||||||
|
// Distillation: Improved Latent Consistency Distillation
|
||||||
|
// by Semi-Linear Consistency Function with Trajectory
|
||||||
|
// Mapping", arXiv:2402.19159 [cs.CV]
|
||||||
|
float beta_start = 0.00085f;
|
||||||
|
float beta_end = 0.0120f;
|
||||||
|
std::vector<double> alphas_cumprod;
|
||||||
|
std::vector<double> compvis_sigmas;
|
||||||
|
|
||||||
|
alphas_cumprod.reserve(TIMESTEPS);
|
||||||
|
compvis_sigmas.reserve(TIMESTEPS);
|
||||||
|
for (int i = 0; i < TIMESTEPS; i++) {
|
||||||
|
alphas_cumprod[i] =
|
||||||
|
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
|
||||||
|
(1.0f -
|
||||||
|
std::pow(sqrtf(beta_start) +
|
||||||
|
(sqrtf(beta_end) - sqrtf(beta_start)) *
|
||||||
|
((float)i / (TIMESTEPS - 1)), 2));
|
||||||
|
compvis_sigmas[i] =
|
||||||
|
std::sqrt((1 - alphas_cumprod[i]) /
|
||||||
|
alphas_cumprod[i]);
|
||||||
|
}
|
||||||
|
int original_steps = 50;
|
||||||
|
|
||||||
|
struct ggml_tensor* pred_original_sample =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
struct ggml_tensor* noise =
|
||||||
|
ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
// Analytic form for TCD timesteps
|
||||||
|
int timestep = TIMESTEPS - 1 -
|
||||||
|
(TIMESTEPS / original_steps) *
|
||||||
|
(int)floor(i * ((float)original_steps / steps));
|
||||||
|
// 1. get previous step value
|
||||||
|
int prev_timestep = i >= steps - 1 ? 0 :
|
||||||
|
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
|
||||||
|
(int)floor((i + 1) *
|
||||||
|
((float)original_steps / steps));
|
||||||
|
// Here timestep_s is tau_n' in Algorithm 4. The _s
|
||||||
|
// notation appears to be that from C. Lu,
|
||||||
|
// "DPM-Solver: A Fast ODE Solver for Diffusion
|
||||||
|
// Probabilistic Model Sampling in Around 10 Steps",
|
||||||
|
// arXiv:2206.00927 [cs.LG], but this notation is not
|
||||||
|
// continued in Algorithm 4, where _n' is used.
|
||||||
|
int timestep_s =
|
||||||
|
(int)floor((1 - eta) * prev_timestep);
|
||||||
|
// Begin k-diffusion specific workaround for
|
||||||
|
// evaluating F_theta(x; ...) from D(x, sigma), same
|
||||||
|
// as in DDIM (and see there for detailed comments)
|
||||||
|
float sigma = compvis_sigmas[timestep];
|
||||||
|
if (i == 0) {
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
|
||||||
|
sigma;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_x[j] *= std::sqrt(sigma * sigma + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
struct ggml_tensor* model_output =
|
||||||
|
model(x, sigma, i + 1);
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_model_output[j] =
|
||||||
|
(vec_x[j] - vec_model_output[j]) *
|
||||||
|
(1 / sigma);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. compute alphas, betas
|
||||||
|
//
|
||||||
|
// When comparing TCD with DDPM/DDIM note that Zheng
|
||||||
|
// et al. (2024) follows the DPM-Solver notation for
|
||||||
|
// alpha. One can find the following comment in the
|
||||||
|
// original DPM-Solver code
|
||||||
|
// (https://github.com/LuChengTHU/dpm-solver/):
|
||||||
|
// "**Important**: Please pay special attention for
|
||||||
|
// the args for `alphas_cumprod`: The `alphas_cumprod`
|
||||||
|
// is the \hat{alpha_n} arrays in the notations of
|
||||||
|
// DDPM. [...] Therefore, the notation \hat{alpha_n}
|
||||||
|
// is different from the notation alpha_t in
|
||||||
|
// DPM-Solver. In fact, we have alpha_{t_n} =
|
||||||
|
// \sqrt{\hat{alpha_n}}, [...]"
|
||||||
|
float alpha_prod_t = alphas_cumprod[timestep];
|
||||||
|
float beta_prod_t = 1 - alpha_prod_t;
|
||||||
|
// Note final_alpha_cumprod = alphas_cumprod[0] since
|
||||||
|
// TCD is always "trailing"
|
||||||
|
float alpha_prod_t_prev = prev_timestep >= 0 ?
|
||||||
|
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
|
||||||
|
// The subscript _s are the only portion in this
|
||||||
|
// section (2) unique to TCD
|
||||||
|
float alpha_prod_s = alphas_cumprod[timestep_s];
|
||||||
|
float beta_prod_s = 1 - alpha_prod_s;
|
||||||
|
// 3. Compute the predicted noised sample x_s based on
|
||||||
|
// the model parameterization
|
||||||
|
//
|
||||||
|
// This section is also exactly the same as DDIM
|
||||||
|
{
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
vec_pred_original_sample[j] =
|
||||||
|
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
|
||||||
|
std::sqrt(beta_prod_t) *
|
||||||
|
vec_model_output[j]) *
|
||||||
|
(1 / std::sqrt(alpha_prod_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// This consistency function step can be difficult to
|
||||||
|
// decipher from Algorithm 4, as it is simply stated
|
||||||
|
// using a consistency function. This step is the
|
||||||
|
// modified DDIM, i.e. p. 8 (32) in Zheng et
|
||||||
|
// al. (2024), with eta set to 0 (see the paragraph
|
||||||
|
// immediately thereafter that states this somewhat
|
||||||
|
// obliquely).
|
||||||
|
{
|
||||||
|
float* vec_pred_original_sample =
|
||||||
|
(float*)pred_original_sample->data;
|
||||||
|
float* vec_model_output =
|
||||||
|
(float*)model_output->data;
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
// Substituting x = pred_noised_sample and
|
||||||
|
// pred_epsilon = model_output
|
||||||
|
vec_x[j] =
|
||||||
|
std::sqrt(alpha_prod_s) *
|
||||||
|
vec_pred_original_sample[j] +
|
||||||
|
std::sqrt(beta_prod_s) *
|
||||||
|
vec_model_output[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 4. Sample and inject noise z ~ N(0, I) for
|
||||||
|
// MultiStep Inference Noise is not used on the final
|
||||||
|
// timestep of the timestep schedule. This also means
|
||||||
|
// that noise is not used for one-step sampling. Eta
|
||||||
|
// (referred to as "gamma" in the paper) was
|
||||||
|
// introduced to control the stochasticity in every
|
||||||
|
// step. When eta = 0, it represents deterministic
|
||||||
|
// sampling, whereas eta = 1 indicates full stochastic
|
||||||
|
// sampling.
|
||||||
|
if (eta > 0 && i != steps - 1) {
|
||||||
|
// In this case, x is still pred_noised_sample,
|
||||||
|
// continue in-place
|
||||||
|
ggml_tensor_set_f32_randn(noise, rng);
|
||||||
|
float* vec_x = (float*)x->data;
|
||||||
|
float* vec_noise = (float*)noise->data;
|
||||||
|
for (int j = 0; j < ggml_nelements(x); j++) {
|
||||||
|
// Corresponding to (35) in Zheng et
|
||||||
|
// al. (2024), substituting x =
|
||||||
|
// pred_noised_sample
|
||||||
|
vec_x[j] =
|
||||||
|
std::sqrt(alpha_prod_t_prev /
|
||||||
|
alpha_prod_s) *
|
||||||
|
vec_x[j] +
|
||||||
|
std::sqrt(1 - alpha_prod_t_prev /
|
||||||
|
alpha_prod_s) *
|
||||||
|
vec_noise[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
|
||||||
|
|||||||
@ -39,6 +39,8 @@ const char* sample_method_str[] = {
|
|||||||
"ipndm",
|
"ipndm",
|
||||||
"ipndm_v",
|
"ipndm_v",
|
||||||
"lcm",
|
"lcm",
|
||||||
|
"ddim_trailing",
|
||||||
|
"tcd",
|
||||||
};
|
};
|
||||||
|
|
||||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||||
@ -93,6 +95,7 @@ struct SDParams {
|
|||||||
float min_cfg = 1.0f;
|
float min_cfg = 1.0f;
|
||||||
float cfg_scale = 7.0f;
|
float cfg_scale = 7.0f;
|
||||||
float guidance = 3.5f;
|
float guidance = 3.5f;
|
||||||
|
float eta = 0.f;
|
||||||
float style_ratio = 20.f;
|
float style_ratio = 20.f;
|
||||||
int clip_skip = -1; // <= 0 represents unspecified
|
int clip_skip = -1; // <= 0 represents unspecified
|
||||||
int width = 512;
|
int width = 512;
|
||||||
@ -162,6 +165,7 @@ void print_params(SDParams params) {
|
|||||||
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
printf(" cfg_scale: %.2f\n", params.cfg_scale);
|
||||||
printf(" slg_scale: %.2f\n", params.slg_scale);
|
printf(" slg_scale: %.2f\n", params.slg_scale);
|
||||||
printf(" guidance: %.2f\n", params.guidance);
|
printf(" guidance: %.2f\n", params.guidance);
|
||||||
|
printf(" eta: %.2f\n", params.eta);
|
||||||
printf(" clip_skip: %d\n", params.clip_skip);
|
printf(" clip_skip: %d\n", params.clip_skip);
|
||||||
printf(" width: %d\n", params.width);
|
printf(" width: %d\n", params.width);
|
||||||
printf(" height: %d\n", params.height);
|
printf(" height: %d\n", params.height);
|
||||||
@ -211,6 +215,7 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
|
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
|
||||||
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
|
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
|
||||||
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
|
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
|
||||||
|
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
|
||||||
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
|
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
|
||||||
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
|
printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n");
|
||||||
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
|
printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n");
|
||||||
@ -221,7 +226,7 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" 1.0 corresponds to full destruction of information in init image\n");
|
printf(" 1.0 corresponds to full destruction of information in init image\n");
|
||||||
printf(" -H, --height H image height, in pixel space (default: 512)\n");
|
printf(" -H, --height H image height, in pixel space (default: 512)\n");
|
||||||
printf(" -W, --width W image width, in pixel space (default: 512)\n");
|
printf(" -W, --width W image width, in pixel space (default: 512)\n");
|
||||||
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n");
|
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
|
||||||
printf(" sampling method (default: \"euler_a\")\n");
|
printf(" sampling method (default: \"euler_a\")\n");
|
||||||
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
printf(" --steps STEPS number of sample steps (default: 20)\n");
|
||||||
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
|
||||||
@ -440,6 +445,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.guidance = std::stof(argv[i]);
|
params.guidance = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--eta") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_arg = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.eta = std::stof(argv[i]);
|
||||||
} else if (arg == "--strength") {
|
} else if (arg == "--strength") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_arg = true;
|
invalid_arg = true;
|
||||||
@ -719,6 +730,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
|||||||
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
|
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
|
||||||
}
|
}
|
||||||
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
|
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
|
||||||
|
parameter_string += "Eta: " + std::to_string(params.eta) + ", ";
|
||||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||||
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
|
||||||
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
|
||||||
@ -939,6 +951,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.clip_skip,
|
params.clip_skip,
|
||||||
params.cfg_scale,
|
params.cfg_scale,
|
||||||
params.guidance,
|
params.guidance,
|
||||||
|
params.eta,
|
||||||
params.width,
|
params.width,
|
||||||
params.height,
|
params.height,
|
||||||
params.sample_method,
|
params.sample_method,
|
||||||
@ -1006,6 +1019,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.clip_skip,
|
params.clip_skip,
|
||||||
params.cfg_scale,
|
params.cfg_scale,
|
||||||
params.guidance,
|
params.guidance,
|
||||||
|
params.eta,
|
||||||
params.width,
|
params.width,
|
||||||
params.height,
|
params.height,
|
||||||
params.sample_method,
|
params.sample_method,
|
||||||
|
|||||||
@ -47,6 +47,8 @@ const char* sampling_methods_str[] = {
|
|||||||
"iPNDM",
|
"iPNDM",
|
||||||
"iPNDM_v",
|
"iPNDM_v",
|
||||||
"LCM",
|
"LCM",
|
||||||
|
"DDIM \"trailing\"",
|
||||||
|
"TCD"
|
||||||
};
|
};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
@ -793,6 +795,7 @@ public:
|
|||||||
float min_cfg,
|
float min_cfg,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
sample_method_t method,
|
sample_method_t method,
|
||||||
const std::vector<float>& sigmas,
|
const std::vector<float>& sigmas,
|
||||||
int start_merge_step,
|
int start_merge_step,
|
||||||
@ -988,7 +991,7 @@ public:
|
|||||||
return denoised;
|
return denoised;
|
||||||
};
|
};
|
||||||
|
|
||||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng);
|
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
||||||
|
|
||||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||||
|
|
||||||
@ -1194,6 +1197,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
|||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
@ -1457,6 +1461,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
|
|||||||
cfg_scale,
|
cfg_scale,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
guidance,
|
guidance,
|
||||||
|
eta,
|
||||||
sample_method,
|
sample_method,
|
||||||
sigmas,
|
sigmas,
|
||||||
start_merge_step,
|
start_merge_step,
|
||||||
@ -1522,6 +1527,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
|||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
@ -1600,6 +1606,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
|||||||
clip_skip,
|
clip_skip,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
guidance,
|
guidance,
|
||||||
|
eta,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sample_method,
|
sample_method,
|
||||||
@ -1631,6 +1638,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
|||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
sample_method_t sample_method,
|
sample_method_t sample_method,
|
||||||
@ -1778,6 +1786,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
|||||||
clip_skip,
|
clip_skip,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
guidance,
|
guidance,
|
||||||
|
eta,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sample_method,
|
sample_method,
|
||||||
@ -1891,6 +1900,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
|
|||||||
min_cfg,
|
min_cfg,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
0.f,
|
0.f,
|
||||||
|
0.f,
|
||||||
sample_method,
|
sample_method,
|
||||||
sigmas,
|
sigmas,
|
||||||
-1,
|
-1,
|
||||||
|
|||||||
@ -44,6 +44,8 @@ enum sample_method_t {
|
|||||||
IPNDM,
|
IPNDM,
|
||||||
IPNDM_V,
|
IPNDM_V,
|
||||||
LCM,
|
LCM,
|
||||||
|
DDIM_TRAILING,
|
||||||
|
TCD,
|
||||||
N_SAMPLE_METHODS
|
N_SAMPLE_METHODS
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -155,6 +157,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
|
|||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
@ -180,6 +183,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
|
|||||||
int clip_skip,
|
int clip_skip,
|
||||||
float cfg_scale,
|
float cfg_scale,
|
||||||
float guidance,
|
float guidance,
|
||||||
|
float eta,
|
||||||
int width,
|
int width,
|
||||||
int height,
|
int height,
|
||||||
enum sample_method_t sample_method,
|
enum sample_method_t sample_method,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user