wip pipeline

This commit is contained in:
leejet 2025-12-21 14:23:41 +08:00
parent 58f7b789cb
commit 13a19ee6fd
12 changed files with 371 additions and 200 deletions

View File

@ -16,7 +16,7 @@ struct DiffusionParams {
struct ggml_tensor* y = nullptr;
struct ggml_tensor* guidance = nullptr;
std::vector<ggml_tensor*> ref_latents = {};
bool increase_ref_index = false;
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED;
int num_video_frames = -1;
std::vector<struct ggml_tensor*> controls = {};
float control_strength = 0.f;
@ -222,7 +222,7 @@ struct FluxModel : public DiffusionModel {
diffusion_params.y,
diffusion_params.guidance,
diffusion_params.ref_latents,
diffusion_params.increase_ref_index,
diffusion_params.ref_index_mode,
output,
output_ctx,
diffusion_params.skip_layers);
@ -352,7 +352,7 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.ref_latents,
true, // increase_ref_index
Rope::RefIndexMode::INCREASE,
output,
output_ctx);
}
@ -415,7 +415,7 @@ struct ZImageModel : public DiffusionModel {
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.ref_latents,
true, // increase_ref_index
Rope::RefIndexMode::INCREASE,
output,
output_ctx);
}

View File

@ -615,6 +615,7 @@ int main(int argc, const char* argv[]) {
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
num_results = 4;
} else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(),

View File

@ -1388,7 +1388,7 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED,
std::vector<int> skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(FLUX_GRAPH_SIZE);
@ -1426,7 +1426,7 @@ namespace Flux {
std::set<int> txt_arange_dims;
if (sd_version_is_flux2(version)) {
txt_arange_dims = {3};
increase_ref_index = true;
ref_index_mode = Rope::RefIndexMode::INCREASE;
} else if (version == VERSION_OVIS_IMAGE) {
txt_arange_dims = {1, 2};
}
@ -1438,7 +1438,7 @@ namespace Flux {
context->ne[1],
txt_arange_dims,
ref_latents,
increase_ref_index,
ref_index_mode,
flux_params.ref_index_scale,
flux_params.theta,
flux_params.axes_dim);
@ -1489,7 +1489,7 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr,
std::vector<int> skip_layers = std::vector<int>()) {
@ -1499,7 +1499,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, ref_index_mode, skip_layers);
};
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -1542,7 +1542,7 @@ namespace Flux {
struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, nullptr, y, guidance, {}, false, &out, work_ctx);
compute(8, x, timesteps, context, nullptr, y, guidance, {}, Rope::RefIndexMode::FIXED, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);

View File

@ -386,7 +386,7 @@ __STATIC_INLINE__ uint8_t* ggml_tensor_to_sd_image(struct ggml_tensor* input, ui
int64_t width = input->ne[0];
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32);
GGML_ASSERT(input->type == GGML_TYPE_F32);
if (image_data == nullptr) {
image_data = (uint8_t*)malloc(width * height * channels);
}
@ -1200,7 +1200,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
bool diag_mask_inf = false,
bool skip_reshape = false,
bool flash_attn = false,
float kv_scale = 1.0f) { // avoid overflow
float kv_scale = 1.0f / 256.f) { // avoid overflow
int64_t L_q;
int64_t L_k;
int64_t C;
@ -2142,7 +2142,7 @@ public:
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false,
float scale = 1.f)
float scale = 1.f / 256.f)
: in_features(in_features),
out_features(out_features),
bias(bias),

View File

