Added check for additional cuda-paths when multiple versions are installed

This commit is contained in:
Darth Affe 2024-03-22 21:33:37 +01:00
parent aa01b5687b
commit 4b8faf24d0

View File

@ -2,10 +2,13 @@
using System.Runtime.InteropServices;
using System.Text.Json;
using System;
using System.Collections;
using System.Linq;
using System.Text.RegularExpressions;
namespace StableDiffusion.NET;
public class CudaBackend : IBackend
public partial class CudaBackend : IBackend
{
#region Constants
@ -55,7 +58,22 @@ public class CudaBackend : IBackend
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
if (cudaPath is null) return -1;
if (cudaPath == null)
{
IDictionary environmentVariables = Environment.GetEnvironmentVariables();
string? key = environmentVariables.Keys.Cast<string>().Where(x => x.StartsWith("CUDA_PATH_", StringComparison.OrdinalIgnoreCase))
.Select(x => (x, CudaPathRegex().Match(x)))
.Where(x => x.Item2.Success)
.Select(x => (x.x, x.Item2.Groups["majorVersion"].Value))
.OrderByDescending(x => int.Parse(x.Value))
.FirstOrDefault()
.x;
if (key != null)
cudaPath = Environment.GetEnvironmentVariable(key);
}
if (cudaPath == null) return -1;
version = GetCudaVersionFromPath(cudaPath);
}
@ -110,5 +128,8 @@ public class CudaBackend : IBackend
}
}
[GeneratedRegex(@"CUDA_PATH_V?(?<majorVersion>\d+)_?\d*", RegexOptions.IgnoreCase)]
private static partial Regex CudaPathRegex();
#endregion
}