add ltx audio support

This commit is contained in:
leejet 2026-05-10 15:00:33 +08:00
parent bb63d5c2c5
commit 8b03d9bd0e
12 changed files with 1627 additions and 75 deletions

View File

@ -385,11 +385,32 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
return result; return result;
} }
static fs::path get_video_audio_sidecar_path(const SDCliParams& cli_params) {
fs::path out_path = cli_params.output_path;
fs::path base_path = out_path;
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
std::string ext_lower = ext.string();
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
const EncodedImageFormat output_format = encoded_image_format_from_path(out_path.string());
if (!ext.empty()) {
if (output_format == EncodedImageFormat::JPEG ||
output_format == EncodedImageFormat::PNG ||
output_format == EncodedImageFormat::WEBP ||
ext_lower == ".avi" ||
ext_lower == ".webm") {
base_path.replace_extension();
}
}
base_path += ".wav";
return base_path;
}
bool save_results(const SDCliParams& cli_params, bool save_results(const SDCliParams& cli_params,
const SDContextParams& ctx_params, const SDContextParams& ctx_params,
const SDGenerationParams& gen_params, const SDGenerationParams& gen_params,
sd_image_t* results, sd_image_t* results,
int num_results) { int num_results,
const sd_audio_t* generated_audio = nullptr) {
if (results == nullptr || num_results <= 0) { if (results == nullptr || num_results <= 0) {
return false; return false;
} }
@ -442,6 +463,21 @@ bool save_results(const SDCliParams& cli_params,
return ok; return ok;
}; };
auto write_audio_sidecar = [&](const fs::path& wav_path) {
if (generated_audio == nullptr) {
return;
}
if (write_wav_to_file(wav_path.string(),
generated_audio->data,
generated_audio->sample_count,
generated_audio->channels,
generated_audio->sample_rate)) {
LOG_INFO("save result audio to '%s'", wav_path.string().c_str());
} else {
LOG_WARN("failed to save result audio to '%s'", wav_path.string().c_str());
}
};
int sucessful_reults = 0; int sucessful_reults = 0;
if (std::regex_search(cli_params.output_path, format_specifier_regex)) { if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
@ -465,8 +501,16 @@ bool save_results(const SDCliParams& cli_params,
ext = ".avi"; ext = ".avi";
fs::path video_path = base_path; fs::path video_path = base_path;
video_path += ext; video_path += ext;
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) { std::string final_ext_lower = ext.string();
std::transform(final_ext_lower.begin(), final_ext_lower.end(), final_ext_lower.begin(), ::tolower);
const bool mux_audio = generated_audio != nullptr && (final_ext_lower == ".avi" || final_ext_lower == ".webm");
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps, 90, mux_audio ? generated_audio : nullptr) == 0) {
LOG_INFO("save result video to '%s'", video_path.string().c_str()); LOG_INFO("save result video to '%s'", video_path.string().c_str());
if (generated_audio != nullptr && !mux_audio) {
fs::path wav_path = video_path;
wav_path.replace_extension(".wav");
write_audio_sidecar(wav_path);
}
return true; return true;
} else { } else {
LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str()); LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str());
@ -488,6 +532,9 @@ bool save_results(const SDCliParams& cli_params,
} }
} }
LOG_INFO("%d/%d images saved", sucessful_reults, num_results); LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
if (generated_audio != nullptr) {
write_audio_sidecar(get_video_audio_sidecar_path(cli_params));
}
return sucessful_reults != 0; return sucessful_reults != 0;
} }
@ -701,7 +748,8 @@ int main(int argc, const char* argv[]) {
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
SDImageVec results; SDImageVec results;
int num_results = 0; int num_results = 0;
sd_audio_t* generated_audio = nullptr;
if (cli_params.mode == UPSCALE) { if (cli_params.mode == UPSCALE) {
num_results = 1; num_results = 1;
@ -733,7 +781,10 @@ int main(int argc, const char* argv[]) {
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results); results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
} else if (cli_params.mode == VID_GEN) { } else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t(); sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results); sd_image_t* generated_video = nullptr;
if (!generate_video(sd_ctx.get(), &vid_gen_params, &generated_video, &num_results, &generated_audio)) {
generated_video = nullptr;
}
results.adopt(generated_video, num_results); results.adopt(generated_video, num_results);
} }
@ -773,9 +824,12 @@ int main(int argc, const char* argv[]) {
} }
} }
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) { if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results, generated_audio)) {
free_sd_audio(generated_audio);
return 1; return 1;
} }
free_sd_audio(generated_audio);
return 0; return 0;
} }

View File

