mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
feat: add logit-normal scheduler (#1669)
This commit is contained in:
parent
f440ad9c29
commit
2938272d82
@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
|
||||
&hires_upscaler},
|
||||
{"",
|
||||
"--extra-sample-args",
|
||||
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
|
||||
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware",
|
||||
(int)',',
|
||||
&extra_sample_args},
|
||||
{"",
|
||||
@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() {
|
||||
on_high_noise_sample_method_arg},
|
||||
{"",
|
||||
"--scheduler",
|
||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific",
|
||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal], default: model-specific",
|
||||
on_scheduler_arg},
|
||||
{"",
|
||||
"--sigmas",
|
||||
|
||||
@ -70,6 +70,7 @@ enum scheduler_t {
|
||||
LCM_SCHEDULER,
|
||||
BONG_TANGENT_SCHEDULER,
|
||||
LTX2_SCHEDULER,
|
||||
LOGIT_NORMAL_SCHEDULER,
|
||||
SCHEDULER_COUNT
|
||||
};
|
||||
|
||||
|
||||
@ -559,6 +559,203 @@ struct LTX2Scheduler : SigmaScheduler {
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* Logit-Normal Scheduler
|
||||
* Based on: https://github.com/ideogram-oss/ideogram4/blob/main/src/ideogram4/scheduler.py
|
||||
*/
|
||||
struct LogitNormalScheduler : SigmaScheduler {
|
||||
float mean = 0.0f;
|
||||
float std = 1.75f;
|
||||
float logsnr_min = -15.0f;
|
||||
float logsnr_max = 18.0f;
|
||||
|
||||
bool resolution_aware = true;
|
||||
|
||||
float one_minus_t_min, one_minus_t_max;
|
||||
|
||||
void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
|
||||
const int known_seq_len = (512 * 512) / (16 * 16);
|
||||
if (extra_sample_args) {
|
||||
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) {
|
||||
if (key == "mu") {
|
||||
if (!parse_strict_float(value, mean)) {
|
||||
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
} else if (key == "std") {
|
||||
if (!parse_strict_float(value, std)) {
|
||||
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
}
|
||||
if (key == "logsnr_min") {
|
||||
if (!parse_strict_float(value, logsnr_min)) {
|
||||
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
} else if (key == "logsnr_max") {
|
||||
if (!parse_strict_float(value, logsnr_max)) {
|
||||
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
} else if (key == "resolution_aware") {
|
||||
if (!parse_strict_bool(value, resolution_aware)) {
|
||||
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (image_seq_len > 0 && resolution_aware) {
|
||||
mean += 0.5 * std::log(static_cast<float>(image_seq_len) / static_cast<float>(known_seq_len));
|
||||
}
|
||||
}
|
||||
|
||||
float sigmoid(float x) {
|
||||
return 1.0f / (1.0f + std::exp(-x));
|
||||
}
|
||||
|
||||
LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f)
|
||||
: mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) {
|
||||
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
|
||||
one_minus_t_min = sigmoid(0.5f * logsnr_max);
|
||||
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
|
||||
one_minus_t_max = sigmoid(0.5f * logsnr_min);
|
||||
|
||||
}
|
||||
|
||||
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
|
||||
mean = 0.0f;
|
||||
std = 1.75f;
|
||||
logsnr_min = -15.0f;
|
||||
logsnr_max = 18.0f;
|
||||
|
||||
parse_extra_sample_args(image_seq_len, extra_sample_args);
|
||||
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
|
||||
one_minus_t_min = sigmoid(0.5f * logsnr_max);
|
||||
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
|
||||
one_minus_t_max = sigmoid(0.5f * logsnr_min);
|
||||
}
|
||||
|
||||
// https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/
|
||||
double ndtri(double p) {
|
||||
if (p <= 0.0) {
|
||||
return -std::numeric_limits<double>::infinity();
|
||||
} else if (p >= 1.0) {
|
||||
return std::numeric_limits<double>::infinity();
|
||||
}
|
||||
|
||||
static const double p_low = 0.02425;
|
||||
static const double p_high = 1.0 - p_low;
|
||||
|
||||
static const double c[6] = {-7.784894002430293e-03,
|
||||
-3.223964580411365e-01,
|
||||
-2.400758277161838e+00,
|
||||
-2.549732539343734e+00,
|
||||
4.374664141464968e+00,
|
||||
2.938163982698783e+00};
|
||||
|
||||
static const double d[5] = {7.784695709041462e-03,
|
||||
3.224671290700398e-01,
|
||||
2.445134137142996e+00,
|
||||
3.754408661907416e+00,
|
||||
1.0};
|
||||
|
||||
// Coefficients for the central region
|
||||
static const double a[6] = {-3.969683028665376e+01,
|
||||
2.209460984245205e+02,
|
||||
-2.759285104469687e+02,
|
||||
1.383577518672690e+02,
|
||||
-3.066479806614716e+01,
|
||||
2.506628277459239e+00};
|
||||
|
||||
static const double b[6] = {-5.447609879822406e+01,
|
||||
1.615858368580409e+02,
|
||||
-1.556989798598866e+02,
|
||||
6.680131188771972e+01,
|
||||
-1.328068155288572e+01,
|
||||
1.0};
|
||||
|
||||
double x = 0.0;
|
||||
|
||||
if (p < p_low) {
|
||||
// Lower region
|
||||
double q = std::sqrt(-2.0 * std::log(p));
|
||||
|
||||
// Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5]
|
||||
double numerator = c[0];
|
||||
for (int i = 1; i < 6; ++i) {
|
||||
numerator = numerator * q + c[i];
|
||||
}
|
||||
|
||||
// Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1
|
||||
double denominator = d[0];
|
||||
for (int i = 1; i < 5; ++i) {
|
||||
denominator = denominator * q + d[i];
|
||||
}
|
||||
|
||||
x = numerator / denominator;
|
||||
} else if (p > p_high) {
|
||||
// Upper region
|
||||
double q = std::sqrt(-2.0 * std::log(1.0 - p));
|
||||
|
||||
double numerator = c[0];
|
||||
for (int i = 1; i < 6; ++i) {
|
||||
numerator = numerator * q + c[i];
|
||||
}
|
||||
|
||||
double denominator = d[0];
|
||||
for (int i = 1; i < 5; ++i) {
|
||||
denominator = denominator * q + d[i];
|
||||
}
|
||||
|
||||
x = -(numerator / denominator);
|
||||
} else {
|
||||
// Central region
|
||||
double q = p - 0.5;
|
||||
double r = q * q;
|
||||
|
||||
// Numerator: (a[0]*r^5 + a[1]*r^4 + ... + a[5])*q
|
||||
double numerator = a[0];
|
||||
for (int i = 1; i < 6; ++i) {
|
||||
numerator = numerator * r + a[i];
|
||||
}
|
||||
numerator *= q;
|
||||
|
||||
// Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1
|
||||
double denominator = b[0];
|
||||
for (int i = 1; i < 6; ++i) {
|
||||
denominator = denominator * r + b[i];
|
||||
}
|
||||
|
||||
x = numerator / denominator;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override {
|
||||
std::vector<float> sigmas;
|
||||
LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max);
|
||||
sigmas.reserve(n + 1);
|
||||
for (uint32_t i = 0; i <= n; ++i) {
|
||||
float t = static_cast<float>(i) / static_cast<float>(n);
|
||||
|
||||
// ndtri(1-t) == -ndtri(t)
|
||||
float z = -ndtri(t);
|
||||
|
||||
float y = mean + std * z;
|
||||
|
||||
float timestep = sigmoid(y);
|
||||
|
||||
if (timestep > one_minus_t_min)
|
||||
timestep = one_minus_t_min;
|
||||
if (timestep < one_minus_t_max)
|
||||
timestep = one_minus_t_max;
|
||||
|
||||
float sigma = timestep;
|
||||
|
||||
sigmas.push_back(sigma);
|
||||
}
|
||||
sigmas[n] = 0.0f;
|
||||
return sigmas;
|
||||
}
|
||||
};
|
||||
|
||||
struct Denoiser {
|
||||
virtual float sigma_min() = 0;
|
||||
virtual float sigma_max() = 0;
|
||||
@ -623,6 +820,11 @@ struct Denoiser {
|
||||
LOG_INFO("get_sigmas with LTX2 scheduler");
|
||||
scheduler = std::make_shared<LTX2Scheduler>(image_seq_len, extra_sample_args);
|
||||
break;
|
||||
case LOGIT_NORMAL_SCHEDULER: {
|
||||
LOG_INFO("get_sigmas with Logit-Normal scheduler");
|
||||
scheduler = std::make_shared<LogitNormalScheduler>(image_seq_len, extra_sample_args);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG_INFO("get_sigmas with discrete scheduler (default)");
|
||||
scheduler = std::make_shared<DiscreteScheduler>();
|
||||
|
||||
@ -2535,6 +2535,7 @@ const char* scheduler_to_str[] = {
|
||||
"lcm",
|
||||
"bong_tangent",
|
||||
"ltx2",
|
||||
"logit_normal",
|
||||
};
|
||||
|
||||
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
||||
@ -3137,6 +3138,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
|
||||
return SIMPLE_SCHEDULER;
|
||||
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||
return LTX2_SCHEDULER;
|
||||
} else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
|
||||
return LOGIT_NORMAL_SCHEDULER;
|
||||
}
|
||||
return DISCRETE_SCHEDULER;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user