Updated to support SDXL and clip scip (commit 78ad76f)

This commit is contained in:
Darth Affe 2023-12-29 01:35:30 +01:00
parent 76be341505
commit 289abe4d20
4 changed files with 9 additions and 6 deletions

View File

@ -6,12 +6,15 @@ public class ModelParameter
public int ThreadCount { get; set; } = 8;
public bool VaeDecodeOnly { get; set; } = false;
public bool VaeTiling { get; set; } = false;
public string TaesdPath { get; set; } = string.Empty;
public string ESRGANPath { get; set; } = string.Empty;
public string LoraModelDir { get; set; } = string.Empty;
public RngType RngType { get; set; } = RngType.Standard;
public string VaePath { get; set; } = string.Empty;
public Quantization Quantization { get; set; } = Quantization.Default;
public Schedule Schedule { get; set; } = Schedule.Default;
public int ClipSkip { get; set; } = -1;
#endregion
}

View File

@ -66,11 +66,11 @@ internal unsafe partial class Native
internal static partial void stable_diffusion_full_params_set_strength(stable_diffusion_full_params* @params, float strength);
[LibraryImport(LIB_NAME, EntryPoint = "stable_diffusion_init")]
internal static partial stable_diffusion_ctx* stable_diffusion_init(int n_threads, [MarshalAs(UnmanagedType.I1)] bool vae_decode_only, [MarshalAs(UnmanagedType.LPStr)] string taesd_path, [MarshalAs(UnmanagedType.I1)] bool free_params_immediately, [MarshalAs(UnmanagedType.LPStr)] string lora_model_dir, [MarshalAs(UnmanagedType.LPStr)] string rng_type);
internal static partial stable_diffusion_ctx* stable_diffusion_init(int n_threads, [MarshalAs(UnmanagedType.I1)] bool vae_decode_only, [MarshalAs(UnmanagedType.LPStr)] string taesd_path, [MarshalAs(UnmanagedType.LPStr)] string esrgan_path, [MarshalAs(UnmanagedType.I1)] bool free_params_immediately, [MarshalAs(UnmanagedType.I1)] bool vae_tiling, [MarshalAs(UnmanagedType.LPStr)] string lora_model_dir, [MarshalAs(UnmanagedType.LPStr)] string rng_type);
[LibraryImport(LIB_NAME, EntryPoint = "stable_diffusion_load_from_file")]
[return: MarshalAs(UnmanagedType.I1)]
internal static partial bool stable_diffusion_load_from_file(stable_diffusion_ctx* ctx, [MarshalAs(UnmanagedType.LPStr)] string file_path, [MarshalAs(UnmanagedType.LPStr)] string vae_path, [MarshalAs(UnmanagedType.LPStr)] string wtype, [MarshalAs(UnmanagedType.LPStr)] string schedule);
internal static partial bool stable_diffusion_load_from_file(stable_diffusion_ctx* ctx, [MarshalAs(UnmanagedType.LPStr)] string file_path, [MarshalAs(UnmanagedType.LPStr)] string vae_path, [MarshalAs(UnmanagedType.LPStr)] string wtype, [MarshalAs(UnmanagedType.LPStr)] string schedule, int clip_skip);
[LibraryImport(LIB_NAME, EntryPoint = "stable_diffusion_predict_image")]
internal static partial byte* stable_diffusion_predict_image(stable_diffusion_ctx* ctx, stable_diffusion_full_params* @params, [MarshalAs(UnmanagedType.LPStr)] string prompt);

View File

@ -34,10 +34,10 @@ public sealed unsafe class StableDiffusionModel : IDisposable
private void Initialize()
{
_ctx = Native.stable_diffusion_init(_parameter.ThreadCount, _parameter.VaeDecodeOnly, _parameter.TaesdPath, false, _parameter.LoraModelDir, _parameter.RngType.GetNativeName() ?? "STD_DEFAULT_RNG");
_ctx = Native.stable_diffusion_init(_parameter.ThreadCount, _parameter.VaeDecodeOnly, _parameter.TaesdPath, _parameter.ESRGANPath, false, _parameter.VaeTiling, _parameter.LoraModelDir, _parameter.RngType.GetNativeName() ?? "STD_DEFAULT_RNG");
if (_ctx == null) throw new NullReferenceException("Failed to initialize Stable Diffusion");
bool success = Native.stable_diffusion_load_from_file(_ctx, _modelPath, _parameter.VaePath, _parameter.Quantization.GetNativeName() ?? "DEFAULT", _parameter.Schedule.GetNativeName() ?? "DEFAULT");
bool success = Native.stable_diffusion_load_from_file(_ctx, _modelPath, _parameter.VaePath, _parameter.Quantization.GetNativeName() ?? "DEFAULT", _parameter.Schedule.GetNativeName() ?? "DEFAULT", _parameter.ClipSkip);
if (!success) throw new IOException("Failed to load model");
}

View File

@ -1,10 +1,10 @@
if not exist stable-diffusion.cpp-build (
git clone https://github.com/seasonjs/stable-diffusion.cpp-build
git clone https://github.com/DarthAffe/stable-diffusion.cpp-build
)
cd stable-diffusion.cpp-build
git fetch
git checkout 4b95d98404bbfe91698fd41b0f514656e358163a
git checkout b518ce72f1ba448f164e58961b1513ccacc95006
if not exist build (
mkdir build