diff --git a/rope.hpp b/rope.hpp index 4abc514..12047e3 100644 --- a/rope.hpp +++ b/rope.hpp @@ -91,14 +91,23 @@ namespace Rope { int axes_dim_num, int index = 0, int h_offset = 0, - int w_offset = 0) { + int w_offset = 0, + bool scale_rope = false) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; std::vector> img_ids(h_len * w_len, std::vector(axes_dim_num, 0.0)); - std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); - std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); + int h_start = h_offset; + int w_start = w_offset; + + if (scale_rope) { + h_start -= h_len / 2; + w_start -= w_len / 2; + } + + std::vector row_ids = linspace(h_start, h_start + h_len - 1, h_len); + std::vector col_ids = linspace(w_start, w_start + w_len - 1, w_len); for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { @@ -171,7 +180,8 @@ namespace Rope { int axes_dim_num, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { + float ref_index_scale, + bool scale_rope) { std::vector> ids; uint64_t curr_h_offset = 0; uint64_t curr_w_offset = 0; @@ -185,6 +195,7 @@ namespace Rope { } else { h_offset = curr_h_offset; } + scale_rope = false; } auto ref_ids = gen_flux_img_ids(ref->ne[1], @@ -194,7 +205,8 @@ namespace Rope { axes_dim_num, static_cast(index * ref_index_scale), h_offset, - w_offset); + w_offset, + scale_rope); ids = concat_ids(ids, ref_ids, bs); if (increase_ref_index) { @@ -222,7 +234,7 @@ namespace Rope { auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false); ids = concat_ids(ids, refs_ids, bs); } return ids; @@ -271,10 +283,10 @@ namespace Rope { } } int axes_dim_num = 3; - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true); ids = concat_ids(ids, refs_ids, bs); } return ids;