Compare commits

...

6 Commits

Author SHA1 Message Date
leejet
20345888a3
refactor: optimize the handling of sample method (#999) 2025-11-22 14:00:25 +08:00
akleine
490c51d963
feat: report success/failure when saving PNG/JPG output (#912) 2025-11-22 13:57:44 +08:00
Wagner Bruna
45c46779af
feat: add LCM scheduler (#983) 2025-11-22 13:53:31 +08:00
leejet
869d023416
refactor: optimize the handling of scheduler (#998) 2025-11-22 12:48:53 +08:00
akleine
e9bc3b6c06
fix: check the PhotoMaker id_embeds tensor ONLY in PhotoMaker V2 mode (#987) 2025-11-22 12:47:40 +08:00
Wagner Bruna
b542894fb9
fix: avoid crash on default video preview path (#997)
Co-authored-by: masamaru-san
2025-11-22 12:46:27 +08:00
5 changed files with 203 additions and 186 deletions

View File

@ -11,14 +11,13 @@
#define TIMESTEPS 1000 #define TIMESTEPS 1000
#define FLUX_TIMESTEPS 1000 #define FLUX_TIMESTEPS 1000
struct SigmaSchedule { struct SigmaScheduler {
int version = 0;
typedef std::function<float(float)> t_to_sigma_t; typedef std::function<float(float)> t_to_sigma_t;
virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0; virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
}; };
struct DiscreteSchedule : SigmaSchedule { struct DiscreteScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result; std::vector<float> result;
@ -42,7 +41,7 @@ struct DiscreteSchedule : SigmaSchedule {
} }
}; };
struct ExponentialSchedule : SigmaSchedule { struct ExponentialScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas; std::vector<float> sigmas;
@ -149,7 +148,10 @@ std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
/* /*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/ */
struct AYSSchedule : SigmaSchedule { struct AYSScheduler : SigmaScheduler {
SDVersion version;
explicit AYSScheduler(SDVersion version)
: version(version) {}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
const std::vector<float> noise_levels[] = { const std::vector<float> noise_levels[] = {
/* SD1.5 */ /* SD1.5 */
@ -169,19 +171,19 @@ struct AYSSchedule : SigmaSchedule {
std::vector<float> results(n + 1); std::vector<float> results(n + 1);
if (sd_version_is_sd2((SDVersion)version)) { if (sd_version_is_sd2((SDVersion)version)) {
LOG_WARN("AYS not designed for SD2.X models"); LOG_WARN("AYS_SCHEDULER not designed for SD2.X models");
} /* fallthrough */ } /* fallthrough */
else if (sd_version_is_sd1((SDVersion)version)) { else if (sd_version_is_sd1((SDVersion)version)) {
LOG_INFO("AYS using SD1.5 noise levels"); LOG_INFO("AYS_SCHEDULER using SD1.5 noise levels");
inputs = noise_levels[0]; inputs = noise_levels[0];
} else if (sd_version_is_sdxl((SDVersion)version)) { } else if (sd_version_is_sdxl((SDVersion)version)) {
LOG_INFO("AYS using SDXL noise levels"); LOG_INFO("AYS_SCHEDULER using SDXL noise levels");
inputs = noise_levels[1]; inputs = noise_levels[1];
} else if (version == VERSION_SVD) { } else if (version == VERSION_SVD) {
LOG_INFO("AYS using SVD noise levels"); LOG_INFO("AYS_SCHEDULER using SVD noise levels");
inputs = noise_levels[2]; inputs = noise_levels[2];
} else { } else {
LOG_ERROR("Version not compatible with AYS scheduler"); LOG_ERROR("Version not compatible with AYS_SCHEDULER scheduler");
return results; return results;
} }
@ -203,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
/* /*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main * GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/ */
struct GITSSchedule : SigmaSchedule { struct GITSScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
if (sigma_max <= 0.0f) { if (sigma_max <= 0.0f) {
return std::vector<float>{}; return std::vector<float>{};
@ -232,7 +234,7 @@ struct GITSSchedule : SigmaSchedule {
} }
}; };
struct SGMUniformSchedule : SigmaSchedule { struct SGMUniformScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
std::vector<float> result; std::vector<float> result;
if (n == 0) { if (n == 0) {
@ -251,7 +253,24 @@ struct SGMUniformSchedule : SigmaSchedule {
} }
}; };
struct KarrasSchedule : SigmaSchedule { struct LCMScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result;
result.reserve(n + 1);
const int original_steps = 50;
const int k = TIMESTEPS / original_steps;
for (int i = 0; i < n; i++) {
// the rounding ensures we match the training schedule of the LCM model
int index = (i * original_steps) / n;
int timestep = (original_steps - index) * k - 1;
result.push_back(t_to_sigma(timestep));
}
result.push_back(0.0f);
return result;
}
};
struct KarrasScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
// These *COULD* be function arguments here, // These *COULD* be function arguments here,
// but does anybody ever bother to touch them? // but does anybody ever bother to touch them?
@ -270,7 +289,7 @@ struct KarrasSchedule : SigmaSchedule {
} }
}; };
struct SimpleSchedule : SigmaSchedule { struct SimpleScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result_sigmas; std::vector<float> result_sigmas;
@ -299,8 +318,8 @@ struct SimpleSchedule : SigmaSchedule {
} }
}; };
// Close to Beta Schedule, but increadably simple in code. // Close to Beta Scheduler, but increadably simple in code.
struct SmoothStepSchedule : SigmaSchedule { struct SmoothStepScheduler : SigmaScheduler {
static constexpr float smoothstep(float x) { static constexpr float smoothstep(float x) {
return x * x * (3.0f - 2.0f * x); return x * x * (3.0f - 2.0f * x);
} }
@ -329,7 +348,6 @@ struct SmoothStepSchedule : SigmaSchedule {
}; };
struct Denoiser { struct Denoiser {
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0; virtual float sigma_min() = 0;
virtual float sigma_max() = 0; virtual float sigma_max() = 0;
virtual float sigma_to_t(float sigma) = 0; virtual float sigma_to_t(float sigma) = 0;
@ -338,8 +356,51 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0; virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0; virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
virtual std::vector<float> get_sigmas(uint32_t n) { virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) {
case DISCRETE_SCHEDULER:
LOG_INFO("get_sigmas with discrete scheduler");
scheduler = std::make_shared<DiscreteScheduler>();
break;
case KARRAS_SCHEDULER:
LOG_INFO("get_sigmas with Karras scheduler");
scheduler = std::make_shared<KarrasScheduler>();
break;
case EXPONENTIAL_SCHEDULER:
LOG_INFO("get_sigmas exponential scheduler");
scheduler = std::make_shared<ExponentialScheduler>();
break;
case AYS_SCHEDULER:
LOG_INFO("get_sigmas with Align-Your-Steps scheduler");
scheduler = std::make_shared<AYSScheduler>(version);
break;
case GITS_SCHEDULER:
LOG_INFO("get_sigmas with GITS scheduler");
scheduler = std::make_shared<GITSScheduler>();
break;
case SGM_UNIFORM_SCHEDULER:
LOG_INFO("get_sigmas with SGM Uniform scheduler");
scheduler = std::make_shared<SGMUniformScheduler>();
break;
case SIMPLE_SCHEDULER:
LOG_INFO("get_sigmas with Simple scheduler");
scheduler = std::make_shared<SimpleScheduler>();
break;
case SMOOTHSTEP_SCHEDULER:
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>();
break;
case LCM_SCHEDULER:
LOG_INFO("get_sigmas with LCM scheduler");
scheduler = std::make_shared<LCMScheduler>();
break;
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();
break;
}
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
} }
}; };
@ -426,7 +487,6 @@ struct EDMVDenoiser : public CompVisVDenoiser {
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) { : min_sigma(min_sigma), max_sigma(max_sigma) {
scheduler = std::make_shared<ExponentialSchedule>();
} }
float t_to_sigma(float t) override { float t_to_sigma(float t) override {
@ -580,7 +640,7 @@ static void sample_k_diffusion(sample_method_t method,
size_t steps = sigmas.size() - 1; size_t steps = sigmas.size() - 1;
// sample_euler_ancestral // sample_euler_ancestral
switch (method) { switch (method) {
case EULER_A: { case EULER_A_SAMPLE_METHOD: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@ -633,7 +693,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case EULER: // Implemented without any sigma churn case EULER_SAMPLE_METHOD: // Implemented without any sigma churn
{ {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@ -666,7 +726,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case HEUN: { case HEUN_SAMPLE_METHOD: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
@ -716,7 +776,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DPM2: { case DPM2_SAMPLE_METHOD: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
@ -768,7 +828,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} break; } break;
case DPMPP2S_A: { case DPMPP2S_A_SAMPLE_METHOD: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
@ -832,7 +892,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DPMPP2M: // DPM++ (2M) from Karras et al (2022) case DPMPP2M_SAMPLE_METHOD: // DPM++ (2M) from Karras et al (2022)
{ {
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
@ -871,7 +931,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 case DPMPP2Mv2_SAMPLE_METHOD: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
{ {
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
@ -914,7 +974,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main case IPNDM_SAMPLE_METHOD: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{ {
int max_order = 4; int max_order = 4;
ggml_tensor* x_next = x; ggml_tensor* x_next = x;
@ -989,7 +1049,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main case IPNDM_V_SAMPLE_METHOD: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{ {
int max_order = 4; int max_order = 4;
std::vector<ggml_tensor*> buffer_model; std::vector<ggml_tensor*> buffer_model;
@ -1063,7 +1123,7 @@ static void sample_k_diffusion(sample_method_t method,
d_cur = ggml_dup_tensor(work_ctx, x_next); d_cur = ggml_dup_tensor(work_ctx, x_next);
} }
} break; } break;
case LCM: // Latent Consistency Models case LCM_SAMPLE_METHOD: // Latent Consistency Models
{ {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@ -1098,8 +1158,8 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DDIM_TRAILING: // Denoising Diffusion Implicit Models case DDIM_TRAILING_SAMPLE_METHOD: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing // with the "trailing" timestep spacing
{ {
// See J. Song et al., "Denoising Diffusion Implicit // See J. Song et al., "Denoising Diffusion Implicit
// Models", arXiv:2010.02502 [cs.LG] // Models", arXiv:2010.02502 [cs.LG]
@ -1109,7 +1169,7 @@ static void sample_k_diffusion(sample_method_t method,
// end beta) (which unfortunately k-diffusion's data // end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are // structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser // also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteScheduler) // (k-diffusion's LMSDiscreteSchedulerr)
float beta_start = 0.00085f; float beta_start = 0.00085f;
float beta_end = 0.0120f; float beta_end = 0.0120f;
std::vector<double> alphas_cumprod; std::vector<double> alphas_cumprod;
@ -1137,7 +1197,7 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al., // The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedules and Sample Steps // "Common Diffusion Noise Schedulers and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table // are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming // 2. Most variables below follow Diffusers naming
// //
@ -1292,8 +1352,8 @@ static void sample_k_diffusion(sample_method_t method,
// factor c_in. // factor c_in.
} }
} break; } break;
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in case TCD_SAMPLE_METHOD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation // Trajectory Consistency Distillation
{ {
// See J. Zheng et al., "Trajectory Consistency // See J. Zheng et al., "Trajectory Consistency
// Distillation: Improved Latent Consistency Distillation // Distillation: Improved Latent Consistency Distillation

View File

@ -107,8 +107,8 @@ Options:
compatibility issues with quantized parameters, but it usually offers faster inference compatibility issues with quantized parameters, but it usually offers faster inference
speed and, in some cases, lower memory usage. The at_runtime mode, on the other speed and, in some cases, lower memory usage. The at_runtime mode, on the other
hand, is exactly the opposite. hand, is exactly the opposite.
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
discrete default: discrete
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise

View File

@ -912,13 +912,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_schedule_arg = [&](int argc, const char** argv, int index) { auto on_scheduler_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
} }
const char* arg = argv[index]; const char* arg = argv[index];
params.sample_params.scheduler = str_to_schedule(arg); params.sample_params.scheduler = str_to_scheduler(arg);
if (params.sample_params.scheduler == SCHEDULE_COUNT) { if (params.sample_params.scheduler == SCHEDULER_COUNT) {
fprintf(stderr, "error: invalid scheduler %s\n", fprintf(stderr, "error: invalid scheduler %s\n",
arg); arg);
return -1; return -1;
@ -926,20 +926,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_high_noise_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.high_noise_sample_params.scheduler = str_to_schedule(arg);
if (params.high_noise_sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid high noise scheduler %s\n",
arg);
return -1;
}
return 1;
};
auto on_prediction_arg = [&](int argc, const char** argv, int index) { auto on_prediction_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -1211,8 +1197,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
on_lora_apply_mode_arg}, on_lora_apply_mode_arg},
{"", {"",
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
on_schedule_arg}, on_scheduler_arg},
{"", {"",
"--skip-layers", "--skip-layers",
"layers to skip for SLG steps (default: [7,8,9])", "layers to skip for SLG steps (default: [7,8,9])",
@ -1222,10 +1208,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
" default: euler for Flux/SD3/Wan, euler_a otherwise", " default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg}, on_high_noise_sample_method_arg},
{"",
"--high-noise-scheduler",
"(high noise) denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_high_noise_schedule_arg},
{"", {"",
"--high-noise-skip-layers", "--high-noise-skip-layers",
"(high noise) layers to skip for SLG steps (default: [7,8,9])", "(high noise) layers to skip for SLG steps (default: [7,8,9])",
@ -1442,8 +1424,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", "; parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
} }
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method)); parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
if (params.sample_params.scheduler != DEFAULT) { if (params.sample_params.scheduler != SCHEDULER_COUNT) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler)); parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
@ -1648,7 +1630,7 @@ bool load_images_from_dir(const std::string dir,
return true; return true;
} }
const char* preview_path; std::string preview_path;
float preview_fps; float preview_fps;
void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) { void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) {
@ -1657,16 +1639,16 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy)
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents // is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
// unused in this app, it will either be always noisy or always denoised here // unused in this app, it will either be always noisy or always denoised here
if (frame_count == 1) { if (frame_count == 1) {
stbi_write_png(preview_path, image->width, image->height, image->channel, image->data, 0); stbi_write_png(preview_path.c_str(), image->width, image->height, image->channel, image->data, 0);
} else { } else {
create_mjpg_avi_from_sd_images(preview_path, image, frame_count, preview_fps); create_mjpg_avi_from_sd_images(preview_path.c_str(), image, frame_count, preview_fps);
} }
} }
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
parse_args(argc, argv, params); parse_args(argc, argv, params);
preview_path = params.preview_path.c_str(); preview_path = params.preview_path;
if (params.video_frames > 4) { if (params.video_frames > 4) {
size_t last_dot_pos = params.preview_path.find_last_of("."); size_t last_dot_pos = params.preview_path.find_last_of(".");
std::string base_path = params.preview_path; std::string base_path = params.preview_path;
@ -1677,8 +1659,7 @@ int main(int argc, const char* argv[]) {
std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower); std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower);
} }
if (file_ext == ".png") { if (file_ext == ".png") {
base_path = base_path + ".avi"; preview_path = base_path + ".avi";
preview_path = base_path.c_str();
} }
} }
preview_fps = params.fps; preview_fps = params.fps;
@ -1921,10 +1902,18 @@ int main(int argc, const char* argv[]) {
return 1; return 1;
} }
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
} }
if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
}
if (params.mode == IMG_GEN) { if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(), params.prompt.c_str(),
@ -2067,15 +2056,16 @@ int main(int argc, const char* argv[]) {
if (results[i].data == nullptr) { if (results[i].data == nullptr) {
continue; continue;
} }
int write_ok;
std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext; std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext;
if (is_jpg) { if (is_jpg) {
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(params, params.seed + i).c_str()); results[i].data, 90, get_image_params(params, params.seed + i).c_str());
printf("save result JPEG image to '%s'\n", final_image_path.c_str()); printf("save result JPEG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
} else { } else {
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str()); results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result PNG image to '%s'\n", final_image_path.c_str()); printf("save result PNG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
} }
} }
} }

View File

@ -47,8 +47,8 @@ const char* model_version_to_str[] = {
}; };
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
"default",
"Euler", "Euler",
"Euler A",
"Heun", "Heun",
"DPM2", "DPM2",
"DPM++ (2s)", "DPM++ (2s)",
@ -59,7 +59,6 @@ const char* sampling_methods_str[] = {
"LCM", "LCM",
"DDIM \"trailing\"", "DDIM \"trailing\"",
"TCD", "TCD",
"Euler A",
}; };
/*================================================== Helper Functions ================================================*/ /*================================================== Helper Functions ================================================*/
@ -870,53 +869,6 @@ public:
return true; return true;
} }
void init_scheduler(scheduler_t scheduler) {
switch (scheduler) {
case DISCRETE:
LOG_INFO("running with discrete scheduler");
denoiser->scheduler = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras scheduler");
denoiser->scheduler = std::make_shared<KarrasSchedule>();
break;
case EXPONENTIAL:
LOG_INFO("running exponential scheduler");
denoiser->scheduler = std::make_shared<ExponentialSchedule>();
break;
case AYS:
LOG_INFO("Running with Align-Your-Steps scheduler");
denoiser->scheduler = std::make_shared<AYSSchedule>();
denoiser->scheduler->version = version;
break;
case GITS:
LOG_INFO("Running with GITS scheduler");
denoiser->scheduler = std::make_shared<GITSSchedule>();
denoiser->scheduler->version = version;
break;
case SGM_UNIFORM:
LOG_INFO("Running with SGM Uniform schedule");
denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
denoiser->scheduler->version = version;
break;
case SIMPLE:
LOG_INFO("Running with Simple schedule");
denoiser->scheduler = std::make_shared<SimpleSchedule>();
denoiser->scheduler->version = version;
break;
case SMOOTHSTEP:
LOG_INFO("Running with SmoothStep scheduler");
denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown scheduler %i", scheduler);
abort();
}
}
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(x_t, 0.5); ggml_set_f32(x_t, 0.5);
@ -2275,8 +2227,8 @@ enum rng_type_t str_to_rng_type(const char* str) {
} }
const char* sample_method_to_str[] = { const char* sample_method_to_str[] = {
"default",
"euler", "euler",
"euler_a",
"heun", "heun",
"dpm2", "dpm2",
"dpm++2s_a", "dpm++2s_a",
@ -2287,7 +2239,6 @@ const char* sample_method_to_str[] = {
"lcm", "lcm",
"ddim_trailing", "ddim_trailing",
"tcd", "tcd",
"euler_a",
}; };
const char* sd_sample_method_name(enum sample_method_t sample_method) { const char* sd_sample_method_name(enum sample_method_t sample_method) {
@ -2306,8 +2257,7 @@ enum sample_method_t str_to_sample_method(const char* str) {
return SAMPLE_METHOD_COUNT; return SAMPLE_METHOD_COUNT;
} }
const char* schedule_to_str[] = { const char* scheduler_to_str[] = {
"default",
"discrete", "discrete",
"karras", "karras",
"exponential", "exponential",
@ -2316,22 +2266,23 @@ const char* schedule_to_str[] = {
"sgm_uniform", "sgm_uniform",
"simple", "simple",
"smoothstep", "smoothstep",
"lcm",
}; };
const char* sd_schedule_name(enum scheduler_t scheduler) { const char* sd_scheduler_name(enum scheduler_t scheduler) {
if (scheduler < SCHEDULE_COUNT) { if (scheduler < SCHEDULER_COUNT) {
return schedule_to_str[scheduler]; return scheduler_to_str[scheduler];
} }
return NONE_STR; return NONE_STR;
} }
enum scheduler_t str_to_schedule(const char* str) { enum scheduler_t str_to_scheduler(const char* str) {
for (int i = 0; i < SCHEDULE_COUNT; i++) { for (int i = 0; i < SCHEDULER_COUNT; i++) {
if (!strcmp(str, schedule_to_str[i])) { if (!strcmp(str, scheduler_to_str[i])) {
return (enum scheduler_t)i; return (enum scheduler_t)i;
} }
} }
return SCHEDULE_COUNT; return SCHEDULER_COUNT;
} }
const char* prediction_to_str[] = { const char* prediction_to_str[] = {
@ -2515,8 +2466,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_start = 0.01f; sample_params->guidance.slg.layer_start = 0.01f;
sample_params->guidance.slg.layer_end = 0.2f; sample_params->guidance.slg.layer_end = 0.2f;
sample_params->guidance.slg.scale = 0.f; sample_params->guidance.slg.scale = 0.f;
sample_params->scheduler = DEFAULT; sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_DEFAULT; sample_params->sample_method = SAMPLE_METHOD_COUNT;
sample_params->sample_steps = 20; sample_params->sample_steps = 20;
} }
@ -2548,7 +2499,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_start, sample_params->guidance.slg.layer_start,
sample_params->guidance.slg.layer_end, sample_params->guidance.slg.layer_end,
sample_params->guidance.slg.scale, sample_params->guidance.slg.scale,
sd_schedule_name(sample_params->scheduler), sd_scheduler_name(sample_params->scheduler),
sd_sample_method_name(sample_params->sample_method), sd_sample_method_name(sample_params->sample_method),
sample_params->sample_steps, sample_params->sample_steps,
sample_params->eta, sample_params->eta,
@ -2674,13 +2625,21 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
SDVersion version = sd_ctx->sd->version; if (sd_version_is_dit(sd_ctx->sd->version)) {
if (sd_version_is_dit(version)) return EULER_SAMPLE_METHOD;
return EULER; }
else
return EULER_A;
} }
return SAMPLE_METHOD_COUNT; return EULER_A_SAMPLE_METHOD;
}
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
}
}
return DISCRETE_SCHEDULER;
} }
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
@ -2800,7 +2759,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
LOG_WARN("Turn off PhotoMaker"); LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false; sd_ctx->sd->stacked_id = false;
} else { } else {
if (pm_params.id_images_count != id_embeds->ne[1]) { if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) {
LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]); LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]);
LOG_WARN("Turn off PhotoMaker"); LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false; sd_ctx->sd->stacked_id = false;
@ -2866,7 +2825,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int C = sd_ctx->sd->get_latent_channel(); int C = sd_ctx->sd->get_latent_channel();
int W = width / sd_ctx->sd->get_vae_scale_factor(); int W = width / sd_ctx->sd->get_vae_scale_factor();
int H = height / sd_ctx->sd->get_vae_scale_factor(); int H = height / sd_ctx->sd->get_vae_scale_factor();
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* control_latent = nullptr; struct ggml_tensor* control_latent = nullptr;
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) { if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
@ -3095,12 +3053,16 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
sd_ctx->sd->sampler_rng->manual_seed(seed); sd_ctx->sd->sampler_rng->manual_seed(seed);
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
size_t t0 = ggml_time_ms(); size_t t0 = ggml_time_ms();
sd_ctx->sd->init_scheduler(sd_img_gen_params->sample_params.scheduler); enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);
ggml_tensor* init_latent = nullptr; ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr; ggml_tensor* concat_latent = nullptr;
@ -3288,11 +3250,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
} }
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_DEFAULT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
sd_image_t* result_images = generate_image_internal(sd_ctx, sd_image_t* result_images = generate_image_internal(sd_ctx,
work_ctx, work_ctx,
init_latent, init_latent,
@ -3342,11 +3299,14 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler); enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
int high_noise_sample_steps = 0; int high_noise_sample_steps = 0;
if (sd_ctx->sd->high_noise_diffusion_model) { if (sd_ctx->sd->high_noise_diffusion_model) {
sd_ctx->sd->init_scheduler(sd_vid_gen_params->high_noise_sample_params.scheduler);
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps; high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
} }
@ -3355,7 +3315,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps; total_steps += high_noise_sample_steps;
} }
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);
if (high_noise_sample_steps < 0) { if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b) // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
@ -3613,6 +3573,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
// High Noise Sample // High Noise Sample
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T); LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
if (high_noise_sample_method == SAMPLE_METHOD_COUNT) {
high_noise_sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1); std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
@ -3631,7 +3597,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->high_noise_sample_params.guidance, sd_vid_gen_params->high_noise_sample_params.guidance,
sd_vid_gen_params->high_noise_sample_params.eta, sd_vid_gen_params->high_noise_sample_params.eta,
sd_vid_gen_params->high_noise_sample_params.shifted_timestep, sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
sd_vid_gen_params->high_noise_sample_params.sample_method, high_noise_sample_method,
high_noise_sigmas, high_noise_sigmas,
-1, -1,
{}, {},
@ -3668,7 +3634,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->sample_params.guidance, sd_vid_gen_params->sample_params.guidance,
sd_vid_gen_params->sample_params.eta, sd_vid_gen_params->sample_params.eta,
sd_vid_gen_params->sample_params.shifted_timestep, sd_vid_gen_params->sample_params.shifted_timestep,
sd_vid_gen_params->sample_params.sample_method, sample_method,
sigmas, sigmas,
-1, -1,
{}, {},

