diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 86fefc4..0965784 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1307,6 +1307,9 @@ protected: ggml_backend_buffer_t runtime_params_buffer = NULL; bool params_on_runtime_backend = false; + struct ggml_context* cache_ctx = NULL; + ggml_backend_buffer_t cache_buffer = NULL; + struct ggml_context* compute_ctx = NULL; struct ggml_gallocr* compute_allocr = NULL; @@ -1314,6 +1317,8 @@ protected: ggml_tensor* one_tensor = NULL; std::map backend_tensor_data_map; + std::map cache_tensor_map; // name -> tensor + const std::string final_result_name = "ggml_runner_final_result_tensor"; void alloc_params_ctx() { struct ggml_init_params params; @@ -1340,6 +1345,23 @@ protected: } } + void alloc_cache_ctx() { + struct ggml_init_params params; + params.mem_size = static_cast(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead()); + params.mem_buffer = NULL; + params.no_alloc = true; + + cache_ctx = ggml_init(params); + GGML_ASSERT(cache_ctx != NULL); + } + + void free_cache_ctx() { + if (cache_ctx != NULL) { + ggml_free(cache_ctx); + cache_ctx = NULL; + } + } + void alloc_compute_ctx() { struct ggml_init_params params; params.mem_size = static_cast(ggml_tensor_overhead() * MAX_GRAPH_SIZE + ggml_graph_overhead()); @@ -1370,6 +1392,8 @@ protected: struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) { prepare_build_in_tensor_before(); struct ggml_cgraph* gf = get_graph(); + auto result = ggml_graph_node(gf, -1); + ggml_set_name(result, final_result_name.c_str()); prepare_build_in_tensor_after(gf); return gf; } @@ -1399,7 +1423,43 @@ protected: return true; } - void cpy_data_to_backend_tensor() { + void free_cache_buffer() { + if (cache_buffer != NULL) { + ggml_backend_buffer_free(cache_buffer); + cache_buffer = NULL; + } + } + + void copy_cache_tensors_to_cache_buffer() { + if (cache_tensor_map.size() == 0) { + return; + } + free_cache_ctx_and_buffer(); + alloc_cache_ctx(); + GGML_ASSERT(cache_buffer == NULL); + std::map runtime_tensor_to_cache_tensor; + for (auto kv : cache_tensor_map) { + auto cache_tensor = ggml_dup_tensor(cache_ctx, kv.second); + ggml_set_name(cache_tensor, kv.first.c_str()); + runtime_tensor_to_cache_tensor[kv.second] = cache_tensor; + } + size_t num_tensors = ggml_tensor_num(cache_ctx); + cache_buffer = ggml_backend_alloc_ctx_tensors(cache_ctx, runtime_backend); + GGML_ASSERT(cache_buffer != NULL); + for (auto kv : runtime_tensor_to_cache_tensor) { + ggml_backend_tensor_copy(kv.first, kv.second); + } + ggml_backend_synchronize(runtime_backend); + cache_tensor_map.clear(); + size_t cache_buffer_size = ggml_backend_buffer_get_size(cache_buffer); + LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)", + get_desc().c_str(), + cache_buffer_size / (1024.f * 1024.f), + ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM", + num_tensors); + } + + void copy_data_to_backend_tensor() { for (auto& kv : backend_tensor_data_map) { auto tensor = kv.first; auto data = kv.second; @@ -1510,6 +1570,7 @@ public: if (params_backend != runtime_backend) { ggml_backend_free(params_backend); } + free_cache_ctx_and_buffer(); } void reset_compute_ctx() { @@ -1549,6 +1610,11 @@ public: return 0; } + void free_cache_ctx_and_buffer() { + free_cache_buffer(); + free_cache_ctx(); + } + void free_compute_buffer() { if (compute_allocr != NULL) { ggml_gallocr_free(compute_allocr); @@ -1579,6 +1645,17 @@ public: } } + void cache(const std::string name, struct ggml_tensor* tensor) { + cache_tensor_map[name] = tensor; + } + + struct ggml_tensor* get_cache_tensor_by_name(const std::string& name) { + if (cache_ctx == NULL) { + return NULL; + } + return ggml_get_tensor(cache_ctx, name.c_str()); + } + void compute(get_graph_cb_t get_graph, int n_threads, bool free_compute_buffer_immediately = true, @@ -1592,7 +1669,7 @@ public: reset_compute_ctx(); struct ggml_cgraph* gf = get_compute_graph(get_graph); GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf)); - cpy_data_to_backend_tensor(); + copy_data_to_backend_tensor(); if (ggml_backend_is_cpu(runtime_backend)) { ggml_backend_cpu_set_n_threads(runtime_backend, n_threads); } @@ -1601,8 +1678,9 @@ public: #ifdef GGML_PERF ggml_graph_print(gf); #endif + copy_cache_tensors_to_cache_buffer(); if (output != NULL) { - auto result = ggml_graph_node(gf, -1); + auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str()); if (*output == NULL && output_ctx != NULL) { *output = ggml_dup_tensor(output_ctx, result); } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 3e6110c..fdf7a65 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2384,7 +2384,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps); struct ggml_init_params params; - params.mem_size = static_cast(100 * 1024) * 1024; // 100 MB + params.mem_size = static_cast(200 * 1024) * 1024; // 200 MB params.mem_size += width * height * frames * 3 * sizeof(float) * 2; params.mem_buffer = NULL; params.no_alloc = false; diff --git a/wan.hpp b/wan.hpp index bd594ab..2580818 100644 --- a/wan.hpp +++ b/wan.hpp @@ -14,8 +14,6 @@ namespace WAN { constexpr int CACHE_T = 2; constexpr int WAN_GRAPH_SIZE = 10240; -#define Rep ((struct ggml_tensor*)1) - class CausalConv3d : public GGMLBlock { protected: int64_t in_channels; @@ -147,7 +145,8 @@ namespace WAN { struct ggml_tensor* x, int64_t b, std::vector& feat_cache, - int& feat_idx) { + int& feat_idx, + int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); int64_t c = x->ne[3] / b; @@ -158,34 +157,33 @@ namespace WAN { if (mode == "upsample3d") { if (feat_cache.size() > 0) { int idx = feat_idx; - if (feat_cache[idx] == NULL) { - feat_cache[idx] = Rep; // Rep - feat_idx += 1; + feat_idx += 1; + if (chunk_idx == 0) { + // feat_cache[idx] == NULL, pass } else { auto time_conv = std::dynamic_pointer_cast(blocks["time_conv"]); auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]); - if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] != Rep) { + if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // chunk_idx >= 2 // cache last frame of last two chunk cache_x = ggml_concat(ctx, ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } - if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) { + if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); // aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2) } - if (feat_cache[idx] == Rep) { + if (chunk_idx == 1) { x = time_conv->forward(ctx, x); } else { x = time_conv->forward(ctx, x, feat_cache[idx]); } feat_cache[idx] = cache_x; - feat_idx += 1; - x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) - x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) - x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) + x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) + x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) + x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) } } } @@ -429,7 +427,8 @@ namespace WAN { struct ggml_tensor* x, int64_t b, std::vector& feat_cache, - int& feat_idx) { + int& feat_idx, + int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); struct ggml_tensor* x_copy = x; @@ -447,7 +446,7 @@ namespace WAN { if (down_flag) { std::string block_name = "downsamples." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[block_name]); - x = block->forward(ctx, x, b, feat_cache, feat_idx); + x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } auto shortcut = avg_shortcut->forward(ctx, x_copy, b); @@ -491,7 +490,7 @@ namespace WAN { int64_t b, std::vector& feat_cache, int& feat_idx, - bool first_chunk = false) { + int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); struct ggml_tensor* x_copy = x; @@ -507,10 +506,10 @@ namespace WAN { if (up_flag) { std::string block_name = "upsamples." + std::to_string(i); auto block = std::dynamic_pointer_cast(blocks[block_name]); - x = block->forward(ctx, x, b, feat_cache, feat_idx); + x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); auto avg_shortcut = std::dynamic_pointer_cast(blocks["avg_shortcut"]); - auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b); + auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b); x = ggml_add(ctx, x, shortcut); } @@ -566,6 +565,8 @@ namespace WAN { v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w] x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c] + // v = ggml_cont(ctx, ggml_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] + // x = ggml_nn_attention_ext(ctx, q, k, v, q->ne[2], NULL, false, false, true); x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w] @@ -656,7 +657,8 @@ namespace WAN { struct ggml_tensor* x, int64_t b, std::vector& feat_cache, - int& feat_idx) { + int& feat_idx, + int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); @@ -695,7 +697,7 @@ namespace WAN { if (wan2_2) { auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); - x = layer->forward(ctx, x, b, feat_cache, feat_idx); + x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } else { for (int j = 0; j < num_res_blocks; j++) { auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); @@ -706,7 +708,7 @@ namespace WAN { if (i != dim_mult.size() - 1) { auto layer = std::dynamic_pointer_cast(blocks["downsamples." + std::to_string(index++)]); - x = layer->forward(ctx, x, b, feat_cache, feat_idx); + x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } } } @@ -827,7 +829,7 @@ namespace WAN { int64_t b, std::vector& feat_cache, int& feat_idx, - bool first_chunk = false) { + int chunk_idx) { // x: [b*c, t, h, w] GGML_ASSERT(b == 1); auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); @@ -871,7 +873,7 @@ namespace WAN { if (wan2_2) { auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); - x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk); + x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } else { for (int j = 0; j < num_res_blocks + 1; j++) { auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); @@ -882,7 +884,7 @@ namespace WAN { if (i != dim_mult.size() - 1) { auto layer = std::dynamic_pointer_cast(blocks["upsamples." + std::to_string(index++)]); - x = layer->forward(ctx, x, b, feat_cache, feat_idx); + x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); } } } @@ -1034,10 +1036,10 @@ namespace WAN { _enc_conv_idx = 0; if (i == 0) { auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w] - out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx); + out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); } else { auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w] - auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx); + auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); out = ggml_concat(ctx, out, out_, 2); } } @@ -1065,10 +1067,10 @@ namespace WAN { _conv_idx = 0; if (i == 0) { auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] - out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true); + out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); } else { auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] - auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx); + auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); out = ggml_concat(ctx, out, out_, 2); } } @@ -1092,33 +1094,17 @@ namespace WAN { auto x = conv2->forward(ctx, z); auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] _conv_idx = 0; - auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx); + auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); + if (wan2_2) { + out = unpatchify(ctx, out, 2, b); + } return out; } }; - struct FeatCache { - std::vector data; - std::vector shape; - bool is_rep = false; - - FeatCache() = default; - - FeatCache(ggml_backend_t backend, ggml_tensor* tensor) { - shape = {tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]}; - data.resize(shape[0] * shape[1] * shape[2] * shape[3]); - ggml_backend_tensor_get_and_sync(backend, tensor, (void*)data.data(), 0, ggml_nbytes(tensor)); - } - - ggml_tensor* to_ggml_tensor(ggml_context* ctx) { - return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, shape[0], shape[1], shape[2], shape[3]); - } - }; - struct WanVAERunner : public VAE { bool decode_only = true; WanVAE ae; - std::vector _feat_vec_map; WanVAERunner(ggml_backend_t backend, bool offload_params_to_cpu, @@ -1128,11 +1114,6 @@ namespace WAN { SDVersion version = VERSION_WAN2) : decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) { ae.init(params_ctx, tensor_types, prefix); - rest_feat_vec_map(); - } - - void rest_feat_vec_map() { - _feat_vec_map = std::vector(ae._conv_num, FeatCache()); } std::string get_desc() { @@ -1160,15 +1141,9 @@ namespace WAN { ae.clear_cache(); - for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) { - FeatCache& feat_cache_vec = _feat_vec_map[feat_idx]; - if (feat_cache_vec.is_rep) { - ae._feat_map[feat_idx] = Rep; - } else if (feat_cache_vec.data.size() > 0) { - ggml_tensor* feat_cache = feat_cache_vec.to_ggml_tensor(compute_ctx); - set_backend_tensor_data(feat_cache, feat_cache_vec.data.data()); - ae._feat_map[feat_idx] = feat_cache; - } + for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { + auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx)); + ae._feat_map[feat_idx] = feat_cache; } z = to_backend(z); @@ -1177,7 +1152,8 @@ namespace WAN { for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { ggml_tensor* feat_cache = ae._feat_map[feat_idx]; - if (feat_cache != NULL && feat_cache != Rep) { + if (feat_cache != NULL) { + cache("feat_idx:" + std::to_string(feat_idx), feat_cache); ggml_build_forward_expand(gf, feat_cache); } } @@ -1197,7 +1173,7 @@ namespace WAN { return build_graph(z, decode_graph); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); - } else { // broken + } else { // chunk 1 result is weird ae.clear_cache(); int64_t t = z->ne[2]; int64_t i = 0; @@ -1205,21 +1181,10 @@ namespace WAN { return build_graph_partial(z, decode_graph, i); }; struct ggml_tensor* out = NULL; - GGMLRunner::compute(get_graph, n_threads, false, &out, output_ctx); - for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) { - ggml_tensor* feat_cache = ae._feat_map[feat_idx]; - if (feat_cache == Rep) { - FeatCache feat_cache_vec; - feat_cache_vec.is_rep = true; - _feat_vec_map[feat_idx] = feat_cache_vec; - } else if (feat_cache != NULL) { - _feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache); - } - } - GGMLRunner::free_compute_buffer(); + GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx); + ae.clear_cache(); if (t == 1) { *output = out; - ae.clear_cache(); return; } @@ -1244,25 +1209,11 @@ namespace WAN { out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]); for (i = 1; i < t; i++) { - GGMLRunner::compute(get_graph, n_threads, false, &out); - - for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) { - ggml_tensor* feat_cache = ae._feat_map[feat_idx]; - if (feat_cache == Rep) { - FeatCache feat_cache_vec; - feat_cache_vec.is_rep = true; - _feat_vec_map[feat_idx] = feat_cache_vec; - } else if (feat_cache != NULL) { - _feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache); - } - } - + GGMLRunner::compute(get_graph, n_threads, true, &out); ae.clear_cache(); - - GGMLRunner::free_compute_buffer(); - copy_to_output(); } + free_cache_ctx_and_buffer(); } }