From 60abda56e07455f1d008b221de0f4ef4c4136741 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Sun, 21 Dec 2025 08:35:38 +0100 Subject: [PATCH] feat: select vulkan device with env variable (#629) --- stable-diffusion.cpp | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 17fe3fd..7451993 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -165,7 +165,27 @@ public: #endif #ifdef SD_USE_VULKAN LOG_DEBUG("Using Vulkan backend"); - for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { + size_t device = 0; + const int device_count = ggml_backend_vk_get_device_count(); + if (device_count) { + const char* SD_VK_DEVICE = getenv("SD_VK_DEVICE"); + if (SD_VK_DEVICE != nullptr) { + std::string sd_vk_device_str = SD_VK_DEVICE; + try { + device = std::stoull(sd_vk_device_str); + } catch (const std::invalid_argument&) { + LOG_WARN("SD_VK_DEVICE environment variable is not a valid integer (%s). Falling back to device 0.", SD_VK_DEVICE); + device = 0; + } catch (const std::out_of_range&) { + LOG_WARN("SD_VK_DEVICE environment variable value is out of range for `unsigned long long` type (%s). Falling back to device 0.", SD_VK_DEVICE); + device = 0; + } + if (device >= device_count) { + LOG_WARN("Cannot find targeted vulkan device (%llu). Falling back to device 0.", device); + device = 0; + } + } + LOG_INFO("Vulkan: Using device %llu", device); backend = ggml_backend_vk_init(device); } if (!backend) {