Merge pull request #68 from DarthAffe/sd.cpp_update

Updated sd.cpp to 43a70e8
This commit is contained in:
DarthAffe 2025-12-30 11:27:28 +01:00 committed by GitHub
commit f324960942
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 378 additions and 123 deletions

View File

@ -150,6 +150,11 @@ typedef struct {
float rel_size_y; float rel_size_y;
} sd_tiling_params_t; } sd_tiling_params_t;
typedef struct {
const char* name;
const char* path;
} sd_embedding_t;
typedef struct { typedef struct {
const char* model_path; const char* model_path;
const char* clip_l_path; const char* clip_l_path;
@ -164,7 +169,8 @@ typedef struct {
const char* taesd_path; const char* taesd_path;
const char* control_net_path; const char* control_net_path;
const char* lora_model_dir; const char* lora_model_dir;
const char* embedding_dir; const sd_embedding_t* embeddings;
uint32_t embedding_count;
const char* photo_maker_path; const char* photo_maker_path;
const char* tensor_type_rules; const char* tensor_type_rules;
bool vae_decode_only; bool vae_decode_only;
@ -219,6 +225,8 @@ typedef struct {
int sample_steps; int sample_steps;
float eta; float eta;
int shifted_timestep; int shifted_timestep;
float* custom_sigmas;
int custom_sigmas_count;
} sd_sample_params_t; } sd_sample_params_t;
typedef struct { typedef struct {
@ -236,6 +244,14 @@ typedef struct {
} sd_easycache_params_t; } sd_easycache_params_t;
typedef struct { typedef struct {
bool is_high_noise;
float multiplier;
const char* path;
} sd_lora_t;
typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt; const char* prompt;
const char* negative_prompt; const char* negative_prompt;
int clip_skip; int clip_skip;
@ -259,6 +275,8 @@ typedef struct {
} sd_img_gen_params_t; } sd_img_gen_params_t;
typedef struct { typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt; const char* prompt;
const char* negative_prompt; const char* negative_prompt;
int clip_skip; int clip_skip;
@ -331,7 +349,8 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
bool offload_params_to_cpu, bool offload_params_to_cpu,
bool direct, bool direct,
int n_threads); int n_threads,
int tile_size);
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,
@ -353,6 +372,9 @@ SD_API bool preprocess_canny(sd_image_t image,
float strong, float strong,
bool inverse); bool inverse);
SD_API const char* sd_commit(void);
SD_API const char* sd_version(void);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using JetBrains.Annotations; using JetBrains.Annotations;
namespace StableDiffusion.NET; namespace StableDiffusion.NET;
@ -24,8 +25,11 @@ public sealed class DiffusionModelParameter
/// <summary> /// <summary>
/// path to embeddings /// path to embeddings
/// </summary> /// </summary>
[Obsolete("Use Embeddings instead")]
public string EmbeddingsDirectory { get; set; } = string.Empty; public string EmbeddingsDirectory { get; set; } = string.Empty;
public List<Embedding> Embeddings { get; } = [];
/// <summary> /// <summary>
/// path to control net model /// path to control net model
/// </summary> /// </summary>

View File

@ -0,0 +1,18 @@
using System.Diagnostics.CodeAnalysis;
namespace StableDiffusion.NET;
public sealed class Embedding
{
public required string Name { get; init; }
public required string Path { get; init; }
public Embedding() { }
[SetsRequiredMembers]
public Embedding(string name, string path)
{
this.Name = name;
this.Path = path;
}
}

View File

@ -37,6 +37,7 @@ public static class DiffusionModelBuilderExtension
return parameter; return parameter;
} }
[Obsolete("Use WithEmbedding instead")]
public static DiffusionModelParameter WithEmbeddingSupport(this DiffusionModelParameter parameter, string embeddingsDirectory) public static DiffusionModelParameter WithEmbeddingSupport(this DiffusionModelParameter parameter, string embeddingsDirectory)
{ {
ArgumentNullException.ThrowIfNull(embeddingsDirectory); ArgumentNullException.ThrowIfNull(embeddingsDirectory);
@ -46,6 +47,15 @@ public static class DiffusionModelBuilderExtension
return parameter; return parameter;
} }
public static DiffusionModelParameter WithEmbedding(this DiffusionModelParameter parameter, Embedding embedding)
{
ArgumentNullException.ThrowIfNull(embedding);
parameter.Embeddings.Add(embedding);
return parameter;
}
public static DiffusionModelParameter WithControlNet(this DiffusionModelParameter parameter, string controlNetPath) public static DiffusionModelParameter WithControlNet(this DiffusionModelParameter parameter, string controlNetPath)
{ {
ArgumentNullException.ThrowIfNull(controlNetPath); ArgumentNullException.ThrowIfNull(controlNetPath);

View File

@ -1,4 +1,5 @@
using HPPH; using System.Collections.Generic;
using HPPH;
using JetBrains.Annotations; using JetBrains.Annotations;
namespace StableDiffusion.NET; namespace StableDiffusion.NET;
@ -61,6 +62,8 @@ public sealed class ImageGenerationParameter
public EasyCache EasyCache { get; } = new(); public EasyCache EasyCache { get; } = new();
public List<Lora> Loras { get; } = [];
#endregion #endregion
public static ImageGenerationParameter Create() => new(); public static ImageGenerationParameter Create() => new();

View File

@ -0,0 +1,18 @@
using System.Diagnostics.CodeAnalysis;
namespace StableDiffusion.NET;
public sealed class Lora
{
public bool IsHighNoise { get; set; } = false;
public float Multiplier { get; set; } = 1f;
public required string Path { get; init; }
public Lora() { }
[SetsRequiredMembers]
public Lora(string path)
{
this.Path = path;
}
}

View File

@ -26,5 +26,7 @@ public sealed class SampleParameter
public int ShiftedTimestep { get; set; } = 0; public int ShiftedTimestep { get; set; } = 0;
public float[] CustomSigmas { get; set; } = [];
internal SampleParameter() { } internal SampleParameter() { }
} }

