correct rope offset for image tokens

stuff
This commit is contained in:
Stéphane du Hamel 2025-12-06 16:06:32 +01:00
parent 37c5e3eca4
commit a907fe2851
3 changed files with 18 additions and 13 deletions

View File

@ -2238,15 +2238,15 @@ public:
forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
}
auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0);
auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0);
}
return x0;
return out;
}
};

View File

@ -180,10 +180,11 @@ namespace Rope {
int start_index,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
float ref_index_scale,
int base_offset = 0) {
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
uint64_t curr_h_offset = base_offset;
uint64_t curr_w_offset = base_offset;
int index = start_index;
for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0;
@ -227,15 +228,15 @@ namespace Rope {
bool increase_ref_index,
float ref_index_scale,
bool is_longcat) {
int start_index = is_longcat ? 1 : 0;
int x_index = is_longcat ? 1 : 0;
auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
int offset = is_longcat ? context_len : 0;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, offset);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;

View File

@ -456,6 +456,9 @@ public:
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_longcat(version)) {
bool enable_vision = false;
if (!vae_decode_only) {
enable_vision = true;
}
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map,
@ -850,7 +853,7 @@ public:
flow_shift = 1.15f;
}
}
if(sd_version_is_longcat(version)) {
if (sd_version_is_longcat(version)) {
flow_shift = 3.0f;
}
}
@ -2244,6 +2247,7 @@ public:
sd_version_is_qwen_image(version) ||
sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
sd_version_is_longcat(version) ||
version == VERSION_CHROMA_RADIANCE) {
latent = vae_output;
} else if (version == VERSION_SD1_PIX2PIX) {