longcat rope ids

This commit is contained in:
Stéphane du Hamel 2025-12-06 02:44:20 +01:00
parent 7ba7febef2
commit 1241323c4a
4 changed files with 56 additions and 23 deletions

View File

@ -1807,6 +1807,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else if (sd_version_is_longcat(version)) {
prompt_template_encode_start_idx = 36;
// prompt_template_encode_end_idx = 5;
prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt_template_encode_start_idx = 34;

View File

@ -1341,7 +1341,7 @@ namespace Flux {
}
if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style naming");
LOG_INFO("Using diffusers-style attention blocks");
}
flux = Flux(flux_params);
@ -1455,7 +1455,6 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2};
}
pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
flux_params.patch_size,
@ -1466,9 +1465,9 @@ namespace Flux {
increase_ref_index,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);
flux_params.axes_dim,
sd_version_is_longcat(version));
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);

View File

@ -2214,7 +2214,7 @@ public:
in_features(in_features),
out_features_vec(out_features_vec),
bias(bias),
force_f32(true),
force_f32(force_f32),
force_prec_f32(force_prec_f32),
scale(scale) {}
@ -2224,21 +2224,29 @@ public:
if (bias) {
b = params["bias"];
}
// concat all weights and biases together
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
if (ctx->weight_adapter) {
// concat all weights and biases together so it runs in one linear layer
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0);
}
return x0;
}
};

View File

@ -84,7 +84,16 @@ namespace Rope {
return txt_ids;
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
__STATIC_INLINE__ std::vector<std::vector<float>> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) {
auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][1] = (i % context_len);
txt_ids[i][2] = (i % context_len);
}
return txt_ids;
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
int w,
int patch_size,
int bs,
@ -94,7 +103,6 @@ namespace Rope {
int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
@ -169,13 +177,14 @@ namespace Rope {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs,
int axes_dim_num,
int start_index,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
int index = 1;
int index = start_index;
for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0;
uint64_t w_offset = 0;
@ -216,13 +225,17 @@ namespace Rope {
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
float ref_index_scale,
bool is_longcat) {
int start_index = is_longcat ? 1 : 0;
auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
int offset = is_longcat ? context_len : 0;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
@ -239,7 +252,8 @@ namespace Rope {
bool increase_ref_index,
float ref_index_scale,
int theta,
const std::vector<int>& axes_dim) {
const std::vector<int>& axes_dim,
bool is_longcat) {
std::vector<std::vector<float>> ids = gen_flux_ids(h,
w,
patch_size,
@ -249,7 +263,8 @@ namespace Rope {
txt_arange_dims,
ref_latents,
increase_ref_index,
ref_index_scale);
ref_index_scale,
is_longcat);
return embed_nd(ids, bs, theta, axes_dim);
}
@ -274,7 +289,7 @@ namespace Rope {
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;