Added backend abstraction to allow loading them at runtime depending on the capabilities of the system

This commit is contained in:
Darth Affe 2024-03-21 23:05:18 +01:00
parent 11e9c3d0ec
commit e9ecb141db
9 changed files with 413 additions and 1 deletions

View File

@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Linq;
namespace StableDiffusion.NET;
public static class Backends
{
#region Properties & Fields
public static CpuBackend CpuBackend { get; } = new();
public static CudaBackend CudaBackend { get; } = new();
public static RocmBackend RocmBackend { get; } = new();
private static readonly List<IBackend> CUSTOM_BACKENDS = [];
public static IReadOnlyList<IBackend> CustomBackends => CUSTOM_BACKENDS.AsReadOnly();
public static IEnumerable<IBackend> RegisteredBackends => [CpuBackend, CudaBackend, RocmBackend, .. CUSTOM_BACKENDS];
public static IEnumerable<IBackend> AvailableBackends => RegisteredBackends.Where(x => x.IsAvailable);
public static IEnumerable<IBackend> ActiveBackends => AvailableBackends.Where(x => x.IsEnabled);
public static List<string> SearchPaths { get; } = [];
#endregion
#region Methods
public static bool RegisterBackend(IBackend backend)
{
if (backend is NET.CpuBackend or NET.CudaBackend or NET.RocmBackend)
throw new ArgumentException("Default backends can't be registered again.");
if (CUSTOM_BACKENDS.Contains(backend))
return false;
CUSTOM_BACKENDS.Add(backend);
return true;
}
public static bool UnregisterBackend(IBackend backend)
=> CUSTOM_BACKENDS.Remove(backend);
#endregion
}

View File

@ -0,0 +1,102 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Runtime.InteropServices;
using StableDiffusion.NET.Extensions;
namespace StableDiffusion.NET;
public class CpuBackend : IBackend
{
#region Properties & Fields
public bool IsEnabled { get; set; } = true;
public int Priority => 0;
public bool IsAvailable => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
|| RuntimeInformation.IsOSPlatform(OSPlatform.Linux)
|| RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
&& (RuntimeInformation.OSArchitecture == Architecture.X64);
public string PathPart => Avx.GetDescription();
private readonly List<AvxLevel> _availableAvxLevels = [];
public IEnumerable<AvxLevel> AvailableAvxLevels => _availableAvxLevels.AsReadOnly();
private AvxLevel _avx;
public AvxLevel Avx
{
get => _avx;
set
{
if (!_availableAvxLevels.Contains(value)) throw new ArgumentException("The selected AVX-Level is not supported on this system.");
_avx = value;
}
}
#endregion
#region Constructors
internal CpuBackend()
{
_availableAvxLevels.Add(AvxLevel.None);
Avx = AvxLevel.None;
if (System.Runtime.Intrinsics.X86.Avx.IsSupported)
{
_availableAvxLevels.Add(AvxLevel.Avx);
Avx = AvxLevel.Avx;
}
if (System.Runtime.Intrinsics.X86.Avx2.IsSupported)
{
_availableAvxLevels.Add(AvxLevel.Avx2);
Avx = AvxLevel.Avx2;
}
if (CheckAvx512())
{
_availableAvxLevels.Add(AvxLevel.Avx512);
Avx = AvxLevel.Avx512;
}
}
#endregion
#region Methods
private static bool CheckAvx512()
{
if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported)
return false;
(_, int _, int ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0);
bool vnni = (ecx & 0b_1000_0000_0000) != 0;
bool f = System.Runtime.Intrinsics.X86.Avx512F.IsSupported;
bool bw = System.Runtime.Intrinsics.X86.Avx512BW.IsSupported;
bool vbmi = System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported;
return vnni && vbmi && bw && f;
}
#endregion
public enum AvxLevel
{
[Description("")]
None,
[Description("avx")]
Avx,
[Description("avx2")]
Avx2,
[Description("avx512")]
Avx512,
}
}

View File

