mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
make wan a little faster
This commit is contained in:
parent
00f790d0e9
commit
73f76e6d96
@ -608,6 +608,14 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
|
||||
}
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
if (ggml_is_contiguous(x)) {
|
||||
return x;
|
||||
}
|
||||
return ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
// torch like permute
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
@ -799,7 +807,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
|
||||
struct ggml_tensor* b) {
|
||||
x = ggml_mul_mat(ctx, w, x);
|
||||
if (b != NULL) {
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -822,7 +830,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
|
||||
if (b != NULL) {
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||
// b = ggml_repeat(ctx, b, x);
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -851,45 +859,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
|
||||
|
||||
if (b != NULL) {
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
// w: [OC,IC, KD, 1 * 1]
|
||||
// x: [N, IC, IH, IW]
|
||||
// b: [OC,]
|
||||
// result: [N, OC, OH, OW]
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1_bak(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* w,
|
||||
struct ggml_tensor* b,
|
||||
int s2 = 1,
|
||||
int p2 = 1,
|
||||
int d2 = 1) {
|
||||
GGML_ASSERT(w->ne[0] == 1);
|
||||
// timesteps = x.shape[0]
|
||||
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
// x = conv3d(x)
|
||||
// return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
int64_t T = x->ne[3];
|
||||
int64_t B = x->ne[3] / T;
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [B, OC, T, OH * OW]
|
||||
if (b != NULL) {
|
||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
||||
x = ggml_add(ctx, x, b);
|
||||
}
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||
return x; // [B*T, OC, OH, OW]
|
||||
}
|
||||
|
||||
// w: [OC,IC, KD, 1 * 1]
|
||||
// x: [N, IC, ID, IH*IW]
|
||||
// b: [OC,]
|
||||
@ -991,13 +965,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
C = q->ne[0];
|
||||
N = q->ne[2];
|
||||
d_head = C / n_head;
|
||||
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
||||
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
||||
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
||||
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
||||
|
||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
|
||||
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
||||
} else {
|
||||
@ -1047,14 +1021,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
||||
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
|
||||
}
|
||||
k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||
// k = ggml_cast(ctx, k, GGML_TYPE_F16);
|
||||
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
||||
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
||||
if (kv_pad != 0) {
|
||||
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
|
||||
}
|
||||
v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
||||
// v = ggml_cast(ctx, v, GGML_TYPE_F16);
|
||||
|
||||
if (mask != nullptr) {
|
||||
mask = ggml_transpose(ctx, mask);
|
||||
@ -1074,8 +1048,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
// kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
|
||||
kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0);
|
||||
} else {
|
||||
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
||||
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
||||
|
||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||
kq = ggml_scale_inplace(ctx, kq, scale);
|
||||
@ -1093,7 +1067,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
|
||||
}
|
||||
|
||||
kqv = ggml_cont(ctx, kqv);
|
||||
kqv = ggml_nn_cont(ctx, kqv);
|
||||
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
|
||||
|
||||
return kqv;
|
||||
@ -1106,9 +1080,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ct
|
||||
float eps = EPS) {
|
||||
x = ggml_norm(ctx, x, eps);
|
||||
if (w != NULL) {
|
||||
x = ggml_mul(ctx, x, w);
|
||||
x = ggml_mul_inplace(ctx, x, w);
|
||||
if (b != NULL) {
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
}
|
||||
return x;
|
||||
@ -1127,9 +1101,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
|
||||
const float eps = 1e-6f; // default eps parameter
|
||||
x = ggml_group_norm(ctx, x, num_groups, eps);
|
||||
if (w != NULL && b != NULL) {
|
||||
x = ggml_mul(ctx, x, w);
|
||||
x = ggml_mul_inplace(ctx, x, w);
|
||||
// b = ggml_repeat(ctx, b, x);
|
||||
x = ggml_add(ctx, x, b);
|
||||
x = ggml_add_inplace(ctx, x, b);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -1874,7 +1848,7 @@ public:
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
struct ggml_tensor* w = params["weight"];
|
||||
x = ggml_rms_norm(ctx, x, eps);
|
||||
x = ggml_mul(ctx, x, w);
|
||||
x = ggml_mul_inplace(ctx, x, w);
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
@ -882,8 +882,6 @@ public:
|
||||
float img_cfg_scale = guidance.img_cfg;
|
||||
float slg_scale = guidance.slg.scale;
|
||||
|
||||
LOG_DEBUG("cfg_scale %.2f", cfg_scale);
|
||||
|
||||
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
|
||||
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
|
||||
img_cfg_scale = cfg_scale;
|
||||
@ -1215,7 +1213,6 @@ public:
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (!use_tiny_autoencoder) {
|
||||
LOG_DEBUG("scale_factor %.2f", scale_factor);
|
||||
process_latent_out(x);
|
||||
if (vae_tiling && !decode_video) {
|
||||
// split latent in 32x32 tiles and compute in several steps
|
||||
|
||||
46
wan.hpp
46
wan.hpp
@ -99,10 +99,10 @@ namespace WAN {
|
||||
// assert N == 1
|
||||
|
||||
struct ggml_tensor* w = params["gamma"];
|
||||
auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||
auto h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||
h = ggml_rms_norm(ctx, h, 1e-12);
|
||||
h = ggml_mul(ctx, h, w);
|
||||
h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
||||
h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
||||
|
||||
return h;
|
||||
}
|
||||
@ -175,9 +175,9 @@ namespace WAN {
|
||||
}
|
||||
feat_cache[idx] = cache_x;
|
||||
feat_idx += 1;
|
||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -186,7 +186,7 @@ namespace WAN {
|
||||
if (mode != "none") {
|
||||
auto resample_1 = std::dynamic_pointer_cast<Conv2d>(blocks["resample.1"]);
|
||||
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||
if (mode == "upsample2d") {
|
||||
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST);
|
||||
} else if (mode == "upsample3d") {
|
||||
@ -197,7 +197,7 @@ namespace WAN {
|
||||
x = ggml_pad(ctx, x, 1, 1, 0, 0);
|
||||
}
|
||||
x = resample_1->forward(ctx, x);
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||
}
|
||||
|
||||
if (mode == "downsample3d") {
|
||||
@ -318,7 +318,7 @@ namespace WAN {
|
||||
|
||||
x = norm->forward(ctx, x);
|
||||
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||
|
||||
const int64_t n = x->ne[3];
|
||||
const int64_t c = x->ne[2];
|
||||
@ -329,24 +329,24 @@ namespace WAN {
|
||||
auto qkv_vec = split_image_qkv(ctx, qkv);
|
||||
|
||||
auto q = qkv_vec[0];
|
||||
q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
||||
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
|
||||
q = ggml_nn_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c]
|
||||
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c]
|
||||
|
||||
auto k = qkv_vec[1];
|
||||
k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
|
||||
k = ggml_nn_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c]
|
||||
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c]
|
||||
|
||||
auto v = qkv_vec[2];
|
||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
|
||||
|
||||
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
||||
|
||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
||||
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
||||
|
||||
x = proj->forward(ctx, x);
|
||||
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w)
|
||||
|
||||
x = ggml_add(ctx, x, identity);
|
||||
return x;
|
||||
@ -987,11 +987,11 @@ namespace WAN {
|
||||
int64_t dim = x->ne[2];
|
||||
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||
|
||||
context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||
context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
||||
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
|
||||
context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||
context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||
context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||
context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||
|
||||
auto q = q_proj->forward(ctx, x);
|
||||
q = norm_q->forward(ctx, q);
|
||||
@ -1294,13 +1294,13 @@ namespace WAN {
|
||||
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
|
||||
|
||||
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||
return x;
|
||||
}
|
||||
@ -1331,7 +1331,7 @@ namespace WAN {
|
||||
// patch_embedding
|
||||
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
|
||||
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
||||
|
||||
// time_embedding
|
||||
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user