fix gen_z_image_ids

This commit is contained in:
leejet 2025-12-01 02:27:00 +08:00
parent 1798ec02ba
commit a96a1e7b93

View File

@ -401,7 +401,7 @@ namespace Rope {
int index = padded_context_len + 1;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index);
int img_pad_len = ((-context_len) % seq_multi_of);
int img_pad_len = bound_mod(static_cast<int>(img_ids.size() / bs), seq_multi_of);
if (img_pad_len > 0) {
std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f));
img_ids = concat_ids(img_ids, img_pad_ids, bs);