From 8b03d9bd0e64d78d99c0362371acf2b1f88eff51 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 10 May 2026 15:00:33 +0800 Subject: [PATCH] add ltx audio support --- examples/cli/main.cpp | 64 +- examples/common/common.cpp | 6 + examples/common/common.h | 1 + examples/common/media_io.cpp | 233 ++++++- examples/common/media_io.h | 24 +- examples/server/async_jobs.cpp | 13 +- ggml | 2 +- include/stable-diffusion.h | 15 +- src/ggml_extend.hpp | 28 +- src/ltx_audio_vae.h | 1109 ++++++++++++++++++++++++++++++++ src/ltxv.hpp | 9 +- src/stable-diffusion.cpp | 198 +++++- 12 files changed, 1627 insertions(+), 75 deletions(-) create mode 100644 src/ltx_audio_vae.h diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 27513f47..d13ca6d9 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -385,11 +385,32 @@ std::string format_frame_idx(std::string pattern, int frame_idx) { return result; } +static fs::path get_video_audio_sidecar_path(const SDCliParams& cli_params) { + fs::path out_path = cli_params.output_path; + fs::path base_path = out_path; + fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{}; + std::string ext_lower = ext.string(); + std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower); + const EncodedImageFormat output_format = encoded_image_format_from_path(out_path.string()); + if (!ext.empty()) { + if (output_format == EncodedImageFormat::JPEG || + output_format == EncodedImageFormat::PNG || + output_format == EncodedImageFormat::WEBP || + ext_lower == ".avi" || + ext_lower == ".webm") { + base_path.replace_extension(); + } + } + base_path += ".wav"; + return base_path; +} + bool save_results(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, sd_image_t* results, - int num_results) { + int num_results, + const sd_audio_t* generated_audio = nullptr) { if (results == nullptr || num_results <= 0) { return false; } @@ -442,6 +463,21 @@ bool save_results(const SDCliParams& cli_params, return ok; }; + auto write_audio_sidecar = [&](const fs::path& wav_path) { + if (generated_audio == nullptr) { + return; + } + if (write_wav_to_file(wav_path.string(), + generated_audio->data, + generated_audio->sample_count, + generated_audio->channels, + generated_audio->sample_rate)) { + LOG_INFO("save result audio to '%s'", wav_path.string().c_str()); + } else { + LOG_WARN("failed to save result audio to '%s'", wav_path.string().c_str()); + } + }; + int sucessful_reults = 0; if (std::regex_search(cli_params.output_path, format_specifier_regex)) { @@ -465,8 +501,16 @@ bool save_results(const SDCliParams& cli_params, ext = ".avi"; fs::path video_path = base_path; video_path += ext; - if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) { + std::string final_ext_lower = ext.string(); + std::transform(final_ext_lower.begin(), final_ext_lower.end(), final_ext_lower.begin(), ::tolower); + const bool mux_audio = generated_audio != nullptr && (final_ext_lower == ".avi" || final_ext_lower == ".webm"); + if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps, 90, mux_audio ? generated_audio : nullptr) == 0) { LOG_INFO("save result video to '%s'", video_path.string().c_str()); + if (generated_audio != nullptr && !mux_audio) { + fs::path wav_path = video_path; + wav_path.replace_extension(".wav"); + write_audio_sidecar(wav_path); + } return true; } else { LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str()); @@ -488,6 +532,9 @@ bool save_results(const SDCliParams& cli_params, } } LOG_INFO("%d/%d images saved", sucessful_reults, num_results); + if (generated_audio != nullptr) { + write_audio_sidecar(get_video_audio_sidecar_path(cli_params)); + } return sucessful_reults != 0; } @@ -701,7 +748,8 @@ int main(int argc, const char* argv[]) { sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview); SDImageVec results; - int num_results = 0; + int num_results = 0; + sd_audio_t* generated_audio = nullptr; if (cli_params.mode == UPSCALE) { num_results = 1; @@ -733,7 +781,10 @@ int main(int argc, const char* argv[]) { results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results); } else if (cli_params.mode == VID_GEN) { sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t(); - sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results); + sd_image_t* generated_video = nullptr; + if (!generate_video(sd_ctx.get(), &vid_gen_params, &generated_video, &num_results, &generated_audio)) { + generated_video = nullptr; + } results.adopt(generated_video, num_results); } @@ -773,9 +824,12 @@ int main(int argc, const char* argv[]) { } } - if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) { + if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results, generated_audio)) { + free_sd_audio(generated_audio); return 1; } + free_sd_audio(generated_audio); + return 0; } diff --git a/examples/common/common.cpp b/examples/common/common.cpp index dca63922..5107012d 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -348,6 +348,10 @@ ArgOptions SDContextParams::get_options() { "--vae", "path to standalone vae model", &vae_path}, + {"", + "--audio-vae", + "path to standalone LTX audio vae model", + &audio_vae_path}, {"", "--taesd", "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", @@ -667,6 +671,7 @@ std::string SDContextParams::to_string() const { << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" << " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n" << " vae_path: \"" << vae_path << "\",\n" + << " audio_vae_path: \"" << audio_vae_path << "\",\n" << " taesd_path: \"" << taesd_path << "\",\n" << " esrgan_path: \"" << esrgan_path << "\",\n" << " control_net_path: \"" << control_net_path << "\",\n" @@ -725,6 +730,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f high_noise_diffusion_model_path.c_str(), embeddings_connectors_path.c_str(), vae_path.c_str(), + audio_vae_path.c_str(), taesd_path.c_str(), control_net_path.c_str(), embedding_vec.data(), diff --git a/examples/common/common.h b/examples/common/common.h index ba10b101..f1d38083 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -94,6 +94,7 @@ struct SDContextParams { std::string high_noise_diffusion_model_path; std::string embeddings_connectors_path; std::string vae_path; + std::string audio_vae_path; std::string taesd_path; std::string esrgan_path; std::string control_net_path; diff --git a/examples/common/media_io.cpp b/examples/common/media_io.cpp index e2e1ca5a..506c67f4 100644 --- a/examples/common/media_io.cpp +++ b/examples/common/media_io.cpp @@ -613,6 +613,13 @@ typedef struct { uint32_t size; } avi_index_entry; +typedef struct { + char fourcc[4]; + uint32_t flags; + uint32_t offset; + uint32_t size; +} avi_chunk_index_entry; + void write_u32_le(FILE* f, uint32_t val) { fwrite(&val, 4, 1, f); } @@ -647,6 +654,33 @@ void write_fourcc(std::vector& data, const char* fourcc) { data.insert(data.end(), fourcc, fourcc + 4); } +static std::vector audio_to_pcm16_bytes(const sd_audio_t* audio) { + if (audio == nullptr || audio->data == nullptr || audio->sample_count == 0 || audio->channels == 0 || audio->sample_rate == 0) { + return {}; + } + + const size_t pcm_samples = static_cast(audio->sample_count) * static_cast(audio->channels); + std::vector bytes(pcm_samples * sizeof(int16_t)); + auto* pcm = reinterpret_cast(bytes.data()); + for (size_t i = 0; i < pcm_samples; ++i) { + const float sample = std::clamp(audio->data[i], -1.0f, 1.0f); + pcm[i] = static_cast(std::lrint(sample * 32767.0f)); + } + return bytes; +} + +static std::pair audio_sample_range_for_video_frame(const sd_audio_t* audio, int frame_idx, int num_frames, int fps) { + if (audio == nullptr || fps <= 0 || num_frames <= 0) { + return {0, 0}; + } + const uint64_t total = audio->sample_count; + const uint64_t start = static_cast((static_cast(frame_idx) * total) / num_frames); + const uint64_t end = frame_idx + 1 == num_frames + ? total + : static_cast((static_cast(frame_idx + 1) * total) / num_frames); + return {start, std::max(start, end)}; +} + EncodedImageFormat encoded_image_format_from_path(const std::string& path) { std::string ext = fs::path(path).extension().string(); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); @@ -776,7 +810,7 @@ uint8_t* load_image_from_memory(const char* image_bytes, return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel); } -std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); return {}; @@ -793,7 +827,13 @@ std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images // stb_image_write changes JPEG sampling behavior above quality 90. // MJPG AVI playback is more compatible when we keep the encoder on the // <= 90 path. - const int mjpg_quality = std::clamp(quality, 1, 90); + const int mjpg_quality = std::clamp(quality, 1, 90); + const bool has_audio = audio != nullptr && audio->data != nullptr && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0; + const std::vector audio_pcm = audio_to_pcm16_bytes(audio); + const uint16_t audio_bits_per_sample = 16; + const uint16_t audio_block_align = has_audio ? static_cast(audio->channels * (audio_bits_per_sample / 8)) : 0; + const uint32_t audio_byte_rate = has_audio ? static_cast(audio->sample_rate * audio_block_align) : 0; + const uint32_t audio_data_size = has_audio ? static_cast(audio_pcm.size()) : 0; std::vector avi_data; avi_data.reserve(static_cast(num_images) * 1024); @@ -804,7 +844,11 @@ std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images write_fourcc(avi_data, "AVI "); write_fourcc(avi_data, "LIST"); - write_u32_le(avi_data, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); + uint32_t hdrl_size = 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40; + if (has_audio) { + hdrl_size += 8 + (4 + 8 + 56 + 8 + 16); + } + write_u32_le(avi_data, hdrl_size); write_fourcc(avi_data, "hdrl"); write_fourcc(avi_data, "avih"); @@ -815,7 +859,7 @@ std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images write_u32_le(avi_data, 0x110); write_u32_le(avi_data, num_images); write_u32_le(avi_data, 0); - write_u32_le(avi_data, 1); + write_u32_le(avi_data, has_audio ? 2 : 1); write_u32_le(avi_data, width * height * 3); write_u32_le(avi_data, width); write_u32_le(avi_data, height); @@ -862,12 +906,48 @@ std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images write_u32_le(avi_data, 0); write_u32_le(avi_data, 0); + if (has_audio) { + write_fourcc(avi_data, "LIST"); + write_u32_le(avi_data, 4 + 8 + 56 + 8 + 16); + write_fourcc(avi_data, "strl"); + + write_fourcc(avi_data, "strh"); + write_u32_le(avi_data, 56); + write_fourcc(avi_data, "auds"); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, audio_block_align); + write_u32_le(avi_data, audio_byte_rate); + write_u32_le(avi_data, 0); + write_u32_le(avi_data, static_cast(audio->sample_count)); + write_u32_le(avi_data, audio_data_size); + write_u32_le(avi_data, static_cast(-1)); + write_u32_le(avi_data, audio_block_align); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + write_u16_le(avi_data, 0); + + write_fourcc(avi_data, "strf"); + write_u32_le(avi_data, 16); + write_u16_le(avi_data, 1); + write_u16_le(avi_data, static_cast(audio->channels)); + write_u32_le(avi_data, audio->sample_rate); + write_u32_le(avi_data, audio_byte_rate); + write_u16_le(avi_data, audio_block_align); + write_u16_le(avi_data, audio_bits_per_sample); + } + write_fourcc(avi_data, "LIST"); const size_t movi_size_pos = avi_data.size(); write_u32_le(avi_data, 0); write_fourcc(avi_data, "movi"); - std::vector index(static_cast(num_images)); + std::vector index; + index.reserve(static_cast(num_images) + (has_audio ? 1 : 0)); std::vector jpeg_data; for (int i = 0; i < num_images; i++) { @@ -884,27 +964,46 @@ std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images return {}; } - index[i].offset = static_cast(avi_data.size()); + avi_chunk_index_entry video_entry = {}; + memcpy(video_entry.fourcc, "00dc", 4); + video_entry.flags = 0x10; + video_entry.offset = static_cast(avi_data.size()); write_fourcc(avi_data, "00dc"); write_u32_le(avi_data, static_cast(jpeg_data.size())); - index[i].size = (uint32_t)jpeg_data.size(); + video_entry.size = static_cast(jpeg_data.size()); avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end()); + index.push_back(video_entry); if (jpeg_data.size() % 2) { avi_data.push_back(0); } } + if (has_audio && !audio_pcm.empty()) { + avi_chunk_index_entry audio_entry = {}; + memcpy(audio_entry.fourcc, "01wb", 4); + audio_entry.flags = 0; + audio_entry.offset = static_cast(avi_data.size()); + audio_entry.size = static_cast(audio_pcm.size()); + write_fourcc(avi_data, "01wb"); + write_u32_le(avi_data, static_cast(audio_pcm.size())); + avi_data.insert(avi_data.end(), audio_pcm.begin(), audio_pcm.end()); + index.push_back(audio_entry); + if (audio_pcm.size() % 2 != 0) { + avi_data.push_back(0); + } + } + const size_t movi_size = avi_data.size() - movi_size_pos - 4; patch_u32_le(avi_data, movi_size_pos, static_cast(movi_size)); write_fourcc(avi_data, "idx1"); - write_u32_le(avi_data, num_images * 16); - for (int i = 0; i < num_images; i++) { - write_fourcc(avi_data, "00dc"); - write_u32_le(avi_data, 0x10); - write_u32_le(avi_data, index[i].offset); - write_u32_le(avi_data, index[i].size); + write_u32_le(avi_data, static_cast(index.size() * 16)); + for (const auto& entry : index) { + write_fourcc(avi_data, entry.fourcc); + write_u32_le(avi_data, entry.flags); + write_u32_le(avi_data, entry.offset); + write_u32_le(avi_data, entry.size); } const size_t file_size = avi_data.size() - riff_size_pos - 4; @@ -913,8 +1012,8 @@ std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images return avi_data; } -int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { - std::vector avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); +int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) { + std::vector avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio); if (avi_data.empty()) { return -1; } @@ -1044,7 +1143,7 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images #endif #ifdef SD_USE_WEBM -std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { +std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) { if (num_images == 0) { fprintf(stderr, "Error: Image array is empty.\n"); return {}; @@ -1089,6 +1188,25 @@ std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, in video_track->set_display_height(static_cast(height)); video_track->set_frame_rate(static_cast(fps)); } + + uint64_t audio_track_number = 0; + std::vector audio_pcm = audio_to_pcm16_bytes(audio); + if (audio != nullptr && !audio_pcm.empty()) { + audio_track_number = segment.AddAudioTrack(static_cast(audio->sample_rate), static_cast(audio->channels), 0); + if (audio_track_number == 0) { + fprintf(stderr, "Error: Failed to add audio track.\n"); + return -1; + } + auto* audio_track = static_cast(segment.GetTrackByNumber(audio_track_number)); + if (audio_track == nullptr) { + fprintf(stderr, "Error: Failed to get audio track.\n"); + return -1; + } + audio_track->set_codec_id("A_PCM/INT/LIT"); + audio_track->set_bit_depth(16); + audio_track->set_sample_rate(static_cast(audio->sample_rate)); + audio_track->set_channels(audio->channels); + } segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp"); segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp"); @@ -1118,6 +1236,23 @@ std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, in return -1; } + if (audio_track_number != 0) { + auto [audio_begin, audio_end] = audio_sample_range_for_video_frame(audio, i, num_images, fps); + const uint64_t frame_samples = audio_end - audio_begin; + if (frame_samples > 0) { + const uint64_t frame_bytes = frame_samples * audio->channels * sizeof(int16_t); + const uint8_t* frame_ptr = audio_pcm.data() + audio_begin * audio->channels * sizeof(int16_t); + if (!segment.AddFrame(frame_ptr, + frame_bytes, + audio_track_number, + timestamp_ns, + true)) { + fprintf(stderr, "Error: Failed to mux audio chunk %d into WebM.\n", i); + return -1; + } + } + } + timestamp_ns += frame_duration_ns; } @@ -1133,8 +1268,8 @@ std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, in return writer.data(); } -int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { - std::vector webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality); +int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) { + std::vector webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio); if (webm_data.empty()) { return -1; } @@ -1150,7 +1285,8 @@ std::vector create_video_from_sd_images_to_vector(const std::string& ou sd_image_t* images, int num_images, int fps, - int quality) { + int quality, + const sd_audio_t* audio) { std::string format = output_format; std::transform(format.begin(), format.end(), format.begin(), [](unsigned char c) { return static_cast(tolower(c)); }); @@ -1160,7 +1296,7 @@ std::vector create_video_from_sd_images_to_vector(const std::string& ou #ifdef SD_USE_WEBM if (format == "webm") { - return create_webm_from_sd_images_to_vector(images, num_images, fps, quality); + return create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio); } #endif @@ -1170,14 +1306,14 @@ std::vector create_video_from_sd_images_to_vector(const std::string& ou } #endif - return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); + return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio); } -int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { +int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) { std::string path = filename ? filename : ""; auto pos = path.find_last_of('.'); std::string ext = pos == std::string::npos ? "" : path.substr(pos); - std::vector video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality); + std::vector video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality, audio); if (video_data.empty()) { return -1; } @@ -1187,3 +1323,54 @@ int create_video_from_sd_images(const char* filename, sd_image_t* images, int nu } return 0; } + +bool write_wav_to_file(const std::string& path, + const float* interleaved_samples, + uint64_t sample_count, + uint32_t channels, + uint32_t sample_rate) { + if (interleaved_samples == nullptr || sample_count == 0 || channels == 0 || sample_rate == 0) { + return false; + } + + std::ofstream file(path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + uint32_t bits_per_sample = 16; + uint32_t bytes_per_sample = bits_per_sample / 8; + uint32_t block_align = channels * bytes_per_sample; + uint32_t byte_rate = sample_rate * block_align; + uint32_t data_size = static_cast(sample_count * channels * bytes_per_sample); + uint32_t riff_size = 36 + data_size; + + file.write("RIFF", 4); + file.write(reinterpret_cast(&riff_size), sizeof(riff_size)); + file.write("WAVE", 4); + file.write("fmt ", 4); + + uint32_t fmt_size = 16; + uint16_t audio_format = 1; + uint16_t wav_channels = static_cast(channels); + uint16_t wav_block_align = static_cast(block_align); + uint16_t wav_bits_per_sample = static_cast(bits_per_sample); + file.write(reinterpret_cast(&fmt_size), sizeof(fmt_size)); + file.write(reinterpret_cast(&audio_format), sizeof(audio_format)); + file.write(reinterpret_cast(&wav_channels), sizeof(wav_channels)); + file.write(reinterpret_cast(&sample_rate), sizeof(sample_rate)); + file.write(reinterpret_cast(&byte_rate), sizeof(byte_rate)); + file.write(reinterpret_cast(&wav_block_align), sizeof(wav_block_align)); + file.write(reinterpret_cast(&wav_bits_per_sample), sizeof(wav_bits_per_sample)); + + file.write("data", 4); + file.write(reinterpret_cast(&data_size), sizeof(data_size)); + + std::vector pcm(sample_count * channels); + for (size_t i = 0; i < pcm.size(); ++i) { + float sample = std::max(-1.0f, std::min(1.0f, interleaved_samples[i])); + pcm[i] = static_cast(std::lrint(sample * 32767.0f)); + } + file.write(reinterpret_cast(pcm.data()), static_cast(pcm.size() * sizeof(int16_t))); + return file.good(); +} diff --git a/examples/common/media_io.h b/examples/common/media_io.h index 6b3f6f88..0f7679d7 100644 --- a/examples/common/media_io.h +++ b/examples/common/media_io.h @@ -57,11 +57,13 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, - int quality = 90); + int quality = 90, + const sd_audio_t* audio = nullptr); std::vector create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, - int quality = 90); + int quality = 90, + const sd_audio_t* audio = nullptr); #ifdef SD_USE_WEBP int create_animated_webp_from_sd_images(const char* filename, @@ -80,22 +82,32 @@ int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, - int quality = 90); + int quality = 90, + const sd_audio_t* audio = nullptr); std::vector create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, - int quality = 90); + int quality = 90, + const sd_audio_t* audio = nullptr); #endif int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, - int quality = 90); + int quality = 90, + const sd_audio_t* audio = nullptr); std::vector create_video_from_sd_images_to_vector(const std::string& output_format, sd_image_t* images, int num_images, int fps, - int quality = 90); + int quality = 90, + const sd_audio_t* audio = nullptr); + +bool write_wav_to_file(const std::string& path, + const float* interleaved_samples, + uint64_t sample_count, + uint32_t channels, + uint32_t sample_rate); #endif // __MEDIA_IO_H__ diff --git a/examples/server/async_jobs.cpp b/examples/server/async_jobs.cpp index e8e9d8ad..4e635cc3 100644 --- a/examples/server/async_jobs.cpp +++ b/examples/server/async_jobs.cpp @@ -231,16 +231,21 @@ bool execute_vid_gen_job(ServerRuntime& runtime, sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t(); SDImageVec results; - int num_results = 0; + int num_results = 0; + sd_audio_t* generated_audio = nullptr; { std::lock_guard lock(*runtime.sd_ctx_mutex); - sd_image_t* raw_results = generate_video(runtime.sd_ctx, ¶ms, &num_results); + sd_image_t* raw_results = nullptr; + if (!generate_video(runtime.sd_ctx, ¶ms, &raw_results, &num_results, &generated_audio)) { + raw_results = nullptr; + } results.adopt(raw_results, num_results); } num_results = results.count(); if (num_results <= 0) { + free_sd_audio(generated_audio); error_message = "generate_video returned no results"; return false; } @@ -249,7 +254,9 @@ bool execute_vid_gen_job(ServerRuntime& runtime, results.data(), num_results, job.vid_gen.gen_params.fps, - job.vid_gen.output_compression); + job.vid_gen.output_compression, + generated_audio); + free_sd_audio(generated_audio); if (video_bytes.empty()) { error_message = "failed to encode generated video container"; return false; diff --git a/ggml b/ggml index 404fcb9d..7f4ab364 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 404fcb9d7c96989569e68c9e7881ee3465a05c50 +Subproject commit 7f4ab364b2843921e795d6890d0f42dd5e5d6b63 diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 9e342303..ccc76acc 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -174,6 +174,7 @@ typedef struct { const char* high_noise_diffusion_model_path; const char* embeddings_connectors_path; const char* vae_path; + const char* audio_vae_path; const char* taesd_path; const char* control_net_path; const sd_embedding_t* embeddings; @@ -208,6 +209,13 @@ typedef struct { float max_vram; } sd_ctx_params_t; +typedef struct { + uint32_t sample_rate; + uint32_t channels; + uint64_t sample_count; + float* data; +} sd_audio_t; + typedef struct { uint32_t width; uint32_t height; @@ -407,6 +415,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); +SD_API void free_sd_audio(sd_audio_t* audio); SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); @@ -419,7 +428,11 @@ SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_para SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); -SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out); +SD_API bool generate_video(sd_ctx_t* sd_ctx, + const sd_vid_gen_params_t* sd_vid_gen_params, + sd_image_t** frames_out, + int* num_frames_out, + sd_audio_t** audio_out); typedef struct upscaler_ctx_t upscaler_ctx_t; diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index edea25b5..ab98e0c4 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1698,14 +1698,14 @@ struct WeightAdapter { }; struct GGMLRunnerContext { - ggml_backend_t backend = nullptr; - ggml_context* ggml_ctx = nullptr; - bool flash_attn_enabled = false; - bool conv2d_direct_enabled = false; - bool circular_x_enabled = false; - bool circular_y_enabled = false; - std::shared_ptr weight_adapter = nullptr; - std::unordered_map* debug_tensors = nullptr; + ggml_backend_t backend = nullptr; + ggml_context* ggml_ctx = nullptr; + bool flash_attn_enabled = false; + bool conv2d_direct_enabled = false; + bool circular_x_enabled = false; + bool circular_y_enabled = false; + std::shared_ptr weight_adapter = nullptr; + std::vector>* debug_tensors = nullptr; std::function get_cache_tensor; std::function cache_tensor; @@ -1713,8 +1713,14 @@ struct GGMLRunnerContext { if (debug_tensors == nullptr || tensor == nullptr) { return; } - ggml_set_output(tensor); - (*debug_tensors)[tensor] = name; + ggml_tensor* snapshot = tensor; + if (!ggml_is_contiguous(snapshot) || snapshot->view_src != nullptr) { + snapshot = ggml_cont(ggml_ctx, snapshot); + } + ggml_tensor* dst = ggml_dup_tensor(ggml_ctx, snapshot); + snapshot = ggml_cpy(ggml_ctx, snapshot, dst); + ggml_set_output(snapshot); + debug_tensors->push_back({snapshot, name}); } ggml_tensor* load_cache_tensor(const std::string& name) const { @@ -1768,7 +1774,7 @@ protected: std::map backend_tensor_data_map; std::map cache_tensor_map; // name -> tensor - std::unordered_map debug_tensors; + std::vector> debug_tensors; const std::string final_result_name = "ggml_runner_final_result_tensor"; bool flash_attn_enabled = false; diff --git a/src/ltx_audio_vae.h b/src/ltx_audio_vae.h new file mode 100644 index 00000000..d5ee30bc --- /dev/null +++ b/src/ltx_audio_vae.h @@ -0,0 +1,1109 @@ +#ifndef __SD_LTX_AUDIO_VAE_H__ +#define __SD_LTX_AUDIO_VAE_H__ + +#include +#include +#include +#include + +#include "ggml_extend.hpp" + +namespace LTXV { + + struct LTXAudioVAEConfig { + int sample_rate = 16000; + int mel_hop_length = 160; + int n_fft = 1024; + int mel_bins = 64; + int latent_channels = 8; + int latent_frequency_bins = 16; + int audio_channels = 2; + int decoder_channels = 128; + std::vector decoder_channel_multipliers = {1, 2, 4}; + int decoder_num_res_blocks = 2; + int base_upsample_initial_channel = 1536; + std::vector base_upsample_rates = {5, 2, 2, 2, 2, 2}; + std::vector base_upsample_kernel_sizes = {11, 4, 4, 4, 4, 4}; + std::vector base_resblock_kernel_sizes = {3, 7, 11}; + std::vector> base_resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + bool has_bwe = false; + int bwe_input_sample_rate = 16000; + int bwe_output_sample_rate = 48000; + int bwe_hop_length = 80; + int bwe_n_fft = 512; + int bwe_num_mels = 64; + int bwe_upsample_initial_channel = 512; + std::vector bwe_upsample_rates = {6, 5, 2, 2, 2}; + std::vector bwe_upsample_kernel_sizes = {12, 11, 4, 4, 4}; + std::vector bwe_resblock_kernel_sizes = {3, 7, 11}; + std::vector> bwe_resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + + int latent_downsample_factor() const { + return 4; + } + + int base_output_sample_rate() const { + int upsample_factor = 1; + for (int rate : base_upsample_rates) { + upsample_factor *= rate; + } + return sample_rate * upsample_factor / mel_hop_length; + } + + int output_sample_rate() const { + if (has_bwe) { + return bwe_output_sample_rate; + } + return base_output_sample_rate(); + } + + static LTXAudioVAEConfig detect_from_weights(const String2TensorStorage& tensor_storage_map) { + LTXAudioVAEConfig config; + + auto require = [&](const std::string& name) -> const TensorStorage* { + auto iter = tensor_storage_map.find(name); + if (iter == tensor_storage_map.end()) { + return nullptr; + } + return &iter->second; + }; + + const TensorStorage* decoder_conv_in = require("audio_vae.decoder.conv_in.conv.weight"); + const TensorStorage* decoder_conv_out = require("audio_vae.decoder.conv_out.conv.weight"); + const TensorStorage* latent_std = require("audio_vae.per_channel_statistics.std-of-means"); + const TensorStorage* vocoder_conv_pre = require("vocoder.vocoder.conv_pre.weight"); + const TensorStorage* vocoder_conv_post = require("vocoder.vocoder.conv_post.weight"); + if (decoder_conv_in == nullptr || decoder_conv_out == nullptr || latent_std == nullptr || + vocoder_conv_pre == nullptr || vocoder_conv_post == nullptr) { + return config; + } + + config.sample_rate = 16000; + config.mel_hop_length = 160; + config.n_fft = 1024; + config.base_upsample_rates = {5, 2, 2, 2, 2, 2}; + config.base_resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + + config.latent_channels = static_cast(decoder_conv_in->ne[2]); + config.audio_channels = static_cast(decoder_conv_out->ne[3]); + config.latent_frequency_bins = static_cast(latent_std->ne[0] / std::max(1, decoder_conv_in->ne[2])); + config.mel_bins = config.latent_frequency_bins * config.latent_downsample_factor(); + config.base_upsample_initial_channel = static_cast(vocoder_conv_pre->ne[2]); + + if (latent_std->ne[0] % std::max(1, decoder_conv_in->ne[2]) != 0) { + return config; + } + + std::vector> level_channels; + for (const auto& pair : tensor_storage_map) { + const std::string& name = pair.first; + const std::string prefix = "audio_vae.decoder.up."; + const std::string suffix = ".block.0.conv1.conv.weight"; + if (!starts_with(name, prefix) || !ends_with(name, suffix)) { + continue; + } + std::string level_str = name.substr(prefix.size(), name.size() - prefix.size() - suffix.size()); + int level = std::stoi(level_str); + level_channels.push_back({level, static_cast(pair.second.ne[3])}); + } + std::sort(level_channels.begin(), level_channels.end()); + if (level_channels.empty()) { + return config; + } + config.decoder_channels = level_channels.front().second; + config.decoder_channel_multipliers.clear(); + for (const auto& level_channel : level_channels) { + config.decoder_channel_multipliers.push_back(level_channel.second / std::max(1, config.decoder_channels)); + } + + int block_count = 0; + while (tensor_storage_map.find("audio_vae.decoder.up.0.block." + std::to_string(block_count) + ".conv1.conv.weight") != tensor_storage_map.end()) { + ++block_count; + } + if (block_count <= 0) { + return config; + } + config.decoder_num_res_blocks = block_count - 1; + + config.base_upsample_kernel_sizes.clear(); + for (int i = 0;; ++i) { + auto iter = tensor_storage_map.find("vocoder.vocoder.ups." + std::to_string(i) + ".weight"); + if (iter == tensor_storage_map.end()) { + break; + } + config.base_upsample_kernel_sizes.push_back(static_cast(iter->second.ne[0])); + } + if (config.base_upsample_kernel_sizes.size() != config.base_upsample_rates.size()) { + return config; + } + + config.base_resblock_kernel_sizes.clear(); + for (int i = 0;; ++i) { + auto iter = tensor_storage_map.find("vocoder.vocoder.resblocks." + std::to_string(i) + ".convs1.0.weight"); + if (iter == tensor_storage_map.end()) { + break; + } + config.base_resblock_kernel_sizes.push_back(static_cast(iter->second.ne[0])); + } + if (config.base_resblock_kernel_sizes.size() < 3) { + return config; + } + config.base_resblock_kernel_sizes.resize(3); + + config.has_bwe = tensor_storage_map.find("vocoder.bwe_generator.conv_pre.weight") != tensor_storage_map.end(); + if (config.has_bwe) { + config.bwe_input_sample_rate = 16000; + config.bwe_output_sample_rate = 48000; + config.bwe_hop_length = 80; + config.bwe_n_fft = 512; + config.bwe_num_mels = 64; + config.bwe_upsample_initial_channel = 512; + config.bwe_upsample_rates = {6, 5, 2, 2, 2}; + config.bwe_upsample_kernel_sizes = {12, 11, 4, 4, 4}; + config.bwe_resblock_kernel_sizes = {3, 7, 11}; + config.bwe_resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + } + + if (config.audio_channels != 2 || config.latent_channels != 8 || config.mel_bins != 64) { + return config; + } + return config; + } + }; + + static sd::Tensor squeeze_trailing_singleton_dims(sd::Tensor tensor) { + while (tensor.dim() > 0 && tensor.shape().back() == 1) { + tensor = tensor.squeeze(static_cast(tensor.dim() - 1)); + } + return tensor; + } + + static sd::Tensor normalize_waveform_for_host(sd::Tensor waveform) { + waveform = squeeze_trailing_singleton_dims(std::move(waveform)); + if (waveform.empty()) { + return waveform; + } + if (waveform.dim() == 1) { + return waveform.reshape({waveform.shape()[0], 1, 1}); + } + if (waveform.dim() == 2) { + return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1}); + } + if (waveform.dim() == 3) { + return waveform; + } + throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim())); + } + + static sd::Tensor load_param_tensor_f32(ggml_tensor* tensor) { + GGML_ASSERT(tensor != nullptr); + return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml(tensor)); + } + + static sd::Tensor compute_log_mel_spectrogram(const sd::Tensor& waveform_in, + const sd::Tensor& forward_basis, + const sd::Tensor& mel_basis, + int hop_length) { + auto waveform = normalize_waveform_for_host(waveform_in); + GGML_ASSERT(forward_basis.dim() >= 3); + GGML_ASSERT(mel_basis.dim() >= 2); + + const int64_t time = waveform.shape()[0]; + const int64_t channels = waveform.shape()[1]; + const int64_t batch = waveform.shape()[2]; + const int64_t filter_len = forward_basis.shape()[0]; + const int64_t basis_freq2 = forward_basis.shape().back(); + const int64_t n_freqs = basis_freq2 / 2; + const int64_t n_mels = mel_basis.shape()[1]; + const int64_t left_pad = std::max(0, filter_len - hop_length); + const int64_t padded_time = time + left_pad; + const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; + + sd::Tensor log_mel({n_mels, frame_count, channels, batch}); + std::vector padded(static_cast(padded_time), 0.0f); + std::vector magnitude(static_cast(n_freqs), 0.0f); + + for (int64_t b = 0; b < batch; ++b) { + for (int64_t c = 0; c < channels; ++c) { + std::fill(padded.begin(), padded.end(), 0.0f); + for (int64_t t = 0; t < time; ++t) { + padded[static_cast(t + left_pad)] = waveform.index(t, c, b); + } + + for (int64_t frame = 0; frame < frame_count; ++frame) { + const int64_t frame_offset = frame * hop_length; + for (int64_t f = 0; f < n_freqs; ++f) { + double real = 0.0; + double imag = 0.0; + for (int64_t k = 0; k < filter_len; ++k) { + const float sample = padded[static_cast(frame_offset + k)]; + real += static_cast(sample) * static_cast(forward_basis.index(k, 0, f)); + imag += static_cast(sample) * static_cast(forward_basis.index(k, 0, f + n_freqs)); + } + magnitude[static_cast(f)] = static_cast(std::sqrt(real * real + imag * imag)); + } + + for (int64_t m = 0; m < n_mels; ++m) { + double mel_value = 0.0; + for (int64_t f = 0; f < n_freqs; ++f) { + mel_value += static_cast(mel_basis.index(f, m)) * static_cast(magnitude[static_cast(f)]); + } + log_mel.index(m, frame, c, b) = static_cast(std::log(std::max(mel_value, 1e-5))); + } + } + } + } + + return log_mel; + } + + static std::vector build_hann_resample_filter(int ratio) { + constexpr double kPi = 3.14159265358979323846; + const double rolloff = 0.99; + const int lowpass_filter_width = 6; + const int width = static_cast(std::ceil(static_cast(lowpass_filter_width) / rolloff)); + const int kernel_size = 2 * width * ratio + 1; + const double half_lowpass_pi = kPi / lowpass_filter_width / 2.0; + std::vector filter(static_cast(kernel_size), 0.0f); + for (int i = 0; i < kernel_size; ++i) { + double t = (static_cast(i) / ratio - width) * rolloff; + double t_clamped = std::clamp(t, -static_cast(lowpass_filter_width), static_cast(lowpass_filter_width)); + double window = std::cos(t_clamped * half_lowpass_pi); + window *= window; + double sinc = t == 0.0 ? 1.0 : std::sin(kPi * t) / (kPi * t); + filter[static_cast(i)] = static_cast(sinc * window * rolloff / ratio); + } + return filter; + } + + static sd::Tensor upsample_waveform_hann(const sd::Tensor& waveform_in, int ratio) { + auto waveform = normalize_waveform_for_host(waveform_in); + if (ratio <= 1) { + return waveform; + } + + const int lowpass_filter_width = 6; + const double rolloff = 0.99; + const int width = static_cast(std::ceil(static_cast(lowpass_filter_width) / rolloff)); + const int kernel_size = 2 * width * ratio + 1; + const int pad = width; + const int pad_left = 2 * width * ratio; + const int pad_right = kernel_size - ratio; + const int64_t time = waveform.shape()[0]; + const int64_t channels = waveform.shape()[1]; + const int64_t batch = waveform.shape()[2]; + const int64_t padded_time = time + 2 * pad; + const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size; + const int64_t cropped_time = conv_out_time - pad_left - pad_right; + auto filter = build_hann_resample_filter(ratio); + + sd::Tensor output({cropped_time, channels, batch}); + std::vector padded(static_cast(padded_time), 0.0f); + std::vector conv_out(static_cast(conv_out_time), 0.0f); + + for (int64_t b = 0; b < batch; ++b) { + for (int64_t c = 0; c < channels; ++c) { + std::fill(padded.begin(), padded.end(), 0.0f); + const float first = waveform.index(0, c, b); + const float last = waveform.index(time - 1, c, b); + for (int i = 0; i < pad; ++i) { + padded[static_cast(i)] = first; + padded[static_cast(pad + time + i)] = last; + } + for (int64_t t = 0; t < time; ++t) { + padded[static_cast(pad + t)] = waveform.index(t, c, b); + } + + std::fill(conv_out.begin(), conv_out.end(), 0.0f); + for (int64_t t = 0; t < padded_time; ++t) { + const double sample = static_cast(padded[static_cast(t)]) * ratio; + const int64_t out_base = t * ratio; + for (int k = 0; k < kernel_size; ++k) { + conv_out[static_cast(out_base + k)] += static_cast(sample * filter[static_cast(k)]); + } + } + + for (int64_t t = 0; t < cropped_time; ++t) { + output.index(t, c, b) = conv_out[static_cast(t + pad_left)]; + } + } + } + + return output; + } + + static sd::Tensor crop_waveform_samples(const sd::Tensor& waveform_in, int64_t target_samples) { + auto waveform = normalize_waveform_for_host(waveform_in); + if (waveform.shape()[0] == target_samples) { + return waveform; + } + if (waveform.shape()[0] > target_samples) { + return sd::ops::slice(waveform, 0, 0, target_samples); + } + sd::Tensor output({target_samples, waveform.shape()[1], waveform.shape()[2]}); + sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform); + return output; + } + + static ggml_type audio_conv_weight_type(ggml_type type) { + return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type; + } + + static ggml_tensor* repeat_1d_value(ggml_context* ctx, ggml_tensor* x, int64_t count) { + GGML_ASSERT(x->ne[0] == 1); + ggml_tensor* target = ggml_new_tensor_4d(ctx, x->type, count, x->ne[1], x->ne[2], x->ne[3]); + return ggml_repeat(ctx, x, target); + } + + static ggml_tensor* replicate_pad_1d(ggml_context* ctx, ggml_tensor* x, int64_t left, int64_t right) { + if (left > 0) { + auto first = ggml_ext_slice(ctx, x, 0, 0, 1); + x = ggml_concat(ctx, repeat_1d_value(ctx, first, left), x, 0); + } + if (right > 0) { + auto last = ggml_ext_slice(ctx, x, 0, x->ne[0] - 1, x->ne[0]); + x = ggml_concat(ctx, x, repeat_1d_value(ctx, last, right), 0); + } + return x; + } + + static ggml_tensor* tile_depthwise_filter_1d(ggml_context* ctx, ggml_tensor* filter, int64_t channels) { + ggml_tensor* base = filter; + if (ggml_n_dims(base) == 3) { + base = ggml_reshape_4d(ctx, base, base->ne[0], 1, 1, 1); + } else if (ggml_n_dims(base) == 1) { + base = ggml_reshape_4d(ctx, base, base->ne[0], 1, 1, 1); + } + ggml_tensor* target = ggml_new_tensor_4d(ctx, base->type, base->ne[0], 1, channels, 1); + return ggml_repeat(ctx, base, target); + } + + static ggml_tensor* depthwise_conv1d(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* filter, + int stride, + int padding) { + GGML_ASSERT(x->ne[3] == 1); + auto tiled = tile_depthwise_filter_1d(ctx, filter, x->ne[1]); + auto out = ggml_conv_1d_dw(ctx, tiled, x, stride, padding, 1); + return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1); + } + + static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* filter, + int stride) { + GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1); + GGML_ASSERT(filter->ne[1] == 1); + + ggml_tensor* out = nullptr; + for (int64_t c = 0; c < x->ne[1]; ++c) { + auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1); + auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1); + yi = ggml_ext_scale(ctx, yi, static_cast(stride)); + yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1); + out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1); + } + return out; + } + + struct PixelNorm2D : public UnaryBlock { + float eps = 1e-6f; + + explicit PixelNorm2D(float eps = 1e-6f) + : eps(eps) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); + h = ggml_rms_norm(ctx->ggml_ctx, h, eps); + h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 0, 3)); + return h; + } + }; + + struct HeightCausalConv2D : public UnaryBlock { + std::pair kernel_size; + + HeightCausalConv2D(int64_t in_channels, + int64_t out_channels, + std::pair kernel_size, + std::pair stride = {1, 1}, + bool bias = true) + : kernel_size(kernel_size) { + blocks["conv"] = std::make_shared(in_channels, out_channels, kernel_size, stride, std::pair{0, 0}, std::pair{1, 1}, bias); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + int pad_h = kernel_size.first - 1; + int pad_w = kernel_size.second - 1; + x = ggml_ext_pad_ext(ctx->ggml_ctx, + x, + pad_w / 2, + pad_w - pad_w / 2, + pad_h, + 0, + 0, + 0, + 0, + 0); + x = conv->forward(ctx, x); + return x; + } + }; + + struct AudioUpsample2D : public GGMLBlock { + AudioUpsample2D(int64_t channels) { + blocks["conv"] = std::make_shared(channels, channels, std::pair{3, 3}); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); + x = conv->forward(ctx, x); + return ggml_ext_slice(ctx->ggml_ctx, x, 1, 1, x->ne[1]); + } + }; + + struct AudioResnetBlock2D : public GGMLBlock { + int64_t in_channels; + int64_t out_channels; + + AudioResnetBlock2D(int64_t in_channels, int64_t out_channels) + : in_channels(in_channels), out_channels(out_channels) { + blocks["norm1"] = std::make_shared(); + blocks["conv1"] = std::make_shared(in_channels, out_channels, std::pair{3, 3}); + blocks["norm2"] = std::make_shared(); + blocks["conv2"] = std::make_shared(out_channels, out_channels, std::pair{3, 3}); + if (in_channels != out_channels) { + blocks["nin_shortcut"] = std::make_shared(in_channels, out_channels, std::pair{1, 1}); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + + auto h = norm1->forward(ctx, x); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h); + h = norm2->forward(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h); + + if (in_channels != out_channels) { + auto shortcut = std::dynamic_pointer_cast(blocks["nin_shortcut"]); + x = shortcut->forward(ctx, x); + } + return ggml_add(ctx->ggml_ctx, x, h); + } + }; + + struct Conv1D : public UnaryBlock { + int64_t in_channels; + int64_t out_channels; + int kernel_size; + int stride; + int padding; + int dilation; + bool bias; + std::string prefix; + + Conv1D(int64_t in_channels, + int64_t out_channels, + int kernel_size, + int stride = 1, + int padding = 0, + int dilation = 1, + bool bias = true) + : in_channels(in_channels), + out_channels(out_channels), + kernel_size(kernel_size), + stride(stride), + padding(padding), + dilation(dilation), + bias(bias) {} + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + this->prefix = prefix; + ggml_type wtype = audio_conv_weight_type(get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F16)); + params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size, in_channels, out_channels, 1); + if (bias) { + params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + x = ggml_conv_1d(ctx->ggml_ctx, params["weight"], x, stride, padding, dilation); + if (bias) { + auto b = ggml_reshape_4d(ctx->ggml_ctx, params["bias"], 1, params["bias"]->ne[0], 1, 1); + x = ggml_add_inplace(ctx->ggml_ctx, x, b); + } + return x; + } + }; + + struct ConvTranspose1D : public UnaryBlock { + int64_t in_channels; + int64_t out_channels; + int kernel_size; + int stride; + int padding; + int dilation; + bool bias; + + ConvTranspose1D(int64_t in_channels, + int64_t out_channels, + int kernel_size, + int stride, + int padding, + int dilation = 1, + bool bias = true) + : in_channels(in_channels), + out_channels(out_channels), + kernel_size(kernel_size), + stride(stride), + padding(padding), + dilation(dilation), + bias(bias) {} + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + SD_UNUSED(tensor_storage_map); + SD_UNUSED(prefix); + params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kernel_size, out_channels, in_channels, 1); + if (bias) { + params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + GGML_ASSERT(dilation == 1); + x = ggml_conv_transpose_1d(ctx->ggml_ctx, params["weight"], x, stride, 0, dilation); + if (padding > 0) { + x = ggml_ext_slice(ctx->ggml_ctx, x, 0, padding, x->ne[0] - padding); + } + if (bias) { + auto b = ggml_reshape_4d(ctx->ggml_ctx, params["bias"], 1, params["bias"]->ne[0], 1, 1); + x = ggml_add_inplace(ctx->ggml_ctx, x, b); + } + return x; + } + }; + + struct SnakeBeta1D : public UnaryBlock { + int64_t channels; + float eps = 1e-9f; + + explicit SnakeBeta1D(int64_t channels) + : channels(channels) {} + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + SD_UNUSED(tensor_storage_map); + SD_UNUSED(prefix); + params["alpha"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + params["beta"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto alpha = ggml_exp(ctx->ggml_ctx, params["alpha"]); + auto beta = ggml_exp(ctx->ggml_ctx, params["beta"]); + alpha = ggml_reshape_4d(ctx->ggml_ctx, alpha, 1, alpha->ne[0], 1, 1); + beta = ggml_reshape_4d(ctx->ggml_ctx, beta, 1, beta->ne[0], 1, 1); + auto oscillation = ggml_sin(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, x, alpha)); + oscillation = ggml_mul(ctx->ggml_ctx, oscillation, oscillation); + auto eps_tensor = ggml_ext_scale(ctx->ggml_ctx, ggml_ext_ones(ctx->ggml_ctx, 1, 1, 1, 1), eps); + oscillation = ggml_div(ctx->ggml_ctx, oscillation, ggml_add(ctx->ggml_ctx, beta, eps_tensor)); + return ggml_add(ctx->ggml_ctx, x, oscillation); + } + }; + + struct Activation1D : public GGMLBlock { + int64_t channels; + int up_ratio = 2; + int down_ratio = 2; + int up_kernel_size = 12; + int down_kernel_size = 12; + + explicit Activation1D(int64_t channels) + : channels(channels) { + blocks["act"] = std::make_shared(channels); + } + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + ggml_type down_type = audio_conv_weight_type(get_type(prefix + "downsample.lowpass.filter", tensor_storage_map, GGML_TYPE_F16)); + params["downsample.lowpass.filter"] = ggml_new_tensor_3d(ctx, down_type, down_kernel_size, 1, 1); + params["upsample.filter"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, up_kernel_size, 1, 1); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + auto act = std::dynamic_pointer_cast(blocks["act"]); + auto up_filter = params["upsample.filter"]; + auto down_filter = params["downsample.lowpass.filter"]; + + int up_pad = up_kernel_size / up_ratio - 1; + int up_pad_left = up_pad * up_ratio + (up_kernel_size - up_ratio) / 2; + int up_pad_right = up_pad * up_ratio + (up_kernel_size - up_ratio + 1) / 2; + + x = replicate_pad_1d(ctx->ggml_ctx, x, up_pad, up_pad); + x = depthwise_conv_transpose1d(ctx->ggml_ctx, x, up_filter, up_ratio); + x = ggml_ext_slice(ctx->ggml_ctx, x, 0, up_pad_left, x->ne[0] - up_pad_right); + + x = act->forward(ctx, x); + + int down_pad_left = down_kernel_size / 2 - (down_kernel_size % 2 == 0 ? 1 : 0); + int down_pad_right = down_kernel_size / 2; + x = replicate_pad_1d(ctx->ggml_ctx, x, down_pad_left, down_pad_right); + x = depthwise_conv1d(ctx->ggml_ctx, x, down_filter, down_ratio, 0); + return x; + } + }; + + struct AMPBlock1 : public GGMLBlock { + int64_t channels; + std::vector dilation; + + AMPBlock1(int64_t channels, int kernel_size, const std::vector& dilation) + : channels(channels), dilation(dilation) { + for (int i = 0; i < 3; ++i) { + blocks["acts1." + std::to_string(i)] = std::make_shared(channels); + blocks["acts2." + std::to_string(i)] = std::make_shared(channels); + blocks["convs1." + std::to_string(i)] = std::make_shared(channels, + channels, + kernel_size, + 1, + (kernel_size * dilation[i] - dilation[i]) / 2, + dilation[i]); + blocks["convs2." + std::to_string(i)] = std::make_shared(channels, + channels, + kernel_size, + 1, + kernel_size / 2, + 1); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + for (int i = 0; i < 3; ++i) { + auto act1 = std::dynamic_pointer_cast(blocks["acts1." + std::to_string(i)]); + auto act2 = std::dynamic_pointer_cast(blocks["acts2." + std::to_string(i)]); + auto conv1 = std::dynamic_pointer_cast(blocks["convs1." + std::to_string(i)]); + auto conv2 = std::dynamic_pointer_cast(blocks["convs2." + std::to_string(i)]); + + auto h = act1->forward(ctx, x); + h = conv1->forward(ctx, h); + h = act2->forward(ctx, h); + h = conv2->forward(ctx, h); + x = ggml_add(ctx->ggml_ctx, x, h); + } + return x; + } + }; + + struct Vocoder : public GGMLBlock { + LTXAudioVAEConfig config; + bool use_bwe_config; + bool apply_final_activation; + + explicit Vocoder(const LTXAudioVAEConfig& config, + bool use_bwe_config = false, + bool apply_final_activation = true) + : config(config), + use_bwe_config(use_bwe_config), + apply_final_activation(apply_final_activation) { + const int mel_bins = use_bwe_config ? config.bwe_num_mels : config.mel_bins; + const int initial_channels = use_bwe_config ? config.bwe_upsample_initial_channel : config.base_upsample_initial_channel; + const std::vector& upsample_rates = use_bwe_config ? config.bwe_upsample_rates : config.base_upsample_rates; + const std::vector& upsample_kernel_sizes = use_bwe_config ? config.bwe_upsample_kernel_sizes : config.base_upsample_kernel_sizes; + const std::vector& resblock_kernel_sizes = use_bwe_config ? config.bwe_resblock_kernel_sizes : config.base_resblock_kernel_sizes; + const std::vector>& resblock_dilation_sizes = use_bwe_config ? config.bwe_resblock_dilation_sizes : config.base_resblock_dilation_sizes; + + int in_channels = mel_bins * config.audio_channels; + blocks["conv_pre"] = std::make_shared(in_channels, + initial_channels, + 7, + 1, + 3); + + int current_channels = initial_channels; + int resblock_index = 0; + for (size_t i = 0; i < upsample_rates.size(); ++i) { + int next_channels = initial_channels / (1 << static_cast(i + 1)); + blocks["ups." + std::to_string(i)] = std::make_shared(current_channels, + next_channels, + upsample_kernel_sizes[i], + upsample_rates[i], + (upsample_kernel_sizes[i] - upsample_rates[i]) / 2); + for (size_t j = 0; j < resblock_kernel_sizes.size(); ++j) { + blocks["resblocks." + std::to_string(resblock_index)] = std::make_shared(next_channels, + resblock_kernel_sizes[j], + resblock_dilation_sizes[j]); + ++resblock_index; + } + current_channels = next_channels; + } + + blocks["act_post"] = std::make_shared(current_channels); + blocks["conv_post"] = std::make_shared(current_channels, config.audio_channels, 7, 1, 3, 1, false); + } + + ggml_tensor* prepare_input(GGMLRunnerContext* ctx, ggml_tensor* mel) { + mel = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mel, 1, 0, 2, 3)); + auto mels = ggml_ext_chunk(ctx->ggml_ctx, mel, 2, 2); + mel = ggml_concat(ctx->ggml_ctx, mels[0], mels[1], 1); + // mel = ggml_reshape_4d(ctx->ggml_ctx, + // mel, + // mel->ne[0], + // mel->ne[1] * mel->ne[2], + // mel->ne[3], + // 1); // [b, c*t, f] + return mel; + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* mel) { + const std::vector& upsample_rates = use_bwe_config ? config.bwe_upsample_rates : config.base_upsample_rates; + const std::vector& resblock_kernel_sizes = use_bwe_config ? config.bwe_resblock_kernel_sizes : config.base_resblock_kernel_sizes; + mel = prepare_input(ctx, mel); + auto conv_pre = std::dynamic_pointer_cast(blocks["conv_pre"]); + auto act_post = std::dynamic_pointer_cast(blocks["act_post"]); + auto conv_post = std::dynamic_pointer_cast(blocks["conv_post"]); + + auto x = conv_pre->forward(ctx, mel); + int resblock_index = 0; + for (size_t i = 0; i < upsample_rates.size(); ++i) { + // x = ggml_leaky_relu(ctx->ggml_ctx, x, 0.1f, false); + auto up = std::dynamic_pointer_cast(blocks["ups." + std::to_string(i)]); + x = up->forward(ctx, x); + + ggml_tensor* sum = nullptr; + for (size_t j = 0; j < resblock_kernel_sizes.size(); ++j) { + auto resblock = std::dynamic_pointer_cast(blocks["resblocks." + std::to_string(resblock_index++)]); + auto block_out = resblock->forward(ctx, x); + sum = sum == nullptr ? block_out : ggml_add(ctx->ggml_ctx, sum, block_out); + } + x = ggml_ext_scale(ctx->ggml_ctx, sum, 1.0f / static_cast(resblock_kernel_sizes.size())); + } + + x = act_post->forward(ctx, x); + x = conv_post->forward(ctx, x); + if (apply_final_activation) { + x = ggml_clamp(ctx->ggml_ctx, x, -1.0f, 1.0f); + } + return x; + } + }; + + struct AudioDecoder : public GGMLBlock { + LTXAudioVAEConfig config; + + explicit AudioDecoder(const LTXAudioVAEConfig& config) + : config(config) { + int block_in = config.decoder_channels * config.decoder_channel_multipliers.back(); + blocks["conv_in"] = std::make_shared(config.latent_channels, block_in, std::pair{3, 3}); + blocks["mid.block_1"] = std::make_shared(block_in, block_in); + blocks["mid.block_2"] = std::make_shared(block_in, block_in); + + for (int level = static_cast(config.decoder_channel_multipliers.size()) - 1; level >= 0; --level) { + int block_out = config.decoder_channels * config.decoder_channel_multipliers[level]; + for (int block_idx = 0; block_idx < config.decoder_num_res_blocks + 1; ++block_idx) { + blocks["up." + std::to_string(level) + ".block." + std::to_string(block_idx)] = + std::make_shared(block_in, block_out); + block_in = block_out; + } + if (level != 0) { + blocks["up." + std::to_string(level) + ".upsample"] = std::make_shared(block_in); + } + } + + blocks["norm_out"] = std::make_shared(); + blocks["conv_out"] = std::make_shared(block_in, config.audio_channels, std::pair{3, 3}); + } + + ggml_tensor* denormalize_latent(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* mean, + ggml_tensor* stddev) { + latent = ggml_permute(ctx->ggml_ctx, latent, 0, 2, 1, 3); + latent = ggml_cont(ctx->ggml_ctx, latent); + latent = ggml_reshape_4d(ctx->ggml_ctx, latent, config.latent_frequency_bins * config.latent_channels, latent->ne[2], 1, latent->ne[3]); + + mean = ggml_reshape_4d(ctx->ggml_ctx, mean, mean->ne[0], 1, 1, 1); + stddev = ggml_reshape_4d(ctx->ggml_ctx, stddev, stddev->ne[0], 1, 1, 1); + latent = ggml_add(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, latent, stddev), mean); + + latent = ggml_reshape_4d(ctx->ggml_ctx, + latent, + config.latent_frequency_bins, + config.latent_channels, + latent->ne[1], + latent->ne[3]); + latent = ggml_permute(ctx->ggml_ctx, latent, 0, 2, 1, 3); + return ggml_cont(ctx->ggml_ctx, latent); + } + + ggml_tensor* adjust_output_shape(GGMLRunnerContext* ctx, + ggml_tensor* decoded, + int target_time, + int target_freq) { + int64_t time = std::min(decoded->ne[1], target_time); + int64_t freq = std::min(decoded->ne[0], target_freq); + decoded = ggml_ext_slice(ctx->ggml_ctx, decoded, 0, 0, freq); + decoded = ggml_ext_slice(ctx->ggml_ctx, decoded, 1, 0, time); + return decoded; + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* latent, + ggml_tensor* mean, + ggml_tensor* stddev, + int target_time, + int target_freq) { + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto mid_block_1 = std::dynamic_pointer_cast(blocks["mid.block_1"]); + auto mid_block_2 = std::dynamic_pointer_cast(blocks["mid.block_2"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + + auto x = denormalize_latent(ctx, latent, mean, stddev); + x = conv_in->forward(ctx, x); + x = mid_block_1->forward(ctx, x); + x = mid_block_2->forward(ctx, x); + + for (int level = static_cast(config.decoder_channel_multipliers.size()) - 1; level >= 0; --level) { + for (int block_idx = 0; block_idx < config.decoder_num_res_blocks + 1; ++block_idx) { + auto block = std::dynamic_pointer_cast(blocks["up." + std::to_string(level) + ".block." + std::to_string(block_idx)]); + x = block->forward(ctx, x); + } + if (level != 0) { + auto upsample = std::dynamic_pointer_cast(blocks["up." + std::to_string(level) + ".upsample"]); + x = upsample->forward(ctx, x); + } + } + + x = norm_out->forward(ctx, x); + x = ggml_silu_inplace(ctx->ggml_ctx, x); + x = conv_out->forward(ctx, x); + return adjust_output_shape(ctx, x, target_time, target_freq); + } + }; + + struct LTXAudioVAE : public GGMLBlock { + LTXAudioVAEConfig config; + + explicit LTXAudioVAE(const LTXAudioVAEConfig& config) + : config(config) { + blocks["audio_vae.decoder"] = std::make_shared(config); + blocks["vocoder.vocoder"] = std::make_shared(config); + if (config.has_bwe) { + blocks["vocoder.bwe_generator"] = std::make_shared(config, true, false); + } + } + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + GGMLBlock::init_params(ctx, tensor_storage_map, prefix); + params["audio_vae.per_channel_statistics.mean-of-means"] = + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, config.latent_channels * config.latent_frequency_bins); + params["audio_vae.per_channel_statistics.std-of-means"] = + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, config.latent_channels * config.latent_frequency_bins); + if (config.has_bwe) { + params["vocoder.mel_stft.mel_basis"] = + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, config.bwe_n_fft / 2 + 1, config.bwe_num_mels); + params["vocoder.mel_stft.stft_fn.forward_basis"] = + ggml_new_tensor_3d(ctx, GGML_TYPE_F32, config.bwe_n_fft, 1, (config.bwe_n_fft / 2 + 1) * 2); + params["vocoder.mel_stft.stft_fn.inverse_basis"] = + ggml_new_tensor_3d(ctx, GGML_TYPE_F32, config.bwe_n_fft, 1, (config.bwe_n_fft / 2 + 1) * 2); + } + } + + ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx, + ggml_tensor* latent, + int target_time, + int target_freq) { + auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; + auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; + auto decoder = std::dynamic_pointer_cast(blocks["audio_vae.decoder"]); + return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq); + } + + ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) { + auto vocoder = std::dynamic_pointer_cast(blocks["vocoder.vocoder"]); + return vocoder->forward(ctx, mel); + } + + ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) { + GGML_ASSERT(config.has_bwe); + auto bwe_generator = std::dynamic_pointer_cast(blocks["vocoder.bwe_generator"]); + return bwe_generator->forward(ctx, mel); + } + + ggml_tensor* mel_basis_tensor() const { + auto iter = params.find("vocoder.mel_stft.mel_basis"); + return iter == params.end() ? nullptr : iter->second; + } + + ggml_tensor* stft_forward_basis_tensor() const { + auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis"); + return iter == params.end() ? nullptr : iter->second; + } + }; + + struct LTXAudioVAERunner : public GGMLRunner { + LTXAudioVAEConfig config; + LTXAudioVAE model; + + LTXAudioVAERunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix = "") + : GGMLRunner(backend, offload_params_to_cpu), + config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)), + model(config) { + model.init(params_ctx, tensor_storage_map, prefix); + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + model.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() { + return model.get_params_mem_size(); + } + + std::string get_desc() { + return "ltx_audio_vae"; + } + + ggml_cgraph* build_base_graph(const sd::Tensor& latent_tensor) { + auto latent = make_input(latent_tensor); + int target_time = static_cast(latent_tensor.shape()[1]) * config.latent_downsample_factor() - + (config.latent_downsample_factor() - 1); + int target_freq = config.mel_bins; + + ggml_cgraph* gf = new_graph_custom(655360); + auto runner_ctx = GGMLRunner::get_context(); + auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq); + auto waveform = model.run_vocoder(&runner_ctx, mel); + ggml_build_forward_expand(gf, waveform); + return gf; + } + + ggml_cgraph* build_bwe_graph(const sd::Tensor& mel_tensor) { + auto mel = make_input(mel_tensor); + ggml_cgraph* gf = new_graph_custom(655360); + auto runner_ctx = GGMLRunner::get_context(); + auto residual = model.run_bwe_generator(&runner_ctx, mel); + ggml_build_forward_expand(gf, residual); + return gf; + } + + sd::Tensor compute_base_waveform(int n_threads, + const sd::Tensor& latent_tensor) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_base_graph(latent_tensor); + }; + return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); + } + + sd::Tensor compute_bwe_residual(int n_threads, + const sd::Tensor& mel_tensor) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_bwe_graph(mel_tensor); + }; + return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), 4); + } + + sd::Tensor decode(int n_threads, + const sd::Tensor& latent_tensor) { + auto waveform = compute_base_waveform(n_threads, latent_tensor); + if (!config.has_bwe || waveform.empty()) { + return waveform; + } + + auto waveform_host = normalize_waveform_for_host(waveform); + const int64_t low_time = waveform_host.shape()[0]; + const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate; + int64_t remainder = low_time % config.bwe_hop_length; + if (remainder != 0) { + sd::Tensor padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]}); + sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host); + waveform_host = std::move(padded); + } + + auto mel_basis_tensor = model.mel_basis_tensor(); + auto stft_basis_tensor = model.stft_forward_basis_tensor(); + GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr); + auto mel_basis = load_param_tensor_f32(mel_basis_tensor); + auto forward_basis = load_param_tensor_f32(stft_basis_tensor); + auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length); + auto residual_raw = compute_bwe_residual(n_threads, bwe_mel); + if (residual_raw.empty()) { + return waveform; + } + auto residual = normalize_waveform_for_host(residual_raw); + auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate); + auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f); + auto cropped = crop_waveform_samples(combined, out_time); + return restore_trailing_singleton_dims(cropped, 4); + } + + void test(const std::string& input_path) { + auto z = sd::load_tensor_from_file_as_tensor(input_path); + GGML_ASSERT(!z.empty()); + print_sd_tensor(z, false, "ltx_audio_vae_z"); + + int64_t t0 = ggml_time_ms(); + auto out = decode(8, z); + int64_t t1 = ggml_time_ms(); + + GGML_ASSERT(!out.empty()); + print_sd_tensor(out, false, "ltx_audio_vae_out"); + LOG_DEBUG("ltx audio vae test done in %lldms", t1 - t0); + } + + static void load_from_file_and_test(const std::string& model_path, + const std::string& input_path, + const std::string& prefix = "") { + ggml_backend_t backend = ggml_backend_cpu_init(); + // ggml_backend_t backend = ggml_backend_cuda_init(0); + LOG_INFO("loading ltx audio vae from '%s'", model_path.c_str()); + + ModelLoader model_loader; + if (!model_loader.init_from_file(model_path)) { + LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); + return; + } + + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + auto ltx_audio_vae = std::make_shared(backend, + false, + tensor_storage_map, + prefix); + + ltx_audio_vae->alloc_params_buffer(); + std::map tensors; + ltx_audio_vae->get_param_tensors(tensors, ""); + + if (!model_loader.load_tensors(tensors)) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("ltx audio vae model loaded"); + ltx_audio_vae->test(input_path); + } + }; + +} // namespace LTXV + +#endif // __SD_LTX_AUDIO_VAE_H__ diff --git a/src/ltxv.hpp b/src/ltxv.hpp index 7f5532ec..7cb52a1f 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -1059,7 +1059,6 @@ namespace LTXV { auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0]; ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, v2a_out, v2a_gate)); } - auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto ax_scaled = rms_norm(ctx->ggml_ctx, ax); ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true); @@ -1183,7 +1182,9 @@ namespace LTXV { } ggml_tensor* patchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax) { - ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[2], ax->ne[1], ax->ne[3]); + // ax: [b, c, t, f] + ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, t, c, f] + ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[1], ax->ne[2], ax->ne[3]); // [b, t, c*f] return ax; } @@ -1191,7 +1192,9 @@ namespace LTXV { if (ax == nullptr) { return nullptr; } - return ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, audio_length, cfg.num_audio_channels, ax->ne[2]); + ax = ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, cfg.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f] + ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f] + return ax; } std::pair preprocess_contexts(GGMLRunnerContext* ctx, diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 5f9ebb54..13b5426c 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -14,6 +14,7 @@ #include "diffusion_model.hpp" #include "esrgan.hpp" #include "lora.hpp" +#include "ltx_audio_vae.h" #include "ltx_vae.hpp" #include "pmid.hpp" #include "sample-cache.h" @@ -114,6 +115,7 @@ public: ggml_backend_t clip_backend = nullptr; ggml_backend_t control_net_backend = nullptr; ggml_backend_t vae_backend = nullptr; + ggml_backend_t audio_backend = nullptr; SDVersion version; bool vae_decode_only = false; @@ -134,6 +136,7 @@ public: std::shared_ptr high_noise_diffusion_model; std::shared_ptr first_stage_model; std::shared_ptr preview_vae; + std::shared_ptr audio_vae_model; std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; @@ -171,6 +174,9 @@ public: if (vae_backend != backend) { ggml_backend_free(vae_backend); } + if (audio_backend != nullptr && audio_backend != backend) { + ggml_backend_free(audio_backend); + } ggml_backend_free(backend); } @@ -195,7 +201,8 @@ public: offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; max_vram = sd_ctx_params->max_vram; - bool use_tae = false; + bool use_tae = false; + bool use_audio_vae = false; rng = get_rng(sd_ctx_params->rng_type); if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) { @@ -294,6 +301,22 @@ public: use_tae = true; } + if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) { + LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path); + if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) { + LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path); + } + } + + if (strlen(SAFE_STR(sd_ctx_params->audio_vae_path)) > 0) { + LOG_INFO("loading LTX audio VAE from '%s'", sd_ctx_params->audio_vae_path); + if (!model_loader.init_from_file(sd_ctx_params->audio_vae_path)) { + LOG_WARN("loading LTX audio VAE weights from '%s' failed", sd_ctx_params->audio_vae_path); + } else { + use_audio_vae = true; + } + } + model_loader.convert_tensors_name(); version = model_loader.get_sd_version(); @@ -302,17 +325,6 @@ public: return false; } - if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) { - if (sd_version_is_ltxav(version)) { - LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path); - if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) { - LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path); - } - } else { - LOG_WARN("ignoring embeddings connectors for non-LTXAV model: '%s'", sd_ctx_params->embeddings_connectors_path); - } - } - auto& tensor_storage_map = model_loader.get_tensor_storage_map(); LOG_INFO("Version: %s ", model_version_to_str[version]); @@ -682,6 +694,20 @@ public: } } + if (use_audio_vae) { + if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { + LOG_INFO("LTX audio VAE: Using CPU backend"); + audio_backend = ggml_backend_cpu_init(); + } else { + audio_backend = backend; + } + audio_vae_model = std::make_shared(audio_backend, + false, + tensor_storage_map); + audio_vae_model->alloc_params_buffer(); + audio_vae_model->get_param_tensors(tensors, ""); + } + if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the vae model"); first_stage_model->set_conv2d_direct_enabled(true); @@ -815,6 +841,9 @@ public: ignore_tensors.insert("tae.encoder"); ignore_tensors.insert("text_encoders.llm.visual."); } + if (audio_vae_model) { + ignore_tensors.insert("audio_vae.encoder"); + } if (version == VERSION_OVIS_IMAGE) { ignore_tensors.insert("text_encoders.llm.vision_model."); ignore_tensors.insert("text_encoders.llm.visual_tokenizer."); @@ -847,6 +876,9 @@ public: if (preview_vae) { vae_params_mem_size += preview_vae->get_params_buffer_size(); } + if (audio_vae_model) { + vae_params_mem_size += audio_vae_model->get_params_buffer_size(); + } size_t control_net_params_mem_size = 0; if (control_net) { if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) { @@ -1938,6 +1970,17 @@ public: return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y); } + sd::Tensor decode_ltx_audio_latent(const sd::Tensor& audio_latent) { + if (audio_vae_model == nullptr || audio_latent.empty()) { + return {}; + } + auto waveform = audio_vae_model->decode(n_threads, audio_latent); + if (free_params_immediately) { + audio_vae_model->free_params_buffer(); + } + return waveform; + } + void set_flow_shift(float flow_shift = INFINITY) { auto flow_denoiser = std::dynamic_pointer_cast(denoiser); if (flow_denoiser) { @@ -2243,6 +2286,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "high_noise_diffusion_model_path: %s\n" "embeddings_connectors_path: %s\n" "vae_path: %s\n" + "audio_vae_path: %s\n" "taesd_path: %s\n" "control_net_path: %s\n" "photo_maker_path: %s\n" @@ -2277,6 +2321,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->embeddings_connectors_path), SAFE_STR(sd_ctx_params->vae_path), + SAFE_STR(sd_ctx_params->audio_vae_path), SAFE_STR(sd_ctx_params->taesd_path), SAFE_STR(sd_ctx_params->control_net_path), SAFE_STR(sd_ctx_params->photo_maker_path), @@ -2504,6 +2549,45 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd, + const sd::Tensor& waveform) { + if (sd == nullptr || waveform.empty()) { + return nullptr; + } + + int64_t sample_count = waveform.shape()[0]; + int64_t channels = waveform.shape().size() > 1 ? waveform.shape()[1] : 1; + if (sample_count <= 0 || channels <= 0) { + return nullptr; + } + + sd_audio_t* audio = (sd_audio_t*)malloc(sizeof(sd_audio_t)); + if (audio == nullptr) { + return nullptr; + } + + audio->sample_rate = static_cast(sd->audio_vae_model != nullptr ? sd->audio_vae_model->config.output_sample_rate() : 0); + audio->channels = static_cast(channels); + audio->sample_count = static_cast(sample_count); + size_t sample_bytes = waveform.numel() * sizeof(float); + audio->data = (float*)malloc(sample_bytes); + if (audio->data == nullptr) { + free(audio); + return nullptr; + } + std::memcpy(audio->data, waveform.data(), sample_bytes); + return audio; +} + +void free_sd_audio(sd_audio_t* audio) { + if (audio == nullptr) { + return; + } + free(audio->data); + audio->data = nullptr; + free(audio); +} + SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) { if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { return false; @@ -2939,6 +3023,37 @@ static sd::Tensor pack_ltxav_audio_and_video_latents(const sd::Tensor unpack_ltxav_audio_latent(const sd::Tensor& packed_latent, + int audio_length, + int video_channels) { + if (packed_latent.empty() || audio_length <= 0) { + return {}; + } + + GGML_ASSERT(packed_latent.dim() == 4 || packed_latent.dim() == 5); + int64_t width = packed_latent.shape()[0]; + int64_t height = packed_latent.shape()[1]; + int64_t frames = packed_latent.shape()[2]; + int64_t total_channels = packed_latent.shape()[3]; + int64_t spatial_size = width * height * frames; + if (total_channels <= video_channels) { + return {}; + } + + constexpr int kLtxavAudioFrequencyBins = 16; + constexpr int kLtxavAudioChannels = 8; + int64_t required_values = static_cast(audio_length) * kLtxavAudioFrequencyBins * kLtxavAudioChannels; + int64_t packed_values = (total_channels - video_channels) * spatial_size; + if (packed_values < required_values) { + return {}; + } + + sd::Tensor audio_latent({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1}); + const float* audio_src = packed_latent.data() + static_cast(video_channels) * static_cast(spatial_size); + std::copy_n(audio_src, static_cast(required_values), audio_latent.data()); + return audio_latent; +} + static int get_ltxav_num_audio_latents(int frames, int fps) { GGML_ASSERT(frames > 0); GGML_ASSERT(fps > 0); @@ -3722,8 +3837,10 @@ static std::optional prepare_video_generation_latents(sd } if (sd_version_is_ltxav(sd_ctx->sd->version)) { - latents.audio_length = 0; - latents.audio_latent = {}; + constexpr int kLtxavAudioFrequencyBins = 16; + constexpr int kLtxavAudioChannels = 8; + latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps); + latents.audio_latent = sd::zeros({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1}); } if (sd_version_is_ltxav(sd_ctx->sd->version)) { @@ -3903,8 +4020,9 @@ static std::optional prepare_video_generation_latents(sd latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true); } - // Pipeline-level audio support is temporarily disabled. Keep the model-side - // AV implementation intact, but feed pure video latents through vid_gen. + if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) { + latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent); + } return latents; } @@ -3996,9 +4114,19 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, return result_images; } -SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out) { +SD_API bool generate_video(sd_ctx_t* sd_ctx, + const sd_vid_gen_params_t* sd_vid_gen_params, + sd_image_t** frames_out, + int* num_frames_out, + sd_audio_t** audio_out) { if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { - return nullptr; + return false; + } + if (frames_out != nullptr) { + *frames_out = nullptr; + } + if (audio_out != nullptr) { + *audio_out = nullptr; } if (num_frames_out != nullptr) { *num_frames_out = 0; @@ -4014,7 +4142,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s SamplePlan plan(sd_ctx, sd_vid_gen_params, request); auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request); if (!latent_inputs_opt.has_value()) { - return nullptr; + return false; } ImageGenerationLatents latents = std::move(*latent_inputs_opt); ImageGenerationEmbeds embeds = prepare_video_generation_embeds(sd_ctx, @@ -4115,10 +4243,27 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } if (final_latent.empty()) { LOG_ERROR("sampling failed after %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); - return nullptr; + return false; } LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); + sd_audio_t* generated_audio = nullptr; + if (sd_version_is_ltxav(sd_ctx->sd->version) && + latents.audio_length > 0 && + sd_ctx->sd->audio_vae_model != nullptr) { + auto audio_latent = unpack_ltxav_audio_latent(final_latent, + latents.audio_length, + sd_ctx->sd->get_latent_channel()); + if (!audio_latent.empty()) { + auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent); + if (!waveform.empty()) { + generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform); + } else { + LOG_WARN("LTX audio latent decode failed; continuing with silent video output"); + } + } + } + if (latents.ref_image_num > 0) { final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); } @@ -4128,12 +4273,21 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out); if (result == nullptr) { - return nullptr; + free_sd_audio(generated_audio); + return false; } sd_ctx->sd->lora_stat(); int64_t t1 = ggml_time_ms(); LOG_INFO("generate_video completed in %.2fs", (t1 - t0) * 1.0f / 1000); - return result; + if (frames_out != nullptr) { + *frames_out = result; + } + if (audio_out != nullptr) { + *audio_out = generated_audio; + } else { + free_sd_audio(generated_audio); + } + return true; }