add cache support to ggml runner

This commit is contained in:
leejet 2025-08-30 23:53:51 +08:00
parent aa5566f005
commit fed78a3f1a
3 changed files with 126 additions and 97 deletions

View File

@ -1307,6 +1307,9 @@ protected:
ggml_backend_buffer_t runtime_params_buffer = NULL;
bool params_on_runtime_backend = false;
struct ggml_context* cache_ctx = NULL;
ggml_backend_buffer_t cache_buffer = NULL;
struct ggml_context* compute_ctx = NULL;
struct ggml_gallocr* compute_allocr = NULL;
@ -1314,6 +1317,8 @@ protected:
ggml_tensor* one_tensor = NULL;
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
const std::string final_result_name = "ggml_runner_final_result_tensor";
void alloc_params_ctx() {
struct ggml_init_params params;
@ -1340,6 +1345,23 @@ protected:
}
}
void alloc_cache_ctx() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
params.mem_buffer = NULL;
params.no_alloc = true;
cache_ctx = ggml_init(params);
GGML_ASSERT(cache_ctx != NULL);
}
void free_cache_ctx() {
if (cache_ctx != NULL) {
ggml_free(cache_ctx);
cache_ctx = NULL;
}
}
void alloc_compute_ctx() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(ggml_tensor_overhead() * MAX_GRAPH_SIZE + ggml_graph_overhead());
@ -1370,6 +1392,8 @@ protected:
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
prepare_build_in_tensor_before();
struct ggml_cgraph* gf = get_graph();
auto result = ggml_graph_node(gf, -1);
ggml_set_name(result, final_result_name.c_str());
prepare_build_in_tensor_after(gf);
return gf;
}
@ -1399,7 +1423,43 @@ protected:
return true;
}
void cpy_data_to_backend_tensor() {
void free_cache_buffer() {
if (cache_buffer != NULL) {
ggml_backend_buffer_free(cache_buffer);
cache_buffer = NULL;
}
}
void copy_cache_tensors_to_cache_buffer() {
if (cache_tensor_map.size() == 0) {
return;
}
free_cache_ctx_and_buffer();
alloc_cache_ctx();
GGML_ASSERT(cache_buffer == NULL);
std::map<ggml_tensor*, ggml_tensor*> runtime_tensor_to_cache_tensor;
for (auto kv : cache_tensor_map) {
auto cache_tensor = ggml_dup_tensor(cache_ctx, kv.second);
ggml_set_name(cache_tensor, kv.first.c_str());
runtime_tensor_to_cache_tensor[kv.second] = cache_tensor;
}
size_t num_tensors = ggml_tensor_num(cache_ctx);
cache_buffer = ggml_backend_alloc_ctx_tensors(cache_ctx, runtime_backend);
GGML_ASSERT(cache_buffer != NULL);
for (auto kv : runtime_tensor_to_cache_tensor) {
ggml_backend_tensor_copy(kv.first, kv.second);
}
ggml_backend_synchronize(runtime_backend);
cache_tensor_map.clear();
size_t cache_buffer_size = ggml_backend_buffer_get_size(cache_buffer);
LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)",
get_desc().c_str(),
cache_buffer_size / (1024.f * 1024.f),
ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM",
num_tensors);
}
void copy_data_to_backend_tensor() {
for (auto& kv : backend_tensor_data_map) {
auto tensor = kv.first;
auto data = kv.second;
@ -1510,6 +1570,7 @@ public:
if (params_backend != runtime_backend) {
ggml_backend_free(params_backend);
}
free_cache_ctx_and_buffer();
}
void reset_compute_ctx() {
@ -1549,6 +1610,11 @@ public:
return 0;
}
void free_cache_ctx_and_buffer() {
free_cache_buffer();
free_cache_ctx();
}
void free_compute_buffer() {
if (compute_allocr != NULL) {
ggml_gallocr_free(compute_allocr);
@ -1579,6 +1645,17 @@ public:
}
}
void cache(const std::string name, struct ggml_tensor* tensor) {
cache_tensor_map[name] = tensor;
}
struct ggml_tensor* get_cache_tensor_by_name(const std::string& name) {
if (cache_ctx == NULL) {
return NULL;
}
return ggml_get_tensor(cache_ctx, name.c_str());
}
void compute(get_graph_cb_t get_graph,
int n_threads,
bool free_compute_buffer_immediately = true,
@ -1592,7 +1669,7 @@ public:
reset_compute_ctx();
struct ggml_cgraph* gf = get_compute_graph(get_graph);
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
cpy_data_to_backend_tensor();
copy_data_to_backend_tensor();
if (ggml_backend_is_cpu(runtime_backend)) {
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
}
@ -1601,8 +1678,9 @@ public:
#ifdef GGML_PERF
ggml_graph_print(gf);
#endif
copy_cache_tensors_to_cache_buffer();
if (output != NULL) {
auto result = ggml_graph_node(gf, -1);
auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str());
if (*output == NULL && output_ctx != NULL) {
*output = ggml_dup_tensor(output_ctx, result);
}

View File

@ -2384,7 +2384,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB
params.mem_size += width * height * frames * 3 * sizeof(float) * 2;
params.mem_buffer = NULL;
params.no_alloc = false;

137
wan.hpp
View File

@ -14,8 +14,6 @@ namespace WAN {
constexpr int CACHE_T = 2;
constexpr int WAN_GRAPH_SIZE = 10240;
#define Rep ((struct ggml_tensor*)1)
class CausalConv3d : public GGMLBlock {
protected:
int64_t in_channels;
@ -147,7 +145,8 @@ namespace WAN {
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx) {
int& feat_idx,
int chunk_idx) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
int64_t c = x->ne[3] / b;
@ -158,34 +157,33 @@ namespace WAN {
if (mode == "upsample3d") {
if (feat_cache.size() > 0) {
int idx = feat_idx;
if (feat_cache[idx] == NULL) {
feat_cache[idx] = Rep; // Rep
feat_idx += 1;
feat_idx += 1;
if (chunk_idx == 0) {
// feat_cache[idx] == NULL, pass
} else {
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] != Rep) {
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // chunk_idx >= 2
// cache last frame of last two chunk
cache_x = ggml_concat(ctx,
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
cache_x,
2);
}
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) {
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
}
if (feat_cache[idx] == Rep) {
if (chunk_idx == 1) {
x = time_conv->forward(ctx, x);
} else {
x = time_conv->forward(ctx, x, feat_cache[idx]);
}
feat_cache[idx] = cache_x;
feat_idx += 1;
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
}
}
}
@ -429,7 +427,8 @@ namespace WAN {
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx) {
int& feat_idx,
int chunk_idx) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
struct ggml_tensor* x_copy = x;
@ -447,7 +446,7 @@ namespace WAN {
if (down_flag) {
std::string block_name = "downsamples." + std::to_string(i);
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
x = block->forward(ctx, x, b, feat_cache, feat_idx);
x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
}
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
@ -491,7 +490,7 @@ namespace WAN {
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx,
bool first_chunk = false) {
int chunk_idx) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
struct ggml_tensor* x_copy = x;
@ -507,10 +506,10 @@ namespace WAN {
if (up_flag) {
std::string block_name = "upsamples." + std::to_string(i);
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
x = block->forward(ctx, x, b, feat_cache, feat_idx);
x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b);
auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b);
x = ggml_add(ctx, x, shortcut);
}
@ -566,6 +565,8 @@ namespace WAN {
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
// v = ggml_cont(ctx, ggml_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
// x = ggml_nn_attention_ext(ctx, q, k, v, q->ne[2], NULL, false, false, true);
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
@ -656,7 +657,8 @@ namespace WAN {
struct ggml_tensor* x,
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx) {
int& feat_idx,
int chunk_idx) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
@ -695,7 +697,7 @@ namespace WAN {
if (wan2_2) {
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
} else {
for (int j = 0; j < num_res_blocks; j++) {
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
@ -706,7 +708,7 @@ namespace WAN {
if (i != dim_mult.size() - 1) {
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
}
}
}
@ -827,7 +829,7 @@ namespace WAN {
int64_t b,
std::vector<struct ggml_tensor*>& feat_cache,
int& feat_idx,
bool first_chunk = false) {
int chunk_idx) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
@ -871,7 +873,7 @@ namespace WAN {
if (wan2_2) {
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
} else {
for (int j = 0; j < num_res_blocks + 1; j++) {
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
@ -882,7 +884,7 @@ namespace WAN {
if (i != dim_mult.size() - 1) {
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
}
}
}
@ -1034,10 +1036,10 @@ namespace WAN {
_enc_conv_idx = 0;
if (i == 0) {
auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
} else {
auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
out = ggml_concat(ctx, out, out_, 2);
}
}
@ -1065,10 +1067,10 @@ namespace WAN {
_conv_idx = 0;
if (i == 0) {
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true);
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
} else {
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
out = ggml_concat(ctx, out, out_, 2);
}
}
@ -1092,33 +1094,17 @@ namespace WAN {
auto x = conv2->forward(ctx, z);
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
_conv_idx = 0;
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
if (wan2_2) {
out = unpatchify(ctx, out, 2, b);
}
return out;
}
};
struct FeatCache {
std::vector<float> data;
std::vector<int64_t> shape;
bool is_rep = false;
FeatCache() = default;
FeatCache(ggml_backend_t backend, ggml_tensor* tensor) {
shape = {tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]};
data.resize(shape[0] * shape[1] * shape[2] * shape[3]);
ggml_backend_tensor_get_and_sync(backend, tensor, (void*)data.data(), 0, ggml_nbytes(tensor));
}
ggml_tensor* to_ggml_tensor(ggml_context* ctx) {
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, shape[0], shape[1], shape[2], shape[3]);
}
};
struct WanVAERunner : public VAE {
bool decode_only = true;
WanVAE ae;
std::vector<FeatCache> _feat_vec_map;
WanVAERunner(ggml_backend_t backend,
bool offload_params_to_cpu,
@ -1128,11 +1114,6 @@ namespace WAN {
SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_types, prefix);
rest_feat_vec_map();
}
void rest_feat_vec_map() {
_feat_vec_map = std::vector<FeatCache>(ae._conv_num, FeatCache());
}
std::string get_desc() {
@ -1160,15 +1141,9 @@ namespace WAN {
ae.clear_cache();
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
FeatCache& feat_cache_vec = _feat_vec_map[feat_idx];
if (feat_cache_vec.is_rep) {
ae._feat_map[feat_idx] = Rep;
} else if (feat_cache_vec.data.size() > 0) {
ggml_tensor* feat_cache = feat_cache_vec.to_ggml_tensor(compute_ctx);
set_backend_tensor_data(feat_cache, feat_cache_vec.data.data());
ae._feat_map[feat_idx] = feat_cache;
}
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx));
ae._feat_map[feat_idx] = feat_cache;
}
z = to_backend(z);
@ -1177,7 +1152,8 @@ namespace WAN {
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache != NULL && feat_cache != Rep) {
if (feat_cache != NULL) {
cache("feat_idx:" + std::to_string(feat_idx), feat_cache);
ggml_build_forward_expand(gf, feat_cache);
}
}
@ -1197,7 +1173,7 @@ namespace WAN {
return build_graph(z, decode_graph);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} else { // broken
} else { // chunk 1 result is weird
ae.clear_cache();
int64_t t = z->ne[2];
int64_t i = 0;
@ -1205,21 +1181,10 @@ namespace WAN {
return build_graph_partial(z, decode_graph, i);
};
struct ggml_tensor* out = NULL;
GGMLRunner::compute(get_graph, n_threads, false, &out, output_ctx);
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache == Rep) {
FeatCache feat_cache_vec;
feat_cache_vec.is_rep = true;
_feat_vec_map[feat_idx] = feat_cache_vec;
} else if (feat_cache != NULL) {
_feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache);
}
}
GGMLRunner::free_compute_buffer();
GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
ae.clear_cache();
if (t == 1) {
*output = out;
ae.clear_cache();
return;
}
@ -1244,25 +1209,11 @@ namespace WAN {
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
for (i = 1; i < t; i++) {
GGMLRunner::compute(get_graph, n_threads, false, &out);
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache == Rep) {
FeatCache feat_cache_vec;
feat_cache_vec.is_rep = true;
_feat_vec_map[feat_idx] = feat_cache_vec;
} else if (feat_cache != NULL) {
_feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache);
}
}
GGMLRunner::compute(get_graph, n_threads, true, &out);
ae.clear_cache();
GGMLRunner::free_compute_buffer();
copy_to_output();
}
free_cache_ctx_and_buffer();
}
}