View File

@ -24,5 +24,7 @@ public sealed class UpscaleModelParameter
/// </summary> /// </summary>
public bool ConvDirect { get; set; } = false; public bool ConvDirect { get; set; } = false;
public int TileSize { get; set; } = 128;
public static UpscaleModelParameter Create() => new(); public static UpscaleModelParameter Create() => new();
} }

View File

@ -1,5 +1,6 @@
using HPPH; using HPPH;
using JetBrains.Annotations; using JetBrains.Annotations;
using System.Collections.Generic;
namespace StableDiffusion.NET; namespace StableDiffusion.NET;
@ -40,6 +41,8 @@ public sealed class VideoGenerationParameter
public EasyCache EasyCache { get; } = new(); public EasyCache EasyCache { get; } = new();
public List<Lora> Loras { get; } = [];
#endregion #endregion
public static VideoGenerationParameter Create() => new(); public static VideoGenerationParameter Create() => new();

View File

@ -41,7 +41,8 @@ public sealed unsafe class UpscaleModel : IDisposable
_ctx = Native.new_upscaler_ctx(ModelParameter.ModelPath, _ctx = Native.new_upscaler_ctx(ModelParameter.ModelPath,
ModelParameter.OffloadParamsToCPU, ModelParameter.OffloadParamsToCPU,
ModelParameter.ConvDirect, ModelParameter.ConvDirect,
ModelParameter.ThreadCount); ModelParameter.ThreadCount,
ModelParameter.TileSize);
if (_ctx == null) throw new NullReferenceException("Failed to initialize upscale-model."); if (_ctx == null) throw new NullReferenceException("Failed to initialize upscale-model.");
} }

View File

