feat: add taef2 support (#1211)

This commit is contained in:
Oleg Skutte 2026-01-19 19:39:36 +04:00 committed by GitHub
parent c6206fb351
commit e50e1f253d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

54
tae.hpp
View File

@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
protected:
int n_in;
int n_out;
bool use_midblock_gn;
public:
TAEBlock(int n_in, int n_out)
: n_in(n_in), n_out(n_out) {
TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
: n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
if (n_in != n_out) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
}
if (use_midblock_gn) {
int n_gn = n_in * 4;
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
// pool.2 is ReLU, handled in forward
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [n, n_in, h, w]
// return: [n, n_out, h, w]
if (use_midblock_gn) {
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
auto p = pool_0->forward(ctx, x);
p = pool_1->forward(ctx, p);
p = ggml_relu_inplace(ctx->ggml_ctx, p);
p = pool_3->forward(ctx, p);
x = ggml_add(ctx->ggml_ctx, x, p);
}
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
int num_blocks = 3;
public:
TinyEncoder(int z_channels = 4)
TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
: z_channels(z_channels) {
int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
@ -80,7 +101,7 @@ public:
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
}
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
@ -107,7 +128,7 @@ class TinyDecoder : public UnaryBlock {
int num_blocks = 3;
public:
TinyDecoder(int z_channels = 4)
TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
: z_channels(z_channels) {
int index = 0;
@ -115,7 +136,7 @@ public:
index++; // nn.ReLU()
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
}
index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
@ -470,29 +491,44 @@ public:
class TAESD : public GGMLBlock {
protected:
bool decode_only;
bool taef2 = false;
public:
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) {
int z_channels = 4;
bool use_midblock_gn = false;
taef2 = sd_version_is_flux2(version);
if (sd_version_is_dit(version)) {
z_channels = 16;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
if (taef2) {
z_channels = 32;
use_midblock_gn = true;
}
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
}
}
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
if (taef2) {
z = unpatchify(ctx->ggml_ctx, z, 2);
}
return decoder->forward(ctx, z);
}
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
return encoder->forward(ctx, x);
auto z = encoder->forward(ctx, x);
if (taef2) {
z = patchify(ctx->ggml_ctx, z, 2);
}
return z;
}
};