mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 18:53:36 +00:00
feat: add scale_rope support (#1121)
This commit is contained in:
parent
60abda56e0
commit
88ec9d30b1
28
rope.hpp
28
rope.hpp
@ -91,14 +91,23 @@ namespace Rope {
|
|||||||
int axes_dim_num,
|
int axes_dim_num,
|
||||||
int index = 0,
|
int index = 0,
|
||||||
int h_offset = 0,
|
int h_offset = 0,
|
||||||
int w_offset = 0) {
|
int w_offset = 0,
|
||||||
|
bool scale_rope = false) {
|
||||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||||
|
|
||||||
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
|
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
|
||||||
|
|
||||||
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
|
int h_start = h_offset;
|
||||||
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
|
int w_start = w_offset;
|
||||||
|
|
||||||
|
if (scale_rope) {
|
||||||
|
h_start -= h_len / 2;
|
||||||
|
w_start -= w_len / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len);
|
||||||
|
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len);
|
||||||
|
|
||||||
for (int i = 0; i < h_len; ++i) {
|
for (int i = 0; i < h_len; ++i) {
|
||||||
for (int j = 0; j < w_len; ++j) {
|
for (int j = 0; j < w_len; ++j) {
|
||||||
@ -171,7 +180,8 @@ namespace Rope {
|
|||||||
int axes_dim_num,
|
int axes_dim_num,
|
||||||
const std::vector<ggml_tensor*>& ref_latents,
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
bool increase_ref_index,
|
bool increase_ref_index,
|
||||||
float ref_index_scale) {
|
float ref_index_scale,
|
||||||
|
bool scale_rope) {
|
||||||
std::vector<std::vector<float>> ids;
|
std::vector<std::vector<float>> ids;
|
||||||
uint64_t curr_h_offset = 0;
|
uint64_t curr_h_offset = 0;
|
||||||
uint64_t curr_w_offset = 0;
|
uint64_t curr_w_offset = 0;
|
||||||
@ -185,6 +195,7 @@ namespace Rope {
|
|||||||
} else {
|
} else {
|
||||||
h_offset = curr_h_offset;
|
h_offset = curr_h_offset;
|
||||||
}
|
}
|
||||||
|
scale_rope = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ref_ids = gen_flux_img_ids(ref->ne[1],
|
auto ref_ids = gen_flux_img_ids(ref->ne[1],
|
||||||
@ -194,7 +205,8 @@ namespace Rope {
|
|||||||
axes_dim_num,
|
axes_dim_num,
|
||||||
static_cast<int>(index * ref_index_scale),
|
static_cast<int>(index * ref_index_scale),
|
||||||
h_offset,
|
h_offset,
|
||||||
w_offset);
|
w_offset,
|
||||||
|
scale_rope);
|
||||||
ids = concat_ids(ids, ref_ids, bs);
|
ids = concat_ids(ids, ref_ids, bs);
|
||||||
|
|
||||||
if (increase_ref_index) {
|
if (increase_ref_index) {
|
||||||
@ -222,7 +234,7 @@ namespace Rope {
|
|||||||
|
|
||||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
|
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false);
|
||||||
ids = concat_ids(ids, refs_ids, bs);
|
ids = concat_ids(ids, refs_ids, bs);
|
||||||
}
|
}
|
||||||
return ids;
|
return ids;
|
||||||
@ -271,10 +283,10 @@ namespace Rope {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
int axes_dim_num = 3;
|
int axes_dim_num = 3;
|
||||||
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
|
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true);
|
||||||
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
|
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true);
|
||||||
ids = concat_ids(ids, refs_ids, bs);
|
ids = concat_ids(ids, refs_ids, bs);
|
||||||
}
|
}
|
||||||
return ids;
|
return ids;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user