#ifndef __SD_MODEL_DIFFUSION_MODEL_HPP__ #define __SD_MODEL_DIFFUSION_MODEL_HPP__ #include #include #include #include "core/ggml_extend.hpp" #include "core/tensor_ggml.hpp" #include "model_manager.h" struct UNetDiffusionExtra { int num_video_frames = -1; const std::vector>* controls = nullptr; float control_strength = 0.f; }; struct SkipLayerDiffusionExtra { const std::vector* skip_layers = nullptr; }; struct FluxDiffusionExtra { const sd::Tensor* guidance = nullptr; const std::vector* skip_layers = nullptr; }; struct AnimaDiffusionExtra { const sd::Tensor* t5_ids = nullptr; const sd::Tensor* t5_weights = nullptr; }; struct WanDiffusionExtra { const sd::Tensor* vace_context = nullptr; float vace_strength = 1.f; }; struct HiDreamO1DiffusionExtra { const sd::Tensor* input_ids = nullptr; const sd::Tensor* input_pos = nullptr; const sd::Tensor* token_types = nullptr; const sd::Tensor* vinput_mask = nullptr; const std::vector>>* image_embeds = nullptr; }; struct LTXAVDiffusionExtra { const sd::Tensor* audio_x = nullptr; const sd::Tensor* audio_timesteps = nullptr; int audio_length = 0; float frame_rate = 24.f; const sd::Tensor* video_positions = nullptr; }; using DiffusionExtraParams = std::variant; struct DiffusionParams { const sd::Tensor* x = nullptr; const sd::Tensor* timesteps = nullptr; const sd::Tensor* context = nullptr; const sd::Tensor* c_concat = nullptr; const sd::Tensor* y = nullptr; const std::vector>* ref_latents = nullptr; bool increase_ref_index = false; DiffusionExtraParams extra = std::monostate{}; }; template static inline const T* diffusion_extra_as(const DiffusionParams& params) { const auto* extra = std::get_if(¶ms.extra); GGML_ASSERT(extra != nullptr); return extra; } template static inline const sd::Tensor& tensor_or_empty(const sd::Tensor* tensor) { static const sd::Tensor kEmpty; return tensor != nullptr ? *tensor : kEmpty; } struct DiffusionModelRunner : public GGMLRunner { protected: std::string prefix; public: DiffusionModelRunner(ggml_backend_t backend, const std::string& prefix, std::shared_ptr weight_manager = nullptr) : GGMLRunner(backend, weight_manager), prefix(prefix) {} virtual sd::Tensor compute(int n_threads, const DiffusionParams& diffusion_params) = 0; void get_param_tensors(std::map& tensors) { get_param_tensors(tensors, prefix); } virtual void get_param_tensors(std::map& tensors, const std::string& prefix) = 0; }; #endif // __SD_MODEL_DIFFUSION_MODEL_HPP__