mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add cache support to ggml runner
This commit is contained in:
parent
aa5566f005
commit
fed78a3f1a
@ -1307,6 +1307,9 @@ protected:
|
|||||||
ggml_backend_buffer_t runtime_params_buffer = NULL;
|
ggml_backend_buffer_t runtime_params_buffer = NULL;
|
||||||
bool params_on_runtime_backend = false;
|
bool params_on_runtime_backend = false;
|
||||||
|
|
||||||
|
struct ggml_context* cache_ctx = NULL;
|
||||||
|
ggml_backend_buffer_t cache_buffer = NULL;
|
||||||
|
|
||||||
struct ggml_context* compute_ctx = NULL;
|
struct ggml_context* compute_ctx = NULL;
|
||||||
struct ggml_gallocr* compute_allocr = NULL;
|
struct ggml_gallocr* compute_allocr = NULL;
|
||||||
|
|
||||||
@ -1314,6 +1317,8 @@ protected:
|
|||||||
ggml_tensor* one_tensor = NULL;
|
ggml_tensor* one_tensor = NULL;
|
||||||
|
|
||||||
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
|
||||||
|
std::map<std::string, struct ggml_tensor*> cache_tensor_map; // name -> tensor
|
||||||
|
const std::string final_result_name = "ggml_runner_final_result_tensor";
|
||||||
|
|
||||||
void alloc_params_ctx() {
|
void alloc_params_ctx() {
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
@ -1340,6 +1345,23 @@ protected:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void alloc_cache_ctx() {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = true;
|
||||||
|
|
||||||
|
cache_ctx = ggml_init(params);
|
||||||
|
GGML_ASSERT(cache_ctx != NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_cache_ctx() {
|
||||||
|
if (cache_ctx != NULL) {
|
||||||
|
ggml_free(cache_ctx);
|
||||||
|
cache_ctx = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void alloc_compute_ctx() {
|
void alloc_compute_ctx() {
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(ggml_tensor_overhead() * MAX_GRAPH_SIZE + ggml_graph_overhead());
|
params.mem_size = static_cast<size_t>(ggml_tensor_overhead() * MAX_GRAPH_SIZE + ggml_graph_overhead());
|
||||||
@ -1370,6 +1392,8 @@ protected:
|
|||||||
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
|
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
|
||||||
prepare_build_in_tensor_before();
|
prepare_build_in_tensor_before();
|
||||||
struct ggml_cgraph* gf = get_graph();
|
struct ggml_cgraph* gf = get_graph();
|
||||||
|
auto result = ggml_graph_node(gf, -1);
|
||||||
|
ggml_set_name(result, final_result_name.c_str());
|
||||||
prepare_build_in_tensor_after(gf);
|
prepare_build_in_tensor_after(gf);
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
@ -1399,7 +1423,43 @@ protected:
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void cpy_data_to_backend_tensor() {
|
void free_cache_buffer() {
|
||||||
|
if (cache_buffer != NULL) {
|
||||||
|
ggml_backend_buffer_free(cache_buffer);
|
||||||
|
cache_buffer = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_cache_tensors_to_cache_buffer() {
|
||||||
|
if (cache_tensor_map.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
free_cache_ctx_and_buffer();
|
||||||
|
alloc_cache_ctx();
|
||||||
|
GGML_ASSERT(cache_buffer == NULL);
|
||||||
|
std::map<ggml_tensor*, ggml_tensor*> runtime_tensor_to_cache_tensor;
|
||||||
|
for (auto kv : cache_tensor_map) {
|
||||||
|
auto cache_tensor = ggml_dup_tensor(cache_ctx, kv.second);
|
||||||
|
ggml_set_name(cache_tensor, kv.first.c_str());
|
||||||
|
runtime_tensor_to_cache_tensor[kv.second] = cache_tensor;
|
||||||
|
}
|
||||||
|
size_t num_tensors = ggml_tensor_num(cache_ctx);
|
||||||
|
cache_buffer = ggml_backend_alloc_ctx_tensors(cache_ctx, runtime_backend);
|
||||||
|
GGML_ASSERT(cache_buffer != NULL);
|
||||||
|
for (auto kv : runtime_tensor_to_cache_tensor) {
|
||||||
|
ggml_backend_tensor_copy(kv.first, kv.second);
|
||||||
|
}
|
||||||
|
ggml_backend_synchronize(runtime_backend);
|
||||||
|
cache_tensor_map.clear();
|
||||||
|
size_t cache_buffer_size = ggml_backend_buffer_get_size(cache_buffer);
|
||||||
|
LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)",
|
||||||
|
get_desc().c_str(),
|
||||||
|
cache_buffer_size / (1024.f * 1024.f),
|
||||||
|
ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM",
|
||||||
|
num_tensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_data_to_backend_tensor() {
|
||||||
for (auto& kv : backend_tensor_data_map) {
|
for (auto& kv : backend_tensor_data_map) {
|
||||||
auto tensor = kv.first;
|
auto tensor = kv.first;
|
||||||
auto data = kv.second;
|
auto data = kv.second;
|
||||||
@ -1510,6 +1570,7 @@ public:
|
|||||||
if (params_backend != runtime_backend) {
|
if (params_backend != runtime_backend) {
|
||||||
ggml_backend_free(params_backend);
|
ggml_backend_free(params_backend);
|
||||||
}
|
}
|
||||||
|
free_cache_ctx_and_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
void reset_compute_ctx() {
|
void reset_compute_ctx() {
|
||||||
@ -1549,6 +1610,11 @@ public:
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void free_cache_ctx_and_buffer() {
|
||||||
|
free_cache_buffer();
|
||||||
|
free_cache_ctx();
|
||||||
|
}
|
||||||
|
|
||||||
void free_compute_buffer() {
|
void free_compute_buffer() {
|
||||||
if (compute_allocr != NULL) {
|
if (compute_allocr != NULL) {
|
||||||
ggml_gallocr_free(compute_allocr);
|
ggml_gallocr_free(compute_allocr);
|
||||||
@ -1579,6 +1645,17 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cache(const std::string name, struct ggml_tensor* tensor) {
|
||||||
|
cache_tensor_map[name] = tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* get_cache_tensor_by_name(const std::string& name) {
|
||||||
|
if (cache_ctx == NULL) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return ggml_get_tensor(cache_ctx, name.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
void compute(get_graph_cb_t get_graph,
|
void compute(get_graph_cb_t get_graph,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
bool free_compute_buffer_immediately = true,
|
bool free_compute_buffer_immediately = true,
|
||||||
@ -1592,7 +1669,7 @@ public:
|
|||||||
reset_compute_ctx();
|
reset_compute_ctx();
|
||||||
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
struct ggml_cgraph* gf = get_compute_graph(get_graph);
|
||||||
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
|
||||||
cpy_data_to_backend_tensor();
|
copy_data_to_backend_tensor();
|
||||||
if (ggml_backend_is_cpu(runtime_backend)) {
|
if (ggml_backend_is_cpu(runtime_backend)) {
|
||||||
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
|
ggml_backend_cpu_set_n_threads(runtime_backend, n_threads);
|
||||||
}
|
}
|
||||||
@ -1601,8 +1678,9 @@ public:
|
|||||||
#ifdef GGML_PERF
|
#ifdef GGML_PERF
|
||||||
ggml_graph_print(gf);
|
ggml_graph_print(gf);
|
||||||
#endif
|
#endif
|
||||||
|
copy_cache_tensors_to_cache_buffer();
|
||||||
if (output != NULL) {
|
if (output != NULL) {
|
||||||
auto result = ggml_graph_node(gf, -1);
|
auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str());
|
||||||
if (*output == NULL && output_ctx != NULL) {
|
if (*output == NULL && output_ctx != NULL) {
|
||||||
*output = ggml_dup_tensor(output_ctx, result);
|
*output = ggml_dup_tensor(output_ctx, result);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2384,7 +2384,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
|
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps + high_noise_sample_steps);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB
|
params.mem_size = static_cast<size_t>(200 * 1024) * 1024; // 200 MB
|
||||||
params.mem_size += width * height * frames * 3 * sizeof(float) * 2;
|
params.mem_size += width * height * frames * 3 * sizeof(float) * 2;
|
||||||
params.mem_buffer = NULL;
|
params.mem_buffer = NULL;
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
|
|||||||
137
wan.hpp
137
wan.hpp
@ -14,8 +14,6 @@ namespace WAN {
|
|||||||
constexpr int CACHE_T = 2;
|
constexpr int CACHE_T = 2;
|
||||||
constexpr int WAN_GRAPH_SIZE = 10240;
|
constexpr int WAN_GRAPH_SIZE = 10240;
|
||||||
|
|
||||||
#define Rep ((struct ggml_tensor*)1)
|
|
||||||
|
|
||||||
class CausalConv3d : public GGMLBlock {
|
class CausalConv3d : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t in_channels;
|
int64_t in_channels;
|
||||||
@ -147,7 +145,8 @@ namespace WAN {
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
int& feat_idx) {
|
int& feat_idx,
|
||||||
|
int chunk_idx) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
GGML_ASSERT(b == 1);
|
GGML_ASSERT(b == 1);
|
||||||
int64_t c = x->ne[3] / b;
|
int64_t c = x->ne[3] / b;
|
||||||
@ -158,34 +157,33 @@ namespace WAN {
|
|||||||
if (mode == "upsample3d") {
|
if (mode == "upsample3d") {
|
||||||
if (feat_cache.size() > 0) {
|
if (feat_cache.size() > 0) {
|
||||||
int idx = feat_idx;
|
int idx = feat_idx;
|
||||||
if (feat_cache[idx] == NULL) {
|
feat_idx += 1;
|
||||||
feat_cache[idx] = Rep; // Rep
|
if (chunk_idx == 0) {
|
||||||
feat_idx += 1;
|
// feat_cache[idx] == NULL, pass
|
||||||
} else {
|
} else {
|
||||||
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
auto time_conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["time_conv"]);
|
||||||
|
|
||||||
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
auto cache_x = ggml_slice(ctx, x, 2, -CACHE_T, x->ne[2]);
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] != Rep) {
|
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL) { // chunk_idx >= 2
|
||||||
// cache last frame of last two chunk
|
// cache last frame of last two chunk
|
||||||
cache_x = ggml_concat(ctx,
|
cache_x = ggml_concat(ctx,
|
||||||
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
ggml_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]),
|
||||||
cache_x,
|
cache_x,
|
||||||
2);
|
2);
|
||||||
}
|
}
|
||||||
if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) {
|
if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep
|
||||||
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
|
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
|
||||||
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
|
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
|
||||||
}
|
}
|
||||||
if (feat_cache[idx] == Rep) {
|
if (chunk_idx == 1) {
|
||||||
x = time_conv->forward(ctx, x);
|
x = time_conv->forward(ctx, x);
|
||||||
} else {
|
} else {
|
||||||
x = time_conv->forward(ctx, x, feat_cache[idx]);
|
x = time_conv->forward(ctx, x, feat_cache[idx]);
|
||||||
}
|
}
|
||||||
feat_cache[idx] = cache_x;
|
feat_cache[idx] = cache_x;
|
||||||
feat_idx += 1;
|
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
||||||
x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w)
|
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
||||||
x = ggml_nn_cont(ctx, ggml_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w)
|
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -429,7 +427,8 @@ namespace WAN {
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
int& feat_idx) {
|
int& feat_idx,
|
||||||
|
int chunk_idx) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
GGML_ASSERT(b == 1);
|
GGML_ASSERT(b == 1);
|
||||||
struct ggml_tensor* x_copy = x;
|
struct ggml_tensor* x_copy = x;
|
||||||
@ -447,7 +446,7 @@ namespace WAN {
|
|||||||
if (down_flag) {
|
if (down_flag) {
|
||||||
std::string block_name = "downsamples." + std::to_string(i);
|
std::string block_name = "downsamples." + std::to_string(i);
|
||||||
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
||||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
|
auto shortcut = avg_shortcut->forward(ctx, x_copy, b);
|
||||||
@ -491,7 +490,7 @@ namespace WAN {
|
|||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
bool first_chunk = false) {
|
int chunk_idx) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
GGML_ASSERT(b == 1);
|
GGML_ASSERT(b == 1);
|
||||||
struct ggml_tensor* x_copy = x;
|
struct ggml_tensor* x_copy = x;
|
||||||
@ -507,10 +506,10 @@ namespace WAN {
|
|||||||
if (up_flag) {
|
if (up_flag) {
|
||||||
std::string block_name = "upsamples." + std::to_string(i);
|
std::string block_name = "upsamples." + std::to_string(i);
|
||||||
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
auto block = std::dynamic_pointer_cast<Resample>(blocks[block_name]);
|
||||||
x = block->forward(ctx, x, b, feat_cache, feat_idx);
|
x = block->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||||
|
|
||||||
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
|
auto avg_shortcut = std::dynamic_pointer_cast<DupUp3D>(blocks["avg_shortcut"]);
|
||||||
auto shortcut = avg_shortcut->forward(ctx, x_copy, first_chunk, b);
|
auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b);
|
||||||
|
|
||||||
x = ggml_add(ctx, x, shortcut);
|
x = ggml_add(ctx, x, shortcut);
|
||||||
}
|
}
|
||||||
@ -566,6 +565,8 @@ namespace WAN {
|
|||||||
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
|
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w]
|
||||||
|
|
||||||
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
x = ggml_nn_attention(ctx, q, k, v, false); // [t, h * w, c]
|
||||||
|
// v = ggml_cont(ctx, ggml_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||||
|
// x = ggml_nn_attention_ext(ctx, q, k, v, q->ne[2], NULL, false, false, true);
|
||||||
|
|
||||||
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
x = ggml_nn_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||||
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w]
|
||||||
@ -656,7 +657,8 @@ namespace WAN {
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
int& feat_idx) {
|
int& feat_idx,
|
||||||
|
int chunk_idx) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
GGML_ASSERT(b == 1);
|
GGML_ASSERT(b == 1);
|
||||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||||
@ -695,7 +697,7 @@ namespace WAN {
|
|||||||
if (wan2_2) {
|
if (wan2_2) {
|
||||||
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||||
} else {
|
} else {
|
||||||
for (int j = 0; j < num_res_blocks; j++) {
|
for (int j = 0; j < num_res_blocks; j++) {
|
||||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
|
||||||
@ -706,7 +708,7 @@ namespace WAN {
|
|||||||
if (i != dim_mult.size() - 1) {
|
if (i != dim_mult.size() - 1) {
|
||||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
|
auto layer = std::dynamic_pointer_cast<Resample>(blocks["downsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -827,7 +829,7 @@ namespace WAN {
|
|||||||
int64_t b,
|
int64_t b,
|
||||||
std::vector<struct ggml_tensor*>& feat_cache,
|
std::vector<struct ggml_tensor*>& feat_cache,
|
||||||
int& feat_idx,
|
int& feat_idx,
|
||||||
bool first_chunk = false) {
|
int chunk_idx) {
|
||||||
// x: [b*c, t, h, w]
|
// x: [b*c, t, h, w]
|
||||||
GGML_ASSERT(b == 1);
|
GGML_ASSERT(b == 1);
|
||||||
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
|
||||||
@ -871,7 +873,7 @@ namespace WAN {
|
|||||||
if (wan2_2) {
|
if (wan2_2) {
|
||||||
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx, first_chunk);
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||||
} else {
|
} else {
|
||||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||||
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
auto layer = std::dynamic_pointer_cast<ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
|
||||||
@ -882,7 +884,7 @@ namespace WAN {
|
|||||||
if (i != dim_mult.size() - 1) {
|
if (i != dim_mult.size() - 1) {
|
||||||
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
|
auto layer = std::dynamic_pointer_cast<Resample>(blocks["upsamples." + std::to_string(index++)]);
|
||||||
|
|
||||||
x = layer->forward(ctx, x, b, feat_cache, feat_idx);
|
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1034,10 +1036,10 @@ namespace WAN {
|
|||||||
_enc_conv_idx = 0;
|
_enc_conv_idx = 0;
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
|
auto in = ggml_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w]
|
||||||
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
|
out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
||||||
} else {
|
} else {
|
||||||
auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
|
auto in = ggml_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w]
|
||||||
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx);
|
auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i);
|
||||||
out = ggml_concat(ctx, out, out_, 2);
|
out = ggml_concat(ctx, out, out_, 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1065,10 +1067,10 @@ namespace WAN {
|
|||||||
_conv_idx = 0;
|
_conv_idx = 0;
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, true);
|
out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||||
} else {
|
} else {
|
||||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||||
out = ggml_concat(ctx, out, out_, 2);
|
out = ggml_concat(ctx, out, out_, 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1092,33 +1094,17 @@ namespace WAN {
|
|||||||
auto x = conv2->forward(ctx, z);
|
auto x = conv2->forward(ctx, z);
|
||||||
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
|
||||||
_conv_idx = 0;
|
_conv_idx = 0;
|
||||||
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
|
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
|
||||||
|
if (wan2_2) {
|
||||||
|
out = unpatchify(ctx, out, 2, b);
|
||||||
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FeatCache {
|
|
||||||
std::vector<float> data;
|
|
||||||
std::vector<int64_t> shape;
|
|
||||||
bool is_rep = false;
|
|
||||||
|
|
||||||
FeatCache() = default;
|
|
||||||
|
|
||||||
FeatCache(ggml_backend_t backend, ggml_tensor* tensor) {
|
|
||||||
shape = {tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]};
|
|
||||||
data.resize(shape[0] * shape[1] * shape[2] * shape[3]);
|
|
||||||
ggml_backend_tensor_get_and_sync(backend, tensor, (void*)data.data(), 0, ggml_nbytes(tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* to_ggml_tensor(ggml_context* ctx) {
|
|
||||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, shape[0], shape[1], shape[2], shape[3]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct WanVAERunner : public VAE {
|
struct WanVAERunner : public VAE {
|
||||||
bool decode_only = true;
|
bool decode_only = true;
|
||||||
WanVAE ae;
|
WanVAE ae;
|
||||||
std::vector<FeatCache> _feat_vec_map;
|
|
||||||
|
|
||||||
WanVAERunner(ggml_backend_t backend,
|
WanVAERunner(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
@ -1128,11 +1114,6 @@ namespace WAN {
|
|||||||
SDVersion version = VERSION_WAN2)
|
SDVersion version = VERSION_WAN2)
|
||||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
||||||
ae.init(params_ctx, tensor_types, prefix);
|
ae.init(params_ctx, tensor_types, prefix);
|
||||||
rest_feat_vec_map();
|
|
||||||
}
|
|
||||||
|
|
||||||
void rest_feat_vec_map() {
|
|
||||||
_feat_vec_map = std::vector<FeatCache>(ae._conv_num, FeatCache());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() {
|
std::string get_desc() {
|
||||||
@ -1160,15 +1141,9 @@ namespace WAN {
|
|||||||
|
|
||||||
ae.clear_cache();
|
ae.clear_cache();
|
||||||
|
|
||||||
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
|
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
||||||
FeatCache& feat_cache_vec = _feat_vec_map[feat_idx];
|
auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx));
|
||||||
if (feat_cache_vec.is_rep) {
|
ae._feat_map[feat_idx] = feat_cache;
|
||||||
ae._feat_map[feat_idx] = Rep;
|
|
||||||
} else if (feat_cache_vec.data.size() > 0) {
|
|
||||||
ggml_tensor* feat_cache = feat_cache_vec.to_ggml_tensor(compute_ctx);
|
|
||||||
set_backend_tensor_data(feat_cache, feat_cache_vec.data.data());
|
|
||||||
ae._feat_map[feat_idx] = feat_cache;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
z = to_backend(z);
|
z = to_backend(z);
|
||||||
@ -1177,7 +1152,8 @@ namespace WAN {
|
|||||||
|
|
||||||
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
|
||||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
||||||
if (feat_cache != NULL && feat_cache != Rep) {
|
if (feat_cache != NULL) {
|
||||||
|
cache("feat_idx:" + std::to_string(feat_idx), feat_cache);
|
||||||
ggml_build_forward_expand(gf, feat_cache);
|
ggml_build_forward_expand(gf, feat_cache);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1197,7 +1173,7 @@ namespace WAN {
|
|||||||
return build_graph(z, decode_graph);
|
return build_graph(z, decode_graph);
|
||||||
};
|
};
|
||||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||||
} else { // broken
|
} else { // chunk 1 result is weird
|
||||||
ae.clear_cache();
|
ae.clear_cache();
|
||||||
int64_t t = z->ne[2];
|
int64_t t = z->ne[2];
|
||||||
int64_t i = 0;
|
int64_t i = 0;
|
||||||
@ -1205,21 +1181,10 @@ namespace WAN {
|
|||||||
return build_graph_partial(z, decode_graph, i);
|
return build_graph_partial(z, decode_graph, i);
|
||||||
};
|
};
|
||||||
struct ggml_tensor* out = NULL;
|
struct ggml_tensor* out = NULL;
|
||||||
GGMLRunner::compute(get_graph, n_threads, false, &out, output_ctx);
|
GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
|
||||||
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
|
ae.clear_cache();
|
||||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
|
||||||
if (feat_cache == Rep) {
|
|
||||||
FeatCache feat_cache_vec;
|
|
||||||
feat_cache_vec.is_rep = true;
|
|
||||||
_feat_vec_map[feat_idx] = feat_cache_vec;
|
|
||||||
} else if (feat_cache != NULL) {
|
|
||||||
_feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GGMLRunner::free_compute_buffer();
|
|
||||||
if (t == 1) {
|
if (t == 1) {
|
||||||
*output = out;
|
*output = out;
|
||||||
ae.clear_cache();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1244,25 +1209,11 @@ namespace WAN {
|
|||||||
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
|
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
|
||||||
|
|
||||||
for (i = 1; i < t; i++) {
|
for (i = 1; i < t; i++) {
|
||||||
GGMLRunner::compute(get_graph, n_threads, false, &out);
|
GGMLRunner::compute(get_graph, n_threads, true, &out);
|
||||||
|
|
||||||
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
|
|
||||||
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
|
|
||||||
if (feat_cache == Rep) {
|
|
||||||
FeatCache feat_cache_vec;
|
|
||||||
feat_cache_vec.is_rep = true;
|
|
||||||
_feat_vec_map[feat_idx] = feat_cache_vec;
|
|
||||||
} else if (feat_cache != NULL) {
|
|
||||||
_feat_vec_map[feat_idx] = FeatCache(runtime_backend, feat_cache);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ae.clear_cache();
|
ae.clear_cache();
|
||||||
|
|
||||||
GGMLRunner::free_compute_buffer();
|
|
||||||
|
|
||||||
copy_to_output();
|
copy_to_output();
|
||||||
}
|
}
|
||||||
|
free_cache_ctx_and_buffer();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user