feat: default to LCM scheduler for LCM sampling (#1109)

* feat: default to LCM scheduler for LCM sampling

* fix bug and attempt to get default scheduler for vid_gen when none is set

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Wagner Bruna 2025-12-18 10:43:39 -03:00 committed by GitHub
parent 97cf2efe45
commit 78e15bd4af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 6 deletions

View File

@ -579,7 +579,7 @@ int main(int argc, const char* argv[]) {
} }
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) { if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx); gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
} }
if (cli_params.mode == IMG_GEN) { if (cli_params.mode == IMG_GEN) {
@ -752,4 +752,4 @@ int main(int argc, const char* argv[]) {
release_all_resources(); release_all_resources();
return 0; return 0;
} }

View File

@ -2777,13 +2777,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
return EULER_A_SAMPLE_METHOD; return EULER_A_SAMPLE_METHOD;
} }
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) { enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser); auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) { if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER; return EXPONENTIAL_SCHEDULER;
} }
} }
if (sample_method == LCM_SAMPLE_METHOD) {
return LCM_SCHEDULER;
}
return DISCRETE_SCHEDULER; return DISCRETE_SCHEDULER;
} }
@ -3218,9 +3221,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps); LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
} }
} else { } else {
scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler;
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
}
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width), sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler, scheduler,
sd_ctx->sd->version); sd_ctx->sd->version);
} }
@ -3503,9 +3510,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
} }
} else { } else {
scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler;
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
}
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps, sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
0, 0,
sd_vid_gen_params->sample_params.scheduler, scheduler,
sd_ctx->sd->version); sd_ctx->sd->version);
} }

View File

@ -335,7 +335,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx); SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx); SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method);
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);