From 8823dc48bcc1598eb9671da7b69e45338d0cc5a5 Mon Sep 17 00:00:00 2001 From: leejet Date: Wed, 10 Dec 2025 23:15:08 +0800 Subject: [PATCH] feat: align the spatial size to the corresponding multiple (#1073) --- ggml_extend.hpp | 8 +++++++ stable-diffusion.cpp | 53 ++++++++++++++++++++++++++++++-------------- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 2b4ce5d..5024eb9 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -60,6 +60,14 @@ #define SD_UNUSED(x) (void)(x) #endif +__STATIC_INLINE__ int align_up_offset(int n, int multiple) { + return (multiple - n % multiple) % multiple; +} + +__STATIC_INLINE__ int align_up(int n, int multiple) { + return n + align_up_offset(n, multiple); +} + __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) { switch (level) { case GGML_LOG_LEVEL_DEBUG: diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index d381bf6..1ef8512 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1898,6 +1898,18 @@ public: return vae_scale_factor; } + int get_diffusion_model_down_factor() { + int down_factor = 8; // unet + if (sd_version_is_dit(version)) { + if (sd_version_is_wan(version)) { + down_factor = 2; + } else { + down_factor = 1; + } + } + return down_factor; + } + int get_latent_channel() { int latent_channel = 4; if (sd_version_is_dit(version)) { @@ -3133,22 +3145,19 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; int width = sd_img_gen_params->width; int height = sd_img_gen_params->height; - int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); - if (sd_version_is_dit(sd_ctx->sd->version)) { - if (width % 16 || height % 16) { - LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)", - model_version_to_str[sd_ctx->sd->version], - width, - height); - return nullptr; - } - } else if (width % 64 || height % 64) { - LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)", - model_version_to_str[sd_ctx->sd->version], - width, - height); - return nullptr; + + int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); + int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); + int spatial_multiple = vae_scale_factor * diffusion_model_down_factor; + + int width_offset = align_up_offset(width, spatial_multiple); + int height_offset = align_up_offset(height, spatial_multiple); + if (width_offset > 0 || height_offset > 0) { + width += width_offset; + height += height_offset; + LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_img_gen_params->width, sd_img_gen_params->height, width, height, spatial_multiple); } + LOG_DEBUG("generate_image %dx%d", width, height); if (sd_ctx == nullptr || sd_img_gen_params == nullptr) { return nullptr; @@ -3422,9 +3431,19 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int frames = sd_vid_gen_params->video_frames; frames = (frames - 1) / 4 * 4 + 1; int sample_steps = sd_vid_gen_params->sample_params.sample_steps; - LOG_INFO("generate_video %dx%dx%d", width, height, frames); - int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); + int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); + int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); + int spatial_multiple = vae_scale_factor * diffusion_model_down_factor; + + int width_offset = align_up_offset(width, spatial_multiple); + int height_offset = align_up_offset(height, spatial_multiple); + if (width_offset > 0 || height_offset > 0) { + width += width_offset; + height += height_offset; + LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_vid_gen_params->width, sd_vid_gen_params->height, width, height, spatial_multiple); + } + LOG_INFO("generate_video %dx%dx%d", width, height, frames); enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method; if (sample_method == SAMPLE_METHOD_COUNT) {