#ifndef __DIFFUSION_MODEL_H__ #define __DIFFUSION_MODEL_H__ #include #include "anima.hpp" #include "ernie_image.hpp" #include "flux.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" #include "tensor_ggml.hpp" #include "unet.hpp" #include "wan.hpp" #include "z_image.hpp" struct DiffusionParams { const sd::Tensor* x = nullptr; const sd::Tensor* timesteps = nullptr; const sd::Tensor* context = nullptr; const sd::Tensor* c_concat = nullptr; const sd::Tensor* y = nullptr; const sd::Tensor* t5_ids = nullptr; const sd::Tensor* t5_weights = nullptr; const sd::Tensor* guidance = nullptr; const std::vector>* ref_latents = nullptr; bool increase_ref_index = false; int num_video_frames = -1; const std::vector>* controls = nullptr; float control_strength = 0.f; const sd::Tensor* vace_context = nullptr; float vace_strength = 1.f; const std::vector* skip_layers = nullptr; }; template static inline const sd::Tensor& tensor_or_empty(const sd::Tensor* tensor) { static const sd::Tensor kEmpty; return tensor != nullptr ? *tensor : kEmpty; } struct DiffusionModel { virtual std::string get_desc() = 0; virtual sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter){}; virtual int64_t get_adm_in_channels() = 0; virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; }; struct UNetModel : public DiffusionModel { UNetModelRunner unet; UNetModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, SDVersion version = VERSION_SD1) : unet(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version) { } std::string get_desc() override { return unet.get_desc(); } void alloc_params_buffer() override { unet.alloc_params_buffer(); } void free_params_buffer() override { unet.free_params_buffer(); } void free_compute_buffer() override { unet.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { unet.get_param_tensors(tensors, "model.diffusion_model"); } size_t get_params_buffer_size() override { return unet.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { unet.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return unet.unet.adm_in_channels; } void set_flash_attention_enabled(bool enabled) { unet.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { unet.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); static const std::vector> empty_controls; return unet.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(diffusion_params.c_concat), tensor_or_empty(diffusion_params.y), diffusion_params.num_video_frames, diffusion_params.controls ? *diffusion_params.controls : empty_controls, diffusion_params.control_strength); } }; struct MMDiTModel : public DiffusionModel { MMDiTRunner mmdit; MMDiTModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}) : mmdit(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model") { } std::string get_desc() override { return mmdit.get_desc(); } void alloc_params_buffer() override { mmdit.alloc_params_buffer(); } void free_params_buffer() override { mmdit.free_params_buffer(); } void free_compute_buffer() override { mmdit.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { mmdit.get_param_tensors(tensors, "model.diffusion_model"); } size_t get_params_buffer_size() override { return mmdit.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { mmdit.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768 + 1280; } void set_flash_attention_enabled(bool enabled) { mmdit.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { mmdit.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); static const std::vector empty_skip_layers; return mmdit.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(diffusion_params.y), diffusion_params.skip_layers ? *diffusion_params.skip_layers : empty_skip_layers); } }; struct FluxModel : public DiffusionModel { Flux::FluxRunner flux; FluxModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, SDVersion version = VERSION_FLUX, bool use_mask = false) : flux(backend, offload_params_to_cpu, tensor_storage_map, "model.diffusion_model", version, use_mask) { } std::string get_desc() override { return flux.get_desc(); } void alloc_params_buffer() override { flux.alloc_params_buffer(); } void free_params_buffer() override { flux.free_params_buffer(); } void free_compute_buffer() override { flux.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { flux.get_param_tensors(tensors, "model.diffusion_model"); } size_t get_params_buffer_size() override { return flux.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { flux.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768; } void set_flash_attention_enabled(bool enabled) { flux.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { flux.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); static const std::vector> empty_ref_latents; static const std::vector empty_skip_layers; return flux.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(diffusion_params.c_concat), tensor_or_empty(diffusion_params.y), tensor_or_empty(diffusion_params.guidance), diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents, diffusion_params.increase_ref_index, diffusion_params.skip_layers ? *diffusion_params.skip_layers : empty_skip_layers); } }; struct AnimaModel : public DiffusionModel { std::string prefix; Anima::AnimaRunner anima; AnimaModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "model.diffusion_model") : prefix(prefix), anima(backend, offload_params_to_cpu, tensor_storage_map, prefix) { } std::string get_desc() override { return anima.get_desc(); } void alloc_params_buffer() override { anima.alloc_params_buffer(); } void free_params_buffer() override { anima.free_params_buffer(); } void free_compute_buffer() override { anima.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { anima.get_param_tensors(tensors, prefix); } size_t get_params_buffer_size() override { return anima.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { anima.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768; } void set_flash_attention_enabled(bool enabled) { anima.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { anima.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); return anima.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(diffusion_params.t5_ids), tensor_or_empty(diffusion_params.t5_weights)); } }; struct WanModel : public DiffusionModel { std::string prefix; WAN::WanRunner wan; WanModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "model.diffusion_model", SDVersion version = VERSION_WAN2) : prefix(prefix), wan(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { } std::string get_desc() override { return wan.get_desc(); } void alloc_params_buffer() override { wan.alloc_params_buffer(); } void free_params_buffer() override { wan.free_params_buffer(); } void free_compute_buffer() override { wan.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { wan.get_param_tensors(tensors, prefix); } size_t get_params_buffer_size() override { return wan.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { wan.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768; } void set_flash_attention_enabled(bool enabled) { wan.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { wan.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); return wan.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), tensor_or_empty(diffusion_params.y), tensor_or_empty(diffusion_params.c_concat), sd::Tensor(), tensor_or_empty(diffusion_params.vace_context), diffusion_params.vace_strength); } }; struct QwenImageModel : public DiffusionModel { std::string prefix; Qwen::QwenImageRunner qwen_image; QwenImageModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "model.diffusion_model", SDVersion version = VERSION_QWEN_IMAGE, bool zero_cond_t = false) : prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version, zero_cond_t) { } std::string get_desc() override { return qwen_image.get_desc(); } void alloc_params_buffer() override { qwen_image.alloc_params_buffer(); } void free_params_buffer() override { qwen_image.free_params_buffer(); } void free_compute_buffer() override { qwen_image.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { qwen_image.get_param_tensors(tensors, prefix); } size_t get_params_buffer_size() override { return qwen_image.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { qwen_image.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768; } void set_flash_attention_enabled(bool enabled) { qwen_image.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { qwen_image.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); static const std::vector> empty_ref_latents; return qwen_image.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents, true); } }; struct ZImageModel : public DiffusionModel { std::string prefix; ZImage::ZImageRunner z_image; ZImageModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "model.diffusion_model", SDVersion version = VERSION_Z_IMAGE) : prefix(prefix), z_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { } std::string get_desc() override { return z_image.get_desc(); } void alloc_params_buffer() override { z_image.alloc_params_buffer(); } void free_params_buffer() override { z_image.free_params_buffer(); } void free_compute_buffer() override { z_image.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { z_image.get_param_tensors(tensors, prefix); } size_t get_params_buffer_size() override { return z_image.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { z_image.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768; } void set_flash_attention_enabled(bool enabled) { z_image.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { z_image.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); static const std::vector> empty_ref_latents; return z_image.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context), diffusion_params.ref_latents ? *diffusion_params.ref_latents : empty_ref_latents, true); } }; struct ErnieImageModel : public DiffusionModel { std::string prefix; ErnieImage::ErnieImageRunner ernie_image; ErnieImageModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "model.diffusion_model") : prefix(prefix), ernie_image(backend, offload_params_to_cpu, tensor_storage_map, prefix) { } std::string get_desc() override { return ernie_image.get_desc(); } void alloc_params_buffer() override { ernie_image.alloc_params_buffer(); } void free_params_buffer() override { ernie_image.free_params_buffer(); } void free_compute_buffer() override { ernie_image.free_compute_buffer(); } void get_param_tensors(std::map& tensors) override { ernie_image.get_param_tensors(tensors, prefix); } size_t get_params_buffer_size() override { return ernie_image.get_params_buffer_size(); } void set_weight_adapter(const std::shared_ptr& adapter) override { ernie_image.set_weight_adapter(adapter); } int64_t get_adm_in_channels() override { return 768; } void set_flash_attention_enabled(bool enabled) { ernie_image.set_flash_attention_enabled(enabled); } void set_circular_axes(bool circular_x, bool circular_y) override { ernie_image.set_circular_axes(circular_x, circular_y); } sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) override { GGML_ASSERT(diffusion_params.x != nullptr); GGML_ASSERT(diffusion_params.timesteps != nullptr); return ernie_image.compute(n_threads, *diffusion_params.x, *diffusion_params.timesteps, tensor_or_empty(diffusion_params.context)); } }; #endif