fix: fix vae tiling for flux2 (#1025)

This commit is contained in:
rmatif 2025-12-01 15:41:56 +01:00 committed by GitHub
parent 34a6fd4e60
commit 0743a1b3b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2096,8 +2096,9 @@ public:
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
int64_t t0 = ggml_time_ms();
ggml_tensor* result = nullptr;
int W = x->ne[0] / get_vae_scale_factor();
int H = x->ne[1] / get_vae_scale_factor();
const int vae_scale_factor = get_vae_scale_factor();
int W = x->ne[0] / vae_scale_factor;
int H = x->ne[1] / vae_scale_factor;
int C = get_latent_channel();
if (vae_tiling_params.enabled && !encode_video) {
// TODO wan2.2 vae support?
@ -2133,7 +2134,7 @@ public:
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, false, &out, work_ctx);
};
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
}
@ -2144,7 +2145,7 @@ public:
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, false, &out, nullptr);
};
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
} else {
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
}
@ -2220,8 +2221,9 @@ public:
}
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
int64_t W = x->ne[0] * get_vae_scale_factor();
int64_t H = x->ne[1] * get_vae_scale_factor();
const int vae_scale_factor = get_vae_scale_factor();
int64_t W = x->ne[0] * vae_scale_factor;
int64_t H = x->ne[1] * vae_scale_factor;
int64_t C = 3;
ggml_tensor* result = nullptr;
if (decode_video) {
@ -2261,7 +2263,7 @@ public:
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, true, &out, nullptr);
};
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
}
@ -2273,7 +2275,7 @@ public:
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, true, &out);
};
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
} else {
tae_first_stage->compute(n_threads, x, true, &result);
}