wip pipeline

This commit is contained in:
leejet 2025-12-21 14:23:41 +08:00
parent 58f7b789cb
commit 13a19ee6fd
12 changed files with 371 additions and 200 deletions

View File

@ -16,7 +16,7 @@ struct DiffusionParams {
struct ggml_tensor* y = nullptr; struct ggml_tensor* y = nullptr;
struct ggml_tensor* guidance = nullptr; struct ggml_tensor* guidance = nullptr;
std::vector<ggml_tensor*> ref_latents = {}; std::vector<ggml_tensor*> ref_latents = {};
bool increase_ref_index = false; Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED;
int num_video_frames = -1; int num_video_frames = -1;
std::vector<struct ggml_tensor*> controls = {}; std::vector<struct ggml_tensor*> controls = {};
float control_strength = 0.f; float control_strength = 0.f;
@ -222,7 +222,7 @@ struct FluxModel : public DiffusionModel {
diffusion_params.y, diffusion_params.y,
diffusion_params.guidance, diffusion_params.guidance,
diffusion_params.ref_latents, diffusion_params.ref_latents,
diffusion_params.increase_ref_index, diffusion_params.ref_index_mode,
output, output,
output_ctx, output_ctx,
diffusion_params.skip_layers); diffusion_params.skip_layers);
@ -352,7 +352,7 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.timesteps, diffusion_params.timesteps,
diffusion_params.context, diffusion_params.context,
diffusion_params.ref_latents, diffusion_params.ref_latents,
true, // increase_ref_index Rope::RefIndexMode::INCREASE,
output, output,
output_ctx); output_ctx);
} }
@ -415,7 +415,7 @@ struct ZImageModel : public DiffusionModel {
diffusion_params.timesteps, diffusion_params.timesteps,
diffusion_params.context, diffusion_params.context,
diffusion_params.ref_latents, diffusion_params.ref_latents,
true, // increase_ref_index Rope::RefIndexMode::INCREASE,
output, output,
output_ctx); output_ctx);
} }

View File

@ -615,6 +615,7 @@ int main(int argc, const char* argv[]) {
results = generate_image(sd_ctx, &img_gen_params); results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
num_results = 4;
} else if (cli_params.mode == VID_GEN) { } else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = { sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(), gen_params.lora_vec.data(),

View File

@ -1388,7 +1388,7 @@ namespace Flux {
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false, Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED,
std::vector<int> skip_layers = {}) { std::vector<int> skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(FLUX_GRAPH_SIZE); struct ggml_cgraph* gf = new_graph_custom(FLUX_GRAPH_SIZE);
@ -1426,7 +1426,7 @@ namespace Flux {
std::set<int> txt_arange_dims; std::set<int> txt_arange_dims;
if (sd_version_is_flux2(version)) { if (sd_version_is_flux2(version)) {
txt_arange_dims = {3}; txt_arange_dims = {3};
increase_ref_index = true; ref_index_mode = Rope::RefIndexMode::INCREASE;
} else if (version == VERSION_OVIS_IMAGE) { } else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2}; txt_arange_dims = {1, 2};
} }
@ -1438,7 +1438,7 @@ namespace Flux {
context->ne[1], context->ne[1],
txt_arange_dims, txt_arange_dims,
ref_latents, ref_latents,
increase_ref_index, ref_index_mode,
flux_params.ref_index_scale, flux_params.ref_index_scale,
flux_params.theta, flux_params.theta,
flux_params.axes_dim); flux_params.axes_dim);
@ -1489,7 +1489,7 @@ namespace Flux {
struct ggml_tensor* y, struct ggml_tensor* y,
struct ggml_tensor* guidance, struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false, Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr, struct ggml_context* output_ctx = nullptr,
std::vector<int> skip_layers = std::vector<int>()) { std::vector<int> skip_layers = std::vector<int>()) {
@ -1499,7 +1499,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels] // y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ] // guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers); return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, ref_index_mode, skip_layers);
}; };
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -1542,7 +1542,7 @@ namespace Flux {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx); compute(8, x, timesteps, context, nullptr, y, guidance, {}, Rope::RefIndexMode::FIXED, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);

View File

@ -386,7 +386,7 @@ __STATIC_INLINE__ uint8_t* ggml_tensor_to_sd_image(struct ggml_tensor* input, ui
int64_t width = input->ne[0]; int64_t width = input->ne[0];
int64_t height = input->ne[1]; int64_t height = input->ne[1];
int64_t channels = input->ne[2]; int64_t channels = input->ne[2];
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32); GGML_ASSERT(input->type == GGML_TYPE_F32);
if (image_data == nullptr) { if (image_data == nullptr) {
image_data = (uint8_t*)malloc(width * height * channels); image_data = (uint8_t*)malloc(width * height * channels);
} }
@ -1200,7 +1200,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
bool diag_mask_inf = false, bool diag_mask_inf = false,
bool skip_reshape = false, bool skip_reshape = false,
bool flash_attn = false, bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow float kv_scale = 1.0f / 256.f) { // avoid overflow
int64_t L_q; int64_t L_q;
int64_t L_k; int64_t L_k;
int64_t C; int64_t C;
@ -2142,7 +2142,7 @@ public:
bool bias = true, bool bias = true,
bool force_f32 = false, bool force_f32 = false,
bool force_prec_f32 = false, bool force_prec_f32 = false,
float scale = 1.f) float scale = 1.f / 256.f)
: in_features(in_features), : in_features(in_features),
out_features(out_features), out_features(out_features),
bias(bias), bias(bias),

View File

