make wan a little faster

This commit is contained in:
leejet 2025-08-10 18:29:30 +08:00
parent 00f790d0e9
commit 73f76e6d96
3 changed files with 52 additions and 81 deletions

View File

@ -608,6 +608,14 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
}
}
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx,
struct ggml_tensor* x) {
if (ggml_is_contiguous(x)) {
return x;
}
return ggml_cont(ctx, x);
}
// torch like permute
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx,
struct ggml_tensor* x,
@ -799,7 +807,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
struct ggml_tensor* b) {
x = ggml_mul_mat(ctx, w, x);
if (b != NULL) {
x = ggml_add(ctx, x, b);
x = ggml_add_inplace(ctx, x, b);
}
return x;
}
@ -822,7 +830,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
x = ggml_add_inplace(ctx, x, b);
}
return x;
}
@ -851,45 +859,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
x = ggml_add(ctx, x, b);
x = ggml_add_inplace(ctx, x, b);
}
return x;
}
// w: [OCIC, KD, 1 * 1]
// x: [N, IC, IH, IW]
// b: [OC,]
// result: [N, OC, OH, OW]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1_bak(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s2 = 1,
int p2 = 1,
int d2 = 1) {
GGML_ASSERT(w->ne[0] == 1);
// timesteps = x.shape[0]
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
// x = conv3d(x)
// return rearrange(x, "b c t h w -> (b t) c h w")
int64_t T = x->ne[3];
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [B, OC, T, OH * OW]
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
x = ggml_add(ctx, x, b);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x; // [B*T, OC, OH, OW]
}
// w: [OCIC, KD, 1 * 1]
// x: [N, IC, ID, IH*IW]
// b: [OC,]
@ -991,13 +965,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
C = q->ne[0];
N = q->ne[2];
d_head = C / n_head;
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
} else {
@ -1047,14 +1021,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
}
k = ggml_cast(ctx, k, GGML_TYPE_F16);
// k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
if (kv_pad != 0) {
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
}
v = ggml_cast(ctx, v, GGML_TYPE_F16);
// v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (mask != nullptr) {
mask = ggml_transpose(ctx, mask);
@ -1074,8 +1048,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
} else {
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
kq = ggml_scale_inplace(ctx, kq, scale);
@ -1093,7 +1067,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
}
kqv = ggml_cont(ctx, kqv);
kqv = ggml_nn_cont(ctx, kqv);
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
return kqv;
@ -1106,9 +1080,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ct
float eps = EPS) {
x = ggml_norm(ctx, x, eps);
if (w != NULL) {
x = ggml_mul(ctx, x, w);
x = ggml_mul_inplace(ctx, x, w);
if (b != NULL) {
x = ggml_add(ctx, x, b);
x = ggml_add_inplace(ctx, x, b);
}
}
return x;
@ -1127,9 +1101,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
const float eps = 1e-6f; // default eps parameter
x = ggml_group_norm(ctx, x, num_groups, eps);
if (w != NULL && b != NULL) {
x = ggml_mul(ctx, x, w);
x = ggml_mul_inplace(ctx, x, w);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
x = ggml_add_inplace(ctx, x, b);
}
return x;
}
@ -1874,7 +1848,7 @@ public:
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
x = ggml_mul_inplace(ctx, x, w);
return x;
}
};

View File

@ -882,8 +882,6 @@ public:
float img_cfg_scale = guidance.img_cfg;
float slg_scale = guidance.slg.scale;
LOG_DEBUG("cfg_scale %.2f", cfg_scale);
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
img_cfg_scale = cfg_scale;
@ -1215,7 +1213,6 @@ public:
int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) {
LOG_DEBUG("scale_factor %.2f", scale_factor);
process_latent_out(x);
if (vae_tiling && !decode_video) {
// split latent in 32x32 tiles and compute in several steps

46
wan.hpp
View File

@ -99,10 +99,10 @@ namespace WAN {
// assert N == 1
struct ggml_tensor* w = params["gamma"];
auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
auto h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
h = ggml_rms_norm(ctx, h, 1e-12);
h = ggml_mul(ctx, h, w);
h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
return h;
}
@ -175,9 +175,9 @@ namespace WAN {
}
feat_cache[idx] = cache_x;
feat_idx += 1;
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
}
}
}
@ -186,7 +186,7 @@ namespace WAN {
if (mode != "none") {
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
if (mode == "upsample2d") {
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
} else if (mode == "upsample3d") {
@ -197,7 +197,7 @@ namespace WAN {
x = ggml_pad(ctx, x, 1, 1, 0, 0);
}
x = resample_1->forward(ctx, x);
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
}
if (mode == "downsample3d") {
@ -318,7 +318,7 @@ namespace WAN {
x = norm->forward(ctx, x);
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
const int64_t n = x->ne[3];
const int64_t c = x->ne[2];
@ -329,24 +329,24 @@ namespace WAN {
auto qkv_vec = split_image_qkv(ctx, qkv);
auto q = qkv_vec[0];
q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
q = ggml_nn_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
auto k = qkv_vec[1];
k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
k = ggml_nn_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
x = proj->forward(ctx, x);
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
x = ggml_add(ctx, x, identity);
return x;
@ -987,11 +987,11 @@ namespace WAN {
int64_t dim = x->ne[2];
int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q);
@ -1294,13 +1294,13 @@ namespace WAN {
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
return x;
}
@ -1331,7 +1331,7 @@ namespace WAN {
// patch_embedding
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
// time_embedding
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);