diff --git a/rope.hpp b/rope.hpp index 2518023..7a35926 100644 --- a/rope.hpp +++ b/rope.hpp @@ -401,7 +401,7 @@ namespace Rope { int index = padded_context_len + 1; auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index); - int img_pad_len = ((-context_len) % seq_multi_of); + int img_pad_len = bound_mod(static_cast(img_ids.size() / bs), seq_multi_of); if (img_pad_len > 0) { std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); img_ids = concat_ids(img_ids, img_pad_ids, bs);