diff --git a/StableDiffusion.NET/Backends/CudaBackend.cs b/StableDiffusion.NET/Backends/CudaBackend.cs index e96d7d2..68edde2 100644 --- a/StableDiffusion.NET/Backends/CudaBackend.cs +++ b/StableDiffusion.NET/Backends/CudaBackend.cs @@ -2,10 +2,13 @@ using System.Runtime.InteropServices; using System.Text.Json; using System; +using System.Collections; +using System.Linq; +using System.Text.RegularExpressions; namespace StableDiffusion.NET; -public class CudaBackend : IBackend +public partial class CudaBackend : IBackend { #region Constants @@ -55,7 +58,22 @@ public class CudaBackend : IBackend if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH"); - if (cudaPath is null) return -1; + + if (cudaPath == null) + { + IDictionary environmentVariables = Environment.GetEnvironmentVariables(); + string? key = environmentVariables.Keys.Cast().Where(x => x.StartsWith("CUDA_PATH_", StringComparison.OrdinalIgnoreCase)) + .Select(x => (x, CudaPathRegex().Match(x))) + .Where(x => x.Item2.Success) + .Select(x => (x.x, x.Item2.Groups["majorVersion"].Value)) + .OrderByDescending(x => int.Parse(x.Value)) + .FirstOrDefault() + .x; + if (key != null) + cudaPath = Environment.GetEnvironmentVariable(key); + } + + if (cudaPath == null) return -1; version = GetCudaVersionFromPath(cudaPath); } @@ -110,5 +128,8 @@ public class CudaBackend : IBackend } } + [GeneratedRegex(@"CUDA_PATH_V?(?\d+)_?\d*", RegexOptions.IgnoreCase)] + private static partial Regex CudaPathRegex(); + #endregion } \ No newline at end of file