mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan model support
This commit is contained in:
parent
e3f9366857
commit
5f7d98884c
@ -746,11 +746,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|||||||
|
|
||||||
int main(int argc, const char* argv[]) {
|
int main(int argc, const char* argv[]) {
|
||||||
SDParams params;
|
SDParams params;
|
||||||
// params.verbose = true;
|
params.verbose = true;
|
||||||
// sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
||||||
|
|
||||||
// WAN::WanVAERunner::load_from_file_and_test(argv[1]);
|
WAN::WanRunner::load_from_file_and_test(argv[1]);
|
||||||
// return 0;
|
return 0;
|
||||||
|
|
||||||
parse_args(argc, argv, params);
|
parse_args(argc, argv, params);
|
||||||
|
|
||||||
|
|||||||
178
flux.hpp
178
flux.hpp
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
#include "model.h"
|
#include "model.h"
|
||||||
|
#include "rope.hpp"
|
||||||
|
|
||||||
#define FLUX_GRAPH_SIZE 10240
|
#define FLUX_GRAPH_SIZE 10240
|
||||||
|
|
||||||
@ -610,179 +611,11 @@ namespace Flux {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Flux : public GGMLBlock {
|
struct Flux : public GGMLBlock {
|
||||||
public:
|
|
||||||
std::vector<float> linspace(float start, float end, int num) {
|
|
||||||
std::vector<float> result(num);
|
|
||||||
float step = (end - start) / (num - 1);
|
|
||||||
for (int i = 0; i < num; ++i) {
|
|
||||||
result[i] = start + i * step;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
|
|
||||||
int rows = mat.size();
|
|
||||||
int cols = mat[0].size();
|
|
||||||
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
|
|
||||||
for (int i = 0; i < rows; ++i) {
|
|
||||||
for (int j = 0; j < cols; ++j) {
|
|
||||||
transposed[j][i] = mat[i][j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return transposed;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
|
|
||||||
std::vector<float> flat_vec;
|
|
||||||
for (const auto& sub_vec : vec) {
|
|
||||||
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
|
|
||||||
}
|
|
||||||
return flat_vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
|
|
||||||
assert(dim % 2 == 0);
|
|
||||||
int half_dim = dim / 2;
|
|
||||||
|
|
||||||
std::vector<float> scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim);
|
|
||||||
|
|
||||||
std::vector<float> omega(half_dim);
|
|
||||||
for (int i = 0; i < half_dim; ++i) {
|
|
||||||
omega[i] = 1.0 / std::pow(theta, scale[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
int pos_size = pos.size();
|
|
||||||
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
|
|
||||||
for (int i = 0; i < pos_size; ++i) {
|
|
||||||
for (int j = 0; j < half_dim; ++j) {
|
|
||||||
out[i][j] = pos[i] * omega[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
|
|
||||||
for (int i = 0; i < pos_size; ++i) {
|
|
||||||
for (int j = 0; j < half_dim; ++j) {
|
|
||||||
result[i][4 * j] = std::cos(out[i][j]);
|
|
||||||
result[i][4 * j + 1] = -std::sin(out[i][j]);
|
|
||||||
result[i][4 * j + 2] = std::sin(out[i][j]);
|
|
||||||
result[i][4 * j + 3] = std::cos(out[i][j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate IDs for image patches and text
|
|
||||||
std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
|
|
||||||
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
|
|
||||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
|
||||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
|
|
||||||
|
|
||||||
std::vector<float> row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len);
|
|
||||||
std::vector<float> col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len);
|
|
||||||
|
|
||||||
for (int i = 0; i < h_len; ++i) {
|
|
||||||
for (int j = 0; j < w_len; ++j) {
|
|
||||||
img_ids[i * w_len + j][0] = index;
|
|
||||||
img_ids[i * w_len + j][1] = row_ids[i];
|
|
||||||
img_ids[i * w_len + j][2] = col_ids[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
|
|
||||||
for (int i = 0; i < bs; ++i) {
|
|
||||||
for (int j = 0; j < img_ids.size(); ++j) {
|
|
||||||
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return img_ids_repeated;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
|
|
||||||
const std::vector<std::vector<float>>& b,
|
|
||||||
int bs) {
|
|
||||||
size_t a_len = a.size() / bs;
|
|
||||||
size_t b_len = b.size() / bs;
|
|
||||||
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
|
|
||||||
for (int i = 0; i < bs; ++i) {
|
|
||||||
for (int j = 0; j < a_len; ++j) {
|
|
||||||
ids[i * (a_len + b_len) + j] = a[i * a_len + j];
|
|
||||||
}
|
|
||||||
for (int j = 0; j < b_len; ++j) {
|
|
||||||
ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ids;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents) {
|
|
||||||
auto txt_ids = gen_txt_ids(bs, context_len);
|
|
||||||
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
|
||||||
|
|
||||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
|
||||||
uint64_t curr_h_offset = 0;
|
|
||||||
uint64_t curr_w_offset = 0;
|
|
||||||
for (ggml_tensor* ref : ref_latents) {
|
|
||||||
uint64_t h_offset = 0;
|
|
||||||
uint64_t w_offset = 0;
|
|
||||||
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
|
|
||||||
w_offset = curr_w_offset;
|
|
||||||
} else {
|
|
||||||
h_offset = curr_h_offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
|
|
||||||
ids = concat_ids(ids, ref_ids, bs);
|
|
||||||
|
|
||||||
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
|
|
||||||
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
|
|
||||||
}
|
|
||||||
return ids;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate positional embeddings
|
|
||||||
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents, int theta, const std::vector<int>& axes_dim) {
|
|
||||||
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents);
|
|
||||||
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
|
||||||
size_t pos_len = ids.size();
|
|
||||||
int num_axes = axes_dim.size();
|
|
||||||
for (int i = 0; i < pos_len; i++) {
|
|
||||||
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
int emb_dim = 0;
|
|
||||||
for (int d : axes_dim)
|
|
||||||
emb_dim += d / 2;
|
|
||||||
|
|
||||||
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
|
|
||||||
int offset = 0;
|
|
||||||
for (int i = 0; i < num_axes; ++i) {
|
|
||||||
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
|
|
||||||
for (int b = 0; b < bs; ++b) {
|
|
||||||
for (int j = 0; j < pos_len; ++j) {
|
|
||||||
for (int k = 0; k < rope_emb[0].size(); ++k) {
|
|
||||||
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
offset += rope_emb[0].size();
|
|
||||||
}
|
|
||||||
|
|
||||||
return flatten(emb);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
FluxParams params;
|
FluxParams params;
|
||||||
Flux() {}
|
Flux() {}
|
||||||
Flux(FluxParams params)
|
Flux(FluxParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
int64_t pe_dim = params.hidden_size / params.num_heads;
|
|
||||||
|
|
||||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
||||||
if (params.is_chroma) {
|
if (params.is_chroma) {
|
||||||
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
|
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
|
||||||
@ -1150,7 +983,14 @@ namespace Flux {
|
|||||||
ref_latents[i] = to_backend(ref_latents[i]);
|
ref_latents[i] = to_backend(ref_latents[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim);
|
pe_vec = Rope::gen_flux_pe(x->ne[1],
|
||||||
|
x->ne[0],
|
||||||
|
2,
|
||||||
|
x->ne[3],
|
||||||
|
context->ne[1],
|
||||||
|
ref_latents,
|
||||||
|
flux_params.theta,
|
||||||
|
flux_params.axes_dim);
|
||||||
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
|
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
|
||||||
// LOG_DEBUG("pos_len %d", pos_len);
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
|
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
|
||||||
|
|||||||
@ -663,6 +663,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_slice(struct ggml_context* ctx,
|
|||||||
|
|
||||||
if (dim != 3) {
|
if (dim != 3) {
|
||||||
x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
x = ggml_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||||
|
x = ggml_cont(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
@ -837,10 +838,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d(struct ggml_context* ctx,
|
|||||||
int64_t OC = w->ne[3] / IC;
|
int64_t OC = w->ne[3] / IC;
|
||||||
int64_t N = x->ne[3] / IC;
|
int64_t N = x->ne[3] / IC;
|
||||||
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
x = ggml_conv_3d(ctx, w, x, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||||
|
|
||||||
|
|
||||||
if (b != NULL) {
|
if (b != NULL) {
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
b = ggml_reshape_4d(ctx, b, 1, 1, 1, b->ne[0]); // [OC, 1, 1, 1]
|
||||||
x = ggml_add(ctx, x, b);
|
x = ggml_add(ctx, x, b);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
@ -1005,7 +1005,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||||
// }
|
// }
|
||||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
||||||
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
// GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
||||||
|
|
||||||
bool can_use_flash_attn = true;
|
bool can_use_flash_attn = true;
|
||||||
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
||||||
@ -1542,6 +1542,13 @@ public:
|
|||||||
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
|
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Identity : public UnaryBlock {
|
||||||
|
public:
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class Linear : public UnaryBlock {
|
class Linear : public UnaryBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t in_features;
|
int64_t in_features;
|
||||||
@ -1556,7 +1563,7 @@ protected:
|
|||||||
}
|
}
|
||||||
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
|
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features);
|
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1726,7 +1733,7 @@ protected:
|
|||||||
std::get<0>(kernel_size),
|
std::get<0>(kernel_size),
|
||||||
in_channels * out_channels);
|
in_channels * out_channels);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1844,6 +1851,30 @@ public:
|
|||||||
: GroupNorm(32, num_channels, 1e-06f) {}
|
: GroupNorm(32, num_channels, 1e-06f) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RMSNorm : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t hidden_size;
|
||||||
|
float eps;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
||||||
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
|
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
RMSNorm(int64_t hidden_size,
|
||||||
|
float eps = 1e-06f)
|
||||||
|
: hidden_size(hidden_size),
|
||||||
|
eps(eps) {}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
|
struct ggml_tensor* w = params["weight"];
|
||||||
|
x = ggml_rms_norm(ctx, x, eps);
|
||||||
|
x = ggml_mul(ctx, x, w);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class MultiheadAttention : public GGMLBlock {
|
class MultiheadAttention : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t embed_dim;
|
int64_t embed_dim;
|
||||||
|
|||||||
24
mmdit.hpp
24
mmdit.hpp
@ -142,30 +142,6 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class RMSNorm : public UnaryBlock {
|
|
||||||
protected:
|
|
||||||
int64_t hidden_size;
|
|
||||||
float eps;
|
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
|
||||||
params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
RMSNorm(int64_t hidden_size,
|
|
||||||
float eps = 1e-06f)
|
|
||||||
: hidden_size(hidden_size),
|
|
||||||
eps(eps) {}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
|
||||||
struct ggml_tensor* w = params["weight"];
|
|
||||||
x = ggml_rms_norm(ctx, x, eps);
|
|
||||||
x = ggml_mul(ctx, x, w);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class SelfAttention : public GGMLBlock {
|
class SelfAttention : public GGMLBlock {
|
||||||
public:
|
public:
|
||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
|
|||||||
10
model.cpp
10
model.cpp
@ -1179,10 +1179,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
|||||||
|
|
||||||
if (n_dims == 5) {
|
if (n_dims == 5) {
|
||||||
n_dims = 4;
|
n_dims = 4;
|
||||||
ne[0] = ne[0]*ne[1];
|
ne[0] = ne[0] * ne[1];
|
||||||
ne[1] = ne[2];
|
ne[1] = ne[2];
|
||||||
ne[2] = ne[3];
|
ne[2] = ne[3];
|
||||||
ne[3] = ne[4];
|
ne[3] = ne[4];
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_n_dims returns 1 for scalars
|
// ggml_n_dims returns 1 for scalars
|
||||||
@ -2146,7 +2146,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
|
|||||||
|
|
||||||
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||||
std::vector<std::pair<std::string, ggml_type>> result;
|
std::vector<std::pair<std::string, ggml_type>> result;
|
||||||
for (const auto& item : splitString(tensor_type_rules, ',')) {
|
for (const auto& item : split_string(tensor_type_rules, ',')) {
|
||||||
if (item.size() == 0)
|
if (item.size() == 0)
|
||||||
continue;
|
continue;
|
||||||
std::string::size_type pos = item.find('=');
|
std::string::size_type pos = item.find('=');
|
||||||
|
|||||||
39
model.h
39
model.h
@ -31,23 +31,11 @@ enum SDVersion {
|
|||||||
VERSION_SD3,
|
VERSION_SD3,
|
||||||
VERSION_FLUX,
|
VERSION_FLUX,
|
||||||
VERSION_FLUX_FILL,
|
VERSION_FLUX_FILL,
|
||||||
|
VERSION_WAN_2_1,
|
||||||
|
VERSION_WAN_2_2,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool sd_version_is_flux(SDVersion version) {
|
|
||||||
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline bool sd_version_is_sd3(SDVersion version) {
|
|
||||||
if (version == VERSION_SD3) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline bool sd_version_is_sd1(SDVersion version) {
|
static inline bool sd_version_is_sd1(SDVersion version) {
|
||||||
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
|
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
|
||||||
return true;
|
return true;
|
||||||
@ -69,6 +57,27 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_sd3(SDVersion version) {
|
||||||
|
if (version == VERSION_SD3) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_flux(SDVersion version) {
|
||||||
|
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool sd_version_is_wan(SDVersion version) {
|
||||||
|
if (version == VERSION_WAN_2_1 || version == VERSION_WAN_2_2) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_inpaint(SDVersion version) {
|
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||||
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
|
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
|
||||||
return true;
|
return true;
|
||||||
@ -77,7 +86,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_dit(SDVersion version) {
|
static inline bool sd_version_is_dit(SDVersion version) {
|
||||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
|
if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
252
rope.hpp
Normal file
252
rope.hpp
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
#ifndef __ROPE_HPP__
|
||||||
|
#define __ROPE_HPP__
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "ggml_extend.hpp"
|
||||||
|
|
||||||
|
struct Rope {
|
||||||
|
template <class T>
|
||||||
|
static std::vector<T> linspace(T start, T end, int num) {
|
||||||
|
std::vector<T> result(num);
|
||||||
|
if (num == 1) {
|
||||||
|
result[0] = start;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
T step = (end - start) / (num - 1);
|
||||||
|
for (int i = 0; i < num; ++i) {
|
||||||
|
result[i] = start + i * step;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> transpose(const std::vector<std::vector<float>>& mat) {
|
||||||
|
int rows = mat.size();
|
||||||
|
int cols = mat[0].size();
|
||||||
|
std::vector<std::vector<float>> transposed(cols, std::vector<float>(rows));
|
||||||
|
for (int i = 0; i < rows; ++i) {
|
||||||
|
for (int j = 0; j < cols; ++j) {
|
||||||
|
transposed[j][i] = mat[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> flatten(const std::vector<std::vector<float>>& vec) {
|
||||||
|
std::vector<float> flat_vec;
|
||||||
|
for (const auto& sub_vec : vec) {
|
||||||
|
flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end());
|
||||||
|
}
|
||||||
|
return flat_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> rope(const std::vector<float>& pos, int dim, int theta) {
|
||||||
|
assert(dim % 2 == 0);
|
||||||
|
int half_dim = dim / 2;
|
||||||
|
|
||||||
|
std::vector<float> scale = linspace(0.f, (dim * 1.f - 2) / dim, half_dim);
|
||||||
|
|
||||||
|
std::vector<float> omega(half_dim);
|
||||||
|
for (int i = 0; i < half_dim; ++i) {
|
||||||
|
omega[i] = 1.0 / std::pow(theta, scale[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int pos_size = pos.size();
|
||||||
|
std::vector<std::vector<float>> out(pos_size, std::vector<float>(half_dim));
|
||||||
|
for (int i = 0; i < pos_size; ++i) {
|
||||||
|
for (int j = 0; j < half_dim; ++j) {
|
||||||
|
out[i][j] = pos[i] * omega[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> result(pos_size, std::vector<float>(half_dim * 4));
|
||||||
|
for (int i = 0; i < pos_size; ++i) {
|
||||||
|
for (int j = 0; j < half_dim; ++j) {
|
||||||
|
result[i][4 * j] = std::cos(out[i][j]);
|
||||||
|
result[i][4 * j + 1] = -std::sin(out[i][j]);
|
||||||
|
result[i][4 * j + 2] = std::sin(out[i][j]);
|
||||||
|
result[i][4 * j + 3] = std::cos(out[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate IDs for image patches and text
|
||||||
|
static std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
|
||||||
|
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) {
|
||||||
|
int h_len = (h + (patch_size / 2)) / patch_size;
|
||||||
|
int w_len = (w + (patch_size / 2)) / patch_size;
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(3, 0.0));
|
||||||
|
|
||||||
|
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
|
||||||
|
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
|
||||||
|
|
||||||
|
for (int i = 0; i < h_len; ++i) {
|
||||||
|
for (int j = 0; j < w_len; ++j) {
|
||||||
|
img_ids[i * w_len + j][0] = index;
|
||||||
|
img_ids[i * w_len + j][1] = row_ids[i];
|
||||||
|
img_ids[i * w_len + j][2] = col_ids[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> img_ids_repeated(bs * img_ids.size(), std::vector<float>(3));
|
||||||
|
for (int i = 0; i < bs; ++i) {
|
||||||
|
for (int j = 0; j < img_ids.size(); ++j) {
|
||||||
|
img_ids_repeated[i * img_ids.size() + j] = img_ids[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return img_ids_repeated;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> concat_ids(const std::vector<std::vector<float>>& a,
|
||||||
|
const std::vector<std::vector<float>>& b,
|
||||||
|
int bs) {
|
||||||
|
size_t a_len = a.size() / bs;
|
||||||
|
size_t b_len = b.size() / bs;
|
||||||
|
std::vector<std::vector<float>> ids(a.size() + b.size(), std::vector<float>(3));
|
||||||
|
for (int i = 0; i < bs; ++i) {
|
||||||
|
for (int j = 0; j < a_len; ++j) {
|
||||||
|
ids[i * (a_len + b_len) + j] = a[i * a_len + j];
|
||||||
|
}
|
||||||
|
for (int j = 0; j < b_len; ++j) {
|
||||||
|
ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> embed_nd(const std::vector<std::vector<float>>& ids,
|
||||||
|
int bs,
|
||||||
|
int theta,
|
||||||
|
const std::vector<int>& axes_dim) {
|
||||||
|
std::vector<std::vector<float>> trans_ids = transpose(ids);
|
||||||
|
size_t pos_len = ids.size() / bs;
|
||||||
|
int num_axes = axes_dim.size();
|
||||||
|
// for (int i = 0; i < pos_len; i++) {
|
||||||
|
// std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl;
|
||||||
|
// }
|
||||||
|
|
||||||
|
int emb_dim = 0;
|
||||||
|
for (int d : axes_dim)
|
||||||
|
emb_dim += d / 2;
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2 * 2, 0.0));
|
||||||
|
int offset = 0;
|
||||||
|
for (int i = 0; i < num_axes; ++i) {
|
||||||
|
std::vector<std::vector<float>> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2]
|
||||||
|
for (int b = 0; b < bs; ++b) {
|
||||||
|
for (int j = 0; j < pos_len; ++j) {
|
||||||
|
for (int k = 0; k < rope_emb[0].size(); ++k) {
|
||||||
|
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset += rope_emb[0].size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return flatten(emb);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> gen_flux_ids(int h,
|
||||||
|
int w,
|
||||||
|
int patch_size,
|
||||||
|
int bs,
|
||||||
|
int context_len,
|
||||||
|
std::vector<ggml_tensor*> ref_latents) {
|
||||||
|
auto txt_ids = gen_txt_ids(bs, context_len);
|
||||||
|
auto img_ids = gen_img_ids(h, w, patch_size, bs);
|
||||||
|
|
||||||
|
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||||
|
uint64_t curr_h_offset = 0;
|
||||||
|
uint64_t curr_w_offset = 0;
|
||||||
|
for (ggml_tensor* ref : ref_latents) {
|
||||||
|
uint64_t h_offset = 0;
|
||||||
|
uint64_t w_offset = 0;
|
||||||
|
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
|
||||||
|
w_offset = curr_w_offset;
|
||||||
|
} else {
|
||||||
|
h_offset = curr_h_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
|
||||||
|
ids = concat_ids(ids, ref_ids, bs);
|
||||||
|
|
||||||
|
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
|
||||||
|
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
|
||||||
|
}
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate flux positional embeddings
|
||||||
|
static std::vector<float> gen_flux_pe(int h,
|
||||||
|
int w,
|
||||||
|
int patch_size,
|
||||||
|
int bs,
|
||||||
|
int context_len,
|
||||||
|
std::vector<ggml_tensor*> ref_latents,
|
||||||
|
int theta,
|
||||||
|
const std::vector<int>& axes_dim) {
|
||||||
|
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents);
|
||||||
|
return embed_nd(ids, bs, theta, axes_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::vector<float>> gen_vid_ids(int t,
|
||||||
|
int h,
|
||||||
|
int w,
|
||||||
|
int pt,
|
||||||
|
int ph,
|
||||||
|
int pw,
|
||||||
|
int bs,
|
||||||
|
int t_offset = 0,
|
||||||
|
int h_offset = 0,
|
||||||
|
int w_offset = 0) {
|
||||||
|
int t_len = (t + (pt / 2)) / pt;
|
||||||
|
int h_len = (h + (ph / 2)) / ph;
|
||||||
|
int w_len = (w + (pw / 2)) / pw;
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> vid_ids(t_len * h_len * w_len, std::vector<float>(3, 0.0));
|
||||||
|
|
||||||
|
std::vector<float> t_ids = linspace<float>(t_offset, t_len - 1 + t_offset, t_len);
|
||||||
|
std::vector<float> h_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
|
||||||
|
std::vector<float> w_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
|
||||||
|
|
||||||
|
for (int i = 0; i < t_len; ++i) {
|
||||||
|
for (int j = 0; j < h_len; ++j) {
|
||||||
|
for (int k = 0; k < w_len; ++k) {
|
||||||
|
int idx = i * h_len * w_len + j * w_len + k;
|
||||||
|
vid_ids[idx][0] = t_ids[i];
|
||||||
|
vid_ids[idx][1] = h_ids[j];
|
||||||
|
vid_ids[idx][2] = w_ids[k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> vid_ids_repeated(bs * vid_ids.size(), std::vector<float>(3));
|
||||||
|
for (int i = 0; i < bs; ++i) {
|
||||||
|
for (int j = 0; j < vid_ids.size(); ++j) {
|
||||||
|
vid_ids_repeated[i * vid_ids.size() + j] = vid_ids[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vid_ids_repeated;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate wan positional embeddings
|
||||||
|
static std::vector<float> gen_wan_pe(int t,
|
||||||
|
int h,
|
||||||
|
int w,
|
||||||
|
int pt,
|
||||||
|
int ph,
|
||||||
|
int pw,
|
||||||
|
int bs,
|
||||||
|
int theta,
|
||||||
|
const std::vector<int>& axes_dim) {
|
||||||
|
std::vector<std::vector<float>> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs);
|
||||||
|
return embed_nd(ids, bs, theta, axes_dim);
|
||||||
|
}
|
||||||
|
}; // struct Rope
|
||||||
|
|
||||||
|
#endif __ROPE_HPP__
|
||||||
2
util.cpp
2
util.cpp
@ -290,7 +290,7 @@ std::string path_join(const std::string& p1, const std::string& p2) {
|
|||||||
return p1 + "/" + p2;
|
return p1 + "/" + p2;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> splitString(const std::string& str, char delimiter) {
|
std::vector<std::string> split_string(const std::string& str, char delimiter) {
|
||||||
std::vector<std::string> result;
|
std::vector<std::string> result;
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
size_t end = str.find(delimiter);
|
size_t end = str.find(delimiter);
|
||||||
|
|||||||
2
util.h
2
util.h
@ -48,7 +48,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
|
|||||||
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
|
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
|
||||||
|
|
||||||
std::string path_join(const std::string& p1, const std::string& p2);
|
std::string path_join(const std::string& p1, const std::string& p2);
|
||||||
std::vector<std::string> splitString(const std::string& str, char delimiter);
|
std::vector<std::string> split_string(const std::string& str, char delimiter);
|
||||||
void pretty_progress(int step, int steps, float time);
|
void pretty_progress(int step, int steps, float time);
|
||||||
|
|
||||||
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...);
|
||||||
|
|||||||
832
wan.hpp
832
wan.hpp
@ -4,11 +4,14 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
|
#include "flux.hpp"
|
||||||
#include "ggml_extend.hpp"
|
#include "ggml_extend.hpp"
|
||||||
|
#include "rope.hpp"
|
||||||
|
|
||||||
namespace WAN {
|
namespace WAN {
|
||||||
|
|
||||||
constexpr int CACHE_T = 2;
|
constexpr int CACHE_T = 2;
|
||||||
|
constexpr int WAN_GRAPH_SIZE = 10240;
|
||||||
|
|
||||||
class CausalConv3d : public GGMLBlock {
|
class CausalConv3d : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
@ -21,14 +24,14 @@ namespace WAN {
|
|||||||
bool bias;
|
bool bias;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
params["weight"] = ggml_new_tensor_4d(ctx,
|
params["weight"] = ggml_new_tensor_4d(ctx,
|
||||||
GGML_TYPE_F16,
|
GGML_TYPE_F16,
|
||||||
std::get<2>(kernel_size),
|
std::get<2>(kernel_size),
|
||||||
std::get<1>(kernel_size),
|
std::get<1>(kernel_size),
|
||||||
std::get<0>(kernel_size),
|
std::get<0>(kernel_size),
|
||||||
in_channels * out_channels);
|
in_channels * out_channels);
|
||||||
if (bias) {
|
if (bias) {
|
||||||
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,10 +98,10 @@ namespace WAN {
|
|||||||
// assert N == 1
|
// assert N == 1
|
||||||
|
|
||||||
struct ggml_tensor* w = params["gamma"];
|
struct ggml_tensor* w = params["gamma"];
|
||||||
auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
auto h = ggml_cont(ctx, ggml_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC]
|
||||||
h = ggml_rms_norm(ctx, h, 1e-12);
|
h = ggml_rms_norm(ctx, h, 1e-12);
|
||||||
h = ggml_mul(ctx, h, w);
|
h = ggml_mul(ctx, h, w);
|
||||||
h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
h = ggml_cont(ctx, ggml_torch_permute(ctx, h, 1, 2, 3, 0));
|
||||||
|
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
@ -258,7 +261,7 @@ namespace WAN {
|
|||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
if (i == 0 || i == 3) { // RMS_norm
|
if (i == 0 || i == 3) { // RMS_norm
|
||||||
auto layer = std::dynamic_pointer_cast<RMS_norm>(blocks["residual." + std::to_string(i)]);
|
auto layer = std::dynamic_pointer_cast<RMS_norm>(blocks["residual." + std::to_string(i)]);
|
||||||
x = layer->forward(ctx, x);
|
x = layer->forward(ctx, x);
|
||||||
} else if (i == 2 || i == 6) { // CausalConv3d
|
} else if (i == 2 || i == 6) { // CausalConv3d
|
||||||
auto layer = std::dynamic_pointer_cast<CausalConv3d>(blocks["residual." + std::to_string(i)]);
|
auto layer = std::dynamic_pointer_cast<CausalConv3d>(blocks["residual." + std::to_string(i)]);
|
||||||
|
|
||||||
@ -312,7 +315,7 @@ namespace WAN {
|
|||||||
|
|
||||||
auto identity = x;
|
auto identity = x;
|
||||||
|
|
||||||
x = norm->forward(ctx, x);
|
x = norm->forward(ctx, x);
|
||||||
|
|
||||||
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w)
|
||||||
|
|
||||||
@ -783,7 +786,7 @@ namespace WAN {
|
|||||||
// cuda f16, pass
|
// cuda f16, pass
|
||||||
// cuda f32, pass
|
// cuda f32, pass
|
||||||
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 1, 16);
|
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 1, 16);
|
||||||
z = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
z = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||||
// ggml_set_f32(z, 0.5f);
|
// ggml_set_f32(z, 0.5f);
|
||||||
print_ggml_tensor(z);
|
print_ggml_tensor(z);
|
||||||
struct ggml_tensor* out = NULL;
|
struct ggml_tensor* out = NULL;
|
||||||
@ -798,7 +801,7 @@ namespace WAN {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static void load_from_file_and_test(const std::string& file_path) {
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
ggml_backend_t backend = ggml_backend_cuda_init(0);
|
ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
// ggml_backend_t backend = ggml_backend_cpu_init();
|
// ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
ggml_type model_data_type = GGML_TYPE_F32;
|
ggml_type model_data_type = GGML_TYPE_F32;
|
||||||
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend));
|
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend));
|
||||||
@ -828,6 +831,799 @@ namespace WAN {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
};
|
class WanSelfAttention : public GGMLBlock {
|
||||||
|
public:
|
||||||
|
int64_t num_heads;
|
||||||
|
int64_t head_dim;
|
||||||
|
bool flash_attn;
|
||||||
|
|
||||||
#endif
|
public:
|
||||||
|
WanSelfAttention(int64_t dim,
|
||||||
|
int64_t num_heads,
|
||||||
|
bool qk_norm = true,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: num_heads(num_heads), flash_attn(flash_attn) {
|
||||||
|
head_dim = dim / num_heads;
|
||||||
|
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
blocks["o"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
|
||||||
|
if (qk_norm) {
|
||||||
|
blocks["norm_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
|
||||||
|
blocks["norm_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
|
||||||
|
} else {
|
||||||
|
blocks["norm_q"] = std::shared_ptr<GGMLBlock>(new Identity());
|
||||||
|
blocks["norm_k"] = std::shared_ptr<GGMLBlock>(new Identity());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* mask = NULL) {
|
||||||
|
// x: [N, n_token, dim]
|
||||||
|
// pe: [n_token, d_head/2, 2, 2]
|
||||||
|
// return [N, n_token, dim]
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t n_token = x->ne[1];
|
||||||
|
|
||||||
|
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q"]);
|
||||||
|
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k"]);
|
||||||
|
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v"]);
|
||||||
|
auto o_proj = std::dynamic_pointer_cast<Linear>(blocks["o"]);
|
||||||
|
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
|
||||||
|
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
|
||||||
|
|
||||||
|
auto q = q_proj->forward(ctx, x);
|
||||||
|
q = norm_q->forward(ctx, q);
|
||||||
|
auto k = k_proj->forward(ctx, x);
|
||||||
|
k = norm_k->forward(ctx, k);
|
||||||
|
auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head]
|
||||||
|
|
||||||
|
q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
|
k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
|
v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head]
|
||||||
|
|
||||||
|
x = Flux::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim]
|
||||||
|
|
||||||
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class WanCrossAttention : public WanSelfAttention {
|
||||||
|
public:
|
||||||
|
WanCrossAttention(int64_t dim,
|
||||||
|
int64_t num_heads,
|
||||||
|
bool qk_norm = true,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
|
||||||
|
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
int64_t context_img_len) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WanT2VCrossAttention : public WanCrossAttention {
|
||||||
|
public:
|
||||||
|
WanT2VCrossAttention(int64_t dim,
|
||||||
|
int64_t num_heads,
|
||||||
|
bool qk_norm = true,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {}
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
int64_t context_img_len) {
|
||||||
|
// x: [N, n_token, dim]
|
||||||
|
// context: [N, n_context, dim]
|
||||||
|
// context_img_len: unused
|
||||||
|
// return [N, n_token, dim]
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t n_token = x->ne[1];
|
||||||
|
|
||||||
|
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q"]);
|
||||||
|
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k"]);
|
||||||
|
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v"]);
|
||||||
|
auto o_proj = std::dynamic_pointer_cast<Linear>(blocks["o"]);
|
||||||
|
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
|
||||||
|
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
|
||||||
|
|
||||||
|
auto q = q_proj->forward(ctx, x);
|
||||||
|
q = norm_q->forward(ctx, q);
|
||||||
|
auto k = k_proj->forward(ctx, context); // [N, n_context, dim]
|
||||||
|
k = norm_k->forward(ctx, k);
|
||||||
|
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||||
|
|
||||||
|
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
|
||||||
|
|
||||||
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class WanI2VCrossAttention : public WanCrossAttention {
|
||||||
|
public:
|
||||||
|
WanI2VCrossAttention(int64_t dim,
|
||||||
|
int64_t num_heads,
|
||||||
|
bool qk_norm = true,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {
|
||||||
|
blocks["k_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
blocks["v_img"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
|
||||||
|
|
||||||
|
if (qk_norm) {
|
||||||
|
blocks["norm_k_img"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim, eps));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
int64_t context_img_len) {
|
||||||
|
// x: [N, n_token, dim]
|
||||||
|
// context: [N, context_img_len + context_txt_len, dim]
|
||||||
|
// return [N, n_token, dim]
|
||||||
|
|
||||||
|
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q"]);
|
||||||
|
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k"]);
|
||||||
|
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v"]);
|
||||||
|
auto o_proj = std::dynamic_pointer_cast<Linear>(blocks["o"]);
|
||||||
|
|
||||||
|
auto k_img_proj = std::dynamic_pointer_cast<Linear>(blocks["k_img"]);
|
||||||
|
auto v_img_proj = std::dynamic_pointer_cast<Linear>(blocks["v_img"]);
|
||||||
|
|
||||||
|
auto norm_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_q"]);
|
||||||
|
auto norm_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k"]);
|
||||||
|
auto norm_k_img = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm_k_img"]);
|
||||||
|
|
||||||
|
int64_t N = x->ne[2];
|
||||||
|
int64_t n_token = x->ne[1];
|
||||||
|
int64_t dim = x->ne[2];
|
||||||
|
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||||
|
|
||||||
|
context = ggml_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||||
|
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
||||||
|
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
|
||||||
|
context_img = ggml_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||||
|
context_txt = ggml_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
|
auto q = q_proj->forward(ctx, x);
|
||||||
|
q = norm_q->forward(ctx, q);
|
||||||
|
auto k = k_proj->forward(ctx, context_txt); // [N, context_txt_len, dim]
|
||||||
|
k = norm_k->forward(ctx, k);
|
||||||
|
auto v = v_proj->forward(ctx, context_txt); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
|
auto k_img = k_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||||
|
k_img = norm_k_img->forward(ctx, k_img);
|
||||||
|
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||||
|
|
||||||
|
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim]
|
||||||
|
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
|
||||||
|
|
||||||
|
x = ggml_add(ctx, x, img_x);
|
||||||
|
|
||||||
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class WanAttentionBlock : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int dim;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
|
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||||
|
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
WanAttentionBlock(bool t2v_cross_attn,
|
||||||
|
int64_t dim,
|
||||||
|
int64_t ffn_dim,
|
||||||
|
int64_t num_heads,
|
||||||
|
bool qk_norm = true,
|
||||||
|
bool cross_attn_norm = false,
|
||||||
|
float eps = 1e-6,
|
||||||
|
bool flash_attn = false)
|
||||||
|
: dim(dim) {
|
||||||
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps));
|
||||||
|
if (cross_attn_norm) {
|
||||||
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
|
||||||
|
} else {
|
||||||
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
|
||||||
|
}
|
||||||
|
if (t2v_cross_attn) {
|
||||||
|
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps));
|
||||||
|
} else {
|
||||||
|
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps));
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
|
||||||
|
blocks["ffn.0"] = std::shared_ptr<GGMLBlock>(new Linear(dim, ffn_dim));
|
||||||
|
// ffn.1 is nn.GELU(approximate='tanh')
|
||||||
|
blocks["ffn.2"] = std::shared_ptr<GGMLBlock>(new Linear(ffn_dim, dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* e,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
int64_t context_img_len = 257) {
|
||||||
|
// x: [N, n_token, dim]
|
||||||
|
// e: [N, 6, dim]
|
||||||
|
// context: [N, context_img_len + context_txt_len, dim]
|
||||||
|
// return [N, n_token, dim]
|
||||||
|
|
||||||
|
auto modulation = params["modulation"];
|
||||||
|
e = ggml_add(ctx, modulation, e); // [N, 6, dim]
|
||||||
|
auto es = ggml_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...)
|
||||||
|
|
||||||
|
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
|
||||||
|
auto self_attn = std::dynamic_pointer_cast<WanSelfAttention>(blocks["self_attn"]);
|
||||||
|
auto norm3 = std::dynamic_pointer_cast<UnaryBlock>(blocks["norm3"]);
|
||||||
|
auto cross_attn = std::dynamic_pointer_cast<WanCrossAttention>(blocks["cross_attn"]);
|
||||||
|
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
|
||||||
|
auto ffn_0 = std::dynamic_pointer_cast<Linear>(blocks["ffn.0"]);
|
||||||
|
auto ffn_2 = std::dynamic_pointer_cast<Linear>(blocks["ffn.2"]);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
auto y = norm1->forward(ctx, x);
|
||||||
|
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[1]));
|
||||||
|
y = ggml_add(ctx, y, es[0]);
|
||||||
|
y = self_attn->forward(ctx, y, pe);
|
||||||
|
|
||||||
|
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[2]));
|
||||||
|
|
||||||
|
// cross-attention
|
||||||
|
x = ggml_add(ctx,
|
||||||
|
x,
|
||||||
|
cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len));
|
||||||
|
|
||||||
|
// ffn
|
||||||
|
y = norm2->forward(ctx, x);
|
||||||
|
y = ggml_add(ctx, y, ggml_mul(ctx, y, es[4]));
|
||||||
|
y = ggml_add(ctx, y, es[3]);
|
||||||
|
|
||||||
|
y = ffn_0->forward(ctx, y);
|
||||||
|
y = ggml_gelu_inplace(ctx, y);
|
||||||
|
y = ffn_2->forward(ctx, y);
|
||||||
|
|
||||||
|
x = ggml_add(ctx, x, ggml_mul(ctx, y, es[5]));
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class Head : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int dim;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
|
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
|
||||||
|
params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 2, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Head(int64_t dim,
|
||||||
|
int64_t out_dim,
|
||||||
|
std::tuple<int, int, int> patch_size,
|
||||||
|
float eps = 1e-6)
|
||||||
|
: dim(dim) {
|
||||||
|
out_dim = out_dim * std::get<0>(patch_size) * std::get<1>(patch_size) * std::get<2>(patch_size);
|
||||||
|
|
||||||
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
|
blocks["head"] = std::shared_ptr<GGMLBlock>(new Linear(dim, out_dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* e) {
|
||||||
|
// x: [N, n_token, dim]
|
||||||
|
// e: [N, dim]
|
||||||
|
// return [N, n_token, out_dim]
|
||||||
|
|
||||||
|
auto modulation = params["modulation"];
|
||||||
|
e = ggml_add(ctx, modulation, ggml_reshape_3d(ctx, e, e->ne[0], 1, e->ne[1])); // [N, 2, dim]
|
||||||
|
auto es = ggml_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...)
|
||||||
|
|
||||||
|
auto norm = std::dynamic_pointer_cast<LayerNorm>(blocks["norm"]);
|
||||||
|
auto head = std::dynamic_pointer_cast<Linear>(blocks["head"]);
|
||||||
|
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
x = ggml_add(ctx, x, ggml_mul(ctx, x, es[1]));
|
||||||
|
x = ggml_add(ctx, x, es[0]);
|
||||||
|
x = head->forward(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MLPProj : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
int in_dim;
|
||||||
|
int flf_pos_embed_token_number;
|
||||||
|
|
||||||
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
|
if (flf_pos_embed_token_number > 0) {
|
||||||
|
params["emb_pos"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_dim, flf_pos_embed_token_number, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
MLPProj(int64_t in_dim,
|
||||||
|
int64_t out_dim,
|
||||||
|
int64_t flf_pos_embed_token_number = 0)
|
||||||
|
: in_dim(in_dim), flf_pos_embed_token_number(flf_pos_embed_token_number) {
|
||||||
|
blocks["proj.0"] = std::shared_ptr<GGMLBlock>(new LayerNorm(in_dim));
|
||||||
|
blocks["proj.1"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, in_dim));
|
||||||
|
// proj.2 is nn.GELU()
|
||||||
|
blocks["proj.3"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, out_dim));
|
||||||
|
blocks["proj.4"] = std::shared_ptr<GGMLBlock>(new LayerNorm(out_dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* image_embeds) {
|
||||||
|
if (flf_pos_embed_token_number > 0) {
|
||||||
|
auto emb_pos = params["emb_pos"];
|
||||||
|
|
||||||
|
auto a = ggml_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]);
|
||||||
|
auto b = ggml_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]);
|
||||||
|
|
||||||
|
image_embeds = ggml_add(ctx, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto proj_0 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.0"]);
|
||||||
|
auto proj_1 = std::dynamic_pointer_cast<Linear>(blocks["proj.1"]);
|
||||||
|
auto proj_3 = std::dynamic_pointer_cast<Linear>(blocks["proj.3"]);
|
||||||
|
auto proj_4 = std::dynamic_pointer_cast<LayerNorm>(blocks["proj.4"]);
|
||||||
|
|
||||||
|
auto x = proj_0->forward(ctx, image_embeds);
|
||||||
|
x = proj_1->forward(ctx, x);
|
||||||
|
x = ggml_gelu_inplace(ctx, x);
|
||||||
|
x = proj_3->forward(ctx, x);
|
||||||
|
x = proj_4->forward(ctx, x);
|
||||||
|
|
||||||
|
return x; // clip_extra_context_tokens
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WanParams {
|
||||||
|
std::string model_type = "t2v";
|
||||||
|
std::tuple<int, int, int> patch_size = {1, 2, 2};
|
||||||
|
int64_t text_len = 512;
|
||||||
|
int64_t in_dim = 16;
|
||||||
|
int64_t dim = 2048;
|
||||||
|
int64_t ffn_dim = 8192;
|
||||||
|
int64_t freq_dim = 256;
|
||||||
|
int64_t text_dim = 4096;
|
||||||
|
int64_t out_dim = 16;
|
||||||
|
int64_t num_heads = 16;
|
||||||
|
int64_t num_layers = 32;
|
||||||
|
bool qk_norm = true;
|
||||||
|
bool cross_attn_norm = true;
|
||||||
|
float eps = 1e-6;
|
||||||
|
int64_t flf_pos_embed_token_number = 0;
|
||||||
|
int theta = 10000;
|
||||||
|
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
|
||||||
|
std::vector<int> axes_dim = {44, 42, 42};
|
||||||
|
int64_t axes_dim_sum = 128;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WanModel : public GGMLBlock {
|
||||||
|
protected:
|
||||||
|
WanParams params;
|
||||||
|
|
||||||
|
public:
|
||||||
|
WanModel() {}
|
||||||
|
WanModel(WanParams params)
|
||||||
|
: params(params) {
|
||||||
|
// patch_embedding
|
||||||
|
blocks["patch_embedding"] = std::shared_ptr<GGMLBlock>(new Conv3d(params.in_dim, params.dim, params.patch_size, params.patch_size));
|
||||||
|
|
||||||
|
// text_embedding
|
||||||
|
blocks["text_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(params.text_dim, params.dim));
|
||||||
|
// text_embedding.1 is nn.GELU()
|
||||||
|
blocks["text_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim));
|
||||||
|
|
||||||
|
// time_embedding
|
||||||
|
blocks["time_embedding.0"] = std::shared_ptr<GGMLBlock>(new Linear(params.freq_dim, params.dim));
|
||||||
|
// time_embedding.1 is nn.SiLU()
|
||||||
|
blocks["time_embedding.2"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim));
|
||||||
|
|
||||||
|
// time_projection.0 is nn.SiLU()
|
||||||
|
blocks["time_projection.1"] = std::shared_ptr<GGMLBlock>(new Linear(params.dim, params.dim * 6));
|
||||||
|
|
||||||
|
// blocks
|
||||||
|
for (int i = 0; i < params.num_layers; i++) {
|
||||||
|
auto block = std::shared_ptr<GGMLBlock>(new WanAttentionBlock(params.model_type == "t2v",
|
||||||
|
params.dim,
|
||||||
|
params.ffn_dim,
|
||||||
|
params.num_heads,
|
||||||
|
params.qk_norm,
|
||||||
|
params.cross_attn_norm,
|
||||||
|
params.eps));
|
||||||
|
blocks["blocks." + std::to_string(i)] = block;
|
||||||
|
}
|
||||||
|
|
||||||
|
// head
|
||||||
|
blocks["head"] = std::shared_ptr<GGMLBlock>(new Head(params.dim, params.out_dim, params.patch_size, params.eps));
|
||||||
|
|
||||||
|
// img_emb
|
||||||
|
if (params.model_type == "i2v") {
|
||||||
|
blocks["img_emb"] = std::shared_ptr<GGMLBlock>(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t T = x->ne[1];
|
||||||
|
|
||||||
|
int pad_t = (std::get<0>(params.patch_size) - T % std::get<0>(params.patch_size)) % std::get<0>(params.patch_size);
|
||||||
|
int pad_h = (std::get<1>(params.patch_size) - H % std::get<1>(params.patch_size)) % std::get<1>(params.patch_size);
|
||||||
|
int pad_w = (std::get<2>(params.patch_size) - W % std::get<2>(params.patch_size)) % std::get<2>(params.patch_size);
|
||||||
|
x = ggml_pad(ctx, x, pad_w, pad_h, pad_t, 0); // [N*C, T + pad_t, H + pad_h, W + pad_w]
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
int64_t t_len,
|
||||||
|
int64_t h_len,
|
||||||
|
int64_t w_len) {
|
||||||
|
// x: [N, t_len*h_len*w_len, pt*ph*pw*C]
|
||||||
|
// return: [N*C, t_len*pt, h_len*ph, w_len*pw]
|
||||||
|
int64_t N = x->ne[3];
|
||||||
|
int64_t pt = std::get<0>(params.patch_size);
|
||||||
|
int64_t ph = std::get<1>(params.patch_size);
|
||||||
|
int64_t pw = std::get<2>(params.patch_size);
|
||||||
|
int64_t C = x->ne[0] / pt / ph / pw;
|
||||||
|
|
||||||
|
GGML_ASSERT(C * pt * ph * pw == x->ne[0]);
|
||||||
|
|
||||||
|
x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw]
|
||||||
|
x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward_orig(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timestep,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* clip_fea = NULL,
|
||||||
|
int64_t N = 1) {
|
||||||
|
// x: [N*C, T, H, W], C => in_dim
|
||||||
|
// timestep: [N,]
|
||||||
|
// context: [N, L, text_dim]
|
||||||
|
// return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw]
|
||||||
|
|
||||||
|
auto patch_embedding = std::dynamic_pointer_cast<Conv3d>(blocks["patch_embedding"]);
|
||||||
|
|
||||||
|
auto text_embedding_0 = std::dynamic_pointer_cast<Linear>(blocks["text_embedding.0"]);
|
||||||
|
auto text_embedding_2 = std::dynamic_pointer_cast<Linear>(blocks["text_embedding.2"]);
|
||||||
|
|
||||||
|
auto time_embedding_0 = std::dynamic_pointer_cast<Linear>(blocks["time_embedding.0"]);
|
||||||
|
auto time_embedding_2 = std::dynamic_pointer_cast<Linear>(blocks["time_embedding.2"]);
|
||||||
|
auto time_projection_1 = std::dynamic_pointer_cast<Linear>(blocks["time_projection.1"]);
|
||||||
|
|
||||||
|
auto head = std::dynamic_pointer_cast<Head>(blocks["head"]);
|
||||||
|
|
||||||
|
// patch_embedding
|
||||||
|
x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len]
|
||||||
|
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len]
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim]
|
||||||
|
|
||||||
|
// time_embedding
|
||||||
|
auto e = ggml_nn_timestep_embedding(ctx, timestep, params.freq_dim);
|
||||||
|
e = time_embedding_0->forward(ctx, e);
|
||||||
|
e = ggml_silu_inplace(ctx, e);
|
||||||
|
e = time_embedding_2->forward(ctx, e); // [N, dim]
|
||||||
|
// time_projection
|
||||||
|
auto e0 = ggml_silu(ctx, e);
|
||||||
|
e0 = time_projection_1->forward(ctx, e0);
|
||||||
|
e0 = ggml_reshape_3d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1]); // [N, 6, dim]
|
||||||
|
|
||||||
|
context = text_embedding_0->forward(ctx, context);
|
||||||
|
context = ggml_gelu(ctx, context);
|
||||||
|
context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim]
|
||||||
|
|
||||||
|
int64_t context_img_len = 0;
|
||||||
|
if (clip_fea != NULL) {
|
||||||
|
if (params.model_type == "i2v") {
|
||||||
|
auto img_emb = std::dynamic_pointer_cast<MLPProj>(blocks["img_emb"]);
|
||||||
|
auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim]
|
||||||
|
context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim]
|
||||||
|
}
|
||||||
|
context_img_len = clip_fea->ne[1]; // 257
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < params.num_layers; i++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<WanAttentionBlock>(blocks["blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
|
x = block->forward(ctx, x, e0, pe, context, context_img_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim]
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timestep,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* clip_fea = NULL,
|
||||||
|
struct ggml_tensor* time_dim_concat = NULL,
|
||||||
|
int64_t N = 1) {
|
||||||
|
// Forward pass of DiT.
|
||||||
|
// x: [N*C, T, H, W]
|
||||||
|
// timestep: [N,]
|
||||||
|
// context: [N, L, D]
|
||||||
|
// pe: [L, d_head/2, 2, 2]
|
||||||
|
// time_dim_concat: [N*C, T2, H, W]
|
||||||
|
// return: [N*C, T, H, W]
|
||||||
|
|
||||||
|
GGML_ASSERT(N == 1);
|
||||||
|
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t T = x->ne[2];
|
||||||
|
int64_t C = x->ne[3];
|
||||||
|
|
||||||
|
x = pad_to_patch_size(ctx, x);
|
||||||
|
|
||||||
|
int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
||||||
|
int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size));
|
||||||
|
int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size));
|
||||||
|
|
||||||
|
if (time_dim_concat != NULL) {
|
||||||
|
time_dim_concat = pad_to_patch_size(ctx, time_dim_concat);
|
||||||
|
x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
||||||
|
t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C]
|
||||||
|
|
||||||
|
out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w]
|
||||||
|
|
||||||
|
// slice
|
||||||
|
|
||||||
|
out = ggml_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w]
|
||||||
|
out = ggml_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w]
|
||||||
|
out = ggml_slice(ctx, out, 0, 0, W); // [N*C, T, H, W]
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WanRunner : public GGMLRunner {
|
||||||
|
public:
|
||||||
|
WanParams wan_params;
|
||||||
|
WanModel wan;
|
||||||
|
std::vector<float> pe_vec;
|
||||||
|
SDVersion version;
|
||||||
|
|
||||||
|
WanRunner(ggml_backend_t backend,
|
||||||
|
const String2GGMLType& tensor_types = {},
|
||||||
|
const std::string prefix = "",
|
||||||
|
SDVersion version = VERSION_WAN_2_1)
|
||||||
|
: GGMLRunner(backend) {
|
||||||
|
wan_params.num_layers = 0;
|
||||||
|
for (auto pair : tensor_types) {
|
||||||
|
std::string tensor_name = pair.first;
|
||||||
|
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
|
||||||
|
continue;
|
||||||
|
size_t pos = tensor_name.find("blocks.");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
tensor_name = tensor_name.substr(pos); // remove prefix
|
||||||
|
auto items = split_string(tensor_name, '.');
|
||||||
|
if (items.size() > 1) {
|
||||||
|
int block_index = atoi(items[1].c_str());
|
||||||
|
if (block_index + 1 > wan_params.num_layers) {
|
||||||
|
wan_params.num_layers = block_index + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tensor_name.find("img_emb") != std::string::npos) {
|
||||||
|
wan_params.model_type = "i2v";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wan_params.num_layers == 30) {
|
||||||
|
LOG_INFO("Wan2.1-T2V-1.3B");
|
||||||
|
wan_params.dim = 1536;
|
||||||
|
wan_params.eps = 1e-06;
|
||||||
|
wan_params.ffn_dim = 8960;
|
||||||
|
wan_params.freq_dim = 256;
|
||||||
|
wan_params.in_dim = 16;
|
||||||
|
wan_params.num_heads = 12;
|
||||||
|
wan_params.out_dim = 16;
|
||||||
|
wan_params.text_len = 512;
|
||||||
|
} else if (wan_params.num_layers == 40) {
|
||||||
|
if (wan_params.model_type == "t2v") {
|
||||||
|
LOG_INFO("Wan2.1-T2V-14B");
|
||||||
|
} else {
|
||||||
|
LOG_INFO("Wan2.1-I2V-14B");
|
||||||
|
}
|
||||||
|
wan_params.dim = 5120;
|
||||||
|
wan_params.eps = 1e-06;
|
||||||
|
wan_params.ffn_dim = 13824;
|
||||||
|
wan_params.freq_dim = 256;
|
||||||
|
wan_params.in_dim = 16;
|
||||||
|
wan_params.num_heads = 40;
|
||||||
|
wan_params.out_dim = 16;
|
||||||
|
wan_params.text_len = 512;
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
|
||||||
|
}
|
||||||
|
|
||||||
|
wan = WanModel(wan_params);
|
||||||
|
wan.init(params_ctx, tensor_types, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() {
|
||||||
|
return "wan";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
|
wan.get_param_tensors(tensors, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timesteps,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* clip_fea = NULL,
|
||||||
|
struct ggml_tensor* time_dim_concat = NULL) {
|
||||||
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false);
|
||||||
|
|
||||||
|
x = to_backend(x);
|
||||||
|
timesteps = to_backend(timesteps);
|
||||||
|
context = to_backend(context);
|
||||||
|
clip_fea = to_backend(clip_fea);
|
||||||
|
time_dim_concat = to_backend(time_dim_concat);
|
||||||
|
|
||||||
|
pe_vec = Rope::gen_wan_pe(x->ne[2],
|
||||||
|
x->ne[1],
|
||||||
|
x->ne[0],
|
||||||
|
std::get<0>(wan_params.patch_size),
|
||||||
|
std::get<1>(wan_params.patch_size),
|
||||||
|
std::get<2>(wan_params.patch_size),
|
||||||
|
1,
|
||||||
|
wan_params.theta,
|
||||||
|
wan_params.axes_dim);
|
||||||
|
int pos_len = pe_vec.size() / wan_params.axes_dim_sum / 2;
|
||||||
|
// LOG_DEBUG("pos_len %d", pos_len);
|
||||||
|
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, wan_params.axes_dim_sum / 2, pos_len);
|
||||||
|
// pe->data = pe_vec.data();
|
||||||
|
// print_ggml_tensor(pe);
|
||||||
|
// pe->data = NULL;
|
||||||
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
|
struct ggml_tensor* out = wan.forward(compute_ctx,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
pe,
|
||||||
|
clip_fea,
|
||||||
|
time_dim_concat);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, out);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(int n_threads,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timesteps,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* clip_fea = NULL,
|
||||||
|
struct ggml_tensor* time_dim_concat = NULL,
|
||||||
|
struct ggml_tensor** output = NULL,
|
||||||
|
struct ggml_context* output_ctx = NULL) {
|
||||||
|
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||||
|
return build_graph(x, timesteps, context, clip_fea, time_dim_concat);
|
||||||
|
};
|
||||||
|
|
||||||
|
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test() {
|
||||||
|
struct ggml_init_params params;
|
||||||
|
params.mem_size = static_cast<size_t>(20 * 1024 * 1024); // 20 MB
|
||||||
|
params.mem_buffer = NULL;
|
||||||
|
params.no_alloc = false;
|
||||||
|
|
||||||
|
struct ggml_context* work_ctx = ggml_init(params);
|
||||||
|
GGML_ASSERT(work_ctx != NULL);
|
||||||
|
|
||||||
|
{
|
||||||
|
// cpu f16: pass
|
||||||
|
// cuda f16: pass
|
||||||
|
// cpu q8_0: pass
|
||||||
|
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 1, 16);
|
||||||
|
// ggml_set_f32(x, 0.01f);
|
||||||
|
auto x = load_tensor_from_file(work_ctx, "wan_dit_x.bin");
|
||||||
|
// print_ggml_tensor(x);
|
||||||
|
|
||||||
|
std::vector<float> timesteps_vec(1, 999.f);
|
||||||
|
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||||
|
|
||||||
|
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 512, 1);
|
||||||
|
// ggml_set_f32(context, 0.01f);
|
||||||
|
auto context = load_tensor_from_file(work_ctx, "wan_dit_context.bin");
|
||||||
|
// print_ggml_tensor(context);
|
||||||
|
|
||||||
|
struct ggml_tensor* out = NULL;
|
||||||
|
|
||||||
|
int t0 = ggml_time_ms();
|
||||||
|
compute(8, x, timesteps, context, NULL, NULL, &out, work_ctx);
|
||||||
|
int t1 = ggml_time_ms();
|
||||||
|
|
||||||
|
print_ggml_tensor(out);
|
||||||
|
LOG_DEBUG("wan test done in %dms", t1 - t0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
|
ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
|
// ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
|
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||||
|
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||||
|
|
||||||
|
ModelLoader model_loader;
|
||||||
|
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
|
||||||
|
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor_types = model_loader.tensor_storages_types;
|
||||||
|
for (auto& item : tensor_types) {
|
||||||
|
LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||||
|
if (ends_with(item.first, "weight")) {
|
||||||
|
item.second = model_data_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<WanRunner> wan = std::shared_ptr<WanRunner>(new WanRunner(backend,
|
||||||
|
tensor_types,
|
||||||
|
"model.diffusion_model"));
|
||||||
|
|
||||||
|
wan->alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
wan->get_param_tensors(tensors, "model.diffusion_model");
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(tensors, backend);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tensors from model loader failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("wan model loaded");
|
||||||
|
|
||||||
|
wan->test();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace WAN
|
||||||
|
|
||||||
|
#endif // __WAN_HPP__
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user