diff --git a/flux.hpp b/flux.hpp index be00e6b..3b398b4 100644 --- a/flux.hpp +++ b/flux.hpp @@ -840,7 +840,9 @@ public: context = to_backend(context); y = to_backend(y); timesteps = to_backend(timesteps); - guidance = to_backend(guidance); + if (flux_params.guidance_embed) { + guidance = to_backend(guidance); + } pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;