feat: add taeltx2/taeltx2.3 support (#1531)

This commit is contained in:
stduhpf 2026-05-20 16:14:05 +02:00 committed by GitHub
parent c51ec7cad9
commit bdd937f29a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 16 deletions

View File

@ -679,7 +679,9 @@ public:
auto create_tae = [&]() -> std::shared_ptr<VAE> { auto create_tae = [&]() -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) || if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_anima(version)) { sd_version_is_anima(version) ||
sd_version_is_ltxav(version)
) {
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE), return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
params_backend_for(SDBackendModule::VAE), params_backend_for(SDBackendModule::VAE),
tensor_storage_map, tensor_storage_map,

View File

@ -322,13 +322,21 @@ class TinyVideoEncoder : public UnaryBlock {
int patch_size = 1; int patch_size = 1;
public: public:
TinyVideoEncoder(int z_channels = 4, int patch_size = 1) int t_downscale = 1;
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
: z_channels(z_channels), patch_size(patch_size) { : z_channels(z_channels), patch_size(patch_size) {
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
t_downscale = 1;
for (bool downscale : time_downscale) {
if (downscale) {
t_downscale *= 2;
}
}
int index = 0; int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
index++; // nn.ReLU() index++; // nn.ReLU()
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
int stride = i == num_layers - 1 ? 1 : 2; int stride = time_downscale[i] ? 2 : 1;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int j = 0; j < num_blocks; j++) { for (int j = 0; j < num_blocks; j++) {
@ -375,15 +383,22 @@ class TinyVideoDecoder : public UnaryBlock {
static const int num_layers = 3; static const int num_layers = 3;
int channels[num_layers + 1] = {256, 128, 64, 64}; int channels[num_layers + 1] = {256, 128, 64, 64};
int patch_size = 1; int patch_size = 1;
int t_upscale = 1;
public: public:
TinyVideoDecoder(int z_channels = 4, int patch_size = 1) TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
: z_channels(z_channels), patch_size(patch_size) { : z_channels(z_channels), patch_size(patch_size) {
t_upscale = 1;
for (bool upscale : time_upscale) {
if (upscale) {
t_upscale *= 2;
}
}
int index = 1; // Clamp() int index = 1; // Clamp()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
index++; // nn.ReLU() index++; // nn.ReLU()
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
int stride = i == 0 ? 1 : 2; int stride = time_upscale[i] ? 2 : 1;
for (int j = 0; j < num_blocks; j++) { for (int j = 0; j < num_blocks; j++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i])); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
} }
@ -404,7 +419,7 @@ public:
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
3.0f, 3.0f,
true); true);
h = first_conv->forward(ctx, h); h = first_conv->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h); h = ggml_relu_inplace(ctx->ggml_ctx, h);
int index = 3; int index = 3;
@ -430,8 +445,8 @@ public:
if (patch_size > 1) { if (patch_size > 1) {
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1); h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
} }
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T) // shape(W, H, 3, (t_upscale - 1) + T) => shape(W, H, 3, T)
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]); h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - (t_upscale - 1), h->nb[1], h->nb[2], h->nb[3], (t_upscale - 1) * h->nb[3]);
return h; return h;
} }
}; };
@ -442,7 +457,9 @@ protected:
SDVersion version; SDVersion version;
public: public:
int z_channels = 16; int z_channels = 16;
std::vector<bool> time_downscale = {true, true, false};
std::vector<bool> time_upscale = {false, true, true};
public: public:
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
@ -451,21 +468,26 @@ public:
if (version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2_2_TI2V) {
z_channels = 48; z_channels = 48;
patch = 2; patch = 2;
} else if (sd_version_is_ltxav(version)) {
z_channels = 128;
patch = 4;
time_downscale = {true, true, true};
time_upscale = {true, true, true};
} }
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch)); blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale));
if (!decode_only) { if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch)); blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
} }
} }
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) { ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]); auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
// (W, H, C, T) -> (W, H, T, C) // (W, H, C, T) -> (W, H, T, C)
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2)); z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
} }
auto result = decoder->forward(ctx, z); auto result = decoder->forward(ctx, z);
if (sd_version_is_wan(version)) { if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
// (W, H, C, T) -> (W, H, T, C) // (W, H, C, T) -> (W, H, T, C)
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2)); result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
} }
@ -477,10 +499,10 @@ public:
// (W, H, T, C) -> (W, H, C, T) // (W, H, T, C) -> (W, H, C, T)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
int64_t num_frames = x->ne[3]; int64_t num_frames = x->ne[3];
if (num_frames % 4) { if (num_frames % encoder->t_downscale) {
// pad to multiple of 4 at the end // pad to multiple of encoder->t_downscale at the end
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]); auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
for (int i = 0; i < 4 - num_frames % 4; i++) { for (int i = 0; i < encoder->t_downscale - num_frames % encoder->t_downscale; i++) {
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3); x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
} }
} }