fix: correct TAEHV encoding for image models (#1711)

This commit is contained in:
Wagner Bruna 2026-06-26 14:23:18 -03:00 committed by GitHub
parent 3973015ed7
commit ec4cb8104b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -548,7 +548,7 @@ public:
}
auto result = decoder->forward(ctx, z);
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
// (W, H, C, T) -> (W, H, T, C)
// (W, H, T, C) -> (W, H, C, T)
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
}
return result;
@ -556,8 +556,10 @@ public:
ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
// (W, H, T, C) -> (W, H, C, T)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
// (W, H, T, C) -> (W, H, C, T)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
}
int64_t num_frames = x->ne[3];
if (num_frames % encoder->t_downscale) {
// pad to multiple of encoder->t_downscale at the end
@ -567,7 +569,10 @@ public:
}
}
x = encoder->forward(ctx, x);
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
// (W, H, C, T) -> (W, H, T, C)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
}
return x;
}
};