View File

@ -36,33 +36,32 @@ enum rng_type_t {
}; };
enum sample_method_t { enum sample_method_t {
SAMPLE_METHOD_DEFAULT, EULER_SAMPLE_METHOD,
EULER, EULER_A_SAMPLE_METHOD,
HEUN, HEUN_SAMPLE_METHOD,
DPM2, DPM2_SAMPLE_METHOD,
DPMPP2S_A, DPMPP2S_A_SAMPLE_METHOD,
DPMPP2M, DPMPP2M_SAMPLE_METHOD,
DPMPP2Mv2, DPMPP2Mv2_SAMPLE_METHOD,
IPNDM, IPNDM_SAMPLE_METHOD,
IPNDM_V, IPNDM_V_SAMPLE_METHOD,
LCM, LCM_SAMPLE_METHOD,
DDIM_TRAILING, DDIM_TRAILING_SAMPLE_METHOD,
TCD, TCD_SAMPLE_METHOD,
EULER_A,
SAMPLE_METHOD_COUNT SAMPLE_METHOD_COUNT
}; };
enum scheduler_t { enum scheduler_t {
DEFAULT, DISCRETE_SCHEDULER,
DISCRETE, KARRAS_SCHEDULER,
KARRAS, EXPONENTIAL_SCHEDULER,
EXPONENTIAL, AYS_SCHEDULER,
AYS, GITS_SCHEDULER,
GITS, SGM_UNIFORM_SCHEDULER,
SGM_UNIFORM, SIMPLE_SCHEDULER,
SIMPLE, SMOOTHSTEP_SCHEDULER,
SMOOTHSTEP, LCM_SCHEDULER,
SCHEDULE_COUNT SCHEDULER_COUNT
}; };
enum prediction_t { enum prediction_t {
@ -297,8 +296,8 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type);
SD_API enum rng_type_t str_to_rng_type(const char* str); SD_API enum rng_type_t str_to_rng_type(const char* str);
SD_API const char* sd_sample_method_name(enum sample_method_t sample_method); SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
SD_API enum sample_method_t str_to_sample_method(const char* str); SD_API enum sample_method_t str_to_sample_method(const char* str);
SD_API const char* sd_schedule_name(enum scheduler_t scheduler); SD_API const char* sd_scheduler_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_schedule(const char* str); SD_API enum scheduler_t str_to_scheduler(const char* str);
SD_API const char* sd_prediction_name(enum prediction_t prediction); SD_API const char* sd_prediction_name(enum prediction_t prediction);
SD_API enum prediction_t str_to_prediction(const char* str); SD_API enum prediction_t str_to_prediction(const char* str);
SD_API const char* sd_preview_name(enum preview_t preview); SD_API const char* sd_preview_name(enum preview_t preview);
@ -313,11 +312,13 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);