@ -348,6 +348,10 @@ ArgOptions SDContextParams::get_options() {
"--vae", "--vae",
"path to standalone vae model", "path to standalone vae model",
&vae_path}, &vae_path},
{"",
"--audio-vae",
"path to standalone LTX audio vae model",
&audio_vae_path},
{"", {"",
"--taesd", "--taesd",
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)", "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
@ -667,6 +671,7 @@ std::string SDContextParams::to_string() const {
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n" << " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
<< " vae_path: \"" << vae_path << "\",\n" << " vae_path: \"" << vae_path << "\",\n"
<< " audio_vae_path: \"" << audio_vae_path << "\",\n"
<< " taesd_path: \"" << taesd_path << "\",\n" << " taesd_path: \"" << taesd_path << "\",\n"
<< " esrgan_path: \"" << esrgan_path << "\",\n" << " esrgan_path: \"" << esrgan_path << "\",\n"
<< " control_net_path: \"" << control_net_path << "\",\n" << " control_net_path: \"" << control_net_path << "\",\n"
@ -725,6 +730,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
high_noise_diffusion_model_path.c_str(), high_noise_diffusion_model_path.c_str(),
embeddings_connectors_path.c_str(), embeddings_connectors_path.c_str(),
vae_path.c_str(), vae_path.c_str(),
audio_vae_path.c_str(),
taesd_path.c_str(), taesd_path.c_str(),
control_net_path.c_str(), control_net_path.c_str(),
embedding_vec.data(), embedding_vec.data(),

View File

@ -94,6 +94,7 @@ struct SDContextParams {
std::string high_noise_diffusion_model_path; std::string high_noise_diffusion_model_path;
std::string embeddings_connectors_path; std::string embeddings_connectors_path;
std::string vae_path; std::string vae_path;
std::string audio_vae_path;
std::string taesd_path; std::string taesd_path;
std::string esrgan_path; std::string esrgan_path;
std::string control_net_path; std::string control_net_path;

View File

@ -613,6 +613,13 @@ typedef struct {
uint32_t size; uint32_t size;
} avi_index_entry; } avi_index_entry;
typedef struct {
char fourcc[4];
uint32_t flags;
uint32_t offset;
uint32_t size;
} avi_chunk_index_entry;
void write_u32_le(FILE* f, uint32_t val) { void write_u32_le(FILE* f, uint32_t val) {
fwrite(&val, 4, 1, f); fwrite(&val, 4, 1, f);
} }
@ -647,6 +654,33 @@ void write_fourcc(std::vector<uint8_t>& data, const char* fourcc) {
data.insert(data.end(), fourcc, fourcc + 4); data.insert(data.end(), fourcc, fourcc + 4);
} }
static std::vector<uint8_t> audio_to_pcm16_bytes(const sd_audio_t* audio) {
if (audio == nullptr || audio->data == nullptr || audio->sample_count == 0 || audio->channels == 0 || audio->sample_rate == 0) {
return {};
}
const size_t pcm_samples = static_cast<size_t>(audio->sample_count) * static_cast<size_t>(audio->channels);
std::vector<uint8_t> bytes(pcm_samples * sizeof(int16_t));
auto* pcm = reinterpret_cast<int16_t*>(bytes.data());
for (size_t i = 0; i < pcm_samples; ++i) {
const float sample = std::clamp(audio->data[i], -1.0f, 1.0f);
pcm[i] = static_cast<int16_t>(std::lrint(sample * 32767.0f));
}
return bytes;
}
static std::pair<uint64_t, uint64_t> audio_sample_range_for_video_frame(const sd_audio_t* audio, int frame_idx, int num_frames, int fps) {
if (audio == nullptr || fps <= 0 || num_frames <= 0) {
return {0, 0};
}
const uint64_t total = audio->sample_count;
const uint64_t start = static_cast<uint64_t>((static_cast<long double>(frame_idx) * total) / num_frames);
const uint64_t end = frame_idx + 1 == num_frames
? total
: static_cast<uint64_t>((static_cast<long double>(frame_idx + 1) * total) / num_frames);
return {start, std::max(start, end)};
}
EncodedImageFormat encoded_image_format_from_path(const std::string& path) { EncodedImageFormat encoded_image_format_from_path(const std::string& path) {
std::string ext = fs::path(path).extension().string(); std::string ext = fs::path(path).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
@ -776,7 +810,7 @@ uint8_t* load_image_from_memory(const char* image_bytes,
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel); return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
} }
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
if (num_images == 0) { if (num_images == 0) {
fprintf(stderr, "Error: Image array is empty.\n"); fprintf(stderr, "Error: Image array is empty.\n");
return {}; return {};
@ -793,7 +827,13 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
// stb_image_write changes JPEG sampling behavior above quality 90. // stb_image_write changes JPEG sampling behavior above quality 90.
// MJPG AVI playback is more compatible when we keep the encoder on the // MJPG AVI playback is more compatible when we keep the encoder on the
// <= 90 path. // <= 90 path.
const int mjpg_quality = std::clamp(quality, 1, 90); const int mjpg_quality = std::clamp(quality, 1, 90);
const bool has_audio = audio != nullptr && audio->data != nullptr && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0;
const std::vector<uint8_t> audio_pcm = audio_to_pcm16_bytes(audio);
const uint16_t audio_bits_per_sample = 16;
const uint16_t audio_block_align = has_audio ? static_cast<uint16_t>(audio->channels * (audio_bits_per_sample / 8)) : 0;
const uint32_t audio_byte_rate = has_audio ? static_cast<uint32_t>(audio->sample_rate * audio_block_align) : 0;
const uint32_t audio_data_size = has_audio ? static_cast<uint32_t>(audio_pcm.size()) : 0;
std::vector<uint8_t> avi_data; std::vector<uint8_t> avi_data;
avi_data.reserve(static_cast<size_t>(num_images) * 1024); avi_data.reserve(static_cast<size_t>(num_images) * 1024);
@ -804,7 +844,11 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
write_fourcc(avi_data, "AVI "); write_fourcc(avi_data, "AVI ");
write_fourcc(avi_data, "LIST"); write_fourcc(avi_data, "LIST");
write_u32_le(avi_data, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40); uint32_t hdrl_size = 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40;
if (has_audio) {
hdrl_size += 8 + (4 + 8 + 56 + 8 + 16);
}
write_u32_le(avi_data, hdrl_size);
write_fourcc(avi_data, "hdrl"); write_fourcc(avi_data, "hdrl");
write_fourcc(avi_data, "avih"); write_fourcc(avi_data, "avih");
@ -815,7 +859,7 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
write_u32_le(avi_data, 0x110); write_u32_le(avi_data, 0x110);
write_u32_le(avi_data, num_images); write_u32_le(avi_data, num_images);
write_u32_le(avi_data, 0); write_u32_le(avi_data, 0);
write_u32_le(avi_data, 1); write_u32_le(avi_data, has_audio ? 2 : 1);
write_u32_le(avi_data, width * height * 3); write_u32_le(avi_data, width * height * 3);
write_u32_le(avi_data, width); write_u32_le(avi_data, width);
write_u32_le(avi_data, height); write_u32_le(avi_data, height);
@ -862,12 +906,48 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
write_u32_le(avi_data, 0); write_u32_le(avi_data, 0);
write_u32_le(avi_data, 0); write_u32_le(avi_data, 0);
if (has_audio) {
write_fourcc(avi_data, "LIST");
write_u32_le(avi_data, 4 + 8 + 56 + 8 + 16);
write_fourcc(avi_data, "strl");
write_fourcc(avi_data, "strh");
write_u32_le(avi_data, 56);
write_fourcc(avi_data, "auds");
write_u32_le(avi_data, 0);
write_u32_le(avi_data, 0);
write_u16_le(avi_data, 0);
write_u16_le(avi_data, 0);
write_u32_le(avi_data, 0);
write_u32_le(avi_data, audio_block_align);
write_u32_le(avi_data, audio_byte_rate);
write_u32_le(avi_data, 0);
write_u32_le(avi_data, static_cast<uint32_t>(audio->sample_count));
write_u32_le(avi_data, audio_data_size);
write_u32_le(avi_data, static_cast<uint32_t>(-1));
write_u32_le(avi_data, audio_block_align);
write_u16_le(avi_data, 0);
write_u16_le(avi_data, 0);
write_u16_le(avi_data, 0);
write_u16_le(avi_data, 0);
write_fourcc(avi_data, "strf");
write_u32_le(avi_data, 16);
write_u16_le(avi_data, 1);
write_u16_le(avi_data, static_cast<uint16_t>(audio->channels));
write_u32_le(avi_data, audio->sample_rate);
write_u32_le(avi_data, audio_byte_rate);
write_u16_le(avi_data, audio_block_align);
write_u16_le(avi_data, audio_bits_per_sample);
}
write_fourcc(avi_data, "LIST"); write_fourcc(avi_data, "LIST");
const size_t movi_size_pos = avi_data.size(); const size_t movi_size_pos = avi_data.size();
write_u32_le(avi_data, 0); write_u32_le(avi_data, 0);
write_fourcc(avi_data, "movi"); write_fourcc(avi_data, "movi");
std::vector<avi_index_entry> index(static_cast<size_t>(num_images)); std::vector<avi_chunk_index_entry> index;
index.reserve(static_cast<size_t>(num_images) + (has_audio ? 1 : 0));
std::vector<uint8_t> jpeg_data; std::vector<uint8_t> jpeg_data;
for (int i = 0; i < num_images; i++) { for (int i = 0; i < num_images; i++) {
@ -884,27 +964,46 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
return {}; return {};
} }
index[i].offset = static_cast<uint32_t>(avi_data.size()); avi_chunk_index_entry video_entry = {};
memcpy(video_entry.fourcc, "00dc", 4);
video_entry.flags = 0x10;
video_entry.offset = static_cast<uint32_t>(avi_data.size());
write_fourcc(avi_data, "00dc"); write_fourcc(avi_data, "00dc");
write_u32_le(avi_data, static_cast<uint32_t>(jpeg_data.size())); write_u32_le(avi_data, static_cast<uint32_t>(jpeg_data.size()));
index[i].size = (uint32_t)jpeg_data.size(); video_entry.size = static_cast<uint32_t>(jpeg_data.size());
avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end()); avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end());
index.push_back(video_entry);
if (jpeg_data.size() % 2) { if (jpeg_data.size() % 2) {
avi_data.push_back(0); avi_data.push_back(0);
} }
} }
if (has_audio && !audio_pcm.empty()) {
avi_chunk_index_entry audio_entry = {};
memcpy(audio_entry.fourcc, "01wb", 4);
audio_entry.flags = 0;
audio_entry.offset = static_cast<uint32_t>(avi_data.size());
audio_entry.size = static_cast<uint32_t>(audio_pcm.size());
write_fourcc(avi_data, "01wb");
write_u32_le(avi_data, static_cast<uint32_t>(audio_pcm.size()));
avi_data.insert(avi_data.end(), audio_pcm.begin(), audio_pcm.end());
index.push_back(audio_entry);
if (audio_pcm.size() % 2 != 0) {
avi_data.push_back(0);
}
}
const size_t movi_size = avi_data.size() - movi_size_pos - 4; const size_t movi_size = avi_data.size() - movi_size_pos - 4;
patch_u32_le(avi_data, movi_size_pos, static_cast<uint32_t>(movi_size)); patch_u32_le(avi_data, movi_size_pos, static_cast<uint32_t>(movi_size));
write_fourcc(avi_data, "idx1"); write_fourcc(avi_data, "idx1");
write_u32_le(avi_data, num_images * 16); write_u32_le(avi_data, static_cast<uint32_t>(index.size() * 16));
for (int i = 0; i < num_images; i++) { for (const auto& entry : index) {
write_fourcc(avi_data, "00dc"); write_fourcc(avi_data, entry.fourcc);
write_u32_le(avi_data, 0x10); write_u32_le(avi_data, entry.flags);
write_u32_le(avi_data, index[i].offset); write_u32_le(avi_data, entry.offset);
write_u32_le(avi_data, index[i].size); write_u32_le(avi_data, entry.size);
} }
const size_t file_size = avi_data.size() - riff_size_pos - 4; const size_t file_size = avi_data.size() - riff_size_pos - 4;
@ -913,8 +1012,8 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
return avi_data; return avi_data;
} }
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
std::vector<uint8_t> avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); std::vector<uint8_t> avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio);
if (avi_data.empty()) { if (avi_data.empty()) {
return -1; return -1;
} }
@ -1044,7 +1143,7 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
#endif #endif
#ifdef SD_USE_WEBM #ifdef SD_USE_WEBM
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) { std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
if (num_images == 0) { if (num_images == 0) {
fprintf(stderr, "Error: Image array is empty.\n"); fprintf(stderr, "Error: Image array is empty.\n");
return {}; return {};
@ -1089,6 +1188,25 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
video_track->set_display_height(static_cast<uint64_t>(height)); video_track->set_display_height(static_cast<uint64_t>(height));
video_track->set_frame_rate(static_cast<double>(fps)); video_track->set_frame_rate(static_cast<double>(fps));
} }
uint64_t audio_track_number = 0;
std::vector<uint8_t> audio_pcm = audio_to_pcm16_bytes(audio);
if (audio != nullptr && !audio_pcm.empty()) {
audio_track_number = segment.AddAudioTrack(static_cast<int32_t>(audio->sample_rate), static_cast<int32_t>(audio->channels), 0);
if (audio_track_number == 0) {
fprintf(stderr, "Error: Failed to add audio track.\n");
return -1;
}
auto* audio_track = static_cast<mkvmuxer::AudioTrack*>(segment.GetTrackByNumber(audio_track_number));
if (audio_track == nullptr) {
fprintf(stderr, "Error: Failed to get audio track.\n");
return -1;
}
audio_track->set_codec_id("A_PCM/INT/LIT");
audio_track->set_bit_depth(16);
audio_track->set_sample_rate(static_cast<double>(audio->sample_rate));
audio_track->set_channels(audio->channels);
}
segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp"); segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp");
segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp"); segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp");
@ -1118,6 +1236,23 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
return -1; return -1;
} }
if (audio_track_number != 0) {
auto [audio_begin, audio_end] = audio_sample_range_for_video_frame(audio, i, num_images, fps);
const uint64_t frame_samples = audio_end - audio_begin;
if (frame_samples > 0) {
const uint64_t frame_bytes = frame_samples * audio->channels * sizeof(int16_t);
const uint8_t* frame_ptr = audio_pcm.data() + audio_begin * audio->channels * sizeof(int16_t);
if (!segment.AddFrame(frame_ptr,
frame_bytes,
audio_track_number,
timestamp_ns,
true)) {
fprintf(stderr, "Error: Failed to mux audio chunk %d into WebM.\n", i);
return -1;
}
}
}
timestamp_ns += frame_duration_ns; timestamp_ns += frame_duration_ns;
} }
@ -1133,8 +1268,8 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
return writer.data(); return writer.data();
} }
int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
std::vector<uint8_t> webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality); std::vector<uint8_t> webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio);
if (webm_data.empty()) { if (webm_data.empty()) {
return -1; return -1;
} }
@ -1150,7 +1285,8 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
sd_image_t* images, sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality) { int quality,
const sd_audio_t* audio) {
std::string format = output_format; std::string format = output_format;
std::transform(format.begin(), format.end(), format.begin(), std::transform(format.begin(), format.end(), format.begin(),
[](unsigned char c) { return static_cast<char>(tolower(c)); }); [](unsigned char c) { return static_cast<char>(tolower(c)); });
@ -1160,7 +1296,7 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
#ifdef SD_USE_WEBM #ifdef SD_USE_WEBM
if (format == "webm") { if (format == "webm") {
return create_webm_from_sd_images_to_vector(images, num_images, fps, quality); return create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio);
} }
#endif #endif
@ -1170,14 +1306,14 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
} }
#endif #endif
return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality); return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio);
} }
int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) { int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
std::string path = filename ? filename : ""; std::string path = filename ? filename : "";
auto pos = path.find_last_of('.'); auto pos = path.find_last_of('.');
std::string ext = pos == std::string::npos ? "" : path.substr(pos); std::string ext = pos == std::string::npos ? "" : path.substr(pos);
std::vector<uint8_t> video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality); std::vector<uint8_t> video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality, audio);
if (video_data.empty()) { if (video_data.empty()) {
return -1; return -1;
} }
@ -1187,3 +1323,54 @@ int create_video_from_sd_images(const char* filename, sd_image_t* images, int nu
} }
return 0; return 0;
} }
bool write_wav_to_file(const std::string& path,
const float* interleaved_samples,
uint64_t sample_count,
uint32_t channels,
uint32_t sample_rate) {
if (interleaved_samples == nullptr || sample_count == 0 || channels == 0 || sample_rate == 0) {
return false;
}
std::ofstream file(path, std::ios::binary);
if (!file.is_open()) {
return false;
}
uint32_t bits_per_sample = 16;
uint32_t bytes_per_sample = bits_per_sample / 8;
uint32_t block_align = channels * bytes_per_sample;
uint32_t byte_rate = sample_rate * block_align;
uint32_t data_size = static_cast<uint32_t>(sample_count * channels * bytes_per_sample);
uint32_t riff_size = 36 + data_size;
file.write("RIFF", 4);
file.write(reinterpret_cast<const char*>(&riff_size), sizeof(riff_size));
file.write("WAVE", 4);
file.write("fmt ", 4);
uint32_t fmt_size = 16;
uint16_t audio_format = 1;
uint16_t wav_channels = static_cast<uint16_t>(channels);
uint16_t wav_block_align = static_cast<uint16_t>(block_align);
uint16_t wav_bits_per_sample = static_cast<uint16_t>(bits_per_sample);
file.write(reinterpret_cast<const char*>(&fmt_size), sizeof(fmt_size));
file.write(reinterpret_cast<const char*>(&audio_format), sizeof(audio_format));
file.write(reinterpret_cast<const char*>(&wav_channels), sizeof(wav_channels));
file.write(reinterpret_cast<const char*>(&sample_rate), sizeof(sample_rate));
file.write(reinterpret_cast<const char*>(&byte_rate), sizeof(byte_rate));
file.write(reinterpret_cast<const char*>(&wav_block_align), sizeof(wav_block_align));
file.write(reinterpret_cast<const char*>(&wav_bits_per_sample), sizeof(wav_bits_per_sample));
file.write("data", 4);
file.write(reinterpret_cast<const char*>(&data_size), sizeof(data_size));
std::vector<int16_t> pcm(sample_count * channels);
for (size_t i = 0; i < pcm.size(); ++i) {
float sample = std::max(-1.0f, std::min(1.0f, interleaved_samples[i]));
pcm[i] = static_cast<int16_t>(std::lrint(sample * 32767.0f));
}
file.write(reinterpret_cast<const char*>(pcm.data()), static_cast<std::streamsize>(pcm.size() * sizeof(int16_t)));
return file.good();
}

