mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
fix: fix vae tiling for flux2 (#1025)
This commit is contained in:
parent
34a6fd4e60
commit
0743a1b3b5
@ -2096,8 +2096,9 @@ public:
|
||||
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
ggml_tensor* result = nullptr;
|
||||
int W = x->ne[0] / get_vae_scale_factor();
|
||||
int H = x->ne[1] / get_vae_scale_factor();
|
||||
const int vae_scale_factor = get_vae_scale_factor();
|
||||
int W = x->ne[0] / vae_scale_factor;
|
||||
int H = x->ne[1] / vae_scale_factor;
|
||||
int C = get_latent_channel();
|
||||
if (vae_tiling_params.enabled && !encode_video) {
|
||||
// TODO wan2.2 vae support?
|
||||
@ -2133,7 +2134,7 @@ public:
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
first_stage_model->compute(n_threads, in, false, &out, work_ctx);
|
||||
};
|
||||
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
|
||||
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
|
||||
} else {
|
||||
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
|
||||
}
|
||||
@ -2144,7 +2145,7 @@ public:
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
tae_first_stage->compute(n_threads, in, false, &out, nullptr);
|
||||
};
|
||||
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
|
||||
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
|
||||
} else {
|
||||
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
|
||||
}
|
||||
@ -2220,8 +2221,9 @@ public:
|
||||
}
|
||||
|
||||
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||
int64_t W = x->ne[0] * get_vae_scale_factor();
|
||||
int64_t H = x->ne[1] * get_vae_scale_factor();
|
||||
const int vae_scale_factor = get_vae_scale_factor();
|
||||
int64_t W = x->ne[0] * vae_scale_factor;
|
||||
int64_t H = x->ne[1] * vae_scale_factor;
|
||||
int64_t C = 3;
|
||||
ggml_tensor* result = nullptr;
|
||||
if (decode_video) {
|
||||
@ -2261,7 +2263,7 @@ public:
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
first_stage_model->compute(n_threads, in, true, &out, nullptr);
|
||||
};
|
||||
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
|
||||
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
|
||||
} else {
|
||||
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
|
||||
}
|
||||
@ -2273,7 +2275,7 @@ public:
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
tae_first_stage->compute(n_threads, in, true, &out);
|
||||
};
|
||||
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
|
||||
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
|
||||
} else {
|
||||
tae_first_stage->compute(n_threads, x, true, &result);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user