feat: add scale_rope support (#1121)

This commit is contained in:
leejet 2025-12-21 15:40:21 +08:00 committed by GitHub
parent 60abda56e0
commit 88ec9d30b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -91,14 +91,23 @@ namespace Rope {
int axes_dim_num,
int index = 0,
int h_offset = 0,
int w_offset = 0) {
int w_offset = 0,
bool scale_rope = false) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
int h_start = h_offset;
int w_start = w_offset;
if (scale_rope) {
h_start -= h_len / 2;
w_start -= w_len / 2;
}
std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len);
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len);
for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
@ -171,7 +180,8 @@ namespace Rope {
int axes_dim_num,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
float ref_index_scale,
bool scale_rope) {
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
@ -185,6 +195,7 @@ namespace Rope {
} else {
h_offset = curr_h_offset;
}
scale_rope = false;
}
auto ref_ids = gen_flux_img_ids(ref->ne[1],
@ -194,7 +205,8 @@ namespace Rope {
axes_dim_num,
static_cast<int>(index * ref_index_scale),
h_offset,
w_offset);
w_offset,
scale_rope);
ids = concat_ids(ids, ref_ids, bs);
if (increase_ref_index) {
@ -222,7 +234,7 @@ namespace Rope {
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
@ -271,10 +283,10 @@ namespace Rope {
}
}
int axes_dim_num = 3;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;