Compare commits

..

No commits in common. "20345888a313c11826e5d511c5c00109b8ab61a6" and "5498cc0d67b0f95b4ef6004890b18cdbe3678175" have entirely different histories.

5 changed files with 186 additions and 203 deletions

View File

@ -11,13 +11,14 @@
#define TIMESTEPS 1000 #define TIMESTEPS 1000
#define FLUX_TIMESTEPS 1000 #define FLUX_TIMESTEPS 1000
struct SigmaScheduler { struct SigmaSchedule {
int version = 0;
typedef std::function<float(float)> t_to_sigma_t; typedef std::function<float(float)> t_to_sigma_t;
virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0; virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
}; };
struct DiscreteScheduler : SigmaScheduler { struct DiscreteSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result; std::vector<float> result;
@ -41,7 +42,7 @@ struct DiscreteScheduler : SigmaScheduler {
} }
}; };
struct ExponentialScheduler : SigmaScheduler { struct ExponentialSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas; std::vector<float> sigmas;
@ -148,10 +149,7 @@ std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
/* /*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/ */
struct AYSScheduler : SigmaScheduler { struct AYSSchedule : SigmaSchedule {
SDVersion version;
explicit AYSScheduler(SDVersion version)
: version(version) {}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
const std::vector<float> noise_levels[] = { const std::vector<float> noise_levels[] = {
/* SD1.5 */ /* SD1.5 */
@ -171,19 +169,19 @@ struct AYSScheduler : SigmaScheduler {
std::vector<float> results(n + 1); std::vector<float> results(n + 1);
if (sd_version_is_sd2((SDVersion)version)) { if (sd_version_is_sd2((SDVersion)version)) {
LOG_WARN("AYS_SCHEDULER not designed for SD2.X models"); LOG_WARN("AYS not designed for SD2.X models");
} /* fallthrough */ } /* fallthrough */
else if (sd_version_is_sd1((SDVersion)version)) { else if (sd_version_is_sd1((SDVersion)version)) {
LOG_INFO("AYS_SCHEDULER using SD1.5 noise levels"); LOG_INFO("AYS using SD1.5 noise levels");
inputs = noise_levels[0]; inputs = noise_levels[0];
} else if (sd_version_is_sdxl((SDVersion)version)) { } else if (sd_version_is_sdxl((SDVersion)version)) {
LOG_INFO("AYS_SCHEDULER using SDXL noise levels"); LOG_INFO("AYS using SDXL noise levels");
inputs = noise_levels[1]; inputs = noise_levels[1];
} else if (version == VERSION_SVD) { } else if (version == VERSION_SVD) {
LOG_INFO("AYS_SCHEDULER using SVD noise levels"); LOG_INFO("AYS using SVD noise levels");
inputs = noise_levels[2]; inputs = noise_levels[2];
} else { } else {
LOG_ERROR("Version not compatible with AYS_SCHEDULER scheduler"); LOG_ERROR("Version not compatible with AYS scheduler");
return results; return results;
} }
@ -205,7 +203,7 @@ struct AYSScheduler : SigmaScheduler {
/* /*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main * GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/ */
struct GITSScheduler : SigmaScheduler { struct GITSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
if (sigma_max <= 0.0f) { if (sigma_max <= 0.0f) {
return std::vector<float>{}; return std::vector<float>{};
@ -234,7 +232,7 @@ struct GITSScheduler : SigmaScheduler {
} }
}; };
struct SGMUniformScheduler : SigmaScheduler { struct SGMUniformSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
std::vector<float> result; std::vector<float> result;
if (n == 0) { if (n == 0) {
@ -253,24 +251,7 @@ struct SGMUniformScheduler : SigmaScheduler {
} }
}; };
struct LCMScheduler : SigmaScheduler { struct KarrasSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result;
result.reserve(n + 1);
const int original_steps = 50;
const int k = TIMESTEPS / original_steps;
for (int i = 0; i < n; i++) {
// the rounding ensures we match the training schedule of the LCM model
int index = (i * original_steps) / n;
int timestep = (original_steps - index) * k - 1;
result.push_back(t_to_sigma(timestep));
}
result.push_back(0.0f);
return result;
}
};
struct KarrasScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
// These *COULD* be function arguments here, // These *COULD* be function arguments here,
// but does anybody ever bother to touch them? // but does anybody ever bother to touch them?
@ -289,7 +270,7 @@ struct KarrasScheduler : SigmaScheduler {
} }
}; };
struct SimpleScheduler : SigmaScheduler { struct SimpleSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result_sigmas; std::vector<float> result_sigmas;
@ -318,8 +299,8 @@ struct SimpleScheduler : SigmaScheduler {
} }
}; };
// Close to Beta Scheduler, but increadably simple in code. // Close to Beta Schedule, but increadably simple in code.
struct SmoothStepScheduler : SigmaScheduler { struct SmoothStepSchedule : SigmaSchedule {
static constexpr float smoothstep(float x) { static constexpr float smoothstep(float x) {
return x * x * (3.0f - 2.0f * x); return x * x * (3.0f - 2.0f * x);
} }
@ -348,6 +329,7 @@ struct SmoothStepScheduler : SigmaScheduler {
}; };
struct Denoiser { struct Denoiser {
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0; virtual float sigma_min() = 0;
virtual float sigma_max() = 0; virtual float sigma_max() = 0;
virtual float sigma_to_t(float sigma) = 0; virtual float sigma_to_t(float sigma) = 0;
@ -356,51 +338,8 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0; virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0; virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) { virtual std::vector<float> get_sigmas(uint32_t n) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) {
case DISCRETE_SCHEDULER:
LOG_INFO("get_sigmas with discrete scheduler");
scheduler = std::make_shared<DiscreteScheduler>();
break;
case KARRAS_SCHEDULER:
LOG_INFO("get_sigmas with Karras scheduler");
scheduler = std::make_shared<KarrasScheduler>();
break;
case EXPONENTIAL_SCHEDULER:
LOG_INFO("get_sigmas exponential scheduler");
scheduler = std::make_shared<ExponentialScheduler>();
break;
case AYS_SCHEDULER:
LOG_INFO("get_sigmas with Align-Your-Steps scheduler");
scheduler = std::make_shared<AYSScheduler>(version);
break;
case GITS_SCHEDULER:
LOG_INFO("get_sigmas with GITS scheduler");
scheduler = std::make_shared<GITSScheduler>();
break;
case SGM_UNIFORM_SCHEDULER:
LOG_INFO("get_sigmas with SGM Uniform scheduler");
scheduler = std::make_shared<SGMUniformScheduler>();
break;
case SIMPLE_SCHEDULER:
LOG_INFO("get_sigmas with Simple scheduler");
scheduler = std::make_shared<SimpleScheduler>();
break;
case SMOOTHSTEP_SCHEDULER:
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>();
break;
case LCM_SCHEDULER:
LOG_INFO("get_sigmas with LCM scheduler");
scheduler = std::make_shared<LCMScheduler>();
break;
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();
break;
}
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
} }
}; };
@ -487,6 +426,7 @@ struct EDMVDenoiser : public CompVisVDenoiser {
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) { : min_sigma(min_sigma), max_sigma(max_sigma) {
scheduler = std::make_shared<ExponentialSchedule>();
} }
float t_to_sigma(float t) override { float t_to_sigma(float t) override {
@ -640,7 +580,7 @@ static void sample_k_diffusion(sample_method_t method,
size_t steps = sigmas.size() - 1; size_t steps = sigmas.size() - 1;
// sample_euler_ancestral // sample_euler_ancestral
switch (method) { switch (method) {
case EULER_A_SAMPLE_METHOD: { case EULER_A: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@ -693,7 +633,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case EULER_SAMPLE_METHOD: // Implemented without any sigma churn case EULER: // Implemented without any sigma churn
{ {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@ -726,7 +666,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case HEUN_SAMPLE_METHOD: { case HEUN: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
@ -776,7 +716,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DPM2_SAMPLE_METHOD: { case DPM2: {
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
@ -828,7 +768,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} break; } break;
case DPMPP2S_A_SAMPLE_METHOD: { case DPMPP2S_A: {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
@ -892,7 +832,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DPMPP2M_SAMPLE_METHOD: // DPM++ (2M) from Karras et al (2022) case DPMPP2M: // DPM++ (2M) from Karras et al (2022)
{ {
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
@ -931,7 +871,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DPMPP2Mv2_SAMPLE_METHOD: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
{ {
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
@ -974,7 +914,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case IPNDM_SAMPLE_METHOD: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{ {
int max_order = 4; int max_order = 4;
ggml_tensor* x_next = x; ggml_tensor* x_next = x;
@ -1049,7 +989,7 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case IPNDM_V_SAMPLE_METHOD: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
{ {
int max_order = 4; int max_order = 4;
std::vector<ggml_tensor*> buffer_model; std::vector<ggml_tensor*> buffer_model;
@ -1123,7 +1063,7 @@ static void sample_k_diffusion(sample_method_t method,
d_cur = ggml_dup_tensor(work_ctx, x_next); d_cur = ggml_dup_tensor(work_ctx, x_next);
} }
} break; } break;
case LCM_SAMPLE_METHOD: // Latent Consistency Models case LCM: // Latent Consistency Models
{ {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
@ -1158,8 +1098,8 @@ static void sample_k_diffusion(sample_method_t method,
} }
} }
} break; } break;
case DDIM_TRAILING_SAMPLE_METHOD: // Denoising Diffusion Implicit Models case DDIM_TRAILING: // Denoising Diffusion Implicit Models
// with the "trailing" timestep spacing // with the "trailing" timestep spacing
{ {
// See J. Song et al., "Denoising Diffusion Implicit // See J. Song et al., "Denoising Diffusion Implicit
// Models", arXiv:2010.02502 [cs.LG] // Models", arXiv:2010.02502 [cs.LG]
@ -1169,7 +1109,7 @@ static void sample_k_diffusion(sample_method_t method,
// end beta) (which unfortunately k-diffusion's data // end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are // structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser // also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteSchedulerr) // (k-diffusion's LMSDiscreteScheduler)
float beta_start = 0.00085f; float beta_start = 0.00085f;
float beta_end = 0.0120f; float beta_end = 0.0120f;
std::vector<double> alphas_cumprod; std::vector<double> alphas_cumprod;
@ -1197,7 +1137,7 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al., // The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedulers and Sample Steps // "Common Diffusion Noise Schedules and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table // are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming // 2. Most variables below follow Diffusers naming
// //
@ -1352,8 +1292,8 @@ static void sample_k_diffusion(sample_method_t method,
// factor c_in. // factor c_in.
} }
} break; } break;
case TCD_SAMPLE_METHOD: // Strategic Stochastic Sampling (Algorithm 4) in case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
// Trajectory Consistency Distillation // Trajectory Consistency Distillation
{ {
// See J. Zheng et al., "Trajectory Consistency // See J. Zheng et al., "Trajectory Consistency
// Distillation: Improved Latent Consistency Distillation // Distillation: Improved Latent Consistency Distillation

View File

@ -107,8 +107,8 @@ Options:
compatibility issues with quantized parameters, but it usually offers faster inference compatibility issues with quantized parameters, but it usually offers faster inference
speed and, in some cases, lower memory usage. The at_runtime mode, on the other speed and, in some cases, lower memory usage. The at_runtime mode, on the other
hand, is exactly the opposite. hand, is exactly the opposite.
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default:
default: discrete discrete
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, --high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise

View File

@ -912,13 +912,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_scheduler_arg = [&](int argc, const char** argv, int index) { auto on_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
} }
const char* arg = argv[index]; const char* arg = argv[index];
params.sample_params.scheduler = str_to_scheduler(arg); params.sample_params.scheduler = str_to_schedule(arg);
if (params.sample_params.scheduler == SCHEDULER_COUNT) { if (params.sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid scheduler %s\n", fprintf(stderr, "error: invalid scheduler %s\n",
arg); arg);
return -1; return -1;
@ -926,6 +926,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_high_noise_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.high_noise_sample_params.scheduler = str_to_schedule(arg);
if (params.high_noise_sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid high noise scheduler %s\n",
arg);
return -1;
}
return 1;
};
auto on_prediction_arg = [&](int argc, const char** argv, int index) { auto on_prediction_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -1197,8 +1211,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
on_lora_apply_mode_arg}, on_lora_apply_mode_arg},
{"", {"",
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_scheduler_arg}, on_schedule_arg},
{"", {"",
"--skip-layers", "--skip-layers",
"layers to skip for SLG steps (default: [7,8,9])", "layers to skip for SLG steps (default: [7,8,9])",
@ -1208,6 +1222,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
" default: euler for Flux/SD3/Wan, euler_a otherwise", " default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg}, on_high_noise_sample_method_arg},
{"",
"--high-noise-scheduler",
"(high noise) denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_high_noise_schedule_arg},
{"", {"",
"--high-noise-skip-layers", "--high-noise-skip-layers",
"(high noise) layers to skip for SLG steps (default: [7,8,9])", "(high noise) layers to skip for SLG steps (default: [7,8,9])",
@ -1424,8 +1442,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", "; parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
} }
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method)); parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
if (params.sample_params.scheduler != SCHEDULER_COUNT) { if (params.sample_params.scheduler != DEFAULT) {
parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler)); parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
@ -1630,7 +1648,7 @@ bool load_images_from_dir(const std::string dir,
return true; return true;
} }
std::string preview_path; const char* preview_path;
float preview_fps; float preview_fps;
void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) { void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) {
@ -1639,16 +1657,16 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy)
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents // is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
// unused in this app, it will either be always noisy or always denoised here // unused in this app, it will either be always noisy or always denoised here
if (frame_count == 1) { if (frame_count == 1) {
stbi_write_png(preview_path.c_str(), image->width, image->height, image->channel, image->data, 0); stbi_write_png(preview_path, image->width, image->height, image->channel, image->data, 0);
} else { } else {
create_mjpg_avi_from_sd_images(preview_path.c_str(), image, frame_count, preview_fps); create_mjpg_avi_from_sd_images(preview_path, image, frame_count, preview_fps);
} }
} }
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
parse_args(argc, argv, params); parse_args(argc, argv, params);
preview_path = params.preview_path; preview_path = params.preview_path.c_str();
if (params.video_frames > 4) { if (params.video_frames > 4) {
size_t last_dot_pos = params.preview_path.find_last_of("."); size_t last_dot_pos = params.preview_path.find_last_of(".");
std::string base_path = params.preview_path; std::string base_path = params.preview_path;
@ -1659,7 +1677,8 @@ int main(int argc, const char* argv[]) {
std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower); std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower);
} }
if (file_ext == ".png") { if (file_ext == ".png") {
preview_path = base_path + ".avi"; base_path = base_path + ".avi";
preview_path = base_path.c_str();
} }
} }
preview_fps = params.fps; preview_fps = params.fps;
@ -1902,18 +1921,10 @@ int main(int argc, const char* argv[]) {
return 1; return 1;
} }
if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) { if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
} }
if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
}
if (params.mode == IMG_GEN) { if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(), params.prompt.c_str(),
@ -2056,16 +2067,15 @@ int main(int argc, const char* argv[]) {
if (results[i].data == nullptr) { if (results[i].data == nullptr) {
continue; continue;
} }
int write_ok;
std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext; std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext;
if (is_jpg) { if (is_jpg) {
write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(params, params.seed + i).c_str()); results[i].data, 90, get_image_params(params, params.seed + i).c_str());
printf("save result JPEG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success"); printf("save result JPEG image to '%s'\n", final_image_path.c_str());
} else { } else {
write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str()); results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result PNG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success"); printf("save result PNG image to '%s'\n", final_image_path.c_str());
} }
} }
} }

