feat: default to LCM scheduler for LCM sampling (#1109)

* feat: default to LCM scheduler for LCM sampling

* fix bug and attempt to get default scheduler for vid_gen when none is set

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Wagner Bruna 2025-12-18 10:43:39 -03:00 committed by GitHub
parent 97cf2efe45
commit 78e15bd4af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 6 deletions

View File

@ -579,7 +579,7 @@ int main(int argc, const char* argv[]) {
}
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
}
if (cli_params.mode == IMG_GEN) {
@ -752,4 +752,4 @@ int main(int argc, const char* argv[]) {
release_all_resources();
return 0;
}
}

View File

@ -2777,13 +2777,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
return EULER_A_SAMPLE_METHOD;
}
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
}
}
if (sample_method == LCM_SAMPLE_METHOD) {
return LCM_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}
@ -3218,9 +3221,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
}
} else {
scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler;
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
}
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
scheduler,
sd_ctx->sd->version);
}
@ -3503,9 +3510,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
}
} else {
scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler;
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
}
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
0,
sd_vid_gen_params->sample_params.scheduler,
scheduler,
sd_ctx->sd->version);
}

View File

@ -335,7 +335,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method);
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);