diff --git a/rope.hpp b/rope.hpp index 551c8ab..295c9a2 100644 --- a/rope.hpp +++ b/rope.hpp @@ -222,8 +222,8 @@ namespace Rope { int context_len, const std::vector& ref_latents, bool increase_ref_index) { - int h_len = (h + (patch_size / 2)) / patch_size / 2; - int w_len = (w + (patch_size / 2)) / patch_size / 2; + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; int txt_id_start = std::max(h_len, w_len); auto txt_ids = linspace(txt_id_start, context_len + txt_id_start, context_len); std::vector> txt_ids_repeated(bs * context_len, std::vector(3));