@ -1,55 +1,18 @@
using System.Runtime.InteropServices.Marshalling; using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
namespace StableDiffusion.NET; namespace StableDiffusion.NET;
[CustomMarshaller(typeof(DiffusionModelParameter), MarshalMode.ManagedToUnmanagedIn, typeof(DiffusionModelParameterMarshaller))] [CustomMarshaller(typeof(DiffusionModelParameter), MarshalMode.ManagedToUnmanagedIn, typeof(DiffusionModelParameterMarshallerIn))]
[CustomMarshaller(typeof(DiffusionModelParameter), MarshalMode.ManagedToUnmanagedRef, typeof(DiffusionModelParameterMarshaller))] [CustomMarshaller(typeof(DiffusionModelParameter), MarshalMode.ManagedToUnmanagedOut, typeof(DiffusionModelParameterMarshaller))]
[CustomMarshaller(typeof(DiffusionModelParameter), MarshalMode.ManagedToUnmanagedRef, typeof(DiffusionModelParameterMarshallerRef))]
internal static unsafe class DiffusionModelParameterMarshaller internal static unsafe class DiffusionModelParameterMarshaller
{ {
public static Native.Types.sd_ctx_params_t ConvertToUnmanaged(DiffusionModelParameter managed)
=> new()
{
model_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ModelPath),
clip_l_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ClipLPath),
clip_g_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ClipGPath),
clip_vision_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ClipVisionPath),
t5xxl_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.T5xxlPath),
llm_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.LLMPath),
llm_vision_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.LLMVisionPath),
diffusion_model_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.DiffusionModelPath),
high_noise_diffusion_model_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.HighNoiseDiffusionModelPath),
vae_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.VaePath),
taesd_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.TaesdPath),
control_net_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ControlNetPath),
lora_model_dir = AnsiStringMarshaller.ConvertToUnmanaged(managed.LoraModelDirectory),
embedding_dir = AnsiStringMarshaller.ConvertToUnmanaged(managed.EmbeddingsDirectory),
photo_maker_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.StackedIdEmbeddingsDirectory),
tensor_type_rules = AnsiStringMarshaller.ConvertToUnmanaged(managed.TensorTypeRules),
vae_decode_only = (sbyte)(managed.VaeDecodeOnly ? 1 : 0),
free_params_immediately = (sbyte)(managed.FreeParamsImmediately ? 1 : 0),
n_threads = managed.ThreadCount,
wtype = managed.Quantization,
rng_type = managed.RngType,
sampler_rng_type = managed.SamplerRngType,
prediction = managed.Prediction,
lora_apply_mode = managed.LoraApplyMode,
offload_params_to_cpu = (sbyte)(managed.OffloadParamsToCPU ? 1 : 0),
keep_clip_on_cpu = (sbyte)(managed.KeepClipOnCPU ? 1 : 0),
keep_control_net_on_cpu = (sbyte)(managed.KeepControlNetOnCPU ? 1 : 0),
keep_vae_on_cpu = (sbyte)(managed.KeepVaeOnCPU ? 1 : 0),
diffusion_flash_attn = (sbyte)(managed.FlashAttention ? 1 : 0),
tae_preview_only = (sbyte)(managed.TaePreviewOnly ? 1 : 0),
diffusion_conv_direct = (sbyte)(managed.DiffusionConvDirect ? 1 : 0),
vae_conv_direct = (sbyte)(managed.VaeConvDirect ? 1 : 0),
force_sdxl_vae_conv_scale = (sbyte)(managed.ForceSdxlVaeConvScale ? 1 : 0),
chroma_use_dit_mask = (sbyte)(managed.ChromaUseDitMap ? 1 : 0),
chroma_use_t5_mask = (sbyte)(managed.ChromaEnableT5Map ? 1 : 0),
chroma_t5_mask_pad = managed.ChromaT5MaskPad,
flow_shift = managed.FlowShift
};
public static DiffusionModelParameter ConvertToManaged(Native.Types.sd_ctx_params_t unmanaged) public static DiffusionModelParameter ConvertToManaged(Native.Types.sd_ctx_params_t unmanaged)
=> new() {
DiffusionModelParameter parameter = new()
{ {
ModelPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.model_path) ?? string.Empty, ModelPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.model_path) ?? string.Empty,
ClipLPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.clip_l_path) ?? string.Empty, ClipLPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.clip_l_path) ?? string.Empty,
@ -64,7 +27,6 @@ internal static unsafe class DiffusionModelParameterMarshaller
TaesdPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.taesd_path) ?? string.Empty, TaesdPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.taesd_path) ?? string.Empty,
ControlNetPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.control_net_path) ?? string.Empty, ControlNetPath = AnsiStringMarshaller.ConvertToManaged(unmanaged.control_net_path) ?? string.Empty,
LoraModelDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.lora_model_dir) ?? string.Empty, LoraModelDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.lora_model_dir) ?? string.Empty,
EmbeddingsDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.embedding_dir) ?? string.Empty,
StackedIdEmbeddingsDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.photo_maker_path) ?? string.Empty, StackedIdEmbeddingsDirectory = AnsiStringMarshaller.ConvertToManaged(unmanaged.photo_maker_path) ?? string.Empty,
TensorTypeRules = AnsiStringMarshaller.ConvertToManaged(unmanaged.tensor_type_rules) ?? string.Empty, TensorTypeRules = AnsiStringMarshaller.ConvertToManaged(unmanaged.tensor_type_rules) ?? string.Empty,
VaeDecodeOnly = unmanaged.vae_decode_only == 1, VaeDecodeOnly = unmanaged.vae_decode_only == 1,
@ -90,21 +52,155 @@ internal static unsafe class DiffusionModelParameterMarshaller
FlowShift = unmanaged.flow_shift FlowShift = unmanaged.flow_shift
}; };
public static void Free(Native.Types.sd_ctx_params_t unmanaged) for (int i = 0; i < unmanaged.embedding_count; i++)
{ {
AnsiStringMarshaller.Free(unmanaged.model_path); Native.Types.sd_embedding_t embedding = unmanaged.embeddings[i];
AnsiStringMarshaller.Free(unmanaged.clip_l_path); parameter.Embeddings.Add(new Embedding
AnsiStringMarshaller.Free(unmanaged.clip_g_path); {
AnsiStringMarshaller.Free(unmanaged.t5xxl_path); Name = AnsiStringMarshaller.ConvertToManaged(embedding.name) ?? string.Empty,
AnsiStringMarshaller.Free(unmanaged.llm_path); Path = AnsiStringMarshaller.ConvertToManaged(embedding.path) ?? string.Empty
AnsiStringMarshaller.Free(unmanaged.llm_vision_path); });
AnsiStringMarshaller.Free(unmanaged.diffusion_model_path); }
AnsiStringMarshaller.Free(unmanaged.vae_path);
AnsiStringMarshaller.Free(unmanaged.taesd_path); return parameter;
AnsiStringMarshaller.Free(unmanaged.control_net_path); }
AnsiStringMarshaller.Free(unmanaged.lora_model_dir);
AnsiStringMarshaller.Free(unmanaged.embedding_dir); internal ref struct DiffusionModelParameterMarshallerIn
AnsiStringMarshaller.Free(unmanaged.photo_maker_path); {
AnsiStringMarshaller.Free(unmanaged.tensor_type_rules); private Native.Types.sd_ctx_params_t _ctxParams;
private Native.Types.sd_embedding_t* _embeddings;
public DiffusionModelParameterMarshallerIn() { }
public void FromManaged(DiffusionModelParameter managed)
{
//_embeddings = (Native.Types.sd_embedding_t*)NativeMemory.Alloc((nuint)managed.Embeddings.Count, (nuint)Marshal.SizeOf<Native.Types.sd_embedding_t>());
//for (int i = 0; i < managed.Embeddings.Count; i++)
//{
// Embedding embedding = managed.Embeddings[i];
// _embeddings[i] = new Native.Types.sd_embedding_t
// {
// name = AnsiStringMarshaller.ConvertToUnmanaged(embedding.Name),
// path = AnsiStringMarshaller.ConvertToUnmanaged(embedding.Path),
// };
//}
//HACK DarthAffe 25.12.2025 Workaround to support EmbeddingsDir till the next major release
List<Embedding> embeddings = [];
{
embeddings.AddRange(managed.Embeddings);
try
{
if (!string.IsNullOrWhiteSpace(managed.EmbeddingsDirectory) && Directory.Exists(managed.EmbeddingsDirectory))
{
foreach (string file in Directory.GetFiles(managed.EmbeddingsDirectory))
embeddings.Add(new Embedding(Path.GetFileNameWithoutExtension(file), file));
}
}
catch { /**/ }
_embeddings = (Native.Types.sd_embedding_t*)NativeMemory.Alloc((nuint)embeddings.Count, (nuint)Marshal.SizeOf<Native.Types.sd_embedding_t>());
for (int i = 0; i < embeddings.Count; i++)
{
Embedding embedding = embeddings[i];
_embeddings[i] = new Native.Types.sd_embedding_t
{
name = AnsiStringMarshaller.ConvertToUnmanaged(embedding.Name),
path = AnsiStringMarshaller.ConvertToUnmanaged(embedding.Path),
};
}
}
_ctxParams = new Native.Types.sd_ctx_params_t
{
model_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ModelPath),
clip_l_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ClipLPath),
clip_g_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ClipGPath),
clip_vision_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ClipVisionPath),
t5xxl_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.T5xxlPath),
llm_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.LLMPath),
llm_vision_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.LLMVisionPath),
diffusion_model_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.DiffusionModelPath),
high_noise_diffusion_model_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.HighNoiseDiffusionModelPath),
vae_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.VaePath),
taesd_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.TaesdPath),
control_net_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.ControlNetPath),
lora_model_dir = AnsiStringMarshaller.ConvertToUnmanaged(managed.LoraModelDirectory),
embeddings = _embeddings,
embedding_count = (uint)embeddings.Count,
photo_maker_path = AnsiStringMarshaller.ConvertToUnmanaged(managed.StackedIdEmbeddingsDirectory),
tensor_type_rules = AnsiStringMarshaller.ConvertToUnmanaged(managed.TensorTypeRules),
vae_decode_only = (sbyte)(managed.VaeDecodeOnly ? 1 : 0),
free_params_immediately = (sbyte)(managed.FreeParamsImmediately ? 1 : 0),
n_threads = managed.ThreadCount,
wtype = managed.Quantization,
rng_type = managed.RngType,
sampler_rng_type = managed.SamplerRngType,
prediction = managed.Prediction,
lora_apply_mode = managed.LoraApplyMode,
offload_params_to_cpu = (sbyte)(managed.OffloadParamsToCPU ? 1 : 0),
keep_clip_on_cpu = (sbyte)(managed.KeepClipOnCPU ? 1 : 0),
keep_control_net_on_cpu = (sbyte)(managed.KeepControlNetOnCPU ? 1 : 0),
keep_vae_on_cpu = (sbyte)(managed.KeepVaeOnCPU ? 1 : 0),
diffusion_flash_attn = (sbyte)(managed.FlashAttention ? 1 : 0),
tae_preview_only = (sbyte)(managed.TaePreviewOnly ? 1 : 0),
diffusion_conv_direct = (sbyte)(managed.DiffusionConvDirect ? 1 : 0),
vae_conv_direct = (sbyte)(managed.VaeConvDirect ? 1 : 0),
force_sdxl_vae_conv_scale = (sbyte)(managed.ForceSdxlVaeConvScale ? 1 : 0),
chroma_use_dit_mask = (sbyte)(managed.ChromaUseDitMap ? 1 : 0),
chroma_use_t5_mask = (sbyte)(managed.ChromaEnableT5Map ? 1 : 0),
chroma_t5_mask_pad = managed.ChromaT5MaskPad,
flow_shift = managed.FlowShift
};
}
public Native.Types.sd_ctx_params_t ToUnmanaged() => _ctxParams;
public void Free()
{
AnsiStringMarshaller.Free(_ctxParams.model_path);
AnsiStringMarshaller.Free(_ctxParams.clip_l_path);
AnsiStringMarshaller.Free(_ctxParams.clip_g_path);
AnsiStringMarshaller.Free(_ctxParams.t5xxl_path);
AnsiStringMarshaller.Free(_ctxParams.llm_path);
AnsiStringMarshaller.Free(_ctxParams.llm_vision_path);
AnsiStringMarshaller.Free(_ctxParams.diffusion_model_path);
AnsiStringMarshaller.Free(_ctxParams.vae_path);
AnsiStringMarshaller.Free(_ctxParams.taesd_path);
AnsiStringMarshaller.Free(_ctxParams.control_net_path);
AnsiStringMarshaller.Free(_ctxParams.lora_model_dir);
AnsiStringMarshaller.Free(_ctxParams.photo_maker_path);
AnsiStringMarshaller.Free(_ctxParams.tensor_type_rules);
for (int i = 0; i < _ctxParams.embedding_count; i++)
{
AnsiStringMarshaller.Free(_ctxParams.embeddings[i].name);
AnsiStringMarshaller.Free(_ctxParams.embeddings[i].path);
}
if (_embeddings != null)
NativeMemory.Free(_embeddings);
}
}
internal ref struct DiffusionModelParameterMarshallerRef()
{
private DiffusionModelParameterMarshallerIn _inMarshaller = new();
private DiffusionModelParameter? _parameter;
public void FromManaged(DiffusionModelParameter managed) => _inMarshaller.FromManaged(managed);
public Native.Types.sd_ctx_params_t ToUnmanaged() => _inMarshaller.ToUnmanaged();
public void FromUnmanaged(Native.Types.sd_ctx_params_t unmanaged) => _parameter = ConvertToManaged(unmanaged);
public DiffusionModelParameter ToManaged() => _parameter!;
public void Free() => _inMarshaller.Free();
} }
} }

