mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-29 09:36:40 +00:00
feat: support Qwen2D VAE (#1714)
This commit is contained in:
parent
d77b8f5ee8
commit
7b5f34d93e
@ -113,6 +113,26 @@ namespace WAN {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Conv2dBut3d : public Conv2d {
|
||||
public:
|
||||
using Conv2d::Conv2d;
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||
ggml_tensor* x_swapped = ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2);
|
||||
x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped);
|
||||
|
||||
ggml_tensor* out = Conv2d::forward(ctx, x_swapped);
|
||||
|
||||
ggml_tensor* out_swapped = ggml_permute(ctx->ggml_ctx, out, 0, 1, 3, 2);
|
||||
|
||||
out_swapped = ggml_cont(ctx->ggml_ctx, out_swapped);
|
||||
|
||||
return out_swapped;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Resample : public GGMLBlock {
|
||||
protected:
|
||||
int64_t dim;
|
||||
@ -338,21 +358,34 @@ namespace WAN {
|
||||
protected:
|
||||
int64_t in_dim;
|
||||
int64_t out_dim;
|
||||
bool is_2D;
|
||||
|
||||
public:
|
||||
ResidualBlock(int64_t in_dim, int64_t out_dim)
|
||||
: in_dim(in_dim), out_dim(out_dim) {
|
||||
ResidualBlock(int64_t in_dim, int64_t out_dim, bool is_2D = false)
|
||||
: in_dim(in_dim), out_dim(out_dim), is_2D(is_2D) {
|
||||
blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim));
|
||||
// residual.1 is nn.SiLU()
|
||||
if (is_2D) {
|
||||
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||
} else {
|
||||
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||
// residual.4 is nn.SiLU()
|
||||
// residual.5 is nn.Dropout()
|
||||
if (is_2D) {
|
||||
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||
} else {
|
||||
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
if (in_dim != out_dim) {
|
||||
if (is_2D) {
|
||||
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {1, 1}));
|
||||
} else {
|
||||
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
@ -363,10 +396,16 @@ namespace WAN {
|
||||
GGML_ASSERT(b == 1);
|
||||
ggml_tensor* h = x;
|
||||
if (in_dim != out_dim) {
|
||||
if (is_2D) {
|
||||
auto shortcut = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["shortcut"]);
|
||||
|
||||
h = shortcut->forward(ctx, x);
|
||||
} else {
|
||||
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);
|
||||
|
||||
h = shortcut->forward(ctx, x);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 7; i++) {
|
||||
if (i == 0 || i == 3) { // RMS_norm
|
||||
@ -385,8 +424,13 @@ namespace WAN {
|
||||
cache_x,
|
||||
2);
|
||||
}
|
||||
if (is_2D) {
|
||||
auto layer = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["residual." + std::to_string(i)]);
|
||||
|
||||
x = layer->forward(ctx, x);
|
||||
} else {
|
||||
x = layer->forward(ctx, x, feat_cache[idx]);
|
||||
}
|
||||
feat_cache[idx] = cache_x;
|
||||
feat_idx += 1;
|
||||
}
|
||||
@ -412,13 +456,14 @@ namespace WAN {
|
||||
int64_t out_dim,
|
||||
int mult,
|
||||
bool temperal_downsample = false,
|
||||
bool down_flag = false)
|
||||
bool down_flag = false,
|
||||
bool is_2D = false)
|
||||
: mult(mult), down_flag(down_flag) {
|
||||
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1));
|
||||
|
||||
int i = 0;
|
||||
for (; i < mult; i++) {
|
||||
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||
in_dim = out_dim;
|
||||
}
|
||||
if (down_flag) {
|
||||
@ -472,7 +517,8 @@ namespace WAN {
|
||||
int64_t out_dim,
|
||||
int mult,
|
||||
bool temperal_upsample = false,
|
||||
bool up_flag = false)
|
||||
bool up_flag = false,
|
||||
bool is_2D = false)
|
||||
: mult(mult), up_flag(up_flag) {
|
||||
if (up_flag) {
|
||||
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
|
||||
@ -480,7 +526,7 @@ namespace WAN {
|
||||
|
||||
int i = 0;
|
||||
for (; i < mult; i++) {
|
||||
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||
in_dim = out_dim;
|
||||
}
|
||||
if (up_flag) {
|
||||
@ -592,6 +638,7 @@ namespace WAN {
|
||||
std::vector<int> dim_mult;
|
||||
int num_res_blocks;
|
||||
std::vector<bool> temperal_downsample;
|
||||
bool is_2D = false;
|
||||
|
||||
public:
|
||||
Encoder3d(int64_t dim = 128,
|
||||
@ -599,23 +646,26 @@ namespace WAN {
|
||||
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||
int num_res_blocks = 2,
|
||||
std::vector<bool> temperal_downsample = {false, true, true},
|
||||
bool wan2_2 = false)
|
||||
bool wan2_2 = false,
|
||||
bool is_2D = false)
|
||||
: dim(dim),
|
||||
z_dim(z_dim),
|
||||
dim_mult(dim_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
temperal_downsample(temperal_downsample),
|
||||
wan2_2(wan2_2) {
|
||||
wan2_2(wan2_2),
|
||||
is_2D(is_2D) {
|
||||
// attn_scales is always []
|
||||
std::vector<int64_t> dims = {dim};
|
||||
for (int u : dim_mult) {
|
||||
dims.push_back(dim * u);
|
||||
}
|
||||
|
||||
if (wan2_2) {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
int64_t input_dim = wan2_2 ? 12 : 3;
|
||||
if (is_2D) {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(input_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
|
||||
} else {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(input_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
|
||||
int index = 0;
|
||||
@ -630,12 +680,13 @@ namespace WAN {
|
||||
out_dim,
|
||||
num_res_blocks,
|
||||
t_down_flag,
|
||||
i != dim_mult.size() - 1));
|
||||
i != dim_mult.size() - 1,
|
||||
is_2D));
|
||||
|
||||
blocks["downsamples." + std::to_string(index++)] = block;
|
||||
} else {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||
blocks["downsamples." + std::to_string(index++)] = block;
|
||||
in_dim = out_dim;
|
||||
}
|
||||
@ -648,14 +699,18 @@ namespace WAN {
|
||||
}
|
||||
}
|
||||
|
||||
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim));
|
||||
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim, is_2D));
|
||||
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(out_dim));
|
||||
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim));
|
||||
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim, is_2D));
|
||||
|
||||
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||
// head.1 is nn.SiLU()
|
||||
if (is_2D) {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, z_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||
} else {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
ggml_tensor* x,
|
||||
@ -673,7 +728,10 @@ namespace WAN {
|
||||
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
||||
|
||||
// conv1
|
||||
if (feat_cache.size() > 0) {
|
||||
if (is_2D) {
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
|
||||
x = conv1->forward(ctx, x);
|
||||
} else if (feat_cache.size() > 0) {
|
||||
int idx = feat_idx;
|
||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||
@ -728,7 +786,10 @@ namespace WAN {
|
||||
// head
|
||||
x = head_0->forward(ctx, x);
|
||||
x = ggml_silu(ctx->ggml_ctx, x);
|
||||
if (feat_cache.size() > 0) {
|
||||
if (is_2D) {
|
||||
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
|
||||
x = head_2->forward(ctx, x);
|
||||
} else if (feat_cache.size() > 0) {
|
||||
int idx = feat_idx;
|
||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||
@ -758,6 +819,7 @@ namespace WAN {
|
||||
std::vector<int> dim_mult;
|
||||
int num_res_blocks;
|
||||
std::vector<bool> temperal_upsample;
|
||||
bool is_2D = false;
|
||||
|
||||
public:
|
||||
Decoder3d(int64_t dim = 128,
|
||||
@ -765,13 +827,15 @@ namespace WAN {
|
||||
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||
int num_res_blocks = 2,
|
||||
std::vector<bool> temperal_upsample = {true, true, false},
|
||||
bool wan2_2 = false)
|
||||
bool wan2_2 = false,
|
||||
bool is_2D = false)
|
||||
: dim(dim),
|
||||
z_dim(z_dim),
|
||||
dim_mult(dim_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
temperal_upsample(temperal_upsample),
|
||||
wan2_2(wan2_2) {
|
||||
wan2_2(wan2_2),
|
||||
is_2D(is_2D) {
|
||||
// attn_scales is always []
|
||||
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
|
||||
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
|
||||
@ -779,12 +843,16 @@ namespace WAN {
|
||||
}
|
||||
|
||||
// init block
|
||||
if(is_2D){
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
|
||||
}else{
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
|
||||
// middle blocks
|
||||
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
|
||||
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
|
||||
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0]));
|
||||
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
|
||||
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
|
||||
|
||||
// upsample blocks
|
||||
int index = 0;
|
||||
@ -799,7 +867,8 @@ namespace WAN {
|
||||
out_dim,
|
||||
num_res_blocks + 1,
|
||||
t_up_flag,
|
||||
i != dim_mult.size() - 1));
|
||||
i != dim_mult.size() - 1,
|
||||
is_2D));
|
||||
|
||||
blocks["upsamples." + std::to_string(index++)] = block;
|
||||
} else {
|
||||
@ -807,7 +876,7 @@ namespace WAN {
|
||||
in_dim = in_dim / 2;
|
||||
}
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||
blocks["upsamples." + std::to_string(index++)] = block;
|
||||
in_dim = out_dim;
|
||||
}
|
||||
@ -822,12 +891,13 @@ namespace WAN {
|
||||
|
||||
// output blocks
|
||||
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||
int64_t final_dim = wan2_2 ? 12 : 3;
|
||||
// head.1 is nn.SiLU()
|
||||
if (wan2_2) {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
if (is_2D) {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, final_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
} else {
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, final_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
@ -847,7 +917,10 @@ namespace WAN {
|
||||
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
||||
|
||||
// conv1
|
||||
if (feat_cache.size() > 0) {
|
||||
if (is_2D) {
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
|
||||
x = conv1->forward(ctx, x);
|
||||
} else if (feat_cache.size() > 0) {
|
||||
int idx = feat_idx;
|
||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||
@ -902,7 +975,10 @@ namespace WAN {
|
||||
// head
|
||||
x = head_0->forward(ctx, x);
|
||||
x = ggml_silu(ctx->ggml_ctx, x);
|
||||
if (feat_cache.size() > 0) {
|
||||
if (is_2D) {
|
||||
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
|
||||
x = head_2->forward(ctx, x);
|
||||
} else if (feat_cache.size() > 0) {
|
||||
int idx = feat_idx;
|
||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||
@ -935,6 +1011,7 @@ namespace WAN {
|
||||
int num_res_blocks = 2;
|
||||
std::vector<bool> temperal_upsample = {true, true, false};
|
||||
std::vector<bool> temperal_downsample = {false, true, true};
|
||||
bool is_2D = false;
|
||||
|
||||
int _conv_num = 33;
|
||||
int _conv_idx = 0;
|
||||
@ -951,8 +1028,8 @@ namespace WAN {
|
||||
}
|
||||
|
||||
public:
|
||||
WanVAE(bool decode_only = true, bool wan2_2 = false)
|
||||
: decode_only(decode_only), wan2_2(wan2_2) {
|
||||
WanVAE(bool decode_only = true, bool wan2_2 = false, bool is_2D = false)
|
||||
: decode_only(decode_only), wan2_2(wan2_2), is_2D(is_2D) {
|
||||
// attn_scales is always []
|
||||
if (wan2_2) {
|
||||
dim = 160;
|
||||
@ -962,13 +1039,27 @@ namespace WAN {
|
||||
_conv_num = 34;
|
||||
_enc_conv_num = 26;
|
||||
}
|
||||
|
||||
if(is_2D){
|
||||
temperal_upsample = {false, false, false};
|
||||
temperal_downsample = {false, false, false};
|
||||
}
|
||||
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2));
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2, is_2D));
|
||||
if (is_2D) {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim * 2, z_dim * 2, {1, 1}));
|
||||
} else {
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2));
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2, is_2D));
|
||||
if (is_2D) {
|
||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, z_dim, {1, 1}));
|
||||
} else {
|
||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_tensor* patchify(ggml_context* ctx,
|
||||
ggml_tensor* x,
|
||||
@ -1054,7 +1145,12 @@ namespace WAN {
|
||||
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
|
||||
}
|
||||
}
|
||||
if (is_2D) {
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
|
||||
out = conv1->forward(ctx, out);
|
||||
} else {
|
||||
out = conv1->forward(ctx, out);
|
||||
}
|
||||
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
|
||||
// sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu");
|
||||
clear_cache();
|
||||
@ -1073,7 +1169,13 @@ namespace WAN {
|
||||
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
|
||||
|
||||
int64_t iter_ = z->ne[2];
|
||||
auto x = conv2->forward(ctx, z);
|
||||
auto x = z;
|
||||
if(is_2D){
|
||||
auto conv2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv2"]);
|
||||
x = conv2->forward(ctx, z);
|
||||
} else {
|
||||
x = conv2->forward(ctx, z);
|
||||
}
|
||||
// sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x");
|
||||
ggml_tensor* out;
|
||||
for (int i = 0; i < iter_; i++) {
|
||||
@ -1129,7 +1231,20 @@ namespace WAN {
|
||||
bool decode_only = false,
|
||||
SDVersion version = VERSION_WAN2,
|
||||
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
|
||||
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V) {
|
||||
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only) {
|
||||
bool is_2D = false;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (ends_with(name, "decoder.conv1.weight")) {
|
||||
if (tensor_storage.ne[2] > 3) {
|
||||
is_2D = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_2D) {
|
||||
LOG_DEBUG("USING 2D VAE");
|
||||
}
|
||||
ae = WanVAE(decode_only, version == VERSION_WAN2_2_TI2V, is_2D);
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user