mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add taeltx2_3_wide support (#1535)
This commit is contained in:
parent
ef92a0027e
commit
47d8198b69
@ -1602,6 +1602,23 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
||||
return num;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* ggml_ext_vec_concat(ggml_context* ctx,
|
||||
std::vector<ggml_tensor*>& tensors,
|
||||
int dim) {
|
||||
while (tensors.size() > 1) {
|
||||
std::vector<ggml_tensor*> next_level;
|
||||
for (size_t i = 0; i < tensors.size(); i += 2) {
|
||||
if (i + 1 < tensors.size()) {
|
||||
next_level.push_back(ggml_concat(ctx, tensors[i], tensors[i + 1], dim));
|
||||
} else {
|
||||
next_level.push_back(tensors[i]);
|
||||
}
|
||||
}
|
||||
tensors = std::move(next_level);
|
||||
}
|
||||
return tensors[0];
|
||||
}
|
||||
|
||||
/* SDXL with LoRA requires more space */
|
||||
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||
#define MAX_GRAPH_SIZE 327680
|
||||
@ -3139,6 +3156,163 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class Conv2d_grouped : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
int64_t out_channels;
|
||||
int groups;
|
||||
std::pair<int, int> kernel_size;
|
||||
std::pair<int, int> stride;
|
||||
std::pair<int, int> padding;
|
||||
std::pair<int, int> dilation;
|
||||
bool bias;
|
||||
float scale = 1.f;
|
||||
std::string prefix;
|
||||
|
||||
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||
this->prefix = prefix;
|
||||
enum ggml_type wtype = GGML_TYPE_F16;
|
||||
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
|
||||
if (bias) {
|
||||
enum ggml_type wtype = GGML_TYPE_F32;
|
||||
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Conv2d_grouped(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
int groups,
|
||||
std::pair<int, int> kernel_size,
|
||||
std::pair<int, int> stride = {1, 1},
|
||||
std::pair<int, int> padding = {0, 0},
|
||||
std::pair<int, int> dilation = {1, 1},
|
||||
bool bias = true)
|
||||
: in_channels(in_channels),
|
||||
out_channels(out_channels),
|
||||
groups(groups),
|
||||
kernel_size(kernel_size),
|
||||
stride(stride),
|
||||
padding(padding),
|
||||
dilation(dilation),
|
||||
bias(bias) {}
|
||||
|
||||
void set_scale(float scale_value) {
|
||||
scale = scale_value;
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return "Conv2d_grouped";
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||
ggml_tensor* w = params["weight"];
|
||||
ggml_tensor* b = nullptr;
|
||||
if (bias) {
|
||||
b = params["bias"];
|
||||
}
|
||||
|
||||
if (groups == 1) {
|
||||
if (ctx->weight_adapter) {
|
||||
WeightAdapter::ForwardParams forward_params;
|
||||
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||
forward_params.conv2d.s0 = stride.second;
|
||||
forward_params.conv2d.s1 = stride.first;
|
||||
forward_params.conv2d.p0 = padding.second;
|
||||
forward_params.conv2d.p1 = padding.first;
|
||||
forward_params.conv2d.d0 = dilation.second;
|
||||
forward_params.conv2d.d1 = dilation.first;
|
||||
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
|
||||
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
|
||||
forward_params.conv2d.scale = scale;
|
||||
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params);
|
||||
}
|
||||
return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, b,
|
||||
stride.second, stride.first,
|
||||
padding.second, padding.first,
|
||||
dilation.second, dilation.first,
|
||||
ctx->conv2d_direct_enabled,
|
||||
ctx->circular_x_enabled,
|
||||
ctx->circular_y_enabled,
|
||||
scale);
|
||||
}
|
||||
|
||||
if (groups == in_channels && groups == out_channels) {
|
||||
ggml_tensor* res;
|
||||
if (ctx->conv2d_direct_enabled) {
|
||||
res = ggml_conv_2d_dw_direct(ctx->ggml_ctx, x, w,
|
||||
stride.second, stride.first,
|
||||
padding.second, padding.first,
|
||||
dilation.second, dilation.first);
|
||||
} else {
|
||||
res = ggml_conv_2d_dw(ctx->ggml_ctx, x, w,
|
||||
stride.second, stride.first,
|
||||
padding.second, padding.first,
|
||||
dilation.second, dilation.first);
|
||||
}
|
||||
if (b) {
|
||||
res = ggml_add(ctx->ggml_ctx, res, b);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int64_t ic_g = in_channels / groups;
|
||||
int64_t oc_g = out_channels / groups;
|
||||
|
||||
std::vector<ggml_tensor*> out_slices(groups);
|
||||
|
||||
for (int i = 0; i < groups; ++i) {
|
||||
size_t x_offset = i * ic_g * x->nb[2];
|
||||
ggml_tensor* x_i = ggml_view_4d(ctx->ggml_ctx, x,
|
||||
x->ne[0], x->ne[1], ic_g, x->ne[3],
|
||||
x->nb[1], x->nb[2], x->nb[3],
|
||||
x_offset);
|
||||
|
||||
size_t w_offset = i * oc_g * w->nb[3];
|
||||
ggml_tensor* w_i = ggml_view_4d(ctx->ggml_ctx, w,
|
||||
w->ne[0], w->ne[1], w->ne[2], oc_g,
|
||||
w->nb[1], w->nb[2], w->nb[3],
|
||||
w_offset);
|
||||
|
||||
ggml_tensor* b_i = nullptr;
|
||||
if (b) {
|
||||
size_t b_offset = i * oc_g * b->nb[0];
|
||||
b_i = ggml_view_1d(ctx->ggml_ctx, b, oc_g, b_offset);
|
||||
}
|
||||
|
||||
if (ctx->weight_adapter) {
|
||||
WeightAdapter::ForwardParams forward_params;
|
||||
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||
forward_params.conv2d.s0 = stride.second;
|
||||
forward_params.conv2d.s1 = stride.first;
|
||||
forward_params.conv2d.p0 = padding.second;
|
||||
forward_params.conv2d.p1 = padding.first;
|
||||
forward_params.conv2d.d0 = dilation.second;
|
||||
forward_params.conv2d.d1 = dilation.first;
|
||||
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
|
||||
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
|
||||
forward_params.conv2d.scale = scale;
|
||||
out_slices[i] = ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x_i, w_i, b_i, prefix, forward_params);
|
||||
} else {
|
||||
out_slices[i] = ggml_ext_conv_2d(ctx->ggml_ctx, x_i, w_i, b_i,
|
||||
stride.second, stride.first,
|
||||
padding.second, padding.first,
|
||||
dilation.second, dilation.first,
|
||||
ctx->conv2d_direct_enabled,
|
||||
ctx->circular_x_enabled,
|
||||
ctx->circular_y_enabled,
|
||||
scale);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* out = ggml_ext_vec_concat(ctx->ggml_ctx, out_slices, 2);
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
class Conv3d : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
|
||||
88
src/tae.hpp
88
src/tae.hpp
@ -259,7 +259,51 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
ggml_tensor* patchify(ggml_context* ctx,
|
||||
class WideMemBlock : public GGMLBlock {
|
||||
bool has_skip_conv = false;
|
||||
|
||||
public:
|
||||
WideMemBlock(int channels, int out_channels)
|
||||
: has_skip_conv(channels != out_channels) {
|
||||
int groups = std::max(1, out_channels / 64);
|
||||
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
|
||||
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
|
||||
blocks["conv.6"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
||||
if (has_skip_conv) {
|
||||
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* past) {
|
||||
// x: [n, channels, h, w]
|
||||
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.2"]);
|
||||
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||
auto conv3 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.6"]);
|
||||
|
||||
auto h = ggml_concat(ctx->ggml_ctx, x, past, 2);
|
||||
h = conv0->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv1->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv2->forward(ctx, h);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
h = conv3->forward(ctx, h);
|
||||
|
||||
auto skip = x;
|
||||
if (has_skip_conv) {
|
||||
auto skip_conv = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
|
||||
skip = skip_conv->forward(ctx, x);
|
||||
}
|
||||
h = ggml_add_inplace(ctx->ggml_ctx, h, skip);
|
||||
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
ggml_tensor*
|
||||
patchify(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
int64_t patch_size,
|
||||
int64_t b = 1) {
|
||||
@ -325,7 +369,6 @@ public:
|
||||
int t_downscale = 1;
|
||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
||||
: z_channels(z_channels), patch_size(patch_size) {
|
||||
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||
t_downscale = 1;
|
||||
for (bool downscale : time_downscale) {
|
||||
if (downscale) {
|
||||
@ -384,11 +427,18 @@ class TinyVideoDecoder : public UnaryBlock {
|
||||
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||
int patch_size = 1;
|
||||
int t_upscale = 1;
|
||||
bool is_wide = false;
|
||||
|
||||
public:
|
||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
|
||||
: z_channels(z_channels), patch_size(patch_size) {
|
||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true}, bool is_wide = false)
|
||||
: z_channels(z_channels), patch_size(patch_size), is_wide(is_wide) {
|
||||
t_upscale = 1;
|
||||
if (is_wide) {
|
||||
channels[0] = 1024;
|
||||
channels[1] = 512;
|
||||
channels[2] = 256;
|
||||
}
|
||||
|
||||
for (bool upscale : time_upscale) {
|
||||
if (upscale) {
|
||||
t_upscale *= 2;
|
||||
@ -400,8 +450,12 @@ public:
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
int stride = time_upscale[i] ? 2 : 1;
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
if (is_wide) {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new WideMemBlock(channels[i], channels[i]));
|
||||
} else {
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
||||
}
|
||||
}
|
||||
index++; // nn.Upsample()
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
|
||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||
@ -425,10 +479,15 @@ public:
|
||||
int index = 3;
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
for (int j = 0; j < num_blocks; j++) {
|
||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||
if (is_wide) {
|
||||
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h, mem);
|
||||
} else{
|
||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||
h = block->forward(ctx, h, mem);
|
||||
}
|
||||
}
|
||||
// upsample
|
||||
index++;
|
||||
@ -455,6 +514,7 @@ class TAEHV : public GGMLBlock {
|
||||
protected:
|
||||
bool decode_only;
|
||||
SDVersion version;
|
||||
bool is_wide;
|
||||
|
||||
public:
|
||||
int z_channels = 16;
|
||||
@ -462,8 +522,8 @@ public:
|
||||
std::vector<bool> time_upscale = {false, true, true};
|
||||
|
||||
public:
|
||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), version(version) {
|
||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2, bool is_wide = false)
|
||||
: decode_only(decode_only), version(version), is_wide(is_wide) {
|
||||
int patch = 1;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
z_channels = 48;
|
||||
@ -474,7 +534,7 @@ public:
|
||||
time_downscale = {true, true, true};
|
||||
time_upscale = {true, true, true};
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale));
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale, is_wide));
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
|
||||
}
|
||||
@ -623,6 +683,7 @@ struct TinyImageAutoEncoder : public VAE {
|
||||
struct TinyVideoAutoEncoder : public VAE {
|
||||
TAEHV taehv;
|
||||
bool decode_only = false;
|
||||
bool is_wide = false;
|
||||
|
||||
TinyVideoAutoEncoder(ggml_backend_t backend,
|
||||
ggml_backend_t params_backend,
|
||||
@ -631,8 +692,14 @@ struct TinyVideoAutoEncoder : public VAE {
|
||||
bool decoder_only = true,
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decoder_only),
|
||||
taehv(decoder_only, version),
|
||||
VAE(version, backend, params_backend) {
|
||||
for (auto tensor_storage : tensor_storage_map) {
|
||||
if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) {
|
||||
is_wide = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
taehv = TAEHV(decoder_only, version, is_wide);
|
||||
scale_input = false;
|
||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
@ -663,7 +730,8 @@ struct TinyVideoAutoEncoder : public VAE {
|
||||
}
|
||||
|
||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
ggml_cgraph* gf = decode_graph && is_wide ? ggml_new_graph_custom(compute_ctx, 4096, false)
|
||||
: ggml_new_graph(compute_ctx);
|
||||
ggml_tensor* z = make_input(z_tensor);
|
||||
auto runner_ctx = get_context();
|
||||
ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user