refactor: optimize the handling of scheduler (#998)

This commit is contained in:
leejet 2025-11-22 12:48:53 +08:00 committed by GitHub
parent e9bc3b6c06
commit 869d023416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 104 additions and 122 deletions

View File

@ -11,14 +11,13 @@
#define TIMESTEPS 1000 #define TIMESTEPS 1000
#define FLUX_TIMESTEPS 1000 #define FLUX_TIMESTEPS 1000
struct SigmaSchedule { struct SigmaScheduler {
int version = 0;
typedef std::function<float(float)> t_to_sigma_t; typedef std::function<float(float)> t_to_sigma_t;
virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0; virtual std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
}; };
struct DiscreteSchedule : SigmaSchedule { struct DiscreteScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result; std::vector<float> result;
@ -42,7 +41,7 @@ struct DiscreteSchedule : SigmaSchedule {
} }
}; };
struct ExponentialSchedule : SigmaSchedule { struct ExponentialScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> sigmas; std::vector<float> sigmas;
@ -149,7 +148,10 @@ std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
/* /*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/ */
struct AYSSchedule : SigmaSchedule { struct AYSScheduler : SigmaScheduler {
SDVersion version;
explicit AYSScheduler(SDVersion version)
: version(version) {}
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
const std::vector<float> noise_levels[] = { const std::vector<float> noise_levels[] = {
/* SD1.5 */ /* SD1.5 */
@ -169,19 +171,19 @@ struct AYSSchedule : SigmaSchedule {
std::vector<float> results(n + 1); std::vector<float> results(n + 1);
if (sd_version_is_sd2((SDVersion)version)) { if (sd_version_is_sd2((SDVersion)version)) {
LOG_WARN("AYS not designed for SD2.X models"); LOG_WARN("AYS_SCHEDULER not designed for SD2.X models");
} /* fallthrough */ } /* fallthrough */
else if (sd_version_is_sd1((SDVersion)version)) { else if (sd_version_is_sd1((SDVersion)version)) {
LOG_INFO("AYS using SD1.5 noise levels"); LOG_INFO("AYS_SCHEDULER using SD1.5 noise levels");
inputs = noise_levels[0]; inputs = noise_levels[0];
} else if (sd_version_is_sdxl((SDVersion)version)) { } else if (sd_version_is_sdxl((SDVersion)version)) {
LOG_INFO("AYS using SDXL noise levels"); LOG_INFO("AYS_SCHEDULER using SDXL noise levels");
inputs = noise_levels[1]; inputs = noise_levels[1];
} else if (version == VERSION_SVD) { } else if (version == VERSION_SVD) {
LOG_INFO("AYS using SVD noise levels"); LOG_INFO("AYS_SCHEDULER using SVD noise levels");
inputs = noise_levels[2]; inputs = noise_levels[2];
} else { } else {
LOG_ERROR("Version not compatible with AYS scheduler"); LOG_ERROR("Version not compatible with AYS_SCHEDULER scheduler");
return results; return results;
} }
@ -203,7 +205,7 @@ struct AYSSchedule : SigmaSchedule {
/* /*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main * GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/ */
struct GITSSchedule : SigmaSchedule { struct GITSScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
if (sigma_max <= 0.0f) { if (sigma_max <= 0.0f) {
return std::vector<float>{}; return std::vector<float>{};
@ -232,7 +234,7 @@ struct GITSSchedule : SigmaSchedule {
} }
}; };
struct SGMUniformSchedule : SigmaSchedule { struct SGMUniformScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
std::vector<float> result; std::vector<float> result;
if (n == 0) { if (n == 0) {
@ -251,7 +253,7 @@ struct SGMUniformSchedule : SigmaSchedule {
} }
}; };
struct KarrasSchedule : SigmaSchedule { struct KarrasScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
// These *COULD* be function arguments here, // These *COULD* be function arguments here,
// but does anybody ever bother to touch them? // but does anybody ever bother to touch them?
@ -270,7 +272,7 @@ struct KarrasSchedule : SigmaSchedule {
} }
}; };
struct SimpleSchedule : SigmaSchedule { struct SimpleScheduler : SigmaScheduler {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector<float> result_sigmas; std::vector<float> result_sigmas;
@ -299,8 +301,8 @@ struct SimpleSchedule : SigmaSchedule {
} }
}; };
// Close to Beta Schedule, but increadably simple in code. // Close to Beta Scheduler, but increadably simple in code.
struct SmoothStepSchedule : SigmaSchedule { struct SmoothStepScheduler : SigmaScheduler {
static constexpr float smoothstep(float x) { static constexpr float smoothstep(float x) {
return x * x * (3.0f - 2.0f * x); return x * x * (3.0f - 2.0f * x);
} }
@ -329,7 +331,6 @@ struct SmoothStepSchedule : SigmaSchedule {
}; };
struct Denoiser { struct Denoiser {
std::shared_ptr<SigmaSchedule> scheduler = std::make_shared<DiscreteSchedule>();
virtual float sigma_min() = 0; virtual float sigma_min() = 0;
virtual float sigma_max() = 0; virtual float sigma_max() = 0;
virtual float sigma_to_t(float sigma) = 0; virtual float sigma_to_t(float sigma) = 0;
@ -338,8 +339,47 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0; virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0; virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
virtual std::vector<float> get_sigmas(uint32_t n) { virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) {
case DISCRETE_SCHEDULER:
LOG_INFO("get_sigmas with discrete scheduler");
scheduler = std::make_shared<DiscreteScheduler>();
break;
case KARRAS_SCHEDULER:
LOG_INFO("get_sigmas with Karras scheduler");
scheduler = std::make_shared<KarrasScheduler>();
break;
case EXPONENTIAL_SCHEDULER:
LOG_INFO("get_sigmas exponential scheduler");
scheduler = std::make_shared<ExponentialScheduler>();
break;
case AYS_SCHEDULER:
LOG_INFO("get_sigmas with Align-Your-Steps scheduler");
scheduler = std::make_shared<AYSScheduler>(version);
break;
case GITS_SCHEDULER:
LOG_INFO("get_sigmas with GITS scheduler");
scheduler = std::make_shared<GITSScheduler>();
break;
case SGM_UNIFORM_SCHEDULER:
LOG_INFO("get_sigmas with SGM Uniform scheduler");
scheduler = std::make_shared<SGMUniformScheduler>();
break;
case SIMPLE_SCHEDULER:
LOG_INFO("get_sigmas with Simple scheduler");
scheduler = std::make_shared<SimpleScheduler>();
break;
case SMOOTHSTEP_SCHEDULER:
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared<SmoothStepScheduler>();
break;
default:
LOG_INFO("get_sigmas with discrete scheduler (default)");
scheduler = std::make_shared<DiscreteScheduler>();
break;
}
return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); return scheduler->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
} }
}; };
@ -426,7 +466,6 @@ struct EDMVDenoiser : public CompVisVDenoiser {
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) { : min_sigma(min_sigma), max_sigma(max_sigma) {
scheduler = std::make_shared<ExponentialSchedule>();
} }
float t_to_sigma(float t) override { float t_to_sigma(float t) override {
@ -1109,7 +1148,7 @@ static void sample_k_diffusion(sample_method_t method,
// end beta) (which unfortunately k-diffusion's data // end beta) (which unfortunately k-diffusion's data
// structure hides from the denoiser), and the sigmas are // structure hides from the denoiser), and the sigmas are
// also needed to invert the behavior of CompVisDenoiser // also needed to invert the behavior of CompVisDenoiser
// (k-diffusion's LMSDiscreteScheduler) // (k-diffusion's LMSDiscreteSchedulerr)
float beta_start = 0.00085f; float beta_start = 0.00085f;
float beta_end = 0.0120f; float beta_end = 0.0120f;
std::vector<double> alphas_cumprod; std::vector<double> alphas_cumprod;
@ -1137,7 +1176,7 @@ static void sample_k_diffusion(sample_method_t method,
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
// The "trailing" DDIM timestep, see S. Lin et al., // The "trailing" DDIM timestep, see S. Lin et al.,
// "Common Diffusion Noise Schedules and Sample Steps // "Common Diffusion Noise Schedulers and Sample Steps
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table // are Flawed", arXiv:2305.08891 [cs], p. 4, Table
// 2. Most variables below follow Diffusers naming // 2. Most variables below follow Diffusers naming
// //

View File

@ -912,13 +912,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_schedule_arg = [&](int argc, const char** argv, int index) { auto on_scheduler_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
} }
const char* arg = argv[index]; const char* arg = argv[index];
params.sample_params.scheduler = str_to_schedule(arg); params.sample_params.scheduler = str_to_scheduler(arg);
if (params.sample_params.scheduler == SCHEDULE_COUNT) { if (params.sample_params.scheduler == SCHEDULER_COUNT) {
fprintf(stderr, "error: invalid scheduler %s\n", fprintf(stderr, "error: invalid scheduler %s\n",
arg); arg);
return -1; return -1;
@ -926,20 +926,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_high_noise_schedule_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* arg = argv[index];
params.high_noise_sample_params.scheduler = str_to_schedule(arg);
if (params.high_noise_sample_params.scheduler == SCHEDULE_COUNT) {
fprintf(stderr, "error: invalid high noise scheduler %s\n",
arg);
return -1;
}
return 1;
};
auto on_prediction_arg = [&](int argc, const char** argv, int index) { auto on_prediction_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -1212,7 +1198,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", {"",
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_schedule_arg}, on_scheduler_arg},
{"", {"",
"--skip-layers", "--skip-layers",
"layers to skip for SLG steps (default: [7,8,9])", "layers to skip for SLG steps (default: [7,8,9])",
@ -1222,10 +1208,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]" "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
" default: euler for Flux/SD3/Wan, euler_a otherwise", " default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg}, on_high_noise_sample_method_arg},
{"",
"--high-noise-scheduler",
"(high noise) denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
on_high_noise_schedule_arg},
{"", {"",
"--high-noise-skip-layers", "--high-noise-skip-layers",
"(high noise) layers to skip for SLG steps (default: [7,8,9])", "(high noise) layers to skip for SLG steps (default: [7,8,9])",
@ -1442,8 +1424,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", "; parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
} }
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method)); parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
if (params.sample_params.scheduler != DEFAULT) { if (params.sample_params.scheduler != SCHEDULER_COUNT) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler)); parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
@ -1924,6 +1906,10 @@ int main(int argc, const char* argv[]) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
} }
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
}
if (params.mode == IMG_GEN) { if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(), params.prompt.c_str(),

View File

@ -870,53 +870,6 @@ public:
return true; return true;
} }
void init_scheduler(scheduler_t scheduler) {
switch (scheduler) {
case DISCRETE:
LOG_INFO("running with discrete scheduler");
denoiser->scheduler = std::make_shared<DiscreteSchedule>();
break;
case KARRAS:
LOG_INFO("running with Karras scheduler");
denoiser->scheduler = std::make_shared<KarrasSchedule>();
break;
case EXPONENTIAL:
LOG_INFO("running exponential scheduler");
denoiser->scheduler = std::make_shared<ExponentialSchedule>();
break;
case AYS:
LOG_INFO("Running with Align-Your-Steps scheduler");
denoiser->scheduler = std::make_shared<AYSSchedule>();
denoiser->scheduler->version = version;
break;
case GITS:
LOG_INFO("Running with GITS scheduler");
denoiser->scheduler = std::make_shared<GITSSchedule>();
denoiser->scheduler->version = version;
break;
case SGM_UNIFORM:
LOG_INFO("Running with SGM Uniform schedule");
denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
denoiser->scheduler->version = version;
break;
case SIMPLE:
LOG_INFO("Running with Simple schedule");
denoiser->scheduler = std::make_shared<SimpleSchedule>();
denoiser->scheduler->version = version;
break;
case SMOOTHSTEP:
LOG_INFO("Running with SmoothStep scheduler");
denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
break;
case DEFAULT:
// Don't touch anything.
break;
default:
LOG_ERROR("Unknown scheduler %i", scheduler);
abort();
}
}
bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) {
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(x_t, 0.5); ggml_set_f32(x_t, 0.5);
@ -2306,8 +2259,7 @@ enum sample_method_t str_to_sample_method(const char* str) {
return SAMPLE_METHOD_COUNT; return SAMPLE_METHOD_COUNT;
} }
const char* schedule_to_str[] = { const char* scheduler_to_str[] = {
"default",
"discrete", "discrete",
"karras", "karras",
"exponential", "exponential",
@ -2318,20 +2270,20 @@ const char* schedule_to_str[] = {
"smoothstep", "smoothstep",
}; };
const char* sd_schedule_name(enum scheduler_t scheduler) { const char* sd_scheduler_name(enum scheduler_t scheduler) {
if (scheduler < SCHEDULE_COUNT) { if (scheduler < SCHEDULER_COUNT) {
return schedule_to_str[scheduler]; return scheduler_to_str[scheduler];
} }
return NONE_STR; return NONE_STR;
} }
enum scheduler_t str_to_schedule(const char* str) { enum scheduler_t str_to_scheduler(const char* str) {
for (int i = 0; i < SCHEDULE_COUNT; i++) { for (int i = 0; i < SCHEDULER_COUNT; i++) {
if (!strcmp(str, schedule_to_str[i])) { if (!strcmp(str, scheduler_to_str[i])) {
return (enum scheduler_t)i; return (enum scheduler_t)i;
} }
} }
return SCHEDULE_COUNT; return SCHEDULER_COUNT;
} }
const char* prediction_to_str[] = { const char* prediction_to_str[] = {
@ -2515,7 +2467,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_start = 0.01f; sample_params->guidance.slg.layer_start = 0.01f;
sample_params->guidance.slg.layer_end = 0.2f; sample_params->guidance.slg.layer_end = 0.2f;
sample_params->guidance.slg.scale = 0.f; sample_params->guidance.slg.scale = 0.f;
sample_params->scheduler = DEFAULT; sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_DEFAULT; sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
sample_params->sample_steps = 20; sample_params->sample_steps = 20;
} }
@ -2548,7 +2500,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_start, sample_params->guidance.slg.layer_start,
sample_params->guidance.slg.layer_end, sample_params->guidance.slg.layer_end,
sample_params->guidance.slg.scale, sample_params->guidance.slg.scale,
sd_schedule_name(sample_params->scheduler), sd_scheduler_name(sample_params->scheduler),
sd_sample_method_name(sample_params->sample_method), sd_sample_method_name(sample_params->sample_method),
sample_params->sample_steps, sample_params->sample_steps,
sample_params->eta, sample_params->eta,
@ -2683,6 +2635,14 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
return SAMPLE_METHOD_COUNT; return SAMPLE_METHOD_COUNT;
} }
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
struct ggml_context* work_ctx, struct ggml_context* work_ctx,
ggml_tensor* init_latent, ggml_tensor* init_latent,
@ -3099,8 +3059,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
size_t t0 = ggml_time_ms(); size_t t0 = ggml_time_ms();
sd_ctx->sd->init_scheduler(sd_img_gen_params->sample_params.scheduler); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
ggml_tensor* init_latent = nullptr; ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr; ggml_tensor* concat_latent = nullptr;
@ -3342,11 +3301,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
int high_noise_sample_steps = 0; int high_noise_sample_steps = 0;
if (sd_ctx->sd->high_noise_diffusion_model) { if (sd_ctx->sd->high_noise_diffusion_model) {
sd_ctx->sd->init_scheduler(sd_vid_gen_params->high_noise_sample_params.scheduler);
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps; high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
} }
@ -3355,7 +3311,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps; total_steps += high_noise_sample_steps;
} }
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);
if (high_noise_sample_steps < 0) { if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b) // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)

