add wan2.1 t2v support

This commit is contained in:
leejet 2025-08-15 00:37:30 +08:00
parent 73f76e6d96
commit 3a2840f9fb
9 changed files with 476 additions and 92 deletions

215
examples/cli/avi_writer.h Normal file
View File

@ -0,0 +1,215 @@
#ifndef __AVI_WRITER_H__
#define __AVI_WRITER_H__
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "stable-diffusion.h"
#include "stb_image_write.h"
typedef struct {
uint32_t offset;
uint32_t size;
} avi_index_entry;
// Write 32-bit little-endian integer
void write_u32_le(FILE* f, uint32_t val) {
fwrite(&val, 4, 1, f);
}
// Write 16-bit little-endian integer
void write_u16_le(FILE* f, uint16_t val) {
fwrite(&val, 2, 1, f);
}
/**
* Create an MJPG AVI file from an array of sd_image_t images.
* Images are encoded to JPEG using stb_image_write.
*
* @param filename Output AVI file name.
* @param images Array of input images.
* @param num_images Number of images in the array.
* @param fps Frames per second for the video.
* @param quality JPEG quality (0-100).
* @return 0 on success, -1 on failure.
*/
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality = 90) {
if (num_images == 0) {
fprintf(stderr, "Error: Image array is empty.\n");
return -1;
}
FILE* f = fopen(filename, "wb");
if (!f) {
perror("Error opening file for writing");
return -1;
}
uint32_t width = images[0].width;
uint32_t height = images[0].height;
uint32_t channels = images[0].channel;
if (channels != 3 && channels != 4) {
fprintf(stderr, "Error: Unsupported channel count: %u\n", channels);
fclose(f);
return -1;
}
// --- RIFF AVI Header ---
fwrite("RIFF", 4, 1, f);
long riff_size_pos = ftell(f);
write_u32_le(f, 0); // Placeholder for file size
fwrite("AVI ", 4, 1, f);
// 'hdrl' LIST (header list)
fwrite("LIST", 4, 1, f);
write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40);
fwrite("hdrl", 4, 1, f);
// 'avih' chunk (AVI main header)
fwrite("avih", 4, 1, f);
write_u32_le(f, 56);
write_u32_le(f, 1000000 / fps); // Microseconds per frame
write_u32_le(f, 0); // Max bytes per second
write_u32_le(f, 0); // Padding granularity
write_u32_le(f, 0x110); // Flags (HASINDEX | ISINTERLEAVED)
write_u32_le(f, num_images); // Total frames
write_u32_le(f, 0); // Initial frames
write_u32_le(f, 1); // Number of streams
write_u32_le(f, width * height * 3); // Suggested buffer size
write_u32_le(f, width);
write_u32_le(f, height);
write_u32_le(f, 0); // Reserved
write_u32_le(f, 0); // Reserved
write_u32_le(f, 0); // Reserved
write_u32_le(f, 0); // Reserved
// 'strl' LIST (stream list)
fwrite("LIST", 4, 1, f);
write_u32_le(f, 4 + 8 + 56 + 8 + 40);
fwrite("strl", 4, 1, f);
// 'strh' chunk (stream header)
fwrite("strh", 4, 1, f);
write_u32_le(f, 56);
fwrite("vids", 4, 1, f); // Stream type: video
fwrite("MJPG", 4, 1, f); // Codec: Motion JPEG
write_u32_le(f, 0); // Flags
write_u16_le(f, 0); // Priority
write_u16_le(f, 0); // Language
write_u32_le(f, 0); // Initial frames
write_u32_le(f, 1); // Scale
write_u32_le(f, fps); // Rate
write_u32_le(f, 0); // Start
write_u32_le(f, num_images); // Length
write_u32_le(f, width * height * 3); // Suggested buffer size
write_u32_le(f, (uint32_t)-1); // Quality
write_u32_le(f, 0); // Sample size
write_u16_le(f, 0); // rcFrame.left
write_u16_le(f, 0); // rcFrame.top
write_u16_le(f, 0); // rcFrame.right
write_u16_le(f, 0); // rcFrame.bottom
// 'strf' chunk (stream format: BITMAPINFOHEADER)
fwrite("strf", 4, 1, f);
write_u32_le(f, 40);
write_u32_le(f, 40); // biSize
write_u32_le(f, width);
write_u32_le(f, height);
write_u16_le(f, 1); // biPlanes
write_u16_le(f, 24); // biBitCount
fwrite("MJPG", 4, 1, f); // biCompression (FOURCC)
write_u32_le(f, width * height * 3); // biSizeImage
write_u32_le(f, 0); // XPelsPerMeter
write_u32_le(f, 0); // YPelsPerMeter
write_u32_le(f, 0); // Colors used
write_u32_le(f, 0); // Colors important
// 'movi' LIST (video frames)
long movi_list_pos = ftell(f);
fwrite("LIST", 4, 1, f);
long movi_size_pos = ftell(f);
write_u32_le(f, 0); // Placeholder for movi size
fwrite("movi", 4, 1, f);
avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images);
if (!index) {
fclose(f);
return -1;
}
// Encode and write each frame as JPEG
struct {
uint8_t* buf;
size_t size;
} jpeg_data;
for (int i = 0; i < num_images; i++) {
jpeg_data.buf = NULL;
jpeg_data.size = 0;
// Callback function to collect JPEG data into memory
auto write_to_buf = [](void* context, void* data, int size) {
auto jd = (decltype(jpeg_data)*)context;
jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size);
memcpy(jd->buf + jd->size, data, size);
jd->size += size;
};
// Encode to JPEG in memory
stbi_write_jpg_to_func(
write_to_buf,
&jpeg_data,
images[i].width,
images[i].height,
channels,
images[i].data,
quality);
// Write '00dc' chunk (video frame)
fwrite("00dc", 4, 1, f);
write_u32_le(f, jpeg_data.size);
index[i].offset = ftell(f) - 8;
index[i].size = jpeg_data.size;
fwrite(jpeg_data.buf, 1, jpeg_data.size, f);
// Align to even byte size
if (jpeg_data.size % 2)
fputc(0, f);
free(jpeg_data.buf);
}
// Finalize 'movi' size
long cur_pos = ftell(f);
long movi_size = cur_pos - movi_size_pos - 4;
fseek(f, movi_size_pos, SEEK_SET);
write_u32_le(f, movi_size);
fseek(f, cur_pos, SEEK_SET);
// Write 'idx1' index
fwrite("idx1", 4, 1, f);
write_u32_le(f, num_images * 16);
for (int i = 0; i < num_images; i++) {
fwrite("00dc", 4, 1, f);
write_u32_le(f, 0x10);
write_u32_le(f, index[i].offset);
write_u32_le(f, index[i].size);
}
// Finalize RIFF size
cur_pos = ftell(f);
long file_size = cur_pos - riff_size_pos - 4;
fseek(f, riff_size_pos, SEEK_SET);
write_u32_le(f, file_size);
fseek(f, cur_pos, SEEK_SET);
fclose(f);
free(index);
return 0;
}
#endif // __AVI_WRITER_H__

