perf: run LTX audio VAE decode in one ggml graph (#1538)

This commit is contained in:
leejet 2026-05-21 22:43:14 +08:00 committed by GitHub
parent 47d8198b69
commit 2e3514625a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 216 additions and 253 deletions

View File

@ -2,6 +2,7 @@
#define __SD_LTX_AUDIO_VAE_H__ #define __SD_LTX_AUDIO_VAE_H__
#include <cmath> #include <cmath>
#include <limits>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <vector> #include <vector>
@ -171,90 +172,59 @@ namespace LTXV {
} }
}; };
static sd::Tensor<float> squeeze_trailing_singleton_dims(sd::Tensor<float> tensor) { static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx,
while (tensor.dim() > 0 && tensor.shape().back() == 1) { ggml_tensor* waveform,
tensor = tensor.squeeze(static_cast<size_t>(tensor.dim() - 1)); ggml_tensor* forward_basis,
} ggml_tensor* mel_basis,
return tensor;
}
static sd::Tensor<float> normalize_waveform_for_host(sd::Tensor<float> waveform) {
waveform = squeeze_trailing_singleton_dims(std::move(waveform));
if (waveform.empty()) {
return waveform;
}
if (waveform.dim() == 1) {
return waveform.reshape({waveform.shape()[0], 1, 1});
}
if (waveform.dim() == 2) {
return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1});
}
if (waveform.dim() == 3) {
return waveform;
}
throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim()));
}
static sd::Tensor<float> load_param_tensor_f32(ggml_tensor* tensor) {
GGML_ASSERT(tensor != nullptr);
return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(tensor));
}
static sd::Tensor<float> compute_log_mel_spectrogram(const sd::Tensor<float>& waveform_in,
const sd::Tensor<float>& forward_basis,
const sd::Tensor<float>& mel_basis,
int hop_length) { int hop_length) {
auto waveform = normalize_waveform_for_host(waveform_in); auto ctx = runner_ctx->ggml_ctx;
GGML_ASSERT(forward_basis.dim() >= 3); GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(mel_basis.dim() >= 2); GGML_ASSERT(waveform != nullptr);
GGML_ASSERT(forward_basis != nullptr);
GGML_ASSERT(mel_basis != nullptr);
GGML_ASSERT(waveform->type == GGML_TYPE_F32);
GGML_ASSERT(forward_basis->type == GGML_TYPE_F32);
GGML_ASSERT(mel_basis->type == GGML_TYPE_F32);
GGML_ASSERT(forward_basis->ne[1] == 1);
const int64_t time = waveform.shape()[0]; const int64_t time = waveform->ne[0];
const int64_t channels = waveform.shape()[1]; const int64_t channels = waveform->ne[1];
const int64_t batch = waveform.shape()[2]; const int64_t batch = waveform->ne[2];
const int64_t filter_len = forward_basis.shape()[0]; const int64_t filter_len = forward_basis->ne[0];
const int64_t basis_freq2 = forward_basis.shape().back(); const int64_t stft_channels = forward_basis->ne[2];
const int64_t n_freqs = basis_freq2 / 2; const int64_t n_freqs = stft_channels / 2;
const int64_t n_mels = mel_basis.shape()[1]; const int64_t n_mels = mel_basis->ne[1];
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length); const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
const int64_t padded_time = time + left_pad; const int64_t padded_time = time + left_pad;
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
sd::Tensor<float> log_mel({n_mels, frame_count, channels, batch}); GGML_ASSERT(stft_channels % 2 == 0);
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f); GGML_ASSERT(mel_basis->ne[0] == n_freqs);
std::vector<float> magnitude(static_cast<size_t>(n_freqs), 0.0f); GGML_ASSERT(waveform->ne[3] == 1);
GGML_ASSERT(frame_count > 0);
for (int64_t b = 0; b < batch; ++b) { auto x = ggml_reshape_3d(ctx, waveform, time, 1, channels * batch);
for (int64_t c = 0; c < channels; ++c) { if (left_pad > 0) {
std::fill(padded.begin(), padded.end(), 0.0f); x = ggml_pad_ext(ctx, x, static_cast<int>(left_pad), 0, 0, 0, 0, 0, 0, 0);
for (int64_t t = 0; t < time; ++t) {
padded[static_cast<size_t>(t + left_pad)] = waveform.index(t, c, b);
} }
for (int64_t frame = 0; frame < frame_count; ++frame) { auto frames = ggml_conv_1d(ctx, forward_basis, x, hop_length, 0, 1);
const int64_t frame_offset = frame * hop_length; GGML_ASSERT(frames->ne[0] == frame_count);
for (int64_t f = 0; f < n_freqs; ++f) { GGML_ASSERT(frames->ne[1] == stft_channels);
double real = 0.0; GGML_ASSERT(frames->ne[2] == channels * batch);
double imag = 0.0;
for (int64_t k = 0; k < filter_len; ++k) {
const float sample = padded[static_cast<size_t>(frame_offset + k)];
real += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f));
imag += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f + n_freqs));
}
magnitude[static_cast<size_t>(f)] = static_cast<float>(std::sqrt(real * real + imag * imag));
}
for (int64_t m = 0; m < n_mels; ++m) { auto real = ggml_ext_slice(ctx, frames, 1, 0, n_freqs);
double mel_value = 0.0; auto imag = ggml_ext_slice(ctx, frames, 1, n_freqs, stft_channels);
for (int64_t f = 0; f < n_freqs; ++f) { auto magnitude = ggml_sqrt(ctx,
mel_value += static_cast<double>(mel_basis.index(f, m)) * static_cast<double>(magnitude[static_cast<size_t>(f)]); ggml_add(ctx,
} ggml_sqr(ctx, real),
log_mel.index(m, frame, c, b) = static_cast<float>(std::log(std::max(mel_value, 1e-5))); ggml_sqr(ctx, imag)));
}
}
}
}
return log_mel; magnitude = ggml_cont(ctx, ggml_permute(ctx, magnitude, 1, 0, 2, 3));
auto mel = ggml_mul_mat(ctx, mel_basis, magnitude);
mel = ggml_log(ctx, ggml_clamp(ctx, mel, 1e-5f, std::numeric_limits<float>::max()));
return ggml_reshape_4d(ctx, mel, n_mels, frame_count, channels, batch);
} }
static std::vector<float> build_hann_resample_filter(int ratio) { static std::vector<float> build_hann_resample_filter(int ratio) {
@ -276,75 +246,6 @@ namespace LTXV {
return filter; return filter;
} }
static sd::Tensor<float> upsample_waveform_hann(const sd::Tensor<float>& waveform_in, int ratio) {
auto waveform = normalize_waveform_for_host(waveform_in);
if (ratio <= 1) {
return waveform;
}
const int lowpass_filter_width = 6;
const double rolloff = 0.99;
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
const int kernel_size = 2 * width * ratio + 1;
const int pad = width;
const int pad_left = 2 * width * ratio;
const int pad_right = kernel_size - ratio;
const int64_t time = waveform.shape()[0];
const int64_t channels = waveform.shape()[1];
const int64_t batch = waveform.shape()[2];
const int64_t padded_time = time + 2 * pad;
const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size;
const int64_t cropped_time = conv_out_time - pad_left - pad_right;
auto filter = build_hann_resample_filter(ratio);
sd::Tensor<float> output({cropped_time, channels, batch});
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
std::vector<float> conv_out(static_cast<size_t>(conv_out_time), 0.0f);
for (int64_t b = 0; b < batch; ++b) {
for (int64_t c = 0; c < channels; ++c) {
std::fill(padded.begin(), padded.end(), 0.0f);
const float first = waveform.index(0, c, b);
const float last = waveform.index(time - 1, c, b);
for (int i = 0; i < pad; ++i) {
padded[static_cast<size_t>(i)] = first;
padded[static_cast<size_t>(pad + time + i)] = last;
}
for (int64_t t = 0; t < time; ++t) {
padded[static_cast<size_t>(pad + t)] = waveform.index(t, c, b);
}
std::fill(conv_out.begin(), conv_out.end(), 0.0f);
for (int64_t t = 0; t < padded_time; ++t) {
const double sample = static_cast<double>(padded[static_cast<size_t>(t)]) * ratio;
const int64_t out_base = t * ratio;
for (int k = 0; k < kernel_size; ++k) {
conv_out[static_cast<size_t>(out_base + k)] += static_cast<float>(sample * filter[static_cast<size_t>(k)]);
}
}
for (int64_t t = 0; t < cropped_time; ++t) {
output.index(t, c, b) = conv_out[static_cast<size_t>(t + pad_left)];
}
}
}
return output;
}
static sd::Tensor<float> crop_waveform_samples(const sd::Tensor<float>& waveform_in, int64_t target_samples) {
auto waveform = normalize_waveform_for_host(waveform_in);
if (waveform.shape()[0] == target_samples) {
return waveform;
}
if (waveform.shape()[0] > target_samples) {
return sd::ops::slice(waveform, 0, 0, target_samples);
}
sd::Tensor<float> output({target_samples, waveform.shape()[1], waveform.shape()[2]});
sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform);
return output;
}
static ggml_type audio_conv_weight_type(ggml_type type) { static ggml_type audio_conv_weight_type(ggml_type type) {
return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type; return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type;
} }
@ -413,22 +314,101 @@ namespace LTXV {
return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1); return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1);
} }
static ggml_tensor* reverse_1d_filter(ggml_context* ctx, ggml_tensor* filter) {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(filter != nullptr);
GGML_ASSERT(filter->ne[1] == 1);
GGML_ASSERT(filter->ne[2] == 1);
GGML_ASSERT(filter->ne[3] == 1);
ggml_tensor* reversed = nullptr;
for (int64_t k = filter->ne[0] - 1; k >= 0; --k) {
auto slice = ggml_ext_slice(ctx, filter, 0, k, k + 1);
reversed = reversed == nullptr ? slice : ggml_concat(ctx, reversed, slice, 0);
}
return reversed;
}
static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx, static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx,
ggml_tensor* x, ggml_tensor* x,
ggml_tensor* filter, ggml_tensor* filter,
int stride) { int stride) {
GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1); GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1);
GGML_ASSERT(filter->ne[1] == 1); GGML_ASSERT(filter->ne[1] == 1);
GGML_ASSERT(filter->ne[2] == 1 && filter->ne[3] == 1);
ggml_tensor* out = nullptr; const int64_t time = x->ne[0];
for (int64_t c = 0; c < x->ne[1]; ++c) { const int64_t channels = x->ne[1];
auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1); const int64_t kernel_size = filter->ne[0];
auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1); const int64_t out_time = (time - 1) * stride + kernel_size;
yi = ggml_ext_scale(ctx, yi, static_cast<float>(stride));
yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1); auto x_flat = ggml_reshape_3d(ctx, x, 1, time, channels);
out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1); if (stride > 1) {
auto zero_unit = ggml_ext_scale(ctx, x_flat, 0.0f);
auto zero_tail = zero_unit;
for (int i = 1; i < stride - 1; ++i) {
zero_tail = ggml_concat(ctx, zero_tail, zero_unit, 0);
} }
return out; x_flat = ggml_concat(ctx, x_flat, zero_tail, 0);
}
x_flat = ggml_reshape_3d(ctx, x_flat, time * stride, 1, channels);
auto reversed_filter = reverse_1d_filter(ctx, filter);
auto out = ggml_conv_1d(ctx, reversed_filter, x_flat, 1, static_cast<int>(kernel_size - 1), 1);
if (out->ne[0] > out_time) {
out = ggml_ext_slice(ctx, out, 0, 0, out_time);
}
GGML_ASSERT(out->ne[0] == out_time);
GGML_ASSERT(out->ne[1] == 1);
GGML_ASSERT(out->ne[2] == channels);
out = ggml_ext_scale(ctx, out, static_cast<float>(stride));
return ggml_reshape_4d(ctx, out, out_time, channels, 1, 1);
}
static ggml_tensor* upsample_waveform_hann(GGMLRunnerContext* runner_ctx,
ggml_tensor* waveform,
ggml_tensor* filter,
int ratio) {
auto ctx = runner_ctx->ggml_ctx;
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(waveform != nullptr);
GGML_ASSERT(filter != nullptr);
GGML_ASSERT(waveform->ne[3] == 1);
if (ratio <= 1) {
return waveform;
}
const int lowpass_filter_width = 6;
const double rolloff = 0.99;
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
const int kernel_size = 2 * width * ratio + 1;
const int pad = width;
const int pad_left = 2 * width * ratio;
const int pad_right = kernel_size - ratio;
const int64_t time = waveform->ne[0];
const int64_t channels = waveform->ne[1];
const int64_t batch = waveform->ne[2];
GGML_ASSERT(filter->ne[0] == kernel_size);
auto x = ggml_reshape_3d(ctx, waveform, time, channels * batch, 1);
x = replicate_pad_1d(runner_ctx, x, pad, pad);
x = depthwise_conv_transpose1d(ctx, x, filter, ratio);
x = ggml_ext_slice(ctx, x, 0, pad_left, x->ne[0] - pad_right);
return ggml_reshape_3d(ctx, x, x->ne[0], channels, batch);
}
static ggml_tensor* crop_waveform_samples(ggml_context* ctx,
ggml_tensor* waveform,
int64_t target_samples) {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(waveform != nullptr);
if (waveform->ne[0] == target_samples) {
return waveform;
}
GGML_ASSERT(waveform->ne[0] > target_samples);
return ggml_ext_slice(ctx, waveform, 0, 0, target_samples);
} }
struct PixelNorm2D : public UnaryBlock { struct PixelNorm2D : public UnaryBlock {
@ -950,41 +930,66 @@ namespace LTXV {
} }
} }
ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx, ggml_tensor* decode(GGMLRunnerContext* ctx,
ggml_tensor* latent, ggml_tensor* latent,
int target_time, ggml_tensor* bwe_skip_filter) {
int target_freq) { int target_time = static_cast<int>(latent->ne[1]) * config.latent_downsample_factor() -
(config.latent_downsample_factor() - 1);
int target_freq = config.mel_bins;
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"]; auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"]; auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]); auto mel = decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
}
ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) {
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]); auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
return vocoder->forward(ctx, mel); auto waveform = vocoder->forward(ctx, mel);
if (config.has_bwe) {
GGML_ASSERT(bwe_skip_filter != nullptr);
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
const int64_t low_time = waveform->ne[0];
const int64_t out_time = low_time * bwe_ratio;
int64_t remainder = low_time % config.bwe_hop_length;
auto bwe_waveform = waveform;
if (remainder != 0) {
bwe_waveform = ggml_pad_ext(ctx->ggml_ctx,
bwe_waveform,
0,
static_cast<int>(config.bwe_hop_length - remainder),
0,
0,
0,
0,
0,
0);
} }
ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) { auto mel_basis = params["vocoder.mel_stft.mel_basis"];
GGML_ASSERT(config.has_bwe); auto stft_basis = params["vocoder.mel_stft.stft_fn.forward_basis"];
GGML_ASSERT(mel_basis != nullptr && stft_basis != nullptr);
auto bwe_mel = compute_log_mel_spectrogram(ctx, bwe_waveform, stft_basis, mel_basis, config.bwe_hop_length);
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]); auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
return bwe_generator->forward(ctx, mel); auto residual = bwe_generator->forward(ctx, bwe_mel);
auto skip = upsample_waveform_hann(ctx,
bwe_waveform,
bwe_skip_filter,
bwe_ratio);
waveform = ggml_clamp(ctx->ggml_ctx,
ggml_add(ctx->ggml_ctx, residual, skip),
-1.0f,
1.0f);
waveform = crop_waveform_samples(ctx->ggml_ctx, waveform, out_time);
} }
ggml_tensor* mel_basis_tensor() const { return waveform;
auto iter = params.find("vocoder.mel_stft.mel_basis");
return iter == params.end() ? nullptr : iter->second;
}
ggml_tensor* stft_forward_basis_tensor() const {
auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis");
return iter == params.end() ? nullptr : iter->second;
} }
}; };
struct LTXAudioVAERunner : public GGMLRunner { struct LTXAudioVAERunner : public GGMLRunner {
LTXAudioVAEConfig config; LTXAudioVAEConfig config;
LTXAudioVAE model; LTXAudioVAE model;
sd::Tensor<float> bwe_skip_filter_tensor;
LTXAudioVAERunner(ggml_backend_t backend, LTXAudioVAERunner(ggml_backend_t backend,
ggml_backend_t params_backend, ggml_backend_t params_backend,
@ -994,6 +999,10 @@ namespace LTXV {
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)), config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
model(config) { model(config) {
model.init(params_ctx, tensor_storage_map, prefix); model.init(params_ctx, tensor_storage_map, prefix);
if (config.has_bwe) {
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
bwe_skip_filter_tensor = sd::Tensor<float>::from_vector(build_hann_resample_filter(bwe_ratio));
}
} }
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) { void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
@ -1008,77 +1017,22 @@ namespace LTXV {
return "ltx_audio_vae"; return "ltx_audio_vae";
} }
ggml_cgraph* build_base_graph(const sd::Tensor<float>& latent_tensor) {
auto latent = make_input(latent_tensor);
int target_time = static_cast<int>(latent_tensor.shape()[1]) * config.latent_downsample_factor() -
(config.latent_downsample_factor() - 1);
int target_freq = config.mel_bins;
ggml_cgraph* gf = new_graph_custom(655360);
auto runner_ctx = GGMLRunner::get_context();
auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq);
auto waveform = model.run_vocoder(&runner_ctx, mel);
ggml_build_forward_expand(gf, waveform);
return gf;
}
ggml_cgraph* build_bwe_graph(const sd::Tensor<float>& mel_tensor) {
auto mel = make_input(mel_tensor);
ggml_cgraph* gf = new_graph_custom(655360);
auto runner_ctx = GGMLRunner::get_context();
auto residual = model.run_bwe_generator(&runner_ctx, mel);
ggml_build_forward_expand(gf, residual);
return gf;
}
sd::Tensor<float> compute_base_waveform(int n_threads,
const sd::Tensor<float>& latent_tensor) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_base_graph(latent_tensor);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
}
sd::Tensor<float> compute_bwe_residual(int n_threads,
const sd::Tensor<float>& mel_tensor) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_bwe_graph(mel_tensor);
};
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
}
sd::Tensor<float> decode(int n_threads, sd::Tensor<float> decode(int n_threads,
const sd::Tensor<float>& latent_tensor) { const sd::Tensor<float>& latent_tensor) {
auto waveform = compute_base_waveform(n_threads, latent_tensor); int64_t t0 = ggml_time_ms();
if (!config.has_bwe || waveform.empty()) { auto get_graph = [&]() -> ggml_cgraph* {
return waveform; auto latent = make_input(latent_tensor);
} ggml_tensor* bwe_skip_filter = config.has_bwe ? make_input(bwe_skip_filter_tensor) : nullptr;
ggml_cgraph* gf = new_graph_custom(655360);
auto waveform_host = normalize_waveform_for_host(waveform); auto runner_ctx = GGMLRunner::get_context();
const int64_t low_time = waveform_host.shape()[0]; auto waveform = model.decode(&runner_ctx, latent, bwe_skip_filter);
const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate; ggml_build_forward_expand(gf, waveform);
int64_t remainder = low_time % config.bwe_hop_length; return gf;
if (remainder != 0) { };
sd::Tensor<float> padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]}); auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host); int64_t t1 = ggml_time_ms();
waveform_host = std::move(padded); LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
} return result;
auto mel_basis_tensor = model.mel_basis_tensor();
auto stft_basis_tensor = model.stft_forward_basis_tensor();
GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr);
auto mel_basis = load_param_tensor_f32(mel_basis_tensor);
auto forward_basis = load_param_tensor_f32(stft_basis_tensor);
auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length);
auto residual_raw = compute_bwe_residual(n_threads, bwe_mel);
if (residual_raw.empty()) {
return waveform;
}
auto residual = normalize_waveform_for_host(residual_raw);
auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate);
auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f);
auto cropped = crop_waveform_samples(combined, out_time);
return restore_trailing_singleton_dims(cropped, 4);
} }
void test(const std::string& input_path) { void test(const std::string& input_path) {

View File

@ -5218,14 +5218,24 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
sd_ctx->sd->diffusion_model->free_params_buffer(); sd_ctx->sd->diffusion_model->free_params_buffer();
} }
int64_t latent_end = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
sd_audio_t* generated_audio = nullptr; sd_audio_t* generated_audio = nullptr;
if (sd_version_is_ltxav(sd_ctx->sd->version) && if (sd_version_is_ltxav(sd_ctx->sd->version) &&
latents.audio_length > 0 && latents.audio_length > 0 &&
sd_ctx->sd->audio_vae_model != nullptr) { sd_ctx->sd->audio_vae_model != nullptr) {
int64_t audio_latent_decode_start = ggml_time_ms();
auto audio_latent = unpack_ltxav_audio_latent(final_latent, auto audio_latent = unpack_ltxav_audio_latent(final_latent,
latents.audio_length, latents.audio_length,
sd_ctx->sd->get_latent_channel()); sd_ctx->sd->get_latent_channel());
if (!audio_latent.empty()) { if (!audio_latent.empty()) {
LOG_DEBUG("decode audio latent %dx%dx%dx%d",
(int)audio_latent.shape()[0],
(int)audio_latent.shape()[1],
(int)audio_latent.shape()[2],
(int)audio_latent.shape()[3]);
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent); auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
if (!waveform.empty()) { if (!waveform.empty()) {
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform); generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
@ -5233,6 +5243,8 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_WARN("LTX audio latent decode failed; continuing with silent video output"); LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
} }
} }
int64_t audio_latent_decode_end = ggml_time_ms();
LOG_INFO("decoding audio latent completed, taking %.2fs", (audio_latent_decode_end - audio_latent_decode_start) * 1.0f / 1000);
} }
if (latents.video_conditioning_frame_count > 0) { if (latents.video_conditioning_frame_count > 0) {
@ -5245,9 +5257,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
} }
int64_t latent_end = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out); auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
if (result == nullptr) { if (result == nullptr) {
free_sd_audio(generated_audio); free_sd_audio(generated_audio);