mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add taeltx2/taeltx2.3 support (#1531)
This commit is contained in:
parent
c51ec7cad9
commit
bdd937f29a
@ -679,7 +679,9 @@ public:
|
|||||||
auto create_tae = [&]() -> std::shared_ptr<VAE> {
|
auto create_tae = [&]() -> std::shared_ptr<VAE> {
|
||||||
if (sd_version_is_wan(version) ||
|
if (sd_version_is_wan(version) ||
|
||||||
sd_version_is_qwen_image(version) ||
|
sd_version_is_qwen_image(version) ||
|
||||||
sd_version_is_anima(version)) {
|
sd_version_is_anima(version) ||
|
||||||
|
sd_version_is_ltxav(version)
|
||||||
|
) {
|
||||||
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
|
||||||
params_backend_for(SDBackendModule::VAE),
|
params_backend_for(SDBackendModule::VAE),
|
||||||
tensor_storage_map,
|
tensor_storage_map,
|
||||||
|
|||||||
52
src/tae.hpp
52
src/tae.hpp
@ -322,13 +322,21 @@ class TinyVideoEncoder : public UnaryBlock {
|
|||||||
int patch_size = 1;
|
int patch_size = 1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1)
|
int t_downscale = 1;
|
||||||
|
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
||||||
: z_channels(z_channels), patch_size(patch_size) {
|
: z_channels(z_channels), patch_size(patch_size) {
|
||||||
|
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||||
|
t_downscale = 1;
|
||||||
|
for (bool downscale : time_downscale) {
|
||||||
|
if (downscale) {
|
||||||
|
t_downscale *= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
int index = 0;
|
int index = 0;
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
|
||||||
index++; // nn.ReLU()
|
index++; // nn.ReLU()
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
int stride = i == num_layers - 1 ? 1 : 2;
|
int stride = time_downscale[i] ? 2 : 1;
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
|
||||||
for (int j = 0; j < num_blocks; j++) {
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
@ -375,15 +383,22 @@ class TinyVideoDecoder : public UnaryBlock {
|
|||||||
static const int num_layers = 3;
|
static const int num_layers = 3;
|
||||||
int channels[num_layers + 1] = {256, 128, 64, 64};
|
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||||
int patch_size = 1;
|
int patch_size = 1;
|
||||||
|
int t_upscale = 1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1)
|
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
|
||||||
: z_channels(z_channels), patch_size(patch_size) {
|
: z_channels(z_channels), patch_size(patch_size) {
|
||||||
|
t_upscale = 1;
|
||||||
|
for (bool upscale : time_upscale) {
|
||||||
|
if (upscale) {
|
||||||
|
t_upscale *= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
int index = 1; // Clamp()
|
int index = 1; // Clamp()
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
|
||||||
index++; // nn.ReLU()
|
index++; // nn.ReLU()
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
int stride = i == 0 ? 1 : 2;
|
int stride = time_upscale[i] ? 2 : 1;
|
||||||
for (int j = 0; j < num_blocks; j++) {
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
||||||
}
|
}
|
||||||
@ -404,7 +419,7 @@ public:
|
|||||||
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
|
||||||
3.0f,
|
3.0f,
|
||||||
true);
|
true);
|
||||||
|
|
||||||
h = first_conv->forward(ctx, h);
|
h = first_conv->forward(ctx, h);
|
||||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
int index = 3;
|
int index = 3;
|
||||||
@ -430,8 +445,8 @@ public:
|
|||||||
if (patch_size > 1) {
|
if (patch_size > 1) {
|
||||||
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
|
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
|
||||||
}
|
}
|
||||||
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
|
// shape(W, H, 3, (t_upscale - 1) + T) => shape(W, H, 3, T)
|
||||||
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]);
|
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - (t_upscale - 1), h->nb[1], h->nb[2], h->nb[3], (t_upscale - 1) * h->nb[3]);
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -442,7 +457,9 @@ protected:
|
|||||||
SDVersion version;
|
SDVersion version;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int z_channels = 16;
|
int z_channels = 16;
|
||||||
|
std::vector<bool> time_downscale = {true, true, false};
|
||||||
|
std::vector<bool> time_upscale = {false, true, true};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||||
@ -451,21 +468,26 @@ public:
|
|||||||
if (version == VERSION_WAN2_2_TI2V) {
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
z_channels = 48;
|
z_channels = 48;
|
||||||
patch = 2;
|
patch = 2;
|
||||||
|
} else if (sd_version_is_ltxav(version)) {
|
||||||
|
z_channels = 128;
|
||||||
|
patch = 4;
|
||||||
|
time_downscale = {true, true, true};
|
||||||
|
time_upscale = {true, true, true};
|
||||||
}
|
}
|
||||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch));
|
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale));
|
||||||
if (!decode_only) {
|
if (!decode_only) {
|
||||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch));
|
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
|
ggml_tensor* decode(GGMLRunnerContext* ctx, ggml_tensor* z) {
|
||||||
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
|
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||||
// (W, H, C, T) -> (W, H, T, C)
|
// (W, H, C, T) -> (W, H, T, C)
|
||||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
|
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
|
||||||
}
|
}
|
||||||
auto result = decoder->forward(ctx, z);
|
auto result = decoder->forward(ctx, z);
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_ltxav(version)) {
|
||||||
// (W, H, C, T) -> (W, H, T, C)
|
// (W, H, C, T) -> (W, H, T, C)
|
||||||
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
|
||||||
}
|
}
|
||||||
@ -477,10 +499,10 @@ public:
|
|||||||
// (W, H, T, C) -> (W, H, C, T)
|
// (W, H, T, C) -> (W, H, C, T)
|
||||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
|
||||||
int64_t num_frames = x->ne[3];
|
int64_t num_frames = x->ne[3];
|
||||||
if (num_frames % 4) {
|
if (num_frames % encoder->t_downscale) {
|
||||||
// pad to multiple of 4 at the end
|
// pad to multiple of encoder->t_downscale at the end
|
||||||
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
|
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
|
||||||
for (int i = 0; i < 4 - num_frames % 4; i++) {
|
for (int i = 0; i < encoder->t_downscale - num_frames % encoder->t_downscale; i++) {
|
||||||
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
|
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user