mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add cache support to ggml runner
This commit is contained in:
parent
aa5566f005
commit
fed78a3f1a
@ -1307,6 +1307,9 @@ protected:
|
||||
ggml_backend_buffer_t runtime_params_buffer = NULL;
|
||||
bool params_on_runtime_backend = false;
|
||||
|
||||
struct ggml_context* cache_ctx = NULL;
|
||||
ggml_backend_buffer_t cache_buffer = NULL;
|
||||
|
||||
struct ggml_context* compute_ctx = NULL;
|
||||
struct ggml_gallocr* compute_allocr = NULL;
|
||||
|
||||
@ -1314,6 +1317,8 @@ protected:
|
||||
ggml_tensor* one_tensor = NULL;
|
||||
|
||||
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
||||
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
|
||||
const std::string final_result_name = "ggml_runner_final_result_tensor";
|
||||
|
||||
void alloc_params_ctx() {
|
||||
struct ggml_init_params params;
|
||||
@ -1340,6 +1345,23 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
void alloc_cache_ctx() {
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
|
||||
params.mem_buffer = NULL;
|
||||
params.no_alloc = true;
|
||||
|
||||
cache_ctx = ggml_init(params);
|
||||
GGML_ASSERT(cache_ctx != NULL);
|
||||
}
|
||||
|
||||
void free_cache_ctx() {
|
||||
if (cache_ctx != NULL) {
|
||||
ggml_free(cache_ctx);
|
||||
cache_ctx = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void alloc_compute_ctx() {
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(ggml_tensor_overhead() * MAX_GRAPH_SIZE + ggml_graph_overhead());
|
||||
@ -1370,6 +1392,8 @@ protected:
|
||||
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
|
||||
prepare_build_in_tensor_before();
|
||||
struct ggml_cgraph* gf = get_graph();
|
||||
auto result = ggml_graph_node(gf, -1);
|
||||
ggml_set_name(result, final_result_name.c_str());
|
||||
prepare_build_in_tensor_after(gf);
|
||||
return gf;
|
||||
}
|
||||
@ -1399,7 +1423,43 @@ protected:
|
||||
return true;
|
||||
}
|
||||
|
||||
void cpy_data_to_backend_tensor() {
|
||||
void free_cache_buffer() {
|
||||
if (cache_buffer != NULL) {
|
||||
ggml_backend_buffer_free(cache_buffer);
|
||||
cache_buffer = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void copy_cache_tensors_to_cache_buffer() {
|
||||
if (cache_tensor_map.size() == 0) {
|
||||
return;
|
||||
}
|
||||
free_cache_ctx_and_buffer();
|
||||
alloc_cache_ctx();
|
||||
GGML_ASSERT(cache_buffer == NULL);
|
||||
std::map<ggml_tensor*, ggml_tensor*> runtime_tensor_to_cache_tensor;
|
||||
for (auto kv : cache_tensor_map) {
|
||||
auto cache_tensor = ggml_dup_tensor(cache_ctx, kv.second);
|
||||
ggml_set_name(cache_tensor, kv.first.c_str());
|
||||
runtime_tensor_to_cache_tensor[kv.second] = cache_tensor;
|
||||
}
|
||||
size_t num_tensors = ggml_tensor_num(cache_ctx);
|
||||
cache_buffer = ggml_backend_alloc_ctx_tensors(cache_ctx, runtime_backend);
|
||||
GGML_ASSERT(cache_buffer != NULL);
|
||||
for (auto kv : runtime_tensor_to_cache_tensor) {
|
||||
ggml_backend_tensor_copy(kv.first, kv.second);
|
||||
}
|
||||
ggml_backend_synchronize(runtime_backend);
|
||||
cache_tensor_map.clear();
|
||||
size_t cache_buffer_size = ggml_backend_buffer_get_size(cache_buffer);
|
||||
LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)",
|
||||
get_desc().c_str(),
|
||||
cache_buffer_size / (1024.f * 1024.f),
|
||||
ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM",
|
||||
num_tensors);
|
||||
}
|
||||
|
||||
void copy_data_to_backend_tensor() {
|
||||
for (auto& kv : backend_tensor_data_map) {
|
||||
auto tensor = kv.first;
|
||||
auto data = kv.second;
|
||||
@ -1510,6 +1570,7 @@ public:
|
||||
if (params_backend != runtime_backend) {
|
||||
ggml_backend_free(params_backend);
|
||||
}
|
||||
free_cache_ctx_and_buffer();
|
||||
}
|
||||
|
||||
void reset_compute_ctx() {
|
||||
@ -1549,6 +1610,11 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
void free_cache_ctx_and_buffer() {
|
||||
free_cache_buffer();
|
||||
free_cache_ctx();
|
||||
}
|
||||
|
||||
void free_compute_buffer() {
|
||||
if (compute_allocr != NULL) {
|
||||
ggml_gallocr_free(compute_allocr);
|
||||
@ -1579,6 +1645,17 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void cache(const std::string name, struct ggml_tensor* tensor) {
|
||||
cache_tensor_map[name] = tensor;
|
||||
}
|
||||
|
||||
struct ggml_tensor* get_cache_tensor_by_name(const std::string& name) {
|
||||
if (cache_ctx == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
return ggml_get_tensor(cache_ctx, name.c_str());
|
||||
}
|
||||
|
||||
void compute(get_graph_cb_t get_graph,
|
||||
int n_threads,
|
||||
bool free_compute_buffer_immediately = true,
|
||||
@ -1592,7 +1669,7 @@ public:
|
||||
reset_compute_ctx();
|
||||
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
||||
cpy_data_to_backend_tensor();
|
||||
copy_data_to_backend_tensor();
|
||||
if (ggml_backend_is_cpu(runtime_backend)) {
|
||||
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
|
||||
}
|
||||
@ -1601,8 +1678,9 @@ public:
|
||||
#ifdef GGML_PERF
|
||||
ggml_graph_print(gf);
|
||||
#endif
|
||||
copy_cache_tensors_to_cache_buffer();
|
||||
if (output != NULL) {
|
||||
auto result = ggml_graph_node(gf, -1);
|
||||
auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str());
|
||||
if (*output == NULL && output_ctx != NULL) {
|
||||
*output = ggml_dup_tensor(output_ctx, result);
|
||||
}
|
||||
|
||||
@ -2384,7 +2384,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB
|
||||
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB
|
||||
params.mem_size += width * height * frames * 3 * sizeof(float) * 2;
|
||||
params.mem_buffer = NULL;
|
||||
params.no_alloc = false;
|
||||
|
||||
137
wan.hpp
137
wan.hpp
@ -14,8 +14,6 @@ namespace WAN {
|
||||
constexpr int CACHE_T = 2;
|
||||
constexpr int WAN_GRAPH_SIZE = 10240;
|
||||
|
||||
#define Rep ((struct ggml_tensor*)1)
|
||||
|
||||
class CausalConv3d : public GGMLBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
@ -147,7 +145,8 @@ namespace WAN {
|
||||
struct ggml_tensor* x,
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx) {
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
int64_t c = x->ne[3] / b;
|
||||
@ -158,34 +157,33 @@ namespace WAN {
|
||||
if (mode == "upsample3d") {
|
||||
if (feat_cache.size() > 0) {
|
||||
int idx = feat_idx;
|
||||
if (feat_cache[idx] == NULL) {
|
||||
feat_cache[idx] = Rep; // Rep
|
||||
feat_idx += 1;
|
||||
feat_idx += 1;
|
||||
if (chunk_idx == 0) {
|
||||
// feat_cache[idx] == NULL, pass
|
||||
} else {
|
||||
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
||||
|
||||
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] != Rep) {
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // chunk_idx >= 2
|
||||
// cache last frame of last two chunk
|
||||
cache_x = ggml_concat(ctx,
|
||||
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||
cache_x,
|
||||
2);
|
||||
}
|
||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) {
|
||||
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
|
||||
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
|
||||
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
|
||||
}
|
||||
if (feat_cache[idx] == Rep) {
|
||||
if (chunk_idx == 1) {
|
||||
x = time_conv->forward(ctx, x);
|
||||
} else {
|
||||
x = time_conv->forward(ctx, x, feat_cache[idx]);
|
||||
}
|
||||
feat_cache[idx] = cache_x;
|
||||
feat_idx += 1;
|
||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -429,7 +427,8 @@ namespace WAN {
|
||||
struct ggml_tensor* x,
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx) {
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
struct ggml_tensor* x_copy = x;
|
||||
@ -447,7 +446,7 @@ namespace WAN {
|
||||
if (down_flag) {
|
||||
std::string block_name = "downsamples." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||
}
|
||||
|
||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
|
||||
@ -491,7 +490,7 @@ namespace WAN {
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx,
|
||||
bool first_chunk = false) {
|
||||
int chunk_idx) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
struct ggml_tensor* x_copy = x;
|
||||
@ -507,10 +506,10 @@ namespace WAN {
|
||||
if (up_flag) {
|
||||
std::string block_name = "upsamples." + std::to_string(i);
|
||||
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||
|
||||
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
|
||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b);
|
||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b);
|
||||
|
||||
x = ggml_add(ctx, x, shortcut);
|
||||
}
|
||||
@ -566,6 +565,8 @@ namespace WAN {
|
||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
|
||||
|
||||
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
||||
// v = ggml_cont(ctx, ggml_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||
// x = ggml_nn_attention_ext(ctx, q, k, v, q->ne[2], NULL, false, false, true);
|
||||
|
||||
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
||||
@ -656,7 +657,8 @@ namespace WAN {
|
||||
struct ggml_tensor* x,
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx) {
|
||||
int& feat_idx,
|
||||
int chunk_idx) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||
@ -695,7 +697,7 @@ namespace WAN {
|
||||
if (wan2_2) {
|
||||
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||
} else {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||
@ -706,7 +708,7 @@ namespace WAN {
|
||||
if (i != dim_mult.size() - 1) {
|
||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -827,7 +829,7 @@ namespace WAN {
|
||||
int64_t b,
|
||||
std::vector<struct ggml_tensor*>& feat_cache,
|
||||
int& feat_idx,
|
||||
bool first_chunk = false) {
|
||||
int chunk_idx) {
|
||||
// x: [b*c, t, h, w]
|
||||
GGML_ASSERT(b == 1);
|
||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||
@ -871,7 +873,7 @@ namespace WAN {
|
||||
if (wan2_2) {
|
||||
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk);
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||
} else {
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||
@ -882,7 +884,7 @@ namespace WAN {
|
||||
if (i != dim_mult.size() - 1) {
|
||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
|
||||
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1034,10 +1036,10 @@ namespace WAN {
|
||||
_enc_conv_idx = 0;
|
||||
if (i == 0) {
|
||||
auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
|
||||
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
|
||||
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
||||
} else {
|
||||
auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
|
||||
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
|
||||
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
||||
out = ggml_concat(ctx, out, out_, 2);
|
||||
}
|
||||
}
|
||||
@ -1065,10 +1067,10 @@ namespace WAN {
|
||||
_conv_idx = 0;
|
||||
if (i == 0) {
|
||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true);
|
||||
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||
} else {
|
||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
||||
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||
out = ggml_concat(ctx, out, out_, 2);
|
||||
}
|
||||
}
|
||||
@ -1092,33 +1094,17 @@ namespace WAN {
|
||||
auto x = conv2->forward(ctx, z);
|
||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||
_conv_idx = 0;
|
||||
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
||||
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||
if (wan2_2) {
|
||||
out = unpatchify(ctx, out, 2, b);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
struct FeatCache {
|
||||
std::vector<float> data;
|
||||
std::vector<int64_t> shape;
|
||||
bool is_rep = false;
|
||||
|
||||
FeatCache() = default;
|
||||
|
||||
FeatCache(ggml_backend_t backend, ggml_tensor* tensor) {
|
||||
shape = {tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]};
|
||||
data.resize(shape[0] * shape[1] * shape[2] * shape[3]);
|
||||
ggml_backend_tensor_get_and_sync(backend, tensor, (void*)data.data(), 0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
ggml_tensor* to_ggml_tensor(ggml_context* ctx) {
|
||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, shape[0], shape[1], shape[2], shape[3]);
|
||||
}
|
||||
};
|
||||
|
||||
struct WanVAERunner : public VAE {
|
||||
bool decode_only = true;
|
||||
WanVAE ae;
|
||||
std::vector<FeatCache> _feat_vec_map;
|
||||
|
||||
WanVAERunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
@ -1128,11 +1114,6 @@ namespace WAN {
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
||||
ae.init(params_ctx, tensor_types, prefix);
|
||||
rest_feat_vec_map();
|
||||
}
|
||||
|
||||
void rest_feat_vec_map() {
|
||||
_feat_vec_map = std::vector<FeatCache>(ae._conv_num, FeatCache());
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
@ -1160,15 +1141,9 @@ namespace WAN {
|
||||
|
||||
ae.clear_cache();
|
||||
|
||||
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
|
||||
FeatCache& feat_cache_vec = _feat_vec_map[feat_idx];
|
||||
if (feat_cache_vec.is_rep) {
|
||||
ae._feat_map[feat_idx] = Rep;
|
||||
} else if (feat_cache_vec.data.size() > 0) {
|
||||
ggml_tensor* feat_cache = feat_cache_vec.to_ggml_tensor(compute_ctx);
|
||||
set_backend_tensor_data(feat_cache, feat_cache_vec.data.data());
|
||||
ae._feat_map[feat_idx] = feat_cache;
|
||||
}
|
||||
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
||||
auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx));
|
||||
ae._feat_map[feat_idx] = feat_cache;
|
||||
}
|
||||
|
||||
z = to_backend(z);
|
||||
@ -1177,7 +1152,8 @@ namespace WAN {
|
||||
|
||||
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
||||
if (feat_cache != NULL && feat_cache != Rep) {
|
||||
if (feat_cache != NULL) {
|
||||
cache("feat_idx:" + std::to_string(feat_idx), feat_cache);
|
||||
ggml_build_forward_expand(gf, feat_cache);
|
||||
}
|
||||
}
|
||||
@ -1197,7 +1173,7 @@ namespace WAN {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
} else { // broken
|
||||
} else { // chunk 1 result is weird
|
||||
ae.clear_cache();
|
||||
int64_t t = z->ne[2];
|
||||
int64_t i = 0;
|
||||
@ -1205,21 +1181,10 @@ namespace WAN {
|
||||
return build_graph_partial(z, decode_graph, i);
|
||||
};
|
||||
struct ggml_tensor* out = NULL;
|
||||
GGMLRunner::compute(get_graph, n_threads, false, &out, output_ctx);
|
||||
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
|
||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
||||
if (feat_cache == Rep) {
|
||||
FeatCache feat_cache_vec;
|
||||
feat_cache_vec.is_rep = true;
|
||||
_feat_vec_map[feat_idx] = feat_cache_vec;
|
||||
} else if (feat_cache != NULL) {
|
||||
_feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache);
|
||||
}
|
||||
}
|
||||
GGMLRunner::free_compute_buffer();
|
||||
GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
|
||||
ae.clear_cache();
|
||||
if (t == 1) {
|
||||
*output = out;
|
||||
ae.clear_cache();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1244,25 +1209,11 @@ namespace WAN {
|
||||
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
|
||||
|
||||
for (i = 1; i < t; i++) {
|
||||
GGMLRunner::compute(get_graph, n_threads, false, &out);
|
||||
|
||||
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
|
||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
||||
if (feat_cache == Rep) {
|
||||
FeatCache feat_cache_vec;
|
||||
feat_cache_vec.is_rep = true;
|
||||
_feat_vec_map[feat_idx] = feat_cache_vec;
|
||||
} else if (feat_cache != NULL) {
|
||||
_feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache);
|
||||
}
|
||||
}
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, true, &out);
|
||||
ae.clear_cache();
|
||||
|
||||
GGMLRunner::free_compute_buffer();
|
||||
|
||||
copy_to_output();
|
||||
}
|
||||
free_cache_ctx_and_buffer();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user