add wan model support

This commit is contained in:
leejet 2025-08-06 00:29:53 +08:00
parent e3f9366857
commit 5f7d98884c
10 changed files with 1146 additions and 242 deletions

View File

@ -746,11 +746,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
// params.verbose = true; params.verbose = true;
// sd_set_log_callback(sd_log_cb, (void*)&params); sd_set_log_callback(sd_log_cb, (void*)&params);
// WAN::WanVAERunner::load_from_file_and_test(argv[1]); WAN::WanRunner::load_from_file_and_test(argv[1]);
// return 0; return 0;
parse_args(argc, argv, params); parse_args(argc, argv, params);

178
flux.hpp
View File

@ -5,6 +5,7 @@
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "model.h" #include "model.h"
#include "rope.hpp"
#define FLUX_GRAPH_SIZE 10240 #define FLUX_GRAPH_SIZE 10240
@ -610,179 +611,11 @@ namespace Flux {
}; };
struct Flux : public GGMLBlock { struct Flux : public GGMLBlock {
public:
std::vector<float> linspace(float start, float end, int num) {
std::vector<float> result(num);
float step = (end - start) / (num - 1);
for (int i = 0; i < num; ++i) {
result[i] = start + i * step;
}
return result;
}
std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size();
int cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
transposed[j][i] = mat[i][j];
}
}
return transposed;
}
std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
std::vector<float> flat_vec;
for (const auto& sub_vec : vec) {
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
}
return flat_vec;
}
std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
assert(dim % 2 == 0);
int half_dim = dim / 2;
std::vector<float> scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim);
std::vector<float> omega(half_dim);
for (int i = 0; i < half_dim; ++i) {
omega[i] = 1.0 / std::pow(theta, scale[i]);
}
int pos_size = pos.size();
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
out[i][j] = pos[i] * omega[j];
}
}
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
result[i][4 * j] = std::cos(out[i][j]);
result[i][4 * j + 1] = -std::sin(out[i][j]);
result[i][4 * j + 2] = std::sin(out[i][j]);
result[i][4 * j + 3] = std::cos(out[i][j]);
}
}
return result;
}
// Generate IDs for image patches and text
std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
}
std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
std::vector<float> row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len);
std::vector<float> col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len);
for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
img_ids[i * w_len + j][0] = index;
img_ids[i * w_len + j][1] = row_ids[i];
img_ids[i * w_len + j][2] = col_ids[j];
}
}
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < img_ids.size(); ++j) {
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
}
}
return img_ids_repeated;
}
std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
const std::vector<std::vector<float>>& b,
int bs) {
size_t a_len = a.size() / bs;
size_t b_len = b.size() / bs;
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < a_len; ++j) {
ids[i * (a_len + b_len) + j] = a[i * a_len + j];
}
for (int j = 0; j < b_len; ++j) {
ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j];
}
}
return ids;
}
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0;
uint64_t w_offset = 0;
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
w_offset = curr_w_offset;
} else {
h_offset = curr_h_offset;
}
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
ids = concat_ids(ids, ref_ids, bs);
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
}
return ids;
}
// Generate positional embeddings
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents, int theta, const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents);
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size();
int num_axes = axes_dim.size();
for (int i = 0; i < pos_len; i++) {
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
}
int emb_dim = 0;
for (int d : axes_dim)
emb_dim += d / 2;
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
int offset = 0;
for (int i = 0; i < num_axes; ++i) {
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
for (int b = 0; b < bs; ++b) {
for (int j = 0; j < pos_len; ++j) {
for (int k = 0; k < rope_emb[0].size(); ++k) {
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
}
}
}
offset += rope_emb[0].size();
}
return flatten(emb);
}
public: public:
FluxParams params; FluxParams params;
Flux() {} Flux() {}
Flux(FluxParams params) Flux(FluxParams params)
: params(params) { : params(params) {
int64_t pe_dim = params.hidden_size / params.num_heads;
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true)); blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
if (params.is_chroma) { if (params.is_chroma) {
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size)); blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
@ -1150,7 +983,14 @@ namespace Flux {
ref_latents[i] = to_backend(ref_latents[i]); ref_latents[i] = to_backend(ref_latents[i]);
} }
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim); pe_vec = Rope::gen_flux_pe(x->ne[1],
x->ne[0],
2,
x->ne[3],
context->ne[1],
ref_latents,
flux_params.theta,
flux_params.axes_dim);
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);