@ -114,7 +114,7 @@ static inline bool sd_version_is_wan(SDVersion version) {
} }
static inline bool sd_version_is_qwen_image(SDVersion version) { static inline bool sd_version_is_qwen_image(SDVersion version) {
if (version == VERSION_QWEN_IMAGE || VERSION_QWEN_IMAGE_LAYERED) { if (version == VERSION_QWEN_IMAGE || version == VERSION_QWEN_IMAGE_LAYERED) {
return true; return true;
} }
return false; return false;

View File

@ -67,6 +67,7 @@ namespace Qwen {
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f); auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj); auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
if (use_additional_t_cond) { if (use_additional_t_cond) {
GGML_ASSERT(addition_t_cond != nullptr);
auto addition_t_embedding = std::dynamic_pointer_cast<Embedding>(blocks["addition_t_embedding"]); auto addition_t_embedding = std::dynamic_pointer_cast<Embedding>(blocks["addition_t_embedding"]);
auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond); auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond);
@ -382,10 +383,11 @@ namespace Qwen {
struct ggml_tensor* patchify(struct ggml_context* ctx, struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
// x: [N, C, H, W] // x: [N*C, T, H, W]
// return: [N, h*w, C * patch_size * patch_size] // return: [N, T*h*w, C * patch_size * patch_size]
int64_t N = x->ne[3]; int64_t N = 1;
int64_t C = x->ne[2]; int64_t C = x->ne[3] / N;
int64_t T = x->ne[2];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t p = params.patch_size; int64_t p = params.patch_size;
@ -394,27 +396,31 @@ namespace Qwen {
GGML_ASSERT(h * p == H && w * p == W); GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p] x = ggml_reshape_4d(ctx, x, p, w, p, h * T * C * N); // [N*C*T*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*T*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] x = ggml_reshape_4d(ctx, x, p * p, w * h * T, C, N); // [N, C, T*h*w, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, T*h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p] x = ggml_reshape_3d(ctx, x, p * p * C, w * h * T, N); // [N, T*h*w, C*p*p]
return x; return x;
} }
struct ggml_tensor* process_img(struct ggml_context* ctx, struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* x) { struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x); x = pad_to_patch_size(ctx, x);
if (x->ne[3] == 1) { // [N, C, H, W] => [N*C, 1, H, W]
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], 1, x->ne[2]);
}
x = patchify(ctx, x); x = patchify(ctx, x);
return x; return x;
} }
struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int64_t T,
int64_t h, int64_t h,
int64_t w) { int64_t w) {
// x: [N, h*w, C*patch_size*patch_size] // x: [N, T*h*w, C*patch_size*patch_size]
// return: [N, C, H, W] // return: [N*C, T, H, W]
int64_t N = x->ne[2]; int64_t N = x->ne[2];
int64_t C = x->ne[0] / params.patch_size / params.patch_size; int64_t C = x->ne[0] / params.patch_size / params.patch_size;
int64_t H = h * params.patch_size; int64_t H = h * params.patch_size;
@ -423,11 +429,11 @@ namespace Qwen {
GGML_ASSERT(C * p * p == x->ne[0]); GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] x = ggml_reshape_4d(ctx, x, p * p, C, w * h * T, N); // [N, T*h*w, C, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, T*h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] x = ggml_reshape_4d(ctx, x, p, p, w, h * T * C * N); // [N*C*T*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*T*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] x = ggml_reshape_4d(ctx, x, W, H, T, C * N); // [N*C, T, h*p, w*p]
return x; return x;
} }
@ -435,6 +441,7 @@ namespace Qwen {
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* addition_t_cond,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* pe) { struct ggml_tensor* pe) {
auto time_text_embed = std::dynamic_pointer_cast<QwenTimestepProjEmbeddings>(blocks["time_text_embed"]); auto time_text_embed = std::dynamic_pointer_cast<QwenTimestepProjEmbeddings>(blocks["time_text_embed"]);
@ -444,7 +451,7 @@ namespace Qwen {
auto norm_out = std::dynamic_pointer_cast<AdaLayerNormContinuous>(blocks["norm_out"]); auto norm_out = std::dynamic_pointer_cast<AdaLayerNormContinuous>(blocks["norm_out"]);
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]); auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
auto t_emb = time_text_embed->forward(ctx, timestep); auto t_emb = time_text_embed->forward(ctx, timestep, addition_t_cond);
auto img = img_in->forward(ctx, x); auto img = img_in->forward(ctx, x);
auto txt = txt_norm->forward(ctx, context); auto txt = txt_norm->forward(ctx, context);
txt = txt_in->forward(ctx, txt); txt = txt_in->forward(ctx, txt);
@ -466,11 +473,12 @@ namespace Qwen {
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* addition_t_cond,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* pe, struct ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {}) { std::vector<ggml_tensor*> ref_latents = {}) {
// Forward pass of DiT. // Forward pass of DiT.
// x: [N, C, H, W] // x: [N, C, H, W] or [N*C, T, H, W]
// timestep: [N,] // timestep: [N,]
// context: [N, L, D] // context: [N, L, D]
// pe: [L, d_head/2, 2, 2] // pe: [L, d_head/2, 2, 2]
@ -478,8 +486,15 @@ namespace Qwen {
int64_t W = x->ne[0]; int64_t W = x->ne[0];
int64_t H = x->ne[1]; int64_t H = x->ne[1];
int64_t C = x->ne[2]; int64_t T = 1;
int64_t N = x->ne[3]; int64_t N = 1;
int64_t C;
if (x->ne[3] == 1) {
C = x->ne[2];
} else {
T = x->ne[2];
C = x->ne[3];
}
auto img = process_img(ctx->ggml_ctx, x); auto img = process_img(ctx->ggml_ctx, x);
uint64_t img_tokens = img->ne[1]; uint64_t img_tokens = img->ne[1];
@ -494,7 +509,7 @@ namespace Qwen {
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] auto out = forward_orig(ctx, img, timestep, addition_t_cond, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) { if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
@ -502,7 +517,7 @@ namespace Qwen {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
} }
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = unpatchify(ctx->ggml_ctx, out, T, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
// slice // slice
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
@ -517,6 +532,7 @@ namespace Qwen {
QwenImageParams qwen_image_params; QwenImageParams qwen_image_params;
QwenImageModel qwen_image; QwenImageModel qwen_image;
std::vector<float> pe_vec; std::vector<float> pe_vec;
std::vector<int> additional_t_cond_vec;
SDVersion version; SDVersion version;
QwenImageRunner(ggml_backend_t backend, QwenImageRunner(ggml_backend_t backend,
@ -524,7 +540,7 @@ namespace Qwen {
const String2TensorStorage& tensor_storage_map = {}, const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "", const std::string prefix = "",
SDVersion version = VERSION_QWEN_IMAGE) SDVersion version = VERSION_QWEN_IMAGE)
: GGMLRunner(backend, offload_params_to_cpu) { : GGMLRunner(backend, offload_params_to_cpu), version(version) {
qwen_image_params.num_layers = 0; qwen_image_params.num_layers = 0;
for (auto pair : tensor_storage_map) { for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;
@ -563,25 +579,39 @@ namespace Qwen {
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) { Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::INCREASE) {
GGML_ASSERT(x->ne[3] == 1); int N = 1;
int T = 1;
if (x->ne[3] != 1) {
T = x->ne[2];
}
struct ggml_cgraph* gf = new_graph_custom(QWEN_IMAGE_GRAPH_SIZE); struct ggml_cgraph* gf = new_graph_custom(QWEN_IMAGE_GRAPH_SIZE);
x = to_backend(x); x = to_backend(x);
context = to_backend(context); context = to_backend(context);
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
struct ggml_tensor* addition_t_cond = nullptr;
if (version == VERSION_QWEN_IMAGE_LAYERED) {
additional_t_cond_vec = {0};
addition_t_cond = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
set_backend_tensor_data(addition_t_cond, additional_t_cond_vec.data());
ref_index_mode = Rope::RefIndexMode::DECREASE;
}
for (int i = 0; i < ref_latents.size(); i++) { for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]); ref_latents[i] = to_backend(ref_latents[i]);
} }
pe_vec = Rope::gen_qwen_image_pe(x->ne[1], pe_vec = Rope::gen_qwen_image_pe(T,
x->ne[1],
x->ne[0], x->ne[0],
qwen_image_params.patch_size, qwen_image_params.patch_size,
x->ne[3], N,
context->ne[1], context->ne[1],
ref_latents, ref_latents,
increase_ref_index, ref_index_mode,
qwen_image_params.theta, qwen_image_params.theta,
qwen_image_params.axes_dim); qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
@ -597,6 +627,7 @@ namespace Qwen {
struct ggml_tensor* out = qwen_image.forward(&runner_ctx, struct ggml_tensor* out = qwen_image.forward(&runner_ctx,
x, x,
timesteps, timesteps,
addition_t_cond,
context, context,
pe, pe,
ref_latents); ref_latents);
@ -611,14 +642,14 @@ namespace Qwen {
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false, Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::INCREASE,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) { struct ggml_context* output_ctx = nullptr) {
// x: [N, in_channels, h, w] // x: [N, C, H, W] or [N*C, T, H, W]
// timesteps: [N, ] // timesteps: [N, ]
// context: [N, max_position, hidden_size] // context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index); return build_graph(x, timesteps, context, ref_latents, ref_index_mode);
}; };
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -650,7 +681,7 @@ namespace Qwen {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx); compute(8, x, timesteps, context, {}, Rope::RefIndexMode::FIXED, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);

119
rope.hpp
View File

@ -5,6 +5,12 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
namespace Rope { namespace Rope {
enum class RefIndexMode {
FIXED,
INCREASE,
DECREASE,
};
template <class T> template <class T>
__STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) { __STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num); std::vector<T> result(num);
@ -170,21 +176,26 @@ namespace Rope {
int bs, int bs,
int axes_dim_num, int axes_dim_num,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, RefIndexMode ref_index_mode,
float ref_index_scale) { float ref_index_scale) {
int index = 0;
std::vector<std::vector<float>> ids; std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
int index = 1;
for (ggml_tensor* ref : ref_latents) { for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0; uint64_t h_offset = 0;
uint64_t w_offset = 0; uint64_t w_offset = 0;
if (!increase_ref_index) { if (ref_index_mode == RefIndexMode::FIXED) {
index = 1;
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
w_offset = curr_w_offset; w_offset = curr_w_offset;
} else { } else {
h_offset = curr_h_offset; h_offset = curr_h_offset;
} }
} else if (ref_index_mode == RefIndexMode::INCREASE) {
index++;
} else if (ref_index_mode == RefIndexMode::DECREASE) {
index--;
} }
auto ref_ids = gen_flux_img_ids(ref->ne[1], auto ref_ids = gen_flux_img_ids(ref->ne[1],
@ -197,10 +208,6 @@ namespace Rope {
w_offset); w_offset);
ids = concat_ids(ids, ref_ids, bs); ids = concat_ids(ids, ref_ids, bs);
if (increase_ref_index) {
index++;
}
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
} }
@ -215,14 +222,14 @@ namespace Rope {
int context_len, int context_len,
std::set<int> txt_arange_dims, std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, RefIndexMode ref_index_mode,
float ref_index_scale) { float ref_index_scale) {
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) { if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale); auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, ref_index_mode, ref_index_scale);
ids = concat_ids(ids, refs_ids, bs); ids = concat_ids(ids, refs_ids, bs);
} }
return ids; return ids;
@ -236,7 +243,7 @@ namespace Rope {
int context_len, int context_len,
std::set<int> txt_arange_dims, std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, RefIndexMode ref_index_mode,
float ref_index_scale, float ref_index_scale,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
@ -248,52 +255,11 @@ namespace Rope {
context_len, context_len,
txt_arange_dims, txt_arange_dims,
ref_latents, ref_latents,
increase_ref_index, ref_index_mode,
ref_index_scale); ref_index_scale);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len);
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < txt_ids.size(); ++j) {
txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]};
}
}
int axes_dim_num = 3;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate qwen_image positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t, __STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h, int h,
int w, int w,
@ -334,6 +300,49 @@ namespace Rope {
return vid_ids_repeated; return vid_ids_repeated;
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int t,
int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
RefIndexMode ref_index_mode) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len);
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < txt_ids.size(); ++j) {
txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]};
}
}
int axes_dim_num = 3;
auto img_ids = gen_vid_ids(t, h, w, 1, patch_size, patch_size, bs);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, ref_index_mode, 1.f);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate qwen_image positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int t,
int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
RefIndexMode ref_index_mode,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(t, h, w, patch_size, bs, context_len, ref_latents, ref_index_mode);
return embed_nd(ids, bs, theta, axes_dim);
}
// Generate wan positional embeddings // Generate wan positional embeddings
__STATIC_INLINE__ std::vector<float> gen_wan_pe(int t, __STATIC_INLINE__ std::vector<float> gen_wan_pe(int t,
int h, int h,
@ -395,7 +404,7 @@ namespace Rope {
int context_len, int context_len,
int seq_multi_of, int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { RefIndexMode ref_index_mode) {
int padded_context_len = context_len + bound_mod(context_len, seq_multi_of); int padded_context_len = context_len + bound_mod(context_len, seq_multi_of);
auto txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f)); auto txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs * padded_context_len; i++) { for (int i = 0; i < bs * padded_context_len; i++) {
@ -426,10 +435,10 @@ namespace Rope {
int context_len, int context_len,
int seq_multi_of, int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, RefIndexMode ref_index_mode,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, ref_index_mode);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }

View File

@ -1699,7 +1699,7 @@ public:
diffusion_params.timesteps = timesteps; diffusion_params.timesteps = timesteps;
diffusion_params.guidance = guidance_tensor; diffusion_params.guidance = guidance_tensor;
diffusion_params.ref_latents = ref_latents; diffusion_params.ref_latents = ref_latents;
diffusion_params.increase_ref_index = increase_ref_index; diffusion_params.ref_index_mode = increase_ref_index ? Rope::RefIndexMode::INCREASE : Rope::RefIndexMode::FIXED;
diffusion_params.controls = controls; diffusion_params.controls = controls;
diffusion_params.control_strength = control_strength; diffusion_params.control_strength = control_strength;
diffusion_params.vace_context = vace_context; diffusion_params.vace_context = vace_context;
@ -1940,6 +1940,28 @@ public:
return latent_channel; return latent_channel;
} }
int get_image_channels() {
int image_channel = 3;
if (version == VERSION_QWEN_IMAGE_LAYERED) {
image_channel = 4;
}
return image_channel;
}
void ensure_image_channels(sd_image_f32_t* image) {
if (image->channel == get_image_channels()) {
return;
}
if (get_image_channels() == 4) {
sd_image_f32_t new_image = sd_image_to_rgba(*image);
free(image->data);
image->data = new_image.data;
image->channel = new_image.channel;
return;
}
GGML_ABORT("invalid image channels");
}
int get_image_seq_len(int h, int w) { int get_image_seq_len(int h, int w) {
int vae_scale_factor = get_vae_scale_factor(); int vae_scale_factor = get_vae_scale_factor();
return (h / vae_scale_factor) * (w / vae_scale_factor); return (h / vae_scale_factor) * (w / vae_scale_factor);
@ -2265,7 +2287,7 @@ public:
const int vae_scale_factor = get_vae_scale_factor(); const int vae_scale_factor = get_vae_scale_factor();
int64_t W = x->ne[0] * vae_scale_factor; int64_t W = x->ne[0] * vae_scale_factor;
int64_t H = x->ne[1] * vae_scale_factor; int64_t H = x->ne[1] * vae_scale_factor;
int64_t C = 3; int64_t C = get_image_channels();
ggml_tensor* result = nullptr; ggml_tensor* result = nullptr;
if (decode_video) { if (decode_video) {
int T = x->ne[2]; int T = x->ne[2];
@ -3066,7 +3088,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_ctx->sd->rng->manual_seed(cur_seed); sd_ctx->sd->rng->manual_seed(cur_seed);
sd_ctx->sd->sampler_rng->manual_seed(cur_seed); sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
struct ggml_tensor* x_t = init_latent; struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x_t);
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng); ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
int start_merge_step = -1; int start_merge_step = -1;
@ -3121,11 +3143,25 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
for (size_t i = 0; i < final_latents.size(); i++) { for (size_t i = 0; i < final_latents.size(); i++) {
t1 = ggml_time_ms(); t1 = ggml_time_ms();
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */); if (sd_ctx->sd->version == VERSION_QWEN_IMAGE_LAYERED) {
// print_ggml_tensor(img); int layers = 4;
for (int layer_index = 0; layer_index < layers; layer_index++) {
ggml_tensor* final_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, final_latents[i]->ne[0], final_latents[i]->ne[1], final_latents[i]->ne[3], 1);
ggml_ext_tensor_iter(final_latent, [&](ggml_tensor* final_latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(final_latents[i], i0, i1, layer_index + 1, i2);
ggml_ext_tensor_set_f32(final_latent, value, i0, i1, i2, i3);
});
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latent);
if (img != nullptr) { if (img != nullptr) {
decoded_images.push_back(img); decoded_images.push_back(img);
} }
}
} else {
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */);
if (img != nullptr) {
decoded_images.push_back(img);
}
}
int64_t t2 = ggml_time_ms(); int64_t t2 = ggml_time_ms();
LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000);
} }
@ -3138,7 +3174,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_ctx->sd->lora_stat(); sd_ctx->sd->lora_stat();
sd_image_t* result_images = (sd_image_t*)calloc(batch_count, sizeof(sd_image_t)); sd_image_t* result_images = (sd_image_t*)calloc(decoded_images.size(), sizeof(sd_image_t));
if (result_images == nullptr) { if (result_images == nullptr) {
ggml_free(work_ctx); ggml_free(work_ctx);
return nullptr; return nullptr;
@ -3147,7 +3183,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
for (size_t i = 0; i < decoded_images.size(); i++) { for (size_t i = 0; i < decoded_images.size(); i++) {
result_images[i].width = width; result_images[i].width = width;
result_images[i].height = height; result_images[i].height = height;
result_images[i].channel = 3; result_images[i].channel = sd_ctx->sd->get_image_channels();
result_images[i].data = ggml_tensor_to_sd_image(decoded_images[i]); result_images[i].data = ggml_tensor_to_sd_image(decoded_images[i]);
} }
ggml_free(work_ctx); ggml_free(work_ctx);
@ -3159,6 +3195,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width; int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height; int height = sd_img_gen_params->height;
int image_channels = sd_ctx->sd->get_image_channels();
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
@ -3238,9 +3275,14 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
sigmas = sigma_sched; sigmas = sigma_sched;
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, image_channels, 1);
ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1);
sd_image_t init_image = sd_img_gen_params->init_image;
if (image_channels != init_image.channel && image_channels == 4) {
init_image = sd_image_to_rgba(init_image);
}
sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img); sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img);
sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img); sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img);
@ -3333,8 +3375,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (sd_version_is_inpaint(sd_ctx->sd->version)) {
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
} }
if (sd_ctx->sd->version == VERSION_QWEN_IMAGE_LAYERED) {
int layers = 4;
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, layers + 1, true);
} else {
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height); init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height);
} }
}
sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance; sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance;
std::vector<sd_image_t*> ref_images; std::vector<sd_image_t*> ref_images;
@ -3343,7 +3390,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
} }
std::vector<uint8_t> empty_image_data; std::vector<uint8_t> empty_image_data;
sd_image_t empty_image = {(uint32_t)width, (uint32_t)height, 3, nullptr}; sd_image_t empty_image = {(uint32_t)width, (uint32_t)height, image_channels, nullptr};
if (ref_images.empty() && sd_version_is_unet_edit(sd_ctx->sd->version)) { if (ref_images.empty() && sd_version_is_unet_edit(sd_ctx->sd->version)) {
LOG_WARN("This model needs at least one reference image; using an empty reference"); LOG_WARN("This model needs at least one reference image; using an empty reference");
empty_image_data.resize(width * height * 3); empty_image_data.resize(width * height * 3);
@ -3380,11 +3427,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width); LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width);
sd_ctx->sd->ensure_image_channels(&resized_image);
img = ggml_new_tensor_4d(work_ctx, img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32, GGML_TYPE_F32,
resized_image.width, resized_image.width,
resized_image.height, resized_image.height,
3, resized_image.channel,
1); 1);
sd_image_f32_to_ggml_tensor(resized_image, img); sd_image_f32_to_ggml_tensor(resized_image, img);
free(resized_image.data); free(resized_image.data);
@ -3399,8 +3448,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_image_to_ggml_tensor(*ref_images[i], img); sd_image_to_ggml_tensor(*ref_images[i], img);
} }
// print_ggml_tensor(img, false, "img");
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
ref_latents.push_back(latent); ref_latents.push_back(latent);
} }