View File

@ -10,6 +10,7 @@
#include <vector> #include <vector>
// #include "preprocessing.hpp" // #include "preprocessing.hpp"
#include "avi_writer.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_IMPLEMENTATION
@ -83,6 +84,7 @@ struct SDParams {
int batch_count = 1; int batch_count = 1;
int video_frames = 1; int video_frames = 1;
int fps = 24;
sample_method_t sample_method = EULER_A; sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT; schedule_t schedule = DEFAULT;
@ -166,6 +168,8 @@ void print_params(SDParams params) {
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
printf(" video_frames: %d\n", params.video_frames);
printf(" fps: %d\n", params.fps);
} }
void print_usage(int argc, const char* argv[]) { void print_usage(int argc, const char* argv[]) {
@ -224,7 +228,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate\n"); printf(" -b, --batch-count COUNT number of images to generate\n");
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
@ -238,6 +242,8 @@ void print_usage(int argc, const char* argv[]) {
printf(" --chroma-disable-dit-mask disable dit mask for chroma\n"); printf(" --chroma-disable-dit-mask disable dit mask for chroma\n");
printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n"); printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n");
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
printf(" --video-frames video frames (default: 1)\n");
printf(" --fps fps (default: 24)\n");
printf(" -v, --verbose print extra info\n"); printf(" -v, --verbose print extra info\n");
} }
@ -435,6 +441,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip-skip", "", &params.clip_skip}, {"", "--clip-skip", "", &params.clip_skip},
{"-b", "--batch-count", "", &params.batch_count}, {"-b", "--batch-count", "", &params.batch_count},
{"", "--chroma-t5-mask-pad", "", &params.chroma_t5_mask_pad}, {"", "--chroma-t5-mask-pad", "", &params.chroma_t5_mask_pad},
{"", "--video-frames", "", &params.video_frames},
{"", "--fps", "", &params.fps},
}; };
options.float_options = { options.float_options = {
@ -656,6 +664,16 @@ void parse_args(int argc, const char** argv, SDParams& params) {
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n"); fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
} }
if (params.mode == VID_GEN && params.video_frames <= 0) {
fprintf(stderr, "warning: --video-frames must be at least 1\n");
exit(1);
}
if (params.mode == VID_GEN && params.fps <= 0) {
fprintf(stderr, "warning: --fps must be at least 1\n");
exit(1);
}
if (params.upscale_repeats < 1) { if (params.upscale_repeats < 1) {
fprintf(stderr, "error: upscale multiplier must be at least 1\n"); fprintf(stderr, "error: upscale multiplier must be at least 1\n");
exit(1); exit(1);
@ -983,7 +1001,7 @@ int main(int argc, const char* argv[]) {
mask_image_buffer}; mask_image_buffer};
sd_image_t* results; sd_image_t* results;
int expected_num_results = 1; int num_results = 1;
if (params.mode == IMG_GEN) { if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(), params.prompt.c_str(),
@ -1009,8 +1027,8 @@ int main(int argc, const char* argv[]) {
params.input_id_images_path.c_str(), params.input_id_images_path.c_str(),
}; };
results = generate_image(sd_ctx, &img_gen_params); results = generate_image(sd_ctx, &img_gen_params);
expected_num_results = params.batch_count; num_results = params.batch_count;
} else if (params.mode == VID_GEN) { } else if (params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = { sd_vid_gen_params_t vid_gen_params = {
params.prompt.c_str(), params.prompt.c_str(),
@ -1028,8 +1046,7 @@ int main(int argc, const char* argv[]) {
params.video_frames, params.video_frames,
}; };
results = generate_video(sd_ctx, &vid_gen_params); results = generate_video(sd_ctx, &vid_gen_params, &num_results);
expected_num_results = params.video_frames;
} }
if (results == NULL) { if (results == NULL) {
@ -1065,45 +1082,59 @@ int main(int argc, const char* argv[]) {
} }
} }
std::string dummy_name, ext, lc_ext; std::string base_path;
std::string file_ext;
std::string file_ext_lower;
bool is_jpg; bool is_jpg;
size_t last = params.output_path.find_last_of("."); size_t last_dot_pos = params.output_path.find_last_of(".");
size_t last_path = std::min(params.output_path.find_last_of("/"), size_t last_slash_pos = std::min(params.output_path.find_last_of("/"),
params.output_path.find_last_of("\\")); params.output_path.find_last_of("\\"));
if (last != std::string::npos // filename has extension if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) { // filename has extension
&& (last_path == std::string::npos || last > last_path)) { base_path = params.output_path.substr(0, last_dot_pos);
dummy_name = params.output_path.substr(0, last); file_ext = file_ext_lower = params.output_path.substr(last_dot_pos);
ext = lc_ext = params.output_path.substr(last); std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower);
std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe");
is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe";
} else { } else {
dummy_name = params.output_path; base_path = params.output_path;
ext = lc_ext = ""; file_ext = file_ext_lower = "";
is_jpg = false; is_jpg = false;
} }
// appending ".png" to absent or unknown extension
if (!is_jpg && lc_ext != ".png") { if (params.mode == VID_GEN && num_results > 1) {
dummy_name += ext; std::string vid_output_path = params.output_path;
ext = ".png"; if (file_ext_lower == ".png") {
vid_output_path = base_path + ".avi";
}
create_mjpg_avi_from_sd_images(vid_output_path.c_str(), results, num_results, params.fps);
printf("save result MJPG AVI video to '%s'\n", vid_output_path.c_str());
} else {
// appending ".png" to absent or unknown extension
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
file_ext = ".png";
}
for (int i = 0; i < num_results; i++) {
if (results[i].data == NULL) {
continue;
}
std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext;
if (is_jpg) {
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
printf("save result JPEG image to '%s'\n", final_image_path.c_str());
} else {
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result PNG image to '%s'\n", final_image_path.c_str());
}
}
} }
for (int i = 0; i < expected_num_results; i++) {
if (results[i].data == NULL) { free(results);
continue; for (int i = 0; i < num_results; i++) {
}
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext;
if (is_jpg) {
stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(params, params.seed + i).c_str());
printf("save result JPEG image to '%s'\n", final_image_path.c_str());
} else {
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result PNG image to '%s'\n", final_image_path.c_str());
}
free(results[i].data); free(results[i].data);
results[i].data = NULL; results[i].data = NULL;
} }
free(results);
free_sd_ctx(sd_ctx); free_sd_ctx(sd_ctx);
free(control_image_buffer); free(control_image_buffer);
free(input_image_buffer); free(input_image_buffer);

