feat: add KL Optimal scheduler (#1098)

This commit is contained in:
Daniele 2025-12-18 14:02:55 +01:00 committed by GitHub
parent bda7fab9f2
commit 97cf2efe45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 47 additions and 6 deletions

View File

@ -347,6 +347,41 @@ struct SmoothStepScheduler : SigmaScheduler {
} }
}; };
// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
struct KLOptimalScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas;
if (n == 0) {
return sigmas;
}
if (n == 1) {
sigmas.push_back(sigma_max);
sigmas.push_back(0.0f);
return sigmas;
}
float alpha_min = std::atan(sigma_min);
float alpha_max = std::atan(sigma_max);
for (uint32_t i = 0; i < n; ++i) {
// t goes from 0.0 to 1.0
float t = static_cast<float>(i) / static_cast<float>(n-1);
// Interpolate in the angle domain
float angle = t * alpha_min + (1.0f - t) * alpha_max;
// Convert back to sigma
sigmas.push_back(std::tan(angle));
}
// Append the final zero to sigma
sigmas.push_back(0.0f);
return sigmas;
}
};
struct Denoiser { struct Denoiser {
virtual float sigma_min() = 0; virtual float sigma_min() = 0;
virtual float sigma_max() = 0; virtual float sigma_max() = 0;
@ -392,6 +427,10 @@ struct Denoiser {
LOG_INFO("get_sigmas with SmoothStep scheduler"); LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>(); scheduler = std::make_shared<SmoothStepScheduler>();
break; break;
case KL_OPTIMAL_SCHEDULER:
LOG_INFO("get_sigmas with KL Optimal scheduler");
scheduler = std::make_shared<KLOptimalScheduler>();
break;
case LCM_SCHEDULER: case LCM_SCHEDULER:
LOG_INFO("get_sigmas with LCM scheduler"); LOG_INFO("get_sigmas with LCM scheduler");
scheduler = std::make_shared<LCMScheduler>(); scheduler = std::make_shared<LCMScheduler>();

View File

@ -120,7 +120,7 @@ Generation Options:
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm],
default: discrete default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])

View File

@ -1409,7 +1409,7 @@ struct SDGenerationParams {
on_high_noise_sample_method_arg}, on_high_noise_sample_method_arg},
{"", {"",
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete",
on_scheduler_arg}, on_scheduler_arg},
{"", {"",
"--sigmas", "--sigmas",

View File

@ -114,7 +114,7 @@ Default Generation Options:
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise) tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm],
default: discrete default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0"). --sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])

View File

@ -2412,6 +2412,7 @@ const char* scheduler_to_str[] = {
"sgm_uniform", "sgm_uniform",
"simple", "simple",
"smoothstep", "smoothstep",
"kl_optimal",
"lcm", "lcm",
}; };

View File

@ -60,6 +60,7 @@ enum scheduler_t {
SGM_UNIFORM_SCHEDULER, SGM_UNIFORM_SCHEDULER,
SIMPLE_SCHEDULER, SIMPLE_SCHEDULER,
SMOOTHSTEP_SCHEDULER, SMOOTHSTEP_SCHEDULER,
KL_OPTIMAL_SCHEDULER,
LCM_SCHEDULER, LCM_SCHEDULER,
SCHEDULER_COUNT SCHEDULER_COUNT
}; };