mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
refactor: optimize the handling of sample method (#999)
This commit is contained in:
parent
490c51d963
commit
20345888a3
28
denoiser.hpp
28
denoiser.hpp
@ -640,7 +640,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
size_t steps = sigmas.size() - 1;
|
size_t steps = sigmas.size() - 1;
|
||||||
// sample_euler_ancestral
|
// sample_euler_ancestral
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case EULER_A: {
|
case EULER_A_SAMPLE_METHOD: {
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
||||||
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -693,7 +693,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case EULER: // Implemented without any sigma churn
|
case EULER_SAMPLE_METHOD: // Implemented without any sigma churn
|
||||||
{
|
{
|
||||||
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -726,7 +726,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case HEUN: {
|
case HEUN_SAMPLE_METHOD: {
|
||||||
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
||||||
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -776,7 +776,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case DPM2: {
|
case DPM2_SAMPLE_METHOD: {
|
||||||
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
||||||
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -828,7 +828,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} break;
|
} break;
|
||||||
case DPMPP2S_A: {
|
case DPMPP2S_A_SAMPLE_METHOD: {
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
||||||
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -892,7 +892,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case DPMPP2M: // DPM++ (2M) from Karras et al (2022)
|
case DPMPP2M_SAMPLE_METHOD: // DPM++ (2M) from Karras et al (2022)
|
||||||
{
|
{
|
||||||
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -931,7 +931,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case DPMPP2Mv2: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
|
case DPMPP2Mv2_SAMPLE_METHOD: // Modified DPM++ (2M) from https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457
|
||||||
{
|
{
|
||||||
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
|
||||||
|
|
||||||
@ -974,7 +974,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
|
case IPNDM_SAMPLE_METHOD: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
|
||||||
{
|
{
|
||||||
int max_order = 4;
|
int max_order = 4;
|
||||||
ggml_tensor* x_next = x;
|
ggml_tensor* x_next = x;
|
||||||
@ -1049,7 +1049,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case IPNDM_V: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
|
case IPNDM_V_SAMPLE_METHOD: // iPNDM_v sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main
|
||||||
{
|
{
|
||||||
int max_order = 4;
|
int max_order = 4;
|
||||||
std::vector<ggml_tensor*> buffer_model;
|
std::vector<ggml_tensor*> buffer_model;
|
||||||
@ -1123,7 +1123,7 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
d_cur = ggml_dup_tensor(work_ctx, x_next);
|
d_cur = ggml_dup_tensor(work_ctx, x_next);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LCM: // Latent Consistency Models
|
case LCM_SAMPLE_METHOD: // Latent Consistency Models
|
||||||
{
|
{
|
||||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
|
||||||
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
struct ggml_tensor* d = ggml_dup_tensor(work_ctx, x);
|
||||||
@ -1158,8 +1158,8 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
|
case DDIM_TRAILING_SAMPLE_METHOD: // Denoising Diffusion Implicit Models
|
||||||
// with the "trailing" timestep spacing
|
// with the "trailing" timestep spacing
|
||||||
{
|
{
|
||||||
// See J. Song et al., "Denoising Diffusion Implicit
|
// See J. Song et al., "Denoising Diffusion Implicit
|
||||||
// Models", arXiv:2010.02502 [cs.LG]
|
// Models", arXiv:2010.02502 [cs.LG]
|
||||||
@ -1352,8 +1352,8 @@ static void sample_k_diffusion(sample_method_t method,
|
|||||||
// factor c_in.
|
// factor c_in.
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
|
case TCD_SAMPLE_METHOD: // Strategic Stochastic Sampling (Algorithm 4) in
|
||||||
// Trajectory Consistency Distillation
|
// Trajectory Consistency Distillation
|
||||||
{
|
{
|
||||||
// See J. Zheng et al., "Trajectory Consistency
|
// See J. Zheng et al., "Trajectory Consistency
|
||||||
// Distillation: Improved Latent Consistency Distillation
|
// Distillation: Improved Latent Consistency Distillation
|
||||||
|
|||||||
@ -1902,10 +1902,14 @@ int main(int argc, const char* argv[]) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
|
if (params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
|
||||||
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
|
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
|
||||||
|
params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
|
if (params.sample_params.scheduler == SCHEDULER_COUNT) {
|
||||||
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
|
params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -47,8 +47,8 @@ const char* model_version_to_str[] = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const char* sampling_methods_str[] = {
|
const char* sampling_methods_str[] = {
|
||||||
"default",
|
|
||||||
"Euler",
|
"Euler",
|
||||||
|
"Euler A",
|
||||||
"Heun",
|
"Heun",
|
||||||
"DPM2",
|
"DPM2",
|
||||||
"DPM++ (2s)",
|
"DPM++ (2s)",
|
||||||
@ -59,7 +59,6 @@ const char* sampling_methods_str[] = {
|
|||||||
"LCM",
|
"LCM",
|
||||||
"DDIM \"trailing\"",
|
"DDIM \"trailing\"",
|
||||||
"TCD",
|
"TCD",
|
||||||
"Euler A",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
@ -2228,8 +2227,8 @@ enum rng_type_t str_to_rng_type(const char* str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const char* sample_method_to_str[] = {
|
const char* sample_method_to_str[] = {
|
||||||
"default",
|
|
||||||
"euler",
|
"euler",
|
||||||
|
"euler_a",
|
||||||
"heun",
|
"heun",
|
||||||
"dpm2",
|
"dpm2",
|
||||||
"dpm++2s_a",
|
"dpm++2s_a",
|
||||||
@ -2240,7 +2239,6 @@ const char* sample_method_to_str[] = {
|
|||||||
"lcm",
|
"lcm",
|
||||||
"ddim_trailing",
|
"ddim_trailing",
|
||||||
"tcd",
|
"tcd",
|
||||||
"euler_a",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
||||||
@ -2469,7 +2467,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
|
|||||||
sample_params->guidance.slg.layer_end = 0.2f;
|
sample_params->guidance.slg.layer_end = 0.2f;
|
||||||
sample_params->guidance.slg.scale = 0.f;
|
sample_params->guidance.slg.scale = 0.f;
|
||||||
sample_params->scheduler = SCHEDULER_COUNT;
|
sample_params->scheduler = SCHEDULER_COUNT;
|
||||||
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
|
sample_params->sample_method = SAMPLE_METHOD_COUNT;
|
||||||
sample_params->sample_steps = 20;
|
sample_params->sample_steps = 20;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2627,19 +2625,19 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
|
|||||||
|
|
||||||
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
||||||
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
||||||
SDVersion version = sd_ctx->sd->version;
|
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
||||||
if (sd_version_is_dit(version))
|
return EULER_SAMPLE_METHOD;
|
||||||
return EULER;
|
}
|
||||||
else
|
|
||||||
return EULER_A;
|
|
||||||
}
|
}
|
||||||
return SAMPLE_METHOD_COUNT;
|
return EULER_A_SAMPLE_METHOD;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
|
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
|
||||||
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
|
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
|
||||||
if (edm_v_denoiser) {
|
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
|
||||||
return EXPONENTIAL_SCHEDULER;
|
if (edm_v_denoiser) {
|
||||||
|
return EXPONENTIAL_SCHEDULER;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return DISCRETE_SCHEDULER;
|
return DISCRETE_SCHEDULER;
|
||||||
}
|
}
|
||||||
@ -2827,7 +2825,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
int C = sd_ctx->sd->get_latent_channel();
|
int C = sd_ctx->sd->get_latent_channel();
|
||||||
int W = width / sd_ctx->sd->get_vae_scale_factor();
|
int W = width / sd_ctx->sd->get_vae_scale_factor();
|
||||||
int H = height / sd_ctx->sd->get_vae_scale_factor();
|
int H = height / sd_ctx->sd->get_vae_scale_factor();
|
||||||
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
|
||||||
|
|
||||||
struct ggml_tensor* control_latent = nullptr;
|
struct ggml_tensor* control_latent = nullptr;
|
||||||
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
|
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != nullptr) {
|
||||||
@ -3056,10 +3053,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
sd_ctx->sd->rng->manual_seed(seed);
|
sd_ctx->sd->rng->manual_seed(seed);
|
||||||
sd_ctx->sd->sampler_rng->manual_seed(seed);
|
sd_ctx->sd->sampler_rng->manual_seed(seed);
|
||||||
|
|
||||||
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
|
|
||||||
|
|
||||||
size_t t0 = ggml_time_ms();
|
size_t t0 = ggml_time_ms();
|
||||||
|
|
||||||
|
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
|
||||||
|
if (sample_method == SAMPLE_METHOD_COUNT) {
|
||||||
|
sample_method = sd_get_default_sample_method(sd_ctx);
|
||||||
|
}
|
||||||
|
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
||||||
|
|
||||||
|
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
|
||||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);
|
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps, sd_img_gen_params->sample_params.scheduler, sd_ctx->sd->version);
|
||||||
|
|
||||||
ggml_tensor* init_latent = nullptr;
|
ggml_tensor* init_latent = nullptr;
|
||||||
@ -3248,11 +3250,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
|
|
||||||
if (sample_method == SAMPLE_METHOD_DEFAULT) {
|
|
||||||
sample_method = sd_get_default_sample_method(sd_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
sd_image_t* result_images = generate_image_internal(sd_ctx,
|
sd_image_t* result_images = generate_image_internal(sd_ctx,
|
||||||
work_ctx,
|
work_ctx,
|
||||||
init_latent,
|
init_latent,
|
||||||
@ -3302,6 +3299,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
|
|
||||||
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
||||||
|
|
||||||
|
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
|
||||||
|
if (sample_method == SAMPLE_METHOD_COUNT) {
|
||||||
|
sample_method = sd_get_default_sample_method(sd_ctx);
|
||||||
|
}
|
||||||
|
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
||||||
|
|
||||||
int high_noise_sample_steps = 0;
|
int high_noise_sample_steps = 0;
|
||||||
if (sd_ctx->sd->high_noise_diffusion_model) {
|
if (sd_ctx->sd->high_noise_diffusion_model) {
|
||||||
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
|
high_noise_sample_steps = sd_vid_gen_params->high_noise_sample_params.sample_steps;
|
||||||
@ -3570,6 +3573,12 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
// High Noise Sample
|
// High Noise Sample
|
||||||
if (high_noise_sample_steps > 0) {
|
if (high_noise_sample_steps > 0) {
|
||||||
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
|
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
|
||||||
|
enum sample_method_t high_noise_sample_method = sd_vid_gen_params->high_noise_sample_params.sample_method;
|
||||||
|
if (high_noise_sample_method == SAMPLE_METHOD_COUNT) {
|
||||||
|
high_noise_sample_method = sd_get_default_sample_method(sd_ctx);
|
||||||
|
}
|
||||||
|
LOG_INFO("sampling(high noise) using %s method", sampling_methods_str[high_noise_sample_method]);
|
||||||
|
|
||||||
int64_t sampling_start = ggml_time_ms();
|
int64_t sampling_start = ggml_time_ms();
|
||||||
|
|
||||||
std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
|
std::vector<float> high_noise_sigmas = std::vector<float>(sigmas.begin(), sigmas.begin() + high_noise_sample_steps + 1);
|
||||||
@ -3588,7 +3597,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
sd_vid_gen_params->high_noise_sample_params.guidance,
|
sd_vid_gen_params->high_noise_sample_params.guidance,
|
||||||
sd_vid_gen_params->high_noise_sample_params.eta,
|
sd_vid_gen_params->high_noise_sample_params.eta,
|
||||||
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
|
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
|
||||||
sd_vid_gen_params->high_noise_sample_params.sample_method,
|
high_noise_sample_method,
|
||||||
high_noise_sigmas,
|
high_noise_sigmas,
|
||||||
-1,
|
-1,
|
||||||
{},
|
{},
|
||||||
@ -3625,7 +3634,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
sd_vid_gen_params->sample_params.guidance,
|
sd_vid_gen_params->sample_params.guidance,
|
||||||
sd_vid_gen_params->sample_params.eta,
|
sd_vid_gen_params->sample_params.eta,
|
||||||
sd_vid_gen_params->sample_params.shifted_timestep,
|
sd_vid_gen_params->sample_params.shifted_timestep,
|
||||||
sd_vid_gen_params->sample_params.sample_method,
|
sample_method,
|
||||||
sigmas,
|
sigmas,
|
||||||
-1,
|
-1,
|
||||||
{},
|
{},
|
||||||
|
|||||||
@ -36,19 +36,18 @@ enum rng_type_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum sample_method_t {
|
enum sample_method_t {
|
||||||
SAMPLE_METHOD_DEFAULT,
|
EULER_SAMPLE_METHOD,
|
||||||
EULER,
|
EULER_A_SAMPLE_METHOD,
|
||||||
HEUN,
|
HEUN_SAMPLE_METHOD,
|
||||||
DPM2,
|
DPM2_SAMPLE_METHOD,
|
||||||
DPMPP2S_A,
|
DPMPP2S_A_SAMPLE_METHOD,
|
||||||
DPMPP2M,
|
DPMPP2M_SAMPLE_METHOD,
|
||||||
DPMPP2Mv2,
|
DPMPP2Mv2_SAMPLE_METHOD,
|
||||||
IPNDM,
|
IPNDM_SAMPLE_METHOD,
|
||||||
IPNDM_V,
|
IPNDM_V_SAMPLE_METHOD,
|
||||||
LCM,
|
LCM_SAMPLE_METHOD,
|
||||||
DDIM_TRAILING,
|
DDIM_TRAILING_SAMPLE_METHOD,
|
||||||
TCD,
|
TCD_SAMPLE_METHOD,
|
||||||
EULER_A,
|
|
||||||
SAMPLE_METHOD_COUNT
|
SAMPLE_METHOD_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user