feat: add scale_rope support (#1121)

This commit is contained in:
leejet 2025-12-21 15:40:21 +08:00 committed by GitHub
parent 60abda56e0
commit 88ec9d30b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -91,14 +91,23 @@ namespace Rope {
int axes_dim_num, int axes_dim_num,
int index = 0, int index = 0,
int h_offset = 0, int h_offset = 0,
int w_offset = 0) { int w_offset = 0,
bool scale_rope = false) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0)); std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len); int h_start = h_offset;
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len); int w_start = w_offset;
if (scale_rope) {
h_start -= h_len / 2;
w_start -= w_len / 2;
}
std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len);
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len);
for (int i = 0; i < h_len; ++i) { for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) { for (int j = 0; j < w_len; ++j) {
@ -171,7 +180,8 @@ namespace Rope {
int axes_dim_num, int axes_dim_num,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale) { float ref_index_scale,
bool scale_rope) {
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
@ -185,6 +195,7 @@ namespace Rope {
} else { } else {
h_offset = curr_h_offset; h_offset = curr_h_offset;
} }
scale_rope = false;
} }
auto ref_ids = gen_flux_img_ids(ref->ne[1], auto ref_ids = gen_flux_img_ids(ref->ne[1],
@ -194,7 +205,8 @@ namespace Rope {
axes_dim_num, axes_dim_num,
static_cast<int>(index * ref_index_scale), static_cast<int>(index * ref_index_scale),
h_offset, h_offset,
w_offset); w_offset,
scale_rope);
ids = concat_ids(ids, ref_ids, bs); ids = concat_ids(ids, ref_ids, bs);
if (increase_ref_index) { if (increase_ref_index) {
@ -222,7 +234,7 @@ namespace Rope {
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;
@ -271,10 +283,10 @@ namespace Rope {
} }
} }
int axes_dim_num = 3; int axes_dim_num = 3;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;