feat: add logit-normal scheduler (#1669)

This commit is contained in:
stduhpf 2026-06-24 18:06:11 +02:00 committed by GitHub
parent f440ad9c29
commit 2938272d82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 208 additions and 2 deletions

View File

@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware",
(int)',',
&extra_sample_args},
{"",
@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() {
on_high_noise_sample_method_arg},
{"",
"--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal], default: model-specific",
on_scheduler_arg},
{"",
"--sigmas",

View File

@ -70,6 +70,7 @@ enum scheduler_t {
LCM_SCHEDULER,
BONG_TANGENT_SCHEDULER,
LTX2_SCHEDULER,
LOGIT_NORMAL_SCHEDULER,
SCHEDULER_COUNT
};

View File

@ -559,6 +559,203 @@ struct LTX2Scheduler : SigmaScheduler {
}
};
/*
* Logit-Normal Scheduler
* Based on: https://github.com/ideogram-oss/ideogram4/blob/main/src/ideogram4/scheduler.py
*/
struct LogitNormalScheduler : SigmaScheduler {
float mean = 0.0f;
float std = 1.75f;
float logsnr_min = -15.0f;
float logsnr_max = 18.0f;
bool resolution_aware = true;
float one_minus_t_min, one_minus_t_max;
void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
const int known_seq_len = (512 * 512) / (16 * 16);
if (extra_sample_args) {
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) {
if (key == "mu") {
if (!parse_strict_float(value, mean)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "std") {
if (!parse_strict_float(value, std)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
}
if (key == "logsnr_min") {
if (!parse_strict_float(value, logsnr_min)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "logsnr_max") {
if (!parse_strict_float(value, logsnr_max)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else if (key == "resolution_aware") {
if (!parse_strict_bool(value, resolution_aware)) {
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
}
}
}
if (image_seq_len > 0 && resolution_aware) {
mean += 0.5 * std::log(static_cast<float>(image_seq_len) / static_cast<float>(known_seq_len));
}
}
float sigmoid(float x) {
return 1.0f / (1.0f + std::exp(-x));
}
LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f)
: mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) {
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
one_minus_t_min = sigmoid(0.5f * logsnr_max);
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
one_minus_t_max = sigmoid(0.5f * logsnr_min);
}
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
mean = 0.0f;
std = 1.75f;
logsnr_min = -15.0f;
logsnr_max = 18.0f;
parse_extra_sample_args(image_seq_len, extra_sample_args);
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
one_minus_t_min = sigmoid(0.5f * logsnr_max);
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
one_minus_t_max = sigmoid(0.5f * logsnr_min);
}
// https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/
double ndtri(double p) {
if (p <= 0.0) {
return -std::numeric_limits<double>::infinity();
} else if (p >= 1.0) {
return std::numeric_limits<double>::infinity();
}
static const double p_low = 0.02425;
static const double p_high = 1.0 - p_low;
static const double c[6] = {-7.784894002430293e-03,
-3.223964580411365e-01,
-2.400758277161838e+00,
-2.549732539343734e+00,
4.374664141464968e+00,
2.938163982698783e+00};
static const double d[5] = {7.784695709041462e-03,
3.224671290700398e-01,
2.445134137142996e+00,
3.754408661907416e+00,
1.0};
// Coefficients for the central region
static const double a[6] = {-3.969683028665376e+01,
2.209460984245205e+02,
-2.759285104469687e+02,
1.383577518672690e+02,
-3.066479806614716e+01,
2.506628277459239e+00};
static const double b[6] = {-5.447609879822406e+01,
1.615858368580409e+02,
-1.556989798598866e+02,
6.680131188771972e+01,
-1.328068155288572e+01,
1.0};
double x = 0.0;
if (p < p_low) {
// Lower region
double q = std::sqrt(-2.0 * std::log(p));
// Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5]
double numerator = c[0];
for (int i = 1; i < 6; ++i) {
numerator = numerator * q + c[i];
}
// Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1
double denominator = d[0];
for (int i = 1; i < 5; ++i) {
denominator = denominator * q + d[i];
}
x = numerator / denominator;
} else if (p > p_high) {
// Upper region
double q = std::sqrt(-2.0 * std::log(1.0 - p));
double numerator = c[0];
for (int i = 1; i < 6; ++i) {
numerator = numerator * q + c[i];
}
double denominator = d[0];
for (int i = 1; i < 5; ++i) {
denominator = denominator * q + d[i];
}
x = -(numerator / denominator);
} else {
// Central region
double q = p - 0.5;
double r = q * q;
// Numerator: (a[0]*r^5 + a[1]*r^4 + ... + a[5])*q
double numerator = a[0];
for (int i = 1; i < 6; ++i) {
numerator = numerator * r + a[i];
}
numerator *= q;
// Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1
double denominator = b[0];
for (int i = 1; i < 6; ++i) {
denominator = denominator * r + b[i];
}
x = numerator / denominator;
}
return x;
}
std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override {
std::vector<float> sigmas;
LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max);
sigmas.reserve(n + 1);
for (uint32_t i = 0; i <= n; ++i) {
float t = static_cast<float>(i) / static_cast<float>(n);
// ndtri(1-t) == -ndtri(t)
float z = -ndtri(t);
float y = mean + std * z;
float timestep = sigmoid(y);
if (timestep > one_minus_t_min)
timestep = one_minus_t_min;
if (timestep < one_minus_t_max)
timestep = one_minus_t_max;
float sigma = timestep;
sigmas.push_back(sigma);
}
sigmas[n] = 0.0f;
return sigmas;
}
};
struct Denoiser {
virtual float sigma_min() = 0;
virtual float sigma_max() = 0;
@ -623,6 +820,11 @@ struct Denoiser {
LOG_INFO("get_sigmas with LTX2 scheduler");
scheduler = std::make_shared<LTX2Scheduler>(image_seq_len, extra_sample_args);
break;
case LOGIT_NORMAL_SCHEDULER: {
LOG_INFO("get_sigmas with Logit-Normal scheduler");
scheduler = std::make_shared<LogitNormalScheduler>(image_seq_len, extra_sample_args);
break;
}
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();

View File

@ -2535,6 +2535,7 @@ const char* scheduler_to_str[] = {
"lcm",
"bong_tangent",
"ltx2",
"logit_normal",
};
const char* sd_scheduler_name(enum scheduler_t scheduler) {
@ -3137,6 +3138,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
return SIMPLE_SCHEDULER;
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
return LTX2_SCHEDULER;
} else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
return LOGIT_NORMAL_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}