longcat rope ids

This commit is contained in:
Stéphane du Hamel 2025-12-06 02:44:20 +01:00
parent 7ba7febef2
commit 1241323c4a
4 changed files with 56 additions and 23 deletions

View File

@ -1807,6 +1807,17 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size()); prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else if (sd_version_is_longcat(version)) {
prompt_template_encode_start_idx = 36;
// prompt_template_encode_end_idx = 5;
prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else { } else {
prompt_template_encode_start_idx = 34; prompt_template_encode_start_idx = 34;

View File

@ -1341,7 +1341,7 @@ namespace Flux {
} }
if (flux_params.diffusers_style) { if (flux_params.diffusers_style) {
LOG_INFO("Using diffusers-style naming"); LOG_INFO("Using diffusers-style attention blocks");
} }
flux = Flux(flux_params); flux = Flux(flux_params);
@ -1455,7 +1455,6 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) { } else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2}; txt_arange_dims = {1, 2};
} }
pe_vec = Rope::gen_flux_pe(x->ne[1], pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0], x->ne[0],
flux_params.patch_size, flux_params.patch_size,
@ -1466,9 +1465,9 @@ namespace Flux {
increase_ref_index, increase_ref_index,
flux_params.ref_index_scale, flux_params.ref_index_scale,
flux_params.theta, flux_params.theta,
flux_params.axes_dim); flux_params.axes_dim,
sd_version_is_longcat(version));
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
// print_ggml_tensor(pe); // print_ggml_tensor(pe);

View File

@ -2214,7 +2214,7 @@ public:
in_features(in_features), in_features(in_features),
out_features_vec(out_features_vec), out_features_vec(out_features_vec),
bias(bias), bias(bias),
force_f32(true), force_f32(force_f32),
force_prec_f32(force_prec_f32), force_prec_f32(force_prec_f32),
scale(scale) {} scale(scale) {}
@ -2224,21 +2224,29 @@ public:
if (bias) { if (bias) {
b = params["bias"]; b = params["bias"];
} }
// concat all weights and biases together
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
if (ctx->weight_adapter) { if (ctx->weight_adapter) {
// concat all weights and biases together so it runs in one linear layer
for (int i = 1; i < out_features_vec.size(); i++) {
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
if (bias) {
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
}
}
WeightAdapter::ForwardParams forward_params; WeightAdapter::ForwardParams forward_params;
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
forward_params.linear.force_prec_f32 = force_prec_f32; forward_params.linear.force_prec_f32 = force_prec_f32;
forward_params.linear.scale = scale; forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
} }
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0);
}
return x0;
} }
}; };

View File

@ -84,7 +84,16 @@ namespace Rope {
return txt_ids; return txt_ids;
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) {
auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
for (int i = 0; i < bs * context_len; i++) {
txt_ids[i][1] = (i % context_len);
txt_ids[i][2] = (i % context_len);
}
return txt_ids;
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
@ -94,7 +103,6 @@ namespace Rope {
int w_offset = 0) { int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0)); std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len); std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
@ -169,13 +177,14 @@ namespace Rope {
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size, __STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs, int bs,
int axes_dim_num, int axes_dim_num,
int start_index,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale) { float ref_index_scale) {
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
int index = 1; int index = start_index;
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0; uint64_t h_offset = 0;
uint64_t w_offset = 0; uint64_t w_offset = 0;
@ -216,13 +225,17 @@ namespace Rope {
std::set<int> txt_arange_dims, std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale) { float ref_index_scale,
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); bool is_longcat) {
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); int start_index = is_longcat ? 1 : 0;
auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
int offset = is_longcat ? context_len : 0;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset);
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;
@ -239,7 +252,8 @@ namespace Rope {
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale, float ref_index_scale,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim,
bool is_longcat) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, std::vector<std::vector<float>> ids = gen_flux_ids(h,
w, w,
patch_size, patch_size,
@ -249,7 +263,8 @@ namespace Rope {
txt_arange_dims, txt_arange_dims,
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
ref_index_scale); ref_index_scale,
is_longcat);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
@ -274,7 +289,7 @@ namespace Rope {
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;