View File

@ -663,6 +663,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
if (dim != 3) { if (dim != 3) {
x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
x = ggml_cont(ctx, x);
} }
return x; return x;
@ -838,7 +839,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
int64_t N = x->ne[3] / IC; int64_t N = x->ne[3] / IC;
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
if (b != NULL) { if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1] b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
x = ggml_add(ctx, x, b); x = ggml_add(ctx, x, b);
@ -1005,7 +1005,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// } // }
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); // GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
bool can_use_flash_attn = true; bool can_use_flash_attn = true;
can_use_flash_attn = can_use_flash_attn && (d_head == 64 || can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
@ -1542,6 +1542,13 @@ public:
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0; virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
}; };
class Identity : public UnaryBlock {
public:
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
return x;
}
};
class Linear : public UnaryBlock { class Linear : public UnaryBlock {
protected: protected:
int64_t in_features; int64_t in_features;
@ -1556,7 +1563,7 @@ protected:
} }
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
if (bias) { if (bias) {
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; enum ggml_type wtype = GGML_TYPE_F32;
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features); params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features);
} }
} }
@ -1844,6 +1851,30 @@ public:
: GroupNorm(32, num_channels, 1e-06f) {} : GroupNorm(32, num_channels, 1e-06f) {}
}; };
class RMSNorm : public UnaryBlock {
protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
public:
RMSNorm(int64_t hidden_size,
float eps = 1e-06f)
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
return x;
}
};
class MultiheadAttention : public GGMLBlock { class MultiheadAttention : public GGMLBlock {
protected: protected:
int64_t embed_dim; int64_t embed_dim;

View File

@ -142,30 +142,6 @@ public:
} }
}; };
class RMSNorm : public UnaryBlock {
protected:
int64_t hidden_size;
float eps;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
}
public:
RMSNorm(int64_t hidden_size,
float eps = 1e-06f)
: hidden_size(hidden_size),
eps(eps) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
x = ggml_rms_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
return x;
}
};
class SelfAttention : public GGMLBlock { class SelfAttention : public GGMLBlock {
public: public:
int64_t num_heads; int64_t num_heads;

View File

@ -1179,7 +1179,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
if (n_dims == 5) { if (n_dims == 5) {
n_dims = 4; n_dims = 4;
ne[0] = ne[0]*ne[1]; ne[0] = ne[0] * ne[1];
ne[1] = ne[2]; ne[1] = ne[2];
ne[2] = ne[3]; ne[2] = ne[3];
ne[3] = ne[4]; ne[3] = ne[4];
@ -2146,7 +2146,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) { std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result; std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : splitString(tensor_type_rules, ',')) { for (const auto& item : split_string(tensor_type_rules, ',')) {
if (item.size() == 0) if (item.size() == 0)
continue; continue;
std::string::size_type pos = item.find('='); std::string::size_type pos = item.find('=');

39
model.h
View File

@ -31,23 +31,11 @@ enum SDVersion {
VERSION_SD3, VERSION_SD3,
VERSION_FLUX, VERSION_FLUX,
VERSION_FLUX_FILL, VERSION_FLUX_FILL,
VERSION_WAN_2_1,
VERSION_WAN_2_2,
VERSION_COUNT, VERSION_COUNT,
}; };
static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
return true;
}
return false;
}
static inline bool sd_version_is_sd3(SDVersion version) {
if (version == VERSION_SD3) {
return true;
}
return false;
}
static inline bool sd_version_is_sd1(SDVersion version) { static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) { if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
return true; return true;
@ -69,6 +57,27 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
return false; return false;
} }
static inline bool sd_version_is_sd3(SDVersion version) {
if (version == VERSION_SD3) {
return true;
}
return false;
}
static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
return true;
}
return false;
}
static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN_2_1 || version == VERSION_WAN_2_2) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) { static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
return true; return true;
@ -77,7 +86,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
} }
static inline bool sd_version_is_dit(SDVersion version) { static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) {
return true; return true;
} }
return false; return false;

