mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 10:43:35 +00:00
feat: default to LCM scheduler for LCM sampling (#1109)
* feat: default to LCM scheduler for LCM sampling * fix bug and attempt to get default scheduler for vid_gen when none is set --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
97cf2efe45
commit
78e15bd4af
@ -579,7 +579,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
|
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
|
||||||
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
|
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cli_params.mode == IMG_GEN) {
|
if (cli_params.mode == IMG_GEN) {
|
||||||
@ -752,4 +752,4 @@ int main(int argc, const char* argv[]) {
|
|||||||
release_all_resources();
|
release_all_resources();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2777,13 +2777,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
|||||||
return EULER_A_SAMPLE_METHOD;
|
return EULER_A_SAMPLE_METHOD;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
|
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) {
|
||||||
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
||||||
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
|
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
|
||||||
if (edm_v_denoiser) {
|
if (edm_v_denoiser) {
|
||||||
return EXPONENTIAL_SCHEDULER;
|
return EXPONENTIAL_SCHEDULER;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (sample_method == LCM_SAMPLE_METHOD) {
|
||||||
|
return LCM_SCHEDULER;
|
||||||
|
}
|
||||||
return DISCRETE_SCHEDULER;
|
return DISCRETE_SCHEDULER;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3218,9 +3221,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
|
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler;
|
||||||
|
if (scheduler == SCHEDULER_COUNT) {
|
||||||
|
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||||
|
}
|
||||||
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
|
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
|
||||||
sd_ctx->sd->get_image_seq_len(height, width),
|
sd_ctx->sd->get_image_seq_len(height, width),
|
||||||
sd_img_gen_params->sample_params.scheduler,
|
scheduler,
|
||||||
sd_ctx->sd->version);
|
sd_ctx->sd->version);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3503,9 +3510,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler;
|
||||||
|
if (scheduler == SCHEDULER_COUNT) {
|
||||||
|
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
|
||||||
|
}
|
||||||
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
|
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
|
||||||
0,
|
0,
|
||||||
sd_vid_gen_params->sample_params.scheduler,
|
scheduler,
|
||||||
sd_ctx->sd->version);
|
sd_ctx->sd->version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -335,7 +335,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
|
|||||||
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
||||||
|
|
||||||
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
|
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
|
||||||
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
|
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method);
|
||||||
|
|
||||||
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
|
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
|
||||||
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
|
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user