mirror of
https://github.com/DarthAffe/StableDiffusion.NET.git
synced 2025-12-12 21:38:45 +00:00
137 lines
4.3 KiB
C#
137 lines
4.3 KiB
C#
using System.IO;
|
|
using System.Runtime.InteropServices;
|
|
using System.Text.Json;
|
|
using System;
|
|
using System.Collections;
|
|
using System.Linq;
|
|
using System.Text.RegularExpressions;
|
|
using JetBrains.Annotations;
|
|
|
|
namespace StableDiffusion.NET;
|
|
|
|
[PublicAPI]
|
|
public partial class CudaBackend : IBackend
|
|
{
|
|
#region Constants
|
|
|
|
private const string CUDA_VERSION_FILE = "version.json";
|
|
|
|
#endregion
|
|
|
|
#region Properties & Fields
|
|
|
|
public bool IsEnabled { get; set; } = true;
|
|
|
|
public int Priority { get; set; } = 10;
|
|
|
|
public bool IsAvailable => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
|
|
|| RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
|
|
&& (RuntimeInformation.OSArchitecture == Architecture.X64)
|
|
&& CudaVersion is 12;
|
|
|
|
public string PathPart => CudaVersion switch
|
|
{
|
|
11 => "cuda11",
|
|
12 => "cuda12",
|
|
_ => string.Empty
|
|
};
|
|
|
|
public int CudaVersion { get; }
|
|
|
|
#endregion
|
|
|
|
#region Constructors
|
|
|
|
internal CudaBackend()
|
|
{
|
|
CudaVersion = GetCudaMajorVersion();
|
|
}
|
|
|
|
#endregion
|
|
|
|
#region Methods
|
|
|
|
private static int GetCudaMajorVersion()
|
|
{
|
|
try
|
|
{
|
|
string? cudaPath;
|
|
string version = "";
|
|
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
|
{
|
|
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
|
|
|
|
if (cudaPath == null)
|
|
{
|
|
IDictionary environmentVariables = Environment.GetEnvironmentVariables();
|
|
string? key = environmentVariables.Keys.Cast<string>().Where(x => x.StartsWith("CUDA_PATH_", StringComparison.OrdinalIgnoreCase))
|
|
.Select(x => (x, CudaPathRegex().Match(x)))
|
|
.Where(x => x.Item2.Success)
|
|
.Select(x => (x.x, x.Item2.Groups["majorVersion"].Value))
|
|
.OrderByDescending(x => int.Parse(x.Value))
|
|
.FirstOrDefault()
|
|
.x;
|
|
if (key != null)
|
|
cudaPath = Environment.GetEnvironmentVariable(key);
|
|
}
|
|
|
|
if (cudaPath == null) return -1;
|
|
|
|
version = GetCudaVersionFromPath(cudaPath);
|
|
}
|
|
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
|
|
{
|
|
cudaPath = "/usr/local/bin/cuda";
|
|
version = GetCudaVersionFromPath(cudaPath);
|
|
if (string.IsNullOrEmpty(version))
|
|
{
|
|
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
|
|
if (cudaPath is null)
|
|
return -1;
|
|
|
|
foreach (string path in cudaPath.Split(':'))
|
|
{
|
|
version = GetCudaVersionFromPath(Path.Combine(path, ".."));
|
|
if (string.IsNullOrEmpty(version))
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (string.IsNullOrEmpty(version))
|
|
return -1;
|
|
|
|
version = version.Split('.')[0];
|
|
if (int.TryParse(version, out int majorVersion))
|
|
return majorVersion;
|
|
}
|
|
catch { /* No version or error */ }
|
|
|
|
return -1;
|
|
}
|
|
|
|
private static string GetCudaVersionFromPath(string cudaPath)
|
|
{
|
|
try
|
|
{
|
|
string json = File.ReadAllText(Path.Combine(cudaPath, CUDA_VERSION_FILE));
|
|
using JsonDocument document = JsonDocument.Parse(json);
|
|
JsonElement root = document.RootElement;
|
|
JsonElement cublasNode = root.GetProperty("libcublas");
|
|
JsonElement versionNode = cublasNode.GetProperty("version");
|
|
if (versionNode.ValueKind == JsonValueKind.Undefined)
|
|
return string.Empty;
|
|
|
|
return versionNode.GetString() ?? "";
|
|
}
|
|
catch (Exception)
|
|
{
|
|
return string.Empty;
|
|
}
|
|
}
|
|
|
|
[GeneratedRegex(@"CUDA_PATH_V?(?<majorVersion>\d+)_?\d*", RegexOptions.IgnoreCase)]
|
|
private static partial Regex CudaPathRegex();
|
|
|
|
#endregion
|
|
} |