View File

@ -9,9 +9,9 @@ namespace StableDiffusion.NET;
[CustomMarshaller(typeof(ImageGenerationParameter), MarshalMode.ManagedToUnmanagedIn, typeof(ImageGenerationParameterMarshallerIn))] [CustomMarshaller(typeof(ImageGenerationParameter), MarshalMode.ManagedToUnmanagedIn, typeof(ImageGenerationParameterMarshallerIn))]
[CustomMarshaller(typeof(ImageGenerationParameter), MarshalMode.ManagedToUnmanagedOut, typeof(ImageGenerationParameterMarshaller))] [CustomMarshaller(typeof(ImageGenerationParameter), MarshalMode.ManagedToUnmanagedOut, typeof(ImageGenerationParameterMarshaller))]
[CustomMarshaller(typeof(ImageGenerationParameter), MarshalMode.ManagedToUnmanagedRef, typeof(ImageGenerationParameterMarshallerRef))] [CustomMarshaller(typeof(ImageGenerationParameter), MarshalMode.ManagedToUnmanagedRef, typeof(ImageGenerationParameterMarshallerRef))]
internal static class ImageGenerationParameterMarshaller internal static unsafe class ImageGenerationParameterMarshaller
{ {
public static unsafe ImageGenerationParameter ConvertToManaged(Native.Types.sd_img_gen_params_t unmanaged) public static ImageGenerationParameter ConvertToManaged(Native.Types.sd_img_gen_params_t unmanaged)
{ {
ImageGenerationParameter parameter = new() ImageGenerationParameter parameter = new()
{ {
@ -56,28 +56,21 @@ internal static class ImageGenerationParameterMarshaller
} }
}; };
for (int i = 0; i < unmanaged.lora_count; i++)
{
Native.Types.sd_lora_t lora = unmanaged.loras[i];
parameter.Loras.Add(new Lora
{
IsHighNoise = lora.is_high_noise == 1,
Multiplier = lora.multiplier,
Path = AnsiStringMarshaller.ConvertToManaged(lora.path) ?? string.Empty
});
}
return parameter; return parameter;
} }
public static unsafe void Free(Native.Types.sd_img_gen_params_t unmanaged) internal ref struct ImageGenerationParameterMarshallerIn
{
AnsiStringMarshaller.Free(unmanaged.prompt);
AnsiStringMarshaller.Free(unmanaged.negative_prompt);
unmanaged.init_image.Free();
unmanaged.mask_image.Free();
unmanaged.control_image.Free();
if (unmanaged.ref_images != null)
ImageHelper.Free(unmanaged.ref_images, unmanaged.ref_images_count);
if (unmanaged.pm_params.id_images != null)
ImageHelper.Free(unmanaged.pm_params.id_images, unmanaged.pm_params.id_images_count);
SampleParameterMarshaller.Free(unmanaged.sample_params);
}
internal unsafe ref struct ImageGenerationParameterMarshallerIn
{ {
private SampleParameterMarshaller.SampleParameterMarshallerIn _sampleParameterMarshaller = new(); private SampleParameterMarshaller.SampleParameterMarshallerIn _sampleParameterMarshaller = new();
private Native.Types.sd_img_gen_params_t _imgGenParams; private Native.Types.sd_img_gen_params_t _imgGenParams;
@ -87,6 +80,7 @@ internal static class ImageGenerationParameterMarshaller
private Native.Types.sd_image_t _controlNetImage; private Native.Types.sd_image_t _controlNetImage;
private Native.Types.sd_image_t* _refImages; private Native.Types.sd_image_t* _refImages;
private Native.Types.sd_image_t* _pmIdImages; private Native.Types.sd_image_t* _pmIdImages;
private Native.Types.sd_lora_t* _loras;
public ImageGenerationParameterMarshallerIn() { } public ImageGenerationParameterMarshallerIn() { }
@ -99,6 +93,18 @@ internal static class ImageGenerationParameterMarshaller
_refImages = managed.RefImages == null ? null : managed.RefImages.ToSdImage(); _refImages = managed.RefImages == null ? null : managed.RefImages.ToSdImage();
_pmIdImages = managed.PhotoMaker.IdImages == null ? null : managed.PhotoMaker.IdImages.ToSdImage(); _pmIdImages = managed.PhotoMaker.IdImages == null ? null : managed.PhotoMaker.IdImages.ToSdImage();
_loras = (Native.Types.sd_lora_t*)NativeMemory.Alloc((nuint)managed.Loras.Count, (nuint)Marshal.SizeOf<Native.Types.sd_lora_t>());
for (int i = 0; i < managed.Loras.Count; i++)
{
Lora lora = managed.Loras[i];
_loras[i] = new Native.Types.sd_lora_t
{
is_high_noise = (sbyte)(lora.IsHighNoise ? 1 : 0),
multiplier = lora.Multiplier,
path = AnsiStringMarshaller.ConvertToUnmanaged(lora.Path)
};
}
if (managed.MaskImage != null) if (managed.MaskImage != null)
_maskImage = managed.MaskImage.ToSdImage(true); _maskImage = managed.MaskImage.ToSdImage(true);
else if (managed.InitImage != null) else if (managed.InitImage != null)
@ -161,7 +167,9 @@ internal static class ImageGenerationParameterMarshaller
control_strength = managed.ControlNet.Strength, control_strength = managed.ControlNet.Strength,
pm_params = photoMakerParams, pm_params = photoMakerParams,
vae_tiling_params = tilingParams, vae_tiling_params = tilingParams,
easycache = easyCache easycache = easyCache,
loras = _loras,
lora_count = (uint)managed.Loras.Count
}; };
} }
@ -184,6 +192,12 @@ internal static class ImageGenerationParameterMarshaller
ImageHelper.Free(_pmIdImages, _imgGenParams.pm_params.id_images_count); ImageHelper.Free(_pmIdImages, _imgGenParams.pm_params.id_images_count);
_sampleParameterMarshaller.Free(); _sampleParameterMarshaller.Free();
for (int i = 0; i < _imgGenParams.lora_count; i++)
AnsiStringMarshaller.Free(_imgGenParams.loras[i].path);
if (_loras != null)
NativeMemory.Free(_loras);
} }
} }

