feat: add taeltx2_3_wide support (#1535)

This commit is contained in:
stduhpf 2026-05-21 16:34:12 +02:00 committed by GitHub
parent ef92a0027e
commit 47d8198b69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 258 additions and 16 deletions

View File

@ -1602,6 +1602,23 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
return num; return num;
} }
__STATIC_INLINE__ ggml_tensor* ggml_ext_vec_concat(ggml_context* ctx,
std::vector<ggml_tensor*>& tensors,
int dim) {
while (tensors.size() > 1) {
std::vector<ggml_tensor*> next_level;
for (size_t i = 0; i < tensors.size(); i += 2) {
if (i + 1 < tensors.size()) {
next_level.push_back(ggml_concat(ctx, tensors[i], tensors[i + 1], dim));
} else {
next_level.push_back(tensors[i]);
}
}
tensors = std::move(next_level);
}
return tensors[0];
}
/* SDXL with LoRA requires more space */ /* SDXL with LoRA requires more space */
#define MAX_PARAMS_TENSOR_NUM 32768 #define MAX_PARAMS_TENSOR_NUM 32768
#define MAX_GRAPH_SIZE 327680 #define MAX_GRAPH_SIZE 327680
@ -3139,6 +3156,163 @@ public:
} }
}; };
class Conv2d_grouped : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
int groups;
std::pair<int, int> kernel_size;
std::pair<int, int> stride;
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
float scale = 1.f;
std::string prefix;
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = GGML_TYPE_F16;
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
if (bias) {
enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
}
}
public:
Conv2d_grouped(int64_t in_channels,
int64_t out_channels,
int groups,
std::pair<int, int> kernel_size,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
: in_channels(in_channels),
out_channels(out_channels),
groups(groups),
kernel_size(kernel_size),
stride(stride),
padding(padding),
dilation(dilation),
bias(bias) {}
void set_scale(float scale_value) {
scale = scale_value;
}
std::string get_desc() {
return "Conv2d_grouped";
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* w = params["weight"];
ggml_tensor* b = nullptr;
if (bias) {
b = params["bias"];
}
if (groups == 1) {
if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
forward_params.conv2d.s0 = stride.second;
forward_params.conv2d.s1 = stride.first;
forward_params.conv2d.p0 = padding.second;
forward_params.conv2d.p1 = padding.first;
forward_params.conv2d.d0 = dilation.second;
forward_params.conv2d.d1 = dilation.first;
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
forward_params.conv2d.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params);
}
return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, b,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first,
ctx->conv2d_direct_enabled,
ctx->circular_x_enabled,
ctx->circular_y_enabled,
scale);
}
if (groups == in_channels && groups == out_channels) {
ggml_tensor* res;
if (ctx->conv2d_direct_enabled) {
res = ggml_conv_2d_dw_direct(ctx->ggml_ctx, x, w,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first);
} else {
res = ggml_conv_2d_dw(ctx->ggml_ctx, x, w,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first);
}
if (b) {
res = ggml_add(ctx->ggml_ctx, res, b);
}
return res;
}
int64_t ic_g = in_channels / groups;
int64_t oc_g = out_channels / groups;
std::vector<ggml_tensor*> out_slices(groups);
for (int i = 0; i < groups; ++i) {
size_t x_offset = i * ic_g * x->nb[2];
ggml_tensor* x_i = ggml_view_4d(ctx->ggml_ctx, x,
x->ne[0], x->ne[1], ic_g, x->ne[3],
x->nb[1], x->nb[2], x->nb[3],
x_offset);
size_t w_offset = i * oc_g * w->nb[3];
ggml_tensor* w_i = ggml_view_4d(ctx->ggml_ctx, w,
w->ne[0], w->ne[1], w->ne[2], oc_g,
w->nb[1], w->nb[2], w->nb[3],
w_offset);
ggml_tensor* b_i = nullptr;
if (b) {
size_t b_offset = i * oc_g * b->nb[0];
b_i = ggml_view_1d(ctx->ggml_ctx, b, oc_g, b_offset);
}
if (ctx->weight_adapter) {
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
forward_params.conv2d.s0 = stride.second;
forward_params.conv2d.s1 = stride.first;
forward_params.conv2d.p0 = padding.second;
forward_params.conv2d.p1 = padding.first;
forward_params.conv2d.d0 = dilation.second;
forward_params.conv2d.d1 = dilation.first;
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
forward_params.conv2d.scale = scale;
out_slices[i] = ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x_i, w_i, b_i, prefix, forward_params);
} else {
out_slices[i] = ggml_ext_conv_2d(ctx->ggml_ctx, x_i, w_i, b_i,
stride.second, stride.first,
padding.second, padding.first,
dilation.second, dilation.first,
ctx->conv2d_direct_enabled,
ctx->circular_x_enabled,
ctx->circular_y_enabled,
scale);
}
}
ggml_tensor* out = ggml_ext_vec_concat(ctx->ggml_ctx, out_slices, 2);
return out;
}
};
class Conv3d : public UnaryBlock { class Conv3d : public UnaryBlock {
protected: protected:
int64_t in_channels; int64_t in_channels;

View File

@ -259,10 +259,54 @@ public:
} }
}; };
ggml_tensor* patchify(ggml_context* ctx, class WideMemBlock : public GGMLBlock {
ggml_tensor* x, bool has_skip_conv = false;
int64_t patch_size,
int64_t b = 1) { public:
WideMemBlock(int channels, int out_channels)
: has_skip_conv(channels != out_channels) {
int groups = std::max(1, out_channels / 64);
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
blocks["conv.6"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
if (has_skip_conv) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* past) {
// x: [n, channels, h, w]
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv1 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.2"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
auto conv3 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.6"]);
auto h = ggml_concat(ctx->ggml_ctx, x, past, 2);
h = conv0->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
h = conv3->forward(ctx, h);
auto skip = x;
if (has_skip_conv) {
auto skip_conv = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
skip = skip_conv->forward(ctx, x);
}
h = ggml_add_inplace(ctx->ggml_ctx, h, skip);
h = ggml_relu_inplace(ctx->ggml_ctx, h);
return h;
}
};
ggml_tensor*
patchify(ggml_context* ctx,
ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {
// x: [f, b*c, h*q, w*r] // x: [f, b*c, h*q, w*r]
// return: [f, b*c*r*q, h, w] // return: [f, b*c*r*q, h, w]
if (patch_size == 1) { if (patch_size == 1) {
@ -325,7 +369,6 @@ public:
int t_downscale = 1; int t_downscale = 1;
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false}) TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
: z_channels(z_channels), patch_size(patch_size) { : z_channels(z_channels), patch_size(patch_size) {
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
t_downscale = 1; t_downscale = 1;
for (bool downscale : time_downscale) { for (bool downscale : time_downscale) {
if (downscale) { if (downscale) {
@ -384,11 +427,18 @@ class TinyVideoDecoder : public UnaryBlock {
int channels[num_layers + 1] = {256, 128, 64, 64}; int channels[num_layers + 1] = {256, 128, 64, 64};
int patch_size = 1; int patch_size = 1;
int t_upscale = 1; int t_upscale = 1;
bool is_wide = false;
public: public:
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true}) TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true}, bool is_wide = false)
: z_channels(z_channels), patch_size(patch_size) { : z_channels(z_channels), patch_size(patch_size), is_wide(is_wide) {
t_upscale = 1; t_upscale = 1;
if (is_wide) {
channels[0] = 1024;
channels[1] = 512;
channels[2] = 256;
}
for (bool upscale : time_upscale) { for (bool upscale : time_upscale) {
if (upscale) { if (upscale) {
t_upscale *= 2; t_upscale *= 2;
@ -400,7 +450,11 @@ public:
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
int stride = time_upscale[i] ? 2 : 1; int stride = time_upscale[i] ? 2 : 1;
for (int j = 0; j < num_blocks; j++) { for (int j = 0; j < num_blocks; j++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i])); if (is_wide) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new WideMemBlock(channels[i], channels[i]));
} else {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
}
} }
index++; // nn.Upsample() index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride)); blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
@ -425,10 +479,15 @@ public:
int index = 3; int index = 3;
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
for (int j = 0; j < num_blocks; j++) { for (int j = 0; j < num_blocks; j++) {
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0); auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0); mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
h = block->forward(ctx, h, mem); if (is_wide) {
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
h = block->forward(ctx, h, mem);
} else{
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
h = block->forward(ctx, h, mem);
}
} }
// upsample // upsample
index++; index++;
@ -455,6 +514,7 @@ class TAEHV : public GGMLBlock {
protected: protected:
bool decode_only; bool decode_only;
SDVersion version; SDVersion version;
bool is_wide;
public: public:
int z_channels = 16; int z_channels = 16;
@ -462,8 +522,8 @@ public:
std::vector<bool> time_upscale = {false, true, true}; std::vector<bool> time_upscale = {false, true, true};
public: public:
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2, bool is_wide = false)
: decode_only(decode_only), version(version) { : decode_only(decode_only), version(version), is_wide(is_wide) {
int patch = 1; int patch = 1;
if (version == VERSION_WAN2_2_TI2V) { if (version == VERSION_WAN2_2_TI2V) {
z_channels = 48; z_channels = 48;
@ -474,7 +534,7 @@ public:
time_downscale = {true, true, true}; time_downscale = {true, true, true};
time_upscale = {true, true, true}; time_upscale = {true, true, true};
} }
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale)); blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale, is_wide));
if (!decode_only) { if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale)); blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
} }
@ -623,7 +683,8 @@ struct TinyImageAutoEncoder : public VAE {
struct TinyVideoAutoEncoder : public VAE { struct TinyVideoAutoEncoder : public VAE {
TAEHV taehv; TAEHV taehv;
bool decode_only = false; bool decode_only = false;
bool is_wide = false;
TinyVideoAutoEncoder(ggml_backend_t backend, TinyVideoAutoEncoder(ggml_backend_t backend,
ggml_backend_t params_backend, ggml_backend_t params_backend,
const String2TensorStorage& tensor_storage_map, const String2TensorStorage& tensor_storage_map,
@ -631,8 +692,14 @@ struct TinyVideoAutoEncoder : public VAE {
bool decoder_only = true, bool decoder_only = true,
SDVersion version = VERSION_WAN2) SDVersion version = VERSION_WAN2)
: decode_only(decoder_only), : decode_only(decoder_only),
taehv(decoder_only, version),
VAE(version, backend, params_backend) { VAE(version, backend, params_backend) {
for (auto tensor_storage : tensor_storage_map) {
if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) {
is_wide = true;
break;
}
}
taehv = TAEHV(decoder_only, version, is_wide);
scale_input = false; scale_input = false;
taehv.init(params_ctx, tensor_storage_map, prefix); taehv.init(params_ctx, tensor_storage_map, prefix);
} }
@ -663,7 +730,8 @@ struct TinyVideoAutoEncoder : public VAE {
} }
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) { ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
ggml_cgraph* gf = ggml_new_graph(compute_ctx); ggml_cgraph* gf = decode_graph && is_wide ? ggml_new_graph_custom(compute_ctx, 4096, false)
: ggml_new_graph(compute_ctx);
ggml_tensor* z = make_input(z_tensor); ggml_tensor* z = make_input(z_tensor);
auto runner_ctx = get_context(); auto runner_ctx = get_context();
ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z); ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);