mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-29 09:36:40 +00:00
feat: support Qwen2D VAE (#1714)
This commit is contained in:
parent
d77b8f5ee8
commit
7b5f34d93e
@ -113,6 +113,26 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dBut3d : public Conv2d {
|
||||||
|
public:
|
||||||
|
using Conv2d::Conv2d;
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
ggml_tensor* x_swapped = ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2);
|
||||||
|
x_swapped = ggml_cont(ctx->ggml_ctx, x_swapped);
|
||||||
|
|
||||||
|
ggml_tensor* out = Conv2d::forward(ctx, x_swapped);
|
||||||
|
|
||||||
|
ggml_tensor* out_swapped = ggml_permute(ctx->ggml_ctx, out, 0, 1, 3, 2);
|
||||||
|
|
||||||
|
out_swapped = ggml_cont(ctx->ggml_ctx, out_swapped);
|
||||||
|
|
||||||
|
return out_swapped;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class Resample : public GGMLBlock {
|
class Resample : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
@ -338,21 +358,34 @@ namespace WAN {
|
|||||||
protected:
|
protected:
|
||||||
int64_t in_dim;
|
int64_t in_dim;
|
||||||
int64_t out_dim;
|
int64_t out_dim;
|
||||||
|
bool is_2D;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ResidualBlock(int64_t in_dim, int64_t out_dim)
|
ResidualBlock(int64_t in_dim, int64_t out_dim, bool is_2D = false)
|
||||||
: in_dim(in_dim), out_dim(out_dim) {
|
: in_dim(in_dim), out_dim(out_dim), is_2D(is_2D) {
|
||||||
blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim));
|
blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim));
|
||||||
// residual.1 is nn.SiLU()
|
// residual.1 is nn.SiLU()
|
||||||
|
if (is_2D) {
|
||||||
|
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
} else {
|
||||||
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
}
|
||||||
blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||||
// residual.4 is nn.SiLU()
|
// residual.4 is nn.SiLU()
|
||||||
// residual.5 is nn.Dropout()
|
// residual.5 is nn.Dropout()
|
||||||
|
if (is_2D) {
|
||||||
|
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, out_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
} else {
|
||||||
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
}
|
||||||
if (in_dim != out_dim) {
|
if (in_dim != out_dim) {
|
||||||
|
if (is_2D) {
|
||||||
|
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(in_dim, out_dim, {1, 1}));
|
||||||
|
} else {
|
||||||
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
|
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
@ -363,10 +396,16 @@ namespace WAN {
|
|||||||
GGML_ASSERT(b == 1);
|
GGML_ASSERT(b == 1);
|
||||||
ggml_tensor* h = x;
|
ggml_tensor* h = x;
|
||||||
if (in_dim != out_dim) {
|
if (in_dim != out_dim) {
|
||||||
|
if (is_2D) {
|
||||||
|
auto shortcut = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["shortcut"]);
|
||||||
|
|
||||||
|
h = shortcut->forward(ctx, x);
|
||||||
|
} else {
|
||||||
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);
|
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);
|
||||||
|
|
||||||
h = shortcut->forward(ctx, x);
|
h = shortcut->forward(ctx, x);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
if (i == 0 || i == 3) { // RMS_norm
|
if (i == 0 || i == 3) { // RMS_norm
|
||||||
@ -385,8 +424,13 @@ namespace WAN {
|
|||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
|
if (is_2D) {
|
||||||
|
auto layer = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["residual." + std::to_string(i)]);
|
||||||
|
|
||||||
|
x = layer->forward(ctx, x);
|
||||||
|
} else {
|
||||||
x = layer->forward(ctx, x, feat_cache[idx]);
|
x = layer->forward(ctx, x, feat_cache[idx]);
|
||||||
|
}
|
||||||
feat_cache[idx] = cache_x;
|
feat_cache[idx] = cache_x;
|
||||||
feat_idx += 1;
|
feat_idx += 1;
|
||||||
}
|
}
|
||||||
@ -412,13 +456,14 @@ namespace WAN {
|
|||||||
int64_t out_dim,
|
int64_t out_dim,
|
||||||
int mult,
|
int mult,
|
||||||
bool temperal_downsample = false,
|
bool temperal_downsample = false,
|
||||||
bool down_flag = false)
|
bool down_flag = false,
|
||||||
|
bool is_2D = false)
|
||||||
: mult(mult), down_flag(down_flag) {
|
: mult(mult), down_flag(down_flag) {
|
||||||
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1));
|
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new AvgDown3D(in_dim, out_dim, temperal_downsample ? 2 : 1, down_flag ? 2 : 1));
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (; i < mult; i++) {
|
for (; i < mult; i++) {
|
||||||
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
blocks["downsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||||
in_dim = out_dim;
|
in_dim = out_dim;
|
||||||
}
|
}
|
||||||
if (down_flag) {
|
if (down_flag) {
|
||||||
@ -472,7 +517,8 @@ namespace WAN {
|
|||||||
int64_t out_dim,
|
int64_t out_dim,
|
||||||
int mult,
|
int mult,
|
||||||
bool temperal_upsample = false,
|
bool temperal_upsample = false,
|
||||||
bool up_flag = false)
|
bool up_flag = false,
|
||||||
|
bool is_2D = false)
|
||||||
: mult(mult), up_flag(up_flag) {
|
: mult(mult), up_flag(up_flag) {
|
||||||
if (up_flag) {
|
if (up_flag) {
|
||||||
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
|
blocks["avg_shortcut"] = std::shared_ptr<GGMLBlock>(new DupUp3D(in_dim, out_dim, temperal_upsample ? 2 : 1, up_flag ? 2 : 1));
|
||||||
@ -480,7 +526,7 @@ namespace WAN {
|
|||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (; i < mult; i++) {
|
for (; i < mult; i++) {
|
||||||
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
blocks["upsamples." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||||
in_dim = out_dim;
|
in_dim = out_dim;
|
||||||
}
|
}
|
||||||
if (up_flag) {
|
if (up_flag) {
|
||||||
@ -592,6 +638,7 @@ namespace WAN {
|
|||||||
std::vector<int> dim_mult;
|
std::vector<int> dim_mult;
|
||||||
int num_res_blocks;
|
int num_res_blocks;
|
||||||
std::vector<bool> temperal_downsample;
|
std::vector<bool> temperal_downsample;
|
||||||
|
bool is_2D = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Encoder3d(int64_t dim = 128,
|
Encoder3d(int64_t dim = 128,
|
||||||
@ -599,23 +646,26 @@ namespace WAN {
|
|||||||
std::vector<int> dim_mult = {1, 2, 4, 4},
|
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||||
int num_res_blocks = 2,
|
int num_res_blocks = 2,
|
||||||
std::vector<bool> temperal_downsample = {false, true, true},
|
std::vector<bool> temperal_downsample = {false, true, true},
|
||||||
bool wan2_2 = false)
|
bool wan2_2 = false,
|
||||||
|
bool is_2D = false)
|
||||||
: dim(dim),
|
: dim(dim),
|
||||||
z_dim(z_dim),
|
z_dim(z_dim),
|
||||||
dim_mult(dim_mult),
|
dim_mult(dim_mult),
|
||||||
num_res_blocks(num_res_blocks),
|
num_res_blocks(num_res_blocks),
|
||||||
temperal_downsample(temperal_downsample),
|
temperal_downsample(temperal_downsample),
|
||||||
wan2_2(wan2_2) {
|
wan2_2(wan2_2),
|
||||||
|
is_2D(is_2D) {
|
||||||
// attn_scales is always []
|
// attn_scales is always []
|
||||||
std::vector<int64_t> dims = {dim};
|
std::vector<int64_t> dims = {dim};
|
||||||
for (int u : dim_mult) {
|
for (int u : dim_mult) {
|
||||||
dims.push_back(dim * u);
|
dims.push_back(dim * u);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (wan2_2) {
|
int64_t input_dim = wan2_2 ? 12 : 3;
|
||||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
if (is_2D) {
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(input_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
|
||||||
} else {
|
} else {
|
||||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(input_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
@ -630,12 +680,13 @@ namespace WAN {
|
|||||||
out_dim,
|
out_dim,
|
||||||
num_res_blocks,
|
num_res_blocks,
|
||||||
t_down_flag,
|
t_down_flag,
|
||||||
i != dim_mult.size() - 1));
|
i != dim_mult.size() - 1,
|
||||||
|
is_2D));
|
||||||
|
|
||||||
blocks["downsamples." + std::to_string(index++)] = block;
|
blocks["downsamples." + std::to_string(index++)] = block;
|
||||||
} else {
|
} else {
|
||||||
for (int j = 0; j < num_res_blocks; j++) {
|
for (int j = 0; j < num_res_blocks; j++) {
|
||||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||||
blocks["downsamples." + std::to_string(index++)] = block;
|
blocks["downsamples." + std::to_string(index++)] = block;
|
||||||
in_dim = out_dim;
|
in_dim = out_dim;
|
||||||
}
|
}
|
||||||
@ -648,14 +699,18 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim));
|
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim, is_2D));
|
||||||
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(out_dim));
|
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(out_dim));
|
||||||
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim));
|
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim, is_2D));
|
||||||
|
|
||||||
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||||
// head.1 is nn.SiLU()
|
// head.1 is nn.SiLU()
|
||||||
|
if (is_2D) {
|
||||||
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, z_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
} else {
|
||||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
@ -673,7 +728,10 @@ namespace WAN {
|
|||||||
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
||||||
|
|
||||||
// conv1
|
// conv1
|
||||||
if (feat_cache.size() > 0) {
|
if (is_2D) {
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
|
||||||
|
x = conv1->forward(ctx, x);
|
||||||
|
} else if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
@ -728,7 +786,10 @@ namespace WAN {
|
|||||||
// head
|
// head
|
||||||
x = head_0->forward(ctx, x);
|
x = head_0->forward(ctx, x);
|
||||||
x = ggml_silu(ctx->ggml_ctx, x);
|
x = ggml_silu(ctx->ggml_ctx, x);
|
||||||
if (feat_cache.size() > 0) {
|
if (is_2D) {
|
||||||
|
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
|
||||||
|
x = head_2->forward(ctx, x);
|
||||||
|
} else if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
@ -758,6 +819,7 @@ namespace WAN {
|
|||||||
std::vector<int> dim_mult;
|
std::vector<int> dim_mult;
|
||||||
int num_res_blocks;
|
int num_res_blocks;
|
||||||
std::vector<bool> temperal_upsample;
|
std::vector<bool> temperal_upsample;
|
||||||
|
bool is_2D = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Decoder3d(int64_t dim = 128,
|
Decoder3d(int64_t dim = 128,
|
||||||
@ -765,13 +827,15 @@ namespace WAN {
|
|||||||
std::vector<int> dim_mult = {1, 2, 4, 4},
|
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||||
int num_res_blocks = 2,
|
int num_res_blocks = 2,
|
||||||
std::vector<bool> temperal_upsample = {true, true, false},
|
std::vector<bool> temperal_upsample = {true, true, false},
|
||||||
bool wan2_2 = false)
|
bool wan2_2 = false,
|
||||||
|
bool is_2D = false)
|
||||||
: dim(dim),
|
: dim(dim),
|
||||||
z_dim(z_dim),
|
z_dim(z_dim),
|
||||||
dim_mult(dim_mult),
|
dim_mult(dim_mult),
|
||||||
num_res_blocks(num_res_blocks),
|
num_res_blocks(num_res_blocks),
|
||||||
temperal_upsample(temperal_upsample),
|
temperal_upsample(temperal_upsample),
|
||||||
wan2_2(wan2_2) {
|
wan2_2(wan2_2),
|
||||||
|
is_2D(is_2D) {
|
||||||
// attn_scales is always []
|
// attn_scales is always []
|
||||||
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
|
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
|
||||||
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
|
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
|
||||||
@ -779,12 +843,16 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// init block
|
// init block
|
||||||
|
if(is_2D){
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, dims[0], {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
}else{
|
||||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
// middle blocks
|
// middle blocks
|
||||||
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
|
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
|
||||||
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0]));
|
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0]));
|
||||||
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
|
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0], is_2D));
|
||||||
|
|
||||||
// upsample blocks
|
// upsample blocks
|
||||||
int index = 0;
|
int index = 0;
|
||||||
@ -799,7 +867,8 @@ namespace WAN {
|
|||||||
out_dim,
|
out_dim,
|
||||||
num_res_blocks + 1,
|
num_res_blocks + 1,
|
||||||
t_up_flag,
|
t_up_flag,
|
||||||
i != dim_mult.size() - 1));
|
i != dim_mult.size() - 1,
|
||||||
|
is_2D));
|
||||||
|
|
||||||
blocks["upsamples." + std::to_string(index++)] = block;
|
blocks["upsamples." + std::to_string(index++)] = block;
|
||||||
} else {
|
} else {
|
||||||
@ -807,7 +876,7 @@ namespace WAN {
|
|||||||
in_dim = in_dim / 2;
|
in_dim = in_dim / 2;
|
||||||
}
|
}
|
||||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||||
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim, is_2D));
|
||||||
blocks["upsamples." + std::to_string(index++)] = block;
|
blocks["upsamples." + std::to_string(index++)] = block;
|
||||||
in_dim = out_dim;
|
in_dim = out_dim;
|
||||||
}
|
}
|
||||||
@ -822,12 +891,13 @@ namespace WAN {
|
|||||||
|
|
||||||
// output blocks
|
// output blocks
|
||||||
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||||
|
int64_t final_dim = wan2_2 ? 12 : 3;
|
||||||
// head.1 is nn.SiLU()
|
// head.1 is nn.SiLU()
|
||||||
if (wan2_2) {
|
if (is_2D) {
|
||||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(out_dim, final_dim, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, final_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -847,7 +917,10 @@ namespace WAN {
|
|||||||
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
||||||
|
|
||||||
// conv1
|
// conv1
|
||||||
if (feat_cache.size() > 0) {
|
if (is_2D) {
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
|
||||||
|
x = conv1->forward(ctx, x);
|
||||||
|
} else if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
@ -902,7 +975,10 @@ namespace WAN {
|
|||||||
// head
|
// head
|
||||||
x = head_0->forward(ctx, x);
|
x = head_0->forward(ctx, x);
|
||||||
x = ggml_silu(ctx->ggml_ctx, x);
|
x = ggml_silu(ctx->ggml_ctx, x);
|
||||||
if (feat_cache.size() > 0) {
|
if (is_2D) {
|
||||||
|
auto head_2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["head.2"]);
|
||||||
|
x = head_2->forward(ctx, x);
|
||||||
|
} else if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) {
|
||||||
@ -935,6 +1011,7 @@ namespace WAN {
|
|||||||
int num_res_blocks = 2;
|
int num_res_blocks = 2;
|
||||||
std::vector<bool> temperal_upsample = {true, true, false};
|
std::vector<bool> temperal_upsample = {true, true, false};
|
||||||
std::vector<bool> temperal_downsample = {false, true, true};
|
std::vector<bool> temperal_downsample = {false, true, true};
|
||||||
|
bool is_2D = false;
|
||||||
|
|
||||||
int _conv_num = 33;
|
int _conv_num = 33;
|
||||||
int _conv_idx = 0;
|
int _conv_idx = 0;
|
||||||
@ -951,8 +1028,8 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
WanVAE(bool decode_only = true, bool wan2_2 = false)
|
WanVAE(bool decode_only = true, bool wan2_2 = false, bool is_2D = false)
|
||||||
: decode_only(decode_only), wan2_2(wan2_2) {
|
: decode_only(decode_only), wan2_2(wan2_2), is_2D(is_2D) {
|
||||||
// attn_scales is always []
|
// attn_scales is always []
|
||||||
if (wan2_2) {
|
if (wan2_2) {
|
||||||
dim = 160;
|
dim = 160;
|
||||||
@ -962,13 +1039,27 @@ namespace WAN {
|
|||||||
_conv_num = 34;
|
_conv_num = 34;
|
||||||
_enc_conv_num = 26;
|
_enc_conv_num = 26;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(is_2D){
|
||||||
|
temperal_upsample = {false, false, false};
|
||||||
|
temperal_downsample = {false, false, false};
|
||||||
|
}
|
||||||
|
|
||||||
if (!decode_only) {
|
if (!decode_only) {
|
||||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2));
|
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2, is_2D));
|
||||||
|
if (is_2D) {
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim * 2, z_dim * 2, {1, 1}));
|
||||||
|
} else {
|
||||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2));
|
}
|
||||||
|
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2, is_2D));
|
||||||
|
if (is_2D) {
|
||||||
|
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2dBut3d(z_dim, z_dim, {1, 1}));
|
||||||
|
} else {
|
||||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_tensor* patchify(ggml_context* ctx,
|
static ggml_tensor* patchify(ggml_context* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
@ -1054,7 +1145,12 @@ namespace WAN {
|
|||||||
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
|
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (is_2D) {
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv1"]);
|
||||||
out = conv1->forward(ctx, out);
|
out = conv1->forward(ctx, out);
|
||||||
|
} else {
|
||||||
|
out = conv1->forward(ctx, out);
|
||||||
|
}
|
||||||
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
|
auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0];
|
||||||
// sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu");
|
// sd::ggml_graph_cut::mark_graph_cut(mu, "wan_vae.encode.final", "mu");
|
||||||
clear_cache();
|
clear_cache();
|
||||||
@ -1073,7 +1169,13 @@ namespace WAN {
|
|||||||
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
|
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
|
||||||
|
|
||||||
int64_t iter_ = z->ne[2];
|
int64_t iter_ = z->ne[2];
|
||||||
auto x = conv2->forward(ctx, z);
|
auto x = z;
|
||||||
|
if(is_2D){
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<Conv2dBut3d>(blocks["conv2"]);
|
||||||
|
x = conv2->forward(ctx, z);
|
||||||
|
} else {
|
||||||
|
x = conv2->forward(ctx, z);
|
||||||
|
}
|
||||||
// sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x");
|
// sd::ggml_graph_cut::mark_graph_cut(x, "wan_vae.decode.prelude", "x");
|
||||||
ggml_tensor* out;
|
ggml_tensor* out;
|
||||||
for (int i = 0; i < iter_; i++) {
|
for (int i = 0; i < iter_; i++) {
|
||||||
@ -1129,7 +1231,20 @@ namespace WAN {
|
|||||||
bool decode_only = false,
|
bool decode_only = false,
|
||||||
SDVersion version = VERSION_WAN2,
|
SDVersion version = VERSION_WAN2,
|
||||||
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
|
std::shared_ptr<RunnerWeightManager> weight_manager = nullptr)
|
||||||
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V) {
|
: VAE(version, backend, prefix, weight_manager), decode_only(decode_only) {
|
||||||
|
bool is_2D = false;
|
||||||
|
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||||
|
if (ends_with(name, "decoder.conv1.weight")) {
|
||||||
|
if (tensor_storage.ne[2] > 3) {
|
||||||
|
is_2D = true;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (is_2D) {
|
||||||
|
LOG_DEBUG("USING 2D VAE");
|
||||||
|
}
|
||||||
|
ae = WanVAE(decode_only, version == VERSION_WAN2_2_TI2V, is_2D);
|
||||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user