correct rope offset for image tokens

stuff
This commit is contained in:
Stéphane du Hamel 2025-12-06 16:06:32 +01:00
parent 37c5e3eca4
commit a907fe2851
3 changed files with 18 additions and 13 deletions

View File

@ -2238,15 +2238,15 @@ public:
forward_params.linear.scale = scale; forward_params.linear.scale = scale;
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
} }
auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
for (int i = 1; i < out_features_vec.size(); i++) { for (int i = 1; i < out_features_vec.size(); i++) {
auto wi = params["weight." + std::to_string(i)]; auto wi = params["weight." + std::to_string(i)];
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0); out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0);
} }
return x0; return out;
} }
}; };

View File

@ -180,10 +180,11 @@ namespace Rope {
int start_index, int start_index,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale) { float ref_index_scale,
int base_offset = 0) {
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = base_offset;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = base_offset;
int index = start_index; int index = start_index;
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0; uint64_t h_offset = 0;
@ -227,15 +228,15 @@ namespace Rope {
bool increase_ref_index, bool increase_ref_index,
float ref_index_scale, float ref_index_scale,
bool is_longcat) { bool is_longcat) {
int start_index = is_longcat ? 1 : 0; int x_index = is_longcat ? 1 : 0;
auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
int offset = is_longcat ? context_len : 0; int offset = is_longcat ? context_len : 0;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset); auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset);
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, offset);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;

View File

@ -456,6 +456,9 @@ public:
sd_ctx_params->chroma_use_dit_mask); sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_longcat(version)) { } else if (sd_version_is_longcat(version)) {
bool enable_vision = false; bool enable_vision = false;
if (!vae_decode_only) {
enable_vision = true;
}
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend, cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
@ -2244,6 +2247,7 @@ public:
sd_version_is_qwen_image(version) || sd_version_is_qwen_image(version) ||
sd_version_is_wan(version) || sd_version_is_wan(version) ||
sd_version_is_flux2(version) || sd_version_is_flux2(version) ||
sd_version_is_longcat(version) ||
version == VERSION_CHROMA_RADIANCE) { version == VERSION_CHROMA_RADIANCE) {
latent = vae_output; latent = vae_output;
} else if (version == VERSION_SD1_PIX2PIX) { } else if (version == VERSION_SD1_PIX2PIX) {