refector: reuse some code

This commit is contained in:
leejet 2025-07-01 23:33:50 +08:00
parent 9251756086
commit 7dac89ad75
4 changed files with 32 additions and 38 deletions

View File

@ -347,12 +347,13 @@ struct EDMVDenoiser : public CompVisVDenoiser {
float min_sigma = 0.002; float min_sigma = 0.002;
float max_sigma = 120.0; float max_sigma = 120.0;
EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) { EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0)
: min_sigma(min_sigma), max_sigma(max_sigma) {
schedule = std::make_shared<ExponentialSchedule>(); schedule = std::make_shared<ExponentialSchedule>();
} }
float t_to_sigma(float t) { float t_to_sigma(float t) {
return std::exp(t * 4/(float)TIMESTEPS); return std::exp(t * 4 / (float)TIMESTEPS);
} }
float sigma_to_t(float s) { float sigma_to_t(float s) {

View File

@ -1566,6 +1566,29 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
return result_images; return result_images;
} }
ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
ggml_context* work_ctx,
int width,
int height) {
int C = 4;
if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_version_is_sd3(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
}
return init_latent;
}
sd_image_t* txt2img(sd_ctx_t* sd_ctx, sd_image_t* txt2img(sd_ctx_t* sd_ctx,
const char* prompt_c_str, const char* prompt_c_str,
const char* negative_prompt_c_str, const char* negative_prompt_c_str,
@ -1622,27 +1645,12 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
int C = 4;
if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_version_is_sd3(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
}
if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (sd_version_is_inpaint(sd_ctx->sd->version)) {
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
} }
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height);
sd_image_t* result_images = generate_image(sd_ctx, sd_image_t* result_images = generate_image(sd_ctx,
work_ctx, work_ctx,
init_latent, init_latent,
@ -2046,23 +2054,6 @@ sd_image_t* edit(sd_ctx_t* sd_ctx,
} }
sd_ctx->sd->rng->manual_seed(seed); sd_ctx->sd->rng->manual_seed(seed);
int C = 4;
if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_version_is_sd3(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
}
size_t t0 = ggml_time_ms(); size_t t0 = ggml_time_ms();
std::vector<struct ggml_tensor*> ref_latents; std::vector<struct ggml_tensor*> ref_latents;
@ -2085,6 +2076,8 @@ sd_image_t* edit(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height);
sd_image_t* result_images = generate_image(sd_ctx, sd_image_t* result_images = generate_image(sd_ctx,
work_ctx, work_ctx,
init_latent, init_latent,