feat: add support for custom scheduler (#694)

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
rmatif 2025-12-13 09:20:02 +01:00 committed by GitHub
parent 15d0f82760
commit 8f05f5bc6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 101 additions and 9 deletions

View File

@ -121,6 +121,7 @@ Generation Options:
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
default: discrete default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)

View File

@ -258,7 +258,15 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", "; parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
} }
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method)); parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { if (!gen_params.custom_sigmas.empty()) {
parameter_string += ", Custom Sigmas: [";
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
}
parameter_string += "]";
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler)); parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
} }
parameter_string += ", "; parameter_string += ", ";
@ -806,4 +814,4 @@ int main(int argc, const char* argv[]) {
release_all_resources(); release_all_resources();
return 0; return 0;
} }

View File

@ -883,6 +883,8 @@ struct SDGenerationParams {
std::vector<int> high_noise_skip_layers = {7, 8, 9}; std::vector<int> high_noise_skip_layers = {7, 8, 9};
sd_sample_params_t high_noise_sample_params; sd_sample_params_t high_noise_sample_params;
std::vector<float> custom_sigmas;
std::string easycache_option; std::string easycache_option;
sd_easycache_params_t easycache_params; sd_easycache_params_t easycache_params;
@ -1201,6 +1203,43 @@ struct SDGenerationParams {
return 1; return 1;
}; };
auto on_sigmas_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string sigmas_str = argv[index];
if (!sigmas_str.empty() && sigmas_str.front() == '[') {
sigmas_str.erase(0, 1);
}
if (!sigmas_str.empty() && sigmas_str.back() == ']') {
sigmas_str.pop_back();
}
std::stringstream ss(sigmas_str);
std::string item;
while (std::getline(ss, item, ',')) {
item.erase(0, item.find_first_not_of(" \t\n\r\f\v"));
item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1);
if (!item.empty()) {
try {
custom_sigmas.push_back(std::stof(item));
} catch (const std::invalid_argument& e) {
fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str());
return -1;
} catch (const std::out_of_range& e) {
fprintf(stderr, "error: float value '%s' out of range in --sigmas\n", item.c_str());
return -1;
}
}
}
if (custom_sigmas.empty() && !sigmas_str.empty()) {
fprintf(stderr, "error: could not parse any sigma values from '%s'\n", argv[index]);
return -1;
}
return 1;
};
auto on_ref_image_arg = [&](int argc, const char** argv, int index) { auto on_ref_image_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) { if (++index >= argc) {
return -1; return -1;
@ -1260,6 +1299,10 @@ struct SDGenerationParams {
"--scheduler", "--scheduler",
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete", "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
on_scheduler_arg}, on_scheduler_arg},
{"",
"--sigmas",
"custom sigma values for the sampler, comma-separated (e.g., \"14.61,7.8,3.5,0.0\").",
on_sigmas_arg},
{"", {"",
"--skip-layers", "--skip-layers",
"layers to skip for SLG steps (default: [7,8,9])", "layers to skip for SLG steps (default: [7,8,9])",
@ -1512,6 +1555,8 @@ struct SDGenerationParams {
sample_params.guidance.slg.layers = skip_layers.data(); sample_params.guidance.slg.layers = skip_layers.data();
sample_params.guidance.slg.layer_count = skip_layers.size(); sample_params.guidance.slg.layer_count = skip_layers.size();
sample_params.custom_sigmas = custom_sigmas.data();
sample_params.custom_sigmas_count = static_cast<int>(custom_sigmas.size());
high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data(); high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data();
high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size(); high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size();
@ -1606,6 +1651,7 @@ struct SDGenerationParams {
<< " sample_params: " << sample_params_str << ",\n" << " sample_params: " << sample_params_str << ",\n"
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" << " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
<< " easycache_option: \"" << easycache_option << "\",\n" << " easycache_option: \"" << easycache_option << "\",\n"
<< " easycache: " << " easycache: "
<< (easycache_params.enabled ? "enabled" : "disabled") << (easycache_params.enabled ? "enabled" : "disabled")

View File

@ -115,6 +115,7 @@ Default Generation Options:
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], --scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
default: discrete default: discrete
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
--skip-layers layers to skip for SLG steps (default: [7,8,9]) --skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9]) --high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times) -r, --ref-image reference image for Flux Kontext models (can be used multiple times)

View File

@ -2600,6 +2600,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->scheduler = SCHEDULER_COUNT; sample_params->scheduler = SCHEDULER_COUNT;
sample_params->sample_method = SAMPLE_METHOD_COUNT; sample_params->sample_method = SAMPLE_METHOD_COUNT;
sample_params->sample_steps = 20; sample_params->sample_steps = 20;
sample_params->custom_sigmas = nullptr;
sample_params->custom_sigmas_count = 0;
} }
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) { char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@ -3194,11 +3196,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
} }
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
int sample_steps = sd_img_gen_params->sample_params.sample_steps; int sample_steps = sd_img_gen_params->sample_params.sample_steps;
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, std::vector<float> sigmas;
sd_ctx->sd->get_image_seq_len(height, width), if (sd_img_gen_params->sample_params.custom_sigmas_count > 0) {
sd_img_gen_params->sample_params.scheduler, sigmas = std::vector<float>(sd_img_gen_params->sample_params.custom_sigmas,
sd_ctx->sd->version); sd_img_gen_params->sample_params.custom_sigmas + sd_img_gen_params->sample_params.custom_sigmas_count);
if (sample_steps != sigmas.size() - 1) {
sample_steps = static_cast<int>(sigmas.size()) - 1;
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
}
} else {
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
}
ggml_tensor* init_latent = nullptr; ggml_tensor* init_latent = nullptr;
ggml_tensor* concat_latent = nullptr; ggml_tensor* concat_latent = nullptr;
@ -3461,7 +3473,29 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (high_noise_sample_steps > 0) { if (high_noise_sample_steps > 0) {
total_steps += high_noise_sample_steps; total_steps += high_noise_sample_steps;
} }
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, 0, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);
std::vector<float> sigmas;
if (sd_vid_gen_params->sample_params.custom_sigmas_count > 0) {
sigmas = std::vector<float>(sd_vid_gen_params->sample_params.custom_sigmas,
sd_vid_gen_params->sample_params.custom_sigmas + sd_vid_gen_params->sample_params.custom_sigmas_count);
if (total_steps != sigmas.size() - 1) {
total_steps = static_cast<int>(sigmas.size()) - 1;
LOG_WARN("total_steps != custom_sigmas_count - 1, set total_steps to %d", total_steps);
if (sample_steps >= total_steps) {
sample_steps = total_steps;
LOG_WARN("total_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
}
if (high_noise_sample_steps > 0) {
high_noise_sample_steps = total_steps - sample_steps;
LOG_WARN("total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d", high_noise_sample_steps);
}
}
} else {
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
0,
sd_vid_gen_params->sample_params.scheduler,
sd_ctx->sd->version);
}
if (high_noise_sample_steps < 0) { if (high_noise_sample_steps < 0) {
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b) // timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
@ -3841,4 +3875,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
LOG_INFO("generate_video completed in %.2fs", (t5 - t0) * 1.0f / 1000); LOG_INFO("generate_video completed in %.2fs", (t5 - t0) * 1.0f / 1000);
return result_images; return result_images;
} }

View File

@ -225,6 +225,8 @@ typedef struct {
int sample_steps; int sample_steps;
float eta; float eta;
int shifted_timestep; int shifted_timestep;
float* custom_sigmas;
int custom_sigmas_count;
} sd_sample_params_t; } sd_sample_params_t;
typedef struct { typedef struct {