mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-27 00:26:38 +00:00
fix: correct TAEHV encoding for image models (#1711)
This commit is contained in:
parent
3973015ed7
commit
ec4cb8104b
@ -548,7 +548,7 @@ public:
|
|||||||
}
|
}
|
||||||
auto result = decoder->forward(ctx, z);
|
auto result = decoder->forward(ctx, z);
|
||||||
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||||
// (W, H, C, T) -> (W, H, T, C)
|
// (W, H, T, C) -> (W, H, C, T)
|
||||||
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -556,8 +556,10 @@ public:
|
|||||||
|
|
||||||
ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
ggml_tensor* encode(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
|
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
|
||||||
// (W, H, T, C) -> (W, H, C, T)
|
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
// (W, H, T, C) -> (W, H, C, T)
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||||
|
}
|
||||||
int64_t num_frames = x->ne[3];
|
int64_t num_frames = x->ne[3];
|
||||||
if (num_frames % encoder->t_downscale) {
|
if (num_frames % encoder->t_downscale) {
|
||||||
// pad to multiple of encoder->t_downscale at the end
|
// pad to multiple of encoder->t_downscale at the end
|
||||||
@ -567,7 +569,10 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
x = encoder->forward(ctx, x);
|
x = encoder->forward(ctx, x);
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||||
|
// (W, H, C, T) -> (W, H, T, C)
|
||||||
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||||
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user