View File

@ -57,11 +57,13 @@ int create_mjpg_avi_from_sd_images(const char* filename,
sd_image_t* images, sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality = 90); int quality = 90,
const sd_audio_t* audio = nullptr);
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality = 90); int quality = 90,
const sd_audio_t* audio = nullptr);
#ifdef SD_USE_WEBP #ifdef SD_USE_WEBP
int create_animated_webp_from_sd_images(const char* filename, int create_animated_webp_from_sd_images(const char* filename,
@ -80,22 +82,32 @@ int create_webm_from_sd_images(const char* filename,
sd_image_t* images, sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality = 90); int quality = 90,
const sd_audio_t* audio = nullptr);
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality = 90); int quality = 90,
const sd_audio_t* audio = nullptr);
#endif #endif
int create_video_from_sd_images(const char* filename, int create_video_from_sd_images(const char* filename,
sd_image_t* images, sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality = 90); int quality = 90,
const sd_audio_t* audio = nullptr);
std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format, std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format,
sd_image_t* images, sd_image_t* images,
int num_images, int num_images,
int fps, int fps,
int quality = 90); int quality = 90,
const sd_audio_t* audio = nullptr);
bool write_wav_to_file(const std::string& path,
const float* interleaved_samples,
uint64_t sample_count,
uint32_t channels,
uint32_t sample_rate);
#endif // __MEDIA_IO_H__ #endif // __MEDIA_IO_H__