View File

@ -378,6 +378,72 @@ sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image) {
return converted_image; return converted_image;
} }
sd_image_f32_t sd_image_to_rgba(sd_image_f32_t image) {
sd_image_f32_t rgba_image;
rgba_image.width = image.width;
rgba_image.height = image.height;
rgba_image.channel = 4;
size_t total_pixels = (size_t)image.width * image.height;
rgba_image.data = (float*)malloc(total_pixels * 4 * sizeof(float));
for (size_t i = 0; i < total_pixels; i++) {
if (image.channel == 3) {
// RGB -> RGBA
rgba_image.data[i * 4 + 0] = image.data[i * 3 + 0]; // R
rgba_image.data[i * 4 + 1] = image.data[i * 3 + 1]; // G
rgba_image.data[i * 4 + 2] = image.data[i * 3 + 2]; // B
rgba_image.data[i * 4 + 3] = 1.0f; // A (fully opaque)
} else if (image.channel == 1) {
// Gray -> RGBA
float gray = image.data[i];
rgba_image.data[i * 4 + 0] = gray; // R
rgba_image.data[i * 4 + 1] = gray; // G
rgba_image.data[i * 4 + 2] = gray; // B
rgba_image.data[i * 4 + 3] = 1.0f; // A (fully opaque)
} else if (image.channel == 4) {
// Already RGBA
memcpy(rgba_image.data, image.data, total_pixels * 4 * sizeof(float));
break;
}
}
return rgba_image;
}
sd_image_t sd_image_to_rgba(sd_image_t image) {
sd_image_t rgba_image;
rgba_image.width = image.width;
rgba_image.height = image.height;
rgba_image.channel = 4;
size_t total_pixels = (size_t)image.width * image.height;
rgba_image.data = (uint8_t*)malloc(total_pixels * 4 * sizeof(uint8_t));
for (size_t i = 0; i < total_pixels; i++) {
if (image.channel == 3) {
// RGB -> RGBA
rgba_image.data[i * 4 + 0] = image.data[i * 3 + 0]; // R
rgba_image.data[i * 4 + 1] = image.data[i * 3 + 1]; // G
rgba_image.data[i * 4 + 2] = image.data[i * 3 + 2]; // B
rgba_image.data[i * 4 + 3] = 255; // A (fully opaque)
} else if (image.channel == 1) {
// Gray -> RGBA
float gray = image.data[i];
rgba_image.data[i * 4 + 0] = gray; // R
rgba_image.data[i * 4 + 1] = gray; // G
rgba_image.data[i * 4 + 2] = gray; // B
rgba_image.data[i * 4 + 3] = 255; // A (fully opaque)
} else if (image.channel == 4) {
// Already RGBA
memcpy(rgba_image.data, image.data, total_pixels * 4 * sizeof(uint8_t));
break;
}
}
return rgba_image;
}
// Function to perform double linear interpolation // Function to perform double linear interpolation
float interpolate(float v1, float v2, float v3, float v4, float x_ratio, float y_ratio) { float interpolate(float v1, float v2, float v3, float v4, float x_ratio, float y_ratio) {
return v1 * (1 - x_ratio) * (1 - y_ratio) + v2 * x_ratio * (1 - y_ratio) + v3 * (1 - x_ratio) * y_ratio + v4 * x_ratio * y_ratio; return v1 * (1 - x_ratio) * (1 - y_ratio) + v2 * x_ratio * (1 - y_ratio) + v3 * (1 - x_ratio) * y_ratio + v4 * x_ratio * y_ratio;

3
util.h
View File

@ -39,6 +39,9 @@ void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3
sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image); sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image);
sd_image_f32_t sd_image_to_rgba(sd_image_f32_t image);
sd_image_t sd_image_to_rgba(sd_image_t image);
sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height); sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height); sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);

