From 20345888a313c11826e5d511c5c00109b8ab61a6 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 22 Nov 2025 14:00:25 +0800 Subject: [PATCH] refactor: optimize the handling of sample method (#999) --- denoiser.hpp | 28 ++++++++++----------- examples/cli/main.cpp | 6 ++++- stable-diffusion.cpp | 57 +++++++++++++++++++++++++------------------ stable-diffusion.h | 25 +++++++++---------- 4 files changed, 64 insertions(+), 52 deletions(-) diff --git a/denoiser.hpp b/denoiser.hpp index c9f9bc7..12ba8a7 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -640,7 +640,7 @@ static void sample_k_diffusion(sample_method_t method, size_t steps = sigmas.size() - 1; // sample_euler_ancestral switch (method) { - case EULER_A: { + case EULER_A_SAMPLE_METHOD: { struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); @@ -693,7 +693,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case EULER: // Implemented without any sigma churn + case EULER_SAMPLE_METHOD: // Implemented without any sigma churn { struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); @@ -726,7 +726,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case HEUN: { + case HEUN_SAMPLE_METHOD: { struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); @@ -776,7 +776,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case DPM2: { + case DPM2_SAMPLE_METHOD: { struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); @@ -828,7 +828,7 @@ static void sample_k_diffusion(sample_method_t method, } } break; - case DPMPP2S_A: { + case DPMPP2S_A_SAMPLE_METHOD: { struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x); @@ -892,7 +892,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case DPMPP2M: // DPM++ (2M) from Karras et al (2022) + case DPMPP2M_SAMPLE_METHOD: // DPM++ (2M) from Karras et al (2022) { struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x); @@ -931,7 +931,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 + case DPMPP2Mv2_SAMPLE_METHOD: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457 { struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x); @@ -974,7 +974,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main + case IPNDM_SAMPLE_METHOD: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main { int max_order = 4; ggml_tensor* x_next = x; @@ -1049,7 +1049,7 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main + case IPNDM_V_SAMPLE_METHOD: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main { int max_order = 4; std::vector buffer_model; @@ -1123,7 +1123,7 @@ static void sample_k_diffusion(sample_method_t method, d_cur = ggml_dup_tensor(work_ctx, x_next); } } break; - case LCM: // Latent Consistency Models + case LCM_SAMPLE_METHOD: // Latent Consistency Models { struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x); @@ -1158,8 +1158,8 @@ static void sample_k_diffusion(sample_method_t method, } } } break; - case DDIM_TRAILING: // Denoising Diffusion Implicit Models - // with the "trailing" timestep spacing + case DDIM_TRAILING_SAMPLE_METHOD: // Denoising Diffusion Implicit Models + // with the "trailing" timestep spacing { // See J. Song et al., "Denoising Diffusion Implicit // Models", arXiv:2010.02502 [cs.LG] @@ -1352,8 +1352,8 @@ static void sample_k_diffusion(sample_method_t method, // factor c_in. } } break; - case TCD: // Strategic Stochastic Sampling (Algorithm 4) in - // Trajectory Consistency Distillation + case TCD_SAMPLE_METHOD: // Strategic Stochastic Sampling (Algorithm 4) in + // Trajectory Consistency Distillation { // See J. Zheng et al., "Trajectory Consistency // Distillation: Improved Latent Consistency Distillation diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index f12f145..427364a 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1902,10 +1902,14 @@ int main(int argc, const char* argv[]) { return 1; } - if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { + if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) { params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); } + if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { + params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx); + } + if (params.sample_params.scheduler == SCHEDULER_COUNT) { params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx); } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 063661d..b129d53 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -47,8 +47,8 @@ const char* model_version_to_str[] = { }; const char* sampling_methods_str[] = { - "default", "Euler", + "Euler A", "Heun", "DPM2", "DPM++ (2s)", @@ -59,7 +59,6 @@ const char* sampling_methods_str[] = { "LCM", "DDIM \"trailing\"", "TCD", - "Euler A", }; /*================================================== Helper Functions ================================================*/ @@ -2228,8 +2227,8 @@ enum rng_type_t str_to_rng_type(const char* str) { } const char* sample_method_to_str[] = { - "default", "euler", + "euler_a", "heun", "dpm2", "dpm++2s_a", @@ -2240,7 +2239,6 @@ const char* sample_method_to_str[] = { "lcm", "ddim_trailing", "tcd", - "euler_a", }; const char* sd_sample_method_name(enum sample_method_t sample_method) { @@ -2469,7 +2467,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) { sample_params->guidance.slg.layer_end = 0.2f; sample_params->guidance.slg.scale = 0.f; sample_params->scheduler = SCHEDULER_COUNT; - sample_params->sample_method = SAMPLE_METHOD_DEFAULT; + sample_params->sample_method = SAMPLE_METHOD_COUNT; sample_params->sample_steps = 20; } @@ -2627,19 +2625,19 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { - SDVersion version = sd_ctx->sd->version; - if (sd_version_is_dit(version)) - return EULER; - else - return EULER_A; + if (sd_version_is_dit(sd_ctx->sd->version)) { + return EULER_SAMPLE_METHOD; + } } - return SAMPLE_METHOD_COUNT; + return EULER_A_SAMPLE_METHOD; } enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) { - auto edm_v_denoiser = std::dynamic_pointer_cast(sd_ctx->sd->denoiser); - if (edm_v_denoiser) { - return EXPONENTIAL_SCHEDULER; + if (sd_ctx != nullptr && sd_ctx->sd != nullptr) { + auto edm_v_denoiser = std::dynamic_pointer_cast(sd_ctx->sd->denoiser); + if (edm_v_denoiser) { + return EXPONENTIAL_SCHEDULER; + } } return DISCRETE_SCHEDULER; } @@ -2827,7 +2825,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int C = sd_ctx->sd->get_latent_channel(); int W = width / sd_ctx->sd->get_vae_scale_factor(); int H = height / sd_ctx->sd->get_vae_scale_factor(); - LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); struct ggml_tensor* control_latent = nullptr; if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) { @@ -3056,10 +3053,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->sampler_rng->manual_seed(seed); - int sample_steps = sd_img_gen_params->sample_params.sample_steps; - size_t t0 = ggml_time_ms(); + enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method; + if (sample_method == SAMPLE_METHOD_COUNT) { + sample_method = sd_get_default_sample_method(sd_ctx); + } + LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + + int sample_steps = sd_img_gen_params->sample_params.sample_steps; std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version); ggml_tensor* init_latent = nullptr; @@ -3248,11 +3250,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); } - enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method; - if (sample_method == SAMPLE_METHOD_DEFAULT) { - sample_method = sd_get_default_sample_method(sd_ctx); - } - sd_image_t* result_images = generate_image_internal(sd_ctx, work_ctx, init_latent, @@ -3302,6 +3299,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); + enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method; + if (sample_method == SAMPLE_METHOD_COUNT) { + sample_method = sd_get_default_sample_method(sd_ctx); + } + LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + int high_noise_sample_steps = 0; if (sd_ctx->sd->high_noise_diffusion_model) { high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps; @@ -3570,6 +3573,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s // High Noise Sample if (high_noise_sample_steps > 0) { LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T); + enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method; + if (high_noise_sample_method == SAMPLE_METHOD_COUNT) { + high_noise_sample_method = sd_get_default_sample_method(sd_ctx); + } + LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]); + int64_t sampling_start = ggml_time_ms(); std::vector high_noise_sigmas = std::vector(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1); @@ -3588,7 +3597,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_vid_gen_params->high_noise_sample_params.guidance, sd_vid_gen_params->high_noise_sample_params.eta, sd_vid_gen_params->high_noise_sample_params.shifted_timestep, - sd_vid_gen_params->high_noise_sample_params.sample_method, + high_noise_sample_method, high_noise_sigmas, -1, {}, @@ -3625,7 +3634,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_vid_gen_params->sample_params.guidance, sd_vid_gen_params->sample_params.eta, sd_vid_gen_params->sample_params.shifted_timestep, - sd_vid_gen_params->sample_params.sample_method, + sample_method, sigmas, -1, {}, diff --git a/stable-diffusion.h b/stable-diffusion.h index 83c53c7..309da9b 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -36,19 +36,18 @@ enum rng_type_t { }; enum sample_method_t { - SAMPLE_METHOD_DEFAULT, - EULER, - HEUN, - DPM2, - DPMPP2S_A, - DPMPP2M, - DPMPP2Mv2, - IPNDM, - IPNDM_V, - LCM, - DDIM_TRAILING, - TCD, - EULER_A, + EULER_SAMPLE_METHOD, + EULER_A_SAMPLE_METHOD, + HEUN_SAMPLE_METHOD, + DPM2_SAMPLE_METHOD, + DPMPP2S_A_SAMPLE_METHOD, + DPMPP2M_SAMPLE_METHOD, + DPMPP2Mv2_SAMPLE_METHOD, + IPNDM_SAMPLE_METHOD, + IPNDM_V_SAMPLE_METHOD, + LCM_SAMPLE_METHOD, + DDIM_TRAILING_SAMPLE_METHOD, + TCD_SAMPLE_METHOD, SAMPLE_METHOD_COUNT };