View File

@ -9,9 +9,9 @@ namespace StableDiffusion.NET;
[CustomMarshaller(typeof(SampleParameter), MarshalMode.ManagedToUnmanagedIn, typeof(SampleParameterMarshallerIn))] [CustomMarshaller(typeof(SampleParameter), MarshalMode.ManagedToUnmanagedIn, typeof(SampleParameterMarshallerIn))]
[CustomMarshaller(typeof(SampleParameter), MarshalMode.ManagedToUnmanagedOut, typeof(SampleParameterMarshaller))] [CustomMarshaller(typeof(SampleParameter), MarshalMode.ManagedToUnmanagedOut, typeof(SampleParameterMarshaller))]
[CustomMarshaller(typeof(SampleParameter), MarshalMode.ManagedToUnmanagedRef, typeof(SampleParameterMarshallerRef))] [CustomMarshaller(typeof(SampleParameter), MarshalMode.ManagedToUnmanagedRef, typeof(SampleParameterMarshallerRef))]
internal static class SampleParameterMarshaller internal static unsafe class SampleParameterMarshaller
{ {
public static unsafe SampleParameter ConvertToManaged(Native.Types.sd_sample_params_t unmanaged) public static SampleParameter ConvertToManaged(Native.Types.sd_sample_params_t unmanaged)
{ {
SampleParameter parameter = new() SampleParameter parameter = new()
{ {
@ -33,32 +33,34 @@ internal static class SampleParameterMarshaller
SampleMethod = unmanaged.sample_method, SampleMethod = unmanaged.sample_method,
SampleSteps = unmanaged.sample_steps, SampleSteps = unmanaged.sample_steps,
Eta = unmanaged.eta, Eta = unmanaged.eta,
ShiftedTimestep = unmanaged.shifted_timestep ShiftedTimestep = unmanaged.shifted_timestep,
CustomSigmas = new float[unmanaged.custom_sigmas_count]
}; };
if (unmanaged.guidance.slg.layers != null) if (unmanaged.guidance.slg.layers != null)
new Span<int>(unmanaged.guidance.slg.layers, (int)unmanaged.guidance.slg.layer_count).CopyTo(parameter.Guidance.Slg.Layers); new Span<int>(unmanaged.guidance.slg.layers, (int)unmanaged.guidance.slg.layer_count).CopyTo(parameter.Guidance.Slg.Layers);
if (unmanaged.custom_sigmas != null)
new Span<float>(unmanaged.custom_sigmas, unmanaged.custom_sigmas_count).CopyTo(parameter.CustomSigmas);
return parameter; return parameter;
} }
public static unsafe void Free(Native.Types.sd_sample_params_t unmanaged) internal ref struct SampleParameterMarshallerIn
{
if (unmanaged.guidance.slg.layers != null)
NativeMemory.Free(unmanaged.guidance.slg.layers);
}
internal unsafe ref struct SampleParameterMarshallerIn
{ {
private Native.Types.sd_sample_params_t _sampleParams; private Native.Types.sd_sample_params_t _sampleParams;
private int* _slgLayers; private int* _slgLayers;
private float* _customSigmas;
public void FromManaged(SampleParameter managed) public void FromManaged(SampleParameter managed)
{ {
_slgLayers = (int*)NativeMemory.Alloc((nuint)managed.Guidance.Slg.Layers.Length, (nuint)Marshal.SizeOf<int>()); _slgLayers = (int*)NativeMemory.Alloc((nuint)managed.Guidance.Slg.Layers.Length, (nuint)Marshal.SizeOf<int>());
managed.Guidance.Slg.Layers.AsSpan().CopyTo(new Span<int>(_slgLayers, managed.Guidance.Slg.Layers.Length)); managed.Guidance.Slg.Layers.AsSpan().CopyTo(new Span<int>(_slgLayers, managed.Guidance.Slg.Layers.Length));
_customSigmas = (float*)NativeMemory.Alloc((nuint)managed.CustomSigmas.Length, (nuint)Marshal.SizeOf<float>());
managed.CustomSigmas.AsSpan().CopyTo(new Span<float>(_customSigmas, managed.CustomSigmas.Length));
Native.Types.sd_slg_params_t slg = new() Native.Types.sd_slg_params_t slg = new()
{ {
layers = _slgLayers, layers = _slgLayers,
@ -84,7 +86,9 @@ internal static class SampleParameterMarshaller
sample_method = managed.SampleMethod, sample_method = managed.SampleMethod,
sample_steps = managed.SampleSteps, sample_steps = managed.SampleSteps,
eta = managed.Eta, eta = managed.Eta,
shifted_timestep = managed.ShiftedTimestep shifted_timestep = managed.ShiftedTimestep,
custom_sigmas = _customSigmas,
custom_sigmas_count = managed.CustomSigmas.Length
}; };
} }
@ -94,6 +98,9 @@ internal static class SampleParameterMarshaller
{ {
if (_slgLayers != null) if (_slgLayers != null)
NativeMemory.Free(_slgLayers); NativeMemory.Free(_slgLayers);
if (_customSigmas != null)
NativeMemory.Free(_customSigmas);
} }
} }

View File

@ -1,5 +1,6 @@
// ReSharper disable MemberCanBeMadeStatic.Global // ReSharper disable MemberCanBeMadeStatic.Global
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling; using System.Runtime.InteropServices.Marshalling;
namespace StableDiffusion.NET; namespace StableDiffusion.NET;
@ -7,9 +8,9 @@ namespace StableDiffusion.NET;
[CustomMarshaller(typeof(VideoGenerationParameter), MarshalMode.ManagedToUnmanagedIn, typeof(VideoGenerationParameterMarshallerIn))] [CustomMarshaller(typeof(VideoGenerationParameter), MarshalMode.ManagedToUnmanagedIn, typeof(VideoGenerationParameterMarshallerIn))]
[CustomMarshaller(typeof(VideoGenerationParameter), MarshalMode.ManagedToUnmanagedOut, typeof(VideoGenerationParameterMarshaller))] [CustomMarshaller(typeof(VideoGenerationParameter), MarshalMode.ManagedToUnmanagedOut, typeof(VideoGenerationParameterMarshaller))]
[CustomMarshaller(typeof(VideoGenerationParameter), MarshalMode.ManagedToUnmanagedRef, typeof(VideoGenerationParameterMarshallerRef))] [CustomMarshaller(typeof(VideoGenerationParameter), MarshalMode.ManagedToUnmanagedRef, typeof(VideoGenerationParameterMarshallerRef))]
internal static class VideoGenerationParameterMarshaller internal static unsafe class VideoGenerationParameterMarshaller
{ {
public static unsafe VideoGenerationParameter ConvertToManaged(Native.Types.sd_vid_gen_params_t unmanaged) public static VideoGenerationParameter ConvertToManaged(Native.Types.sd_vid_gen_params_t unmanaged)
{ {
VideoGenerationParameter parameter = new() VideoGenerationParameter parameter = new()
{ {
@ -37,25 +38,21 @@ internal static class VideoGenerationParameterMarshaller
} }
}; };
for (int i = 0; i < unmanaged.lora_count; i++)
{
Native.Types.sd_lora_t lora = unmanaged.loras[i];
parameter.Loras.Add(new Lora
{
IsHighNoise = lora.is_high_noise == 1,
Multiplier = lora.multiplier,
Path = AnsiStringMarshaller.ConvertToManaged(lora.path) ?? string.Empty
});
}
return parameter; return parameter;
} }
public static unsafe void Free(Native.Types.sd_vid_gen_params_t unmanaged) internal ref struct VideoGenerationParameterMarshallerIn
{
AnsiStringMarshaller.Free(unmanaged.prompt);
AnsiStringMarshaller.Free(unmanaged.negative_prompt);
unmanaged.init_image.Free();
unmanaged.end_image.Free();
if (unmanaged.control_frames != null)
ImageHelper.Free(unmanaged.control_frames, unmanaged.control_frames_size);
SampleParameterMarshaller.Free(unmanaged.sample_params);
SampleParameterMarshaller.Free(unmanaged.high_noise_sample_params);
}
internal unsafe ref struct VideoGenerationParameterMarshallerIn
{ {
private SampleParameterMarshaller.SampleParameterMarshallerIn _sampleParameterMarshaller = new(); private SampleParameterMarshaller.SampleParameterMarshallerIn _sampleParameterMarshaller = new();
private SampleParameterMarshaller.SampleParameterMarshallerIn _highNoiseSampleParameterMarshaller = new(); private SampleParameterMarshaller.SampleParameterMarshallerIn _highNoiseSampleParameterMarshaller = new();
@ -64,6 +61,7 @@ internal static class VideoGenerationParameterMarshaller
private Native.Types.sd_image_t _initImage; private Native.Types.sd_image_t _initImage;
private Native.Types.sd_image_t _endImage; private Native.Types.sd_image_t _endImage;
private Native.Types.sd_image_t* _controlFrames; private Native.Types.sd_image_t* _controlFrames;
private Native.Types.sd_lora_t* _loras;
public VideoGenerationParameterMarshallerIn() { } public VideoGenerationParameterMarshallerIn() { }
@ -76,6 +74,18 @@ internal static class VideoGenerationParameterMarshaller
_endImage = managed.EndImage?.ToSdImage() ?? new Native.Types.sd_image_t(); _endImage = managed.EndImage?.ToSdImage() ?? new Native.Types.sd_image_t();
_controlFrames = managed.ControlFrames == null ? null : managed.ControlFrames.ToSdImage(); _controlFrames = managed.ControlFrames == null ? null : managed.ControlFrames.ToSdImage();
_loras = (Native.Types.sd_lora_t*)NativeMemory.Alloc((nuint)managed.Loras.Count, (nuint)Marshal.SizeOf<Native.Types.sd_lora_t>());
for (int i = 0; i < managed.Loras.Count; i++)
{
Lora lora = managed.Loras[i];
_loras[i] = new Native.Types.sd_lora_t
{
is_high_noise = (sbyte)(lora.IsHighNoise ? 1 : 0),
multiplier = lora.Multiplier,
path = AnsiStringMarshaller.ConvertToUnmanaged(lora.Path)
};
}
Native.Types.sd_easycache_params_t easyCache = new() Native.Types.sd_easycache_params_t easyCache = new()
{ {
enabled = (sbyte)(managed.EasyCache.IsEnabled ? 1 : 0), enabled = (sbyte)(managed.EasyCache.IsEnabled ? 1 : 0),
@ -103,6 +113,8 @@ internal static class VideoGenerationParameterMarshaller
video_frames = managed.FrameCount, video_frames = managed.FrameCount,
vace_strength = managed.VaceStrength, vace_strength = managed.VaceStrength,
easycache = easyCache, easycache = easyCache,
loras = _loras,
lora_count = (uint)managed.Loras.Count
}; };
} }
@ -121,6 +133,12 @@ internal static class VideoGenerationParameterMarshaller
_sampleParameterMarshaller.Free(); _sampleParameterMarshaller.Free();
_highNoiseSampleParameterMarshaller.Free(); _highNoiseSampleParameterMarshaller.Free();
for (int i = 0; i < _vidGenParams.lora_count; i++)
AnsiStringMarshaller.Free(_vidGenParams.loras[i].path);
if (_loras != null)
NativeMemory.Free(_loras);
} }
} }