@ -114,7 +114,7 @@ static inline bool sd_version_is_wan(SDVersion version) {
}
static inline bool sd_version_is_qwen_image(SDVersion version) {
if (version == VERSION_QWEN_IMAGE || VERSION_QWEN_IMAGE_LAYERED) {
if (version == VERSION_QWEN_IMAGE || version == VERSION_QWEN_IMAGE_LAYERED) {
return true;
}
return false;

View File

@ -67,6 +67,7 @@ namespace Qwen {
auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f);
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
if (use_additional_t_cond) {
GGML_ASSERT(addition_t_cond != nullptr);
auto addition_t_embedding = std::dynamic_pointer_cast<Embedding>(blocks["addition_t_embedding"]);
auto addition_t_emb = addition_t_embedding->forward(ctx, addition_t_cond);
@ -382,10 +383,11 @@ namespace Qwen {
struct ggml_tensor* patchify(struct ggml_context* ctx,
struct ggml_tensor* x) {
// x: [N, C, H, W]
// return: [N, h*w, C * patch_size * patch_size]
int64_t N = x->ne[3];
int64_t C = x->ne[2];
// x: [N*C, T, H, W]
// return: [N, T*h*w, C * patch_size * patch_size]
int64_t N = 1;
int64_t C = x->ne[3] / N;
int64_t T = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
int64_t p = params.patch_size;
@ -394,27 +396,31 @@ namespace Qwen {
GGML_ASSERT(h * p == H && w * p == W);
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, N); // [N, h*w, C*p*p]
x = ggml_reshape_4d(ctx, x, p, w, p, h * T * C * N); // [N*C*T*h, p, w, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*T*h, w, p, p]
x = ggml_reshape_4d(ctx, x, p * p, w * h * T, C, N); // [N, C, T*h*w, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, T*h*w, C, p*p]
x = ggml_reshape_3d(ctx, x, p * p * C, w * h * T, N); // [N, T*h*w, C*p*p]
return x;
}
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
if (x->ne[3] == 1) { // [N, C, H, W] => [N*C, 1, H, W]
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], 1, x->ne[2]);
}
x = patchify(ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t T,
int64_t h,
int64_t w) {
// x: [N, h*w, C*patch_size*patch_size]
// return: [N, C, H, W]
// x: [N, T*h*w, C*patch_size*patch_size]
// return: [N*C, T, H, W]
int64_t N = x->ne[2];
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
int64_t H = h * params.patch_size;
@ -423,11 +429,11 @@ namespace Qwen {
GGML_ASSERT(C * p * p == x->ne[0]);
x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p]
x = ggml_reshape_4d(ctx, x, p * p, C, w * h * T, N); // [N, T*h*w, C, p*p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, T*h*w, p*p]
x = ggml_reshape_4d(ctx, x, p, p, w, h * T * C * N); // [N*C*T*h, w, p, p]
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*T*h, p, w, p]
x = ggml_reshape_4d(ctx, x, W, H, T, C * N); // [N*C, T, h*p, w*p]
return x;
}
@ -435,6 +441,7 @@ namespace Qwen {
struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* addition_t_cond,
struct ggml_tensor* context,
struct ggml_tensor* pe) {
auto time_text_embed = std::dynamic_pointer_cast<QwenTimestepProjEmbeddings>(blocks["time_text_embed"]);
@ -444,7 +451,7 @@ namespace Qwen {
auto norm_out = std::dynamic_pointer_cast<AdaLayerNormContinuous>(blocks["norm_out"]);
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
auto t_emb = time_text_embed->forward(ctx, timestep);
auto t_emb = time_text_embed->forward(ctx, timestep, addition_t_cond);
auto img = img_in->forward(ctx, x);
auto txt = txt_norm->forward(ctx, context);
txt = txt_in->forward(ctx, txt);
@ -466,11 +473,12 @@ namespace Qwen {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* addition_t_cond,
struct ggml_tensor* context,
struct ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {}) {
// Forward pass of DiT.
// x: [N, C, H, W]
// x: [N, C, H, W] or [N*C, T, H, W]
// timestep: [N,]
// context: [N, L, D]
// pe: [L, d_head/2, 2, 2]
@ -478,8 +486,15 @@ namespace Qwen {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t N = x->ne[3];
int64_t T = 1;
int64_t N = 1;
int64_t C;
if (x->ne[3] == 1) {
C = x->ne[2];
} else {
T = x->ne[2];
C = x->ne[3];
}
auto img = process_img(ctx->ggml_ctx, x);
uint64_t img_tokens = img->ne[1];
@ -494,7 +509,7 @@ namespace Qwen {
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
auto out = forward_orig(ctx, img, timestep, addition_t_cond, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
@ -502,7 +517,7 @@ namespace Qwen {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
out = unpatchify(ctx->ggml_ctx, out, T, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
// slice
out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w]
@ -517,6 +532,7 @@ namespace Qwen {
QwenImageParams qwen_image_params;
QwenImageModel qwen_image;
std::vector<float> pe_vec;
std::vector<int> additional_t_cond_vec;
SDVersion version;
QwenImageRunner(ggml_backend_t backend,
@ -524,7 +540,7 @@ namespace Qwen {
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
SDVersion version = VERSION_QWEN_IMAGE)
: GGMLRunner(backend, offload_params_to_cpu) {
: GGMLRunner(backend, offload_params_to_cpu), version(version) {
qwen_image_params.num_layers = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
@ -563,25 +579,39 @@ namespace Qwen {
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) {
GGML_ASSERT(x->ne[3] == 1);
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::INCREASE) {
int N = 1;
int T = 1;
if (x->ne[3] != 1) {
T = x->ne[2];
}
struct ggml_cgraph* gf = new_graph_custom(QWEN_IMAGE_GRAPH_SIZE);
x = to_backend(x);
context = to_backend(context);
timesteps = to_backend(timesteps);
struct ggml_tensor* addition_t_cond = nullptr;
if (version == VERSION_QWEN_IMAGE_LAYERED) {
additional_t_cond_vec = {0};
addition_t_cond = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1);
set_backend_tensor_data(addition_t_cond, additional_t_cond_vec.data());
ref_index_mode = Rope::RefIndexMode::DECREASE;
}
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
pe_vec = Rope::gen_qwen_image_pe(T,
x->ne[1],
x->ne[0],
qwen_image_params.patch_size,
x->ne[3],
N,
context->ne[1],
ref_latents,
increase_ref_index,
ref_index_mode,
qwen_image_params.theta,
qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
@ -597,6 +627,7 @@ namespace Qwen {
struct ggml_tensor* out = qwen_image.forward(&runner_ctx,
x,
timesteps,
addition_t_cond,
context,
pe,
ref_latents);
@ -611,14 +642,14 @@ namespace Qwen {
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::INCREASE,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) {
// x: [N, in_channels, h, w]
// x: [N, C, H, W] or [N*C, T, H, W]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
return build_graph(x, timesteps, context, ref_latents, ref_index_mode);
};
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -650,7 +681,7 @@ namespace Qwen {
struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
compute(8, x, timesteps, context, {}, Rope::RefIndexMode::FIXED, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);

119
rope.hpp
View File

@ -5,6 +5,12 @@
#include "ggml_extend.hpp"
namespace Rope {
enum class RefIndexMode {
FIXED,
INCREASE,
DECREASE,
};
template <class T>
__STATIC_INLINE__ std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num);
@ -170,21 +176,26 @@ namespace Rope {
int bs,
int axes_dim_num,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
RefIndexMode ref_index_mode,
float ref_index_scale) {
int index = 0;
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
int index = 1;
for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0;
uint64_t w_offset = 0;
if (!increase_ref_index) {
if (ref_index_mode == RefIndexMode::FIXED) {
index = 1;
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
w_offset = curr_w_offset;
} else {
h_offset = curr_h_offset;
}
} else if (ref_index_mode == RefIndexMode::INCREASE) {
index++;
} else if (ref_index_mode == RefIndexMode::DECREASE) {
index--;
}
auto ref_ids = gen_flux_img_ids(ref->ne[1],
@ -197,10 +208,6 @@ namespace Rope {
w_offset);
ids = concat_ids(ids, ref_ids, bs);
if (increase_ref_index) {
index++;
}
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
}
@ -215,14 +222,14 @@ namespace Rope {
int context_len,
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
RefIndexMode ref_index_mode,
float ref_index_scale) {
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, ref_index_mode, ref_index_scale);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
@ -236,7 +243,7 @@ namespace Rope {
int context_len,
std::set<int> txt_arange_dims,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
RefIndexMode ref_index_mode,
float ref_index_scale,
int theta,
const std::vector<int>& axes_dim) {
@ -248,52 +255,11 @@ namespace Rope {
context_len,
txt_arange_dims,
ref_latents,
increase_ref_index,
ref_index_mode,
ref_index_scale);
return embed_nd(ids, bs, theta, axes_dim);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len);
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < txt_ids.size(); ++j) {
txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]};
}
}
int axes_dim_num = 3;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate qwen_image positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_vid_ids(int t,
int h,
int w,
@ -334,6 +300,49 @@ namespace Rope {
return vid_ids_repeated;
}
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int t,
int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
RefIndexMode ref_index_mode) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len);
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < txt_ids.size(); ++j) {
txt_ids_repeated[i * txt_ids.size() + j] = {txt_ids[j], txt_ids[j], txt_ids[j]};
}
}
int axes_dim_num = 3;
auto img_ids = gen_vid_ids(t, h, w, 1, patch_size, patch_size, bs);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, ref_index_mode, 1.f);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate qwen_image positional embeddings
__STATIC_INLINE__ std::vector<float> gen_qwen_image_pe(int t,
int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
RefIndexMode ref_index_mode,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(t, h, w, patch_size, bs, context_len, ref_latents, ref_index_mode);
return embed_nd(ids, bs, theta, axes_dim);
}
// Generate wan positional embeddings
__STATIC_INLINE__ std::vector<float> gen_wan_pe(int t,
int h,
@ -395,7 +404,7 @@ namespace Rope {
int context_len,
int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
RefIndexMode ref_index_mode) {
int padded_context_len = context_len + bound_mod(context_len, seq_multi_of);
auto txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs * padded_context_len; i++) {
@ -426,10 +435,10 @@ namespace Rope {
int context_len,
int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
RefIndexMode ref_index_mode,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, ref_index_mode);
return embed_nd(ids, bs, theta, axes_dim);
}

View File

@ -1699,7 +1699,7 @@ public:
diffusion_params.timesteps = timesteps;
diffusion_params.guidance = guidance_tensor;
diffusion_params.ref_latents = ref_latents;
diffusion_params.increase_ref_index = increase_ref_index;
diffusion_params.ref_index_mode = increase_ref_index ? Rope::RefIndexMode::INCREASE : Rope::RefIndexMode::FIXED;
diffusion_params.controls = controls;
diffusion_params.control_strength = control_strength;
diffusion_params.vace_context = vace_context;
@ -1940,6 +1940,28 @@ public:
return latent_channel;
}
int get_image_channels() {
int image_channel = 3;
if (version == VERSION_QWEN_IMAGE_LAYERED) {
image_channel = 4;
}
return image_channel;
}
void ensure_image_channels(sd_image_f32_t* image) {
if (image->channel == get_image_channels()) {
return;
}
if (get_image_channels() == 4) {
sd_image_f32_t new_image = sd_image_to_rgba(*image);
free(image->data);
image->data = new_image.data;
image->channel = new_image.channel;
return;
}
GGML_ABORT("invalid image channels");
}
int get_image_seq_len(int h, int w) {
int vae_scale_factor = get_vae_scale_factor();
return (h / vae_scale_factor) * (w / vae_scale_factor);
@ -2265,7 +2287,7 @@ public:
const int vae_scale_factor = get_vae_scale_factor();
int64_t W = x->ne[0] * vae_scale_factor;
int64_t H = x->ne[1] * vae_scale_factor;
int64_t C = 3;
int64_t C = get_image_channels();
ggml_tensor* result = nullptr;
if (decode_video) {
int T = x->ne[2];
@ -3066,7 +3088,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_ctx->sd->rng->manual_seed(cur_seed);
sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x_t);
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
int start_merge_step = -1;
@ -3121,11 +3143,25 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
for (size_t i = 0; i < final_latents.size(); i++) {
t1 = ggml_time_ms();
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */);
// print_ggml_tensor(img);
if (sd_ctx->sd->version == VERSION_QWEN_IMAGE_LAYERED) {
int layers = 4;
for (int layer_index = 0; layer_index < layers; layer_index++) {
ggml_tensor* final_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, final_latents[i]->ne[0], final_latents[i]->ne[1], final_latents[i]->ne[3], 1);
ggml_ext_tensor_iter(final_latent, [&](ggml_tensor* final_latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(final_latents[i], i0, i1, layer_index + 1, i2);
ggml_ext_tensor_set_f32(final_latent, value, i0, i1, i2, i3);
});
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latent);
if (img != nullptr) {
decoded_images.push_back(img);
}
}
} else {
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */);
if (img != nullptr) {
decoded_images.push_back(img);
}
}
int64_t t2 = ggml_time_ms();
LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000);
}
@ -3138,7 +3174,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
sd_ctx->sd->lora_stat();
sd_image_t* result_images = (sd_image_t*)calloc(batch_count, sizeof(sd_image_t));
sd_image_t* result_images = (sd_image_t*)calloc(decoded_images.size(), sizeof(sd_image_t));
if (result_images == nullptr) {
ggml_free(work_ctx);
return nullptr;
@ -3147,7 +3183,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
for (size_t i = 0; i < decoded_images.size(); i++) {
result_images[i].width = width;
result_images[i].height = height;
result_images[i].channel = 3;
result_images[i].channel = sd_ctx->sd->get_image_channels();
result_images[i].data = ggml_tensor_to_sd_image(decoded_images[i]);
}
ggml_free(work_ctx);
@ -3159,6 +3195,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width;
int height = sd_img_gen_params->height;
int image_channels = sd_ctx->sd->get_image_channels();
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
@ -3238,9 +3275,14 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end());
sigmas = sigma_sched;
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, image_channels, 1);
ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1);
sd_image_t init_image = sd_img_gen_params->init_image;
if (image_channels != init_image.channel && image_channels == 4) {
init_image = sd_image_to_rgba(init_image);
}
sd_image_to_ggml_tensor(sd_img_gen_params->mask_image, mask_img);
sd_image_to_ggml_tensor(sd_img_gen_params->init_image, init_img);
@ -3333,8 +3375,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
}
if (sd_ctx->sd->version == VERSION_QWEN_IMAGE_LAYERED) {
int layers = 4;
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, layers + 1, true);
} else {
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height);
}
}
sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance;
std::vector<sd_image_t*> ref_images;
@ -3343,7 +3390,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
}
std::vector<uint8_t> empty_image_data;
sd_image_t empty_image = {(uint32_t)width, (uint32_t)height, 3, nullptr};
sd_image_t empty_image = {(uint32_t)width, (uint32_t)height, image_channels, nullptr};
if (ref_images.empty() && sd_version_is_unet_edit(sd_ctx->sd->version)) {
LOG_WARN("This model needs at least one reference image; using an empty reference");
empty_image_data.resize(width * height * 3);
@ -3380,11 +3427,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width);
sd_ctx->sd->ensure_image_channels(&resized_image);
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
resized_image.width,
resized_image.height,
3,
resized_image.channel,
1);
sd_image_f32_to_ggml_tensor(resized_image, img);
free(resized_image.data);
@ -3399,8 +3448,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_image_to_ggml_tensor(*ref_images[i], img);
}
// print_ggml_tensor(img, false, "img");
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
ref_latents.push_back(latent);
}