View File

@ -1,4 +1,4 @@
for f in *.cpp *.h *.hpp examples/cli/*.cpp; do for f in *.cpp *.h *.hpp examples/cli/*.cpp examples/cli/*.h; do
[[ "$f" == vocab* ]] && continue [[ "$f" == vocab* ]] && continue
echo "formatting '$f'" echo "formatting '$f'"
clang-format -style=file -i "$f" clang-format -style=file -i "$f"

2
ggml

@ -1 +1 @@
Subproject commit e89bc7e8625f59145ee8c0b09383009c47752cd8 Subproject commit 089530bb72e70aa9f9ecb98137dfd891c2be20c1

View File

@ -988,19 +988,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
if (flash_attn) { if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
bool can_use_flash_attn = true; bool can_use_flash_attn = true;
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
d_head == 80 ||
d_head == 96 ||
d_head == 112 ||
d_head == 128 ||
d_head == 256);
if (can_use_flash_attn && L_k % 256 != 0) { if (can_use_flash_attn && L_k % 256 != 0) {
// TODO(Green-Sky): might be worth just padding by default kv_pad = GGML_PAD(L_k, 256) - L_k;
if (L_k == 77 || L_k == 1560 || L_k == 4208 || L_k == 3952) {
kv_pad = GGML_PAD(L_k, 256) - L_k;
} else {
can_use_flash_attn = false;
}
} }
if (mask != nullptr) { if (mask != nullptr) {
@ -1021,14 +1010,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0); k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
} }
// k = ggml_cast(ctx, k, GGML_TYPE_F16); k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
if (kv_pad != 0) { if (kv_pad != 0) {
v = ggml_pad(ctx, v, 0, kv_pad, 0, 0); v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
} }
// v = ggml_cast(ctx, v, GGML_TYPE_F16); v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (mask != nullptr) { if (mask != nullptr) {
mask = ggml_transpose(ctx, mask); mask = ggml_transpose(ctx, mask);

View File

@ -1191,16 +1191,20 @@ public:
} }
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) { ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
int64_t W = x->ne[0] * 8; int64_t W = x->ne[0] * 8;
int64_t H = x->ne[1] * 8; int64_t H = x->ne[1] * 8;
int64_t C = 3; int64_t C = 3;
ggml_tensor* result; ggml_tensor* result = NULL;
if (decode_video) { if (decode_video) {
int T = x->ne[2];
if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1;
}
result = ggml_new_tensor_4d(work_ctx, result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32, GGML_TYPE_F32,
W, W,
H, H,
x->ne[2], T,
3); 3);
} else { } else {
result = ggml_new_tensor_4d(work_ctx, result = ggml_new_tensor_4d(work_ctx,
@ -1214,6 +1218,7 @@ public:
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) { if (!use_tiny_autoencoder) {
process_latent_out(x); process_latent_out(x);
// x = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin");
if (vae_tiling && !decode_video) { if (vae_tiling && !decode_video) {
// split latent in 32x32 tiles and compute in several steps // split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
@ -1221,7 +1226,7 @@ public:
}; };
sd_tiling(x, result, 8, 32, 0.5f, on_tiling); sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
} else { } else {
first_stage_model->compute(n_threads, x, true, &result, NULL); first_stage_model->compute(n_threads, x, true, &result, work_ctx);
} }
first_stage_model->free_compute_buffer(); first_stage_model->free_compute_buffer();
ggml_tensor_scale_output(result); ggml_tensor_scale_output(result);
@ -1882,18 +1887,20 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
int frames = 1, int frames = 1,
bool video = false) { bool video = false) {
int C = 4; int C = 4;
int T = frames;
if (sd_version_is_sd3(sd_ctx->sd->version)) { if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16; C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) { } else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16; C = 16;
} else if (sd_version_is_wan(sd_ctx->sd->version)) { } else if (sd_version_is_wan(sd_ctx->sd->version)) {
C = 16; C = 16;
T = ((T - 1) / 4) + 1;
} }
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
ggml_tensor* init_latent; ggml_tensor* init_latent;
if (video) { if (video) {
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, frames, C); init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
} else { } else {
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
} }
@ -2131,7 +2138,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
return result_images; return result_images;
} }
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out) {
if (sd_ctx == NULL || sd_vid_gen_params == NULL) { if (sd_ctx == NULL || sd_vid_gen_params == NULL) {
return NULL; return NULL;
} }
@ -2142,13 +2149,14 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int width = sd_vid_gen_params->width; int width = sd_vid_gen_params->width;
int height = sd_vid_gen_params->height; int height = sd_vid_gen_params->height;
int frames = sd_vid_gen_params->video_frames; int frames = sd_vid_gen_params->video_frames;
frames = (frames - 1) / 4 * 4 + 1;
LOG_INFO("img2vid %dx%dx%d", width, height, frames); LOG_INFO("img2vid %dx%dx%d", width, height, frames);
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 50 MB params.mem_size = static_cast<size_t>(100 * 1024) * 1024; // 100 MB
params.mem_size += width * height * frames * 3 * sizeof(float); params.mem_size += width * height * frames * 3 * sizeof(float) * 2;
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = false; params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size); // LOG_DEBUG("mem_size %u ", params.mem_size);
@ -2204,12 +2212,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int W = width / 8; int W = width / 8;
int H = height / 8; int H = height / 8;
int T = frames; int T = init_latent->ne[2];
int C = 16; int C = 16;
struct ggml_tensor* final_latent; struct ggml_tensor* final_latent;
// Sample // Sample
{ {
LOG_DEBUG("sample %dx%dx%d", W, H, T);
int64_t sampling_start = ggml_time_ms(); int64_t sampling_start = ggml_time_ms();
struct ggml_tensor* x_t = init_latent; struct ggml_tensor* x_t = init_latent;
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
@ -2247,15 +2256,16 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_ctx->sd->first_stage_model->free_params_buffer(); sd_ctx->sd->first_stage_model->free_params_buffer();
} }
sd_image_t* result_images = (sd_image_t*)calloc(T, sizeof(sd_image_t)); sd_image_t* result_images = (sd_image_t*)calloc(vid->ne[2], sizeof(sd_image_t));
if (result_images == NULL) { if (result_images == NULL) {
ggml_free(work_ctx); ggml_free(work_ctx);
return NULL; return NULL;
} }
*num_frames_out = vid->ne[2];
for (size_t i = 0; i < T; i++) { for (size_t i = 0; i < vid->ne[2]; i++) {
result_images[i].width = final_latent->ne[0] * 8; result_images[i].width = vid->ne[0];
result_images[i].height = final_latent->ne[1] * 8; result_images[i].height = vid->ne[1];
result_images[i].channel = 3; result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(vid, i, true); result_images[i].data = sd_tensor_to_image(vid, i, true);
} }

View File

@ -231,7 +231,7 @@ SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_para
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);
typedef struct upscaler_ctx_t upscaler_ctx_t; typedef struct upscaler_ctx_t upscaler_ctx_t;

4
t5.hpp
View File

@ -994,8 +994,8 @@ struct T5Embedder {
// cuda f16: pass // cuda f16: pass
// cuda f32: pass // cuda f32: pass
// cuda q8_0: pass // cuda q8_0: pass
ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
// ggml_backend_t backend = ggml_backend_cpu_init(); ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16; ggml_type model_data_type = GGML_TYPE_F16;
ModelLoader model_loader; ModelLoader model_loader;

179
wan.hpp
View File

@ -14,6 +14,8 @@ namespace WAN {
constexpr int CACHE_T = 2; constexpr int CACHE_T = 2;
constexpr int WAN_GRAPH_SIZE = 10240; constexpr int WAN_GRAPH_SIZE = 10240;
#define Rep ((struct ggml_tensor*)1)
class CausalConv3d : public GGMLBlock { class CausalConv3d : public GGMLBlock {
protected: protected:
int64_t in_channels; int64_t in_channels;
@ -68,7 +70,7 @@ namespace WAN {
int lp2 = 2 * std::get<0>(padding); int lp2 = 2 * std::get<0>(padding);
int rp2 = 0; int rp2 = 0;
if (cache_x != NULL && std::get<0>(padding) > 0) { if (cache_x != NULL && lp2 > 0) {
x = ggml_concat(ctx, cache_x, x, 2); x = ggml_concat(ctx, cache_x, x, 2);
lp2 -= (int)cache_x->ne[2]; lp2 -= (int)cache_x->ne[2];
} }
@ -145,8 +147,6 @@ namespace WAN {
int64_t h = x->ne[1]; int64_t h = x->ne[1];
int64_t w = x->ne[0]; int64_t w = x->ne[0];
struct ggml_tensor* Rep = (struct ggml_tensor*)1;
if (mode == "upsample3d") { if (mode == "upsample3d") {
if (feat_cache.size() > 0) { if (feat_cache.size() > 0) {
int idx = feat_idx; int idx = feat_idx;
@ -164,8 +164,8 @@ namespace WAN {
cache_x, cache_x,
2); 2);
} }
if (cache_x->ne[1] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) { if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) {
cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 1, 1, (int)cache_x->ne[2], 0, 0, 0); cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0);
// aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2) // aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2)
} }
if (feat_cache[idx] == Rep) { if (feat_cache[idx] == Rep) {
@ -629,7 +629,7 @@ namespace WAN {
}; };
class WanVAE : public GGMLBlock { class WanVAE : public GGMLBlock {
protected: public:
bool decode_only = true; bool decode_only = true;
int64_t dim = 96; int64_t dim = 96;
int64_t z_dim = 16; int64_t z_dim = 16;
@ -724,11 +724,47 @@ namespace WAN {
clear_cache(); clear_cache();
return out; return out;
} }
struct ggml_tensor* decode_partial(struct ggml_context* ctx,
struct ggml_tensor* z,
int64_t i,
int64_t b = 1) {
// z: [b*c, t, h, w]
GGML_ASSERT(b == 1);
auto decoder = std::dynamic_pointer_cast<Decoder3d>(blocks["decoder"]);
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
auto x = conv2->forward(ctx, z);
auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w]
_conv_idx = 0;
auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx);
return out;
}
};
struct FeatCache {
std::vector<float> data;
std::vector<int64_t> shape;
bool is_rep = false;
FeatCache() = default;
FeatCache(ggml_backend_t backend, ggml_tensor* tensor) {
shape = {tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]};
data.resize(shape[0] * shape[1] * shape[2] * shape[3]);
ggml_backend_tensor_get_and_sync(backend, tensor, (void*)data.data(), 0, ggml_nbytes(tensor));
}
ggml_tensor* to_ggml_tensor(ggml_context* ctx) {
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, shape[0], shape[1], shape[2], shape[3]);
}
}; };
struct WanVAERunner : public VAE { struct WanVAERunner : public VAE {
bool decode_only = true; bool decode_only = true;
WanVAE ae; WanVAE ae;
std::vector<FeatCache> _feat_vec_map;
WanVAERunner(ggml_backend_t backend, WanVAERunner(ggml_backend_t backend,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
@ -736,6 +772,11 @@ namespace WAN {
bool decode_only = false) bool decode_only = false)
: decode_only(decode_only), ae(decode_only), VAE(backend) { : decode_only(decode_only), ae(decode_only), VAE(backend) {
ae.init(params_ctx, tensor_types, prefix); ae.init(params_ctx, tensor_types, prefix);
rest_feat_vec_map();
}
void rest_feat_vec_map() {
_feat_vec_map = std::vector<FeatCache>(ae._conv_num, FeatCache());
} }
std::string get_desc() { std::string get_desc() {
@ -747,7 +788,7 @@ namespace WAN {
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 10240 * z->ne[2], false);
z = to_backend(z); z = to_backend(z);
@ -758,22 +799,120 @@ namespace WAN {
return gf; return gf;
} }
struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false);
ae.clear_cache();
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
FeatCache& feat_cache_vec = _feat_vec_map[feat_idx];
if (feat_cache_vec.is_rep) {
ae._feat_map[feat_idx] = Rep;
} else if (feat_cache_vec.data.size() > 0) {
ggml_tensor* feat_cache = feat_cache_vec.to_ggml_tensor(compute_ctx);
set_backend_tensor_data(feat_cache, feat_cache_vec.data.data());
ae._feat_map[feat_idx] = feat_cache;
}
}
z = to_backend(z);
struct ggml_tensor* out = decode_graph ? ae.decode_partial(compute_ctx, z, i) : ae.encode(compute_ctx, z);
for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache != NULL && feat_cache != Rep) {
ggml_build_forward_expand(gf, feat_cache);
}
}
ggml_build_forward_expand(gf, out);
return gf;
}
void compute(const int n_threads, void compute(const int n_threads,
struct ggml_tensor* z, struct ggml_tensor* z,
bool decode_graph, bool decode_graph,
struct ggml_tensor** output, struct ggml_tensor** output,
struct ggml_context* output_ctx = NULL) { struct ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* { if (true) {
return build_graph(z, decode_graph); auto get_graph = [&]() -> struct ggml_cgraph* {
}; return build_graph(z, decode_graph);
// ggml_set_f32(z, 0.5f); };
// print_ggml_tensor(z); GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } else { // broken
ae.clear_cache();
int64_t t = z->ne[2];
int64_t i = 0;
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph_partial(z, decode_graph, i);
};
struct ggml_tensor* out = NULL;
GGMLRunner::compute(get_graph, n_threads, false, &out, output_ctx);
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache == Rep) {
FeatCache feat_cache_vec;
feat_cache_vec.is_rep = true;
_feat_vec_map[feat_idx] = feat_cache_vec;
} else if (feat_cache != NULL) {
_feat_vec_map[feat_idx] = FeatCache(backend, feat_cache);
}
}
GGMLRunner::free_compute_buffer();
if (t == 1) {
*output = out;
ae.clear_cache();
return;
}
*output = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], (t - 1) * 4 + 1, out->ne[3]);
auto copy_to_output = [&]() {
for (int64_t i3 = 0; i3 < out->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < out->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < out->ne[1]; i1++) {
for (int64_t i0 = 0; i0 < out->ne[0]; i0++) {
float value = ggml_tensor_get_f32(out, i0, i1, i2, i3);
int64_t offset = (i == 0) ? 0 : (1 + (i - 1) * 4);
ggml_tensor_set_f32(*output, value, i0, i1, offset + i2, i3);
}
}
}
}
};
copy_to_output();
out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]);
for (i = 1; i < t; i++) {
GGMLRunner::compute(get_graph, n_threads, false, &out);
for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) {
ggml_tensor* feat_cache = ae._feat_map[feat_idx];
if (feat_cache == Rep) {
FeatCache feat_cache_vec;
feat_cache_vec.is_rep = true;
_feat_vec_map[feat_idx] = feat_cache_vec;
} else if (feat_cache != NULL) {
_feat_vec_map[feat_idx] = FeatCache(backend, feat_cache);
}
}
ae.clear_cache();
GGMLRunner::free_compute_buffer();
copy_to_output();
}
}
} }
void test() { void test() {
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB params.mem_size = static_cast<size_t>(1000 * 1024 * 1024); // 10 MB
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = false; params.no_alloc = false;
@ -785,9 +924,9 @@ namespace WAN {
// cpu f16, pass // cpu f16, pass
// cuda f16, pass // cuda f16, pass
// cuda f32, pass // cuda f32, pass
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 1, 16); auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 2, 16);
z = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); ggml_set_f32(z, 0.5f);
// ggml_set_f32(z, 0.5f); z = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin");
print_ggml_tensor(z); print_ggml_tensor(z);
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
@ -803,7 +942,7 @@ namespace WAN {
static void load_from_file_and_test(const std::string& file_path) { static void load_from_file_and_test(const std::string& file_path) {
ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cuda_init(0);
// ggml_backend_t backend = ggml_backend_cpu_init(); // ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F32; ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend)); std::shared_ptr<WanVAERunner> vae = std::shared_ptr<WanVAERunner>(new WanVAERunner(backend));
{ {
LOG_INFO("loading from '%s'", file_path.c_str()); LOG_INFO("loading from '%s'", file_path.c_str());
@ -1588,8 +1727,8 @@ namespace WAN {
} }
static void load_from_file_and_test(const std::string& file_path) { static void load_from_file_and_test(const std::string& file_path) {
ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
// ggml_backend_t backend = ggml_backend_cpu_init(); ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0; ggml_type model_data_type = GGML_TYPE_Q8_0;
LOG_INFO("loading from '%s'", file_path.c_str()); LOG_INFO("loading from '%s'", file_path.c_str());