View File

@ -42,7 +42,7 @@ internal unsafe partial class Native
internal static class Types internal static class Types
{ {
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct sd_tiling_params_t internal struct sd_tiling_params_t
{ {
public sbyte enabled; public sbyte enabled;
public int tile_size_x; public int tile_size_x;
@ -52,6 +52,13 @@ internal unsafe partial class Native
public float rel_size_y; public float rel_size_y;
} }
[StructLayout(LayoutKind.Sequential)]
internal struct sd_embedding_t
{
public byte* name;
public byte* path;
}
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
internal struct sd_ctx_params_t internal struct sd_ctx_params_t
{ {
@ -68,7 +75,8 @@ internal unsafe partial class Native
public byte* taesd_path; public byte* taesd_path;
public byte* control_net_path; public byte* control_net_path;
public byte* lora_model_dir; public byte* lora_model_dir;
public byte* embedding_dir; public sd_embedding_t* embeddings;
public uint32_t embedding_count;
public byte* photo_maker_path; public byte* photo_maker_path;
public byte* tensor_type_rules; public byte* tensor_type_rules;
public sbyte vae_decode_only; public sbyte vae_decode_only;
@ -132,6 +140,8 @@ internal unsafe partial class Native
public int sample_steps; public int sample_steps;
public float eta; public float eta;
public int shifted_timestep; public int shifted_timestep;
public float* custom_sigmas;
public int custom_sigmas_count;
} }
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
@ -152,9 +162,19 @@ internal unsafe partial class Native
public float end_percent; public float end_percent;
} }
[StructLayout(LayoutKind.Sequential)]
internal struct sd_lora_t
{
public sbyte is_high_noise;
public float multiplier;
public byte* path;
}
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
internal struct sd_img_gen_params_t internal struct sd_img_gen_params_t
{ {
public sd_lora_t* loras;
public uint32_t lora_count;
public byte* prompt; public byte* prompt;
public byte* negative_prompt; public byte* negative_prompt;
public int clip_skip; public int clip_skip;
@ -180,6 +200,8 @@ internal unsafe partial class Native
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
internal struct sd_vid_gen_params_t internal struct sd_vid_gen_params_t
{ {
public sd_lora_t* loras;
public uint32_t lora_count;
public byte* prompt; public byte* prompt;
public byte* negative_prompt; public byte* negative_prompt;
public int clip_skip; public int clip_skip;
@ -305,7 +327,8 @@ internal unsafe partial class Native
[LibraryImport(LIB_NAME, EntryPoint = "sd_sample_params_init")] [LibraryImport(LIB_NAME, EntryPoint = "sd_sample_params_init")]
internal static partial void sd_sample_params_init([MarshalUsing(typeof(SampleParameterMarshaller))] ref sd_sample_params_t sample_params); internal static partial void sd_sample_params_init([MarshalUsing(typeof(SampleParameterMarshaller))] ref sd_sample_params_t sample_params);
[LibraryImport(LIB_NAME, EntryPoint = "sd_sample_params_to_str")] [LibraryImport(LIB_NAME, EntryPoint = "sd_sample_params_to_str")]
internal static partial char* sd_sample_params_to_str([MarshalUsing(typeof(SampleParameterMarshaller))] in sd_sample_params_t sample_params); [return: MarshalAs(UnmanagedType.LPStr)]
internal static partial string sd_sample_params_to_str([MarshalUsing(typeof(SampleParameterMarshaller))] in sd_sample_params_t sample_params);
// //
@ -339,7 +362,8 @@ internal unsafe partial class Native
internal static partial upscaler_ctx_t* new_upscaler_ctx([MarshalAs(UnmanagedType.LPStr)] string esrgan_path, internal static partial upscaler_ctx_t* new_upscaler_ctx([MarshalAs(UnmanagedType.LPStr)] string esrgan_path,
[MarshalAs(UnmanagedType.I1)] bool offload_params_to_cpu, [MarshalAs(UnmanagedType.I1)] bool offload_params_to_cpu,
[MarshalAs(UnmanagedType.I1)] bool direct, [MarshalAs(UnmanagedType.I1)] bool direct,
int n_threads); int n_threads,
int tile_size);
[LibraryImport(LIB_NAME, EntryPoint = "free_upscaler_ctx")] [LibraryImport(LIB_NAME, EntryPoint = "free_upscaler_ctx")]
internal static partial void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); internal static partial void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
@ -374,5 +398,11 @@ internal unsafe partial class Native
float strong, float strong,
[MarshalAs(UnmanagedType.I1)] bool inverse); [MarshalAs(UnmanagedType.I1)] bool inverse);
[LibraryImport(LIB_NAME, EntryPoint = "sd_commit")]
internal static partial byte* sd_commit();
[LibraryImport(LIB_NAME, EntryPoint = "sd_version")]
internal static partial byte* sd_version();
#endregion #endregion
} }

