From bcc9c0d0b3b04b69b45a2055b18341f14be550f8 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Thu, 4 Dec 2025 15:04:27 +0100 Subject: [PATCH] feat: handle ggml compute failures without crashing the program (#1003) * Feat: handle compute failures more gracefully * fix Unreachable code after return Co-authored-by: idostyle * adjust z_image.hpp --------- Co-authored-by: idostyle Co-authored-by: leejet --- clip.hpp | 4 +-- conditioner.hpp | 4 +-- control.hpp | 10 ++++-- denoiser.hpp | 47 ++++++++++++++++++++++++++--- diffusion_model.hpp | 14 ++++----- esrgan.hpp | 4 +-- flux.hpp | 4 +-- ggml_extend.hpp | 21 ++++++++++--- llm.hpp | 4 +-- mmdit.hpp | 4 +-- model.cpp | 4 +-- pmid.hpp | 4 +-- qwen_image.hpp | 4 +-- stable-diffusion.cpp | 72 +++++++++++++++++++++++++++++++------------- t5.hpp | 4 +-- tae.hpp | 4 +-- unet.hpp | 4 +-- vae.hpp | 9 +++--- wan.hpp | 15 ++++----- z_image.hpp | 6 ++-- 20 files changed, 163 insertions(+), 79 deletions(-) diff --git a/clip.hpp b/clip.hpp index e2a892c..1f98327 100644 --- a/clip.hpp +++ b/clip.hpp @@ -963,7 +963,7 @@ struct CLIPTextModelRunner : public GGMLRunner { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* input_ids, int num_custom_embeddings, void* custom_embeddings_data, @@ -975,7 +975,7 @@ struct CLIPTextModelRunner : public GGMLRunner { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip); }; - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } }; diff --git a/conditioner.hpp b/conditioner.hpp index e28e6e1..403120d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -703,7 +703,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, ggml_tensor* pixel_values, bool return_pooled, int clip_skip, @@ -712,7 +712,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(pixel_values, return_pooled, clip_skip); }; - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } }; diff --git a/control.hpp b/control.hpp index d86f64c..f784202 100644 --- a/control.hpp +++ b/control.hpp @@ -414,7 +414,7 @@ struct ControlNet : public GGMLRunner { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* hint, struct ggml_tensor* timesteps, @@ -430,8 +430,12 @@ struct ControlNet : public GGMLRunner { return build_graph(x, hint, timesteps, context, y); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); - guided_hint_cached = true; + bool res = GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + if (res) { + // cache guided_hint + guided_hint_cached = true; + } + return res; } bool load_from_file(const std::string& file_path, int n_threads) { diff --git a/denoiser.hpp b/denoiser.hpp index 3b6be75..32f4027 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -666,7 +666,7 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser { typedef std::function denoise_cb_t; // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t -static void sample_k_diffusion(sample_method_t method, +static bool sample_k_diffusion(sample_method_t method, denoise_cb_t model, ggml_context* work_ctx, ggml_tensor* x, @@ -685,6 +685,9 @@ static void sample_k_diffusion(sample_method_t method, // denoise ggml_tensor* denoised = model(x, sigma, i + 1); + if (denoised == nullptr) { + return false; + } // d = (x - denoised) / sigma { @@ -738,6 +741,9 @@ static void sample_k_diffusion(sample_method_t method, // denoise ggml_tensor* denoised = model(x, sigma, i + 1); + if (denoised == nullptr) { + return false; + } // d = (x - denoised) / sigma { @@ -769,6 +775,9 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // denoise ggml_tensor* denoised = model(x, sigmas[i], -(i + 1)); + if (denoised == nullptr) { + return false; + } // d = (x - denoised) / sigma { @@ -803,7 +812,10 @@ static void sample_k_diffusion(sample_method_t method, } ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1); - float* vec_denoised = (float*)denoised->data; + if (denoised == nullptr) { + return false; + } + float* vec_denoised = (float*)denoised->data; for (int j = 0; j < ggml_nelements(x); j++) { float d2 = (vec_x2[j] - vec_denoised[j]) / sigmas[i + 1]; vec_d[j] = (vec_d[j] + d2) / 2; @@ -819,6 +831,9 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // denoise ggml_tensor* denoised = model(x, sigmas[i], i + 1); + if (denoised == nullptr) { + return false; + } // d = (x - denoised) / sigma { @@ -855,7 +870,10 @@ static void sample_k_diffusion(sample_method_t method, } ggml_tensor* denoised = model(x2, sigma_mid, i + 1); - float* vec_denoised = (float*)denoised->data; + if (denoised == nullptr) { + return false; + } + float* vec_denoised = (float*)denoised->data; for (int j = 0; j < ggml_nelements(x); j++) { float d2 = (vec_x2[j] - vec_denoised[j]) / sigma_mid; vec_x[j] = vec_x[j] + d2 * dt_2; @@ -871,6 +889,9 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // denoise ggml_tensor* denoised = model(x, sigmas[i], i + 1); + if (denoised == nullptr) { + return false; + } // get_ancestral_step float sigma_up = std::min(sigmas[i + 1], @@ -907,6 +928,9 @@ static void sample_k_diffusion(sample_method_t method, } ggml_tensor* denoised = model(x2, sigmas[i + 1], i + 1); + if (denoised == nullptr) { + return false; + } // Second half-step for (int j = 0; j < ggml_nelements(x); j++) { @@ -937,6 +961,9 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // denoise ggml_tensor* denoised = model(x, sigmas[i], i + 1); + if (denoised == nullptr) { + return false; + } float t = t_fn(sigmas[i]); float t_next = t_fn(sigmas[i + 1]); @@ -976,6 +1003,9 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // denoise ggml_tensor* denoised = model(x, sigmas[i], i + 1); + if (denoised == nullptr) { + return false; + } float t = t_fn(sigmas[i]); float t_next = t_fn(sigmas[i + 1]); @@ -1026,7 +1056,10 @@ static void sample_k_diffusion(sample_method_t method, // Denoising step ggml_tensor* denoised = model(x_cur, sigma, i + 1); - float* vec_denoised = (float*)denoised->data; + if (denoised == nullptr) { + return false; + } + float* vec_denoised = (float*)denoised->data; // d_cur = (x_cur - denoised) / sigma struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur); float* vec_d_cur = (float*)d_cur->data; @@ -1169,6 +1202,9 @@ static void sample_k_diffusion(sample_method_t method, // denoise ggml_tensor* denoised = model(x, sigma, i + 1); + if (denoised == nullptr) { + return false; + } // x = denoised { @@ -1561,8 +1597,9 @@ static void sample_k_diffusion(sample_method_t method, default: LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); - abort(); + return false; } + return true; } #endif // __DENOISER_HPP__ diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 5a311f5..8c741fd 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -27,7 +27,7 @@ struct DiffusionParams { struct DiffusionModel { virtual std::string get_desc() = 0; - virtual void compute(int n_threads, + virtual bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) = 0; @@ -87,7 +87,7 @@ struct UNetModel : public DiffusionModel { unet.set_flash_attention_enabled(enabled); } - void compute(int n_threads, + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { @@ -148,7 +148,7 @@ struct MMDiTModel : public DiffusionModel { mmdit.set_flash_attention_enabled(enabled); } - void compute(int n_threads, + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { @@ -210,7 +210,7 @@ struct FluxModel : public DiffusionModel { flux.set_flash_attention_enabled(enabled); } - void compute(int n_threads, + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { @@ -277,7 +277,7 @@ struct WanModel : public DiffusionModel { wan.set_flash_attention_enabled(enabled); } - void compute(int n_threads, + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { @@ -343,7 +343,7 @@ struct QwenImageModel : public DiffusionModel { qwen_image.set_flash_attention_enabled(enabled); } - void compute(int n_threads, + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { @@ -406,7 +406,7 @@ struct ZImageModel : public DiffusionModel { z_image.set_flash_attention_enabled(enabled); } - void compute(int n_threads, + bool compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { diff --git a/esrgan.hpp b/esrgan.hpp index fb09544..4cac956 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -353,14 +353,14 @@ struct ESRGAN : public GGMLRunner { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* x, ggml_tensor** output, ggml_context* output_ctx = nullptr) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(x); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } }; diff --git a/flux.hpp b/flux.hpp index dc0a96f..f0c65e3 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1413,7 +1413,7 @@ namespace Flux { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -1434,7 +1434,7 @@ namespace Flux { return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 1a0bd44..92dd3b8 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1938,25 +1938,35 @@ public: return ggml_get_tensor(cache_ctx, name.c_str()); } - void compute(get_graph_cb_t get_graph, + bool compute(get_graph_cb_t get_graph, int n_threads, bool free_compute_buffer_immediately = true, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) { if (!offload_params_to_runtime_backend()) { LOG_ERROR("%s offload params to runtime backend failed", get_desc().c_str()); - return; + return false; + } + if (!alloc_compute_buffer(get_graph)) { + LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str()); + return false; } - alloc_compute_buffer(get_graph); reset_compute_ctx(); struct ggml_cgraph* gf = get_compute_graph(get_graph); - GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf)); + if (!ggml_gallocr_alloc_graph(compute_allocr, gf)) { + LOG_ERROR("%s alloc compute graph failed", get_desc().c_str()); + return false; + } copy_data_to_backend_tensor(); if (ggml_backend_is_cpu(runtime_backend)) { ggml_backend_cpu_set_n_threads(runtime_backend, n_threads); } - ggml_backend_graph_compute(runtime_backend, gf); + ggml_status status = ggml_backend_graph_compute(runtime_backend, gf); + if (status != GGML_STATUS_SUCCESS) { + LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status)); + return false; + } #ifdef GGML_PERF ggml_graph_print(gf); #endif @@ -1974,6 +1984,7 @@ public: if (free_compute_buffer_immediately) { free_compute_buffer(); } + return true; } void set_flash_attention_enabled(bool enabled) { diff --git a/llm.hpp b/llm.hpp index d1dd3a6..c42c564 100644 --- a/llm.hpp +++ b/llm.hpp @@ -1191,7 +1191,7 @@ namespace LLM { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* input_ids, std::vector> image_embeds, std::set out_layers, @@ -1200,7 +1200,7 @@ namespace LLM { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(input_ids, image_embeds, out_layers); }; - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { diff --git a/mmdit.hpp b/mmdit.hpp index 247c8f6..38bdc2e 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -894,7 +894,7 @@ struct MMDiTRunner : public GGMLRunner { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -910,7 +910,7 @@ struct MMDiTRunner : public GGMLRunner { return build_graph(x, timesteps, context, y, skip_layers); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/model.cpp b/model.cpp index 5338d25..2b74d34 100644 --- a/model.cpp +++ b/model.cpp @@ -104,8 +104,8 @@ const char* unused_tensors[] = { "embedding_manager", "denoiser.sigmas", "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training - "ztsnr", // Found in some SDXL vpred models - "edm_vpred.sigma_min", // Found in CosXL + "ztsnr", // Found in some SDXL vpred models + "edm_vpred.sigma_min", // Found in CosXL // TODO: find another way to avoid the "unknown tensor" for these two // "edm_vpred.sigma_max", // Used to detect CosXL // "v_pred", // Used to detect SDXL vpred models diff --git a/pmid.hpp b/pmid.hpp index 70d8059..d69423a 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -548,7 +548,7 @@ public: return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, struct ggml_tensor* id_embeds, @@ -561,7 +561,7 @@ public: }; // GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds); - GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx); } }; diff --git a/qwen_image.hpp b/qwen_image.hpp index 3e4a75e..eeb823d 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -588,7 +588,7 @@ namespace Qwen { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -603,7 +603,7 @@ namespace Qwen { return build_graph(x, timesteps, context, ref_latents, increase_ref_index); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 11e2bbc..8ac6b73 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1683,8 +1683,11 @@ public: std::vector controls; if (control_hint != nullptr && control_net != nullptr) { - control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector); - controls = control_net->controls; + if (control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector)) { + controls = control_net->controls; + } else { + LOG_ERROR("controlnet compute failed"); + } // print_ggml_tensor(controls[12]); // GGML_ASSERT(0); } @@ -1716,9 +1719,12 @@ public: bool skip_model = easycache_before_condition(active_condition, *active_output); if (!skip_model) { - work_diffusion_model->compute(n_threads, - diffusion_params, - active_output); + if (!work_diffusion_model->compute(n_threads, + diffusion_params, + active_output)) { + LOG_ERROR("diffusion model compute failed"); + return nullptr; + } easycache_after_condition(active_condition, *active_output); } @@ -1728,8 +1734,11 @@ public: if (has_unconditioned) { // uncond if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) { - control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); - controls = control_net->controls; + if (control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector)) { + controls = control_net->controls; + } else { + LOG_ERROR("controlnet compute failed"); + } } current_step_skipped = easycache_step_is_skipped(); diffusion_params.controls = controls; @@ -1738,9 +1747,12 @@ public: diffusion_params.y = uncond.c_vector; bool skip_uncond = easycache_before_condition(&uncond, out_uncond); if (!skip_uncond) { - work_diffusion_model->compute(n_threads, - diffusion_params, - &out_uncond); + if (!work_diffusion_model->compute(n_threads, + diffusion_params, + &out_uncond)) { + LOG_ERROR("diffusion model compute failed"); + return nullptr; + } easycache_after_condition(&uncond, out_uncond); } negative_data = (float*)out_uncond->data; @@ -1753,9 +1765,12 @@ public: diffusion_params.y = img_cond.c_vector; bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond); if (!skip_img_cond) { - work_diffusion_model->compute(n_threads, - diffusion_params, - &out_img_cond); + if (!work_diffusion_model->compute(n_threads, + diffusion_params, + &out_img_cond)) { + LOG_ERROR("diffusion model compute failed"); + return nullptr; + } easycache_after_condition(&img_cond, out_img_cond); } img_cond_data = (float*)out_img_cond->data; @@ -1772,9 +1787,12 @@ public: diffusion_params.c_concat = cond.c_concat; diffusion_params.y = cond.c_vector; diffusion_params.skip_layers = skip_layers; - work_diffusion_model->compute(n_threads, - diffusion_params, - &out_skip); + if (!work_diffusion_model->compute(n_threads, + diffusion_params, + &out_skip)) { + LOG_ERROR("diffusion model compute failed"); + return nullptr; + } } skip_layer_data = (float*)out_skip->data; } @@ -1837,7 +1855,15 @@ public: return denoised; }; - sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta); + if (!sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta)) { + LOG_ERROR("Diffusion model sampling failed"); + if (control_net) { + control_net->free_control_ctx(); + control_net->free_compute_buffer(); + } + diffusion_model->free_compute_buffer(); + return NULL; + } if (easycache_enabled) { size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0; @@ -3064,10 +3090,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, nullptr, 1.0f, easycache_params); - // print_ggml_tensor(x_0); - int64_t sampling_end = ggml_time_ms(); - LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); - final_latents.push_back(x_0); + int64_t sampling_end = ggml_time_ms(); + if (x_0 != nullptr) { + // print_ggml_tensor(x_0); + LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); + final_latents.push_back(x_0); + } else { + LOG_ERROR("sampling for image %d/%d failed after %.2fs", b + 1, batch_count, (sampling_end - sampling_start) * 1.0f / 1000); + } } if (sd_ctx->sd->free_params_immediately) { diff --git a/t5.hpp b/t5.hpp index 4cc8e12..4370a56 100644 --- a/t5.hpp +++ b/t5.hpp @@ -820,7 +820,7 @@ struct T5Runner : public GGMLRunner { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* input_ids, struct ggml_tensor* attention_mask, ggml_tensor** output, @@ -828,7 +828,7 @@ struct T5Runner : public GGMLRunner { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(input_ids, attention_mask); }; - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } static std::vector _relative_position_bucket(const std::vector& relative_position, diff --git a/tae.hpp b/tae.hpp index 568e409..7f3ca44 100644 --- a/tae.hpp +++ b/tae.hpp @@ -247,7 +247,7 @@ struct TinyAutoEncoder : public GGMLRunner { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -256,7 +256,7 @@ struct TinyAutoEncoder : public GGMLRunner { return build_graph(z, decode_graph); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } }; diff --git a/unet.hpp b/unet.hpp index de05f46..ec7578e 100644 --- a/unet.hpp +++ b/unet.hpp @@ -645,7 +645,7 @@ struct UNetModelRunner : public GGMLRunner { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -665,7 +665,7 @@ struct UNetModelRunner : public GGMLRunner { return build_graph(x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/vae.hpp b/vae.hpp index 281a5ca..ad5db1b 100644 --- a/vae.hpp +++ b/vae.hpp @@ -617,7 +617,7 @@ public: struct VAE : public GGMLRunner { VAE(ggml_backend_t backend, bool offload_params_to_cpu) : GGMLRunner(backend, offload_params_to_cpu) {} - virtual void compute(const int n_threads, + virtual bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -629,7 +629,7 @@ struct VAE : public GGMLRunner { struct FakeVAE : public VAE { FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu) : VAE(backend, offload_params_to_cpu) {} - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -641,6 +641,7 @@ struct FakeVAE : public VAE { float value = ggml_ext_tensor_get_f32(z, i0, i1, i2, i3); ggml_ext_tensor_set_f32(*output, value, i0, i1, i2, i3); }); + return true; } void get_param_tensors(std::map& tensors, const std::string prefix) override {} @@ -711,7 +712,7 @@ struct AutoEncoderKL : public VAE { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -722,7 +723,7 @@ struct AutoEncoderKL : public VAE { }; // ggml_set_f32(z, 0.5f); // print_ggml_tensor(z); - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/wan.hpp b/wan.hpp index 3e02a7b..75333bf 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1175,7 +1175,7 @@ namespace WAN { return gf; } - void compute(const int n_threads, + bool compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, @@ -1184,7 +1184,7 @@ namespace WAN { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } else { // chunk 1 result is weird ae.clear_cache(); int64_t t = z->ne[2]; @@ -1193,11 +1193,11 @@ namespace WAN { return build_graph_partial(z, decode_graph, i); }; struct ggml_tensor* out = nullptr; - GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx); + bool res = GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx); ae.clear_cache(); if (t == 1) { *output = out; - return; + return res; } *output = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], (t - 1) * 4 + 1, out->ne[3]); @@ -1221,11 +1221,12 @@ namespace WAN { out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]); for (i = 1; i < t; i++) { - GGMLRunner::compute(get_graph, n_threads, true, &out); + res = res || GGMLRunner::compute(get_graph, n_threads, true, &out); ae.clear_cache(); copy_to_output(); } free_cache_ctx_and_buffer(); + return res; } } @@ -2194,7 +2195,7 @@ namespace WAN { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -2209,7 +2210,7 @@ namespace WAN { return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat, vace_context, vace_strength); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/z_image.hpp b/z_image.hpp index 888a895..bc554f1 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -30,7 +30,7 @@ namespace ZImage { JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm) : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) { blocks["qkv"] = std::make_shared(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false); - float scale = 1.f; + float scale = 1.f; #if GGML_USE_HIP // Prevent NaN issues with certain ROCm setups scale = 1.f / 16.f; @@ -574,7 +574,7 @@ namespace ZImage { return gf; } - void compute(int n_threads, + bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -589,7 +589,7 @@ namespace ZImage { return build_graph(x, timesteps, context, ref_latents, increase_ref_index); }; - GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() {