View File

@ -53,16 +53,15 @@ enum sample_method_t {
}; };
enum scheduler_t { enum scheduler_t {
DEFAULT, DISCRETE_SCHEDULER,
DISCRETE, KARRAS_SCHEDULER,
KARRAS, EXPONENTIAL_SCHEDULER,
EXPONENTIAL, AYS_SCHEDULER,
AYS, GITS_SCHEDULER,
GITS, SGM_UNIFORM_SCHEDULER,
SGM_UNIFORM, SIMPLE_SCHEDULER,
SIMPLE, SMOOTHSTEP_SCHEDULER,
SMOOTHSTEP, SCHEDULER_COUNT
SCHEDULE_COUNT
}; };
enum prediction_t { enum prediction_t {
@ -297,8 +296,8 @@ SD_API const char* sd_rng_type_name(enum rng_type_t rng_type);
SD_API enum rng_type_t str_to_rng_type(const char* str); SD_API enum rng_type_t str_to_rng_type(const char* str);
SD_API const char* sd_sample_method_name(enum sample_method_t sample_method); SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
SD_API enum sample_method_t str_to_sample_method(const char* str); SD_API enum sample_method_t str_to_sample_method(const char* str);
SD_API const char* sd_schedule_name(enum scheduler_t scheduler); SD_API const char* sd_scheduler_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_schedule(const char* str); SD_API enum scheduler_t str_to_scheduler(const char* str);
SD_API const char* sd_prediction_name(enum prediction_t prediction); SD_API const char* sd_prediction_name(enum prediction_t prediction);
SD_API enum prediction_t str_to_prediction(const char* str); SD_API enum prediction_t str_to_prediction(const char* str);
SD_API const char* sd_preview_name(enum preview_t preview); SD_API const char* sd_preview_name(enum preview_t preview);
@ -313,11 +312,13 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);