mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan vae suppport
This commit is contained in:
parent
f6b9aa1a43
commit
e3f9366857
@ -27,6 +27,8 @@
|
|||||||
#define SAFE_STR(s) ((s) ? (s) : "")
|
#define SAFE_STR(s) ((s) ? (s) : "")
|
||||||
#define BOOL_STR(b) ((b) ? "true" : "false")
|
#define BOOL_STR(b) ((b) ? "true" : "false")
|
||||||
|
|
||||||
|
#include "wan.hpp"
|
||||||
|
|
||||||
const char* modes_str[] = {
|
const char* modes_str[] = {
|
||||||
"img_gen",
|
"img_gen",
|
||||||
"vid_gen",
|
"vid_gen",
|
||||||
@ -744,6 +746,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|||||||
|
|
||||||
int main(int argc, const char* argv[]) {
|
int main(int argc, const char* argv[]) {
|
||||||
SDParams params;
|
SDParams params;
|
||||||
|
// params.verbose = true;
|
||||||
|
// sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
||||||
|
|
||||||
|
// WAN::WanVAERunner::load_from_file_and_test(argv[1]);
|
||||||
|
// return 0;
|
||||||
|
|
||||||
parse_args(argc, argv, params);
|
parse_args(argc, argv, params);
|
||||||
|
|
||||||
|
|||||||
216
ggml_extend.hpp
216
ggml_extend.hpp
@ -210,7 +210,7 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
|
|||||||
if (tensor->type == GGML_TYPE_F32) {
|
if (tensor->type == GGML_TYPE_F32) {
|
||||||
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_tensor_get_f32(tensor, l, k, j, i));
|
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_tensor_get_f32(tensor, l, k, j, i));
|
||||||
} else if (tensor->type == GGML_TYPE_F16) {
|
} else if (tensor->type == GGML_TYPE_F16) {
|
||||||
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_f16(tensor, l, k, j, i));
|
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_fp16_to_fp32(ggml_tensor_get_f16(tensor, l, k, j, i)));
|
||||||
} else if (tensor->type == GGML_TYPE_I32) {
|
} else if (tensor->type == GGML_TYPE_I32) {
|
||||||
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_i32(tensor, l, k, j, i));
|
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_i32(tensor, l, k, j, i));
|
||||||
}
|
}
|
||||||
@ -598,6 +598,116 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// torch like permute
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int axis0,
|
||||||
|
int axis1,
|
||||||
|
int axis2,
|
||||||
|
int axis3) {
|
||||||
|
int torch_axes[4] = {axis0, axis1, axis2, axis3};
|
||||||
|
|
||||||
|
int ggml_axes[4] = {0};
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
int found = 0;
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
if (torch_axes[j] == i) {
|
||||||
|
ggml_axes[i] = j;
|
||||||
|
found = 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(found && "Invalid permute input: must be a permutation of 0-3");
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_permute(ctx, x, ggml_axes[0], ggml_axes[1], ggml_axes[2], ggml_axes[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t dim,
|
||||||
|
int64_t start,
|
||||||
|
int64_t end) {
|
||||||
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
if (x->ne[dim] == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
while (start < 0) {
|
||||||
|
start = x->ne[dim] + start;
|
||||||
|
}
|
||||||
|
while (end < 0) {
|
||||||
|
end = x->ne[dim] + end;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(end > start);
|
||||||
|
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
||||||
|
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
||||||
|
|
||||||
|
int perm[4] = {0, 1, 2, 3};
|
||||||
|
for (int i = dim; i < 3; ++i)
|
||||||
|
perm[i] = perm[i + 1];
|
||||||
|
perm[3] = dim;
|
||||||
|
|
||||||
|
int inv_perm[4];
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
inv_perm[perm[i]] = i;
|
||||||
|
|
||||||
|
if (dim != 3) {
|
||||||
|
x = ggml_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = ggml_view_4d(
|
||||||
|
ctx, x,
|
||||||
|
x->ne[0], x->ne[1], x->ne[2], end - start,
|
||||||
|
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
|
||||||
|
|
||||||
|
if (dim != 3) {
|
||||||
|
x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// example: [N, 3*C, H, W] => ([N, C, H, W], [N, C, H, W], [N, C, H, W])
|
||||||
|
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int num,
|
||||||
|
int64_t dim) {
|
||||||
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
GGML_ASSERT(x->ne[dim] % num == 0);
|
||||||
|
|
||||||
|
int perm[4] = {0, 1, 2, 3};
|
||||||
|
for (int i = dim; i < 3; ++i)
|
||||||
|
perm[i] = perm[i + 1];
|
||||||
|
perm[3] = dim;
|
||||||
|
|
||||||
|
int inv_perm[4];
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
inv_perm[perm[i]] = i;
|
||||||
|
|
||||||
|
if (dim != 3) {
|
||||||
|
x = ggml_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<struct ggml_tensor*> chunks;
|
||||||
|
int64_t chunk_size = x->ne[3] / num;
|
||||||
|
for (int i = 0; i < num; i++) {
|
||||||
|
auto chunk = ggml_view_4d(
|
||||||
|
ctx, x,
|
||||||
|
x->ne[0], x->ne[1], x->ne[2], chunk_size,
|
||||||
|
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size);
|
||||||
|
|
||||||
|
if (dim != 3) {
|
||||||
|
chunk = ggml_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||||
|
chunk = ggml_cont(ctx, chunk);
|
||||||
|
}
|
||||||
|
chunks.push_back(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunks;
|
||||||
|
}
|
||||||
|
|
||||||
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
|
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
|
||||||
|
|
||||||
// Tiling
|
// Tiling
|
||||||
@ -706,6 +816,36 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// w: [OC*IC, KD, KH, KW]
|
||||||
|
// x: [N*IC, ID, IH, IW]
|
||||||
|
// b: [OC,]
|
||||||
|
// result: [N*OC, OD, OH, OW]
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* w,
|
||||||
|
struct ggml_tensor* b,
|
||||||
|
int64_t IC,
|
||||||
|
int s0 = 1,
|
||||||
|
int s1 = 1,
|
||||||
|
int s2 = 1,
|
||||||
|
int p0 = 0,
|
||||||
|
int p1 = 0,
|
||||||
|
int p2 = 0,
|
||||||
|
int d0 = 1,
|
||||||
|
int d1 = 1,
|
||||||
|
int d2 = 1) {
|
||||||
|
int64_t OC = w->ne[3] / IC;
|
||||||
|
int64_t N = x->ne[3] / IC;
|
||||||
|
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||||
|
|
||||||
|
|
||||||
|
if (b != NULL) {
|
||||||
|
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||||
|
x = ggml_add(ctx, x, b);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
// w: [OC,IC, KD, 1 * 1]
|
// w: [OC,IC, KD, 1 * 1]
|
||||||
// x: [N, IC, IH, IW]
|
// x: [N, IC, IH, IW]
|
||||||
// b: [OC,]
|
// b: [OC,]
|
||||||
@ -773,6 +913,26 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_qkv(struct ggml_context
|
|||||||
return {q, k, v};
|
return {q, k, v};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// qkv: [N, 3*C, H, W]
|
||||||
|
// return: ([N, C, H, W], [N, C, H, W], [N, C, H, W])
|
||||||
|
__STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* qkv) {
|
||||||
|
int64_t W = qkv->ne[0];
|
||||||
|
int64_t H = qkv->ne[1];
|
||||||
|
int64_t C = qkv->ne[2] / 3;
|
||||||
|
int64_t N = qkv->ne[3];
|
||||||
|
int64_t nb1 = qkv->nb[1];
|
||||||
|
int64_t nb2 = qkv->nb[2];
|
||||||
|
qkv = ggml_reshape_4d(ctx, qkv, W * H, C, 3, N); // [N, 3, C, H*W]
|
||||||
|
qkv = ggml_cont(ctx, ggml_torch_permute(ctx, qkv, 0, 1, 3, 2)); // [3, N, C, H*W]
|
||||||
|
|
||||||
|
int64_t offset = qkv->nb[2] * qkv->ne[2];
|
||||||
|
auto q = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 0); // [N, C, H, W]
|
||||||
|
auto k = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 1); // [N, C, H, W]
|
||||||
|
auto v = ggml_view_4d(ctx, qkv, W, H, C, N, nb1, nb2, qkv->nb[3], offset * 2); // [N, C, H, W]
|
||||||
|
return {q, k, v};
|
||||||
|
}
|
||||||
|
|
||||||
// q: [N * n_head, n_token, d_head]
|
// q: [N * n_head, n_token, d_head]
|
||||||
// k: [N * n_head, n_k, d_head]
|
// k: [N * n_head, n_k, d_head]
|
||||||
// v: [N * n_head, d_head, n_k]
|
// v: [N * n_head, d_head, n_k]
|
||||||
@ -1095,7 +1255,7 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
|||||||
|
|
||||||
/* SDXL with LoRA requires more space */
|
/* SDXL with LoRA requires more space */
|
||||||
#define MAX_PARAMS_TENSOR_NUM 32768
|
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||||
#define MAX_GRAPH_SIZE 32768
|
#define MAX_GRAPH_SIZE 327680
|
||||||
|
|
||||||
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
typedef std::map<std::string, enum ggml_type> String2GGMLType;
|
||||||
|
|
||||||
@ -1547,6 +1707,58 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Conv3d : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t in_channels;
|
||||||
|
int64_t out_channels;
|
||||||
|
std::tuple<int, int, int> kernel_size;
|
||||||
|
std::tuple<int, int, int> stride;
|
||||||
|
std::tuple<int, int, int> padding;
|
||||||
|
std::tuple<int, int, int> dilation;
|
||||||
|
bool bias;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
||||||
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
|
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||||
|
wtype,
|
||||||
|
std::get<2>(kernel_size),
|
||||||
|
std::get<1>(kernel_size),
|
||||||
|
std::get<0>(kernel_size),
|
||||||
|
in_channels * out_channels);
|
||||||
|
if (bias) {
|
||||||
|
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Conv3d(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
std::tuple<int, int, int> kernel_size,
|
||||||
|
std::tuple<int, int, int> stride = {1, 1, 1},
|
||||||
|
std::tuple<int, int, int> padding = {0, 0, 0},
|
||||||
|
std::tuple<int, int, int> dilation = {1, 1, 1},
|
||||||
|
bool bias = true)
|
||||||
|
: in_channels(in_channels),
|
||||||
|
out_channels(out_channels),
|
||||||
|
kernel_size(kernel_size),
|
||||||
|
stride(stride),
|
||||||
|
padding(padding),
|
||||||
|
dilation(dilation),
|
||||||
|
bias(bias) {}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
|
struct ggml_tensor* w = params["weight"];
|
||||||
|
struct ggml_tensor* b = NULL;
|
||||||
|
if (bias) {
|
||||||
|
b = params["bias"];
|
||||||
|
}
|
||||||
|
return ggml_nn_conv_3d(ctx, x, w, b, in_channels,
|
||||||
|
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||||
|
std::get<2>(padding), std::get<1>(padding), std::get<0>(padding),
|
||||||
|
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class LayerNorm : public UnaryBlock {
|
class LayerNorm : public UnaryBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t normalized_shape;
|
int64_t normalized_shape;
|
||||||
|
|||||||
74
ltxv.hpp
Normal file
74
ltxv.hpp
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
#ifndef __LTXV_HPP__
|
||||||
|
#define __LTXV_HPP__
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
|
namespace LTXV {
|
||||||
|
|
||||||
|
class CausalConv3d : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int time_kernel_size;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CausalConv3d(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
int kernel_size = 3,
|
||||||
|
std::tuple<int> stride = {1, 1, 1},
|
||||||
|
int dilation = 1,
|
||||||
|
bool bias = true) {
|
||||||
|
time_kernel_size = kernel_size / 2;
|
||||||
|
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
{kernel_size, kernel_size, kernel_size},
|
||||||
|
stride,
|
||||||
|
{0, kernel_size / 2, kernel_size / 2},
|
||||||
|
{dilation, 1, 1},
|
||||||
|
bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
bool causal = true) {
|
||||||
|
// x: [N*IC, ID, IH, IW]
|
||||||
|
// result: [N*OC, OD, OH, OW]
|
||||||
|
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
|
||||||
|
if (causal) {
|
||||||
|
auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW]
|
||||||
|
auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW]
|
||||||
|
first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW]
|
||||||
|
auto first_frame_pad = first_frame;
|
||||||
|
for (int i = 1; i < time_kernel_size - 1; i++) {
|
||||||
|
first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2);
|
||||||
|
}
|
||||||
|
x = ggml_concat(ctx, first_frame_pad, x, 2);
|
||||||
|
} else {
|
||||||
|
auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW]
|
||||||
|
int64_t offset = h->nb[2] * h->ne[2];
|
||||||
|
|
||||||
|
auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW]
|
||||||
|
first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW]
|
||||||
|
auto first_frame_pad = first_frame;
|
||||||
|
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
|
||||||
|
first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); // [N*IC, IH, IW]
|
||||||
|
last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); // [N*IC, 1, IH, IW]
|
||||||
|
auto last_frame_pad = last_frame;
|
||||||
|
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
|
||||||
|
last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = ggml_concat(ctx, first_frame_pad, x, 2);
|
||||||
|
x = ggml_concat(ctx, x, last_frame_pad, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = conv->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
18
model.cpp
18
model.cpp
@ -684,6 +684,13 @@ void preprocess_tensor(TensorStorage tensor_storage,
|
|||||||
tensor_storage.unsqueeze();
|
tensor_storage.unsqueeze();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wan vae
|
||||||
|
if (ends_with(new_name, "gamma")) {
|
||||||
|
tensor_storage.reverse_ne();
|
||||||
|
tensor_storage.n_dims = 1;
|
||||||
|
tensor_storage.reverse_ne();
|
||||||
|
}
|
||||||
|
|
||||||
tensor_storage.name = new_name;
|
tensor_storage.name = new_name;
|
||||||
|
|
||||||
if (new_name.find("cond_stage_model") != std::string::npos &&
|
if (new_name.find("cond_stage_model") != std::string::npos &&
|
||||||
@ -1085,7 +1092,7 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
|
|||||||
|
|
||||||
// https://huggingface.co/docs/safetensors/index
|
// https://huggingface.co/docs/safetensors/index
|
||||||
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
|
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
|
||||||
LOG_DEBUG("init from '%s'", file_path.c_str());
|
LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
|
||||||
file_paths_.push_back(file_path);
|
file_paths_.push_back(file_path);
|
||||||
size_t file_index = file_paths_.size() - 1;
|
size_t file_index = file_paths_.size() - 1;
|
||||||
std::ifstream file(file_path, std::ios::binary);
|
std::ifstream file(file_path, std::ios::binary);
|
||||||
@ -1171,12 +1178,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (n_dims == 5) {
|
if (n_dims == 5) {
|
||||||
if (ne[3] == 1 && ne[4] == 1) {
|
|
||||||
n_dims = 4;
|
n_dims = 4;
|
||||||
} else {
|
ne[0] = ne[0]*ne[1];
|
||||||
LOG_ERROR("invalid tensor '%s'", name.c_str());
|
ne[1] = ne[2];
|
||||||
return false;
|
ne[2] = ne[3];
|
||||||
}
|
ne[3] = ne[4];
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_n_dims returns 1 for scalars
|
// ggml_n_dims returns 1 for scalars
|
||||||
|
|||||||
833
wan.hpp
Normal file
833
wan.hpp
Normal file
@ -0,0 +1,833 @@
|
|||||||
|
#ifndef __WAN_HPP__
|
||||||
|
#define __WAN_HPP__
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "common.hpp"
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
|
namespace WAN {
|
||||||
|
|
||||||
|
constexpr int CACHE_T = 2;
|
||||||
|
|
||||||
|
class CausalConv3d : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t in_channels;
|
||||||
|
int64_t out_channels;
|
||||||
|
std::tuple<int, int, int> kernel_size;
|
||||||
|
std::tuple<int, int, int> stride;
|
||||||
|
std::tuple<int, int, int> padding;
|
||||||
|
std::tuple<int, int, int> dilation;
|
||||||
|
bool bias;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
|
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||||
|
GGML_TYPE_F16,
|
||||||
|
std::get<2>(kernel_size),
|
||||||
|
std::get<1>(kernel_size),
|
||||||
|
std::get<0>(kernel_size),
|
||||||
|
in_channels * out_channels);
|
||||||
|
if (bias) {
|
||||||
|
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
CausalConv3d(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
std::tuple<int, int, int> kernel_size,
|
||||||
|
std::tuple<int, int, int> stride = {1, 1, 1},
|
||||||
|
std::tuple<int, int, int> padding = {0, 0, 0},
|
||||||
|
std::tuple<int, int, int> dilation = {1, 1, 1},
|
||||||
|
bool bias = true)
|
||||||
|
: in_channels(in_channels),
|
||||||
|
out_channels(out_channels),
|
||||||
|
kernel_size(kernel_size),
|
||||||
|
stride(stride),
|
||||||
|
padding(padding),
|
||||||
|
dilation(dilation),
|
||||||
|
bias(bias) {}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = NULL) {
|
||||||
|
// x: [N*IC, ID, IH, IW]
|
||||||
|
// result: x: [N*OC, ID, IH, IW]
|
||||||
|
struct ggml_tensor* w = params["weight"];
|
||||||
|
struct ggml_tensor* b = NULL;
|
||||||
|
if (bias) {
|
||||||
|
b = params["bias"];
|
||||||
|
}
|
||||||
|
|
||||||
|
int lp0 = std::get<2>(padding);
|
||||||
|
int rp0 = std::get<2>(padding);
|
||||||
|
int lp1 = std::get<1>(padding);
|
||||||
|
int rp1 = std::get<1>(padding);
|
||||||
|
int lp2 = 2 * std::get<0>(padding);
|
||||||
|
int rp2 = 0;
|
||||||
|
|
||||||
|
if (cache_x != NULL && std::get<0>(padding) > 0) {
|
||||||
|
x = ggml_concat(ctx, cache_x, x, 2);
|
||||||
|
lp2 -= (int)cache_x->ne[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0);
|
||||||
|
return ggml_nn_conv_3d(ctx, x, w, b, in_channels,
|
||||||
|
std::get<2>(stride), std::get<1>(stride), std::get<0>(stride),
|
||||||
|
0, 0, 0,
|
||||||
|
std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class RMS_norm : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
|
ggml_type wtype = GGML_TYPE_F32;
|
||||||
|
params["gamma"] = ggml_new_tensor_1d(ctx, wtype, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
RMS_norm(int64_t dim)
|
||||||
|
: dim(dim) {}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
|
// x: [N*IC, ID, IH, IW], IC == dim
|
||||||
|
// assert N == 1
|
||||||
|
|
||||||
|
struct ggml_tensor* w = params["gamma"];
|
||||||
|
auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||||
|
h = ggml_rms_norm(ctx, h, 1e-12);
|
||||||
|
h = ggml_mul(ctx, h, w);
|
||||||
|
h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
||||||
|
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Resample : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
std::string mode;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Resample(int64_t dim, const std::string& mode)
|
||||||
|
: dim(dim), mode(mode) {
|
||||||
|
if (mode == "upsample2d") {
|
||||||
|
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
} else if (mode == "upsample3d") {
|
||||||
|
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim / 2, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
blocks["time_conv"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(dim, dim * 2, {3, 1, 1}, {1, 1, 1}, {1, 0, 0}));
|
||||||
|
} else if (mode == "downsample2d") {
|
||||||
|
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {2, 2}));
|
||||||
|
} else if (mode == "downsample3d") {
|
||||||
|
blocks["resample.1"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {3, 3}, {2, 2}));
|
||||||
|
blocks["time_conv"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(dim, dim, {3, 1, 1}, {2, 1, 1}, {0, 0, 0}));
|
||||||
|
} else if (mode == "none") {
|
||||||
|
// nn.Identity()
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "invalid mode");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t b,
|
||||||
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
|
int& feat_idx) {
|
||||||
|
// x: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
int64_t c = x->ne[3] / b;
|
||||||
|
int64_t t = x->ne[2];
|
||||||
|
int64_t h = x->ne[1];
|
||||||
|
int64_t w = x->ne[0];
|
||||||
|
|
||||||
|
struct ggml_tensor* Rep = (struct ggml_tensor*)1;
|
||||||
|
|
||||||
|
if (mode == "upsample3d") {
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
if (feat_cache[idx] == NULL) {
|
||||||
|
feat_cache[idx] = Rep; // Rep
|
||||||
|
feat_idx += 1;
|
||||||
|
} else {
|
||||||
|
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
||||||
|
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] != Rep) {
|
||||||
|
// cache last frame of last two chunk
|
||||||
|
cache_x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
cache_x,
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
if (cache_x->ne[1] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) {
|
||||||
|
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 1, 1, (int)cache_x->ne[2], 0, 0, 0);
|
||||||
|
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
|
||||||
|
}
|
||||||
|
if (feat_cache[idx] == Rep) {
|
||||||
|
x = time_conv->forward(ctx, x);
|
||||||
|
} else {
|
||||||
|
x = time_conv->forward(ctx, x, feat_cache[idx]);
|
||||||
|
}
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||||
|
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t = x->ne[2];
|
||||||
|
if (mode != "none") {
|
||||||
|
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
|
||||||
|
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
|
if (mode == "upsample2d") {
|
||||||
|
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
|
} else if (mode == "upsample3d") {
|
||||||
|
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
|
} else if (mode == "downsample2d") {
|
||||||
|
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
||||||
|
} else if (mode == "downsample3d") {
|
||||||
|
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
||||||
|
}
|
||||||
|
x = resample_1->forward(ctx, x);
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode == "downsample3d") {
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
if (feat_cache[idx] == NULL) {
|
||||||
|
feat_cache[idx] = x;
|
||||||
|
feat_idx += 1;
|
||||||
|
} else {
|
||||||
|
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
||||||
|
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -1, x->ne[2]);
|
||||||
|
x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
x,
|
||||||
|
2);
|
||||||
|
x = time_conv->forward(ctx, x);
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ResidualBlock : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t in_dim;
|
||||||
|
int64_t out_dim;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ResidualBlock(int64_t in_dim, int64_t out_dim)
|
||||||
|
: in_dim(in_dim), out_dim(out_dim) {
|
||||||
|
blocks["residual.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(in_dim));
|
||||||
|
// residual.1 is nn.SiLU()
|
||||||
|
blocks["residual.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
blocks["residual.3"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||||
|
// residual.4 is nn.SiLU()
|
||||||
|
// residual.5 is nn.Dropout()
|
||||||
|
blocks["residual.6"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
if (in_dim != out_dim) {
|
||||||
|
blocks["shortcut"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_dim, out_dim, {1, 1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t b,
|
||||||
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
|
int& feat_idx) {
|
||||||
|
// x: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
struct ggml_tensor* h = x;
|
||||||
|
if (in_dim != out_dim) {
|
||||||
|
auto shortcut = std::dynamic_pointer_cast<CausalConv3d>(blocks["shortcut"]);
|
||||||
|
|
||||||
|
h = shortcut->forward(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 7; i++) {
|
||||||
|
if (i == 0 || i == 3) { // RMS_norm
|
||||||
|
auto layer = std::dynamic_pointer_cast<RMS_norm>(blocks["residual." + std::to_string(i)]);
|
||||||
|
x = layer->forward(ctx, x);
|
||||||
|
} else if (i == 2 || i == 6) { // CausalConv3d
|
||||||
|
auto layer = std::dynamic_pointer_cast<CausalConv3d>(blocks["residual." + std::to_string(i)]);
|
||||||
|
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) {
|
||||||
|
// cache last frame of last two chunk
|
||||||
|
cache_x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
cache_x,
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = layer->forward(ctx, x, feat_cache[idx]);
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
}
|
||||||
|
} else if (i == 1 || i == 4) {
|
||||||
|
x = ggml_silu(ctx, x);
|
||||||
|
} else { // i == 5
|
||||||
|
// nn.Dropout(), ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
x = ggml_add(ctx, x, h);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AttentionBlock : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AttentionBlock(int64_t dim)
|
||||||
|
: dim(dim) {
|
||||||
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMS_norm(dim));
|
||||||
|
blocks["to_qkv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim * 3, {1, 1}));
|
||||||
|
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv2d(dim, dim, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t b) {
|
||||||
|
// x: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMS_norm>(blocks["norm"]);
|
||||||
|
auto to_qkv = std::dynamic_pointer_cast<Conv2d>(blocks["to_qkv"]);
|
||||||
|
auto proj = std::dynamic_pointer_cast<Conv2d>(blocks["proj"]);
|
||||||
|
|
||||||
|
auto identity = x;
|
||||||
|
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
|
|
||||||
|
const int64_t n = x->ne[3];
|
||||||
|
const int64_t c = x->ne[2];
|
||||||
|
const int64_t h = x->ne[1];
|
||||||
|
const int64_t w = x->ne[0];
|
||||||
|
|
||||||
|
auto qkv = to_qkv->forward(ctx, x);
|
||||||
|
auto qkv_vec = split_image_qkv(ctx, qkv);
|
||||||
|
|
||||||
|
auto q = qkv_vec[0];
|
||||||
|
q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
||||||
|
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
|
||||||
|
|
||||||
|
auto k = qkv_vec[1];
|
||||||
|
k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
||||||
|
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
|
||||||
|
|
||||||
|
auto v = qkv_vec[2];
|
||||||
|
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
|
||||||
|
|
||||||
|
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
||||||
|
|
||||||
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
|
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
|
|
||||||
|
x = proj->forward(ctx, x);
|
||||||
|
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
|
|
||||||
|
x = ggml_add(ctx, x, identity);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Encoder3d : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
int64_t z_dim;
|
||||||
|
std::vector<int> dim_mult;
|
||||||
|
int num_res_blocks;
|
||||||
|
std::vector<bool> temperal_downsample;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Encoder3d(int64_t dim = 128,
|
||||||
|
int64_t z_dim = 4,
|
||||||
|
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||||
|
int num_res_blocks = 2,
|
||||||
|
std::vector<bool> temperal_downsample = {false, true, true})
|
||||||
|
: dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_downsample(temperal_downsample) {
|
||||||
|
// attn_scales is always []
|
||||||
|
std::vector<int64_t> dims = {dim};
|
||||||
|
for (int u : dim_mult) {
|
||||||
|
dims.push_back(dim * u);
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
int64_t in_dim;
|
||||||
|
int64_t out_dim;
|
||||||
|
for (int i = 0; i < dims.size() - 1; i++) {
|
||||||
|
in_dim = dims[i];
|
||||||
|
out_dim = dims[i + 1];
|
||||||
|
for (int j = 0; j < num_res_blocks; j++) {
|
||||||
|
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||||
|
blocks["downsamples." + std::to_string(index++)] = block;
|
||||||
|
in_dim = out_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i != dim_mult.size() - 1) {
|
||||||
|
std::string mode = temperal_downsample[i] ? "downsample3d" : "downsample2d";
|
||||||
|
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
|
||||||
|
blocks["downsamples." + std::to_string(index++)] = block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim));
|
||||||
|
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(out_dim));
|
||||||
|
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(out_dim, out_dim));
|
||||||
|
|
||||||
|
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||||
|
// head.1 is nn.SiLU()
|
||||||
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t b,
|
||||||
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
|
int& feat_idx) {
|
||||||
|
// x: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||||
|
auto middle_0 = std::dynamic_pointer_cast<ResidualBlock>(blocks["middle.0"]);
|
||||||
|
auto middle_1 = std::dynamic_pointer_cast<AttentionBlock>(blocks["middle.1"]);
|
||||||
|
auto middle_2 = std::dynamic_pointer_cast<ResidualBlock>(blocks["middle.2"]);
|
||||||
|
auto head_0 = std::dynamic_pointer_cast<RMS_norm>(blocks["head.0"]);
|
||||||
|
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
||||||
|
|
||||||
|
// conv1
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) {
|
||||||
|
// cache last frame of last two chunk
|
||||||
|
cache_x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
cache_x,
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = conv1->forward(ctx, x, feat_cache[idx]);
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
} else {
|
||||||
|
x = conv1->forward(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// downsamples
|
||||||
|
std::vector<int64_t> dims = {dim};
|
||||||
|
for (int u : dim_mult) {
|
||||||
|
dims.push_back(dim * u);
|
||||||
|
}
|
||||||
|
int index = 0;
|
||||||
|
for (int i = 0; i < dims.size() - 1; i++) {
|
||||||
|
for (int j = 0; j < num_res_blocks; j++) {
|
||||||
|
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i != dim_mult.size() - 1) {
|
||||||
|
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// middle
|
||||||
|
x = middle_0->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
x = middle_1->forward(ctx, x, b);
|
||||||
|
x = middle_2->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
|
||||||
|
// head
|
||||||
|
x = head_0->forward(ctx, x);
|
||||||
|
x = ggml_silu(ctx, x);
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) {
|
||||||
|
// cache last frame of last two chunk
|
||||||
|
cache_x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
cache_x,
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = head_2->forward(ctx, x, feat_cache[idx]);
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
} else {
|
||||||
|
x = head_2->forward(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Decoder3d : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int64_t dim;
|
||||||
|
int64_t z_dim;
|
||||||
|
std::vector<int> dim_mult;
|
||||||
|
int num_res_blocks;
|
||||||
|
std::vector<bool> temperal_upsample;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Decoder3d(int64_t dim = 128,
|
||||||
|
int64_t z_dim = 4,
|
||||||
|
std::vector<int> dim_mult = {1, 2, 4, 4},
|
||||||
|
int num_res_blocks = 2,
|
||||||
|
std::vector<bool> temperal_upsample = {true, true, false})
|
||||||
|
: dim(dim), z_dim(z_dim), dim_mult(dim_mult), num_res_blocks(num_res_blocks), temperal_upsample(temperal_upsample) {
|
||||||
|
// attn_scales is always []
|
||||||
|
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
|
||||||
|
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
|
||||||
|
dims.push_back(dim * dim_mult[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// init block
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
|
||||||
|
// middle blocks
|
||||||
|
blocks["middle.0"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
|
||||||
|
blocks["middle.1"] = std::shared_ptr<GGMLBlock>(new AttentionBlock(dims[0]));
|
||||||
|
blocks["middle.2"] = std::shared_ptr<GGMLBlock>(new ResidualBlock(dims[0], dims[0]));
|
||||||
|
|
||||||
|
// upsample blocks
|
||||||
|
int index = 0;
|
||||||
|
int64_t in_dim;
|
||||||
|
int64_t out_dim;
|
||||||
|
for (int i = 0; i < dims.size() - 1; i++) {
|
||||||
|
in_dim = dims[i];
|
||||||
|
out_dim = dims[i + 1];
|
||||||
|
LOG_DEBUG("in_dim %u out_dim %u", in_dim, out_dim);
|
||||||
|
if (i == 1 || i == 2 || i == 3) {
|
||||||
|
in_dim = in_dim / 2;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||||
|
auto block = std::shared_ptr<GGMLBlock>(new ResidualBlock(in_dim, out_dim));
|
||||||
|
blocks["upsamples." + std::to_string(index++)] = block;
|
||||||
|
in_dim = out_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i != dim_mult.size() - 1) {
|
||||||
|
std::string mode = temperal_upsample[i] ? "upsample3d" : "upsample2d";
|
||||||
|
auto block = std::shared_ptr<GGMLBlock>(new Resample(out_dim, mode));
|
||||||
|
blocks["upsamples." + std::to_string(index++)] = block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// output blocks
|
||||||
|
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
|
||||||
|
// head.1 is nn.SiLU()
|
||||||
|
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t b,
|
||||||
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
|
int& feat_idx) {
|
||||||
|
// x: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||||
|
auto middle_0 = std::dynamic_pointer_cast<ResidualBlock>(blocks["middle.0"]);
|
||||||
|
auto middle_1 = std::dynamic_pointer_cast<AttentionBlock>(blocks["middle.1"]);
|
||||||
|
auto middle_2 = std::dynamic_pointer_cast<ResidualBlock>(blocks["middle.2"]);
|
||||||
|
auto head_0 = std::dynamic_pointer_cast<RMS_norm>(blocks["head.0"]);
|
||||||
|
auto head_2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["head.2"]);
|
||||||
|
|
||||||
|
// conv1
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) {
|
||||||
|
// cache last frame of last two chunk
|
||||||
|
cache_x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
cache_x,
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = conv1->forward(ctx, x, feat_cache[idx]);
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
} else {
|
||||||
|
x = conv1->forward(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// middle
|
||||||
|
x = middle_0->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
x = middle_1->forward(ctx, x, b);
|
||||||
|
x = middle_2->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
|
||||||
|
// upsamples
|
||||||
|
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
|
||||||
|
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
|
||||||
|
dims.push_back(dim * dim_mult[i]);
|
||||||
|
}
|
||||||
|
int index = 0;
|
||||||
|
for (int i = 0; i < dims.size() - 1; i++) {
|
||||||
|
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||||
|
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i != dim_mult.size() - 1) {
|
||||||
|
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// head
|
||||||
|
x = head_0->forward(ctx, x);
|
||||||
|
x = ggml_silu(ctx, x);
|
||||||
|
if (feat_cache.size() > 0) {
|
||||||
|
int idx = feat_idx;
|
||||||
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) {
|
||||||
|
// cache last frame of last two chunk
|
||||||
|
cache_x = ggml_concat(ctx,
|
||||||
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
|
cache_x,
|
||||||
|
2);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = head_2->forward(ctx, x, feat_cache[idx]);
|
||||||
|
feat_cache[idx] = cache_x;
|
||||||
|
feat_idx += 1;
|
||||||
|
} else {
|
||||||
|
x = head_2->forward(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class WanVAE : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
bool decode_only = true;
|
||||||
|
int64_t dim = 96;
|
||||||
|
int64_t z_dim = 16;
|
||||||
|
std::vector<int> dim_mult = {1, 2, 4, 4};
|
||||||
|
int num_res_blocks = 2;
|
||||||
|
std::vector<bool> temperal_upsample = {true, true, false};
|
||||||
|
std::vector<bool> temperal_downsample = {false, true, true};
|
||||||
|
|
||||||
|
int _conv_num = 33;
|
||||||
|
int _conv_idx = 0;
|
||||||
|
std::vector<struct ggml_tensor*> _feat_map;
|
||||||
|
int _enc_conv_num = 28;
|
||||||
|
int _enc_conv_idx = 0;
|
||||||
|
std::vector<struct ggml_tensor*> _enc_feat_map;
|
||||||
|
|
||||||
|
void clear_cache() {
|
||||||
|
_conv_idx = 0;
|
||||||
|
_feat_map = std::vector<struct ggml_tensor*>(_conv_num, NULL);
|
||||||
|
_enc_conv_idx = 0;
|
||||||
|
_enc_feat_map = std::vector<struct ggml_tensor*>(_enc_conv_num, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
WanVAE(bool decode_only = true)
|
||||||
|
: decode_only(decode_only) {
|
||||||
|
// attn_scales is always []
|
||||||
|
if (!decode_only) {
|
||||||
|
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample));
|
||||||
|
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
|
||||||
|
}
|
||||||
|
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dim, z_dim, dim_mult, num_res_blocks, temperal_upsample));
|
||||||
|
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* encode(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t b = 1) {
|
||||||
|
// x: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
GGML_ASSERT(decode_only == false);
|
||||||
|
|
||||||
|
clear_cache();
|
||||||
|
|
||||||
|
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||||
|
|
||||||
|
int64_t t = x->ne[2];
|
||||||
|
int64_t iter_ = 1 + (t - 1) / 4;
|
||||||
|
struct ggml_tensor* out;
|
||||||
|
for (int i = 0; i < iter_; i++) {
|
||||||
|
_enc_conv_idx = 0;
|
||||||
|
if (i == 0) {
|
||||||
|
auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
|
||||||
|
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
|
||||||
|
} else {
|
||||||
|
auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
|
||||||
|
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
|
||||||
|
out = ggml_concat(ctx, out, out_, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = conv1->forward(ctx, out);
|
||||||
|
auto mu = ggml_chunk(ctx, out, 2, 3)[0];
|
||||||
|
clear_cache();
|
||||||
|
return mu;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* decode(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
int64_t b = 1) {
|
||||||
|
// z: [b*c, t, h, w]
|
||||||
|
GGML_ASSERT(b == 1);
|
||||||
|
|
||||||
|
clear_cache();
|
||||||
|
|
||||||
|
auto decoder = std::dynamic_pointer_cast<Decoder3d>(blocks["decoder"]);
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
|
||||||
|
|
||||||
|
int64_t iter_ = z->ne[2];
|
||||||
|
auto x = conv2->forward(ctx, z);
|
||||||
|
struct ggml_tensor* out;
|
||||||
|
for (int64_t i = 0; i < iter_; i++) {
|
||||||
|
_conv_idx = 0;
|
||||||
|
if (i == 0) {
|
||||||
|
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
|
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
||||||
|
} else {
|
||||||
|
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
|
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
||||||
|
out = ggml_concat(ctx, out, out_, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clear_cache();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WanVAERunner : public GGMLRunner {
|
||||||
|
bool decode_only = true;
|
||||||
|
WanVAE ae;
|
||||||
|
|
||||||
|
WanVAERunner(ggml_backend_t backend,
|
||||||
|
const String2GGMLType& tensor_types = {},
|
||||||
|
const std::string prefix = "",
|
||||||
|
bool decode_only = false)
|
||||||
|
: decode_only(decode_only), ae(decode_only), GGMLRunner(backend) {
|
||||||
|
ae.init(params_ctx, tensor_types, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() {
|
||||||
|
return "wan_vae";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
ae.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||||
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false);
|
||||||
|
|
||||||
|
z = to_backend(z);
|
||||||
|
|
||||||
|
struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx = NULL) {
|
||||||
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
|
return build_graph(z, decode_graph);
|
||||||
|
};
|
||||||
|
// ggml_set_f32(z, 0.5f);
|
||||||
|
// print_ggml_tensor(z);
|
||||||
|
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test() {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
|
||||||
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
|
GGML_ASSERT(work_ctx != NULL);
|
||||||
|
|
||||||
|
if (true) {
|
||||||
|
// cpu f32, pass
|
||||||
|
// cpu f16, pass
|
||||||
|
// cuda f16, pass
|
||||||
|
// cuda f32, pass
|
||||||
|
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 1, 16);
|
||||||
|
z = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||||
|
// ggml_set_f32(z, 0.5f);
|
||||||
|
print_ggml_tensor(z);
|
||||||
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
compute(8, z, true, &out, work_ctx);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("decode test done in %ldms", t1 - t0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
|
ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
|
// ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
|
ggml_type model_data_type = GGML_TYPE_F32;
|
||||||
|
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend));
|
||||||
|
{
|
||||||
|
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||||
|
|
||||||
|
vae->alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
vae->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file(file_path, "vae.")) {
|
||||||
|
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(tensors, backend);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tensors from model loader failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("vae model loaded");
|
||||||
|
}
|
||||||
|
vae->test();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
Loading…
x
Reference in New Issue
Block a user