View File

@ -378,6 +378,72 @@ sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image) {
return converted_image;
}
sd_image_f32_t sd_image_to_rgba(sd_image_f32_t image) {
sd_image_f32_t rgba_image;
rgba_image.width = image.width;
rgba_image.height = image.height;
rgba_image.channel = 4;
size_t total_pixels = (size_t)image.width * image.height;
rgba_image.data = (float*)malloc(total_pixels * 4 * sizeof(float));
for (size_t i = 0; i < total_pixels; i++) {
if (image.channel == 3) {
// RGB -> RGBA
rgba_image.data[i * 4 + 0] = image.data[i * 3 + 0]; // R
rgba_image.data[i * 4 + 1] = image.data[i * 3 + 1]; // G
rgba_image.data[i * 4 + 2] = image.data[i * 3 + 2]; // B
rgba_image.data[i * 4 + 3] = 1.0f; // A (fully opaque)
} else if (image.channel == 1) {
// Gray -> RGBA
float gray = image.data[i];
rgba_image.data[i * 4 + 0] = gray; // R
rgba_image.data[i * 4 + 1] = gray; // G
rgba_image.data[i * 4 + 2] = gray; // B
rgba_image.data[i * 4 + 3] = 1.0f; // A (fully opaque)
} else if (image.channel == 4) {
// Already RGBA
memcpy(rgba_image.data, image.data, total_pixels * 4 * sizeof(float));
break;
}
}
return rgba_image;
}
sd_image_t sd_image_to_rgba(sd_image_t image) {
sd_image_t rgba_image;
rgba_image.width = image.width;
rgba_image.height = image.height;
rgba_image.channel = 4;
size_t total_pixels = (size_t)image.width * image.height;
rgba_image.data = (uint8_t*)malloc(total_pixels * 4 * sizeof(uint8_t));
for (size_t i = 0; i < total_pixels; i++) {
if (image.channel == 3) {
// RGB -> RGBA
rgba_image.data[i * 4 + 0] = image.data[i * 3 + 0]; // R
rgba_image.data[i * 4 + 1] = image.data[i * 3 + 1]; // G
rgba_image.data[i * 4 + 2] = image.data[i * 3 + 2]; // B
rgba_image.data[i * 4 + 3] = 255; // A (fully opaque)
} else if (image.channel == 1) {
// Gray -> RGBA
float gray = image.data[i];
rgba_image.data[i * 4 + 0] = gray; // R
rgba_image.data[i * 4 + 1] = gray; // G
rgba_image.data[i * 4 + 2] = gray; // B
rgba_image.data[i * 4 + 3] = 255; // A (fully opaque)
} else if (image.channel == 4) {
// Already RGBA
memcpy(rgba_image.data, image.data, total_pixels * 4 * sizeof(uint8_t));
break;
}
}
return rgba_image;
}
// Function to perform double linear interpolation
float interpolate(float v1, float v2, float v3, float v4, float x_ratio, float y_ratio) {
return v1 * (1 - x_ratio) * (1 - y_ratio) + v2 * x_ratio * (1 - y_ratio) + v3 * (1 - x_ratio) * y_ratio + v4 * x_ratio * y_ratio;

3
util.h
View File

@ -39,6 +39,9 @@ void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3
sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image);
sd_image_f32_t sd_image_to_rgba(sd_image_f32_t image);
sd_image_t sd_image_to_rgba(sd_image_t image);
sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height);

