feat: add circular RoPE support for ideogram4 (#1627)

This commit is contained in:
stduhpf 2026-06-13 07:06:34 +02:00 committed by GitHub
parent 1b702a51e7
commit c20769b2c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 6 deletions

View File

@ -253,7 +253,8 @@ namespace Rope {
int bs, int bs,
float theta, float theta,
int head_dim, int head_dim,
const std::vector<int>& mrope_section) { const std::vector<int>& mrope_section,
const std::vector<std::vector<int>>& axis_wrap_dims = {}) {
GGML_ASSERT(bs > 0); GGML_ASSERT(bs > 0);
GGML_ASSERT(head_dim % 2 == 0); GGML_ASSERT(head_dim % 2 == 0);
GGML_ASSERT(mrope_section.size() >= 3); GGML_ASSERT(mrope_section.size() >= 3);
@ -265,7 +266,11 @@ namespace Rope {
std::vector<std::vector<std::vector<float>>> axis_embs; std::vector<std::vector<std::vector<float>>> axis_embs;
axis_embs.reserve(3); axis_embs.reserve(3);
for (int axis = 0; axis < 3; ++axis) { for (int axis = 0; axis < 3; ++axis) {
axis_embs.push_back(rope(trans_ids[axis], head_dim, theta)); std::vector<int> axis_wrap;
if (axis < static_cast<int>(axis_wrap_dims.size())) {
axis_wrap = axis_wrap_dims[axis];
}
axis_embs.push_back(rope(trans_ids[axis], head_dim, theta, axis_wrap));
} }
std::vector<std::vector<float>> emb = axis_embs[0]; std::vector<std::vector<float>> emb = axis_embs[0];

View File

@ -151,7 +151,9 @@ namespace Ideogram4 {
int context_len, int context_len,
int head_dim, int head_dim,
int rope_theta, int rope_theta,
const std::vector<int>& mrope_section) { const std::vector<int>& mrope_section,
bool circular_x = false,
bool circular_y = false) {
GGML_ASSERT(bs == 1); GGML_ASSERT(bs == 1);
std::vector<std::vector<float>> ids(static_cast<size_t>(bs) * (context_len + grid_h * grid_w), std::vector<std::vector<float>> ids(static_cast<size_t>(bs) * (context_len + grid_h * grid_w),
std::vector<float>(3, 0.f)); std::vector<float>(3, 0.f));
@ -169,7 +171,29 @@ namespace Ideogram4 {
} }
} }
return Rope::embed_interleaved_mrope(ids, bs, static_cast<float>(rope_theta), head_dim, mrope_section); std::vector<std::vector<int>> axis_wrap_dims(3);
if (circular_y || circular_x) {
size_t total_len = static_cast<size_t>(bs) * (context_len + grid_h * grid_w);
axis_wrap_dims[1].assign(total_len, 0);
axis_wrap_dims[2].assign(total_len, 0);
if (circular_y) {
for (size_t idx = static_cast<size_t>(context_len); idx < total_len; ++idx) {
axis_wrap_dims[1][idx] = grid_h;
}
}
if (circular_x) {
for (size_t idx = static_cast<size_t>(context_len); idx < total_len; ++idx) {
axis_wrap_dims[2][idx] = grid_w;
}
}
}
return Rope::embed_interleaved_mrope(ids,
bs,
static_cast<float>(rope_theta),
head_dim,
mrope_section,
axis_wrap_dims);
} }
class Ideogram4Attention : public GGMLBlock { class Ideogram4Attention : public GGMLBlock {
@ -480,13 +504,16 @@ namespace Ideogram4 {
int64_t pos_len = context_len + grid_h * grid_w; int64_t pos_len = context_len + grid_h * grid_w;
int64_t head_dim = config.emb_dim / config.num_heads; int64_t head_dim = config.emb_dim / config.num_heads;
auto runner_ctx = get_context();
pe_vec = gen_ideogram4_pe(static_cast<int>(grid_h), pe_vec = gen_ideogram4_pe(static_cast<int>(grid_h),
static_cast<int>(grid_w), static_cast<int>(grid_w),
static_cast<int>(x->ne[3]), static_cast<int>(x->ne[3]),
static_cast<int>(context_len), static_cast<int>(context_len),
static_cast<int>(head_dim), static_cast<int>(head_dim),
static_cast<int>(config.rope_theta), static_cast<int>(config.rope_theta),
config.mrope_section); config.mrope_section,
runner_ctx.circular_x_enabled,
runner_ctx.circular_y_enabled);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len);
set_backend_tensor_data(pe, pe_vec.data()); set_backend_tensor_data(pe, pe_vec.data());
@ -497,7 +524,6 @@ namespace Ideogram4 {
auto indicator = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_I32, pos_len, x->ne[3]); auto indicator = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_I32, pos_len, x->ne[3]);
set_backend_tensor_data(indicator, image_indicator_vec.data()); set_backend_tensor_data(indicator, image_indicator_vec.data());
auto runner_ctx = get_context();
ggml_tensor* out = active_model.forward(&runner_ctx, x, timesteps, context, pe, indicator); ggml_tensor* out = active_model.forward(&runner_ctx, x, timesteps, context, pe, indicator);
ggml_build_forward_expand(gf, out); ggml_build_forward_expand(gf, out);
return gf; return gf;