feat: support Qwen2D VAE (#1714)

This commit is contained in:
stduhpf 2026-06-28 16:50:57 +02:00 committed by GitHub
parent d77b8f5ee8
commit 7b5f34d93e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,6 +113,26 @@ namespace WAN {
} }
}; };
class Conv2dBut3d : public Conv2d {
public:
using Conv2d::Conv2d;
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* x_swapped = ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2);
x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped);
ggml_tensor* out = Conv2d::forward(ctx, x_swapped);
ggml_tensor* out_swapped = ggml_permute(ctx->ggml_ctx, out, 0, 1, 3, 2);
out_swapped = ggml_cont(ctx->ggml_ctx, out_swapped);
return out_swapped;
}
};
class Resample : public GGMLBlock { class Resample : public GGMLBlock {
protected: protected:
int64_t dim; int64_t dim;
@ -338,21 +358,34 @@ namespace WAN {
protected: protected:
int64_t in_dim; int64_t in_dim;
int64_t out_dim; int64_t out_dim;
bool is_2D;
public: public:
ResidualBlock(int64_t in_dim, int64_t out_dim) ResidualBlock(int64_t in_dim, int64_t out_dim, bool is_2D = false)
: in_dim(in_dim), out_dim(out_dim) { : in_dim(in_dim), out_dim(out_dim), is_2D(is_2D) {
blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim)); blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim));
// residual.1 is nn.SiLU() // residual.1 is nn.SiLU()
if (is_2D) {
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim)); blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
// residual.4 is nn.SiLU() // residual.4 is nn.SiLU()
// residual.5 is nn.Dropout() // residual.5 is nn.Dropout()
if (is_2D) {
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
if (in_dim != out_dim) { if (in_dim != out_dim) {
if (is_2D) {
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {1, 1}));
} else {
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1})); blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
} }
} }
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x, ggml_tensor* x,
@ -363,10 +396,16 @@ namespace WAN {
GGML_ASSERT(b == 1); GGML_ASSERT(b == 1);
ggml_tensor* h = x; ggml_tensor* h = x;
if (in_dim != out_dim) { if (in_dim != out_dim) {
if (is_2D) {
auto shortcut = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["shortcut"]);
h = shortcut->forward(ctx, x);
} else {
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]); auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);
h = shortcut->forward(ctx, x); h = shortcut->forward(ctx, x);
} }
}
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
if (i == 0 || i == 3) { // RMS_norm if (i == 0 || i == 3) { // RMS_norm
@ -385,8 +424,13 @@ namespace WAN {
cache_x, cache_x,
2); 2);
} }
if (is_2D) {
auto layer = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["residual." + std::to_string(i)]);
x = layer->forward(ctx, x);
} else {
x = layer->forward(ctx, x, feat_cache[idx]); x = layer->forward(ctx, x, feat_cache[idx]);
}
feat_cache[idx] = cache_x; feat_cache[idx] = cache_x;
feat_idx += 1; feat_idx += 1;
} }
@ -412,13 +456,14 @@ namespace WAN {
int64_t out_dim, int64_t out_dim,
int mult, int mult,
bool temperal_downsample = false, bool temperal_downsample = false,
bool down_flag = false) bool down_flag = false,
bool is_2D = false)
: mult(mult), down_flag(down_flag) { : mult(mult), down_flag(down_flag) {
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1)); blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1));
int i = 0; int i = 0;
for (; i < mult; i++) { for (; i < mult; i++) {
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim)); blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
in_dim = out_dim; in_dim = out_dim;
} }
if (down_flag) { if (down_flag) {
@ -472,7 +517,8 @@ namespace WAN {
int64_t out_dim, int64_t out_dim,
int mult, int mult,
bool temperal_upsample = false, bool temperal_upsample = false,
bool up_flag = false) bool up_flag = false,
bool is_2D = false)
: mult(mult), up_flag(up_flag) { : mult(mult), up_flag(up_flag) {
if (up_flag) { if (up_flag) {
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1)); blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
@ -480,7 +526,7 @@ namespace WAN {
int i = 0; int i = 0;
for (; i < mult; i++) { for (; i < mult; i++) {
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim)); blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
in_dim = out_dim; in_dim = out_dim;
} }
if (up_flag) { if (up_flag) {
@ -592,6 +638,7 @@ namespace WAN {
std::vector<int> dim_mult; std::vector<int> dim_mult;
int num_res_blocks; int num_res_blocks;
std::vector<bool> temperal_downsample; std::vector<bool> temperal_downsample;
bool is_2D = false;
public: public:
Encoder3d(int64_t dim = 128, Encoder3d(int64_t dim = 128,
@ -599,23 +646,26 @@ namespace WAN {
std::vector<int> dim_mult = {1, 2, 4, 4}, std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2, int num_res_blocks = 2,
std::vector<bool> temperal_downsample = {false, true, true}, std::vector<bool> temperal_downsample = {false, true, true},
bool wan2_2 = false) bool wan2_2 = false,
bool is_2D = false)
: dim(dim), : dim(dim),
z_dim(z_dim), z_dim(z_dim),
dim_mult(dim_mult), dim_mult(dim_mult),
num_res_blocks(num_res_blocks), num_res_blocks(num_res_blocks),
temperal_downsample(temperal_downsample), temperal_downsample(temperal_downsample),
wan2_2(wan2_2) { wan2_2(wan2_2),
is_2D(is_2D) {
// attn_scales is always [] // attn_scales is always []
std::vector<int64_t> dims = {dim}; std::vector<int64_t> dims = {dim};
for (int u : dim_mult) { for (int u : dim_mult) {
dims.push_back(dim * u); dims.push_back(dim * u);
} }
if (wan2_2) { int64_t input_dim = wan2_2 ? 12 : 3;
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); if (is_2D) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(input_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
} else { } else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(input_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} }
int index = 0; int index = 0;
@ -630,12 +680,13 @@ namespace WAN {
out_dim, out_dim,
num_res_blocks, num_res_blocks,
t_down_flag, t_down_flag,
i != dim_mult.size() - 1)); i != dim_mult.size() - 1,
is_2D));
blocks["downsamples." + std::to_string(index++)] = block; blocks["downsamples." + std::to_string(index++)] = block;
} else { } else {
for (int j = 0; j < num_res_blocks; j++) { for (int j = 0; j < num_res_blocks; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim)); auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
blocks["downsamples." + std::to_string(index++)] = block; blocks["downsamples." + std::to_string(index++)] = block;
in_dim = out_dim; in_dim = out_dim;
} }
@ -648,14 +699,18 @@ namespace WAN {
} }
} }
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim)); blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim, is_2D));
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(out_dim)); blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(out_dim));
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim)); blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim, is_2D));
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim)); blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
// head.1 is nn.SiLU() // head.1 is nn.SiLU()
if (is_2D) {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, z_dim, {3, 3}, {1, 1}, {1, 1}));
} else {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} }
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x, ggml_tensor* x,
@ -673,7 +728,10 @@ namespace WAN {
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]); auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
// conv1 // conv1
if (feat_cache.size() > 0) { if (is_2D) {
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
x = conv1->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
@ -728,7 +786,10 @@ namespace WAN {
// head // head
x = head_0->forward(ctx, x); x = head_0->forward(ctx, x);
x = ggml_silu(ctx->ggml_ctx, x); x = ggml_silu(ctx->ggml_ctx, x);
if (feat_cache.size() > 0) { if (is_2D) {
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
x = head_2->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
@ -758,6 +819,7 @@ namespace WAN {
std::vector<int> dim_mult; std::vector<int> dim_mult;
int num_res_blocks; int num_res_blocks;
std::vector<bool> temperal_upsample; std::vector<bool> temperal_upsample;
bool is_2D = false;
public: public:
Decoder3d(int64_t dim = 128, Decoder3d(int64_t dim = 128,
@ -765,13 +827,15 @@ namespace WAN {
std::vector<int> dim_mult = {1, 2, 4, 4}, std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2, int num_res_blocks = 2,
std::vector<bool> temperal_upsample = {true, true, false}, std::vector<bool> temperal_upsample = {true, true, false},
bool wan2_2 = false) bool wan2_2 = false,
bool is_2D = false)
: dim(dim), : dim(dim),
z_dim(z_dim), z_dim(z_dim),
dim_mult(dim_mult), dim_mult(dim_mult),
num_res_blocks(num_res_blocks), num_res_blocks(num_res_blocks),
temperal_upsample(temperal_upsample), temperal_upsample(temperal_upsample),
wan2_2(wan2_2) { wan2_2(wan2_2),
is_2D(is_2D) {
// attn_scales is always [] // attn_scales is always []
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim}; std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) { for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
@ -779,12 +843,16 @@ namespace WAN {
} }
// init block // init block
if(is_2D){
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
}else{
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
// middle blocks // middle blocks
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0])); blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0])); blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0]));
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0])); blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
// upsample blocks // upsample blocks
int index = 0; int index = 0;
@ -799,7 +867,8 @@ namespace WAN {
out_dim, out_dim,
num_res_blocks + 1, num_res_blocks + 1,
t_up_flag, t_up_flag,
i != dim_mult.size() - 1)); i != dim_mult.size() - 1,
is_2D));
blocks["upsamples." + std::to_string(index++)] = block; blocks["upsamples." + std::to_string(index++)] = block;
} else { } else {
@ -807,7 +876,7 @@ namespace WAN {
in_dim = in_dim / 2; in_dim = in_dim / 2;
} }
for (int j = 0; j < num_res_blocks + 1; j++) { for (int j = 0; j < num_res_blocks + 1; j++) {
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim)); auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
blocks["upsamples." + std::to_string(index++)] = block; blocks["upsamples." + std::to_string(index++)] = block;
in_dim = out_dim; in_dim = out_dim;
} }
@ -822,12 +891,13 @@ namespace WAN {
// output blocks // output blocks
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim)); blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
int64_t final_dim = wan2_2 ? 12 : 3;
// head.1 is nn.SiLU() // head.1 is nn.SiLU()
if (wan2_2) { if (is_2D) {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, final_dim, {3, 3}, {1, 1}, {1, 1}));
} else { } else {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, final_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} }
} }
@ -847,7 +917,10 @@ namespace WAN {
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]); auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
// conv1 // conv1
if (feat_cache.size() > 0) { if (is_2D) {
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
x = conv1->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
@ -902,7 +975,10 @@ namespace WAN {
// head // head
x = head_0->forward(ctx, x); x = head_0->forward(ctx, x);
x = ggml_silu(ctx->ggml_ctx, x); x = ggml_silu(ctx->ggml_ctx, x);
if (feat_cache.size() > 0) { if (is_2D) {
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
x = head_2->forward(ctx, x);
} else if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
@ -935,6 +1011,7 @@ namespace WAN {
int num_res_blocks = 2; int num_res_blocks = 2;
std::vector<bool> temperal_upsample = {true, true, false}; std::vector<bool> temperal_upsample = {true, true, false};
std::vector<bool> temperal_downsample = {false, true, true}; std::vector<bool> temperal_downsample = {false, true, true};
bool is_2D = false;
int _conv_num = 33; int _conv_num = 33;
int _conv_idx = 0; int _conv_idx = 0;
@ -951,8 +1028,8 @@ namespace WAN {
} }
public: public:
WanVAE(bool decode_only = true, bool wan2_2 = false) WanVAE(bool decode_only = true, bool wan2_2 = false, bool is_2D = false)
: decode_only(decode_only), wan2_2(wan2_2) { : decode_only(decode_only), wan2_2(wan2_2), is_2D(is_2D) {
// attn_scales is always [] // attn_scales is always []
if (wan2_2) { if (wan2_2) {
dim = 160; dim = 160;
@ -962,13 +1039,27 @@ namespace WAN {
_conv_num = 34; _conv_num = 34;
_enc_conv_num = 26; _enc_conv_num = 26;
} }
if(is_2D){
temperal_upsample = {false, false, false};
temperal_downsample = {false, false, false};
}
if (!decode_only) { if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2)); blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2, is_2D));
if (is_2D) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim * 2, z_dim * 2, {1, 1}));
} else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1})); blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
} }
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2)); }
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2, is_2D));
if (is_2D) {
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, z_dim, {1, 1}));
} else {
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
} }
}
static ggml_tensor* patchify(ggml_context* ctx, static ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x, ggml_tensor* x,
@ -1054,7 +1145,12 @@ namespace WAN {
out = ggml_concat(ctx->ggml_ctx, out, out_, 2); out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
} }
} }
if (is_2D) {
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
out = conv1->forward(ctx, out); out = conv1->forward(ctx, out);
} else {
out = conv1->forward(ctx, out);
}
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0]; auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
// sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu"); // sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu");
clear_cache(); clear_cache();
@ -1073,7 +1169,13 @@ namespace WAN {
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]); auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
int64_t iter_ = z->ne[2]; int64_t iter_ = z->ne[2];
auto x = conv2->forward(ctx, z); auto x = z;
if(is_2D){
auto conv2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv2"]);
x = conv2->forward(ctx, z);
} else {
x = conv2->forward(ctx, z);
}
// sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x"); // sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x");
ggml_tensor* out; ggml_tensor* out;
for (int i = 0; i < iter_; i++) { for (int i = 0; i < iter_; i++) {
@ -1129,7 +1231,20 @@ namespace WAN {
bool decode_only = false, bool decode_only = false,
SDVersion version = VERSION_WAN2, SDVersion version = VERSION_WAN2,
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr) std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V) { : VAE(version, backend, prefix, weight_manager), decode_only(decode_only) {
bool is_2D = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (ends_with(name, "decoder.conv1.weight")) {
if (tensor_storage.ne[2] > 3) {
is_2D = true;
}
break;
}
}
if (is_2D) {
LOG_DEBUG("USING 2D VAE");
}
ae = WanVAE(decode_only, version == VERSION_WAN2_2_TI2V, is_2D);
ae.init(params_ctx, tensor_storage_map, prefix); ae.init(params_ctx, tensor_storage_map, prefix);
} }