From 3a2840f9fb94964f198eaae1d2f15b93659e9b6b Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 15 Aug 2025 00:37:30 +0800 Subject: [PATCH] add wan2.1 t2v support --- examples/cli/avi_writer.h | 215 ++++++++++++++++++++++++++++++++++++++ examples/cli/main.cpp | 107 ++++++++++++------- format-code.sh | 2 +- ggml | 2 +- ggml_extend.hpp | 17 +-- stable-diffusion.cpp | 40 ++++--- stable-diffusion.h | 2 +- t5.hpp | 4 +- wan.hpp | 179 +++++++++++++++++++++++++++---- 9 files changed, 476 insertions(+), 92 deletions(-) create mode 100644 examples/cli/avi_writer.h diff --git a/examples/cli/avi_writer.h b/examples/cli/avi_writer.h new file mode 100644 index 0000000..68e3f7f --- /dev/null +++ b/examples/cli/avi_writer.h @@ -0,0 +1,215 @@ +#ifndef __AVI_WRITER_H__ +#define __AVI_WRITER_H__ + +#include +#include +#include +#include + +#include "stable-diffusion.h" + +#include "stb_image_write.h" + +typedef struct { + uint32_t offset; + uint32_t size; +} avi_index_entry; + +// Write 32-bit little-endian integer +void write_u32_le(FILE* f, uint32_t val) { + fwrite(&val, 4, 1, f); +} + +// Write 16-bit little-endian integer +void write_u16_le(FILE* f, uint16_t val) { + fwrite(&val, 2, 1, f); +} + +/** + * Create an MJPG AVI file from an array of sd_image_t images. + * Images are encoded to JPEG using stb_image_write. + * + * @param filename Output AVI file name. + * @param images Array of input images. + * @param num_images Number of images in the array. + * @param fps Frames per second for the video. + * @param quality JPEG quality (0-100). + * @return 0 on success, -1 on failure. + */ +int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality = 90) { + if (num_images == 0) { + fprintf(stderr, "Error: Image array is empty.\n"); + return -1; + } + + FILE* f = fopen(filename, "wb"); + if (!f) { + perror("Error opening file for writing"); + return -1; + } + + uint32_t width = images[0].width; + uint32_t height = images[0].height; + uint32_t channels = images[0].channel; + if (channels != 3 && channels != 4) { + fprintf(stderr, "Error: Unsupported channel count: %u\n", channels); + fclose(f); + return -1; + } + + // --- RIFF AVI Header --- + fwrite("RIFF", 4, 1, f); + long riff_size_pos = ftell(f); + write_u32_le(f, 0); // Placeholder for file size + fwrite("AVI ", 4, 1, f); + + // 'hdrl' LIST (header list) + fwrite("LIST", 4, 1, f); + write_u32_le(f, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); + fwrite("hdrl", 4, 1, f); + + // 'avih' chunk (AVI main header) + fwrite("avih", 4, 1, f); + write_u32_le(f, 56); + write_u32_le(f, 1000000 / fps); // Microseconds per frame + write_u32_le(f, 0); // Max bytes per second + write_u32_le(f, 0); // Padding granularity + write_u32_le(f, 0x110); // Flags (HASINDEX | ISINTERLEAVED) + write_u32_le(f, num_images); // Total frames + write_u32_le(f, 0); // Initial frames + write_u32_le(f, 1); // Number of streams + write_u32_le(f, width * height * 3); // Suggested buffer size + write_u32_le(f, width); + write_u32_le(f, height); + write_u32_le(f, 0); // Reserved + write_u32_le(f, 0); // Reserved + write_u32_le(f, 0); // Reserved + write_u32_le(f, 0); // Reserved + + // 'strl' LIST (stream list) + fwrite("LIST", 4, 1, f); + write_u32_le(f, 4 + 8 + 56 + 8 + 40); + fwrite("strl", 4, 1, f); + + // 'strh' chunk (stream header) + fwrite("strh", 4, 1, f); + write_u32_le(f, 56); + fwrite("vids", 4, 1, f); // Stream type: video + fwrite("MJPG", 4, 1, f); // Codec: Motion JPEG + write_u32_le(f, 0); // Flags + write_u16_le(f, 0); // Priority + write_u16_le(f, 0); // Language + write_u32_le(f, 0); // Initial frames + write_u32_le(f, 1); // Scale + write_u32_le(f, fps); // Rate + write_u32_le(f, 0); // Start + write_u32_le(f, num_images); // Length + write_u32_le(f, width * height * 3); // Suggested buffer size + write_u32_le(f, (uint32_t)-1); // Quality + write_u32_le(f, 0); // Sample size + write_u16_le(f, 0); // rcFrame.left + write_u16_le(f, 0); // rcFrame.top + write_u16_le(f, 0); // rcFrame.right + write_u16_le(f, 0); // rcFrame.bottom + + // 'strf' chunk (stream format: BITMAPINFOHEADER) + fwrite("strf", 4, 1, f); + write_u32_le(f, 40); + write_u32_le(f, 40); // biSize + write_u32_le(f, width); + write_u32_le(f, height); + write_u16_le(f, 1); // biPlanes + write_u16_le(f, 24); // biBitCount + fwrite("MJPG", 4, 1, f); // biCompression (FOURCC) + write_u32_le(f, width * height * 3); // biSizeImage + write_u32_le(f, 0); // XPelsPerMeter + write_u32_le(f, 0); // YPelsPerMeter + write_u32_le(f, 0); // Colors used + write_u32_le(f, 0); // Colors important + + // 'movi' LIST (video frames) + long movi_list_pos = ftell(f); + fwrite("LIST", 4, 1, f); + long movi_size_pos = ftell(f); + write_u32_le(f, 0); // Placeholder for movi size + fwrite("movi", 4, 1, f); + + avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images); + if (!index) { + fclose(f); + return -1; + } + + // Encode and write each frame as JPEG + struct { + uint8_t* buf; + size_t size; + } jpeg_data; + + for (int i = 0; i < num_images; i++) { + jpeg_data.buf = NULL; + jpeg_data.size = 0; + + // Callback function to collect JPEG data into memory + auto write_to_buf = [](void* context, void* data, int size) { + auto jd = (decltype(jpeg_data)*)context; + jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size); + memcpy(jd->buf + jd->size, data, size); + jd->size += size; + }; + + // Encode to JPEG in memory + stbi_write_jpg_to_func( + write_to_buf, + &jpeg_data, + images[i].width, + images[i].height, + channels, + images[i].data, + quality); + + // Write '00dc' chunk (video frame) + fwrite("00dc", 4, 1, f); + write_u32_le(f, jpeg_data.size); + index[i].offset = ftell(f) - 8; + index[i].size = jpeg_data.size; + fwrite(jpeg_data.buf, 1, jpeg_data.size, f); + + // Align to even byte size + if (jpeg_data.size % 2) + fputc(0, f); + + free(jpeg_data.buf); + } + + // Finalize 'movi' size + long cur_pos = ftell(f); + long movi_size = cur_pos - movi_size_pos - 4; + fseek(f, movi_size_pos, SEEK_SET); + write_u32_le(f, movi_size); + fseek(f, cur_pos, SEEK_SET); + + // Write 'idx1' index + fwrite("idx1", 4, 1, f); + write_u32_le(f, num_images * 16); + for (int i = 0; i < num_images; i++) { + fwrite("00dc", 4, 1, f); + write_u32_le(f, 0x10); + write_u32_le(f, index[i].offset); + write_u32_le(f, index[i].size); + } + + // Finalize RIFF size + cur_pos = ftell(f); + long file_size = cur_pos - riff_size_pos - 4; + fseek(f, riff_size_pos, SEEK_SET); + write_u32_le(f, file_size); + fseek(f, cur_pos, SEEK_SET); + + fclose(f); + free(index); + + return 0; +} + +#endif // __AVI_WRITER_H__ \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 3f0cce1..9bbe7c7 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -10,6 +10,7 @@ #include // #include "preprocessing.hpp" +#include "avi_writer.h" #include "stable-diffusion.h" #define STB_IMAGE_IMPLEMENTATION @@ -83,6 +84,7 @@ struct SDParams { int batch_count = 1; int video_frames = 1; + int fps = 24; sample_method_t sample_method = EULER_A; schedule_t schedule = DEFAULT; @@ -166,6 +168,8 @@ void print_params(SDParams params) { printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); + printf(" video_frames: %d\n", params.video_frames); + printf(" fps: %d\n", params.fps); } void print_usage(int argc, const char* argv[]) { @@ -224,7 +228,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate\n"); printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); - printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); + printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); @@ -238,6 +242,8 @@ void print_usage(int argc, const char* argv[]) { printf(" --chroma-disable-dit-mask disable dit mask for chroma\n"); printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n"); printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); + printf(" --video-frames video frames (default: 1)\n"); + printf(" --fps fps (default: 24)\n"); printf(" -v, --verbose print extra info\n"); } @@ -435,6 +441,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--clip-skip", "", ¶ms.clip_skip}, {"-b", "--batch-count", "", ¶ms.batch_count}, {"", "--chroma-t5-mask-pad", "", ¶ms.chroma_t5_mask_pad}, + {"", "--video-frames", "", ¶ms.video_frames}, + {"", "--fps", "", ¶ms.fps}, }; options.float_options = { @@ -656,6 +664,16 @@ void parse_args(int argc, const char** argv, SDParams& params) { fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n"); } + if (params.mode == VID_GEN && params.video_frames <= 0) { + fprintf(stderr, "warning: --video-frames must be at least 1\n"); + exit(1); + } + + if (params.mode == VID_GEN && params.fps <= 0) { + fprintf(stderr, "warning: --fps must be at least 1\n"); + exit(1); + } + if (params.upscale_repeats < 1) { fprintf(stderr, "error: upscale multiplier must be at least 1\n"); exit(1); @@ -983,7 +1001,7 @@ int main(int argc, const char* argv[]) { mask_image_buffer}; sd_image_t* results; - int expected_num_results = 1; + int num_results = 1; if (params.mode == IMG_GEN) { sd_img_gen_params_t img_gen_params = { params.prompt.c_str(), @@ -1009,8 +1027,8 @@ int main(int argc, const char* argv[]) { params.input_id_images_path.c_str(), }; - results = generate_image(sd_ctx, &img_gen_params); - expected_num_results = params.batch_count; + results = generate_image(sd_ctx, &img_gen_params); + num_results = params.batch_count; } else if (params.mode == VID_GEN) { sd_vid_gen_params_t vid_gen_params = { params.prompt.c_str(), @@ -1028,8 +1046,7 @@ int main(int argc, const char* argv[]) { params.video_frames, }; - results = generate_video(sd_ctx, &vid_gen_params); - expected_num_results = params.video_frames; + results = generate_video(sd_ctx, &vid_gen_params, &num_results); } if (results == NULL) { @@ -1065,45 +1082,59 @@ int main(int argc, const char* argv[]) { } } - std::string dummy_name, ext, lc_ext; + std::string base_path; + std::string file_ext; + std::string file_ext_lower; bool is_jpg; - size_t last = params.output_path.find_last_of("."); - size_t last_path = std::min(params.output_path.find_last_of("/"), - params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { - dummy_name = params.output_path.substr(0, last); - ext = lc_ext = params.output_path.substr(last); - std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); - is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe"; + size_t last_dot_pos = params.output_path.find_last_of("."); + size_t last_slash_pos = std::min(params.output_path.find_last_of("/"), + params.output_path.find_last_of("\\")); + if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) { // filename has extension + base_path = params.output_path.substr(0, last_dot_pos); + file_ext = file_ext_lower = params.output_path.substr(last_dot_pos); + std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower); + is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe"); } else { - dummy_name = params.output_path; - ext = lc_ext = ""; - is_jpg = false; + base_path = params.output_path; + file_ext = file_ext_lower = ""; + is_jpg = false; } - // appending ".png" to absent or unknown extension - if (!is_jpg && lc_ext != ".png") { - dummy_name += ext; - ext = ".png"; + + if (params.mode == VID_GEN && num_results > 1) { + std::string vid_output_path = params.output_path; + if (file_ext_lower == ".png") { + vid_output_path = base_path + ".avi"; + } + create_mjpg_avi_from_sd_images(vid_output_path.c_str(), results, num_results, params.fps); + printf("save result MJPG AVI video to '%s'\n", vid_output_path.c_str()); + } else { + // appending ".png" to absent or unknown extension + if (!is_jpg && file_ext_lower != ".png") { + base_path += file_ext; + file_ext = ".png"; + } + for (int i = 0; i < num_results; i++) { + if (results[i].data == NULL) { + continue; + } + std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext; + if (is_jpg) { + stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 90, get_image_params(params, params.seed + i).c_str()); + printf("save result JPEG image to '%s'\n", final_image_path.c_str()); + } else { + stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 0, get_image_params(params, params.seed + i).c_str()); + printf("save result PNG image to '%s'\n", final_image_path.c_str()); + } + } } - for (int i = 0; i < expected_num_results; i++) { - if (results[i].data == NULL) { - continue; - } - std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if (is_jpg) { - stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 90, get_image_params(params, params.seed + i).c_str()); - printf("save result JPEG image to '%s'\n", final_image_path.c_str()); - } else { - stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 0, get_image_params(params, params.seed + i).c_str()); - printf("save result PNG image to '%s'\n", final_image_path.c_str()); - } + + free(results); + for (int i = 0; i < num_results; i++) { free(results[i].data); results[i].data = NULL; } - free(results); free_sd_ctx(sd_ctx); free(control_image_buffer); free(input_image_buffer); diff --git a/format-code.sh b/format-code.sh index d42eeed..9fdba32 100644 --- a/format-code.sh +++ b/format-code.sh @@ -1,4 +1,4 @@ -for f in *.cpp *.h *.hpp examples/cli/*.cpp; do +for f in *.cpp *.h *.hpp examples/cli/*.cpp examples/cli/*.h; do [[ "$f" == vocab* ]] && continue echo "formatting '$f'" clang-format -style=file -i "$f" diff --git a/ggml b/ggml index e89bc7e..089530b 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit e89bc7e8625f59145ee8c0b09383009c47752cd8 +Subproject commit 089530bb72e70aa9f9ecb98137dfd891c2be20c1 diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 13aa7e3..7563aed 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -988,19 +988,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* if (flash_attn) { // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); bool can_use_flash_attn = true; - can_use_flash_attn = can_use_flash_attn && (d_head == 64 || - d_head == 80 || - d_head == 96 || - d_head == 112 || - d_head == 128 || - d_head == 256); if (can_use_flash_attn && L_k % 256 != 0) { - // TODO(Green-Sky): might be worth just padding by default - if (L_k == 77 || L_k == 1560 || L_k == 4208 || L_k == 3952) { - kv_pad = GGML_PAD(L_k, 256) - L_k; - } else { - can_use_flash_attn = false; - } + kv_pad = GGML_PAD(L_k, 256) - L_k; } if (mask != nullptr) { @@ -1021,14 +1010,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); k = ggml_pad(ctx, k, 0, kv_pad, 0, 0); } - // k = ggml_cast(ctx, k, GGML_TYPE_F16); + k = ggml_cast(ctx, k, GGML_TYPE_F16); v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] if (kv_pad != 0) { v = ggml_pad(ctx, v, 0, kv_pad, 0, 0); } - // v = ggml_cast(ctx, v, GGML_TYPE_F16); + v = ggml_cast(ctx, v, GGML_TYPE_F16); if (mask != nullptr) { mask = ggml_transpose(ctx, mask); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 424fbad..08b9f4d 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1191,16 +1191,20 @@ public: } ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) { - int64_t W = x->ne[0] * 8; - int64_t H = x->ne[1] * 8; - int64_t C = 3; - ggml_tensor* result; + int64_t W = x->ne[0] * 8; + int64_t H = x->ne[1] * 8; + int64_t C = 3; + ggml_tensor* result = NULL; if (decode_video) { + int T = x->ne[2]; + if (sd_version_is_wan(version)) { + T = ((T - 1) * 4) + 1; + } result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, - x->ne[2], + T, 3); } else { result = ggml_new_tensor_4d(work_ctx, @@ -1214,6 +1218,7 @@ public: int64_t t0 = ggml_time_ms(); if (!use_tiny_autoencoder) { process_latent_out(x); + // x = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin"); if (vae_tiling && !decode_video) { // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { @@ -1221,7 +1226,7 @@ public: }; sd_tiling(x, result, 8, 32, 0.5f, on_tiling); } else { - first_stage_model->compute(n_threads, x, true, &result, NULL); + first_stage_model->compute(n_threads, x, true, &result, work_ctx); } first_stage_model->free_compute_buffer(); ggml_tensor_scale_output(result); @@ -1882,18 +1887,20 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, int frames = 1, bool video = false) { int C = 4; + int T = frames; if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } else if (sd_version_is_wan(sd_ctx->sd->version)) { C = 16; + T = ((T - 1) / 4) + 1; } int W = width / 8; int H = height / 8; ggml_tensor* init_latent; if (video) { - init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, frames, C); + init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); } else { init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); } @@ -2131,7 +2138,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g return result_images; } -SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { +SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out) { if (sd_ctx == NULL || sd_vid_gen_params == NULL) { return NULL; } @@ -2142,13 +2149,14 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int width = sd_vid_gen_params->width; int height = sd_vid_gen_params->height; int frames = sd_vid_gen_params->video_frames; + frames = (frames - 1) / 4 * 4 + 1; LOG_INFO("img2vid %dx%dx%d", width, height, frames); std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); struct ggml_init_params params; - params.mem_size = static_cast(100 * 1024) * 1024; // 50 MB - params.mem_size += width * height * frames * 3 * sizeof(float); + params.mem_size = static_cast(100 * 1024) * 1024; // 100 MB + params.mem_size += width * height * frames * 3 * sizeof(float) * 2; params.mem_buffer = NULL; params.no_alloc = false; // LOG_DEBUG("mem_size %u ", params.mem_size); @@ -2204,12 +2212,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int W = width / 8; int H = height / 8; - int T = frames; + int T = init_latent->ne[2]; int C = 16; struct ggml_tensor* final_latent; // Sample { + LOG_DEBUG("sample %dx%dx%d", W, H, T); int64_t sampling_start = ggml_time_ms(); struct ggml_tensor* x_t = init_latent; struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C); @@ -2247,15 +2256,16 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_ctx->sd->first_stage_model->free_params_buffer(); } - sd_image_t* result_images = (sd_image_t*)calloc(T, sizeof(sd_image_t)); + sd_image_t* result_images = (sd_image_t*)calloc(vid->ne[2], sizeof(sd_image_t)); if (result_images == NULL) { ggml_free(work_ctx); return NULL; } + *num_frames_out = vid->ne[2]; - for (size_t i = 0; i < T; i++) { - result_images[i].width = final_latent->ne[0] * 8; - result_images[i].height = final_latent->ne[1] * 8; + for (size_t i = 0; i < vid->ne[2]; i++) { + result_images[i].width = vid->ne[0]; + result_images[i].height = vid->ne[1]; result_images[i].channel = 3; result_images[i].data = sd_tensor_to_image(vid, i, true); } diff --git a/stable-diffusion.h b/stable-diffusion.h index 6c4cc96..644f930 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -231,7 +231,7 @@ SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_para SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); -SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken +SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out); typedef struct upscaler_ctx_t upscaler_ctx_t; diff --git a/t5.hpp b/t5.hpp index 1088532..408c256 100644 --- a/t5.hpp +++ b/t5.hpp @@ -994,8 +994,8 @@ struct T5Embedder { // cuda f16: pass // cuda f32: pass // cuda q8_0: pass - ggml_backend_t backend = ggml_backend_cuda_init(0); - // ggml_backend_t backend = ggml_backend_cpu_init(); + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; ModelLoader model_loader; diff --git a/wan.hpp b/wan.hpp index 763b774..25f2c17 100644 --- a/wan.hpp +++ b/wan.hpp @@ -14,6 +14,8 @@ namespace WAN { constexpr int CACHE_T = 2; constexpr int WAN_GRAPH_SIZE = 10240; +#define Rep ((struct ggml_tensor*)1) + class CausalConv3d : public GGMLBlock { protected: int64_t in_channels; @@ -68,7 +70,7 @@ namespace WAN { int lp2 = 2 * std::get<0>(padding); int rp2 = 0; - if (cache_x != NULL && std::get<0>(padding) > 0) { + if (cache_x != NULL && lp2 > 0) { x = ggml_concat(ctx, cache_x, x, 2); lp2 -= (int)cache_x->ne[2]; } @@ -145,8 +147,6 @@ namespace WAN { int64_t h = x->ne[1]; int64_t w = x->ne[0]; - struct ggml_tensor* Rep = (struct ggml_tensor*)1; - if (mode == "upsample3d") { if (feat_cache.size() > 0) { int idx = feat_idx; @@ -164,8 +164,8 @@ namespace WAN { cache_x, 2); } - if (cache_x->ne[1] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) { - cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 1, 1, (int)cache_x->ne[2], 0, 0, 0); + if (cache_x->ne[2] < 2 && feat_cache[idx] != NULL && feat_cache[idx] == Rep) { + cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); // aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2) } if (feat_cache[idx] == Rep) { @@ -629,7 +629,7 @@ namespace WAN { }; class WanVAE : public GGMLBlock { - protected: + public: bool decode_only = true; int64_t dim = 96; int64_t z_dim = 16; @@ -724,11 +724,47 @@ namespace WAN { clear_cache(); return out; } + + struct ggml_tensor* decode_partial(struct ggml_context* ctx, + struct ggml_tensor* z, + int64_t i, + int64_t b = 1) { + // z: [b*c, t, h, w] + GGML_ASSERT(b == 1); + + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + + auto x = conv2->forward(ctx, z); + auto in = ggml_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] + _conv_idx = 0; + auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx); + return out; + } + }; + + struct FeatCache { + std::vector data; + std::vector shape; + bool is_rep = false; + + FeatCache() = default; + + FeatCache(ggml_backend_t backend, ggml_tensor* tensor) { + shape = {tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]}; + data.resize(shape[0] * shape[1] * shape[2] * shape[3]); + ggml_backend_tensor_get_and_sync(backend, tensor, (void*)data.data(), 0, ggml_nbytes(tensor)); + } + + ggml_tensor* to_ggml_tensor(ggml_context* ctx) { + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, shape[0], shape[1], shape[2], shape[3]); + } }; struct WanVAERunner : public VAE { bool decode_only = true; WanVAE ae; + std::vector _feat_vec_map; WanVAERunner(ggml_backend_t backend, const String2GGMLType& tensor_types = {}, @@ -736,6 +772,11 @@ namespace WAN { bool decode_only = false) : decode_only(decode_only), ae(decode_only), VAE(backend) { ae.init(params_ctx, tensor_types, prefix); + rest_feat_vec_map(); + } + + void rest_feat_vec_map() { + _feat_vec_map = std::vector(ae._conv_num, FeatCache()); } std::string get_desc() { @@ -747,7 +788,7 @@ namespace WAN { } struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false); + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 10240 * z->ne[2], false); z = to_backend(z); @@ -758,22 +799,120 @@ namespace WAN { return gf; } + struct ggml_cgraph* build_graph_partial(struct ggml_tensor* z, bool decode_graph, int64_t i) { + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, 20480, false); + + ae.clear_cache(); + + for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) { + FeatCache& feat_cache_vec = _feat_vec_map[feat_idx]; + if (feat_cache_vec.is_rep) { + ae._feat_map[feat_idx] = Rep; + } else if (feat_cache_vec.data.size() > 0) { + ggml_tensor* feat_cache = feat_cache_vec.to_ggml_tensor(compute_ctx); + set_backend_tensor_data(feat_cache, feat_cache_vec.data.data()); + ae._feat_map[feat_idx] = feat_cache; + } + } + + z = to_backend(z); + + struct ggml_tensor* out = decode_graph ? ae.decode_partial(compute_ctx, z, i) : ae.encode(compute_ctx, z); + + for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { + ggml_tensor* feat_cache = ae._feat_map[feat_idx]; + if (feat_cache != NULL && feat_cache != Rep) { + ggml_build_forward_expand(gf, feat_cache); + } + } + + ggml_build_forward_expand(gf, out); + + return gf; + } + void compute(const int n_threads, struct ggml_tensor* z, bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { - auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(z, decode_graph); - }; - // ggml_set_f32(z, 0.5f); - // print_ggml_tensor(z); - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + if (true) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(z, decode_graph); + }; + GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + } else { // broken + ae.clear_cache(); + int64_t t = z->ne[2]; + int64_t i = 0; + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph_partial(z, decode_graph, i); + }; + struct ggml_tensor* out = NULL; + GGMLRunner::compute(get_graph, n_threads, false, &out, output_ctx); + for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) { + ggml_tensor* feat_cache = ae._feat_map[feat_idx]; + if (feat_cache == Rep) { + FeatCache feat_cache_vec; + feat_cache_vec.is_rep = true; + _feat_vec_map[feat_idx] = feat_cache_vec; + } else if (feat_cache != NULL) { + _feat_vec_map[feat_idx] = FeatCache(backend, feat_cache); + } + } + GGMLRunner::free_compute_buffer(); + if (t == 1) { + *output = out; + ae.clear_cache(); + return; + } + + *output = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], (t - 1) * 4 + 1, out->ne[3]); + + auto copy_to_output = [&]() { + for (int64_t i3 = 0; i3 < out->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < out->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < out->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < out->ne[0]; i0++) { + float value = ggml_tensor_get_f32(out, i0, i1, i2, i3); + int64_t offset = (i == 0) ? 0 : (1 + (i - 1) * 4); + ggml_tensor_set_f32(*output, value, i0, i1, offset + i2, i3); + } + } + } + } + }; + + copy_to_output(); + + out = ggml_new_tensor_4d(output_ctx, GGML_TYPE_F32, out->ne[0], out->ne[1], 4, out->ne[3]); + + for (i = 1; i < t; i++) { + GGMLRunner::compute(get_graph, n_threads, false, &out); + + for (int64_t feat_idx = 0; feat_idx < _feat_vec_map.size(); feat_idx++) { + ggml_tensor* feat_cache = ae._feat_map[feat_idx]; + if (feat_cache == Rep) { + FeatCache feat_cache_vec; + feat_cache_vec.is_rep = true; + _feat_vec_map[feat_idx] = feat_cache_vec; + } else if (feat_cache != NULL) { + _feat_vec_map[feat_idx] = FeatCache(backend, feat_cache); + } + } + + ae.clear_cache(); + + GGMLRunner::free_compute_buffer(); + + copy_to_output(); + } + } } void test() { struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB + params.mem_size = static_cast(1000 * 1024 * 1024); // 10 MB params.mem_buffer = NULL; params.no_alloc = false; @@ -785,9 +924,9 @@ namespace WAN { // cpu f16, pass // cuda f16, pass // cuda f32, pass - auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 1, 16); - z = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); - // ggml_set_f32(z, 0.5f); + auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 104, 60, 2, 16); + ggml_set_f32(z, 0.5f); + z = load_tensor_from_file(work_ctx, "wan_vae_video_z.bin"); print_ggml_tensor(z); struct ggml_tensor* out = NULL; @@ -803,7 +942,7 @@ namespace WAN { static void load_from_file_and_test(const std::string& file_path) { ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cpu_init(); - ggml_type model_data_type = GGML_TYPE_F32; + ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr vae = std::shared_ptr(new WanVAERunner(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); @@ -1588,8 +1727,8 @@ namespace WAN { } static void load_from_file_and_test(const std::string& file_path) { - ggml_backend_t backend = ggml_backend_cuda_init(0); - // ggml_backend_t backend = ggml_backend_cpu_init(); + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_Q8_0; LOG_INFO("loading from '%s'", file_path.c_str());