mirror of
https://github.com/DarthAffe/StableDiffusion.NET.git
synced 2025-12-12 13:28:35 +00:00
Added backend abstraction to allow loading them at runtime depending on the capabilities of the system
This commit is contained in:
parent
11e9c3d0ec
commit
e9ecb141db
44
StableDiffusion.NET/Backends/Backends.cs
Normal file
44
StableDiffusion.NET/Backends/Backends.cs
Normal file
@ -0,0 +1,44 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace StableDiffusion.NET;
|
||||
|
||||
public static class Backends
|
||||
{
|
||||
#region Properties & Fields
|
||||
|
||||
public static CpuBackend CpuBackend { get; } = new();
|
||||
public static CudaBackend CudaBackend { get; } = new();
|
||||
public static RocmBackend RocmBackend { get; } = new();
|
||||
|
||||
private static readonly List<IBackend> CUSTOM_BACKENDS = [];
|
||||
public static IReadOnlyList<IBackend> CustomBackends => CUSTOM_BACKENDS.AsReadOnly();
|
||||
|
||||
public static IEnumerable<IBackend> RegisteredBackends => [CpuBackend, CudaBackend, RocmBackend, .. CUSTOM_BACKENDS];
|
||||
public static IEnumerable<IBackend> AvailableBackends => RegisteredBackends.Where(x => x.IsAvailable);
|
||||
public static IEnumerable<IBackend> ActiveBackends => AvailableBackends.Where(x => x.IsEnabled);
|
||||
|
||||
public static List<string> SearchPaths { get; } = [];
|
||||
|
||||
#endregion
|
||||
|
||||
#region Methods
|
||||
|
||||
public static bool RegisterBackend(IBackend backend)
|
||||
{
|
||||
if (backend is NET.CpuBackend or NET.CudaBackend or NET.RocmBackend)
|
||||
throw new ArgumentException("Default backends can't be registered again.");
|
||||
|
||||
if (CUSTOM_BACKENDS.Contains(backend))
|
||||
return false;
|
||||
|
||||
CUSTOM_BACKENDS.Add(backend);
|
||||
return true;
|
||||
}
|
||||
|
||||
public static bool UnregisterBackend(IBackend backend)
|
||||
=> CUSTOM_BACKENDS.Remove(backend);
|
||||
|
||||
#endregion
|
||||
}
|
||||
102
StableDiffusion.NET/Backends/CpuBackend.cs
Normal file
102
StableDiffusion.NET/Backends/CpuBackend.cs
Normal file
@ -0,0 +1,102 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.ComponentModel;
|
||||
using System.Runtime.InteropServices;
|
||||
using StableDiffusion.NET.Extensions;
|
||||
|
||||
namespace StableDiffusion.NET;
|
||||
|
||||
public class CpuBackend : IBackend
|
||||
{
|
||||
#region Properties & Fields
|
||||
|
||||
public bool IsEnabled { get; set; } = true;
|
||||
|
||||
public int Priority => 0;
|
||||
|
||||
public bool IsAvailable => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
|
||||
|| RuntimeInformation.IsOSPlatform(OSPlatform.Linux)
|
||||
|| RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
|
||||
&& (RuntimeInformation.OSArchitecture == Architecture.X64);
|
||||
|
||||
public string PathPart => Avx.GetDescription();
|
||||
|
||||
private readonly List<AvxLevel> _availableAvxLevels = [];
|
||||
public IEnumerable<AvxLevel> AvailableAvxLevels => _availableAvxLevels.AsReadOnly();
|
||||
|
||||
private AvxLevel _avx;
|
||||
public AvxLevel Avx
|
||||
{
|
||||
get => _avx;
|
||||
set
|
||||
{
|
||||
if (!_availableAvxLevels.Contains(value)) throw new ArgumentException("The selected AVX-Level is not supported on this system.");
|
||||
_avx = value;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Constructors
|
||||
|
||||
internal CpuBackend()
|
||||
{
|
||||
_availableAvxLevels.Add(AvxLevel.None);
|
||||
Avx = AvxLevel.None;
|
||||
|
||||
if (System.Runtime.Intrinsics.X86.Avx.IsSupported)
|
||||
{
|
||||
_availableAvxLevels.Add(AvxLevel.Avx);
|
||||
Avx = AvxLevel.Avx;
|
||||
}
|
||||
|
||||
if (System.Runtime.Intrinsics.X86.Avx2.IsSupported)
|
||||
{
|
||||
_availableAvxLevels.Add(AvxLevel.Avx2);
|
||||
Avx = AvxLevel.Avx2;
|
||||
}
|
||||
|
||||
if (CheckAvx512())
|
||||
{
|
||||
_availableAvxLevels.Add(AvxLevel.Avx512);
|
||||
Avx = AvxLevel.Avx512;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Methods
|
||||
|
||||
private static bool CheckAvx512()
|
||||
{
|
||||
if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported)
|
||||
return false;
|
||||
|
||||
(_, int _, int ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0);
|
||||
|
||||
bool vnni = (ecx & 0b_1000_0000_0000) != 0;
|
||||
|
||||
bool f = System.Runtime.Intrinsics.X86.Avx512F.IsSupported;
|
||||
bool bw = System.Runtime.Intrinsics.X86.Avx512BW.IsSupported;
|
||||
bool vbmi = System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported;
|
||||
|
||||
return vnni && vbmi && bw && f;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
public enum AvxLevel
|
||||
{
|
||||
[Description("")]
|
||||
None,
|
||||
|
||||
[Description("avx")]
|
||||
Avx,
|
||||
|
||||
[Description("avx2")]
|
||||
Avx2,
|
||||
|
||||
[Description("avx512")]
|
||||
Avx512,
|
||||
}
|
||||
}
|
||||
114
StableDiffusion.NET/Backends/CudaBackend.cs
Normal file
114
StableDiffusion.NET/Backends/CudaBackend.cs
Normal file
@ -0,0 +1,114 @@
|
||||
using System.IO;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text.Json;
|
||||
using System;
|
||||
|
||||
namespace StableDiffusion.NET;
|
||||
|
||||
public class CudaBackend : IBackend
|
||||
{
|
||||
#region Constants
|
||||
|
||||
private const string CUDA_VERSION_FILE = "version.json";
|
||||
|
||||
#endregion
|
||||
|
||||
#region Properties & Fields
|
||||
|
||||
public bool IsEnabled { get; set; } = true;
|
||||
|
||||
public int Priority => 10;
|
||||
|
||||
public bool IsAvailable => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
|
||||
|| RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
|
||||
&& (RuntimeInformation.OSArchitecture == Architecture.X64)
|
||||
&& CudaVersion is 11 or 12;
|
||||
|
||||
public string PathPart => CudaVersion switch
|
||||
{
|
||||
11 => "cuda11",
|
||||
12 => "cuda12",
|
||||
_ => string.Empty
|
||||
};
|
||||
|
||||
public int CudaVersion { get; }
|
||||
|
||||
#endregion
|
||||
|
||||
#region Constructors
|
||||
|
||||
internal CudaBackend()
|
||||
{
|
||||
CudaVersion = GetCudaMajorVersion();
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Methods
|
||||
|
||||
private static int GetCudaMajorVersion()
|
||||
{
|
||||
try
|
||||
{
|
||||
string? cudaPath;
|
||||
string version = "";
|
||||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
||||
{
|
||||
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
|
||||
if (cudaPath is null) return -1;
|
||||
|
||||
version = GetCudaVersionFromPath(cudaPath);
|
||||
}
|
||||
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
|
||||
{
|
||||
cudaPath = "/usr/local/bin/cuda";
|
||||
version = GetCudaVersionFromPath(cudaPath);
|
||||
if (string.IsNullOrEmpty(version))
|
||||
{
|
||||
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
|
||||
if (cudaPath is null)
|
||||
return -1;
|
||||
|
||||
foreach (string path in cudaPath.Split(':'))
|
||||
{
|
||||
version = GetCudaVersionFromPath(Path.Combine(path, ".."));
|
||||
if (string.IsNullOrEmpty(version))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(version))
|
||||
return -1;
|
||||
|
||||
version = version.Split('.')[0];
|
||||
if (int.TryParse(version, out int majorVersion))
|
||||
return majorVersion;
|
||||
}
|
||||
catch { /* No version or error */ }
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
private static string GetCudaVersionFromPath(string cudaPath)
|
||||
{
|
||||
try
|
||||
{
|
||||
string json = File.ReadAllText(Path.Combine(cudaPath, CUDA_VERSION_FILE));
|
||||
using JsonDocument document = JsonDocument.Parse(json);
|
||||
JsonElement root = document.RootElement;
|
||||
JsonElement cublasNode = root.GetProperty("libcublas");
|
||||
JsonElement versionNode = cublasNode.GetProperty("version");
|
||||
if (versionNode.ValueKind == JsonValueKind.Undefined)
|
||||
return string.Empty;
|
||||
|
||||
return versionNode.GetString() ?? "";
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
return string.Empty;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
9
StableDiffusion.NET/Backends/IBackend.cs
Normal file
9
StableDiffusion.NET/Backends/IBackend.cs
Normal file
@ -0,0 +1,9 @@
|
||||
namespace StableDiffusion.NET;
|
||||
|
||||
public interface IBackend
|
||||
{
|
||||
bool IsEnabled { get; set; }
|
||||
public int Priority { get; }
|
||||
bool IsAvailable { get; }
|
||||
string PathPart { get; }
|
||||
}
|
||||
22
StableDiffusion.NET/Backends/RocmBackend.cs
Normal file
22
StableDiffusion.NET/Backends/RocmBackend.cs
Normal file
@ -0,0 +1,22 @@
|
||||
namespace StableDiffusion.NET;
|
||||
|
||||
public class RocmBackend : IBackend
|
||||
{
|
||||
#region Properties & Fields
|
||||
|
||||
public bool IsEnabled { get; set; } = true;
|
||||
|
||||
public int Priority => 10;
|
||||
|
||||
public bool IsAvailable => false;
|
||||
|
||||
public string PathPart { get; } = string.Empty;
|
||||
|
||||
#endregion
|
||||
|
||||
#region Constructors
|
||||
|
||||
internal RocmBackend() { }
|
||||
|
||||
#endregion
|
||||
}
|
||||
16
StableDiffusion.NET/Extensions/EnumExtension.cs
Normal file
16
StableDiffusion.NET/Extensions/EnumExtension.cs
Normal file
@ -0,0 +1,16 @@
|
||||
using System;
|
||||
using System.ComponentModel;
|
||||
|
||||
namespace StableDiffusion.NET.Extensions;
|
||||
|
||||
internal static class EnumExtension
|
||||
{
|
||||
public static string GetDescription(this Enum value)
|
||||
{
|
||||
DescriptionAttribute[]? attributes = (DescriptionAttribute[]?)value.GetType().GetField(value.ToString())?.GetCustomAttributes(typeof(DescriptionAttribute), false);
|
||||
|
||||
return attributes?.Length > 0
|
||||
? attributes[0].Description
|
||||
: value.ToString();
|
||||
}
|
||||
}
|
||||
103
StableDiffusion.NET/Native/Native.Load.cs
Normal file
103
StableDiffusion.NET/Native/Native.Load.cs
Normal file
@ -0,0 +1,103 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace StableDiffusion.NET;
|
||||
|
||||
internal static partial class Native
|
||||
{
|
||||
#region Properties & Fields
|
||||
|
||||
private static nint _loadedLibraryHandle;
|
||||
|
||||
#endregion
|
||||
|
||||
#region Constructors
|
||||
|
||||
static Native()
|
||||
{
|
||||
NativeLibrary.SetDllImportResolver(typeof(Native).Assembly, ResolveDllImport);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Methods
|
||||
|
||||
private static nint ResolveDllImport(string libraryname, Assembly assembly, DllImportSearchPath? searchpath)
|
||||
{
|
||||
if (libraryname != LIB_NAME) return nint.Zero;
|
||||
if (_loadedLibraryHandle != nint.Zero) return _loadedLibraryHandle;
|
||||
|
||||
_loadedLibraryHandle = TryLoadLibrary();
|
||||
|
||||
return _loadedLibraryHandle;
|
||||
}
|
||||
|
||||
private static nint TryLoadLibrary()
|
||||
{
|
||||
GetPlatformPathParts(out string os, out string fileExtension, out string libPrefix);
|
||||
|
||||
foreach (IBackend backend in Backends.ActiveBackends.OrderBy(x => x.Priority))
|
||||
{
|
||||
string path = Path.Combine("runtimes", os, "native", backend.PathPart, $"{libPrefix}{LIB_NAME}{fileExtension}");
|
||||
|
||||
string fullPath = TryFindPath(path);
|
||||
nint result = TryLoad(fullPath);
|
||||
|
||||
if (result != nint.Zero)
|
||||
return result;
|
||||
}
|
||||
|
||||
return nint.Zero;
|
||||
|
||||
static nint TryLoad(string path)
|
||||
{
|
||||
if (NativeLibrary.TryLoad(path, out nint handle))
|
||||
return handle;
|
||||
|
||||
return nint.Zero;
|
||||
}
|
||||
|
||||
static string TryFindPath(string filename)
|
||||
{
|
||||
IEnumerable<string> searchPaths = [.. Backends.SearchPaths, AppDomain.CurrentDomain.BaseDirectory, Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) ?? ""];
|
||||
foreach (string path in searchPaths)
|
||||
{
|
||||
string candidate = Path.Combine(path, filename);
|
||||
if (File.Exists(candidate))
|
||||
return candidate;
|
||||
}
|
||||
|
||||
return filename;
|
||||
}
|
||||
}
|
||||
|
||||
private static void GetPlatformPathParts(out string os, out string fileExtension, out string libPrefix)
|
||||
{
|
||||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
||||
{
|
||||
os = "win-x64";
|
||||
fileExtension = ".dll";
|
||||
libPrefix = "";
|
||||
}
|
||||
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
|
||||
{
|
||||
os = "linux-x64";
|
||||
fileExtension = ".so";
|
||||
libPrefix = "lib";
|
||||
}
|
||||
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
|
||||
{
|
||||
fileExtension = ".dylib";
|
||||
os = "osx-x64";
|
||||
libPrefix = "lib";
|
||||
}
|
||||
else
|
||||
throw new NotSupportedException("Your operating system is not supported.");
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=attributes/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=backends/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=enums/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=eventargs/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=extensions/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=extensions/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=native/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
|
||||
Loading…
x
Reference in New Issue
Block a user