feat: support for cancelling generations (#1124)

* feat: support for canceling the ongoing generation

* return partial image batches on cancel

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Wagner Bruna 2026-06-15 13:36:38 -03:00 committed by GitHub
parent 146b6cc49e
commit 5a34bc7f6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 141 additions and 4 deletions

View File

@ -452,6 +452,17 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
enum sd_cancel_mode_t {
// Stop the current generation as soon as possible.
SD_CANCEL_ALL,
// Finish the current image sample, then skip additional batch latents and return completed images.
SD_CANCEL_NEW_LATENTS,
// Clear a pending cancellation request.
SD_CANCEL_RESET
};
SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API bool generate_video(sd_ctx_t* sd_ctx, SD_API bool generate_video(sd_ctx_t* sd_ctx,
const sd_vid_gen_params_t* sd_vid_gen_params, const sd_vid_gen_params_t* sd_vid_gen_params,

View File

@ -53,6 +53,8 @@
const char* sd_vae_format_name(enum sd_vae_format_t format); const char* sd_vae_format_name(enum sd_vae_format_t format);
static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback); static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback);
#include <atomic>
const char* model_version_to_str[] = { const char* model_version_to_str[] = {
"SD 1.x", "SD 1.x",
"SD 1.x Inpaint", "SD 1.x Inpaint",
@ -159,6 +161,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
/*=============================================== StableDiffusionGGML ================================================*/ /*=============================================== StableDiffusionGGML ================================================*/
static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
"sd_cancel_mode_t must be lock-free");
class StableDiffusionGGML { class StableDiffusionGGML {
public: public:
SDBackendManager backend_manager; SDBackendManager backend_manager;
@ -222,6 +227,20 @@ public:
return module_backend; return module_backend;
} }
std::atomic<sd_cancel_mode_t> cancellation_flag = SD_CANCEL_RESET;
void set_cancel_flag(enum sd_cancel_mode_t flag) {
cancellation_flag.store(flag, std::memory_order_release);
}
void reset_cancel_flag() {
set_cancel_flag(SD_CANCEL_RESET);
}
enum sd_cancel_mode_t get_cancel_flag() {
return cancellation_flag.load(std::memory_order_acquire);
}
size_t max_graph_vram_bytes_for_module(SDBackendModule module) { size_t max_graph_vram_bytes_for_module(SDBackendModule module) {
return max_vram_assignment.bytes_for_backend(backend_for(module)); return max_vram_assignment.bytes_for_backend(backend_for(module));
} }
@ -1941,6 +1960,11 @@ public:
SamplePreviewContext preview = prepare_sample_preview_context(); SamplePreviewContext preview = prepare_sample_preview_context();
auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::guidance::GuiderOutput { auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::guidance::GuiderOutput {
if (get_cancel_flag() == SD_CANCEL_ALL) {
LOG_DEBUG("cancelling generation");
return {};
}
if (step == 1 || step == -1) { if (step == 1 || step == -1) {
pretty_progress(0, (int)steps, 0); pretty_progress(0, (int)steps, 0);
last_progress_us = ggml_time_us(); last_progress_us = ggml_time_us();
@ -2963,6 +2987,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx); free(sd_ctx);
} }
SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) {
if (sd_ctx && sd_ctx->sd) {
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
mode = SD_CANCEL_ALL;
}
sd_ctx->sd->set_cancel_flag(mode);
}
}
static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd, static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd,
const sd::Tensor<float>& waveform) { const sd::Tensor<float>& waveform) {
if (sd == nullptr || waveform.empty()) { if (sd == nullptr || waveform.empty()) {
@ -4150,15 +4183,29 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
const GenerationRequest& request, const GenerationRequest& request,
const std::vector<sd::Tensor<float>>& final_latents) { const std::vector<sd::Tensor<float>>& final_latents) {
if (final_latents.size() != static_cast<size_t>(request.batch_count)) { if (final_latents.empty()) {
LOG_ERROR("expected %d latents, got %zu", request.batch_count, final_latents.size()); LOG_ERROR("no latent images to decode");
return nullptr; return nullptr;
} }
LOG_INFO("decoding %zu latents", final_latents.size()); if (final_latents.size() > static_cast<size_t>(request.batch_count)) {
LOG_ERROR("expected at most %d latents, got %zu", request.batch_count, final_latents.size());
return nullptr;
}
if (final_latents.size() < static_cast<size_t>(request.batch_count)) {
LOG_INFO("decoding %zu/%d latents", final_latents.size(), request.batch_count);
} else {
LOG_INFO("decoding %zu latents", final_latents.size());
}
std::vector<sd::Tensor<float>> decoded_images; std::vector<sd::Tensor<float>> decoded_images;
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
bool cancelled = false;
for (size_t i = 0; i < final_latents.size(); i++) { for (size_t i = 0; i < final_latents.size(); i++) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling latent decodings");
cancelled = true;
break;
}
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]); sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]);
if (image.empty()) { if (image.empty()) {
@ -4172,6 +4219,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
int64_t t4 = ggml_time_ms(); int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t0) * 1.0f / 1000); LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t0) * 1.0f / 1000);
if (decoded_images.empty()) {
LOG_ERROR(cancelled ? "cancelled before any latent images were decoded" : "no decoded images");
return nullptr;
}
sd_image_t* result_images = (sd_image_t*)calloc(request.batch_count, sizeof(sd_image_t)); sd_image_t* result_images = (sd_image_t*)calloc(request.batch_count, sizeof(sd_image_t));
if (result_images == nullptr) { if (result_images == nullptr) {
@ -4190,6 +4241,11 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
const sd::Tensor<float>& latent, const sd::Tensor<float>& latent,
const GenerationRequest& request, const GenerationRequest& request,
UpscalerGGML* upscaler) { UpscalerGGML* upscaler) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling hires latent upscale");
return {};
}
auto get_hires_latent_target_shape = [&]() { auto get_hires_latent_target_shape = [&]() {
std::vector<int64_t> target_shape = latent.shape(); std::vector<int64_t> target_shape = latent.shape();
if (target_shape.size() < 2) { if (target_shape.size() < 2) {
@ -4262,6 +4318,10 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
sd_hires_upscaler_name(request.hires.upscaler)); sd_hires_upscaler_name(request.hires.upscaler));
return {}; return {};
} }
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling hires image upscale");
return {};
}
sd::Tensor<float> upscaled_tensor; sd::Tensor<float> upscaled_tensor;
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) {
@ -4298,6 +4358,10 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f); upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f);
} }
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling hires latent encode");
return {};
}
sd::Tensor<float> upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor); sd::Tensor<float> upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor);
if (upscaled_latent.empty()) { if (upscaled_latent.empty()) {
LOG_ERROR("encode_first_stage failed after hires %s upscale", LOG_ERROR("encode_first_stage failed after hires %s upscale",
@ -4362,6 +4426,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
return nullptr; return nullptr;
} }
sd_ctx->sd->reset_cancel_flag();
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
GenerationRequest request(sd_ctx, sd_img_gen_params); GenerationRequest request(sd_ctx, sd_img_gen_params);
@ -4397,6 +4463,18 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::vector<sd::Tensor<float>> final_latents; std::vector<sd::Tensor<float>> final_latents;
int64_t denoise_start = ggml_time_ms(); int64_t denoise_start = ggml_time_ms();
for (int b = 0; b < request.batch_count; b++) { for (int b = 0; b < request.batch_count; b++) {
sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag();
if (cancel == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation");
return nullptr;
}
if (cancel == SD_CANCEL_NEW_LATENTS) {
LOG_INFO("cancelling new latent generation, returning %zu/%d completed latents",
final_latents.size(),
request.batch_count);
break;
}
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = request.seed + b; int64_t cur_seed = request.seed + b;
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed);
@ -4446,12 +4524,24 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
LOG_INFO("generating %zu latent images completed, taking %.2fs", LOG_INFO("generating %zu latent images completed, taking %.2fs",
final_latents.size(), final_latents.size(),
(denoise_end - denoise_start) * 1.0f / 1000); (denoise_end - denoise_start) * 1.0f / 1000);
if (final_latents.empty()) {
LOG_ERROR("no latent images generated");
return nullptr;
}
if (request.hires.enabled && request.hires.target_width > 0) { if (request.hires.enabled && request.hires.target_width > 0) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before hires fix");
return nullptr;
}
LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height); LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height);
std::unique_ptr<UpscalerGGML> hires_upscaler; std::unique_ptr<UpscalerGGML> hires_upscaler;
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before hires model load");
return nullptr;
}
LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path); LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path);
hires_upscaler = std::make_unique<UpscalerGGML>(sd_ctx->sd->n_threads, hires_upscaler = std::make_unique<UpscalerGGML>(sd_ctx->sd->n_threads,
false, false,
@ -4485,6 +4575,10 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::vector<sd::Tensor<float>> hires_final_latents; std::vector<sd::Tensor<float>> hires_final_latents;
int64_t hires_denoise_start = ggml_time_ms(); int64_t hires_denoise_start = ggml_time_ms();
for (int b = 0; b < (int)final_latents.size(); b++) { for (int b = 0; b < (int)final_latents.size(); b++) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation during hires fix");
return nullptr;
}
int64_t cur_seed = request.seed + b; int64_t cur_seed = request.seed + b;
sd_ctx->sd->rng->manual_seed(cur_seed); sd_ctx->sd->rng->manual_seed(cur_seed);
sd_ctx->sd->sampler_rng->manual_seed(cur_seed); sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
@ -4915,6 +5009,10 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
LOG_ERROR("no latent video to decode"); LOG_ERROR("no latent video to decode");
return nullptr; return nullptr;
} }
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling video decode");
return nullptr;
}
sd::Tensor<float> video_latent = final_latent; sd::Tensor<float> video_latent = final_latent;
if (sd_version_is_ltxav(sd_ctx->sd->version) && if (sd_version_is_ltxav(sd_ctx->sd->version) &&
video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) { video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) {
@ -5160,6 +5258,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
if (audio_out != nullptr) { if (audio_out != nullptr) {
*audio_out = nullptr; *audio_out = nullptr;
} }
sd_ctx->sd->reset_cancel_flag();
if (num_frames_out != nullptr) { if (num_frames_out != nullptr) {
*num_frames_out = 0; *num_frames_out = 0;
} }
@ -5221,6 +5322,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
sd::Tensor<float> noise = sd::Tensor<float>::randn_like(x_t, sd_ctx->sd->rng); sd::Tensor<float> noise = sd::Tensor<float>::randn_like(x_t, sd_ctx->sd->rng);
if (plan.high_noise_sample_steps > 0) { if (plan.high_noise_sample_steps > 0) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before high-noise sampling");
return false;
}
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T); LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
@ -5263,6 +5368,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
} }
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before sampling");
return false;
}
LOG_DEBUG("sample %dx%dx%d", W, H, T); LOG_DEBUG("sample %dx%dx%d", W, H, T);
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
sd::Tensor<float> final_latent = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model, sd::Tensor<float> final_latent = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model,
@ -5299,6 +5408,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
if (latent_upscale_enabled) { if (latent_upscale_enabled) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before latent upscale");
return false;
}
int64_t upscale_start = ggml_time_ms(); int64_t upscale_start = ggml_time_ms();
sd::Tensor<float> upscaled_latent = upscale_ltx_spatial_video_latent(sd_ctx, sd::Tensor<float> upscaled_latent = upscale_ltx_spatial_video_latent(sd_ctx,
request.hires.model_path, request.hires.model_path,
@ -5358,6 +5471,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
} }
sd::Tensor<float> hires_denoise_mask; sd::Tensor<float> hires_denoise_mask;
sd::Tensor<float> hires_video_positions; sd::Tensor<float> hires_video_positions;
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before latent upscale refine");
return false;
}
if (!apply_ltxv_refine_image_conditioning(sd_ctx, if (!apply_ltxv_refine_image_conditioning(sd_ctx,
sd_vid_gen_params, sd_vid_gen_params,
hires_request, hires_request,
@ -5437,6 +5554,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
if (sd_version_is_ltxav(sd_ctx->sd->version) && if (sd_version_is_ltxav(sd_ctx->sd->version) &&
latents.audio_length > 0 && latents.audio_length > 0 &&
sd_ctx->sd->audio_vae_model != nullptr) { sd_ctx->sd->audio_vae_model != nullptr) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before audio decode");
return false;
}
int64_t audio_latent_decode_start = ggml_time_ms(); int64_t audio_latent_decode_start = ggml_time_ms();
auto audio_latent = unpack_ltxav_audio_latent(final_latent, auto audio_latent = unpack_ltxav_audio_latent(final_latent,
@ -5469,6 +5590,11 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
} }
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before video decode");
free_sd_audio(generated_audio);
return false;
}
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out); auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
if (result == nullptr) { if (result == nullptr) {
free_sd_audio(generated_audio); free_sd_audio(generated_audio);