mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 10:43:35 +00:00
feat: add scale_rope support (#1121)
This commit is contained in:
parent
60abda56e0
commit
88ec9d30b1
28
rope.hpp
28
rope.hpp
@ -91,14 +91,23 @@ namespace Rope {
|
||||
int axes_dim_num,
|
||||
int index = 0,
|
||||
int h_offset = 0,
|
||||
int w_offset = 0) {
|
||||
int w_offset = 0,
|
||||
bool scale_rope = false) {
|
||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||
|
||||
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
|
||||
|
||||
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
|
||||
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
|
||||
int h_start = h_offset;
|
||||
int w_start = w_offset;
|
||||
|
||||
if (scale_rope) {
|
||||
h_start -= h_len / 2;
|
||||
w_start -= w_len / 2;
|
||||
}
|
||||
|
||||
std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len);
|
||||
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len);
|
||||
|
||||
for (int i = 0; i < h_len; ++i) {
|
||||
for (int j = 0; j < w_len; ++j) {
|
||||
@ -171,7 +180,8 @@ namespace Rope {
|
||||
int axes_dim_num,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index,
|
||||
float ref_index_scale) {
|
||||
float ref_index_scale,
|
||||
bool scale_rope) {
|
||||
std::vector<std::vector<float>> ids;
|
||||
uint64_t curr_h_offset = 0;
|
||||
uint64_t curr_w_offset = 0;
|
||||
@ -185,6 +195,7 @@ namespace Rope {
|
||||
} else {
|
||||
h_offset = curr_h_offset;
|
||||
}
|
||||
scale_rope = false;
|
||||
}
|
||||
|
||||
auto ref_ids = gen_flux_img_ids(ref->ne[1],
|
||||
@ -194,7 +205,8 @@ namespace Rope {
|
||||
axes_dim_num,
|
||||
static_cast<int>(index * ref_index_scale),
|
||||
h_offset,
|
||||
w_offset);
|
||||
w_offset,
|
||||
scale_rope);
|
||||
ids = concat_ids(ids, ref_ids, bs);
|
||||
|
||||
if (increase_ref_index) {
|
||||
@ -222,7 +234,7 @@ namespace Rope {
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
if (ref_latents.size() > 0) {
|
||||
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
|
||||
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false);
|
||||
ids = concat_ids(ids, refs_ids, bs);
|
||||
}
|
||||
return ids;
|
||||
@ -271,10 +283,10 @@ namespace Rope {
|
||||
}
|
||||
}
|
||||
int axes_dim_num = 3;
|
||||
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
|
||||
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true);
|
||||
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
||||
if (ref_latents.size() > 0) {
|
||||
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
|
||||
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true);
|
||||
ids = concat_ids(ids, refs_ids, bs);
|
||||
}
|
||||
return ids;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user