View File

@ -231,16 +231,21 @@ bool execute_vid_gen_job(ServerRuntime& runtime,
sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t(); sd_vid_gen_params_t params = job.vid_gen.to_sd_vid_gen_params_t();
SDImageVec results; SDImageVec results;
int num_results = 0; int num_results = 0;
sd_audio_t* generated_audio = nullptr;
{ {
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
sd_image_t* raw_results = generate_video(runtime.sd_ctx, &params, &num_results); sd_image_t* raw_results = nullptr;
if (!generate_video(runtime.sd_ctx, &params, &raw_results, &num_results, &generated_audio)) {
raw_results = nullptr;
}
results.adopt(raw_results, num_results); results.adopt(raw_results, num_results);
} }
num_results = results.count(); num_results = results.count();
if (num_results <= 0) { if (num_results <= 0) {
free_sd_audio(generated_audio);
error_message = "generate_video returned no results"; error_message = "generate_video returned no results";
return false; return false;
} }
@ -249,7 +254,9 @@ bool execute_vid_gen_job(ServerRuntime& runtime,
results.data(), results.data(),
num_results, num_results,
job.vid_gen.gen_params.fps, job.vid_gen.gen_params.fps,
job.vid_gen.output_compression); job.vid_gen.output_compression,
generated_audio);
free_sd_audio(generated_audio);
if (video_bytes.empty()) { if (video_bytes.empty()) {
error_message = "failed to encode generated video container"; error_message = "failed to encode generated video container";
return false; return false;

2
ggml

@ -1 +1 @@
Subproject commit 404fcb9d7c96989569e68c9e7881ee3465a05c50 Subproject commit 7f4ab364b2843921e795d6890d0f42dd5e5d6b63

View File

@ -174,6 +174,7 @@ typedef struct {
const char* high_noise_diffusion_model_path; const char* high_noise_diffusion_model_path;
const char* embeddings_connectors_path; const char* embeddings_connectors_path;
const char* vae_path; const char* vae_path;
const char* audio_vae_path;
const char* taesd_path; const char* taesd_path;
const char* control_net_path; const char* control_net_path;
const sd_embedding_t* embeddings; const sd_embedding_t* embeddings;
@ -208,6 +209,13 @@ typedef struct {
float max_vram; float max_vram;
} sd_ctx_params_t; } sd_ctx_params_t;
typedef struct {
uint32_t sample_rate;
uint32_t channels;
uint64_t sample_count;
float* data;
} sd_audio_t;
typedef struct { typedef struct {
uint32_t width; uint32_t width;
uint32_t height; uint32_t height;
@ -407,6 +415,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API void free_sd_audio(sd_audio_t* audio);
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
@ -419,7 +428,11 @@ SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_para
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out); SD_API bool generate_video(sd_ctx_t* sd_ctx,
const sd_vid_gen_params_t* sd_vid_gen_params,
sd_image_t** frames_out,
int* num_frames_out,
sd_audio_t** audio_out);
typedef struct upscaler_ctx_t upscaler_ctx_t; typedef struct upscaler_ctx_t upscaler_ctx_t;

View File

@ -1698,14 +1698,14 @@ struct WeightAdapter {
}; };
struct GGMLRunnerContext { struct GGMLRunnerContext {
ggml_backend_t backend = nullptr; ggml_backend_t backend = nullptr;
ggml_context* ggml_ctx = nullptr; ggml_context* ggml_ctx = nullptr;
bool flash_attn_enabled = false; bool flash_attn_enabled = false;
bool conv2d_direct_enabled = false; bool conv2d_direct_enabled = false;
bool circular_x_enabled = false; bool circular_x_enabled = false;
bool circular_y_enabled = false; bool circular_y_enabled = false;
std::shared_ptr<WeightAdapter> weight_adapter = nullptr; std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
std::unordered_map<ggml_tensor*, std::string>* debug_tensors = nullptr; std::vector<std::pair<ggml_tensor*, std::string>>* debug_tensors = nullptr;
std::function<ggml_tensor*(const std::string&)> get_cache_tensor; std::function<ggml_tensor*(const std::string&)> get_cache_tensor;
std::function<void(const std::string&, ggml_tensor*)> cache_tensor; std::function<void(const std::string&, ggml_tensor*)> cache_tensor;
@ -1713,8 +1713,14 @@ struct GGMLRunnerContext {
if (debug_tensors == nullptr || tensor == nullptr) { if (debug_tensors == nullptr || tensor == nullptr) {
return; return;
} }
ggml_set_output(tensor); ggml_tensor* snapshot = tensor;
(*debug_tensors)[tensor] = name; if (!ggml_is_contiguous(snapshot) || snapshot->view_src != nullptr) {
snapshot = ggml_cont(ggml_ctx, snapshot);
}
ggml_tensor* dst = ggml_dup_tensor(ggml_ctx, snapshot);
snapshot = ggml_cpy(ggml_ctx, snapshot, dst);
ggml_set_output(snapshot);
debug_tensors->push_back({snapshot, name});
} }
ggml_tensor* load_cache_tensor(const std::string& name) const { ggml_tensor* load_cache_tensor(const std::string& name) const {
@ -1768,7 +1774,7 @@ protected:
std::map<ggml_tensor*, const void*> backend_tensor_data_map; std::map<ggml_tensor*, const void*> backend_tensor_data_map;
std::map<std::string, ggml_tensor*> cache_tensor_map; // name -> tensor std::map<std::string, ggml_tensor*> cache_tensor_map; // name -> tensor
std::unordered_map<ggml_tensor*, std::string> debug_tensors; std::vector<std::pair<ggml_tensor*, std::string>> debug_tensors;
const std::string final_result_name = "ggml_runner_final_result_tensor"; const std::string final_result_name = "ggml_runner_final_result_tensor";
bool flash_attn_enabled = false; bool flash_attn_enabled = false;

1109
src/ltx_audio_vae.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1059,7 +1059,6 @@ namespace LTXV {
auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0]; auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0];
ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, v2a_out, v2a_gate)); ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, v2a_out, v2a_gate));
} }
auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3); auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3);
auto ax_scaled = rms_norm(ctx->ggml_ctx, ax); auto ax_scaled = rms_norm(ctx->ggml_ctx, ax);
ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true); ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true);
@ -1183,7 +1182,9 @@ namespace LTXV {
} }
ggml_tensor* patchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax) { ggml_tensor* patchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax) {
ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[2], ax->ne[1], ax->ne[3]); // ax: [b, c, t, f]
ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, t, c, f]
ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[1], ax->ne[2], ax->ne[3]); // [b, t, c*f]
return ax; return ax;
} }
@ -1191,7 +1192,9 @@ namespace LTXV {
if (ax == nullptr) { if (ax == nullptr) {
return nullptr; return nullptr;
} }
return ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, audio_length, cfg.num_audio_channels, ax->ne[2]); ax = ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, cfg.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f]
ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f]
return ax;
} }
std::pair<ggml_tensor*, ggml_tensor*> preprocess_contexts(GGMLRunnerContext* ctx, std::pair<ggml_tensor*, ggml_tensor*> preprocess_contexts(GGMLRunnerContext* ctx,

View File

@ -14,6 +14,7 @@
#include "diffusion_model.hpp" #include "diffusion_model.hpp"
#include "esrgan.hpp" #include "esrgan.hpp"
#include "lora.hpp" #include "lora.hpp"
#include "ltx_audio_vae.h"
#include "ltx_vae.hpp" #include "ltx_vae.hpp"
#include "pmid.hpp" #include "pmid.hpp"
#include "sample-cache.h" #include "sample-cache.h"
@ -114,6 +115,7 @@ public:
ggml_backend_t clip_backend = nullptr; ggml_backend_t clip_backend = nullptr;
ggml_backend_t control_net_backend = nullptr; ggml_backend_t control_net_backend = nullptr;
ggml_backend_t vae_backend = nullptr; ggml_backend_t vae_backend = nullptr;
ggml_backend_t audio_backend = nullptr;
SDVersion version; SDVersion version;
bool vae_decode_only = false; bool vae_decode_only = false;
@ -134,6 +136,7 @@ public:
std::shared_ptr<DiffusionModel> high_noise_diffusion_model; std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
std::shared_ptr<VAE> first_stage_model; std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<VAE> preview_vae; std::shared_ptr<VAE> preview_vae;
std::shared_ptr<LTXV::LTXAudioVAERunner> audio_vae_model;
std::shared_ptr<ControlNet> control_net; std::shared_ptr<ControlNet> control_net;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model; std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
std::shared_ptr<LoraModel> pmid_lora; std::shared_ptr<LoraModel> pmid_lora;
@ -171,6 +174,9 @@ public:
if (vae_backend != backend) { if (vae_backend != backend) {
ggml_backend_free(vae_backend); ggml_backend_free(vae_backend);
} }
if (audio_backend != nullptr && audio_backend != backend) {
ggml_backend_free(audio_backend);
}
ggml_backend_free(backend); ggml_backend_free(backend);
} }
@ -195,7 +201,8 @@ public:
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
max_vram = sd_ctx_params->max_vram; max_vram = sd_ctx_params->max_vram;
bool use_tae = false; bool use_tae = false;
bool use_audio_vae = false;
rng = get_rng(sd_ctx_params->rng_type); rng = get_rng(sd_ctx_params->rng_type);
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) { if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
@ -294,6 +301,22 @@ public:
use_tae = true; use_tae = true;
} }
if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) {
LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path);
if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) {
LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->audio_vae_path)) > 0) {
LOG_INFO("loading LTX audio VAE from '%s'", sd_ctx_params->audio_vae_path);
if (!model_loader.init_from_file(sd_ctx_params->audio_vae_path)) {
LOG_WARN("loading LTX audio VAE weights from '%s' failed", sd_ctx_params->audio_vae_path);
} else {
use_audio_vae = true;
}
}
model_loader.convert_tensors_name(); model_loader.convert_tensors_name();
version = model_loader.get_sd_version(); version = model_loader.get_sd_version();
@ -302,17 +325,6 @@ public:
return false; return false;
} }
if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) {
if (sd_version_is_ltxav(version)) {
LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path);
if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) {
LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path);
}
} else {
LOG_WARN("ignoring embeddings connectors for non-LTXAV model: '%s'", sd_ctx_params->embeddings_connectors_path);
}
}
auto& tensor_storage_map = model_loader.get_tensor_storage_map(); auto& tensor_storage_map = model_loader.get_tensor_storage_map();
LOG_INFO("Version: %s ", model_version_to_str[version]); LOG_INFO("Version: %s ", model_version_to_str[version]);
@ -682,6 +694,20 @@ public:
} }
} }
if (use_audio_vae) {
if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("LTX audio VAE: Using CPU backend");
audio_backend = ggml_backend_cpu_init();
} else {
audio_backend = backend;
}
audio_vae_model = std::make_shared<LTXV::LTXAudioVAERunner>(audio_backend,
false,
tensor_storage_map);
audio_vae_model->alloc_params_buffer();
audio_vae_model->get_param_tensors(tensors, "");
}
if (sd_ctx_params->vae_conv_direct) { if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model"); LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true); first_stage_model->set_conv2d_direct_enabled(true);
@ -815,6 +841,9 @@ public:
ignore_tensors.insert("tae.encoder"); ignore_tensors.insert("tae.encoder");
ignore_tensors.insert("text_encoders.llm.visual."); ignore_tensors.insert("text_encoders.llm.visual.");
} }
if (audio_vae_model) {
ignore_tensors.insert("audio_vae.encoder");
}
if (version == VERSION_OVIS_IMAGE) { if (version == VERSION_OVIS_IMAGE) {
ignore_tensors.insert("text_encoders.llm.vision_model."); ignore_tensors.insert("text_encoders.llm.vision_model.");
ignore_tensors.insert("text_encoders.llm.visual_tokenizer."); ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
@ -847,6 +876,9 @@ public:
if (preview_vae) { if (preview_vae) {
vae_params_mem_size += preview_vae->get_params_buffer_size(); vae_params_mem_size += preview_vae->get_params_buffer_size();
} }
if (audio_vae_model) {
vae_params_mem_size += audio_vae_model->get_params_buffer_size();
}
size_t control_net_params_mem_size = 0; size_t control_net_params_mem_size = 0;
if (control_net) { if (control_net) {
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) { if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) {
@ -1938,6 +1970,17 @@ public:
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y); return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
} }
sd::Tensor<float> decode_ltx_audio_latent(const sd::Tensor<float>& audio_latent) {
if (audio_vae_model == nullptr || audio_latent.empty()) {
return {};
}
auto waveform = audio_vae_model->decode(n_threads, audio_latent);
if (free_params_immediately) {
audio_vae_model->free_params_buffer();
}
return waveform;
}
void set_flow_shift(float flow_shift = INFINITY) { void set_flow_shift(float flow_shift = INFINITY) {
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser); auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
if (flow_denoiser) { if (flow_denoiser) {
@ -2243,6 +2286,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"high_noise_diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n"
"embeddings_connectors_path: %s\n" "embeddings_connectors_path: %s\n"
"vae_path: %s\n" "vae_path: %s\n"
"audio_vae_path: %s\n"
"taesd_path: %s\n" "taesd_path: %s\n"
"control_net_path: %s\n" "control_net_path: %s\n"
"photo_maker_path: %s\n" "photo_maker_path: %s\n"
@ -2277,6 +2321,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->embeddings_connectors_path), SAFE_STR(sd_ctx_params->embeddings_connectors_path),
SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->vae_path),
SAFE_STR(sd_ctx_params->audio_vae_path),
SAFE_STR(sd_ctx_params->taesd_path), SAFE_STR(sd_ctx_params->taesd_path),
SAFE_STR(sd_ctx_params->control_net_path), SAFE_STR(sd_ctx_params->control_net_path),
SAFE_STR(sd_ctx_params->photo_maker_path), SAFE_STR(sd_ctx_params->photo_maker_path),
@ -2504,6 +2549,45 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx); free(sd_ctx);
} }
static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd,
const sd::Tensor<float>& waveform) {
if (sd == nullptr || waveform.empty()) {
return nullptr;
}
int64_t sample_count = waveform.shape()[0];
int64_t channels = waveform.shape().size() > 1 ? waveform.shape()[1] : 1;
if (sample_count <= 0 || channels <= 0) {
return nullptr;
}
sd_audio_t* audio = (sd_audio_t*)malloc(sizeof(sd_audio_t));
if (audio == nullptr) {
return nullptr;
}
audio->sample_rate = static_cast<uint32_t>(sd->audio_vae_model != nullptr ? sd->audio_vae_model->config.output_sample_rate() : 0);
audio->channels = static_cast<uint32_t>(channels);
audio->sample_count = static_cast<uint64_t>(sample_count);
size_t sample_bytes = waveform.numel() * sizeof(float);
audio->data = (float*)malloc(sample_bytes);
if (audio->data == nullptr) {
free(audio);
return nullptr;
}
std::memcpy(audio->data, waveform.data(), sample_bytes);
return audio;
}
void free_sd_audio(sd_audio_t* audio) {
if (audio == nullptr) {
return;
}
free(audio->data);
audio->data = nullptr;
free(audio);
}
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) { SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) { if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
return false; return false;
@ -2939,6 +3023,37 @@ static sd::Tensor<float> pack_ltxav_audio_and_video_latents(const sd::Tensor<flo
return packed; return packed;
} }
static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& packed_latent,
int audio_length,
int video_channels) {
if (packed_latent.empty() || audio_length <= 0) {
return {};
}
GGML_ASSERT(packed_latent.dim() == 4 || packed_latent.dim() == 5);
int64_t width = packed_latent.shape()[0];
int64_t height = packed_latent.shape()[1];
int64_t frames = packed_latent.shape()[2];
int64_t total_channels = packed_latent.shape()[3];
int64_t spatial_size = width * height * frames;
if (total_channels <= video_channels) {
return {};
}
constexpr int kLtxavAudioFrequencyBins = 16;
constexpr int kLtxavAudioChannels = 8;
int64_t required_values = static_cast<int64_t>(audio_length) * kLtxavAudioFrequencyBins * kLtxavAudioChannels;
int64_t packed_values = (total_channels - video_channels) * spatial_size;
if (packed_values < required_values) {
return {};
}
sd::Tensor<float> audio_latent({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1});
const float* audio_src = packed_latent.data() + static_cast<size_t>(video_channels) * static_cast<size_t>(spatial_size);
std::copy_n(audio_src, static_cast<size_t>(required_values), audio_latent.data());
return audio_latent;
}
static int get_ltxav_num_audio_latents(int frames, int fps) { static int get_ltxav_num_audio_latents(int frames, int fps) {
GGML_ASSERT(frames > 0); GGML_ASSERT(frames > 0);
GGML_ASSERT(fps > 0); GGML_ASSERT(fps > 0);
@ -3722,8 +3837,10 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
} }
if (sd_version_is_ltxav(sd_ctx->sd->version)) { if (sd_version_is_ltxav(sd_ctx->sd->version)) {
latents.audio_length = 0; constexpr int kLtxavAudioFrequencyBins = 16;
latents.audio_latent = {}; constexpr int kLtxavAudioChannels = 8;
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
latents.audio_latent = sd::zeros<float>({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1});
} }
if (sd_version_is_ltxav(sd_ctx->sd->version)) { if (sd_version_is_ltxav(sd_ctx->sd->version)) {
@ -3903,8 +4020,9 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true); latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
} }
// Pipeline-level audio support is temporarily disabled. Keep the model-side if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) {
// AV implementation intact, but feed pure video latents through vid_gen. latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent);
}
return latents; return latents;
} }
@ -3996,9 +4114,19 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
return result_images; return result_images;
} }
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out) { SD_API bool generate_video(sd_ctx_t* sd_ctx,
const sd_vid_gen_params_t* sd_vid_gen_params,
sd_image_t** frames_out,
int* num_frames_out,
sd_audio_t** audio_out) {
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) { if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
return nullptr; return false;
}
if (frames_out != nullptr) {
*frames_out = nullptr;
}
if (audio_out != nullptr) {
*audio_out = nullptr;
} }
if (num_frames_out != nullptr) { if (num_frames_out != nullptr) {
*num_frames_out = 0; *num_frames_out = 0;
@ -4014,7 +4142,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
SamplePlan plan(sd_ctx, sd_vid_gen_params, request); SamplePlan plan(sd_ctx, sd_vid_gen_params, request);
auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request); auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request);
if (!latent_inputs_opt.has_value()) { if (!latent_inputs_opt.has_value()) {
return nullptr; return false;
} }
ImageGenerationLatents latents = std::move(*latent_inputs_opt); ImageGenerationLatents latents = std::move(*latent_inputs_opt);
ImageGenerationEmbeds embeds = prepare_video_generation_embeds(sd_ctx, ImageGenerationEmbeds embeds = prepare_video_generation_embeds(sd_ctx,
@ -4115,10 +4243,27 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
} }
if (final_latent.empty()) { if (final_latent.empty()) {
LOG_ERROR("sampling failed after %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_ERROR("sampling failed after %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
return nullptr; return false;
} }
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
sd_audio_t* generated_audio = nullptr;
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
latents.audio_length > 0 &&
sd_ctx->sd->audio_vae_model != nullptr) {
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
latents.audio_length,
sd_ctx->sd->get_latent_channel());
if (!audio_latent.empty()) {
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
if (!waveform.empty()) {
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
} else {
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
}
}
}
if (latents.ref_image_num > 0) { if (latents.ref_image_num > 0) {
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
} }
@ -4128,12 +4273,21 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out); auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out);
if (result == nullptr) { if (result == nullptr) {
return nullptr; free_sd_audio(generated_audio);
return false;
} }
sd_ctx->sd->lora_stat(); sd_ctx->sd->lora_stat();
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_INFO("generate_video completed in %.2fs", (t1 - t0) * 1.0f / 1000); LOG_INFO("generate_video completed in %.2fs", (t1 - t0) * 1.0f / 1000);
return result; if (frames_out != nullptr) {
*frames_out = result;
}
if (audio_out != nullptr) {
*audio_out = generated_audio;
} else {
free_sd_audio(generated_audio);
}
return true;
} }