#ifndef __SD_GUIDANCE_H__ #define __SD_GUIDANCE_H__ #include #include #include #include "tensor.hpp" namespace sd::guidance { struct GuiderOutput { sd::Tensor pred; sd::Tensor pred_cond; sd::Tensor pred_uncond; sd::Tensor pred_img_cond; sd::Tensor pred_skip_layer; }; struct GuidanceInput { int step = 0; size_t schedule_size = 0; const sd::Tensor* pred_cond = nullptr; const sd::Tensor* pred_uncond = nullptr; const sd::Tensor* pred_img_cond = nullptr; std::function()> predict_skip_layer; }; class BaseGuidance { public: virtual ~BaseGuidance() = default; virtual GuiderOutput forward(const GuidanceInput& input, GuiderOutput previous) const = 0; }; class ClassifierFreeGuidance : public BaseGuidance { float guidance_scale_ = 1.0f; float image_guidance_scale_ = 1.0f; public: ClassifierFreeGuidance(float guidance_scale, float image_guidance_scale); GuiderOutput forward(const GuidanceInput& input, GuiderOutput previous) const override; }; class SkipLayerGuidance : public BaseGuidance { std::vector layers_; float scale_ = 0.0f; float start_ = 0.0f; float stop_ = 1.0f; public: SkipLayerGuidance(std::vector layers, float scale, float start, float stop); bool is_enabled_for_step(const GuidanceInput& input) const; const std::vector& layers() const; GuiderOutput forward(const GuidanceInput& input, GuiderOutput previous) const override; }; } // namespace sd::guidance #endif // __SD_GUIDANCE_H__