mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
feat: add logit-normal scheduler (#1669)
This commit is contained in:
parent
f440ad9c29
commit
2938272d82
@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
&hires_upscaler},
|
&hires_upscaler},
|
||||||
{"",
|
{"",
|
||||||
"--extra-sample-args",
|
"--extra-sample-args",
|
||||||
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
|
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;; logit_normal supports mu, std, logsnr_min, logsnr_max, resolution_aware",
|
||||||
(int)',',
|
(int)',',
|
||||||
&extra_sample_args},
|
&extra_sample_args},
|
||||||
{"",
|
{"",
|
||||||
@ -1475,7 +1475,7 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
on_high_noise_sample_method_arg},
|
on_high_noise_sample_method_arg},
|
||||||
{"",
|
{"",
|
||||||
"--scheduler",
|
"--scheduler",
|
||||||
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2], default: model-specific",
|
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent, ltx2, logit_normal], default: model-specific",
|
||||||
on_scheduler_arg},
|
on_scheduler_arg},
|
||||||
{"",
|
{"",
|
||||||
"--sigmas",
|
"--sigmas",
|
||||||
|
|||||||
@ -70,6 +70,7 @@ enum scheduler_t {
|
|||||||
LCM_SCHEDULER,
|
LCM_SCHEDULER,
|
||||||
BONG_TANGENT_SCHEDULER,
|
BONG_TANGENT_SCHEDULER,
|
||||||
LTX2_SCHEDULER,
|
LTX2_SCHEDULER,
|
||||||
|
LOGIT_NORMAL_SCHEDULER,
|
||||||
SCHEDULER_COUNT
|
SCHEDULER_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -559,6 +559,203 @@ struct LTX2Scheduler : SigmaScheduler {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Logit-Normal Scheduler
|
||||||
|
* Based on: https://github.com/ideogram-oss/ideogram4/blob/main/src/ideogram4/scheduler.py
|
||||||
|
*/
|
||||||
|
struct LogitNormalScheduler : SigmaScheduler {
|
||||||
|
float mean = 0.0f;
|
||||||
|
float std = 1.75f;
|
||||||
|
float logsnr_min = -15.0f;
|
||||||
|
float logsnr_max = 18.0f;
|
||||||
|
|
||||||
|
bool resolution_aware = true;
|
||||||
|
|
||||||
|
float one_minus_t_min, one_minus_t_max;
|
||||||
|
|
||||||
|
void parse_extra_sample_args(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
|
||||||
|
const int known_seq_len = (512 * 512) / (16 * 16);
|
||||||
|
if (extra_sample_args) {
|
||||||
|
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "logit-normal scheduler arg")) {
|
||||||
|
if (key == "mu") {
|
||||||
|
if (!parse_strict_float(value, mean)) {
|
||||||
|
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
}
|
||||||
|
} else if (key == "std") {
|
||||||
|
if (!parse_strict_float(value, std)) {
|
||||||
|
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (key == "logsnr_min") {
|
||||||
|
if (!parse_strict_float(value, logsnr_min)) {
|
||||||
|
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
}
|
||||||
|
} else if (key == "logsnr_max") {
|
||||||
|
if (!parse_strict_float(value, logsnr_max)) {
|
||||||
|
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
}
|
||||||
|
} else if (key == "resolution_aware") {
|
||||||
|
if (!parse_strict_bool(value, resolution_aware)) {
|
||||||
|
LOG_WARN("ignoring invalid logit-normal scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (image_seq_len > 0 && resolution_aware) {
|
||||||
|
mean += 0.5 * std::log(static_cast<float>(image_seq_len) / static_cast<float>(known_seq_len));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float sigmoid(float x) {
|
||||||
|
return 1.0f / (1.0f + std::exp(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
LogitNormalScheduler(float mean = 0.0f, float std = 1.75f, float logsnr_min = -18.0f, float logsnr_max = 15.0f)
|
||||||
|
: mean(mean), std(std), logsnr_min(logsnr_min), logsnr_max(logsnr_max) {
|
||||||
|
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
|
||||||
|
one_minus_t_min = sigmoid(0.5f * logsnr_max);
|
||||||
|
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
|
||||||
|
one_minus_t_max = sigmoid(0.5f * logsnr_min);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
LogitNormalScheduler(int image_seq_len = 0, const char* extra_sample_args = nullptr) {
|
||||||
|
mean = 0.0f;
|
||||||
|
std = 1.75f;
|
||||||
|
logsnr_min = -15.0f;
|
||||||
|
logsnr_max = 18.0f;
|
||||||
|
|
||||||
|
parse_extra_sample_args(image_seq_len, extra_sample_args);
|
||||||
|
// t_min = 1.0f / (1.0f + std::exp(0.5f * logsnr_max));
|
||||||
|
one_minus_t_min = sigmoid(0.5f * logsnr_max);
|
||||||
|
// t_max = 1.0f / (1.0f + std::exp(0.5f * logsnr_min));
|
||||||
|
one_minus_t_max = sigmoid(0.5f * logsnr_min);
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://stackedboxes.org/2017/05/01/acklams-normal-quantile-function/
|
||||||
|
double ndtri(double p) {
|
||||||
|
if (p <= 0.0) {
|
||||||
|
return -std::numeric_limits<double>::infinity();
|
||||||
|
} else if (p >= 1.0) {
|
||||||
|
return std::numeric_limits<double>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
static const double p_low = 0.02425;
|
||||||
|
static const double p_high = 1.0 - p_low;
|
||||||
|
|
||||||
|
static const double c[6] = {-7.784894002430293e-03,
|
||||||
|
-3.223964580411365e-01,
|
||||||
|
-2.400758277161838e+00,
|
||||||
|
-2.549732539343734e+00,
|
||||||
|
4.374664141464968e+00,
|
||||||
|
2.938163982698783e+00};
|
||||||
|
|
||||||
|
static const double d[5] = {7.784695709041462e-03,
|
||||||
|
3.224671290700398e-01,
|
||||||
|
2.445134137142996e+00,
|
||||||
|
3.754408661907416e+00,
|
||||||
|
1.0};
|
||||||
|
|
||||||
|
// Coefficients for the central region
|
||||||
|
static const double a[6] = {-3.969683028665376e+01,
|
||||||
|
2.209460984245205e+02,
|
||||||
|
-2.759285104469687e+02,
|
||||||
|
1.383577518672690e+02,
|
||||||
|
-3.066479806614716e+01,
|
||||||
|
2.506628277459239e+00};
|
||||||
|
|
||||||
|
static const double b[6] = {-5.447609879822406e+01,
|
||||||
|
1.615858368580409e+02,
|
||||||
|
-1.556989798598866e+02,
|
||||||
|
6.680131188771972e+01,
|
||||||
|
-1.328068155288572e+01,
|
||||||
|
1.0};
|
||||||
|
|
||||||
|
double x = 0.0;
|
||||||
|
|
||||||
|
if (p < p_low) {
|
||||||
|
// Lower region
|
||||||
|
double q = std::sqrt(-2.0 * std::log(p));
|
||||||
|
|
||||||
|
// Numerator: c[0]*q^5 + c[1]*q^4 + ... + c[5]
|
||||||
|
double numerator = c[0];
|
||||||
|
for (int i = 1; i < 6; ++i) {
|
||||||
|
numerator = numerator * q + c[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Denominator: d[0]*q^4 + d[1]*q^3 + ... + d[3]*q + 1
|
||||||
|
double denominator = d[0];
|
||||||
|
for (int i = 1; i < 5; ++i) {
|
||||||
|
denominator = denominator * q + d[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
x = numerator / denominator;
|
||||||
|
} else if (p > p_high) {
|
||||||
|
// Upper region
|
||||||
|
double q = std::sqrt(-2.0 * std::log(1.0 - p));
|
||||||
|
|
||||||
|
double numerator = c[0];
|
||||||
|
for (int i = 1; i < 6; ++i) {
|
||||||
|
numerator = numerator * q + c[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
double denominator = d[0];
|
||||||
|
for (int i = 1; i < 5; ++i) {
|
||||||
|
denominator = denominator * q + d[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
x = -(numerator / denominator);
|
||||||
|
} else {
|
||||||
|
// Central region
|
||||||
|
double q = p - 0.5;
|
||||||
|
double r = q * q;
|
||||||
|
|
||||||
|
// Numerator: (a[0]*r^5 + a[1]*r^4 + ... + a[5])*q
|
||||||
|
double numerator = a[0];
|
||||||
|
for (int i = 1; i < 6; ++i) {
|
||||||
|
numerator = numerator * r + a[i];
|
||||||
|
}
|
||||||
|
numerator *= q;
|
||||||
|
|
||||||
|
// Denominator: b[0]*r^4 + b[1]*r^3 + ... + b[4]*r + 1
|
||||||
|
double denominator = b[0];
|
||||||
|
for (int i = 1; i < 6; ++i) {
|
||||||
|
denominator = denominator * r + b[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
x = numerator / denominator;
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> get_sigmas(uint32_t n, float /*sigma_min*/, float /*sigma_max*/, t_to_sigma_t /*t_to_sigma*/) override {
|
||||||
|
std::vector<float> sigmas;
|
||||||
|
LOG_INFO("LOGIT_NORMAL_SCHEDULER using mean=%.4f, std=%.4f, logsnr_min=%.4f, logsnr_max=%.4f", mean, std, logsnr_min, logsnr_max);
|
||||||
|
sigmas.reserve(n + 1);
|
||||||
|
for (uint32_t i = 0; i <= n; ++i) {
|
||||||
|
float t = static_cast<float>(i) / static_cast<float>(n);
|
||||||
|
|
||||||
|
// ndtri(1-t) == -ndtri(t)
|
||||||
|
float z = -ndtri(t);
|
||||||
|
|
||||||
|
float y = mean + std * z;
|
||||||
|
|
||||||
|
float timestep = sigmoid(y);
|
||||||
|
|
||||||
|
if (timestep > one_minus_t_min)
|
||||||
|
timestep = one_minus_t_min;
|
||||||
|
if (timestep < one_minus_t_max)
|
||||||
|
timestep = one_minus_t_max;
|
||||||
|
|
||||||
|
float sigma = timestep;
|
||||||
|
|
||||||
|
sigmas.push_back(sigma);
|
||||||
|
}
|
||||||
|
sigmas[n] = 0.0f;
|
||||||
|
return sigmas;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Denoiser {
|
struct Denoiser {
|
||||||
virtual float sigma_min() = 0;
|
virtual float sigma_min() = 0;
|
||||||
virtual float sigma_max() = 0;
|
virtual float sigma_max() = 0;
|
||||||
@ -623,6 +820,11 @@ struct Denoiser {
|
|||||||
LOG_INFO("get_sigmas with LTX2 scheduler");
|
LOG_INFO("get_sigmas with LTX2 scheduler");
|
||||||
scheduler = std::make_shared<LTX2Scheduler>(image_seq_len, extra_sample_args);
|
scheduler = std::make_shared<LTX2Scheduler>(image_seq_len, extra_sample_args);
|
||||||
break;
|
break;
|
||||||
|
case LOGIT_NORMAL_SCHEDULER: {
|
||||||
|
LOG_INFO("get_sigmas with Logit-Normal scheduler");
|
||||||
|
scheduler = std::make_shared<LogitNormalScheduler>(image_seq_len, extra_sample_args);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
LOG_INFO("get_sigmas with discrete scheduler (default)");
|
LOG_INFO("get_sigmas with discrete scheduler (default)");
|
||||||
scheduler = std::make_shared<DiscreteScheduler>();
|
scheduler = std::make_shared<DiscreteScheduler>();
|
||||||
|
|||||||
@ -2535,6 +2535,7 @@ const char* scheduler_to_str[] = {
|
|||||||
"lcm",
|
"lcm",
|
||||||
"bong_tangent",
|
"bong_tangent",
|
||||||
"ltx2",
|
"ltx2",
|
||||||
|
"logit_normal",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
const char* sd_scheduler_name(enum scheduler_t scheduler) {
|
||||||
@ -3137,6 +3138,8 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me
|
|||||||
return SIMPLE_SCHEDULER;
|
return SIMPLE_SCHEDULER;
|
||||||
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
|
} else if (sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||||
return LTX2_SCHEDULER;
|
return LTX2_SCHEDULER;
|
||||||
|
} else if(sd_ctx != nullptr && sd_ctx->sd != nullptr && sd_version_is_ideogram4(sd_ctx->sd->version)) {
|
||||||
|
return LOGIT_NORMAL_SCHEDULER;
|
||||||
}
|
}
|
||||||
return DISCRETE_SCHEDULER;
|
return DISCRETE_SCHEDULER;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user