From bb90bfa00f858c7df6502e75f31c4440d4d11fde Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 14 Jun 2026 22:46:32 +0800 Subject: [PATCH] feat: support backend-specific max-vram budgets --- docs/backend.md | 8 ++ examples/common/common.cpp | 15 ++- examples/common/common.h | 2 +- include/stable-diffusion.h | 2 +- src/core/ggml_extend_backend.cpp | 10 +- src/core/ggml_extend_backend.h | 1 + src/core/ggml_graph_cut.cpp | 154 +++++++++++++++++++++++++++++++ src/core/ggml_graph_cut.h | 12 +++ src/stable-diffusion.cpp | 58 +++++++----- 9 files changed, 223 insertions(+), 39 deletions(-) diff --git a/docs/backend.md b/docs/backend.md index 2e312236..29ac8031 100644 --- a/docs/backend.md +++ b/docs/backend.md @@ -35,6 +35,14 @@ sd-cli -m model.safetensors -p "a cat" --backend cuda0 --params-backend te=cpu,v sd-cli -m model.safetensors -p "a cat" --backend cuda0 --params-backend disk ``` +`--max-vram` can target resolved backend/device names: + +```shell +sd-cli -m model.safetensors -p "a cat" --backend diffusion=cuda0,vae=vulkan0 --max-vram cuda0=6,vulkan0=2 +``` + +The budget applies to every module running on that backend. + Module names are case-insensitive. Hyphens and underscores in module names are ignored, so `clip_vision`, `clip-vision`, and `clipvision` are equivalent. `all=`, `default=`, and `*=` can be used to set the default backend inside a mixed assignment: diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 03a35c9a..e9b8bc85 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -431,6 +431,10 @@ ArgOptions SDContextParams::get_options() { "--rpc-servers", "comma-separated list of RPC servers to connect to for offloading, in the format host:port, e.g. localhost:50052,192.168.1.3:50052", &rpc_servers}, + {"", + "--max-vram", + "maximum VRAM budget in GiB for graph-cut segmented execution. Accepts a single value or assignments by backend/device, e.g. 6 or cuda0=6,vulkan0=4. 0 disables graph splitting; a negative value auto-detects free VRAM, sparing the specified value", + &max_vram}, }; options.int_options = { @@ -445,13 +449,6 @@ ArgOptions SDContextParams::get_options() { &chroma_t5_mask_pad}, }; - options.float_options = { - {"", - "--max-vram", - "maximum VRAM budget in GiB for graph-cut segmented execution. 0 disables graph splitting; a negative value auto-detects free VRAM, sparing the specified value (e.g. -0.5 will keep at least 0.5 GiB free)", - &max_vram}, - }; - options.bool_options = { {"", "--stream-layers", @@ -758,7 +755,7 @@ std::string SDContextParams::to_string() const { << " rng_type: " << sd_rng_type_name(rng_type) << ",\n" << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" << " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n" - << " max_vram: " << max_vram << ",\n" + << " max_vram: \"" << max_vram << "\",\n" << " stream_layers: " << (stream_layers ? "true" : "false") << ",\n" << " backend: \"" << backend << "\",\n" << " params_backend: \"" << params_backend << "\",\n" @@ -836,7 +833,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool taesd_preview) { sd_ctx_params.chroma_t5_mask_pad = chroma_t5_mask_pad; sd_ctx_params.qwen_image_zero_cond_t = qwen_image_zero_cond_t; sd_ctx_params.vae_format = str_to_vae_format(vae_format); - sd_ctx_params.max_vram = max_vram; + sd_ctx_params.max_vram = max_vram.c_str(); sd_ctx_params.stream_layers = stream_layers; sd_ctx_params.backend = effective_backend.c_str(); sd_ctx_params.params_backend = effective_params_backend.c_str(); diff --git a/examples/common/common.h b/examples/common/common.h index 86fcc162..55fa5ac0 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -144,7 +144,7 @@ struct SDContextParams { rng_type_t rng_type = CUDA_RNG; rng_type_t sampler_rng_type = RNG_TYPE_COUNT; bool offload_params_to_cpu = false; - float max_vram = 0.f; + std::string max_vram = "0"; bool stream_layers = false; std::string backend; std::string params_backend; diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 674c9d63..00f3e4e9 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -216,7 +216,7 @@ typedef struct { int chroma_t5_mask_pad; bool qwen_image_zero_cond_t; enum sd_vae_format_t vae_format; - float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB) + const char* max_vram; // GiB budget or backend assignment spec for graph-cut segmented param offload (0 = disabled, -1 = auto) bool stream_layers; // Enable residency+prefetch streaming on top of --max-vram (no effect without --max-vram) const char* backend; const char* params_backend; diff --git a/src/core/ggml_extend_backend.cpp b/src/core/ggml_extend_backend.cpp index d8062fef..f3e2cceb 100644 --- a/src/core/ggml_extend_backend.cpp +++ b/src/core/ggml_extend_backend.cpp @@ -280,7 +280,7 @@ static std::string get_default_backend_name() { return resolve_first_device_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); } -static std::string sd_resolve_backend_name(const std::string& name) { +std::string sd_backend_resolve_name(const std::string& name) { ggml_backend_load_all_once(); std::string requested = trim_copy(name); std::string lower = lower_copy(requested); @@ -318,7 +318,7 @@ static std::string sd_resolve_backend_name(const std::string& name) { } static bool backend_name_exists(const std::string& name) { - return !sd_resolve_backend_name(name).empty(); + return !sd_backend_resolve_name(name).empty(); } static ggml_backend_t init_named_backend(const std::string& name) { @@ -328,7 +328,7 @@ static ggml_backend_t init_named_backend(const std::string& name) { return ggml_backend_init_best(); } - std::string resolved = sd_resolve_backend_name(name); + std::string resolved = sd_backend_resolve_name(name); if (resolved.empty()) { return nullptr; } @@ -599,7 +599,7 @@ bool SDBackendManager::validate(std::string* error) const { } return false; } - if (!sd_resolve_backend_name(name).empty()) { + if (!sd_backend_resolve_name(name).empty()) { return true; } if (error != nullptr) { @@ -632,7 +632,7 @@ bool SDBackendManager::validate(std::string* error) const { } ggml_backend_t SDBackendManager::init_cached_backend(const std::string& name) { - std::string resolved = sd_resolve_backend_name(name); + std::string resolved = sd_backend_resolve_name(name); std::string key = lower_copy(resolved); ggml_backend_t backend = nullptr; diff --git a/src/core/ggml_extend_backend.h b/src/core/ggml_extend_backend.h index 4db54325..9aecf97c 100644 --- a/src/core/ggml_extend_backend.h +++ b/src/core/ggml_extend_backend.h @@ -71,6 +71,7 @@ bool sd_backend_is(ggml_backend_t backend, const std::string& name); bool sd_backend_is_cpu(ggml_backend_t backend); ggml_backend_t sd_backend_cpu_init(); bool sd_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); +std::string sd_backend_resolve_name(const std::string& name); const char* sd_backend_module_name(SDBackendModule module); void ggml_ext_im_set_f32_1d(const struct ggml_tensor* tensor, int i, float value); bool add_rpc_devices(const std::string& servers); diff --git a/src/core/ggml_graph_cut.cpp b/src/core/ggml_graph_cut.cpp index 08312aab..d4874b05 100644 --- a/src/core/ggml_graph_cut.cpp +++ b/src/core/ggml_graph_cut.cpp @@ -1,6 +1,8 @@ #include "core/ggml_graph_cut.h" #include +#include +#include #include #include #include @@ -8,6 +10,7 @@ #include #include +#include "core/ggml_extend_backend.h" #include "core/util.h" #include "ggml-alloc.h" #include "ggml-backend.h" @@ -83,6 +86,157 @@ namespace sd::ggml_graph_cut { segment.output_bytes; } + static std::string lower_ascii_copy(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; + } + + static std::string normalize_backend_budget_key(const std::string& value) { + return lower_ascii_copy(trim(value)); + } + + static bool is_default_max_vram_key(const std::string& key) { + std::string normalized = normalize_backend_budget_key(key); + return normalized == "all" || normalized == "default" || normalized == "*"; + } + + static bool parse_max_vram_budget_value(const std::string& text, float* value, std::string* error) { + float parsed = 0.f; + if (!parse_strict_float(text, parsed) || !std::isfinite(parsed)) { + if (error != nullptr) { + *error = "invalid --max-vram value '" + text + "'"; + } + return false; + } + *value = parsed; + return true; + } + + static std::vector backend_budget_keys(ggml_backend_t backend) { + std::vector keys; + if (backend == nullptr) { + return keys; + } + + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + if (dev != nullptr) { + keys.push_back(normalize_backend_budget_key(ggml_backend_dev_name(dev))); + } + const char* backend_name = ggml_backend_name(backend); + if (backend_name != nullptr) { + keys.push_back(normalize_backend_budget_key(backend_name)); + } + return keys; + } + + void MaxVramAssignment::reset(float fallback_gib) { + default_gib = fallback_gib; + backend_gib.clear(); + resolved_backend_bytes.clear(); + } + + bool MaxVramAssignment::parse(const std::string& raw_spec, std::string* error) { + const std::string in = trim(raw_spec); + if (in.empty()) { + return true; + } + + for (const std::string& raw_part : split_string(in, ',')) { + const std::string part = trim(raw_part); + if (part.empty()) { + continue; + } + + const size_t eq = part.find('='); + if (eq == std::string::npos) { + float value = 0.f; + if (!parse_max_vram_budget_value(part, &value, error)) { + return false; + } + default_gib = value; + continue; + } + + const std::string key = trim(part.substr(0, eq)); + const std::string value_text = trim(part.substr(eq + 1)); + if (key.empty() || value_text.empty()) { + if (error != nullptr) { + *error = "invalid --max-vram assignment '" + part + "'"; + } + return false; + } + + float value = 0.f; + if (!parse_max_vram_budget_value(value_text, &value, error)) { + return false; + } + + if (is_default_max_vram_key(key)) { + default_gib = value; + continue; + } + + const std::string backend_key = trim(key); + if (backend_key.empty()) { + if (error != nullptr) { + *error = "invalid --max-vram backend key in '" + part + "'"; + } + return false; + } + backend_gib[backend_key] = value; + } + resolved_backend_bytes.clear(); + return true; + } + + bool MaxVramAssignment::canonicalize_backend_keys(std::string* error) { + if (backend_gib.empty()) { + return true; + } + + std::unordered_map normalized; + for (const auto& kv : backend_gib) { + std::string resolved = sd_backend_resolve_name(kv.first); + if (resolved.empty()) { + if (error != nullptr) { + *error = "unknown --max-vram backend '" + kv.first + "'"; + } + return false; + } + normalized[normalize_backend_budget_key(resolved)] = kv.second; + } + backend_gib = std::move(normalized); + resolved_backend_bytes.clear(); + return true; + } + + size_t MaxVramAssignment::bytes_for_backend(ggml_backend_t backend) { + std::vector keys = backend_budget_keys(backend); + const std::string cache_key = keys.empty() ? std::string("") : keys.front(); + auto cached = resolved_backend_bytes.find(cache_key); + if (cached != resolved_backend_bytes.end()) { + return cached->second; + } + + float budget_gib = default_gib; + if (!backend_gib.empty()) { + for (const std::string& key : keys) { + auto backend_it = backend_gib.find(key); + if (backend_it != backend_gib.end()) { + budget_gib = backend_it->second; + break; + } + } + } + + const float resolved_gib = resolve_max_vram_gib(budget_gib, backend); + const size_t bytes = max_vram_gib_to_bytes(resolved_gib); + resolved_backend_bytes[cache_key] = bytes; + return bytes; + } + size_t max_vram_gib_to_bytes(float max_vram) { if (max_vram <= 0.f) { return 0; diff --git a/src/core/ggml_graph_cut.h b/src/core/ggml_graph_cut.h index 01e9b3ad..17f2f1d7 100644 --- a/src/core/ggml_graph_cut.h +++ b/src/core/ggml_graph_cut.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -68,6 +69,17 @@ namespace sd::ggml_graph_cut { static constexpr const char* GGML_RUNNER_CUT_PREFIX = "ggml_runner_cut:"; + struct MaxVramAssignment { + float default_gib = 0.f; + std::unordered_map backend_gib; + std::unordered_map resolved_backend_bytes; + + void reset(float fallback_gib); + bool parse(const std::string& raw_spec, std::string* error); + bool canonicalize_backend_keys(std::string* error); + size_t bytes_for_backend(ggml_backend_t backend); + }; + bool is_graph_cut_tensor(const ggml_tensor* tensor); std::string make_graph_cut_name(const std::string& group, const std::string& output); void mark_graph_cut(ggml_tensor* tensor, const std::string& group, const std::string& output); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c74d7363..836b0f85 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "core/ggml_extend.hpp" #include "core/ggml_graph_cut.h" @@ -188,8 +189,8 @@ public: std::string taesd_path; sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr}; bool enable_mmap = false; - float max_vram = 0.f; - bool stream_layers = false; + sd::ggml_graph_cut::MaxVramAssignment max_vram_assignment; + bool stream_layers = false; std::string backend_spec; std::string params_backend_spec; @@ -221,6 +222,10 @@ public: return module_backend; } + size_t max_graph_vram_bytes_for_module(SDBackendModule module) { + return max_vram_assignment.bytes_for_backend(backend_for(module)); + } + bool ensure_backend_pair(SDBackendModule module) { if (backend_for(module) == nullptr) { return false; @@ -314,19 +319,21 @@ public: bool init(const sd_ctx_params_t* sd_ctx_params) { n_threads = sd_ctx_params->n_threads; enable_mmap = sd_ctx_params->enable_mmap; - max_vram = sd_ctx_params->max_vram; stream_layers = sd_ctx_params->stream_layers; backend_spec = SAFE_STR(sd_ctx_params->backend); params_backend_spec = SAFE_STR(sd_ctx_params->params_backend); + max_vram_assignment.reset(0.f); + { + std::string error; + if (!max_vram_assignment.parse(SAFE_STR(sd_ctx_params->max_vram), &error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } + } std::string rpc_servers_spec = SAFE_STR(sd_ctx_params->rpc_servers); add_rpc_devices(rpc_servers_spec); - if (stream_layers && max_vram == 0.f) { - LOG_WARN("--stream-layers has no effect without --max-vram set; ignoring"); - stream_layers = false; - } - bool use_tae = false; bool use_audio_vae = false; bool use_control_net = false; @@ -343,11 +350,17 @@ public: if (!init_backend()) { return false; } + { + std::string error; + if (!max_vram_assignment.canonicalize_backend_keys(&error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } + } if (stream_layers && !backend_manager.params_backend_is_cpu(SDBackendModule::DIFFUSION)) { LOG_WARN("--stream-layers has no effect unless diffusion params backend is cpu; ignoring"); stream_layers = false; } - max_vram = sd::ggml_graph_cut::resolve_max_vram_gib(max_vram, backend_for(SDBackendModule::DIFFUSION)); model_manager = std::make_shared(); model_manager->set_n_threads(n_threads); @@ -564,8 +577,6 @@ public: LOG_INFO("Using circular padding for convolutions"); } - const size_t max_graph_vram_bytes = sd::ggml_graph_cut::max_vram_gib_to_bytes(max_vram); - { if (!ensure_backend_pair(SDBackendModule::TE) || !ensure_backend_pair(SDBackendModule::DIFFUSION)) { @@ -687,7 +698,7 @@ public: clip_vision = std::make_shared(backend_for(SDBackendModule::CLIP_VISION), tensor_storage_map, model_manager); - clip_vision->set_max_graph_vram_bytes(max_graph_vram_bytes); + clip_vision->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::CLIP_VISION)); if (!register_runner_params("CLIP vision", clip_vision, SDBackendModule::CLIP_VISION)) { @@ -791,7 +802,7 @@ public: } } - cond_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes); + cond_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::TE)); if (!register_runner_params("Conditioner model", cond_stage_model, SDBackendModule::TE, @@ -799,7 +810,7 @@ public: return false; } - diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes); + diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::DIFFUSION)); diffusion_model->set_stream_layers_enabled(stream_layers); if (!register_runner_params("Diffusion model", diffusion_model, @@ -809,7 +820,7 @@ public: } if (high_noise_diffusion_model) { - high_noise_diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes); + high_noise_diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::DIFFUSION)); high_noise_diffusion_model->set_stream_layers_enabled(stream_layers); if (!register_runner_params("High noise diffusion model", high_noise_diffusion_model, @@ -908,7 +919,7 @@ public: } else if (use_tae && !tae_preview_only) { LOG_INFO("using TAE for encoding / decoding"); first_stage_model = create_tae(false); - first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes); + first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::VAE)); if (!register_runner_params("VAE", first_stage_model, SDBackendModule::VAE, @@ -918,7 +929,7 @@ public: } else { LOG_INFO("using VAE for encoding / decoding"); first_stage_model = create_vae(); - first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes); + first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::VAE)); if (!register_runner_params("VAE", first_stage_model, SDBackendModule::VAE, @@ -928,7 +939,7 @@ public: if (use_tae && tae_preview_only) { LOG_INFO("using TAE for preview"); preview_vae = create_tae(true); - preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes); + preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes_for_module(SDBackendModule::VAE)); if (!register_runner_params("preview VAE", preview_vae, SDBackendModule::VAE, @@ -2618,7 +2629,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT; sd_ctx_params->prediction = PREDICTION_COUNT; sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO; - sd_ctx_params->max_vram = 0.f; + sd_ctx_params->max_vram = nullptr; sd_ctx_params->stream_layers = false; sd_ctx_params->enable_mmap = false; sd_ctx_params->diffusion_flash_attn = false; @@ -2630,6 +2641,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->vae_format = SD_VAE_FORMAT_AUTO; sd_ctx_params->backend = nullptr; sd_ctx_params->params_backend = nullptr; + sd_ctx_params->rpc_servers = nullptr; } char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { @@ -2661,7 +2673,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "rng_type: %s\n" "sampler_rng_type: %s\n" "prediction: %s\n" - "max_vram: %.3f\n" + "max_vram: %s\n" "stream_layers: %s\n" "backend: %s\n" "params_backend: %s\n" @@ -2695,7 +2707,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { sd_rng_type_name(sd_ctx_params->rng_type), sd_rng_type_name(sd_ctx_params->sampler_rng_type), sd_prediction_name(sd_ctx_params->prediction), - sd_ctx_params->max_vram, + SAFE_STR(sd_ctx_params->max_vram), BOOL_STR(sd_ctx_params->stream_layers), SAFE_STR(sd_ctx_params->backend), SAFE_STR(sd_ctx_params->params_backend), @@ -4417,7 +4429,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s request.hires.upscale_tile_size, sd_ctx->sd->backend_spec, sd_ctx->sd->params_backend_spec); - const size_t max_graph_vram_bytes = sd::ggml_graph_cut::max_vram_gib_to_bytes(sd_ctx->sd->max_vram); + const size_t max_graph_vram_bytes = sd_ctx->sd->max_graph_vram_bytes_for_module(SDBackendModule::UPSCALER); hires_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes); if (!hires_upscaler->load_from_file(request.hires.model_path, sd_ctx->sd->n_threads)) { @@ -4966,7 +4978,7 @@ static sd::Tensor upscale_ltx_spatial_video_latent(sd_ctx_t* sd_ctx, std::make_unique(sd_ctx->sd->backend_for(SDBackendModule::UPSCALER), model_loader.get_tensor_storage_map(), upsampler_manager); - const size_t max_graph_vram_bytes = sd::ggml_graph_cut::max_vram_gib_to_bytes(sd_ctx->sd->max_vram); + const size_t max_graph_vram_bytes = sd_ctx->sd->max_graph_vram_bytes_for_module(SDBackendModule::UPSCALER); upsampler->set_max_graph_vram_bytes(max_graph_vram_bytes); if (upsampler->model == nullptr) { LOG_ERROR("init LTX latent upsampler from metadata failed");