252
rope.hpp Normal file
View File

@ -0,0 +1,252 @@
#ifndef __ROPE_HPP__
#define __ROPE_HPP__
#include <vector>
#include "ggml_extend.hpp"
struct Rope {
template <class T>
static std::vector<T> linspace(T start, T end, int num) {
std::vector<T> result(num);
if (num == 1) {
result[0] = start;
return result;
}
T step = (end - start) / (num - 1);
for (int i = 0; i < num; ++i) {
result[i] = start + i * step;
}
return result;
}
static std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
int rows = mat.size();
int cols = mat[0].size();
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
transposed[j][i] = mat[i][j];
}
}
return transposed;
}
static std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
std::vector<float> flat_vec;
for (const auto& sub_vec : vec) {
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
}
return flat_vec;
}
static std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
assert(dim % 2 == 0);
int half_dim = dim / 2;
std::vector<float> scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim);
std::vector<float> omega(half_dim);
for (int i = 0; i < half_dim; ++i) {
omega[i] = 1.0 / std::pow(theta, scale[i]);
}
int pos_size = pos.size();
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
out[i][j] = pos[i] * omega[j];
}
}
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
for (int i = 0; i < pos_size; ++i) {
for (int j = 0; j < half_dim; ++j) {
result[i][4 * j] = std::cos(out[i][j]);
result[i][4 * j + 1] = -std::sin(out[i][j]);
result[i][4 * j + 2] = std::sin(out[i][j]);
result[i][4 * j + 3] = std::cos(out[i][j]);
}
}
return result;
}
// Generate IDs for image patches and text
static std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
}
static std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
img_ids[i * w_len + j][0] = index;
img_ids[i * w_len + j][1] = row_ids[i];
img_ids[i * w_len + j][2] = col_ids[j];
}
}
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < img_ids.size(); ++j) {
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
}
}
return img_ids_repeated;
}
static std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
const std::vector<std::vector<float>>& b,
int bs) {
size_t a_len = a.size() / bs;
size_t b_len = b.size() / bs;
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < a_len; ++j) {
ids[i * (a_len + b_len) + j] = a[i * a_len + j];
}
for (int j = 0; j < b_len; ++j) {
ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j];
}
}
return ids;
}
static std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
int bs,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> trans_ids = transpose(ids);
size_t pos_len = ids.size() / bs;
int num_axes = axes_dim.size();
// for (int i = 0; i < pos_len; i++) {
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
// }
int emb_dim = 0;
for (int d : axes_dim)
emb_dim += d / 2;
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
int offset = 0;
for (int i = 0; i < num_axes; ++i) {
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
for (int b = 0; b < bs; ++b) {
for (int j = 0; j < pos_len; ++j) {
for (int k = 0; k < rope_emb[0].size(); ++k) {
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
}
}
}
offset += rope_emb[0].size();
}
return flatten(emb);
}
static std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
std::vector<ggml_tensor*> ref_latents) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
for (ggml_tensor* ref : ref_latents) {
uint64_t h_offset = 0;
uint64_t w_offset = 0;
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
w_offset = curr_w_offset;
} else {
h_offset = curr_h_offset;
}
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
ids = concat_ids(ids, ref_ids, bs);
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
}
return ids;
}
// Generate flux positional embeddings
static std::vector<float> gen_flux_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
std::vector<ggml_tensor*> ref_latents,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents);
return embed_nd(ids, bs, theta, axes_dim);
}
static std::vector<std::vector<float>> gen_vid_ids(int t,
int h,
int w,
int pt,
int ph,
int pw,
int bs,
int t_offset = 0,
int h_offset = 0,
int w_offset = 0) {
int t_len = (t + (pt / 2)) / pt;
int h_len = (h + (ph / 2)) / ph;
int w_len = (w + (pw / 2)) / pw;
std::vector<std::vector<float>> vid_ids(t_len * h_len * w_len, std::vector<float>(3, 0.0));
std::vector<float> t_ids = linspace<float>(t_offset, t_len - 1 + t_offset, t_len);
std::vector<float> h_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
std::vector<float> w_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
for (int i = 0; i < t_len; ++i) {
for (int j = 0; j < h_len; ++j) {
for (int k = 0; k < w_len; ++k) {
int idx = i * h_len * w_len + j * w_len + k;
vid_ids[idx][0] = t_ids[i];
vid_ids[idx][1] = h_ids[j];
vid_ids[idx][2] = w_ids[k];
}
}
}
std::vector<std::vector<float>> vid_ids_repeated(bs * vid_ids.size(), std::vector<float>(3));
for (int i = 0; i < bs; ++i) {
for (int j = 0; j < vid_ids.size(); ++j) {
vid_ids_repeated[i * vid_ids.size() + j] = vid_ids[j];
}
}
return vid_ids_repeated;
}
// Generate wan positional embeddings
static std::vector<float> gen_wan_pe(int t,
int h,
int w,
int pt,
int ph,
int pw,
int bs,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
return embed_nd(ids, bs, theta, axes_dim);
}
}; // struct Rope
#endif __ROPE_HPP__

