diff --git a/src/model/vae/tae.hpp b/src/model/vae/tae.hpp index 7c6e1d35..a78e5e96 100644 --- a/src/model/vae/tae.hpp +++ b/src/model/vae/tae.hpp @@ -548,7 +548,7 @@ public: } auto result = decoder->forward(ctx, z); if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) { - // (W, H, C, T) -> (W, H, T, C) + // (W, H, T, C) -> (W, H, C, T) result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); } return result; @@ -556,8 +556,10 @@ public: ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) { auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); - // (W, H, T, C) -> (W, H, C, T) - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) { + // (W, H, T, C) -> (W, H, C, T) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + } int64_t num_frames = x->ne[3]; if (num_frames % encoder->t_downscale) { // pad to multiple of encoder->t_downscale at the end @@ -567,7 +569,10 @@ public: } } x = encoder->forward(ctx, x); - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) { + // (W, H, C, T) -> (W, H, T, C) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); + } return x; } };