fix schnell support

This commit is contained in:
leejet 2024-08-21 21:44:12 +08:00
parent 8650b87e0e
commit 77ca8e3f48

View File

@ -840,7 +840,9 @@ public:
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
guidance = to_backend(guidance);
if (flux_params.guidance_embed) {
guidance = to_backend(guidance);
}
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;