mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-27 00:26:38 +00:00
fix: correct TAEHV encoding for image models (#1711)
This commit is contained in:
parent
3973015ed7
commit
ec4cb8104b
@ -548,7 +548,7 @@ public:
|
||||
}
|
||||
auto result = decoder->forward(ctx, z);
|
||||
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||
// (W, H, C, T) -> (W, H, T, C)
|
||||
// (W, H, T, C) -> (W, H, C, T)
|
||||
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
||||
}
|
||||
return result;
|
||||
@ -556,8 +556,10 @@ public:
|
||||
|
||||
ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
|
||||
// (W, H, T, C) -> (W, H, C, T)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||
// (W, H, T, C) -> (W, H, C, T)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
}
|
||||
int64_t num_frames = x->ne[3];
|
||||
if (num_frames % encoder->t_downscale) {
|
||||
// pad to multiple of encoder->t_downscale at the end
|
||||
@ -567,7 +569,10 @@ public:
|
||||
}
|
||||
}
|
||||
x = encoder->forward(ctx, x);
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||
// (W, H, C, T) -> (W, H, T, C)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
}
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user