View File

@ -47,8 +47,8 @@ const char* model_version_to_str[] = {
}; };
const char* sampling_methods_str[] = { const char* sampling_methods_str[] = {
"default",
"Euler", "Euler",
"Euler A",
"Heun", "Heun",
"DPM2", "DPM2",
"DPM++ (2s)", "DPM++ (2s)",
@ -59,6 +59,7 @@ const char* sampling_methods_str[] = {
"LCM", "LCM",
"DDIM \"trailing\"", "DDIM \"trailing\"",
"TCD", "TCD",
"Euler A",
}; };
/*================================================== Helper Functions ================================================*/ /*================================================== Helper Functions ================================================*/
@ -869,6 +870,53 @@ public:
return true; return true;
} }
void init_scheduler(scheduler_t scheduler) {
switch (scheduler) {
case DISCRETE:
LOG_INFO("running with discrete scheduler");
denoiser->scheduler = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras scheduler");
denoiser->scheduler = std::make_shared<KarrasSchedule>();
break;
case EXPONENTIAL:
LOG_INFO("running exponential scheduler");
denoiser->scheduler = std::make_shared<ExponentialSchedule>();
break;
case AYS:
LOG_INFO("Running with Align-Your-Steps scheduler");
denoiser->scheduler = std::make_shared<AYSSchedule>();
denoiser->scheduler->version = version;
break;
case GITS:
LOG_INFO("Running with GITS scheduler");
denoiser->scheduler = std::make_shared<GITSSchedule>();
denoiser->scheduler->version = version;
break;
case SGM_UNIFORM:
LOG_INFO("Running with SGM Uniform schedule");
denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
denoiser->scheduler->version = version;
break;
case SIMPLE:
LOG_INFO("Running with Simple schedule");
denoiser->scheduler = std::make_shared<SimpleSchedule>();
denoiser->scheduler->version = version;
break;
case SMOOTHSTEP:
LOG_INFO("Running with SmoothStep scheduler");
denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown scheduler %i", scheduler);
abort();
}
}
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(x_t, 0.5); ggml_set_f32(x_t, 0.5);
@ -2227,8 +2275,8 @@ enum rng_type_t str_to_rng_type(const char* str) {
} }
const char* sample_method_to_str[] = { const char* sample_method_to_str[] = {
"default",
"euler", "euler",
"euler_a",
"heun", "heun",
"dpm2", "dpm2",
"dpm++2s_a", "dpm++2s_a",
@ -2239,6 +2287,7 @@ const char* sample_method_to_str[] = {
"lcm", "lcm",
"ddim_trailing", "ddim_trailing",
"tcd", "tcd",
"euler_a",
}; };
const char* sd_sample_method_name(enum sample_method_t sample_method) { const char* sd_sample_method_name(enum sample_method_t sample_method) {
@ -2257,7 +2306,8 @@ enum sample_method_t str_to_sample_method(const char* str) {
return SAMPLE_METHOD_COUNT; return SAMPLE_METHOD_COUNT;
} }
const char* scheduler_to_str[] = { const char* schedule_to_str[] = {
"default",
"discrete", "discrete",
"karras", "karras",
"exponential", "exponential",
@ -2266,23 +2316,22 @@ const char* scheduler_to_str[] = {
"sgm_uniform", "sgm_uniform",
"simple", "simple",
"smoothstep", "smoothstep",
"lcm",
}; };
const char* sd_scheduler_name(enum scheduler_t scheduler) { const char* sd_schedule_name(enum scheduler_t scheduler) {
if (scheduler < SCHEDULER_COUNT) { if (scheduler < SCHEDULE_COUNT) {
return scheduler_to_str[scheduler]; return schedule_to_str[scheduler];
} }
return NONE_STR; return NONE_STR;
} }
enum scheduler_t str_to_scheduler(const char* str) { enum scheduler_t str_to_schedule(const char* str) {
for (int i = 0; i < SCHEDULER_COUNT; i++) { for (int i = 0; i < SCHEDULE_COUNT; i++) {
if (!strcmp(str, scheduler_to_str[i])) { if (!strcmp(str, schedule_to_str[i])) {
return (enum scheduler_t)i; return (enum scheduler_t)i;
} }
} }
return SCHEDULER_COUNT; return SCHEDULE_COUNT;
} }
const char* prediction_to_str[] = { const char* prediction_to_str[] = {
@ -2466,8 +2515,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_start = 0.01f; sample_params->guidance.slg.layer_start = 0.01f;
sample_params->guidance.slg.layer_end = 0.2f; sample_params->guidance.slg.layer_end = 0.2f;
sample_params->guidance.slg.scale = 0.f; sample_params->guidance.slg.scale = 0.f;
sample_params->scheduler = SCHEDULER_COUNT; sample_params->scheduler = DEFAULT;
sample_params->sample_method = SAMPLE_METHOD_COUNT; sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
sample_params->sample_steps = 20; sample_params->sample_steps = 20;
} }
@ -2499,7 +2548,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_start, sample_params->guidance.slg.layer_start,
sample_params->guidance.slg.layer_end, sample_params->guidance.slg.layer_end,
sample_params->guidance.slg.scale, sample_params->guidance.slg.scale,
sd_scheduler_name(sample_params->scheduler), sd_schedule_name(sample_params->scheduler),
sd_sample_method_name(sample_params->sample_method), sd_sample_method_name(sample_params->sample_method),
sample_params->sample_steps, sample_params->sample_steps,
sample_params->eta, sample_params->eta,
@ -2625,21 +2674,13 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
if (sd_version_is_dit(sd_ctx->sd->version)) { SDVersion version = sd_ctx->sd->version;
return EULER_SAMPLE_METHOD; if (sd_version_is_dit(version))
} return EULER;
else
return EULER_A;
} }
return EULER_A_SAMPLE_METHOD; return SAMPLE_METHOD_COUNT;
}
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
}
}
return DISCRETE_SCHEDULER;
} }
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
@ -2759,7 +2800,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
LOG_WARN("Turn off PhotoMaker"); LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false; sd_ctx->sd->stacked_id = false;
} else { } else {
if (pmv2 && pm_params.id_images_count != id_embeds->ne[1]) { if (pm_params.id_images_count != id_embeds->ne[1]) {
LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]); LOG_WARN("PhotoMaker image count (%d) does NOT match ID embeds (%d). You should run face_detect.py again.", pm_params.id_images_count, id_embeds->ne[1]);
LOG_WARN("Turn off PhotoMaker"); LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false; sd_ctx->sd->stacked_id = false;
@ -2825,6 +2866,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int C = sd_ctx->sd->get_latent_channel(); int C = sd_ctx->sd->get_latent_channel();
int W = width / sd_ctx->sd->get_vae_scale_factor(); int W = width / sd_ctx->sd->get_vae_scale_factor();
int H = height / sd_ctx->sd->get_vae_scale_factor(); int H = height / sd_ctx->sd->get_vae_scale_factor();
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* control_latent = nullptr; struct ggml_tensor* control_latent = nullptr;
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) { if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
@ -3053,16 +3095,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
sd_ctx->sd->sampler_rng->manual_seed(seed); sd_ctx->sd->sampler_rng->manual_seed(seed);
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
size_t t0 = ggml_time_ms(); size_t t0 = ggml_time_ms();
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method; sd_ctx->sd->init_scheduler(sd_img_gen_params->sample_params.scheduler);
if (sample_method == SAMPLE_METHOD_COUNT) { std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);
ggml_tensor* init_latent = nullptr; ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr; ggml_tensor* concat_latent = nullptr;
@ -3250,6 +3288,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
} }
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_DEFAULT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
sd_image_t* result_images = generate_image_internal(sd_ctx, sd_image_t* result_images = generate_image_internal(sd_ctx,
work_ctx, work_ctx,
init_latent, init_latent,
@ -3299,14 +3342,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method; sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
int high_noise_sample_steps = 0; int high_noise_sample_steps = 0;
if (sd_ctx->sd->high_noise_diffusion_model) { if (sd_ctx->sd->high_noise_diffusion_model) {
sd_ctx->sd->init_scheduler(sd_vid_gen_params->high_noise_sample_params.scheduler);
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps; high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
} }
@ -3315,7 +3355,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps; total_steps += high_noise_sample_steps;
} }
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps);
if (high_noise_sample_steps < 0) { if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b) // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
@ -3573,12 +3613,6 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
// High Noise Sample // High Noise Sample
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T); LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
if (high_noise_sample_method == SAMPLE_METHOD_COUNT) {
high_noise_sample_method = sd_get_default_sample_method(sd_ctx);
}
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1); std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
@ -3597,7 +3631,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->high_noise_sample_params.guidance, sd_vid_gen_params->high_noise_sample_params.guidance,
sd_vid_gen_params->high_noise_sample_params.eta, sd_vid_gen_params->high_noise_sample_params.eta,
sd_vid_gen_params->high_noise_sample_params.shifted_timestep, sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
high_noise_sample_method, sd_vid_gen_params->high_noise_sample_params.sample_method,
high_noise_sigmas, high_noise_sigmas,
-1, -1,
{}, {},
@ -3634,7 +3668,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_vid_gen_params->sample_params.guidance, sd_vid_gen_params->sample_params.guidance,
sd_vid_gen_params->sample_params.eta, sd_vid_gen_params->sample_params.eta,
sd_vid_gen_params->sample_params.shifted_timestep, sd_vid_gen_params->sample_params.shifted_timestep,
sample_method, sd_vid_gen_params->sample_params.sample_method,
sigmas, sigmas,
-1, -1,
{}, {},