View File

@ -290,7 +290,7 @@ std::string path_join(const std::string& p1, const std::string& p2) {
return p1 + "/" + p2; return p1 + "/" + p2;
} }
std::vector<std::string> splitString(const std::string& str, char delimiter) { std::vector<std::string> split_string(const std::string& str, char delimiter) {
std::vector<std::string> result; std::vector<std::string> result;
size_t start = 0; size_t start = 0;
size_t end = str.find(delimiter); size_t end = str.find(delimiter);

2
util.h
View File

@ -48,7 +48,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
std::string path_join(const std::string& p1, const std::string& p2); std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> splitString(const std::string& str, char delimiter); std::vector<std::string> split_string(const std::string& str, char delimiter);
void pretty_progress(int step, int steps, float time); void pretty_progress(int step, int steps, float time);
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);

800
wan.hpp
View File

@ -4,11 +4,14 @@
#include <map> #include <map>
#include "common.hpp" #include "common.hpp"
#include "flux.hpp"
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "rope.hpp"
namespace WAN { namespace WAN {
constexpr int CACHE_T = 2; constexpr int CACHE_T = 2;
constexpr int WAN_GRAPH_SIZE = 10240;
class CausalConv3d : public GGMLBlock { class CausalConv3d : public GGMLBlock {
protected: protected:
@ -828,6 +831,799 @@ namespace WAN {
} }
}; };
}; class WanSelfAttention : public GGMLBlock {
public:
int64_t num_heads;
int64_t head_dim;
bool flash_attn;
#endif public:
WanSelfAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: num_heads(num_heads), flash_attn(flash_attn) {
head_dim = dim / num_heads;
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
blocks["o"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
if (qk_norm) {
blocks["norm_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
blocks["norm_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
} else {
blocks["norm_q"] = std::shared_ptr<GGMLBlock>(new Identity());
blocks["norm_k"] = std::shared_ptr<GGMLBlock>(new Identity());
}
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe,
struct ggml_tensor* mask = NULL) {
// x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
int64_t N = x->ne[2];
int64_t n_token = x->ne[1];
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v"]);
auto o_proj = std::dynamic_pointer_cast<Linear>(blocks["o"]);
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q);
auto k = k_proj->forward(ctx, x);
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head]
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
x = Flux::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
}
};
class WanCrossAttention : public WanSelfAttention {
public:
WanCrossAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int64_t context_img_len) = 0;
};
class WanT2VCrossAttention : public WanCrossAttention {
public:
WanT2VCrossAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int64_t context_img_len) {
// x: [N, n_token, dim]
// context: [N, n_context, dim]
// context_img_len: unused
// return [N, n_token, dim]
int64_t N = x->ne[2];
int64_t n_token = x->ne[1];
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v"]);
auto o_proj = std::dynamic_pointer_cast<Linear>(blocks["o"]);
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q);
auto k = k_proj->forward(ctx, context); // [N, n_context, dim]
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
}
};
class WanI2VCrossAttention : public WanCrossAttention {
public:
WanI2VCrossAttention(int64_t dim,
int64_t num_heads,
bool qk_norm = true,
float eps = 1e-6,
bool flash_attn = false)
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {
blocks["k_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
blocks["v_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
if (qk_norm) {
blocks["norm_k_img"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* context,
int64_t context_img_len) {
// x: [N, n_token, dim]
// context: [N, context_img_len + context_txt_len, dim]
// return [N, n_token, dim]
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v"]);
auto o_proj = std::dynamic_pointer_cast<Linear>(blocks["o"]);
auto k_img_proj = std::dynamic_pointer_cast<Linear>(blocks["k_img"]);
auto v_img_proj = std::dynamic_pointer_cast<Linear>(blocks["v_img"]);
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
auto norm_k_img = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k_img"]);
int64_t N = x->ne[2];
int64_t n_token = x->ne[1];
int64_t dim = x->ne[2];
int64_t context_txt_len = context->ne[1] - context_img_len;
context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
auto q = q_proj->forward(ctx, x);
q = norm_q->forward(ctx, q);
auto k = k_proj->forward(ctx, context_txt); // [N, context_txt_len, dim]
k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context_txt); // [N, context_txt_len, dim]
auto k_img = k_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim]
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
x = ggml_add(ctx, x, img_x);
x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x;
}
};
class WanAttentionBlock : public GGMLBlock {
protected:
int dim;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1);
}
public:
WanAttentionBlock(bool t2v_cross_attn,
int64_t dim,
int64_t ffn_dim,
int64_t num_heads,
bool qk_norm = true,
bool cross_attn_norm = false,
float eps = 1e-6,
bool flash_attn = false)
: dim(dim) {
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps));
if (cross_attn_norm) {
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
} else {
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
}
if (t2v_cross_attn) {
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps));
} else {
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps));
}
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["ffn.0"] = std::shared_ptr<GGMLBlock>(new Linear(dim, ffn_dim));
// ffn.1 is nn.GELU(approximate='tanh')
blocks["ffn.2"] = std::shared_ptr<GGMLBlock>(new Linear(ffn_dim, dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* e,
struct ggml_tensor* pe,
struct ggml_tensor* context,
int64_t context_img_len = 257) {
// x: [N, n_token, dim]
// e: [N, 6, dim]
// context: [N, context_img_len + context_txt_len, dim]
// return [N, n_token, dim]
auto modulation = params["modulation"];
e = ggml_add(ctx, modulation, e); // [N, 6, dim]
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...)
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
auto norm3 = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm3"]);
auto cross_attn = std::dynamic_pointer_cast<WanCrossAttention>(blocks["cross_attn"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto ffn_0 = std::dynamic_pointer_cast<Linear>(blocks["ffn.0"]);
auto ffn_2 = std::dynamic_pointer_cast<Linear>(blocks["ffn.2"]);
// self-attention
auto y = norm1->forward(ctx, x);
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[1]));
y = ggml_add(ctx, y, es[0]);
y = self_attn->forward(ctx, y, pe);
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[2]));
// cross-attention
x = ggml_add(ctx,
x,
cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len));
// ffn
y = norm2->forward(ctx, x);
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[4]));
y = ggml_add(ctx, y, es[3]);
y = ffn_0->forward(ctx, y);
y = ggml_gelu_inplace(ctx, y);
y = ffn_2->forward(ctx, y);
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[5]));
return x;
}
};
class Head : public GGMLBlock {
protected:
int dim;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1);
}
public:
Head(int64_t dim,
int64_t out_dim,
std::tuple<int, int, int> patch_size,
float eps = 1e-6)
: dim(dim) {
out_dim = out_dim * std::get<0>(patch_size) * std::get<1>(patch_size) * std::get<2>(patch_size);
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["head"] = std::shared_ptr<GGMLBlock>(new Linear(dim, out_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* e) {
// x: [N, n_token, dim]
// e: [N, dim]
// return [N, n_token, out_dim]
auto modulation = params["modulation"];
e = ggml_add(ctx, modulation, ggml_reshape_3d(ctx, e, e->ne[0], 1, e->ne[1])); // [N, 2, dim]
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...)
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
x = norm->forward(ctx, x);
x = ggml_add(ctx, x, ggml_mul(ctx, x, es[1]));
x = ggml_add(ctx, x, es[0]);
x = head->forward(ctx, x);
return x;
}
};
class MLPProj : public GGMLBlock {
protected:
int in_dim;
int flf_pos_embed_token_number;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
if (flf_pos_embed_token_number > 0) {
params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1);
}
}
public:
MLPProj(int64_t in_dim,
int64_t out_dim,
int64_t flf_pos_embed_token_number = 0)
: in_dim(in_dim), flf_pos_embed_token_number(flf_pos_embed_token_number) {
blocks["proj.0"] = std::shared_ptr<GGMLBlock>(new LayerNorm(in_dim));
blocks["proj.1"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, in_dim));
// proj.2 is nn.GELU()
blocks["proj.3"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, out_dim));
blocks["proj.4"] = std::shared_ptr<GGMLBlock>(new LayerNorm(out_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* image_embeds) {
if (flf_pos_embed_token_number > 0) {
auto emb_pos = params["emb_pos"];
auto a = ggml_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]);
auto b = ggml_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]);
image_embeds = ggml_add(ctx, a, b);
}
auto proj_0 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.0"]);
auto proj_1 = std::dynamic_pointer_cast<Linear>(blocks["proj.1"]);
auto proj_3 = std::dynamic_pointer_cast<Linear>(blocks["proj.3"]);
auto proj_4 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.4"]);
auto x = proj_0->forward(ctx, image_embeds);
x = proj_1->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = proj_3->forward(ctx, x);
x = proj_4->forward(ctx, x);
return x; // clip_extra_context_tokens
}
};
struct WanParams {
std::string model_type = "t2v";
std::tuple<int, int, int> patch_size = {1, 2, 2};
int64_t text_len = 512;
int64_t in_dim = 16;
int64_t dim = 2048;
int64_t ffn_dim = 8192;
int64_t freq_dim = 256;
int64_t text_dim = 4096;
int64_t out_dim = 16;
int64_t num_heads = 16;
int64_t num_layers = 32;
bool qk_norm = true;
bool cross_attn_norm = true;
float eps = 1e-6;
int64_t flf_pos_embed_token_number = 0;
int theta = 10000;
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
std::vector<int> axes_dim = {44, 42, 42};
int64_t axes_dim_sum = 128;
};
class WanModel : public GGMLBlock {
protected:
WanParams params;
public:
WanModel() {}
WanModel(WanParams params)
: params(params) {
// patch_embedding
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size));
// text_embedding
blocks["text_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(params.text_dim, params.dim));
// text_embedding.1 is nn.GELU()
blocks["text_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim));
// time_embedding
blocks["time_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(params.freq_dim, params.dim));
// time_embedding.1 is nn.SiLU()
blocks["time_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim));
// time_projection.0 is nn.SiLU()
blocks["time_projection.1"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim * 6));
// blocks
for (int i = 0; i < params.num_layers; i++) {
auto block = std::shared_ptr<GGMLBlock>(new WanAttentionBlock(params.model_type == "t2v",
params.dim,
params.ffn_dim,
params.num_heads,
params.qk_norm,
params.cross_attn_norm,
params.eps));
blocks["blocks." + std::to_string(i)] = block;
}
// head
blocks["head"] = std::shared_ptr<GGMLBlock>(new Head(params.dim, params.out_dim, params.patch_size, params.eps));
// img_emb
if (params.model_type == "i2v") {
blocks["img_emb"] = std::shared_ptr<GGMLBlock>(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number));
}
}
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
struct ggml_tensor* x) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t T = x->ne[1];
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w]
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t t_len,
int64_t h_len,
int64_t w_len) {
// x: [N, t_len*h_len*w_len, pt*ph*pw*C]
// return: [N*C, t_len*pt, h_len*ph, w_len*pw]
int64_t N = x->ne[3];
int64_t pt = std::get<0>(params.patch_size);
int64_t ph = std::get<1>(params.patch_size);
int64_t pw = std::get<2>(params.patch_size);
int64_t C = x->ne[0] / pt / ph / pw;
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
return x;
}
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* pe,
struct ggml_tensor* clip_fea = NULL,
int64_t N = 1) {
// x: [N*C, T, H, W], C => in_dim
// timestep: [N,]
// context: [N, L, text_dim]
// return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw]
auto patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["patch_embedding"]);
auto text_embedding_0 = std::dynamic_pointer_cast<Linear>(blocks["text_embedding.0"]);
auto text_embedding_2 = std::dynamic_pointer_cast<Linear>(blocks["text_embedding.2"]);
auto time_embedding_0 = std::dynamic_pointer_cast<Linear>(blocks["time_embedding.0"]);
auto time_embedding_2 = std::dynamic_pointer_cast<Linear>(blocks["time_embedding.2"]);
auto time_projection_1 = std::dynamic_pointer_cast<Linear>(blocks["time_projection.1"]);
auto head = std::dynamic_pointer_cast<Head>(blocks["head"]);
// patch_embedding
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
// time_embedding
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
e = time_embedding_0->forward(ctx, e);
e = ggml_silu_inplace(ctx, e);
e = time_embedding_2->forward(ctx, e); // [N, dim]
// time_projection
auto e0 = ggml_silu(ctx, e);
e0 = time_projection_1->forward(ctx, e0);
e0 = ggml_reshape_3d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1]); // [N, 6, dim]
context = text_embedding_0->forward(ctx, context);
context = ggml_gelu(ctx, context);
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
int64_t context_img_len = 0;
if (clip_fea != NULL) {
if (params.model_type == "i2v") {
auto img_emb = std::dynamic_pointer_cast<MLPProj>(blocks["img_emb"]);
auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim]
context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
}
context_img_len = clip_fea->ne[1]; // 257
}
for (int i = 0; i < params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
x = block->forward(ctx, x, e0, pe, context, context_img_len);
}
x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim]
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* pe,
struct ggml_tensor* clip_fea = NULL,
struct ggml_tensor* time_dim_concat = NULL,
int64_t N = 1) {
// Forward pass of DiT.
// x: [N*C, T, H, W]
// timestep: [N,]
// context: [N, L, D]
// pe: [L, d_head/2, 2, 2]
// time_dim_concat: [N*C, T2, H, W]
// return: [N*C, T, H, W]
GGML_ASSERT(N == 1);
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t T = x->ne[2];
int64_t C = x->ne[3];
x = pad_to_patch_size(ctx, x);
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
if (time_dim_concat != NULL) {
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
}
auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
// slice
out = ggml_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
out = ggml_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
out = ggml_slice(ctx, out, 0, 0, W); // [N*C, T, H, W]
return out;
}
};
struct WanRunner : public GGMLRunner {
public:
WanParams wan_params;
WanModel wan;
std::vector<float> pe_vec;
SDVersion version;
WanRunner(ggml_backend_t backend,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_WAN_2_1)
: GGMLRunner(backend) {
wan_params.num_layers = 0;
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
continue;
size_t pos = tensor_name.find("blocks.");
if (pos != std::string::npos) {
tensor_name = tensor_name.substr(pos); // remove prefix
auto items = split_string(tensor_name, '.');
if (items.size() > 1) {
int block_index = atoi(items[1].c_str());
if (block_index + 1 > wan_params.num_layers) {
wan_params.num_layers = block_index + 1;
}
}
}
if (tensor_name.find("img_emb") != std::string::npos) {
wan_params.model_type = "i2v";
}
}
if (wan_params.num_layers == 30) {
LOG_INFO("Wan2.1-T2V-1.3B");
wan_params.dim = 1536;
wan_params.eps = 1e-06;
wan_params.ffn_dim = 8960;
wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 12;
wan_params.out_dim = 16;
wan_params.text_len = 512;
} else if (wan_params.num_layers == 40) {
if (wan_params.model_type == "t2v") {
LOG_INFO("Wan2.1-T2V-14B");
} else {
LOG_INFO("Wan2.1-I2V-14B");
}
wan_params.dim = 5120;
wan_params.eps = 1e-06;
wan_params.ffn_dim = 13824;
wan_params.freq_dim = 256;
wan_params.in_dim = 16;
wan_params.num_heads = 40;
wan_params.out_dim = 16;
wan_params.text_len = 512;
} else {
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
}
wan = WanModel(wan_params);
wan.init(params_ctx, tensor_types, prefix);
}
std::string get_desc() {
return "wan";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
wan.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* clip_fea = NULL,
struct ggml_tensor* time_dim_concat = NULL) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false);
x = to_backend(x);
timesteps = to_backend(timesteps);
context = to_backend(context);
clip_fea = to_backend(clip_fea);
time_dim_concat = to_backend(time_dim_concat);
pe_vec = Rope::gen_wan_pe(x->ne[2],
x->ne[1],
x->ne[0],
std::get<0>(wan_params.patch_size),
std::get<1>(wan_params.patch_size),
std::get<2>(wan_params.patch_size),
1,
wan_params.theta,
wan_params.axes_dim);
int pos_len = pe_vec.size() / wan_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
struct ggml_tensor* out = wan.forward(compute_ctx,
x,
timesteps,
context,
pe,
clip_fea,
time_dim_concat);
ggml_build_forward_expand(gf, out);
return gf;
}
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* clip_fea = NULL,
struct ggml_tensor* time_dim_concat = NULL,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, clip_fea, time_dim_concat);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(20 * 1024 * 1024); // 20 MB
params.mem_buffer = NULL;
params.no_alloc = false;
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL);
{
// cpu f16: pass
// cuda f16: pass
// cpu q8_0: pass
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 1, 16);
// ggml_set_f32(x, 0.01f);
auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin");
// print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 999.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1);
// ggml_set_f32(context, 0.01f);
auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin");
// print_ggml_tensor(context);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, NULL, NULL, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("wan test done in %dms", t1 - t0);
}
}
static void load_from_file_and_test(const std::string& file_path) {
ggml_backend_t backend = ggml_backend_cuda_init(0);
// ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
LOG_INFO("loading from '%s'", file_path.c_str());
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return;
}
auto tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (ends_with(item.first, "weight")) {
item.second = model_data_type;
}
}
std::shared_ptr<WanRunner> wan = std::shared_ptr<WanRunner>(new WanRunner(backend,
tensor_types,
"model.diffusion_model"));
wan->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
wan->get_param_tensors(tensors, "model.diffusion_model");
bool success = model_loader.load_tensors(tensors, backend);
if (!success) {
LOG_ERROR("load tensors from model loader failed");
return;
}
LOG_INFO("wan model loaded");
wan->test();
}
};
} // namespace WAN
#endif // __WAN_HPP__