feat: enable vae tiling for vid gen (#1152)

* enable vae tiling for vid gen

* format code

* eliminate compilation warning

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
rmatif 2026-01-08 16:23:05 +01:00 committed by GitHub
parent 27b5f17401
commit 0e52afc651
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 6 deletions

View File

@ -757,6 +757,7 @@ int main(int argc, const char* argv[]) {
gen_params.seed, gen_params.seed,
gen_params.video_frames, gen_params.video_frames,
gen_params.vace_strength, gen_params.vace_strength,
ctx_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };

View File

@ -2489,10 +2489,15 @@ public:
ne2 = 1; ne2 = 1;
ne3 = C * x->ne[3]; ne3 = C * x->ne[3];
} else { } else {
if (!use_tiny_autoencoder) { int64_t out_channels = C;
C *= 2; bool encode_outputs_mu = use_tiny_autoencoder ||
sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
version == VERSION_CHROMA_RADIANCE;
if (!encode_outputs_mu) {
out_channels *= 2;
} }
ne2 = C; ne2 = out_channels;
ne3 = x->ne[3]; ne3 = x->ne[3];
} }
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
@ -2633,7 +2638,7 @@ public:
} }
process_latent_out(x); process_latent_out(x);
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
if (vae_tiling_params.enabled && !decode_video) { if (vae_tiling_params.enabled) {
float tile_overlap; float tile_overlap;
int tile_size_x, tile_size_y; int tile_size_x, tile_size_y;
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]); get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
@ -2651,7 +2656,7 @@ public:
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
process_vae_output_tensor(result); process_vae_output_tensor(result);
} else { } else {
if (vae_tiling_params.enabled && !decode_video) { if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps // split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, true, &out); tae_first_stage->compute(n_threads, in, true, &out);
@ -3046,7 +3051,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"height: %d\n" "height: %d\n"
"sample_params: %s\n" "sample_params: %s\n"
"strength: %.2f\n" "strength: %.2f\n"
"seed: %" PRId64 "\n" "seed: %" PRId64
"\n"
"batch_count: %d\n" "batch_count: %d\n"
"ref_images_count: %d\n" "ref_images_count: %d\n"
"auto_resize_ref_image: %s\n" "auto_resize_ref_image: %s\n"
@ -3099,6 +3105,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f; sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_cache_params_init(&sd_vid_gen_params->cache); sd_cache_params_init(&sd_vid_gen_params->cache);
} }
@ -3728,6 +3735,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
return nullptr; return nullptr;
} }
sd_ctx->sd->vae_tiling_params = sd_vid_gen_params->vae_tiling_params;
std::string prompt = SAFE_STR(sd_vid_gen_params->prompt); std::string prompt = SAFE_STR(sd_vid_gen_params->prompt);
std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); std::string negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);

View File

@ -319,6 +319,7 @@ typedef struct {
int64_t seed; int64_t seed;
int video_frames; int video_frames;
float vace_strength; float vace_strength;
sd_tiling_params_t vae_tiling_params;
sd_cache_params_t cache; sd_cache_params_t cache;
} sd_vid_gen_params_t; } sd_vid_gen_params_t;