mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add taeltx2/taeltx2.3 support (#1531)
This commit is contained in:
parent
c51ec7cad9
commit
bdd937f29a
@ -679,7 +679,9 @@ public:
|
||||
auto create_tae = [&]() -> std::shared_ptr<VAE> {
|
||||
if (sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_anima(version)) {
|
||||
sd_version_is_anima(version) ||
|
||||
sd_version_is_ltxav(version)
|
||||
) {
|
||||
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
||||
params_backend_for(SDBackendModule::VAE),
|
||||
tensor_storage_map,
|
||||
|
||||
52
src/tae.hpp
52
src/tae.hpp
@ -322,13 +322,21 @@ class TinyVideoEncoder : public UnaryBlock {
|
||||
int patch_size = 1;
|
||||
|
||||
public:
|
||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1)
|
||||
int t_downscale = 1;
|
||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
||||
: z_channels(z_channels), patch_size(patch_size) {
|
||||
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||
t_downscale = 1;
|
||||
for (bool downscale : time_downscale) {
|
||||
if (downscale) {
|
||||
t_downscale *= 2;
|
||||
}
|
||||
}
|
||||
int index = 0;
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
|
||||
index++; // nn.ReLU()
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
int stride = i == num_layers - 1 ? 1 : 2;
|
||||
int stride = time_downscale[i] ? 2 : 1;
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
@ -375,15 +383,22 @@ class TinyVideoDecoder : public UnaryBlock {
|
||||
static const int num_layers = 3;
|
||||
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||
int patch_size = 1;
|
||||
int t_upscale = 1;
|
||||
|
||||
public:
|
||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1)
|
||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
|
||||
: z_channels(z_channels), patch_size(patch_size) {
|
||||
t_upscale = 1;
|
||||
for (bool upscale : time_upscale) {
|
||||
if (upscale) {
|
||||
t_upscale *= 2;
|
||||
}
|
||||
}
|
||||
int index = 1; // Clamp()
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
|
||||
index++; // nn.ReLU()
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
int stride = i == 0 ? 1 : 2;
|
||||
int stride = time_upscale[i] ? 2 : 1;
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
||||
}
|
||||
@ -404,7 +419,7 @@ public:
|
||||
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
||||
3.0f,
|
||||
true);
|
||||
|
||||
|
||||
h = first_conv->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
int index = 3;
|
||||
@ -430,8 +445,8 @@ public:
|
||||
if (patch_size > 1) {
|
||||
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
|
||||
}
|
||||
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
|
||||
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]);
|
||||
// shape(W, H, 3, (t_upscale - 1) + T) => shape(W, H, 3, T)
|
||||
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - (t_upscale - 1), h->nb[1], h->nb[2], h->nb[3], (t_upscale - 1) * h->nb[3]);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
@ -442,7 +457,9 @@ protected:
|
||||
SDVersion version;
|
||||
|
||||
public:
|
||||
int z_channels = 16;
|
||||
int z_channels = 16;
|
||||
std::vector<bool> time_downscale = {true, true, false};
|
||||
std::vector<bool> time_upscale = {false, true, true};
|
||||
|
||||
public:
|
||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||
@ -451,21 +468,26 @@ public:
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
z_channels = 48;
|
||||
patch = 2;
|
||||
} else if (sd_version_is_ltxav(version)) {
|
||||
z_channels = 128;
|
||||
patch = 4;
|
||||
time_downscale = {true, true, true};
|
||||
time_upscale = {true, true, true};
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch));
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale));
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
|
||||
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
|
||||
if (sd_version_is_wan(version)) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||
// (W, H, C, T) -> (W, H, T, C)
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
|
||||
}
|
||||
auto result = decoder->forward(ctx, z);
|
||||
if (sd_version_is_wan(version)) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||
// (W, H, C, T) -> (W, H, T, C)
|
||||
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
||||
}
|
||||
@ -477,10 +499,10 @@ public:
|
||||
// (W, H, T, C) -> (W, H, C, T)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||
int64_t num_frames = x->ne[3];
|
||||
if (num_frames % 4) {
|
||||
// pad to multiple of 4 at the end
|
||||
if (num_frames % encoder->t_downscale) {
|
||||
// pad to multiple of encoder->t_downscale at the end
|
||||
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
|
||||
for (int i = 0; i < 4 - num_frames % 4; i++) {
|
||||
for (int i = 0; i < encoder->t_downscale - num_frames % encoder->t_downscale; i++) {
|
||||
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user