From 2a7e419e4cefc868b580034b806bc7e840dcdb51 Mon Sep 17 00:00:00 2001 From: Darth Affe Date: Sat, 23 Mar 2024 03:40:34 +0100 Subject: [PATCH] Added basic hip-blas detection for windows --- StableDiffusion.NET/Backends/RocmBackend.cs | 65 +++++++++++++++++++-- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/StableDiffusion.NET/Backends/RocmBackend.cs b/StableDiffusion.NET/Backends/RocmBackend.cs index 33a0fb8..acce69f 100644 --- a/StableDiffusion.NET/Backends/RocmBackend.cs +++ b/StableDiffusion.NET/Backends/RocmBackend.cs @@ -1,6 +1,10 @@ -namespace StableDiffusion.NET; +using System; +using System.Runtime.InteropServices; +using System.Text.RegularExpressions; -public class RocmBackend : IBackend +namespace StableDiffusion.NET; + +public partial class RocmBackend : IBackend { #region Properties & Fields @@ -8,15 +12,66 @@ public class RocmBackend : IBackend public int Priority => 10; - public bool IsAvailable => false; + public bool IsAvailable => (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + /*|| RuntimeInformation.IsOSPlatform(OSPlatform.Linux)*/) + && (RuntimeInformation.OSArchitecture == Architecture.X64) + && RocmVersion is 5; - public string PathPart { get; } = string.Empty; + public string PathPart => RocmVersion switch + { + 5 => "rocm5", + _ => string.Empty + }; + + public int RocmVersion { get; } #endregion #region Constructors - internal RocmBackend() { } + internal RocmBackend() + { + RocmVersion = GetRocmMajorVersion(); + } + + #endregion + + #region Methods + + private static int GetRocmMajorVersion() + { + try + { + string version = ""; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + string? rocmPath = Environment.GetEnvironmentVariable("HIP_PATH"); + + if (rocmPath == null) return -1; + + Match match = GetWindowsVersionRegex().Match(rocmPath); + if (match.Success) + version = match.Groups["version"].Value; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + //TODO DarthAffe 23.03.2024: Get some info where it's located on linux + } + + if (string.IsNullOrEmpty(version)) + return -1; + + version = version.Split('.')[0]; + if (int.TryParse(version, out int majorVersion)) + return majorVersion; + } + catch { /* No version or error */ } + + return -1; + } + + [GeneratedRegex(@".*?\\(?\d+.\d*)\\")] + private static partial Regex GetWindowsVersionRegex(); #endregion } \ No newline at end of file