View File

@ -36,32 +36,33 @@ enum rng_type_t {
}; };
enum sample_method_t { enum sample_method_t {
EULER_SAMPLE_METHOD, SAMPLE_METHOD_DEFAULT,
EULER_A_SAMPLE_METHOD, EULER,
HEUN_SAMPLE_METHOD, HEUN,
DPM2_SAMPLE_METHOD, DPM2,
DPMPP2S_A_SAMPLE_METHOD, DPMPP2S_A,
DPMPP2M_SAMPLE_METHOD, DPMPP2M,
DPMPP2Mv2_SAMPLE_METHOD, DPMPP2Mv2,
IPNDM_SAMPLE_METHOD, IPNDM,
IPNDM_V_SAMPLE_METHOD, IPNDM_V,
LCM_SAMPLE_METHOD, LCM,
DDIM_TRAILING_SAMPLE_METHOD, DDIM_TRAILING,
TCD_SAMPLE_METHOD, TCD,
EULER_A,
SAMPLE_METHOD_COUNT SAMPLE_METHOD_COUNT
}; };
enum scheduler_t { enum scheduler_t {
DISCRETE_SCHEDULER, DEFAULT,
KARRAS_SCHEDULER, DISCRETE,
EXPONENTIAL_SCHEDULER, KARRAS,
AYS_SCHEDULER, EXPONENTIAL,
GITS_SCHEDULER, AYS,
SGM_UNIFORM_SCHEDULER, GITS,
SIMPLE_SCHEDULER, SGM_UNIFORM,
SMOOTHSTEP_SCHEDULER, SIMPLE,
LCM_SCHEDULER, SMOOTHSTEP,
SCHEDULER_COUNT SCHEDULE_COUNT
}; };
enum prediction_t { enum prediction_t {
@ -296,8 +297,8 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type);
SD_API enum rng_type_t str_to_rng_type(const char* str); SD_API enum rng_type_t str_to_rng_type(const char* str);
SD_API const char* sd_sample_method_name(enum sample_method_t sample_method); SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
SD_API enum sample_method_t str_to_sample_method(const char* str); SD_API enum sample_method_t str_to_sample_method(const char* str);
SD_API const char* sd_scheduler_name(enum scheduler_t scheduler); SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_scheduler(const char* str); SD_API enum scheduler_t str_to_schedule(const char* str);
SD_API const char* sd_prediction_name(enum prediction_t prediction); SD_API const char* sd_prediction_name(enum prediction_t prediction);
SD_API enum prediction_t str_to_prediction(const char* str); SD_API enum prediction_t str_to_prediction(const char* str);
SD_API const char* sd_preview_name(enum preview_t preview); SD_API const char* sd_preview_name(enum preview_t preview);
@ -312,13 +313,11 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);