diff --git a/StableDiffusion.NET/Enums/LoraApplyMode.cs b/StableDiffusion.NET/Enums/LoraApplyMode.cs new file mode 100644 index 0000000..ff5c243 --- /dev/null +++ b/StableDiffusion.NET/Enums/LoraApplyMode.cs @@ -0,0 +1,8 @@ +namespace StableDiffusion.NET; + +public enum LoraApplyMode +{ + Auto, + Immediately, + AtRuntime +} \ No newline at end of file diff --git a/StableDiffusion.NET/Enums/Preview.cs b/StableDiffusion.NET/Enums/Preview.cs new file mode 100644 index 0000000..b967342 --- /dev/null +++ b/StableDiffusion.NET/Enums/Preview.cs @@ -0,0 +1,9 @@ +namespace StableDiffusion.NET; + +public enum Preview +{ + None, + Proj, + TAE, + VAE +} \ No newline at end of file diff --git a/StableDiffusion.NET/EventArgs/StableDiffusionPreviewEventArgs.cs b/StableDiffusion.NET/EventArgs/StableDiffusionPreviewEventArgs.cs new file mode 100644 index 0000000..82db62b --- /dev/null +++ b/StableDiffusion.NET/EventArgs/StableDiffusionPreviewEventArgs.cs @@ -0,0 +1,15 @@ +using System; +using HPPH; + +namespace StableDiffusion.NET; + +public sealed class StableDiffusionPreviewEventArgs(int step, bool isNoisy, Image image) : EventArgs +{ + #region Properties & Fields + + public int Step { get; } = step; + public bool IsNoisy { get; } = isNoisy; + public Image Image { get; } = image; + + #endregion +} \ No newline at end of file diff --git a/StableDiffusion.NET/Models/Parameter/DiffusionModelParameter.cs b/StableDiffusion.NET/Models/Parameter/DiffusionModelParameter.cs index 64d375c..e1b2d7b 100644 --- a/StableDiffusion.NET/Models/Parameter/DiffusionModelParameter.cs +++ b/StableDiffusion.NET/Models/Parameter/DiffusionModelParameter.cs @@ -72,6 +72,8 @@ public sealed class DiffusionModelParameter /// public bool FlashAttention { get; set; } = false; + public bool TaePreviewOnly { get; set; } = false; + /// /// use Conv2d direct in the diffusion model /// This might crash if it is not supported by the backend. @@ -91,6 +93,8 @@ public sealed class DiffusionModelParameter public Prediction Prediction { get; set; } = Prediction.Default; + public LoraApplyMode LoraApplyMode { get; set; } = LoraApplyMode.Auto; + /// /// quantizes on load /// not really useful in most cases diff --git a/StableDiffusion.NET/Native/Marshaller/DiffusionModelParameterMarshaller.cs b/StableDiffusion.NET/Native/Marshaller/DiffusionModelParameterMarshaller.cs index b975543..e892316 100644 --- a/StableDiffusion.NET/Native/Marshaller/DiffusionModelParameterMarshaller.cs +++ b/StableDiffusion.NET/Native/Marshaller/DiffusionModelParameterMarshaller.cs @@ -30,11 +30,13 @@ internal static unsafe class DiffusionModelParameterMarshaller wtype = managed.Quantization, rng_type = managed.RngType, prediction = managed.Prediction, + lora_apply_mode = managed.LoraApplyMode, offload_params_to_cpu = (sbyte)(managed.OffloadParamsToCPU ? 1 : 0), keep_clip_on_cpu = (sbyte)(managed.KeepClipOnCPU ? 1 : 0), keep_control_net_on_cpu = (sbyte)(managed.KeepControlNetOnCPU ? 1 : 0), keep_vae_on_cpu = (sbyte)(managed.KeepVaeOnCPU ? 1 : 0), diffusion_flash_attn = (sbyte)(managed.FlashAttention ? 1 : 0), + tae_preview_only = (sbyte)(managed.TaePreviewOnly ? 1 : 0), diffusion_conv_direct = (sbyte)(managed.DiffusionConvDirect ? 1 : 0), vae_conv_direct = (sbyte)(managed.VaeConvDirect ? 1 : 0), force_sdxl_vae_conv_scale = (sbyte)(managed.ForceSdxlVaeConvScale ? 1 : 0), @@ -68,11 +70,13 @@ internal static unsafe class DiffusionModelParameterMarshaller Quantization = unmanaged.wtype, RngType = unmanaged.rng_type, Prediction = unmanaged.prediction, + LoraApplyMode = unmanaged.lora_apply_mode, OffloadParamsToCPU = unmanaged.offload_params_to_cpu == 1, KeepClipOnCPU = unmanaged.keep_clip_on_cpu == 1, KeepControlNetOnCPU = unmanaged.keep_control_net_on_cpu == 1, KeepVaeOnCPU = unmanaged.keep_vae_on_cpu == 1, FlashAttention = unmanaged.diffusion_flash_attn == 1, + TaePreviewOnly = unmanaged.tae_preview_only == 1, DiffusionConvDirect = unmanaged.diffusion_conv_direct == 1, VaeConvDirect = unmanaged.vae_conv_direct == 1, ForceSdxlVaeConvScale = unmanaged.force_sdxl_vae_conv_scale == 1, diff --git a/StableDiffusion.NET/Native/Native.cs b/StableDiffusion.NET/Native/Native.cs index b007d65..32c6620 100644 --- a/StableDiffusion.NET/Native/Native.cs +++ b/StableDiffusion.NET/Native/Native.cs @@ -23,6 +23,8 @@ using sd_img_gen_params_t = ImageGenerationParameter; using sd_log_level_t = LogLevel; using sd_type_t = Quantization; using sd_vid_gen_params_t = VideoGenerationParameter; +using lora_apply_mode_t = LoraApplyMode; +using preview_t = Preview; using size_t = nuint; using uint32_t = uint; using uint8_t = byte; @@ -73,11 +75,13 @@ internal unsafe partial class Native public sd_type_t wtype; public rng_type_t rng_type; public prediction_t prediction; + public lora_apply_mode_t lora_apply_mode; public sbyte offload_params_to_cpu; public sbyte keep_clip_on_cpu; public sbyte keep_control_net_on_cpu; public sbyte keep_vae_on_cpu; public sbyte diffusion_flash_attn; + public sbyte tae_preview_only; public sbyte diffusion_conv_direct; public sbyte vae_conv_direct; public sbyte force_sdxl_vae_conv_scale; @@ -188,6 +192,7 @@ internal unsafe partial class Native internal delegate void sd_log_cb_t(sd_log_level_t level, [MarshalAs(UnmanagedType.LPStr)] string text, void* data); internal delegate void sd_progress_cb_t(int step, int steps, float time, void* data); + internal delegate void sd_preview_cb_t(int step, int frame_count, sd_image_t* frames, bool is_noisy); #endregion @@ -199,6 +204,9 @@ internal unsafe partial class Native [LibraryImport(LIB_NAME, EntryPoint = "sd_set_progress_callback")] internal static partial void sd_set_progress_callback(sd_progress_cb_t cb, void* data); + [LibraryImport(LIB_NAME, EntryPoint = "sd_set_preview_callback")] + internal static partial void sd_set_preview_callback(sd_preview_cb_t? cb, preview_t mode, int interval, [MarshalAs(UnmanagedType.I1)] bool denoised, [MarshalAs(UnmanagedType.I1)] bool noisy); + [LibraryImport(LIB_NAME, EntryPoint = "get_num_physical_cores")] internal static partial int32_t get_num_physical_cores(); @@ -243,6 +251,20 @@ internal unsafe partial class Native [LibraryImport(LIB_NAME, EntryPoint = "str_to_prediction")] internal static partial prediction_t str_to_prediction([MarshalAs(UnmanagedType.LPStr)] string str); + [LibraryImport(LIB_NAME, EntryPoint = "sd_preview_name")] + [return: MarshalAs(UnmanagedType.LPStr)] + internal static partial string sd_preview_name(preview_t preview); + + [LibraryImport(LIB_NAME, EntryPoint = "str_to_preview")] + internal static partial preview_t str_to_preview([MarshalAs(UnmanagedType.LPStr)] string str); + + [LibraryImport(LIB_NAME, EntryPoint = "sd_lora_apply_mode_name")] + [return: MarshalAs(UnmanagedType.LPStr)] + internal static partial string sd_lora_apply_mode_name(lora_apply_mode_t mode); + + [LibraryImport(LIB_NAME, EntryPoint = "str_to_lora_apply_mode")] + internal static partial lora_apply_mode_t str_to_lora_apply_mode([MarshalAs(UnmanagedType.LPStr)] string str); + // [LibraryImport(LIB_NAME, EntryPoint = "sd_ctx_params_init")] diff --git a/StableDiffusion.NET/StableDiffusionCpp.cs b/StableDiffusion.NET/StableDiffusionCpp.cs index eb3d650..48081fd 100644 --- a/StableDiffusion.NET/StableDiffusionCpp.cs +++ b/StableDiffusion.NET/StableDiffusionCpp.cs @@ -1,6 +1,6 @@ -using System; -using HPPH; +using HPPH; using JetBrains.Annotations; +using System; namespace StableDiffusion.NET; @@ -12,6 +12,7 @@ public static unsafe class StableDiffusionCpp // ReSharper disable NotAccessedField.Local - They are important, the delegate can be collected if it's not stored! private static Native.sd_log_cb_t? _logCallback; private static Native.sd_progress_cb_t? _progressCallback; + private static Native.sd_preview_cb_t? _previewCallback; // ReSharper restore NotAccessedField.Local #endregion @@ -20,6 +21,7 @@ public static unsafe class StableDiffusionCpp public static event EventHandler? Log; public static event EventHandler? Progress; + public static event EventHandler? Preview; #endregion @@ -33,6 +35,19 @@ public static unsafe class StableDiffusionCpp Native.sd_set_progress_callback(_progressCallback = OnNativeProgress, null); } + public static void EnablePreview(Preview mode, int interval, bool denoised, bool noisy) + { + ArgumentOutOfRangeException.ThrowIfNegative(interval); + + if (mode == NET.Preview.None) + _previewCallback = null; + + else if (_previewCallback == null) + _previewCallback = OnPreview; + + Native.sd_set_preview_callback(_previewCallback, mode, interval, denoised, noisy); + } + public static void Convert(string modelPath, string vaePath, Quantization quantization, string outputPath, string tensorTypeRules = "") { ArgumentException.ThrowIfNullOrWhiteSpace(nameof(modelPath)); @@ -89,5 +104,18 @@ public static unsafe class StableDiffusionCpp catch { /**/ } } + private static void OnPreview(int step, int frameCount, Native.Types.sd_image_t* frames, bool isNoisy) + { + try + { + if (frameCount <= 0 || frames == null) return; + + Image image = ImageHelper.GetImage(frames, 0); + + Preview?.Invoke(null, new StableDiffusionPreviewEventArgs(step, isNoisy, image)); + } + catch { /**/ } + } + #endregion } \ No newline at end of file