mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan2.1 i2v support
This commit is contained in:
parent
9b29de27a8
commit
d83867b8e9
15
clip.hpp
15
clip.hpp
@ -851,16 +851,21 @@ public:
|
||||
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) {
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||
struct ggml_tensor* pixel_values,
|
||||
bool return_pooled = true) {
|
||||
// pixel_values: [N, num_channels, image_size, image_size]
|
||||
// return: [N, projection_dim]
|
||||
// return: [N, projection_dim] if return_pooled else [N, n_token, hidden_size]
|
||||
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
|
||||
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
|
||||
|
||||
auto x = vision_model->forward(ctx, pixel_values); // [N, hidden_size]
|
||||
x = visual_projection->forward(ctx, x); // [N, projection_dim]
|
||||
auto x = vision_model->forward(ctx, pixel_values, return_pooled); // [N, hidden_size] or [N, n_token, hidden_size]
|
||||
|
||||
return x; // [N, projection_dim]
|
||||
if (return_pooled) {
|
||||
x = visual_projection->forward(ctx, x); // [N, projection_dim]
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -622,7 +622,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
FrozenCLIPVisionEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, offload_params_to_cpu) {
|
||||
: vision_model(OPEN_CLIP_VIT_H_14), GGMLRunner(backend, offload_params_to_cpu) {
|
||||
vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer");
|
||||
}
|
||||
|
||||
@ -634,12 +634,12 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
vision_model.get_param_tensors(tensors, "cond_stage_model.transformer");
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) {
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values, bool return_pooled) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
pixel_values = to_backend(pixel_values);
|
||||
|
||||
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values);
|
||||
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled);
|
||||
|
||||
ggml_build_forward_expand(gf, hidden_states);
|
||||
|
||||
@ -648,10 +648,11 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
|
||||
|
||||
void compute(const int n_threads,
|
||||
ggml_tensor* pixel_values,
|
||||
bool return_pooled,
|
||||
ggml_tensor** output,
|
||||
ggml_context* output_ctx) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(pixel_values);
|
||||
return build_graph(pixel_values, return_pooled);
|
||||
};
|
||||
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "wan.hpp"
|
||||
|
||||
struct DiffusionModel {
|
||||
virtual std::string get_desc() = 0;
|
||||
virtual void compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
@ -40,6 +41,10 @@ struct UNetModel : public DiffusionModel {
|
||||
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return unet.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
unet.alloc_params_buffer();
|
||||
}
|
||||
@ -92,6 +97,10 @@ struct MMDiTModel : public DiffusionModel {
|
||||
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return mmdit.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
mmdit.alloc_params_buffer();
|
||||
}
|
||||
@ -146,6 +155,10 @@ struct FluxModel : public DiffusionModel {
|
||||
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return flux.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
flux.alloc_params_buffer();
|
||||
}
|
||||
@ -199,6 +212,10 @@ struct WanModel : public DiffusionModel {
|
||||
: wan(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return wan.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
wan.alloc_params_buffer();
|
||||
}
|
||||
@ -237,7 +254,7 @@ struct WanModel : public DiffusionModel {
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL,
|
||||
std::vector<int> skip_layers = std::vector<int>()) {
|
||||
return wan.compute(n_threads, x, timesteps, context, NULL, NULL, output, output_ctx);
|
||||
return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ struct SDParams {
|
||||
std::string model_path;
|
||||
std::string clip_l_path;
|
||||
std::string clip_g_path;
|
||||
std::string clip_vision_path;
|
||||
std::string t5xxl_path;
|
||||
std::string diffusion_model_path;
|
||||
std::string vae_path;
|
||||
@ -123,6 +124,7 @@ void print_params(SDParams params) {
|
||||
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
||||
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
|
||||
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
||||
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
|
||||
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
||||
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
@ -186,6 +188,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --diffusion-model path to the standalone diffusion model\n");
|
||||
printf(" --clip_l path to the clip-l text encoder\n");
|
||||
printf(" --clip_g path to the clip-g text encoder\n");
|
||||
printf(" --clip_vision path to the clip-vision encoder\n");
|
||||
printf(" --t5xxl path to the t5xxl text encoder\n");
|
||||
printf(" --vae [VAE] path to vae\n");
|
||||
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
||||
@ -414,6 +417,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
{"-m", "--model", "", ¶ms.model_path},
|
||||
{"", "--clip_l", "", ¶ms.clip_l_path},
|
||||
{"", "--clip_g", "", ¶ms.clip_g_path},
|
||||
{"", "--clip_vision", "", ¶ms.clip_vision_path},
|
||||
{"", "--t5xxl", "", ¶ms.t5xxl_path},
|
||||
{"", "--diffusion-model", "", ¶ms.diffusion_model_path},
|
||||
{"", "--vae", "", ¶ms.vae_path},
|
||||
@ -927,10 +931,15 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
if (params.mode == VID_GEN) {
|
||||
vae_decode_only = false;
|
||||
}
|
||||
|
||||
sd_ctx_params_t sd_ctx_params = {
|
||||
params.model_path.c_str(),
|
||||
params.clip_l_path.c_str(),
|
||||
params.clip_g_path.c_str(),
|
||||
params.clip_vision_path.c_str(),
|
||||
params.t5xxl_path.c_str(),
|
||||
params.diffusion_model_path.c_str(),
|
||||
params.vae_path.c_str(),
|
||||
|
||||
@ -589,7 +589,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ct
|
||||
}
|
||||
|
||||
// convert values from [0, 1] to [-1, 1]
|
||||
__STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) {
|
||||
__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
|
||||
int64_t nelements = ggml_nelements(src);
|
||||
float* data = (float*)src->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
@ -599,7 +599,7 @@ __STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) {
|
||||
}
|
||||
|
||||
// convert values from [-1, 1] to [0, 1]
|
||||
__STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
|
||||
__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
|
||||
int64_t nelements = ggml_nelements(src);
|
||||
float* data = (float*)src->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
|
||||
14
model.cpp
14
model.cpp
@ -89,6 +89,7 @@ const char* unused_tensors[] = {
|
||||
"posterior_mean_coef1",
|
||||
"posterior_mean_coef2",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_ids",
|
||||
"cond_stage_model.transformer.vision_model.embeddings.position_ids",
|
||||
"cond_stage_model.model.logit_scale",
|
||||
"cond_stage_model.model.text_projection",
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
|
||||
@ -142,6 +143,11 @@ std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
|
||||
{"mlp.c_proj.weight", "mlp.fc2.weight"},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string> cond_model_name_map = {
|
||||
{"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"},
|
||||
{"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"},
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, std::string> vae_decoder_name_map = {
|
||||
{"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"},
|
||||
{"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"},
|
||||
@ -180,7 +186,7 @@ std::unordered_map<std::string, std::string> pmid_v2_name_map = {
|
||||
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
||||
};
|
||||
|
||||
std::string convert_open_clip_to_hf_clip(const std::string& name) {
|
||||
std::string convert_cond_model_name(const std::string& name) {
|
||||
std::string new_name = name;
|
||||
std::string prefix;
|
||||
if (contains(new_name, ".enc.")) {
|
||||
@ -269,6 +275,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
|
||||
new_name = open_clip_to_hf_clip_model[new_name];
|
||||
}
|
||||
|
||||
if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) {
|
||||
new_name = cond_model_name_map[new_name];
|
||||
}
|
||||
|
||||
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
|
||||
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
|
||||
|
||||
@ -564,7 +574,7 @@ std::string convert_tensor_name(std::string name) {
|
||||
// }
|
||||
std::string new_name = name;
|
||||
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
|
||||
new_name = convert_open_clip_to_hf_clip(name);
|
||||
new_name = convert_cond_model_name(name);
|
||||
} else if (starts_with(name, "first_stage_model.decoder")) {
|
||||
new_name = convert_vae_decoder_name(name);
|
||||
} else if (starts_with(name, "pmid.qformer_perceiver")) {
|
||||
|
||||
@ -94,7 +94,7 @@ public:
|
||||
float scale_factor = 0.18215f;
|
||||
|
||||
std::shared_ptr<Conditioner> cond_stage_model;
|
||||
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd
|
||||
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
|
||||
std::shared_ptr<DiffusionModel> diffusion_model;
|
||||
std::shared_ptr<VAE> first_stage_model;
|
||||
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
|
||||
@ -225,6 +225,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->clip_vision_path)) > 0) {
|
||||
LOG_INFO("loading clip_vision from '%s'", sd_ctx_params->clip_vision_path);
|
||||
std::string prefix = "cond_stage_model.transformer.";
|
||||
if (!model_loader.init_from_file(sd_ctx_params->clip_vision_path, prefix)) {
|
||||
LOG_WARN("loading clip_vision from '%s' failed", sd_ctx_params->clip_vision_path);
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) {
|
||||
LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) {
|
||||
@ -374,6 +382,13 @@ public:
|
||||
model_loader.tensor_storages_types,
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
clip_vision->alloc_params_buffer();
|
||||
clip_vision->get_param_tensors(tensors);
|
||||
}
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
@ -581,7 +596,7 @@ public:
|
||||
size_t total_params_size = total_params_ram_size + total_params_vram_size;
|
||||
LOG_INFO(
|
||||
"total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): "
|
||||
"clip %.2fMB(%s), unet %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)",
|
||||
"text_encoders %.2fMB(%s), diffusion_model %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)",
|
||||
total_params_size / 1024.0 / 1024.0,
|
||||
total_params_vram_size / 1024.0 / 1024.0,
|
||||
total_params_ram_size / 1024.0 / 1024.0,
|
||||
@ -812,6 +827,42 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_tensor* get_clip_vision_output(ggml_context* work_ctx,
|
||||
sd_image_t init_image,
|
||||
bool return_pooled = true,
|
||||
bool zero_out_masked = false) {
|
||||
ggml_tensor* output = NULL;
|
||||
if (zero_out_masked) {
|
||||
if (return_pooled) {
|
||||
output = ggml_new_tensor_1d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
clip_vision->vision_model.projection_dim);
|
||||
} else {
|
||||
output = ggml_new_tensor_2d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
clip_vision->vision_model.hidden_size,
|
||||
257);
|
||||
}
|
||||
|
||||
ggml_set_f32(output, 0.f);
|
||||
} else {
|
||||
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
|
||||
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size);
|
||||
free(image.data);
|
||||
image.data = NULL;
|
||||
|
||||
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
||||
sd_image_f32_to_tensor(resized_image.data, pixel_values, false);
|
||||
free(resized_image.data);
|
||||
resized_image.data = NULL;
|
||||
|
||||
// print_ggml_tensor(pixel_values);
|
||||
clip_vision->compute(n_threads, pixel_values, return_pooled, &output, work_ctx);
|
||||
// print_ggml_tensor(c_crossattn);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
SDCondition get_svd_condition(ggml_context* work_ctx,
|
||||
sd_image_t init_image,
|
||||
int width,
|
||||
@ -822,27 +873,7 @@ public:
|
||||
bool zero_out_masked = false) {
|
||||
// c_crossattn
|
||||
int64_t t0 = ggml_time_ms();
|
||||
struct ggml_tensor* c_crossattn = NULL;
|
||||
{
|
||||
if (zero_out_masked) {
|
||||
c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim);
|
||||
ggml_set_f32(c_crossattn, 0.f);
|
||||
} else {
|
||||
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
|
||||
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size);
|
||||
free(image.data);
|
||||
image.data = NULL;
|
||||
|
||||
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
||||
sd_image_f32_to_tensor(resized_image.data, pixel_values, false);
|
||||
free(resized_image.data);
|
||||
resized_image.data = NULL;
|
||||
|
||||
// print_ggml_tensor(pixel_values);
|
||||
clip_vision->compute(n_threads, pixel_values, &c_crossattn, work_ctx);
|
||||
// print_ggml_tensor(c_crossattn);
|
||||
}
|
||||
}
|
||||
struct ggml_tensor* c_crossattn = get_clip_vision_output(work_ctx, init_image, true, zero_out_masked);
|
||||
|
||||
// c_concat
|
||||
struct ggml_tensor* c_concat = NULL;
|
||||
@ -1161,32 +1192,15 @@ public:
|
||||
return latent;
|
||||
}
|
||||
|
||||
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
||||
int64_t W = x->ne[0] / 8;
|
||||
int64_t H = x->ne[1] / 8;
|
||||
int64_t C = 8;
|
||||
if (use_tiny_autoencoder) {
|
||||
C = 4;
|
||||
} else {
|
||||
if (sd_version_is_sd3(version)) {
|
||||
C = 32;
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
C = 32;
|
||||
}
|
||||
}
|
||||
ggml_tensor* result = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
W,
|
||||
H,
|
||||
C,
|
||||
x->ne[3]);
|
||||
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
ggml_tensor* result = NULL;
|
||||
if (!use_tiny_autoencoder) {
|
||||
ggml_tensor_scale_input(x);
|
||||
first_stage_model->compute(n_threads, x, false, &result, NULL);
|
||||
process_vae_input_tensor(x);
|
||||
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
|
||||
first_stage_model->free_compute_buffer();
|
||||
} else {
|
||||
tae_first_stage->compute(n_threads, x, false, &result, NULL);
|
||||
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
|
||||
tae_first_stage->free_compute_buffer();
|
||||
}
|
||||
|
||||
@ -1195,6 +1209,31 @@ public:
|
||||
return result;
|
||||
}
|
||||
|
||||
void process_latent_in(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
GGML_ASSERT(latent->ne[3] == 16);
|
||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
std::vector<float> latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||
for (int i = 0; i < latent->ne[3]; i++) {
|
||||
float mean = latents_mean_vec[i];
|
||||
float std_ = latents_std_vec[i];
|
||||
for (int j = 0; j < latent->ne[2]; j++) {
|
||||
for (int k = 0; k < latent->ne[1]; k++) {
|
||||
for (int l = 0; l < latent->ne[0]; l++) {
|
||||
float value = ggml_tensor_get_f32(latent, l, k, j, i);
|
||||
value = (value - mean) * scale_factor / std_;
|
||||
ggml_tensor_set_f32(latent, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_tensor_scale(latent, scale_factor);
|
||||
}
|
||||
}
|
||||
|
||||
void process_latent_out(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
GGML_ASSERT(latent->ne[3] == 16);
|
||||
@ -1259,7 +1298,7 @@ public:
|
||||
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
|
||||
}
|
||||
first_stage_model->free_compute_buffer();
|
||||
ggml_tensor_scale_output(result);
|
||||
process_vae_output_tensor(result);
|
||||
} else {
|
||||
if (vae_tiling && !decode_video) {
|
||||
// split latent in 64x64 tiles and compute in several steps
|
||||
@ -1404,6 +1443,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"model_path: %s\n"
|
||||
"clip_l_path: %s\n"
|
||||
"clip_g_path: %s\n"
|
||||
"clip_vision_path: %s\n"
|
||||
"t5xxl_path: %s\n"
|
||||
"diffusion_model_path: %s\n"
|
||||
"vae_path: %s\n"
|
||||
@ -1430,6 +1470,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
SAFE_STR(sd_ctx_params->model_path),
|
||||
SAFE_STR(sd_ctx_params->clip_l_path),
|
||||
SAFE_STR(sd_ctx_params->clip_g_path),
|
||||
SAFE_STR(sd_ctx_params->clip_vision_path),
|
||||
SAFE_STR(sd_ctx_params->t5xxl_path),
|
||||
SAFE_STR(sd_ctx_params->diffusion_model_path),
|
||||
SAFE_STR(sd_ctx_params->vae_path),
|
||||
@ -2183,7 +2224,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
int height = sd_vid_gen_params->height;
|
||||
int frames = sd_vid_gen_params->video_frames;
|
||||
frames = (frames - 1) / 4 * 4 + 1;
|
||||
LOG_INFO("img2vid %dx%dx%d", width, height, frames);
|
||||
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
|
||||
|
||||
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
|
||||
|
||||
@ -2209,6 +2250,66 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
|
||||
ggml_tensor* clip_vision_output = NULL;
|
||||
ggml_tensor* concat_latent = NULL;
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
||||
LOG_INFO("IMG2VID");
|
||||
|
||||
if (sd_vid_gen_params->init_image.data) {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false);
|
||||
} else {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, true);
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
|
||||
|
||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
|
||||
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
|
||||
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height
|
||||
for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width
|
||||
float value = 0.5f;
|
||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
|
||||
value /= 255.f;
|
||||
}
|
||||
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8]
|
||||
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||
|
||||
sd_ctx->sd->process_latent_in(concat_latent);
|
||||
|
||||
ggml_tensor* concat_mask = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
concat_latent->ne[0],
|
||||
concat_latent->ne[1],
|
||||
concat_latent->ne[2],
|
||||
4); // [b*4, t, w/8, h/8]
|
||||
for (int i3 = 0; i3 < concat_mask->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < concat_mask->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < concat_mask->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < concat_mask->ne[0]; i0++) {
|
||||
float value = 0.0f;
|
||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||
value = 1.0f;
|
||||
}
|
||||
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8]
|
||||
}
|
||||
|
||||
ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
||||
int sample_steps = sigmas.size() - 1;
|
||||
// Apply lora
|
||||
@ -2216,7 +2317,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
|
||||
// Get learned condition
|
||||
bool zero_out_masked = true;
|
||||
t0 = ggml_time_ms();
|
||||
int64_t t1 = ggml_time_ms();
|
||||
SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||
sd_ctx->sd->n_threads,
|
||||
prompt,
|
||||
@ -2225,19 +2326,23 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
height,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
||||
zero_out_masked);
|
||||
cond.c_concat = concat_latent;
|
||||
cond.c_vector = clip_vision_output;
|
||||
SDCondition uncond;
|
||||
if (sd_vid_gen_params->guidance.txt_cfg != 1.0) {
|
||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||
sd_ctx->sd->n_threads,
|
||||
negative_prompt,
|
||||
sd_vid_gen_params->clip_skip,
|
||||
width,
|
||||
height,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
||||
zero_out_masked);
|
||||
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,
|
||||
sd_ctx->sd->n_threads,
|
||||
negative_prompt,
|
||||
sd_vid_gen_params->clip_skip,
|
||||
width,
|
||||
height,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels(),
|
||||
zero_out_masked);
|
||||
uncond.c_concat = concat_latent;
|
||||
uncond.c_vector = clip_vision_output;
|
||||
}
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);
|
||||
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
sd_ctx->sd->cond_stage_model->free_params_buffer();
|
||||
@ -2280,11 +2385,11 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
sd_ctx->sd->diffusion_model->free_params_buffer();
|
||||
}
|
||||
|
||||
int64_t t3 = ggml_time_ms();
|
||||
LOG_INFO("generating latent video completed, taking %.2fs", (t3 - t1) * 1.0f / 1000);
|
||||
int64_t t4 = ggml_time_ms();
|
||||
LOG_INFO("generating latent video completed, taking %.2fs", (t4 - t2) * 1.0f / 1000);
|
||||
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
||||
int64_t t4 = ggml_time_ms();
|
||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
|
||||
int64_t t5 = ggml_time_ms();
|
||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||
}
|
||||
@ -2304,7 +2409,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
}
|
||||
ggml_free(work_ctx);
|
||||
|
||||
LOG_INFO("img2vid completed in %.2fs", (t4 - t0) * 1.0f / 1000);
|
||||
LOG_INFO("img2vid completed in %.2fs", (t5 - t0) * 1.0f / 1000);
|
||||
|
||||
return result_images;
|
||||
}
|
||||
|
||||
@ -115,6 +115,7 @@ typedef struct {
|
||||
const char* model_path;
|
||||
const char* clip_l_path;
|
||||
const char* clip_g_path;
|
||||
const char* clip_vision_path;
|
||||
const char* t5xxl_path;
|
||||
const char* diffusion_model_path;
|
||||
const char* vae_path;
|
||||
|
||||
30
wan.hpp
30
wan.hpp
@ -1124,12 +1124,12 @@ namespace WAN {
|
||||
|
||||
int64_t N = x->ne[2];
|
||||
int64_t n_token = x->ne[1];
|
||||
int64_t dim = x->ne[2];
|
||||
int64_t dim = x->ne[0];
|
||||
int64_t context_txt_len = context->ne[1] - context_img_len;
|
||||
|
||||
context = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim]
|
||||
auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0);
|
||||
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_txt_len * context->nb[2]);
|
||||
auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]);
|
||||
context_img = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim]
|
||||
context_txt = ggml_nn_cont(ctx, ggml_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim]
|
||||
|
||||
@ -1478,6 +1478,7 @@ namespace WAN {
|
||||
e = time_embedding_0->forward(ctx, e);
|
||||
e = ggml_silu_inplace(ctx, e);
|
||||
e = time_embedding_2->forward(ctx, e); // [N, dim]
|
||||
|
||||
// time_projection
|
||||
auto e0 = ggml_silu(ctx, e);
|
||||
e0 = time_projection_1->forward(ctx, e0);
|
||||
@ -1559,6 +1560,7 @@ namespace WAN {
|
||||
|
||||
struct WanRunner : public GGMLRunner {
|
||||
public:
|
||||
std::string desc = "wan";
|
||||
WanParams wan_params;
|
||||
Wan wan;
|
||||
std::vector<float> pe_vec;
|
||||
@ -1594,7 +1596,7 @@ namespace WAN {
|
||||
}
|
||||
|
||||
if (wan_params.num_layers == 30) {
|
||||
LOG_INFO("Wan2.1-T2V-1.3B");
|
||||
desc = "Wan2.1-T2V-1.3B";
|
||||
wan_params.dim = 1536;
|
||||
wan_params.eps = 1e-06;
|
||||
wan_params.ffn_dim = 8960;
|
||||
@ -1605,15 +1607,16 @@ namespace WAN {
|
||||
wan_params.text_len = 512;
|
||||
} else if (wan_params.num_layers == 40) {
|
||||
if (wan_params.model_type == "t2v") {
|
||||
LOG_INFO("Wan2.1-T2V-14B");
|
||||
desc = "Wan2.1-T2V-14B";
|
||||
wan_params.in_dim = 16;
|
||||
} else {
|
||||
LOG_INFO("Wan2.1-I2V-14B");
|
||||
desc = "Wan2.1-I2V-14B";
|
||||
wan_params.in_dim = 36;
|
||||
}
|
||||
wan_params.dim = 5120;
|
||||
wan_params.eps = 1e-06;
|
||||
wan_params.ffn_dim = 13824;
|
||||
wan_params.freq_dim = 256;
|
||||
wan_params.in_dim = 16;
|
||||
wan_params.num_heads = 40;
|
||||
wan_params.out_dim = 16;
|
||||
wan_params.text_len = 512;
|
||||
@ -1621,12 +1624,14 @@ namespace WAN {
|
||||
GGML_ABORT("invalid num_layers(%d) of wan", wan_params.num_layers);
|
||||
}
|
||||
|
||||
LOG_INFO("%s", desc.c_str());
|
||||
|
||||
wan = Wan(wan_params);
|
||||
wan.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return "wan";
|
||||
return desc;
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
@ -1637,6 +1642,7 @@ namespace WAN {
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* clip_fea = NULL,
|
||||
struct ggml_tensor* c_concat = NULL,
|
||||
struct ggml_tensor* time_dim_concat = NULL) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false);
|
||||
|
||||
@ -1644,6 +1650,7 @@ namespace WAN {
|
||||
timesteps = to_backend(timesteps);
|
||||
context = to_backend(context);
|
||||
clip_fea = to_backend(clip_fea);
|
||||
c_concat = to_backend(c_concat);
|
||||
time_dim_concat = to_backend(time_dim_concat);
|
||||
|
||||
pe_vec = Rope::gen_wan_pe(x->ne[2],
|
||||
@ -1663,6 +1670,10 @@ namespace WAN {
|
||||
// pe->data = NULL;
|
||||
set_backend_tensor_data(pe, pe_vec.data());
|
||||
|
||||
if (c_concat != NULL) {
|
||||
x = ggml_concat(compute_ctx, x, c_concat, 3);
|
||||
}
|
||||
|
||||
struct ggml_tensor* out = wan.forward(compute_ctx,
|
||||
x,
|
||||
timesteps,
|
||||
@ -1681,11 +1692,12 @@ namespace WAN {
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* clip_fea = NULL,
|
||||
struct ggml_tensor* c_concat = NULL,
|
||||
struct ggml_tensor* time_dim_concat = NULL,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context, clip_fea, time_dim_concat);
|
||||
return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat);
|
||||
};
|
||||
|
||||
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
@ -1720,7 +1732,7 @@ namespace WAN {
|
||||
struct ggml_tensor* out = NULL;
|
||||
|
||||
int t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, NULL, NULL, &out, work_ctx);
|
||||
compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx);
|
||||
int t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user