From 77ca8e3f48a051ef95381712a0874aa4c52b5952 Mon Sep 17 00:00:00 2001 From: leejet Date: Wed, 21 Aug 2024 21:44:12 +0800 Subject: [PATCH] fix schnell support --- flux.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flux.hpp b/flux.hpp index be00e6b..3b398b4 100644 --- a/flux.hpp +++ b/flux.hpp @@ -840,7 +840,9 @@ public: context = to_backend(context); y = to_backend(y); timesteps = to_backend(timesteps); - guidance = to_backend(guidance); + if (flux_params.guidance_embed) { + guidance = to_backend(guidance); + } pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;