@ -0,0 +1,114 @@
using System.IO;
using System.Runtime.InteropServices;
using System.Text.Json;
using System;
namespace StableDiffusion.NET;
public class CudaBackend : IBackend
{
#region Constants
private const string CUDA_VERSION_FILE = "version.json";
#endregion
#region Properties & Fields
public bool IsEnabled { get; set; } = true;
public int Priority => 10;
public bool IsAvailable => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
|| RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
&& (RuntimeInformation.OSArchitecture == Architecture.X64)
&& CudaVersion is 11 or 12;
public string PathPart => CudaVersion switch
{
11 => "cuda11",
12 => "cuda12",
_ => string.Empty
};
public int CudaVersion { get; }
#endregion
#region Constructors
internal CudaBackend()
{
CudaVersion = GetCudaMajorVersion();
}
#endregion
#region Methods
private static int GetCudaMajorVersion()
{
try
{
string? cudaPath;
string version = "";
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
if (cudaPath is null) return -1;
version = GetCudaVersionFromPath(cudaPath);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
cudaPath = "/usr/local/bin/cuda";
version = GetCudaVersionFromPath(cudaPath);
if (string.IsNullOrEmpty(version))
{
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
if (cudaPath is null)
return -1;
foreach (string path in cudaPath.Split(':'))
{
version = GetCudaVersionFromPath(Path.Combine(path, ".."));
if (string.IsNullOrEmpty(version))
break;
}
}
}
if (string.IsNullOrEmpty(version))
return -1;
version = version.Split('.')[0];
if (int.TryParse(version, out int majorVersion))
return majorVersion;
}
catch { /* No version or error */ }
return -1;
}
private static string GetCudaVersionFromPath(string cudaPath)
{
try
{
string json = File.ReadAllText(Path.Combine(cudaPath, CUDA_VERSION_FILE));
using JsonDocument document = JsonDocument.Parse(json);
JsonElement root = document.RootElement;
JsonElement cublasNode = root.GetProperty("libcublas");
JsonElement versionNode = cublasNode.GetProperty("version");
if (versionNode.ValueKind == JsonValueKind.Undefined)
return string.Empty;
return versionNode.GetString() ?? "";
}
catch (Exception)
{
return string.Empty;
}
}
#endregion
}

View File

@ -0,0 +1,9 @@
namespace StableDiffusion.NET;
public interface IBackend
{
bool IsEnabled { get; set; }
public int Priority { get; }
bool IsAvailable { get; }
string PathPart { get; }
}

View File

@ -0,0 +1,22 @@
namespace StableDiffusion.NET;
public class RocmBackend : IBackend
{
#region Properties & Fields
public bool IsEnabled { get; set; } = true;
public int Priority => 10;
public bool IsAvailable => false;
public string PathPart { get; } = string.Empty;
#endregion
#region Constructors
internal RocmBackend() { }
#endregion
}

View File

@ -0,0 +1,16 @@
using System;
using System.ComponentModel;
namespace StableDiffusion.NET.Extensions;
internal static class EnumExtension
{
public static string GetDescription(this Enum value)
{
DescriptionAttribute[]? attributes = (DescriptionAttribute[]?)value.GetType().GetField(value.ToString())?.GetCustomAttributes(typeof(DescriptionAttribute), false);
return attributes?.Length > 0
? attributes[0].Description
: value.ToString();
}
}

View File

@ -0,0 +1,103 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
namespace StableDiffusion.NET;
internal static partial class Native
{
#region Properties & Fields
private static nint _loadedLibraryHandle;
#endregion
#region Constructors
static Native()
{
NativeLibrary.SetDllImportResolver(typeof(Native).Assembly, ResolveDllImport);
}
#endregion
#region Methods
private static nint ResolveDllImport(string libraryname, Assembly assembly, DllImportSearchPath? searchpath)
{
if (libraryname != LIB_NAME) return nint.Zero;
if (_loadedLibraryHandle != nint.Zero) return _loadedLibraryHandle;
_loadedLibraryHandle = TryLoadLibrary();
return _loadedLibraryHandle;
}
private static nint TryLoadLibrary()
{
GetPlatformPathParts(out string os, out string fileExtension, out string libPrefix);
foreach (IBackend backend in Backends.ActiveBackends.OrderBy(x => x.Priority))
{
string path = Path.Combine("runtimes", os, "native", backend.PathPart, $"{libPrefix}{LIB_NAME}{fileExtension}");
string fullPath = TryFindPath(path);
nint result = TryLoad(fullPath);
if (result != nint.Zero)
return result;
}
return nint.Zero;
static nint TryLoad(string path)
{
if (NativeLibrary.TryLoad(path, out nint handle))
return handle;
return nint.Zero;
}
static string TryFindPath(string filename)
{
IEnumerable<string> searchPaths = [.. Backends.SearchPaths, AppDomain.CurrentDomain.BaseDirectory, Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) ?? ""];
foreach (string path in searchPaths)
{
string candidate = Path.Combine(path, filename);
if (File.Exists(candidate))
return candidate;
}
return filename;
}
}
private static void GetPlatformPathParts(out string os, out string fileExtension, out string libPrefix)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
os = "win-x64";
fileExtension = ".dll";
libPrefix = "";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
os = "linux-x64";
fileExtension = ".so";
libPrefix = "lib";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
fileExtension = ".dylib";
os = "osx-x64";
libPrefix = "lib";
}
else
throw new NotSupportedException("Your operating system is not supported.");
}
#endregion
}

View File

@ -1,5 +1,7 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=attributes/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=backends/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=enums/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=eventargs/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=extensions/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=extensions/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=native/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>