diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 796ae33..13aa7e3 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -608,6 +608,14 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) { } } +__STATIC_INLINE__ struct ggml_tensor* ggml_nn_cont(struct ggml_context* ctx, + struct ggml_tensor* x) { + if (ggml_is_contiguous(x)) { + return x; + } + return ggml_cont(ctx, x); +} + // torch like permute __STATIC_INLINE__ struct ggml_tensor* ggml_torch_permute(struct ggml_context* ctx, struct ggml_tensor* x, @@ -799,7 +807,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx, struct ggml_tensor* b) { x = ggml_mul_mat(ctx, w, x); if (b != NULL) { - x = ggml_add(ctx, x, b); + x = ggml_add_inplace(ctx, x, b); } return x; } @@ -822,7 +830,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx, if (b != NULL) { b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); // b = ggml_repeat(ctx, b, x); - x = ggml_add(ctx, x, b); + x = ggml_add_inplace(ctx, x, b); } return x; } @@ -851,45 +859,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx, if (b != NULL) { b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1] - x = ggml_add(ctx, x, b); + x = ggml_add_inplace(ctx, x, b); } return x; } -// w: [OC,IC, KD, 1 * 1] -// x: [N, IC, IH, IW] -// b: [OC,] -// result: [N, OC, OH, OW] -__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1_bak(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* w, - struct ggml_tensor* b, - int s2 = 1, - int p2 = 1, - int d2 = 1) { - GGML_ASSERT(w->ne[0] == 1); - // timesteps = x.shape[0] - // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) - // x = conv3d(x) - // return rearrange(x, "b c t h w -> (b t) c h w") - int64_t T = x->ne[3]; - int64_t B = x->ne[3] / T; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - - x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) - x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [B, OC, T, OH * OW] - if (b != NULL) { - b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); - x = ggml_add(ctx, x, b); - } - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) - x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w - return x; // [B*T, OC, OH, OW] -} - // w: [OC,IC, KD, 1 * 1] // x: [N, IC, ID, IH*IW] // b: [OC,] @@ -991,13 +965,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* C = q->ne[0]; N = q->ne[2]; d_head = C / n_head; - q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] - q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] + q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] + q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] + q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] - k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] - k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] + k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] } else { @@ -1047,14 +1021,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); k = ggml_pad(ctx, k, 0, kv_pad, 0, 0); } - k = ggml_cast(ctx, k, GGML_TYPE_F16); + // k = ggml_cast(ctx, k, GGML_TYPE_F16); - v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] - v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] + v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] if (kv_pad != 0) { v = ggml_pad(ctx, v, 0, kv_pad, 0, 0); } - v = ggml_cast(ctx, v, GGML_TYPE_F16); + // v = ggml_cast(ctx, v, GGML_TYPE_F16); if (mask != nullptr) { mask = ggml_transpose(ctx, mask); @@ -1074,8 +1048,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); } else { - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] - v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] + v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] + v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] kq = ggml_scale_inplace(ctx, kq, scale); @@ -1093,7 +1067,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] } - kqv = ggml_cont(ctx, kqv); + kqv = ggml_nn_cont(ctx, kqv); kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C] return kqv; @@ -1106,9 +1080,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ct float eps = EPS) { x = ggml_norm(ctx, x, eps); if (w != NULL) { - x = ggml_mul(ctx, x, w); + x = ggml_mul_inplace(ctx, x, w); if (b != NULL) { - x = ggml_add(ctx, x, b); + x = ggml_add_inplace(ctx, x, b); } } return x; @@ -1127,9 +1101,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct const float eps = 1e-6f; // default eps parameter x = ggml_group_norm(ctx, x, num_groups, eps); if (w != NULL && b != NULL) { - x = ggml_mul(ctx, x, w); + x = ggml_mul_inplace(ctx, x, w); // b = ggml_repeat(ctx, b, x); - x = ggml_add(ctx, x, b); + x = ggml_add_inplace(ctx, x, b); } return x; } @@ -1874,7 +1848,7 @@ public: struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; x = ggml_rms_norm(ctx, x, eps); - x = ggml_mul(ctx, x, w); + x = ggml_mul_inplace(ctx, x, w); return x; } }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 3ef767b..424fbad 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -882,8 +882,6 @@ public: float img_cfg_scale = guidance.img_cfg; float slg_scale = guidance.slg.scale; - LOG_DEBUG("cfg_scale %.2f", cfg_scale); - if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) { LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance..."); img_cfg_scale = cfg_scale; @@ -1215,7 +1213,6 @@ public: int64_t t0 = ggml_time_ms(); if (!use_tiny_autoencoder) { - LOG_DEBUG("scale_factor %.2f", scale_factor); process_latent_out(x); if (vae_tiling && !decode_video) { // split latent in 32x32 tiles and compute in several steps diff --git a/wan.hpp b/wan.hpp index f031a23..763b774 100644 --- a/wan.hpp +++ b/wan.hpp @@ -99,10 +99,10 @@ namespace WAN { // assert N == 1 struct ggml_tensor* w = params["gamma"]; - auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] + auto h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] h = ggml_rms_norm(ctx, h, 1e-12); h = ggml_mul(ctx, h, w); - h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0)); + h = ggml_nn_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0)); return h; } @@ -175,9 +175,9 @@ namespace WAN { } feat_cache[idx] = cache_x; feat_idx += 1; - x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) - x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) + x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) + x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) } } } @@ -186,7 +186,7 @@ namespace WAN { if (mode != "none") { auto resample_1 = std::dynamic_pointer_cast(blocks["resample.1"]); - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) if (mode == "upsample2d") { x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "upsample3d") { @@ -197,7 +197,7 @@ namespace WAN { x = ggml_pad(ctx, x, 1, 1, 0, 0); } x = resample_1->forward(ctx, x); - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) } if (mode == "downsample3d") { @@ -318,7 +318,7 @@ namespace WAN { x = norm->forward(ctx, x); - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) const int64_t n = x->ne[3]; const int64_t c = x->ne[2]; @@ -329,24 +329,24 @@ namespace WAN { auto qkv_vec = split_image_qkv(ctx, qkv); auto q = qkv_vec[0]; - q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c] - q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c] + q = ggml_nn_cont(ctx, ggml_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c] + q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c] auto k = qkv_vec[1]; - k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c] - k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c] + k = ggml_nn_cont(ctx, ggml_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c] + k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c] auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w] x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c] - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w] - x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w] + x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w] + x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w] x = proj->forward(ctx, x); - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) x = ggml_add(ctx, x, identity); return x; @@ -987,11 +987,11 @@ namespace WAN { int64_t dim = x->ne[2]; int64_t context_txt_len = context->ne[1] - context_img_len; - context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] + context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]); - context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] - context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] + context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] + context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] auto q = q_proj->forward(ctx, x); q = norm_q->forward(ctx, q); @@ -1294,13 +1294,13 @@ namespace WAN { GGML_ASSERT(C * pt * ph * pw == x->ne[0]); x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C] - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw] x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw] - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw] x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw] - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] return x; } @@ -1331,7 +1331,7 @@ namespace WAN { // patch_embedding x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len] x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] - x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] // time_embedding auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);