mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add ref latent support for qwen image
This commit is contained in:
parent
94f4f295c1
commit
4e48e6b82b
@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
diffusion_params.ref_latents,
|
||||
diffusion_params.increase_ref_index,
|
||||
output,
|
||||
output_ctx);
|
||||
}
|
||||
|
||||
@ -256,7 +256,7 @@ namespace Qwen {
|
||||
auto txt_gate1 = txt_mod_param_vec[2];
|
||||
|
||||
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe);
|
||||
|
||||
|
||||
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1));
|
||||
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1));
|
||||
|
||||
@ -386,6 +386,13 @@ namespace Qwen {
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
x = pad_to_patch_size(ctx, x);
|
||||
x = patchify(ctx, x);
|
||||
return x;
|
||||
}
|
||||
|
||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int64_t h,
|
||||
@ -446,7 +453,8 @@ namespace Qwen {
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timestep,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
std::vector<ggml_tensor*> ref_latents = {}) {
|
||||
// Forward pass of DiT.
|
||||
// x: [N, C, H, W]
|
||||
// timestep: [N,]
|
||||
@ -459,13 +467,26 @@ namespace Qwen {
|
||||
int64_t C = x->ne[2];
|
||||
int64_t N = x->ne[3];
|
||||
|
||||
x = pad_to_patch_size(ctx, x);
|
||||
x = patchify(ctx, x);
|
||||
auto img = process_img(ctx, x);
|
||||
uint64_t img_tokens = img->ne[1];
|
||||
|
||||
if (ref_latents.size() > 0) {
|
||||
for (ggml_tensor* ref : ref_latents) {
|
||||
ref = process_img(ctx, ref);
|
||||
img = ggml_concat(ctx, img, ref, 1);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
||||
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
||||
|
||||
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
||||
auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
||||
|
||||
if (out->ne[1] > img_tokens) {
|
||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
||||
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
||||
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
||||
}
|
||||
|
||||
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
||||
|
||||
@ -506,7 +527,9 @@ namespace Qwen {
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context) {
|
||||
struct ggml_tensor* context,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false) {
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
|
||||
|
||||
@ -514,18 +537,24 @@ namespace Qwen {
|
||||
context = to_backend(context);
|
||||
timesteps = to_backend(timesteps);
|
||||
|
||||
for (int i = 0; i < ref_latents.size(); i++) {
|
||||
ref_latents[i] = to_backend(ref_latents[i]);
|
||||
}
|
||||
|
||||
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
|
||||
x->ne[0],
|
||||
qwen_image_params.patch_size,
|
||||
x->ne[3],
|
||||
context->ne[1],
|
||||
ref_latents,
|
||||
increase_ref_index,
|
||||
qwen_image_params.theta,
|
||||
qwen_image_params.axes_dim);
|
||||
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
|
||||
// LOG_DEBUG("pos_len %d", pos_len);
|
||||
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
|
||||
// pe->data = pe_vec.data();
|
||||
// print_ggml_tensor(pe);
|
||||
// print_ggml_tensor(pe, true, "pe");
|
||||
// pe->data = NULL;
|
||||
set_backend_tensor_data(pe, pe_vec.data());
|
||||
|
||||
@ -534,7 +563,8 @@ namespace Qwen {
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
pe);
|
||||
pe,
|
||||
ref_latents);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
@ -545,13 +575,15 @@ namespace Qwen {
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context);
|
||||
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
@ -583,7 +615,7 @@ namespace Qwen {
|
||||
struct ggml_tensor* out = NULL;
|
||||
|
||||
int t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, &out, work_ctx);
|
||||
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
|
||||
int t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
|
||||
44
rope.hpp
44
rope.hpp
@ -151,17 +151,11 @@ struct Rope {
|
||||
return flatten(emb);
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> gen_flux_ids(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
static std::vector<std::vector<float>> gen_refs_ids(int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
std::vector<ggml_tensor*> ref_latents,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index) {
|
||||
auto txt_ids = gen_txt_ids(bs, context_len);
|
||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
std::vector<std::vector<float>> ids;
|
||||
uint64_t curr_h_offset = 0;
|
||||
uint64_t curr_w_offset = 0;
|
||||
int index = 1;
|
||||
@ -189,13 +183,31 @@ struct Rope {
|
||||
return ids;
|
||||
}
|
||||
|
||||
static std::vector<std::vector<float>> gen_flux_ids(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index) {
|
||||
auto txt_ids = gen_txt_ids(bs, context_len);
|
||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
if (ref_latents.size() > 0) {
|
||||
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
|
||||
ids = concat_ids(ids, refs_ids, bs);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
// Generate flux positional embeddings
|
||||
static std::vector<float> gen_flux_pe(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
std::vector<ggml_tensor*> ref_latents,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index,
|
||||
int theta,
|
||||
const std::vector<int>& axes_dim) {
|
||||
@ -207,7 +219,9 @@ struct Rope {
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len) {
|
||||
int context_len,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index) {
|
||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||
int txt_id_start = std::max(h_len, w_len);
|
||||
@ -220,6 +234,10 @@ struct Rope {
|
||||
}
|
||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
||||
if (ref_latents.size() > 0) {
|
||||
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
|
||||
ids = concat_ids(ids, refs_ids, bs);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
@ -229,9 +247,11 @@ struct Rope {
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index,
|
||||
int theta,
|
||||
const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len);
|
||||
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
|
||||
return embed_nd(ids, bs, theta, axes_dim);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user