mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
make wan a little faster
This commit is contained in:
parent
00f790d0e9
commit
73f76e6d96
@ -608,6 +608,14 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
if (ggml_is_contiguous(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return ggml_cont(ctx, x);
|
||||||
|
}
|
||||||
|
|
||||||
// torch like permute
|
// torch like permute
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
@ -799,7 +807,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
|
|||||||
struct ggml_tensor* b) {
|
struct ggml_tensor* b) {
|
||||||
x = ggml_mul_mat(ctx, w, x);
|
x = ggml_mul_mat(ctx, w, x);
|
||||||
if (b != NULL) {
|
if (b != NULL) {
|
||||||
x = ggml_add(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -822,7 +830,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
|
|||||||
if (b != NULL) {
|
if (b != NULL) {
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||||
// b = ggml_repeat(ctx, b, x);
|
// b = ggml_repeat(ctx, b, x);
|
||||||
x = ggml_add(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -851,45 +859,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
|
|||||||
|
|
||||||
if (b != NULL) {
|
if (b != NULL) {
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||||
x = ggml_add(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// w: [OC,IC, KD, 1 * 1]
|
|
||||||
// x: [N, IC, IH, IW]
|
|
||||||
// b: [OC,]
|
|
||||||
// result: [N, OC, OH, OW]
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1_bak(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
struct ggml_tensor* w,
|
|
||||||
struct ggml_tensor* b,
|
|
||||||
int s2 = 1,
|
|
||||||
int p2 = 1,
|
|
||||||
int d2 = 1) {
|
|
||||||
GGML_ASSERT(w->ne[0] == 1);
|
|
||||||
// timesteps = x.shape[0]
|
|
||||||
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
|
||||||
// x = conv3d(x)
|
|
||||||
// return rearrange(x, "b c t h w -> (b t) c h w")
|
|
||||||
int64_t T = x->ne[3];
|
|
||||||
int64_t B = x->ne[3] / T;
|
|
||||||
int64_t C = x->ne[2];
|
|
||||||
int64_t H = x->ne[1];
|
|
||||||
int64_t W = x->ne[0];
|
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
|
||||||
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [B, OC, T, OH * OW]
|
|
||||||
if (b != NULL) {
|
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
|
||||||
x = ggml_add(ctx, x, b);
|
|
||||||
}
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
|
||||||
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
|
||||||
return x; // [B*T, OC, OH, OW]
|
|
||||||
}
|
|
||||||
|
|
||||||
// w: [OC,IC, KD, 1 * 1]
|
// w: [OC,IC, KD, 1 * 1]
|
||||||
// x: [N, IC, ID, IH*IW]
|
// x: [N, IC, ID, IH*IW]
|
||||||
// b: [OC,]
|
// b: [OC,]
|
||||||
@ -992,11 +966,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
N = q->ne[2];
|
N = q->ne[2];
|
||||||
d_head = C / n_head;
|
d_head = C / n_head;
|
||||||
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
||||||
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||||
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
||||||
|
|
||||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||||
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||||
|
|
||||||
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||||
@ -1047,14 +1021,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
||||||
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
|
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
// k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||||
|
|
||||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
|
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
// v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
||||||
|
|
||||||
if (mask != nullptr) {
|
if (mask != nullptr) {
|
||||||
mask = ggml_transpose(ctx, mask);
|
mask = ggml_transpose(ctx, mask);
|
||||||
@ -1074,7 +1048,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
|
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
|
||||||
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
||||||
} else {
|
} else {
|
||||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
||||||
|
|
||||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||||
@ -1093,7 +1067,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
|
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
|
||||||
}
|
}
|
||||||
|
|
||||||
kqv = ggml_cont(ctx, kqv);
|
kqv = ggml_nn_cont(ctx, kqv);
|
||||||
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
|
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
|
||||||
|
|
||||||
return kqv;
|
return kqv;
|
||||||
@ -1106,9 +1080,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ct
|
|||||||
float eps = EPS) {
|
float eps = EPS) {
|
||||||
x = ggml_norm(ctx, x, eps);
|
x = ggml_norm(ctx, x, eps);
|
||||||
if (w != NULL) {
|
if (w != NULL) {
|
||||||
x = ggml_mul(ctx, x, w);
|
x = ggml_mul_inplace(ctx, x, w);
|
||||||
if (b != NULL) {
|
if (b != NULL) {
|
||||||
x = ggml_add(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
@ -1127,9 +1101,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
|
|||||||
const float eps = 1e-6f; // default eps parameter
|
const float eps = 1e-6f; // default eps parameter
|
||||||
x = ggml_group_norm(ctx, x, num_groups, eps);
|
x = ggml_group_norm(ctx, x, num_groups, eps);
|
||||||
if (w != NULL && b != NULL) {
|
if (w != NULL && b != NULL) {
|
||||||
x = ggml_mul(ctx, x, w);
|
x = ggml_mul_inplace(ctx, x, w);
|
||||||
// b = ggml_repeat(ctx, b, x);
|
// b = ggml_repeat(ctx, b, x);
|
||||||
x = ggml_add(ctx, x, b);
|
x = ggml_add_inplace(ctx, x, b);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -1874,7 +1848,7 @@ public:
|
|||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
x = ggml_rms_norm(ctx, x, eps);
|
x = ggml_rms_norm(ctx, x, eps);
|
||||||
x = ggml_mul(ctx, x, w);
|
x = ggml_mul_inplace(ctx, x, w);
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -882,8 +882,6 @@ public:
|
|||||||
float img_cfg_scale = guidance.img_cfg;
|
float img_cfg_scale = guidance.img_cfg;
|
||||||
float slg_scale = guidance.slg.scale;
|
float slg_scale = guidance.slg.scale;
|
||||||
|
|
||||||
LOG_DEBUG("cfg_scale %.2f", cfg_scale);
|
|
||||||
|
|
||||||
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
|
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
|
||||||
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
|
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
|
||||||
img_cfg_scale = cfg_scale;
|
img_cfg_scale = cfg_scale;
|
||||||
@ -1215,7 +1213,6 @@ public:
|
|||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
if (!use_tiny_autoencoder) {
|
if (!use_tiny_autoencoder) {
|
||||||
LOG_DEBUG("scale_factor %.2f", scale_factor);
|
|
||||||
process_latent_out(x);
|
process_latent_out(x);
|
||||||
if (vae_tiling && !decode_video) {
|
if (vae_tiling && !decode_video) {
|
||||||
// split latent in 32x32 tiles and compute in several steps
|
// split latent in 32x32 tiles and compute in several steps
|
||||||
|
|||||||
36
wan.hpp
36
wan.hpp
@ -99,10 +99,10 @@ namespace WAN {
|
|||||||
// assert N == 1
|
// assert N == 1
|
||||||
|
|
||||||
struct ggml_tensor* w = params["gamma"];
|
struct ggml_tensor* w = params["gamma"];
|
||||||
auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
auto h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||||
h = ggml_rms_norm(ctx, h, 1e-12);
|
h = ggml_rms_norm(ctx, h, 1e-12);
|
||||||
h = ggml_mul(ctx, h, w);
|
h = ggml_mul(ctx, h, w);
|
||||||
h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
||||||
|
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
@ -176,7 +176,7 @@ namespace WAN {
|
|||||||
feat_cache[idx] = cache_x;
|
feat_cache[idx] = cache_x;
|
||||||
feat_idx += 1;
|
feat_idx += 1;
|
||||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -186,7 +186,7 @@ namespace WAN {
|
|||||||
if (mode != "none") {
|
if (mode != "none") {
|
||||||
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
|
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
if (mode == "upsample2d") {
|
if (mode == "upsample2d") {
|
||||||
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||||
} else if (mode == "upsample3d") {
|
} else if (mode == "upsample3d") {
|
||||||
@ -197,7 +197,7 @@ namespace WAN {
|
|||||||
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
||||||
}
|
}
|
||||||
x = resample_1->forward(ctx, x);
|
x = resample_1->forward(ctx, x);
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode == "downsample3d") {
|
if (mode == "downsample3d") {
|
||||||
@ -318,7 +318,7 @@ namespace WAN {
|
|||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
|
|
||||||
const int64_t n = x->ne[3];
|
const int64_t n = x->ne[3];
|
||||||
const int64_t c = x->ne[2];
|
const int64_t c = x->ne[2];
|
||||||
@ -329,11 +329,11 @@ namespace WAN {
|
|||||||
auto qkv_vec = split_image_qkv(ctx, qkv);
|
auto qkv_vec = split_image_qkv(ctx, qkv);
|
||||||
|
|
||||||
auto q = qkv_vec[0];
|
auto q = qkv_vec[0];
|
||||||
q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
q = ggml_nn_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
||||||
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
|
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
|
||||||
|
|
||||||
auto k = qkv_vec[1];
|
auto k = qkv_vec[1];
|
||||||
k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
k = ggml_nn_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
||||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
|
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
|
||||||
|
|
||||||
auto v = qkv_vec[2];
|
auto v = qkv_vec[2];
|
||||||
@ -341,12 +341,12 @@ namespace WAN {
|
|||||||
|
|
||||||
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
|
|
||||||
x = proj->forward(ctx, x);
|
x = proj->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||||
|
|
||||||
x = ggml_add(ctx, x, identity);
|
x = ggml_add(ctx, x, identity);
|
||||||
return x;
|
return x;
|
||||||
@ -987,11 +987,11 @@ namespace WAN {
|
|||||||
int64_t dim = x->ne[2];
|
int64_t dim = x->ne[2];
|
||||||
int64_t context_txt_len = context->ne[1] - context_img_len;
|
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||||
|
|
||||||
context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||||
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
||||||
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
|
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
|
||||||
context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||||
context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
auto q = q_proj->forward(ctx, x);
|
auto q = q_proj->forward(ctx, x);
|
||||||
q = norm_q->forward(ctx, q);
|
q = norm_q->forward(ctx, q);
|
||||||
@ -1294,13 +1294,13 @@ namespace WAN {
|
|||||||
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
|
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
|
||||||
|
|
||||||
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
|
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
||||||
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
|
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
||||||
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||||
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||||
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -1331,7 +1331,7 @@ namespace WAN {
|
|||||||
// patch_embedding
|
// patch_embedding
|
||||||
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
|
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
|
||||||
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
||||||
|
|
||||||
// time_embedding
|
// time_embedding
|
||||||
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
|
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user