150
wan.hpp
View File

@ -590,7 +590,7 @@ namespace WAN {
class Encoder3d : public GGMLBlock { class Encoder3d : public GGMLBlock {
protected: protected:
bool wan2_2; bool use_down_res_block;
int64_t dim; int64_t dim;
int64_t z_dim; int64_t z_dim;
std::vector<int> dim_mult; std::vector<int> dim_mult;
@ -600,27 +600,24 @@ namespace WAN {
public: public:
Encoder3d(int64_t dim = 128, Encoder3d(int64_t dim = 128,
int64_t z_dim = 4, int64_t z_dim = 4,
int64_t in_channels = 3,
std::vector<int> dim_mult = {1, 2, 4, 4}, std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2, int num_res_blocks = 2,
std::vector<bool> temperal_downsample = {false, true, true}, std::vector<bool> temperal_downsample = {false, true, true},
bool wan2_2 = false) bool use_down_res_block = false)
: dim(dim), : dim(dim),
z_dim(z_dim), z_dim(z_dim),
dim_mult(dim_mult), dim_mult(dim_mult),
num_res_blocks(num_res_blocks), num_res_blocks(num_res_blocks),
temperal_downsample(temperal_downsample), temperal_downsample(temperal_downsample),
wan2_2(wan2_2) { use_down_res_block(use_down_res_block) {
// attn_scales is always [] // attn_scales is always []
std::vector<int64_t> dims = {dim}; std::vector<int64_t> dims = {dim};
for (int u : dim_mult) { for (int u : dim_mult) {
dims.push_back(dim * u); dims.push_back(dim * u);
} }
if (wan2_2) { blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_channels, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
int index = 0; int index = 0;
int64_t in_dim; int64_t in_dim;
@ -628,7 +625,7 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) { for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i]; in_dim = dims[i];
out_dim = dims[i + 1]; out_dim = dims[i + 1];
if (wan2_2) { if (use_down_res_block) {
bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false; bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false;
auto block = std::shared_ptr<GGMLBlock>(new Down_ResidualBlock(in_dim, auto block = std::shared_ptr<GGMLBlock>(new Down_ResidualBlock(in_dim,
out_dim, out_dim,
@ -702,7 +699,7 @@ namespace WAN {
} }
int index = 0; int index = 0;
for (int i = 0; i < dims.size() - 1; i++) { for (int i = 0; i < dims.size() - 1; i++) {
if (wan2_2) { if (use_down_res_block) {
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]); auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
@ -753,7 +750,7 @@ namespace WAN {
class Decoder3d : public GGMLBlock { class Decoder3d : public GGMLBlock {
protected: protected:
bool wan2_2; bool use_up_res_block;
int64_t dim; int64_t dim;
int64_t z_dim; int64_t z_dim;
std::vector<int> dim_mult; std::vector<int> dim_mult;
@ -763,16 +760,17 @@ namespace WAN {
public: public:
Decoder3d(int64_t dim = 128, Decoder3d(int64_t dim = 128,
int64_t z_dim = 4, int64_t z_dim = 4,
int64_t out_channels = 3,
std::vector<int> dim_mult = {1, 2, 4, 4}, std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2, int num_res_blocks = 2,
std::vector<bool> temperal_upsample = {true, true, false}, std::vector<bool> temperal_upsample = {true, true, false},
bool wan2_2 = false) bool use_up_res_block = false)
: dim(dim), : dim(dim),
z_dim(z_dim), z_dim(z_dim),
dim_mult(dim_mult), dim_mult(dim_mult),
num_res_blocks(num_res_blocks), num_res_blocks(num_res_blocks),
temperal_upsample(temperal_upsample), temperal_upsample(temperal_upsample),
wan2_2(wan2_2) { use_up_res_block(use_up_res_block) {
// attn_scales is always [] // attn_scales is always []
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim}; std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) { for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
@ -794,7 +792,7 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) { for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i]; in_dim = dims[i];
out_dim = dims[i + 1]; out_dim = dims[i + 1];
if (wan2_2) { if (use_up_res_block) {
bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false; bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false;
auto block = std::shared_ptr<GGMLBlock>(new Up_ResidualBlock(in_dim, auto block = std::shared_ptr<GGMLBlock>(new Up_ResidualBlock(in_dim,
out_dim, out_dim,
@ -824,12 +822,7 @@ namespace WAN {
// output blocks // output blocks
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim)); blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
// head.1 is nn.SiLU() // head.1 is nn.SiLU()
if (wan2_2) { blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_channels, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} else {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -878,7 +871,7 @@ namespace WAN {
} }
int index = 0; int index = 0;
for (int i = 0; i < dims.size() - 1; i++) { for (int i = 0; i < dims.size() - 1; i++) {
if (wan2_2) { if (use_up_res_block) {
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]); auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx); x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
@ -922,10 +915,11 @@ namespace WAN {
} }
}; };
class WanVAE : public GGMLBlock { struct WanVAEParams {
public:
bool wan2_2 = false;
bool decode_only = true; bool decode_only = true;
bool use_up_down_res_block = false;
int patch_size = 1;
int64_t input_channels = 3;
int64_t dim = 96; int64_t dim = 96;
int64_t dec_dim = 96; int64_t dec_dim = 96;
int64_t z_dim = 16; int64_t z_dim = 16;
@ -933,39 +927,49 @@ namespace WAN {
int num_res_blocks = 2; int num_res_blocks = 2;
std::vector<bool> temperal_upsample = {true, true, false}; std::vector<bool> temperal_upsample = {true, true, false};
std::vector<bool> temperal_downsample = {false, true, true}; std::vector<bool> temperal_downsample = {false, true, true};
int _conv_num = 33; int _conv_num = 33;
int _enc_conv_num = 28;
};
class WanVAE : public GGMLBlock {
protected:
WanVAEParams vae_params;
public:
int _conv_idx = 0; int _conv_idx = 0;
std::vector<struct ggml_tensor*> _feat_map; std::vector<struct ggml_tensor*> _feat_map;
int _enc_conv_num = 28;
int _enc_conv_idx = 0; int _enc_conv_idx = 0;
std::vector<struct ggml_tensor*> _enc_feat_map; std::vector<struct ggml_tensor*> _enc_feat_map;
void clear_cache() { void clear_cache() {
_conv_idx = 0; _conv_idx = 0;
_feat_map = std::vector<struct ggml_tensor*>(_conv_num, nullptr); _feat_map = std::vector<struct ggml_tensor*>(vae_params._conv_num, nullptr);
_enc_conv_idx = 0; _enc_conv_idx = 0;
_enc_feat_map = std::vector<struct ggml_tensor*>(_enc_conv_num, nullptr); _enc_feat_map = std::vector<struct ggml_tensor*>(vae_params._enc_conv_num, nullptr);
} }
public: public:
WanVAE(bool decode_only = true, bool wan2_2 = false) explicit WanVAE(const WanVAEParams& vae_params)
: decode_only(decode_only), wan2_2(wan2_2) { : vae_params(vae_params) {
// attn_scales is always [] // attn_scales is always []
if (wan2_2) { if (!vae_params.decode_only) {
dim = 160; blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(vae_params.dim,
dec_dim = 256; vae_params.z_dim * 2,
z_dim = 48; vae_params.input_channels,
vae_params.dim_mult,
_conv_num = 34; vae_params.num_res_blocks,
_enc_conv_num = 26; vae_params.temperal_downsample,
vae_params.use_up_down_res_block));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(vae_params.z_dim * 2, vae_params.z_dim * 2, {1, 1, 1}));
} }
if (!decode_only) { blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(vae_params.dec_dim,
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2)); vae_params.z_dim,
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1})); vae_params.input_channels,
} vae_params.dim_mult,
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2)); vae_params.num_res_blocks,
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); vae_params.temperal_upsample,
vae_params.use_up_down_res_block));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(vae_params.z_dim, vae_params.z_dim, {1, 1, 1}));
} }
struct ggml_tensor* patchify(struct ggml_context* ctx, struct ggml_tensor* patchify(struct ggml_context* ctx,
@ -1026,13 +1030,11 @@ namespace WAN {
int64_t b = 1) { int64_t b = 1) {
// x: [b*c, t, h, w] // x: [b*c, t, h, w]
GGML_ASSERT(b == 1); GGML_ASSERT(b == 1);
GGML_ASSERT(decode_only == false); GGML_ASSERT(vae_params.decode_only == false);
clear_cache(); clear_cache();
if (wan2_2) { x = patchify(ctx->ggml_ctx, x, vae_params.patch_size, b);
x = patchify(ctx->ggml_ctx, x, 2, b);
}
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]); auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]); auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
@ -1082,9 +1084,9 @@ namespace WAN {
out = ggml_concat(ctx->ggml_ctx, out, out_, 2); out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
} }
} }
if (wan2_2) {
out = unpatchify(ctx->ggml_ctx, out, 2, b); out = unpatchify(ctx->ggml_ctx, out, vae_params.patch_size, b);
}
clear_cache(); clear_cache();
return out; return out;
} }
@ -1103,16 +1105,14 @@ namespace WAN {
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
_conv_idx = 0; _conv_idx = 0;
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
if (wan2_2) { out = unpatchify(ctx->ggml_ctx, out, vae_params.patch_size, b);
out = unpatchify(ctx->ggml_ctx, out, 2, b);
}
return out; return out;
} }
}; };
struct WanVAERunner : public VAE { struct WanVAERunner : public VAE {
bool decode_only = true; WanVAEParams vae_params;
WanVAE ae; std::unique_ptr<WanVAE> ae;
WanVAERunner(ggml_backend_t backend, WanVAERunner(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
@ -1120,8 +1120,22 @@ namespace WAN {
const std::string prefix = "", const std::string prefix = "",
bool decode_only = false, bool decode_only = false,
SDVersion version = VERSION_WAN2) SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) { : VAE(backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_storage_map, prefix); vae_params.decode_only = decode_only;
if (version == VERSION_WAN2_2_TI2V) {
vae_params.dim = 160;
vae_params.dec_dim = 256;
vae_params.z_dim = 48;
vae_params.input_channels = 12;
vae_params.patch_size = 2;
vae_params.use_up_down_res_block = true;
vae_params._conv_num = 34;
vae_params._enc_conv_num = 26;
} else if (version == VERSION_QWEN_IMAGE_LAYERED) {
vae_params.input_channels = 4;
}
ae = std::make_unique<WanVAE>(vae_params);
ae->init(params_ctx, tensor_storage_map, prefix);
} }
std::string get_desc() override { std::string get_desc() override {
@ -1129,7 +1143,7 @@ namespace WAN {
} }
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override { void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
ae.get_param_tensors(tensors, prefix); ae->get_param_tensors(tensors, prefix);
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
@ -1139,7 +1153,7 @@ namespace WAN {
auto runner_ctx = get_context(); auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z); struct ggml_tensor* out = decode_graph ? ae->decode(&runner_ctx, z) : ae->encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out); ggml_build_forward_expand(gf, out);
@ -1149,21 +1163,21 @@ namespace WAN {
struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) { struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) {
struct ggml_cgraph* gf = new_graph_custom(20480); struct ggml_cgraph* gf = new_graph_custom(20480);
ae.clear_cache(); ae->clear_cache();
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { for (int64_t feat_idx = 0; feat_idx < ae->_feat_map.size(); feat_idx++) {
auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx)); auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx));
ae._feat_map[feat_idx] = feat_cache; ae->_feat_map[feat_idx] = feat_cache;
} }
z = to_backend(z); z = to_backend(z);
auto runner_ctx = get_context(); auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z); struct ggml_tensor* out = decode_graph ? ae->decode_partial(&runner_ctx, z, i) : ae->encode(&runner_ctx, z);
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { for (int64_t feat_idx = 0; feat_idx < ae->_feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx]; ggml_tensor* feat_cache = ae->_feat_map[feat_idx];
if (feat_cache != nullptr) { if (feat_cache != nullptr) {
cache("feat_idx:" + std::to_string(feat_idx), feat_cache); cache("feat_idx:" + std::to_string(feat_idx), feat_cache);
ggml_build_forward_expand(gf, feat_cache); ggml_build_forward_expand(gf, feat_cache);
@ -1186,7 +1200,7 @@ namespace WAN {
}; };
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} else { // chunk 1 result is weird } else { // chunk 1 result is weird
ae.clear_cache(); ae->clear_cache();
int64_t t = z->ne[2]; int64_t t = z->ne[2];
int64_t i = 0; int64_t i = 0;
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
@ -1194,7 +1208,7 @@ namespace WAN {
}; };
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
bool res = GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx); bool res = GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
ae.clear_cache(); ae->clear_cache();
if (t == 1) { if (t == 1) {
*output = out; *output = out;
return res; return res;
@ -1222,7 +1236,7 @@ namespace WAN {
for (i = 1; i < t; i++) { for (i = 1; i < t; i++) {
res = res || GGMLRunner::compute(get_graph, n_threads, true, &out); res = res || GGMLRunner::compute(get_graph, n_threads, true, &out);
ae.clear_cache(); ae->clear_cache();
copy_to_output(); copy_to_output();
} }
free_cache_ctx_and_buffer(); free_cache_ctx_and_buffer();

View File

@ -531,7 +531,7 @@ namespace ZImage {
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) { Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE); struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE);
@ -550,7 +550,7 @@ namespace ZImage {
context->ne[1], context->ne[1],
SEQ_MULTI_OF, SEQ_MULTI_OF,
ref_latents, ref_latents,
increase_ref_index, Rope::RefIndexMode::INCREASE,
z_image_params.theta, z_image_params.theta,
z_image_params.axes_dim); z_image_params.axes_dim);
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
@ -579,14 +579,14 @@ namespace ZImage {
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false, Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) { struct ggml_context* output_ctx = nullptr) {
// x: [N, in_channels, h, w] // x: [N, in_channels, h, w]
// timesteps: [N, ] // timesteps: [N, ]
// context: [N, max_position, hidden_size] // context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index); return build_graph(x, timesteps, context, ref_latents, ref_index_mode);
}; };
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -618,7 +618,7 @@ namespace ZImage {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx); compute(8, x, timesteps, context, {}, Rope::RefIndexMode::INCREASE, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);