mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add wan2.2 14B i2v support
This commit is contained in:
parent
079b393b6e
commit
815e9fd6e1
23
model.cpp
23
model.cpp
@ -1691,9 +1691,11 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
bool has_multiple_encoders = false;
|
||||
bool is_unet = false;
|
||||
|
||||
bool is_xl = false;
|
||||
bool is_flux = false;
|
||||
bool is_wan = false;
|
||||
bool is_xl = false;
|
||||
bool is_flux = false;
|
||||
bool is_wan = false;
|
||||
int64_t patch_embedding_channels = 0;
|
||||
bool has_img_emb = false;
|
||||
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (!(is_xl || is_flux)) {
|
||||
@ -1707,7 +1709,13 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
return VERSION_SD3;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
||||
return VERSION_WAN2;
|
||||
is_wan = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) {
|
||||
patch_embedding_channels = tensor_storage.ne[3];
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
|
||||
has_img_emb = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
||||
is_unet = true;
|
||||
@ -1748,6 +1756,13 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_wan) {
|
||||
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
|
||||
if (patch_embedding_channels == 184320 && !has_img_emb) {
|
||||
return VERSION_WAN2_2_I2V;
|
||||
}
|
||||
return VERSION_WAN2;
|
||||
}
|
||||
bool is_inpaint = input_block_weight.ne[2] == 9;
|
||||
bool is_ip2p = input_block_weight.ne[2] == 8;
|
||||
if (is_xl) {
|
||||
|
||||
3
model.h
3
model.h
@ -32,6 +32,7 @@ enum SDVersion {
|
||||
VERSION_FLUX,
|
||||
VERSION_FLUX_FILL,
|
||||
VERSION_WAN2,
|
||||
VERSION_WAN2_2_I2V,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
@ -71,7 +72,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_wan(SDVersion version) {
|
||||
if (version == VERSION_WAN2) {
|
||||
if (version == VERSION_WAN2 || VERSION_WAN2_2_I2V) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -38,7 +38,7 @@ const char* model_version_to_str[] = {
|
||||
"Flux",
|
||||
"Flux Fill",
|
||||
"Wan 2.x",
|
||||
};
|
||||
"Wan 2.2 I2V"};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
"Euler A",
|
||||
@ -2315,18 +2315,22 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
|
||||
ggml_tensor* clip_vision_output = NULL;
|
||||
ggml_tensor* concat_latent = NULL;
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
|
||||
LOG_INFO("IMG2VID");
|
||||
|
||||
if (sd_vid_gen_params->init_image.data) {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
|
||||
} else {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
|
||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
||||
if (sd_vid_gen_params->init_image.data) {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
|
||||
} else {
|
||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
|
||||
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
|
||||
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {
|
||||
|
||||
9
wan.hpp
9
wan.hpp
@ -1609,8 +1609,13 @@ namespace WAN {
|
||||
wan_params.text_len = 512;
|
||||
} else if (wan_params.num_layers == 40) {
|
||||
if (wan_params.model_type == "t2v") {
|
||||
desc = "Wan2.1-T2V-14B";
|
||||
wan_params.in_dim = 16;
|
||||
if (version == VERSION_WAN2_2_I2V) {
|
||||
desc = "Wan2.2-I2V-14B";
|
||||
wan_params.in_dim = 36;
|
||||
} else {
|
||||
desc = "Wan2.x-T2V-14B";
|
||||
wan_params.in_dim = 16;
|
||||
}
|
||||
} else {
|
||||
desc = "Wan2.1-I2V-14B";
|
||||
wan_params.in_dim = 36;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user