View File

@ -1,6 +1,7 @@
using HPPH; using HPPH;
using JetBrains.Annotations; using JetBrains.Annotations;
using System; using System;
using System.Runtime.InteropServices.Marshalling;
namespace StableDiffusion.NET; namespace StableDiffusion.NET;
@ -15,6 +16,8 @@ public static unsafe class StableDiffusionCpp
private static Native.sd_preview_cb_t? _previewCallback; private static Native.sd_preview_cb_t? _previewCallback;
// ReSharper restore NotAccessedField.Local // ReSharper restore NotAccessedField.Local
public static string ExpectedSDCommit => "43a70e8";
#endregion #endregion
#region Events #region Events
@ -62,6 +65,10 @@ public static unsafe class StableDiffusionCpp
public static int GetNumPhysicalCores() => Native.sd_get_num_physical_cores(); public static int GetNumPhysicalCores() => Native.sd_get_num_physical_cores();
public static string GetSDCommit() => AnsiStringMarshaller.ConvertToManaged(Native.sd_commit()) ?? string.Empty;
public static string GetSDVersion() => AnsiStringMarshaller.ConvertToManaged(Native.sd_version()) ?? string.Empty;
public static Image<ColorRGB> PreprocessCanny(CannyParameter parameter) public static Image<ColorRGB> PreprocessCanny(CannyParameter parameter)
{ {
parameter.Validate(); parameter.Validate();