mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-25 07:36:38 +00:00
Compare commits
2 Commits
bb63d5c2c5
...
4fdf43a470
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fdf43a470 | ||
|
|
8b03d9bd0e |
2
.gitmodules
vendored
2
.gitmodules
vendored
@ -1,6 +1,6 @@
|
||||
[submodule "ggml"]
|
||||
path = ggml
|
||||
url = https://github.com/ggml-org/ggml.git
|
||||
url = https://github.com/leejet/ggml.git
|
||||
[submodule "examples/server/frontend"]
|
||||
path = examples/server/frontend
|
||||
url = https://github.com/leejet/sdcpp-webui.git
|
||||
|
||||
@ -385,11 +385,32 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
|
||||
return result;
|
||||
}
|
||||
|
||||
static fs::path get_video_audio_sidecar_path(const SDCliParams& cli_params) {
|
||||
fs::path out_path = cli_params.output_path;
|
||||
fs::path base_path = out_path;
|
||||
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
|
||||
std::string ext_lower = ext.string();
|
||||
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
|
||||
const EncodedImageFormat output_format = encoded_image_format_from_path(out_path.string());
|
||||
if (!ext.empty()) {
|
||||
if (output_format == EncodedImageFormat::JPEG ||
|
||||
output_format == EncodedImageFormat::PNG ||
|
||||
output_format == EncodedImageFormat::WEBP ||
|
||||
ext_lower == ".avi" ||
|
||||
ext_lower == ".webm") {
|
||||
base_path.replace_extension();
|
||||
}
|
||||
}
|
||||
base_path += ".wav";
|
||||
return base_path;
|
||||
}
|
||||
|
||||
bool save_results(const SDCliParams& cli_params,
|
||||
const SDContextParams& ctx_params,
|
||||
const SDGenerationParams& gen_params,
|
||||
sd_image_t* results,
|
||||
int num_results) {
|
||||
int num_results,
|
||||
const sd_audio_t* generated_audio = nullptr) {
|
||||
if (results == nullptr || num_results <= 0) {
|
||||
return false;
|
||||
}
|
||||
@ -442,6 +463,21 @@ bool save_results(const SDCliParams& cli_params,
|
||||
return ok;
|
||||
};
|
||||
|
||||
auto write_audio_sidecar = [&](const fs::path& wav_path) {
|
||||
if (generated_audio == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (write_wav_to_file(wav_path.string(),
|
||||
generated_audio->data,
|
||||
generated_audio->sample_count,
|
||||
generated_audio->channels,
|
||||
generated_audio->sample_rate)) {
|
||||
LOG_INFO("save result audio to '%s'", wav_path.string().c_str());
|
||||
} else {
|
||||
LOG_WARN("failed to save result audio to '%s'", wav_path.string().c_str());
|
||||
}
|
||||
};
|
||||
|
||||
int sucessful_reults = 0;
|
||||
|
||||
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
||||
@ -465,8 +501,16 @@ bool save_results(const SDCliParams& cli_params,
|
||||
ext = ".avi";
|
||||
fs::path video_path = base_path;
|
||||
video_path += ext;
|
||||
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
|
||||
std::string final_ext_lower = ext.string();
|
||||
std::transform(final_ext_lower.begin(), final_ext_lower.end(), final_ext_lower.begin(), ::tolower);
|
||||
const bool mux_audio = generated_audio != nullptr && (final_ext_lower == ".avi" || final_ext_lower == ".webm");
|
||||
if (create_video_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps, 90, mux_audio ? generated_audio : nullptr) == 0) {
|
||||
LOG_INFO("save result video to '%s'", video_path.string().c_str());
|
||||
if (generated_audio != nullptr && !mux_audio) {
|
||||
fs::path wav_path = video_path;
|
||||
wav_path.replace_extension(".wav");
|
||||
write_audio_sidecar(wav_path);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
LOG_ERROR("Failed to save result video to '%s'", video_path.string().c_str());
|
||||
@ -488,6 +532,9 @@ bool save_results(const SDCliParams& cli_params,
|
||||
}
|
||||
}
|
||||
LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
|
||||
if (generated_audio != nullptr) {
|
||||
write_audio_sidecar(get_video_audio_sidecar_path(cli_params));
|
||||
}
|
||||
return sucessful_reults != 0;
|
||||
}
|
||||
|
||||
@ -702,6 +749,7 @@ int main(int argc, const char* argv[]) {
|
||||
|
||||
SDImageVec results;
|
||||
int num_results = 0;
|
||||
sd_audio_t* generated_audio = nullptr;
|
||||
|
||||
if (cli_params.mode == UPSCALE) {
|
||||
num_results = 1;
|
||||
@ -733,7 +781,10 @@ int main(int argc, const char* argv[]) {
|
||||
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
|
||||
} else if (cli_params.mode == VID_GEN) {
|
||||
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
|
||||
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
|
||||
sd_image_t* generated_video = nullptr;
|
||||
if (!generate_video(sd_ctx.get(), &vid_gen_params, &generated_video, &num_results, &generated_audio)) {
|
||||
generated_video = nullptr;
|
||||
}
|
||||
results.adopt(generated_video, num_results);
|
||||
}
|
||||
|
||||
@ -773,9 +824,12 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
|
||||
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results, generated_audio)) {
|
||||
free_sd_audio(generated_audio);
|
||||
return 1;
|
||||
}
|
||||
|
||||
free_sd_audio(generated_audio);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -348,6 +348,10 @@ ArgOptions SDContextParams::get_options() {
|
||||
"--vae",
|
||||
"path to standalone vae model",
|
||||
&vae_path},
|
||||
{"",
|
||||
"--audio-vae",
|
||||
"path to standalone LTX audio vae model",
|
||||
&audio_vae_path},
|
||||
{"",
|
||||
"--taesd",
|
||||
"path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
|
||||
@ -667,6 +671,7 @@ std::string SDContextParams::to_string() const {
|
||||
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
|
||||
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
|
||||
<< " vae_path: \"" << vae_path << "\",\n"
|
||||
<< " audio_vae_path: \"" << audio_vae_path << "\",\n"
|
||||
<< " taesd_path: \"" << taesd_path << "\",\n"
|
||||
<< " esrgan_path: \"" << esrgan_path << "\",\n"
|
||||
<< " control_net_path: \"" << control_net_path << "\",\n"
|
||||
@ -725,6 +730,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
|
||||
high_noise_diffusion_model_path.c_str(),
|
||||
embeddings_connectors_path.c_str(),
|
||||
vae_path.c_str(),
|
||||
audio_vae_path.c_str(),
|
||||
taesd_path.c_str(),
|
||||
control_net_path.c_str(),
|
||||
embedding_vec.data(),
|
||||
|
||||
@ -94,6 +94,7 @@ struct SDContextParams {
|
||||
std::string high_noise_diffusion_model_path;
|
||||
std::string embeddings_connectors_path;
|
||||
std::string vae_path;
|
||||
std::string audio_vae_path;
|
||||
std::string taesd_path;
|
||||
std::string esrgan_path;
|
||||
std::string control_net_path;
|
||||
|
||||
@ -613,6 +613,13 @@ typedef struct {
|
||||
uint32_t size;
|
||||
} avi_index_entry;
|
||||
|
||||
typedef struct {
|
||||
char fourcc[4];
|
||||
uint32_t flags;
|
||||
uint32_t offset;
|
||||
uint32_t size;
|
||||
} avi_chunk_index_entry;
|
||||
|
||||
void write_u32_le(FILE* f, uint32_t val) {
|
||||
fwrite(&val, 4, 1, f);
|
||||
}
|
||||
@ -647,6 +654,33 @@ void write_fourcc(std::vector<uint8_t>& data, const char* fourcc) {
|
||||
data.insert(data.end(), fourcc, fourcc + 4);
|
||||
}
|
||||
|
||||
static std::vector<uint8_t> audio_to_pcm16_bytes(const sd_audio_t* audio) {
|
||||
if (audio == nullptr || audio->data == nullptr || audio->sample_count == 0 || audio->channels == 0 || audio->sample_rate == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const size_t pcm_samples = static_cast<size_t>(audio->sample_count) * static_cast<size_t>(audio->channels);
|
||||
std::vector<uint8_t> bytes(pcm_samples * sizeof(int16_t));
|
||||
auto* pcm = reinterpret_cast<int16_t*>(bytes.data());
|
||||
for (size_t i = 0; i < pcm_samples; ++i) {
|
||||
const float sample = std::clamp(audio->data[i], -1.0f, 1.0f);
|
||||
pcm[i] = static_cast<int16_t>(std::lrint(sample * 32767.0f));
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
static std::pair<uint64_t, uint64_t> audio_sample_range_for_video_frame(const sd_audio_t* audio, int frame_idx, int num_frames, int fps) {
|
||||
if (audio == nullptr || fps <= 0 || num_frames <= 0) {
|
||||
return {0, 0};
|
||||
}
|
||||
const uint64_t total = audio->sample_count;
|
||||
const uint64_t start = static_cast<uint64_t>((static_cast<long double>(frame_idx) * total) / num_frames);
|
||||
const uint64_t end = frame_idx + 1 == num_frames
|
||||
? total
|
||||
: static_cast<uint64_t>((static_cast<long double>(frame_idx + 1) * total) / num_frames);
|
||||
return {start, std::max(start, end)};
|
||||
}
|
||||
|
||||
EncodedImageFormat encoded_image_format_from_path(const std::string& path) {
|
||||
std::string ext = fs::path(path).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
@ -776,7 +810,7 @@ uint8_t* load_image_from_memory(const char* image_bytes,
|
||||
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) {
|
||||
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||
if (num_images == 0) {
|
||||
fprintf(stderr, "Error: Image array is empty.\n");
|
||||
return {};
|
||||
@ -794,6 +828,12 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
||||
// MJPG AVI playback is more compatible when we keep the encoder on the
|
||||
// <= 90 path.
|
||||
const int mjpg_quality = std::clamp(quality, 1, 90);
|
||||
const bool has_audio = audio != nullptr && audio->data != nullptr && audio->sample_count > 0 && audio->channels > 0 && audio->sample_rate > 0;
|
||||
const std::vector<uint8_t> audio_pcm = audio_to_pcm16_bytes(audio);
|
||||
const uint16_t audio_bits_per_sample = 16;
|
||||
const uint16_t audio_block_align = has_audio ? static_cast<uint16_t>(audio->channels * (audio_bits_per_sample / 8)) : 0;
|
||||
const uint32_t audio_byte_rate = has_audio ? static_cast<uint32_t>(audio->sample_rate * audio_block_align) : 0;
|
||||
const uint32_t audio_data_size = has_audio ? static_cast<uint32_t>(audio_pcm.size()) : 0;
|
||||
|
||||
std::vector<uint8_t> avi_data;
|
||||
avi_data.reserve(static_cast<size_t>(num_images) * 1024);
|
||||
@ -804,7 +844,11 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
||||
write_fourcc(avi_data, "AVI ");
|
||||
|
||||
write_fourcc(avi_data, "LIST");
|
||||
write_u32_le(avi_data, 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40);
|
||||
uint32_t hdrl_size = 4 + 8 + 56 + 8 + 4 + 8 + 56 + 8 + 40;
|
||||
if (has_audio) {
|
||||
hdrl_size += 8 + (4 + 8 + 56 + 8 + 16);
|
||||
}
|
||||
write_u32_le(avi_data, hdrl_size);
|
||||
write_fourcc(avi_data, "hdrl");
|
||||
|
||||
write_fourcc(avi_data, "avih");
|
||||
@ -815,7 +859,7 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
||||
write_u32_le(avi_data, 0x110);
|
||||
write_u32_le(avi_data, num_images);
|
||||
write_u32_le(avi_data, 0);
|
||||
write_u32_le(avi_data, 1);
|
||||
write_u32_le(avi_data, has_audio ? 2 : 1);
|
||||
write_u32_le(avi_data, width * height * 3);
|
||||
write_u32_le(avi_data, width);
|
||||
write_u32_le(avi_data, height);
|
||||
@ -862,12 +906,48 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
||||
write_u32_le(avi_data, 0);
|
||||
write_u32_le(avi_data, 0);
|
||||
|
||||
if (has_audio) {
|
||||
write_fourcc(avi_data, "LIST");
|
||||
write_u32_le(avi_data, 4 + 8 + 56 + 8 + 16);
|
||||
write_fourcc(avi_data, "strl");
|
||||
|
||||
write_fourcc(avi_data, "strh");
|
||||
write_u32_le(avi_data, 56);
|
||||
write_fourcc(avi_data, "auds");
|
||||
write_u32_le(avi_data, 0);
|
||||
write_u32_le(avi_data, 0);
|
||||
write_u16_le(avi_data, 0);
|
||||
write_u16_le(avi_data, 0);
|
||||
write_u32_le(avi_data, 0);
|
||||
write_u32_le(avi_data, audio_block_align);
|
||||
write_u32_le(avi_data, audio_byte_rate);
|
||||
write_u32_le(avi_data, 0);
|
||||
write_u32_le(avi_data, static_cast<uint32_t>(audio->sample_count));
|
||||
write_u32_le(avi_data, audio_data_size);
|
||||
write_u32_le(avi_data, static_cast<uint32_t>(-1));
|
||||
write_u32_le(avi_data, audio_block_align);
|
||||
write_u16_le(avi_data, 0);
|
||||
write_u16_le(avi_data, 0);
|
||||
write_u16_le(avi_data, 0);
|
||||
write_u16_le(avi_data, 0);
|
||||
|
||||
write_fourcc(avi_data, "strf");
|
||||
write_u32_le(avi_data, 16);
|
||||
write_u16_le(avi_data, 1);
|
||||
write_u16_le(avi_data, static_cast<uint16_t>(audio->channels));
|
||||
write_u32_le(avi_data, audio->sample_rate);
|
||||
write_u32_le(avi_data, audio_byte_rate);
|
||||
write_u16_le(avi_data, audio_block_align);
|
||||
write_u16_le(avi_data, audio_bits_per_sample);
|
||||
}
|
||||
|
||||
write_fourcc(avi_data, "LIST");
|
||||
const size_t movi_size_pos = avi_data.size();
|
||||
write_u32_le(avi_data, 0);
|
||||
write_fourcc(avi_data, "movi");
|
||||
|
||||
std::vector<avi_index_entry> index(static_cast<size_t>(num_images));
|
||||
std::vector<avi_chunk_index_entry> index;
|
||||
index.reserve(static_cast<size_t>(num_images) + (has_audio ? 1 : 0));
|
||||
std::vector<uint8_t> jpeg_data;
|
||||
|
||||
for (int i = 0; i < num_images; i++) {
|
||||
@ -884,27 +964,46 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
||||
return {};
|
||||
}
|
||||
|
||||
index[i].offset = static_cast<uint32_t>(avi_data.size());
|
||||
avi_chunk_index_entry video_entry = {};
|
||||
memcpy(video_entry.fourcc, "00dc", 4);
|
||||
video_entry.flags = 0x10;
|
||||
video_entry.offset = static_cast<uint32_t>(avi_data.size());
|
||||
write_fourcc(avi_data, "00dc");
|
||||
write_u32_le(avi_data, static_cast<uint32_t>(jpeg_data.size()));
|
||||
index[i].size = (uint32_t)jpeg_data.size();
|
||||
video_entry.size = static_cast<uint32_t>(jpeg_data.size());
|
||||
avi_data.insert(avi_data.end(), jpeg_data.begin(), jpeg_data.end());
|
||||
index.push_back(video_entry);
|
||||
|
||||
if (jpeg_data.size() % 2) {
|
||||
avi_data.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_audio && !audio_pcm.empty()) {
|
||||
avi_chunk_index_entry audio_entry = {};
|
||||
memcpy(audio_entry.fourcc, "01wb", 4);
|
||||
audio_entry.flags = 0;
|
||||
audio_entry.offset = static_cast<uint32_t>(avi_data.size());
|
||||
audio_entry.size = static_cast<uint32_t>(audio_pcm.size());
|
||||
write_fourcc(avi_data, "01wb");
|
||||
write_u32_le(avi_data, static_cast<uint32_t>(audio_pcm.size()));
|
||||
avi_data.insert(avi_data.end(), audio_pcm.begin(), audio_pcm.end());
|
||||
index.push_back(audio_entry);
|
||||
if (audio_pcm.size() % 2 != 0) {
|
||||
avi_data.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t movi_size = avi_data.size() - movi_size_pos - 4;
|
||||
patch_u32_le(avi_data, movi_size_pos, static_cast<uint32_t>(movi_size));
|
||||
|
||||
write_fourcc(avi_data, "idx1");
|
||||
write_u32_le(avi_data, num_images * 16);
|
||||
for (int i = 0; i < num_images; i++) {
|
||||
write_fourcc(avi_data, "00dc");
|
||||
write_u32_le(avi_data, 0x10);
|
||||
write_u32_le(avi_data, index[i].offset);
|
||||
write_u32_le(avi_data, index[i].size);
|
||||
write_u32_le(avi_data, static_cast<uint32_t>(index.size() * 16));
|
||||
for (const auto& entry : index) {
|
||||
write_fourcc(avi_data, entry.fourcc);
|
||||
write_u32_le(avi_data, entry.flags);
|
||||
write_u32_le(avi_data, entry.offset);
|
||||
write_u32_le(avi_data, entry.size);
|
||||
}
|
||||
|
||||
const size_t file_size = avi_data.size() - riff_size_pos - 4;
|
||||
@ -913,8 +1012,8 @@ std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images
|
||||
return avi_data;
|
||||
}
|
||||
|
||||
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) {
|
||||
std::vector<uint8_t> avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality);
|
||||
int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||
std::vector<uint8_t> avi_data = create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||
if (avi_data.empty()) {
|
||||
return -1;
|
||||
}
|
||||
@ -1044,7 +1143,7 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
|
||||
#endif
|
||||
|
||||
#ifdef SD_USE_WEBM
|
||||
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality) {
|
||||
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||
if (num_images == 0) {
|
||||
fprintf(stderr, "Error: Image array is empty.\n");
|
||||
return {};
|
||||
@ -1089,6 +1188,25 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
|
||||
video_track->set_display_height(static_cast<uint64_t>(height));
|
||||
video_track->set_frame_rate(static_cast<double>(fps));
|
||||
}
|
||||
|
||||
uint64_t audio_track_number = 0;
|
||||
std::vector<uint8_t> audio_pcm = audio_to_pcm16_bytes(audio);
|
||||
if (audio != nullptr && !audio_pcm.empty()) {
|
||||
audio_track_number = segment.AddAudioTrack(static_cast<int32_t>(audio->sample_rate), static_cast<int32_t>(audio->channels), 0);
|
||||
if (audio_track_number == 0) {
|
||||
fprintf(stderr, "Error: Failed to add audio track.\n");
|
||||
return -1;
|
||||
}
|
||||
auto* audio_track = static_cast<mkvmuxer::AudioTrack*>(segment.GetTrackByNumber(audio_track_number));
|
||||
if (audio_track == nullptr) {
|
||||
fprintf(stderr, "Error: Failed to get audio track.\n");
|
||||
return -1;
|
||||
}
|
||||
audio_track->set_codec_id("A_PCM/INT/LIT");
|
||||
audio_track->set_bit_depth(16);
|
||||
audio_track->set_sample_rate(static_cast<double>(audio->sample_rate));
|
||||
audio_track->set_channels(audio->channels);
|
||||
}
|
||||
segment.GetSegmentInfo()->set_writing_app("stable-diffusion.cpp");
|
||||
segment.GetSegmentInfo()->set_muxing_app("stable-diffusion.cpp");
|
||||
|
||||
@ -1118,6 +1236,23 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (audio_track_number != 0) {
|
||||
auto [audio_begin, audio_end] = audio_sample_range_for_video_frame(audio, i, num_images, fps);
|
||||
const uint64_t frame_samples = audio_end - audio_begin;
|
||||
if (frame_samples > 0) {
|
||||
const uint64_t frame_bytes = frame_samples * audio->channels * sizeof(int16_t);
|
||||
const uint8_t* frame_ptr = audio_pcm.data() + audio_begin * audio->channels * sizeof(int16_t);
|
||||
if (!segment.AddFrame(frame_ptr,
|
||||
frame_bytes,
|
||||
audio_track_number,
|
||||
timestamp_ns,
|
||||
true)) {
|
||||
fprintf(stderr, "Error: Failed to mux audio chunk %d into WebM.\n", i);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timestamp_ns += frame_duration_ns;
|
||||
}
|
||||
|
||||
@ -1133,8 +1268,8 @@ std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images, in
|
||||
return writer.data();
|
||||
}
|
||||
|
||||
int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) {
|
||||
std::vector<uint8_t> webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality);
|
||||
int create_webm_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||
std::vector<uint8_t> webm_data = create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||
if (webm_data.empty()) {
|
||||
return -1;
|
||||
}
|
||||
@ -1150,7 +1285,8 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality) {
|
||||
int quality,
|
||||
const sd_audio_t* audio) {
|
||||
std::string format = output_format;
|
||||
std::transform(format.begin(), format.end(), format.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(tolower(c)); });
|
||||
@ -1160,7 +1296,7 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
|
||||
|
||||
#ifdef SD_USE_WEBM
|
||||
if (format == "webm") {
|
||||
return create_webm_from_sd_images_to_vector(images, num_images, fps, quality);
|
||||
return create_webm_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1170,14 +1306,14 @@ std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& ou
|
||||
}
|
||||
#endif
|
||||
|
||||
return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality);
|
||||
return create_mjpg_avi_from_sd_images_to_vector(images, num_images, fps, quality, audio);
|
||||
}
|
||||
|
||||
int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality) {
|
||||
int create_video_from_sd_images(const char* filename, sd_image_t* images, int num_images, int fps, int quality, const sd_audio_t* audio) {
|
||||
std::string path = filename ? filename : "";
|
||||
auto pos = path.find_last_of('.');
|
||||
std::string ext = pos == std::string::npos ? "" : path.substr(pos);
|
||||
std::vector<uint8_t> video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality);
|
||||
std::vector<uint8_t> video_data = create_video_from_sd_images_to_vector(ext, images, num_images, fps, quality, audio);
|
||||
if (video_data.empty()) {
|
||||
return -1;
|
||||
}
|
||||
@ -1187,3 +1323,54 @@ int create_video_from_sd_images(const char* filename, sd_image_t* images, int nu
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool write_wav_to_file(const std::string& path,
|
||||
const float* interleaved_samples,
|
||||
uint64_t sample_count,
|
||||
uint32_t channels,
|
||||
uint32_t sample_rate) {
|
||||
if (interleaved_samples == nullptr || sample_count == 0 || channels == 0 || sample_rate == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::ofstream file(path, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t bits_per_sample = 16;
|
||||
uint32_t bytes_per_sample = bits_per_sample / 8;
|
||||
uint32_t block_align = channels * bytes_per_sample;
|
||||
uint32_t byte_rate = sample_rate * block_align;
|
||||
uint32_t data_size = static_cast<uint32_t>(sample_count * channels * bytes_per_sample);
|
||||
uint32_t riff_size = 36 + data_size;
|
||||
|
||||
file.write("RIFF", 4);
|
||||
file.write(reinterpret_cast<const char*>(&riff_size), sizeof(riff_size));
|
||||
file.write("WAVE", 4);
|
||||
file.write("fmt ", 4);
|
||||
|
||||
uint32_t fmt_size = 16;
|
||||
uint16_t audio_format = 1;
|
||||
uint16_t wav_channels = static_cast<uint16_t>(channels);
|
||||
uint16_t wav_block_align = static_cast<uint16_t>(block_align);
|
||||
uint16_t wav_bits_per_sample = static_cast<uint16_t>(bits_per_sample);
|
||||
file.write(reinterpret_cast<const char*>(&fmt_size), sizeof(fmt_size));
|
||||
file.write(reinterpret_cast<const char*>(&audio_format), sizeof(audio_format));
|
||||
file.write(reinterpret_cast<const char*>(&wav_channels), sizeof(wav_channels));
|
||||
file.write(reinterpret_cast<const char*>(&sample_rate), sizeof(sample_rate));
|
||||
file.write(reinterpret_cast<const char*>(&byte_rate), sizeof(byte_rate));
|
||||
file.write(reinterpret_cast<const char*>(&wav_block_align), sizeof(wav_block_align));
|
||||
file.write(reinterpret_cast<const char*>(&wav_bits_per_sample), sizeof(wav_bits_per_sample));
|
||||
|
||||
file.write("data", 4);
|
||||
file.write(reinterpret_cast<const char*>(&data_size), sizeof(data_size));
|
||||
|
||||
std::vector<int16_t> pcm(sample_count * channels);
|
||||
for (size_t i = 0; i < pcm.size(); ++i) {
|
||||
float sample = std::max(-1.0f, std::min(1.0f, interleaved_samples[i]));
|
||||
pcm[i] = static_cast<int16_t>(std::lrint(sample * 32767.0f));
|
||||
}
|
||||
file.write(reinterpret_cast<const char*>(pcm.data()), static_cast<std::streamsize>(pcm.size() * sizeof(int16_t)));
|
||||
return file.good();
|
||||
}
|
||||
|
||||
@ -57,11 +57,13 @@ int create_mjpg_avi_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
int quality = 90,
|
||||
const sd_audio_t* audio = nullptr);
|
||||
std::vector<uint8_t> create_mjpg_avi_from_sd_images_to_vector(sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
int quality = 90,
|
||||
const sd_audio_t* audio = nullptr);
|
||||
|
||||
#ifdef SD_USE_WEBP
|
||||
int create_animated_webp_from_sd_images(const char* filename,
|
||||
@ -80,22 +82,32 @@ int create_webm_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
int quality = 90,
|
||||
const sd_audio_t* audio = nullptr);
|
||||
std::vector<uint8_t> create_webm_from_sd_images_to_vector(sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
int quality = 90,
|
||||
const sd_audio_t* audio = nullptr);
|
||||
#endif
|
||||
|
||||
int create_video_from_sd_images(const char* filename,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
int quality = 90,
|
||||
const sd_audio_t* audio = nullptr);
|
||||
std::vector<uint8_t> create_video_from_sd_images_to_vector(const std::string& output_format,
|
||||
sd_image_t* images,
|
||||
int num_images,
|
||||
int fps,
|
||||
int quality = 90);
|
||||
int quality = 90,
|
||||
const sd_audio_t* audio = nullptr);
|
||||
|
||||
bool write_wav_to_file(const std::string& path,
|
||||
const float* interleaved_samples,
|
||||
uint64_t sample_count,
|
||||
uint32_t channels,
|
||||
uint32_t sample_rate);
|
||||
|
||||
#endif // __MEDIA_IO_H__
|
||||
|
||||
@ -232,15 +232,20 @@ bool execute_vid_gen_job(ServerRuntime& runtime,
|
||||
|
||||
SDImageVec results;
|
||||
int num_results = 0;
|
||||
sd_audio_t* generated_audio = nullptr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
|
||||
sd_image_t* raw_results = generate_video(runtime.sd_ctx, ¶ms, &num_results);
|
||||
sd_image_t* raw_results = nullptr;
|
||||
if (!generate_video(runtime.sd_ctx, ¶ms, &raw_results, &num_results, &generated_audio)) {
|
||||
raw_results = nullptr;
|
||||
}
|
||||
results.adopt(raw_results, num_results);
|
||||
}
|
||||
|
||||
num_results = results.count();
|
||||
if (num_results <= 0) {
|
||||
free_sd_audio(generated_audio);
|
||||
error_message = "generate_video returned no results";
|
||||
return false;
|
||||
}
|
||||
@ -249,7 +254,9 @@ bool execute_vid_gen_job(ServerRuntime& runtime,
|
||||
results.data(),
|
||||
num_results,
|
||||
job.vid_gen.gen_params.fps,
|
||||
job.vid_gen.output_compression);
|
||||
job.vid_gen.output_compression,
|
||||
generated_audio);
|
||||
free_sd_audio(generated_audio);
|
||||
if (video_bytes.empty()) {
|
||||
error_message = "failed to encode generated video container";
|
||||
return false;
|
||||
|
||||
2
ggml
2
ggml
@ -1 +1 @@
|
||||
Subproject commit 404fcb9d7c96989569e68c9e7881ee3465a05c50
|
||||
Subproject commit 7f4ab364b2843921e795d6890d0f42dd5e5d6b63
|
||||
@ -174,6 +174,7 @@ typedef struct {
|
||||
const char* high_noise_diffusion_model_path;
|
||||
const char* embeddings_connectors_path;
|
||||
const char* vae_path;
|
||||
const char* audio_vae_path;
|
||||
const char* taesd_path;
|
||||
const char* control_net_path;
|
||||
const sd_embedding_t* embeddings;
|
||||
@ -208,6 +209,13 @@ typedef struct {
|
||||
float max_vram;
|
||||
} sd_ctx_params_t;
|
||||
|
||||
typedef struct {
|
||||
uint32_t sample_rate;
|
||||
uint32_t channels;
|
||||
uint64_t sample_count;
|
||||
float* data;
|
||||
} sd_audio_t;
|
||||
|
||||
typedef struct {
|
||||
uint32_t width;
|
||||
uint32_t height;
|
||||
@ -407,6 +415,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
|
||||
|
||||
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
|
||||
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
|
||||
SD_API void free_sd_audio(sd_audio_t* audio);
|
||||
|
||||
SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
|
||||
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
|
||||
@ -419,7 +428,11 @@ SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_para
|
||||
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
|
||||
|
||||
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
|
||||
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);
|
||||
SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
||||
const sd_vid_gen_params_t* sd_vid_gen_params,
|
||||
sd_image_t** frames_out,
|
||||
int* num_frames_out,
|
||||
sd_audio_t** audio_out);
|
||||
|
||||
typedef struct upscaler_ctx_t upscaler_ctx_t;
|
||||
|
||||
|
||||
@ -1705,7 +1705,7 @@ struct GGMLRunnerContext {
|
||||
bool circular_x_enabled = false;
|
||||
bool circular_y_enabled = false;
|
||||
std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
|
||||
std::unordered_map<ggml_tensor*, std::string>* debug_tensors = nullptr;
|
||||
std::vector<std::pair<ggml_tensor*, std::string>>* debug_tensors = nullptr;
|
||||
std::function<ggml_tensor*(const std::string&)> get_cache_tensor;
|
||||
std::function<void(const std::string&, ggml_tensor*)> cache_tensor;
|
||||
|
||||
@ -1713,8 +1713,14 @@ struct GGMLRunnerContext {
|
||||
if (debug_tensors == nullptr || tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
ggml_set_output(tensor);
|
||||
(*debug_tensors)[tensor] = name;
|
||||
ggml_tensor* snapshot = tensor;
|
||||
if (!ggml_is_contiguous(snapshot) || snapshot->view_src != nullptr) {
|
||||
snapshot = ggml_cont(ggml_ctx, snapshot);
|
||||
}
|
||||
ggml_tensor* dst = ggml_dup_tensor(ggml_ctx, snapshot);
|
||||
snapshot = ggml_cpy(ggml_ctx, snapshot, dst);
|
||||
ggml_set_output(snapshot);
|
||||
debug_tensors->push_back({snapshot, name});
|
||||
}
|
||||
|
||||
ggml_tensor* load_cache_tensor(const std::string& name) const {
|
||||
@ -1768,7 +1774,7 @@ protected:
|
||||
|
||||
std::map<ggml_tensor*, const void*> backend_tensor_data_map;
|
||||
std::map<std::string, ggml_tensor*> cache_tensor_map; // name -> tensor
|
||||
std::unordered_map<ggml_tensor*, std::string> debug_tensors;
|
||||
std::vector<std::pair<ggml_tensor*, std::string>> debug_tensors;
|
||||
const std::string final_result_name = "ggml_runner_final_result_tensor";
|
||||
|
||||
bool flash_attn_enabled = false;
|
||||
|
||||
1109
src/ltx_audio_vae.h
Normal file
1109
src/ltx_audio_vae.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -1059,7 +1059,6 @@ namespace LTXV {
|
||||
auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0];
|
||||
ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, v2a_out, v2a_gate));
|
||||
}
|
||||
|
||||
auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3);
|
||||
auto ax_scaled = rms_norm(ctx->ggml_ctx, ax);
|
||||
ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true);
|
||||
@ -1183,7 +1182,9 @@ namespace LTXV {
|
||||
}
|
||||
|
||||
ggml_tensor* patchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax) {
|
||||
ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[2], ax->ne[1], ax->ne[3]);
|
||||
// ax: [b, c, t, f]
|
||||
ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, t, c, f]
|
||||
ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[1], ax->ne[2], ax->ne[3]); // [b, t, c*f]
|
||||
return ax;
|
||||
}
|
||||
|
||||
@ -1191,7 +1192,9 @@ namespace LTXV {
|
||||
if (ax == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, audio_length, cfg.num_audio_channels, ax->ne[2]);
|
||||
ax = ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, cfg.num_audio_channels, audio_length, ax->ne[2]); // [b, t, c, f]
|
||||
ax = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, ax, 0, 2, 1, 3)); // [b, c, t, f]
|
||||
return ax;
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor*, ggml_tensor*> preprocess_contexts(GGMLRunnerContext* ctx,
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include "diffusion_model.hpp"
|
||||
#include "esrgan.hpp"
|
||||
#include "lora.hpp"
|
||||
#include "ltx_audio_vae.h"
|
||||
#include "ltx_vae.hpp"
|
||||
#include "pmid.hpp"
|
||||
#include "sample-cache.h"
|
||||
@ -114,6 +115,7 @@ public:
|
||||
ggml_backend_t clip_backend = nullptr;
|
||||
ggml_backend_t control_net_backend = nullptr;
|
||||
ggml_backend_t vae_backend = nullptr;
|
||||
ggml_backend_t audio_backend = nullptr;
|
||||
|
||||
SDVersion version;
|
||||
bool vae_decode_only = false;
|
||||
@ -134,6 +136,7 @@ public:
|
||||
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
|
||||
std::shared_ptr<VAE> first_stage_model;
|
||||
std::shared_ptr<VAE> preview_vae;
|
||||
std::shared_ptr<LTXV::LTXAudioVAERunner> audio_vae_model;
|
||||
std::shared_ptr<ControlNet> control_net;
|
||||
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
||||
std::shared_ptr<LoraModel> pmid_lora;
|
||||
@ -171,6 +174,9 @@ public:
|
||||
if (vae_backend != backend) {
|
||||
ggml_backend_free(vae_backend);
|
||||
}
|
||||
if (audio_backend != nullptr && audio_backend != backend) {
|
||||
ggml_backend_free(audio_backend);
|
||||
}
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
@ -196,6 +202,7 @@ public:
|
||||
max_vram = sd_ctx_params->max_vram;
|
||||
|
||||
bool use_tae = false;
|
||||
bool use_audio_vae = false;
|
||||
|
||||
rng = get_rng(sd_ctx_params->rng_type);
|
||||
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
|
||||
@ -294,6 +301,22 @@ public:
|
||||
use_tae = true;
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) {
|
||||
LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) {
|
||||
LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path);
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->audio_vae_path)) > 0) {
|
||||
LOG_INFO("loading LTX audio VAE from '%s'", sd_ctx_params->audio_vae_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->audio_vae_path)) {
|
||||
LOG_WARN("loading LTX audio VAE weights from '%s' failed", sd_ctx_params->audio_vae_path);
|
||||
} else {
|
||||
use_audio_vae = true;
|
||||
}
|
||||
}
|
||||
|
||||
model_loader.convert_tensors_name();
|
||||
|
||||
version = model_loader.get_sd_version();
|
||||
@ -302,17 +325,6 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) {
|
||||
if (sd_version_is_ltxav(version)) {
|
||||
LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) {
|
||||
LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path);
|
||||
}
|
||||
} else {
|
||||
LOG_WARN("ignoring embeddings connectors for non-LTXAV model: '%s'", sd_ctx_params->embeddings_connectors_path);
|
||||
}
|
||||
}
|
||||
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
|
||||
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
||||
@ -682,6 +694,20 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (use_audio_vae) {
|
||||
if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
|
||||
LOG_INFO("LTX audio VAE: Using CPU backend");
|
||||
audio_backend = ggml_backend_cpu_init();
|
||||
} else {
|
||||
audio_backend = backend;
|
||||
}
|
||||
audio_vae_model = std::make_shared<LTXV::LTXAudioVAERunner>(audio_backend,
|
||||
false,
|
||||
tensor_storage_map);
|
||||
audio_vae_model->alloc_params_buffer();
|
||||
audio_vae_model->get_param_tensors(tensors, "");
|
||||
}
|
||||
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the vae model");
|
||||
first_stage_model->set_conv2d_direct_enabled(true);
|
||||
@ -815,6 +841,9 @@ public:
|
||||
ignore_tensors.insert("tae.encoder");
|
||||
ignore_tensors.insert("text_encoders.llm.visual.");
|
||||
}
|
||||
if (audio_vae_model) {
|
||||
ignore_tensors.insert("audio_vae.encoder");
|
||||
}
|
||||
if (version == VERSION_OVIS_IMAGE) {
|
||||
ignore_tensors.insert("text_encoders.llm.vision_model.");
|
||||
ignore_tensors.insert("text_encoders.llm.visual_tokenizer.");
|
||||
@ -847,6 +876,9 @@ public:
|
||||
if (preview_vae) {
|
||||
vae_params_mem_size += preview_vae->get_params_buffer_size();
|
||||
}
|
||||
if (audio_vae_model) {
|
||||
vae_params_mem_size += audio_vae_model->get_params_buffer_size();
|
||||
}
|
||||
size_t control_net_params_mem_size = 0;
|
||||
if (control_net) {
|
||||
if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path), n_threads)) {
|
||||
@ -1938,6 +1970,17 @@ public:
|
||||
return first_stage_model->decode(n_threads, latents, vae_tiling_params, decode_video, circular_x, circular_y);
|
||||
}
|
||||
|
||||
sd::Tensor<float> decode_ltx_audio_latent(const sd::Tensor<float>& audio_latent) {
|
||||
if (audio_vae_model == nullptr || audio_latent.empty()) {
|
||||
return {};
|
||||
}
|
||||
auto waveform = audio_vae_model->decode(n_threads, audio_latent);
|
||||
if (free_params_immediately) {
|
||||
audio_vae_model->free_params_buffer();
|
||||
}
|
||||
return waveform;
|
||||
}
|
||||
|
||||
void set_flow_shift(float flow_shift = INFINITY) {
|
||||
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
|
||||
if (flow_denoiser) {
|
||||
@ -2243,6 +2286,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"high_noise_diffusion_model_path: %s\n"
|
||||
"embeddings_connectors_path: %s\n"
|
||||
"vae_path: %s\n"
|
||||
"audio_vae_path: %s\n"
|
||||
"taesd_path: %s\n"
|
||||
"control_net_path: %s\n"
|
||||
"photo_maker_path: %s\n"
|
||||
@ -2277,6 +2321,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
|
||||
SAFE_STR(sd_ctx_params->embeddings_connectors_path),
|
||||
SAFE_STR(sd_ctx_params->vae_path),
|
||||
SAFE_STR(sd_ctx_params->audio_vae_path),
|
||||
SAFE_STR(sd_ctx_params->taesd_path),
|
||||
SAFE_STR(sd_ctx_params->control_net_path),
|
||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||
@ -2504,6 +2549,45 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
|
||||
free(sd_ctx);
|
||||
}
|
||||
|
||||
static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd,
|
||||
const sd::Tensor<float>& waveform) {
|
||||
if (sd == nullptr || waveform.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t sample_count = waveform.shape()[0];
|
||||
int64_t channels = waveform.shape().size() > 1 ? waveform.shape()[1] : 1;
|
||||
if (sample_count <= 0 || channels <= 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
sd_audio_t* audio = (sd_audio_t*)malloc(sizeof(sd_audio_t));
|
||||
if (audio == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
audio->sample_rate = static_cast<uint32_t>(sd->audio_vae_model != nullptr ? sd->audio_vae_model->config.output_sample_rate() : 0);
|
||||
audio->channels = static_cast<uint32_t>(channels);
|
||||
audio->sample_count = static_cast<uint64_t>(sample_count);
|
||||
size_t sample_bytes = waveform.numel() * sizeof(float);
|
||||
audio->data = (float*)malloc(sample_bytes);
|
||||
if (audio->data == nullptr) {
|
||||
free(audio);
|
||||
return nullptr;
|
||||
}
|
||||
std::memcpy(audio->data, waveform.data(), sample_bytes);
|
||||
return audio;
|
||||
}
|
||||
|
||||
void free_sd_audio(sd_audio_t* audio) {
|
||||
if (audio == nullptr) {
|
||||
return;
|
||||
}
|
||||
free(audio->data);
|
||||
audio->data = nullptr;
|
||||
free(audio);
|
||||
}
|
||||
|
||||
SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
|
||||
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
|
||||
return false;
|
||||
@ -2939,6 +3023,37 @@ static sd::Tensor<float> pack_ltxav_audio_and_video_latents(const sd::Tensor<flo
|
||||
return packed;
|
||||
}
|
||||
|
||||
static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& packed_latent,
|
||||
int audio_length,
|
||||
int video_channels) {
|
||||
if (packed_latent.empty() || audio_length <= 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
GGML_ASSERT(packed_latent.dim() == 4 || packed_latent.dim() == 5);
|
||||
int64_t width = packed_latent.shape()[0];
|
||||
int64_t height = packed_latent.shape()[1];
|
||||
int64_t frames = packed_latent.shape()[2];
|
||||
int64_t total_channels = packed_latent.shape()[3];
|
||||
int64_t spatial_size = width * height * frames;
|
||||
if (total_channels <= video_channels) {
|
||||
return {};
|
||||
}
|
||||
|
||||
constexpr int kLtxavAudioFrequencyBins = 16;
|
||||
constexpr int kLtxavAudioChannels = 8;
|
||||
int64_t required_values = static_cast<int64_t>(audio_length) * kLtxavAudioFrequencyBins * kLtxavAudioChannels;
|
||||
int64_t packed_values = (total_channels - video_channels) * spatial_size;
|
||||
if (packed_values < required_values) {
|
||||
return {};
|
||||
}
|
||||
|
||||
sd::Tensor<float> audio_latent({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1});
|
||||
const float* audio_src = packed_latent.data() + static_cast<size_t>(video_channels) * static_cast<size_t>(spatial_size);
|
||||
std::copy_n(audio_src, static_cast<size_t>(required_values), audio_latent.data());
|
||||
return audio_latent;
|
||||
}
|
||||
|
||||
static int get_ltxav_num_audio_latents(int frames, int fps) {
|
||||
GGML_ASSERT(frames > 0);
|
||||
GGML_ASSERT(fps > 0);
|
||||
@ -3722,8 +3837,10 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
||||
}
|
||||
|
||||
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||
latents.audio_length = 0;
|
||||
latents.audio_latent = {};
|
||||
constexpr int kLtxavAudioFrequencyBins = 16;
|
||||
constexpr int kLtxavAudioChannels = 8;
|
||||
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
|
||||
latents.audio_latent = sd::zeros<float>({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1});
|
||||
}
|
||||
|
||||
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
|
||||
@ -3903,8 +4020,9 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
|
||||
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
|
||||
}
|
||||
|
||||
// Pipeline-level audio support is temporarily disabled. Keep the model-side
|
||||
// AV implementation intact, but feed pure video latents through vid_gen.
|
||||
if (sd_version_is_ltxav(sd_ctx->sd->version) && !latents.audio_latent.empty()) {
|
||||
latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent);
|
||||
}
|
||||
|
||||
return latents;
|
||||
}
|
||||
@ -3996,9 +4114,19 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
|
||||
return result_images;
|
||||
}
|
||||
|
||||
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out) {
|
||||
SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
||||
const sd_vid_gen_params_t* sd_vid_gen_params,
|
||||
sd_image_t** frames_out,
|
||||
int* num_frames_out,
|
||||
sd_audio_t** audio_out) {
|
||||
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
if (frames_out != nullptr) {
|
||||
*frames_out = nullptr;
|
||||
}
|
||||
if (audio_out != nullptr) {
|
||||
*audio_out = nullptr;
|
||||
}
|
||||
if (num_frames_out != nullptr) {
|
||||
*num_frames_out = 0;
|
||||
@ -4014,7 +4142,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
SamplePlan plan(sd_ctx, sd_vid_gen_params, request);
|
||||
auto latent_inputs_opt = prepare_video_generation_latents(sd_ctx, sd_vid_gen_params, &request);
|
||||
if (!latent_inputs_opt.has_value()) {
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
ImageGenerationLatents latents = std::move(*latent_inputs_opt);
|
||||
ImageGenerationEmbeds embeds = prepare_video_generation_embeds(sd_ctx,
|
||||
@ -4115,10 +4243,27 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
}
|
||||
if (final_latent.empty()) {
|
||||
LOG_ERROR("sampling failed after %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
|
||||
sd_audio_t* generated_audio = nullptr;
|
||||
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
||||
latents.audio_length > 0 &&
|
||||
sd_ctx->sd->audio_vae_model != nullptr) {
|
||||
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
|
||||
latents.audio_length,
|
||||
sd_ctx->sd->get_latent_channel());
|
||||
if (!audio_latent.empty()) {
|
||||
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
|
||||
if (!waveform.empty()) {
|
||||
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
|
||||
} else {
|
||||
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (latents.ref_image_num > 0) {
|
||||
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
||||
}
|
||||
@ -4128,12 +4273,21 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
|
||||
auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out);
|
||||
if (result == nullptr) {
|
||||
return nullptr;
|
||||
free_sd_audio(generated_audio);
|
||||
return false;
|
||||
}
|
||||
|
||||
sd_ctx->sd->lora_stat();
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_INFO("generate_video completed in %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
return result;
|
||||
if (frames_out != nullptr) {
|
||||
*frames_out = result;
|
||||
}
|
||||
if (audio_out != nullptr) {
|
||||
*audio_out = generated_audio;
|
||||
} else {
|
||||
free_sd_audio(generated_audio);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user