mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add ref latent support for qwen image
This commit is contained in:
parent
94f4f295c1
commit
4e48e6b82b
@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
|
|||||||
diffusion_params.x,
|
diffusion_params.x,
|
||||||
diffusion_params.timesteps,
|
diffusion_params.timesteps,
|
||||||
diffusion_params.context,
|
diffusion_params.context,
|
||||||
|
diffusion_params.ref_latents,
|
||||||
|
diffusion_params.increase_ref_index,
|
||||||
output,
|
output,
|
||||||
output_ctx);
|
output_ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -386,6 +386,13 @@ namespace Qwen {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
x = pad_to_patch_size(ctx, x);
|
||||||
|
x = patchify(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t h,
|
int64_t h,
|
||||||
@ -446,7 +453,8 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timestep,
|
struct ggml_tensor* timestep,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
struct ggml_tensor* pe) {
|
struct ggml_tensor* pe,
|
||||||
|
std::vector<ggml_tensor*> ref_latents = {}) {
|
||||||
// Forward pass of DiT.
|
// Forward pass of DiT.
|
||||||
// x: [N, C, H, W]
|
// x: [N, C, H, W]
|
||||||
// timestep: [N,]
|
// timestep: [N,]
|
||||||
@ -459,13 +467,26 @@ namespace Qwen {
|
|||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
|
|
||||||
x = pad_to_patch_size(ctx, x);
|
auto img = process_img(ctx, x);
|
||||||
x = patchify(ctx, x);
|
uint64_t img_tokens = img->ne[1];
|
||||||
|
|
||||||
|
if (ref_latents.size() > 0) {
|
||||||
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
|
ref = process_img(ctx, ref);
|
||||||
|
img = ggml_concat(ctx, img, ref, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
||||||
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
||||||
|
|
||||||
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
||||||
|
|
||||||
|
if (out->ne[1] > img_tokens) {
|
||||||
|
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
||||||
|
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
|
||||||
|
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
|
||||||
|
}
|
||||||
|
|
||||||
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
|
||||||
@ -506,7 +527,9 @@ namespace Qwen {
|
|||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
struct ggml_tensor* context) {
|
struct ggml_tensor* context,
|
||||||
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
|
bool increase_ref_index = false) {
|
||||||
GGML_ASSERT(x->ne[3] == 1);
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
|
||||||
|
|
||||||
@ -514,18 +537,24 @@ namespace Qwen {
|
|||||||
context = to_backend(context);
|
context = to_backend(context);
|
||||||
timesteps = to_backend(timesteps);
|
timesteps = to_backend(timesteps);
|
||||||
|
|
||||||
|
for (int i = 0; i < ref_latents.size(); i++) {
|
||||||
|
ref_latents[i] = to_backend(ref_latents[i]);
|
||||||
|
}
|
||||||
|
|
||||||
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
|
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
|
||||||
x->ne[0],
|
x->ne[0],
|
||||||
qwen_image_params.patch_size,
|
qwen_image_params.patch_size,
|
||||||
x->ne[3],
|
x->ne[3],
|
||||||
context->ne[1],
|
context->ne[1],
|
||||||
|
ref_latents,
|
||||||
|
increase_ref_index,
|
||||||
qwen_image_params.theta,
|
qwen_image_params.theta,
|
||||||
qwen_image_params.axes_dim);
|
qwen_image_params.axes_dim);
|
||||||
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
|
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
|
||||||
// LOG_DEBUG("pos_len %d", pos_len);
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
|
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
|
||||||
// pe->data = pe_vec.data();
|
// pe->data = pe_vec.data();
|
||||||
// print_ggml_tensor(pe);
|
// print_ggml_tensor(pe, true, "pe");
|
||||||
// pe->data = NULL;
|
// pe->data = NULL;
|
||||||
set_backend_tensor_data(pe, pe_vec.data());
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
@ -534,7 +563,8 @@ namespace Qwen {
|
|||||||
x,
|
x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
pe);
|
pe,
|
||||||
|
ref_latents);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, out);
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
@ -545,13 +575,15 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
struct ggml_tensor** output = NULL,
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
struct ggml_context* output_ctx = NULL) {
|
bool increase_ref_index = false,
|
||||||
|
struct ggml_tensor** output = NULL,
|
||||||
|
struct ggml_context* output_ctx = NULL) {
|
||||||
// x: [N, in_channels, h, w]
|
// x: [N, in_channels, h, w]
|
||||||
// timesteps: [N, ]
|
// timesteps: [N, ]
|
||||||
// context: [N, max_position, hidden_size]
|
// context: [N, max_position, hidden_size]
|
||||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
return build_graph(x, timesteps, context);
|
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
|
||||||
};
|
};
|
||||||
|
|
||||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
@ -583,7 +615,7 @@ namespace Qwen {
|
|||||||
struct ggml_tensor* out = NULL;
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
int t0 = ggml_time_ms();
|
int t0 = ggml_time_ms();
|
||||||
compute(8, x, timesteps, context, &out, work_ctx);
|
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
|
||||||
int t1 = ggml_time_ms();
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
print_ggml_tensor(out);
|
print_ggml_tensor(out);
|
||||||
|
|||||||
44
rope.hpp
44
rope.hpp
@ -151,17 +151,11 @@ struct Rope {
|
|||||||
return flatten(emb);
|
return flatten(emb);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<std::vector<float>> gen_flux_ids(int h,
|
static std::vector<std::vector<float>> gen_refs_ids(int patch_size,
|
||||||
int w,
|
|
||||||
int patch_size,
|
|
||||||
int bs,
|
int bs,
|
||||||
int context_len,
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
std::vector<ggml_tensor*> ref_latents,
|
|
||||||
bool increase_ref_index) {
|
bool increase_ref_index) {
|
||||||
auto txt_ids = gen_txt_ids(bs, context_len);
|
std::vector<std::vector<float>> ids;
|
||||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
|
||||||
|
|
||||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
|
||||||
uint64_t curr_h_offset = 0;
|
uint64_t curr_h_offset = 0;
|
||||||
uint64_t curr_w_offset = 0;
|
uint64_t curr_w_offset = 0;
|
||||||
int index = 1;
|
int index = 1;
|
||||||
@ -189,13 +183,31 @@ struct Rope {
|
|||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> gen_flux_ids(int h,
|
||||||
|
int w,
|
||||||
|
int patch_size,
|
||||||
|
int bs,
|
||||||
|
int context_len,
|
||||||
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
|
bool increase_ref_index) {
|
||||||
|
auto txt_ids = gen_txt_ids(bs, context_len);
|
||||||
|
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||||
|
|
||||||
|
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||||
|
if (ref_latents.size() > 0) {
|
||||||
|
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
|
||||||
|
ids = concat_ids(ids, refs_ids, bs);
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
// Generate flux positional embeddings
|
// Generate flux positional embeddings
|
||||||
static std::vector<float> gen_flux_pe(int h,
|
static std::vector<float> gen_flux_pe(int h,
|
||||||
int w,
|
int w,
|
||||||
int patch_size,
|
int patch_size,
|
||||||
int bs,
|
int bs,
|
||||||
int context_len,
|
int context_len,
|
||||||
std::vector<ggml_tensor*> ref_latents,
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
bool increase_ref_index,
|
bool increase_ref_index,
|
||||||
int theta,
|
int theta,
|
||||||
const std::vector<int>& axes_dim) {
|
const std::vector<int>& axes_dim) {
|
||||||
@ -207,7 +219,9 @@ struct Rope {
|
|||||||
int w,
|
int w,
|
||||||
int patch_size,
|
int patch_size,
|
||||||
int bs,
|
int bs,
|
||||||
int context_len) {
|
int context_len,
|
||||||
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
|
bool increase_ref_index) {
|
||||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||||
int txt_id_start = std::max(h_len, w_len);
|
int txt_id_start = std::max(h_len, w_len);
|
||||||
@ -220,6 +234,10 @@ struct Rope {
|
|||||||
}
|
}
|
||||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||||
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
|
||||||
|
if (ref_latents.size() > 0) {
|
||||||
|
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
|
||||||
|
ids = concat_ids(ids, refs_ids, bs);
|
||||||
|
}
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,9 +247,11 @@ struct Rope {
|
|||||||
int patch_size,
|
int patch_size,
|
||||||
int bs,
|
int bs,
|
||||||
int context_len,
|
int context_len,
|
||||||
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
|
bool increase_ref_index,
|
||||||
int theta,
|
int theta,
|
||||||
const std::vector<int>& axes_dim) {
|
const std::vector<int>& axes_dim) {
|
||||||
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len);
|
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
|
||||||
return embed_nd(ids, bs, theta, axes_dim);
|
return embed_nd(ids, bs, theta, axes_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user