150
wan.hpp
View File

@ -590,7 +590,7 @@ namespace WAN {
class Encoder3d : public GGMLBlock {
protected:
bool wan2_2;
bool use_down_res_block;
int64_t dim;
int64_t z_dim;
std::vector<int> dim_mult;
@ -600,27 +600,24 @@ namespace WAN {
public:
Encoder3d(int64_t dim = 128,
int64_t z_dim = 4,
int64_t in_channels = 3,
std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2,
std::vector<bool> temperal_downsample = {false, true, true},
bool wan2_2 = false)
bool use_down_res_block = false)
: dim(dim),
z_dim(z_dim),
dim_mult(dim_mult),
num_res_blocks(num_res_blocks),
temperal_downsample(temperal_downsample),
wan2_2(wan2_2) {
use_down_res_block(use_down_res_block) {
// attn_scales is always []
std::vector<int64_t> dims = {dim};
for (int u : dim_mult) {
dims.push_back(dim * u);
}
if (wan2_2) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(12, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} else {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(3, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(in_channels, dims[0], {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
int index = 0;
int64_t in_dim;
@ -628,7 +625,7 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i];
out_dim = dims[i + 1];
if (wan2_2) {
if (use_down_res_block) {
bool t_down_flag = i < temperal_downsample.size() ? temperal_downsample[i] : false;
auto block = std::shared_ptr<GGMLBlock>(new Down_ResidualBlock(in_dim,
out_dim,
@ -702,7 +699,7 @@ namespace WAN {
}
int index = 0;
for (int i = 0; i < dims.size() - 1; i++) {
if (wan2_2) {
if (use_down_res_block) {
auto layer = std::dynamic_pointer_cast<Down_ResidualBlock>(blocks["downsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
@ -753,7 +750,7 @@ namespace WAN {
class Decoder3d : public GGMLBlock {
protected:
bool wan2_2;
bool use_up_res_block;
int64_t dim;
int64_t z_dim;
std::vector<int> dim_mult;
@ -763,16 +760,17 @@ namespace WAN {
public:
Decoder3d(int64_t dim = 128,
int64_t z_dim = 4,
int64_t out_channels = 3,
std::vector<int> dim_mult = {1, 2, 4, 4},
int num_res_blocks = 2,
std::vector<bool> temperal_upsample = {true, true, false},
bool wan2_2 = false)
bool use_up_res_block = false)
: dim(dim),
z_dim(z_dim),
dim_mult(dim_mult),
num_res_blocks(num_res_blocks),
temperal_upsample(temperal_upsample),
wan2_2(wan2_2) {
use_up_res_block(use_up_res_block) {
// attn_scales is always []
std::vector<int64_t> dims = {dim_mult[dim_mult.size() - 1] * dim};
for (int i = static_cast<int>(dim_mult.size()) - 1; i >= 0; i--) {
@ -794,7 +792,7 @@ namespace WAN {
for (int i = 0; i < dims.size() - 1; i++) {
in_dim = dims[i];
out_dim = dims[i + 1];
if (wan2_2) {
if (use_up_res_block) {
bool t_up_flag = i < temperal_upsample.size() ? temperal_upsample[i] : false;
auto block = std::shared_ptr<GGMLBlock>(new Up_ResidualBlock(in_dim,
out_dim,
@ -824,12 +822,7 @@ namespace WAN {
// output blocks
blocks["head.0"] = std::shared_ptr<GGMLBlock>(new RMS_norm(out_dim));
// head.1 is nn.SiLU()
if (wan2_2) {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 12, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
} else {
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, 3, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
blocks["head.2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(out_dim, out_channels, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
@ -878,7 +871,7 @@ namespace WAN {
}
int index = 0;
for (int i = 0; i < dims.size() - 1; i++) {
if (wan2_2) {
if (use_up_res_block) {
auto layer = std::dynamic_pointer_cast<Up_ResidualBlock>(blocks["upsamples." + std::to_string(index++)]);
x = layer->forward(ctx, x, b, feat_cache, feat_idx, chunk_idx);
@ -922,10 +915,11 @@ namespace WAN {
}
};
class WanVAE : public GGMLBlock {
public:
bool wan2_2 = false;
struct WanVAEParams {
bool decode_only = true;
bool use_up_down_res_block = false;
int patch_size = 1;
int64_t input_channels = 3;
int64_t dim = 96;
int64_t dec_dim = 96;
int64_t z_dim = 16;
@ -933,39 +927,49 @@ namespace WAN {
int num_res_blocks = 2;
std::vector<bool> temperal_upsample = {true, true, false};
std::vector<bool> temperal_downsample = {false, true, true};
int _conv_num = 33;
int _enc_conv_num = 28;
};
class WanVAE : public GGMLBlock {
protected:
WanVAEParams vae_params;
public:
int _conv_idx = 0;
std::vector<struct ggml_tensor*> _feat_map;
int _enc_conv_num = 28;
int _enc_conv_idx = 0;
std::vector<struct ggml_tensor*> _enc_feat_map;
void clear_cache() {
_conv_idx = 0;
_feat_map = std::vector<struct ggml_tensor*>(_conv_num, nullptr);
_feat_map = std::vector<struct ggml_tensor*>(vae_params._conv_num, nullptr);
_enc_conv_idx = 0;
_enc_feat_map = std::vector<struct ggml_tensor*>(_enc_conv_num, nullptr);
_enc_feat_map = std::vector<struct ggml_tensor*>(vae_params._enc_conv_num, nullptr);
}
public:
WanVAE(bool decode_only = true, bool wan2_2 = false)
: decode_only(decode_only), wan2_2(wan2_2) {
explicit WanVAE(const WanVAEParams& vae_params)
: vae_params(vae_params) {
// attn_scales is always []
if (wan2_2) {
dim = 160;
dec_dim = 256;
z_dim = 48;
_conv_num = 34;
_enc_conv_num = 26;
if (!vae_params.decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(vae_params.dim,
vae_params.z_dim * 2,
vae_params.input_channels,
vae_params.dim_mult,
vae_params.num_res_blocks,
vae_params.temperal_downsample,
vae_params.use_up_down_res_block));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(vae_params.z_dim * 2, vae_params.z_dim * 2, {1, 1, 1}));
}
if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, temperal_downsample, wan2_2));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim * 2, z_dim * 2, {1, 1, 1}));
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, temperal_upsample, wan2_2));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder3d(vae_params.dec_dim,
vae_params.z_dim,
vae_params.input_channels,
vae_params.dim_mult,
vae_params.num_res_blocks,
vae_params.temperal_upsample,
vae_params.use_up_down_res_block));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(vae_params.z_dim, vae_params.z_dim, {1, 1, 1}));
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
@ -1026,13 +1030,11 @@ namespace WAN {
int64_t b = 1) {
// x: [b*c, t, h, w]
GGML_ASSERT(b == 1);
GGML_ASSERT(decode_only == false);
GGML_ASSERT(vae_params.decode_only == false);
clear_cache();
if (wan2_2) {
x = patchify(ctx->ggml_ctx, x, 2, b);
}
x = patchify(ctx->ggml_ctx, x, vae_params.patch_size, b);
auto encoder = std::dynamic_pointer_cast<Encoder3d>(blocks["encoder"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
@ -1082,9 +1084,9 @@ namespace WAN {
out = ggml_concat(ctx->ggml_ctx, out, out_, 2);
}
}
if (wan2_2) {
out = unpatchify(ctx->ggml_ctx, out, 2, b);
}
out = unpatchify(ctx->ggml_ctx, out, vae_params.patch_size, b);
clear_cache();
return out;
}
@ -1103,16 +1105,14 @@ namespace WAN {
auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
_conv_idx = 0;
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i);
if (wan2_2) {
out = unpatchify(ctx->ggml_ctx, out, 2, b);
}
out = unpatchify(ctx->ggml_ctx, out, vae_params.patch_size, b);
return out;
}
};
struct WanVAERunner : public VAE {
bool decode_only = true;
WanVAE ae;
WanVAEParams vae_params;
std::unique_ptr<WanVAE> ae;
WanVAERunner(ggml_backend_t backend,
bool offload_params_to_cpu,
@ -1120,8 +1120,22 @@ namespace WAN {
const std::string prefix = "",
bool decode_only = false,
SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_storage_map, prefix);
: VAE(backend, offload_params_to_cpu) {
vae_params.decode_only = decode_only;
if (version == VERSION_WAN2_2_TI2V) {
vae_params.dim = 160;
vae_params.dec_dim = 256;
vae_params.z_dim = 48;
vae_params.input_channels = 12;
vae_params.patch_size = 2;
vae_params.use_up_down_res_block = true;
vae_params._conv_num = 34;
vae_params._enc_conv_num = 26;
} else if (version == VERSION_QWEN_IMAGE_LAYERED) {
vae_params.input_channels = 4;
}
ae = std::make_unique<WanVAE>(vae_params);
ae->init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
@ -1129,7 +1143,7 @@ namespace WAN {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
ae.get_param_tensors(tensors, prefix);
ae->get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
@ -1139,7 +1153,7 @@ namespace WAN {
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
struct ggml_tensor* out = decode_graph ? ae->decode(&runner_ctx, z) : ae->encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);
@ -1149,21 +1163,21 @@ namespace WAN {
struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) {
struct ggml_cgraph* gf = new_graph_custom(20480);
ae.clear_cache();
ae->clear_cache();
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
for (int64_t feat_idx = 0; feat_idx < ae->_feat_map.size(); feat_idx++) {
auto feat_cache = get_cache_tensor_by_name("feat_idx:" + std::to_string(feat_idx));
ae._feat_map[feat_idx] = feat_cache;
ae->_feat_map[feat_idx] = feat_cache;
}
z = to_backend(z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z);
struct ggml_tensor* out = decode_graph ? ae->decode_partial(&runner_ctx, z, i) : ae->encode(&runner_ctx, z);
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
for (int64_t feat_idx = 0; feat_idx < ae->_feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae->_feat_map[feat_idx];
if (feat_cache != nullptr) {
cache("feat_idx:" + std::to_string(feat_idx), feat_cache);
ggml_build_forward_expand(gf, feat_cache);
@ -1186,7 +1200,7 @@ namespace WAN {
};
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
} else { // chunk 1 result is weird
ae.clear_cache();
ae->clear_cache();
int64_t t = z->ne[2];
int64_t i = 0;
auto get_graph = [&]() -> struct ggml_cgraph* {
@ -1194,7 +1208,7 @@ namespace WAN {
};
struct ggml_tensor* out = nullptr;
bool res = GGMLRunner::compute(get_graph, n_threads, true, &out, output_ctx);
ae.clear_cache();
ae->clear_cache();
if (t == 1) {
*output = out;
return res;
@ -1222,7 +1236,7 @@ namespace WAN {
for (i = 1; i < t; i++) {
res = res || GGMLRunner::compute(get_graph, n_threads, true, &out);
ae.clear_cache();
ae->clear_cache();
copy_to_output();
}
free_cache_ctx_and_buffer();

View File

@ -531,7 +531,7 @@ namespace ZImage {
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) {
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE);
@ -550,7 +550,7 @@ namespace ZImage {
context->ne[1],
SEQ_MULTI_OF,
ref_latents,
increase_ref_index,
Rope::RefIndexMode::INCREASE,
z_image_params.theta,
z_image_params.axes_dim);
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
@ -579,14 +579,14 @@ namespace ZImage {
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
Rope::RefIndexMode ref_index_mode = Rope::RefIndexMode::FIXED,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
return build_graph(x, timesteps, context, ref_latents, ref_index_mode);
};
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -618,7 +618,7 @@ namespace ZImage {
struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
compute(8, x, timesteps, context, {}, Rope::RefIndexMode::INCREASE, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);