mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
perf: run LTX audio VAE decode in one ggml graph (#1538)
This commit is contained in:
parent
47d8198b69
commit
2e3514625a
@ -2,6 +2,7 @@
|
|||||||
#define __SD_LTX_AUDIO_VAE_H__
|
#define __SD_LTX_AUDIO_VAE_H__
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -171,90 +172,59 @@ namespace LTXV {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static sd::Tensor<float> squeeze_trailing_singleton_dims(sd::Tensor<float> tensor) {
|
static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx,
|
||||||
while (tensor.dim() > 0 && tensor.shape().back() == 1) {
|
ggml_tensor* waveform,
|
||||||
tensor = tensor.squeeze(static_cast<size_t>(tensor.dim() - 1));
|
ggml_tensor* forward_basis,
|
||||||
}
|
ggml_tensor* mel_basis,
|
||||||
return tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> normalize_waveform_for_host(sd::Tensor<float> waveform) {
|
|
||||||
waveform = squeeze_trailing_singleton_dims(std::move(waveform));
|
|
||||||
if (waveform.empty()) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
if (waveform.dim() == 1) {
|
|
||||||
return waveform.reshape({waveform.shape()[0], 1, 1});
|
|
||||||
}
|
|
||||||
if (waveform.dim() == 2) {
|
|
||||||
return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1});
|
|
||||||
}
|
|
||||||
if (waveform.dim() == 3) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim()));
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> load_param_tensor_f32(ggml_tensor* tensor) {
|
|
||||||
GGML_ASSERT(tensor != nullptr);
|
|
||||||
return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> compute_log_mel_spectrogram(const sd::Tensor<float>& waveform_in,
|
|
||||||
const sd::Tensor<float>& forward_basis,
|
|
||||||
const sd::Tensor<float>& mel_basis,
|
|
||||||
int hop_length) {
|
int hop_length) {
|
||||||
auto waveform = normalize_waveform_for_host(waveform_in);
|
auto ctx = runner_ctx->ggml_ctx;
|
||||||
GGML_ASSERT(forward_basis.dim() >= 3);
|
GGML_ASSERT(ctx != nullptr);
|
||||||
GGML_ASSERT(mel_basis.dim() >= 2);
|
GGML_ASSERT(waveform != nullptr);
|
||||||
|
GGML_ASSERT(forward_basis != nullptr);
|
||||||
|
GGML_ASSERT(mel_basis != nullptr);
|
||||||
|
GGML_ASSERT(waveform->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(forward_basis->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(mel_basis->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(forward_basis->ne[1] == 1);
|
||||||
|
|
||||||
const int64_t time = waveform.shape()[0];
|
const int64_t time = waveform->ne[0];
|
||||||
const int64_t channels = waveform.shape()[1];
|
const int64_t channels = waveform->ne[1];
|
||||||
const int64_t batch = waveform.shape()[2];
|
const int64_t batch = waveform->ne[2];
|
||||||
const int64_t filter_len = forward_basis.shape()[0];
|
const int64_t filter_len = forward_basis->ne[0];
|
||||||
const int64_t basis_freq2 = forward_basis.shape().back();
|
const int64_t stft_channels = forward_basis->ne[2];
|
||||||
const int64_t n_freqs = basis_freq2 / 2;
|
const int64_t n_freqs = stft_channels / 2;
|
||||||
const int64_t n_mels = mel_basis.shape()[1];
|
const int64_t n_mels = mel_basis->ne[1];
|
||||||
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
|
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
|
||||||
const int64_t padded_time = time + left_pad;
|
const int64_t padded_time = time + left_pad;
|
||||||
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
|
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
|
||||||
|
|
||||||
sd::Tensor<float> log_mel({n_mels, frame_count, channels, batch});
|
GGML_ASSERT(stft_channels % 2 == 0);
|
||||||
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
|
GGML_ASSERT(mel_basis->ne[0] == n_freqs);
|
||||||
std::vector<float> magnitude(static_cast<size_t>(n_freqs), 0.0f);
|
GGML_ASSERT(waveform->ne[3] == 1);
|
||||||
|
GGML_ASSERT(frame_count > 0);
|
||||||
|
|
||||||
for (int64_t b = 0; b < batch; ++b) {
|
auto x = ggml_reshape_3d(ctx, waveform, time, 1, channels * batch);
|
||||||
for (int64_t c = 0; c < channels; ++c) {
|
if (left_pad > 0) {
|
||||||
std::fill(padded.begin(), padded.end(), 0.0f);
|
x = ggml_pad_ext(ctx, x, static_cast<int>(left_pad), 0, 0, 0, 0, 0, 0, 0);
|
||||||
for (int64_t t = 0; t < time; ++t) {
|
|
||||||
padded[static_cast<size_t>(t + left_pad)] = waveform.index(t, c, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t frame = 0; frame < frame_count; ++frame) {
|
auto frames = ggml_conv_1d(ctx, forward_basis, x, hop_length, 0, 1);
|
||||||
const int64_t frame_offset = frame * hop_length;
|
GGML_ASSERT(frames->ne[0] == frame_count);
|
||||||
for (int64_t f = 0; f < n_freqs; ++f) {
|
GGML_ASSERT(frames->ne[1] == stft_channels);
|
||||||
double real = 0.0;
|
GGML_ASSERT(frames->ne[2] == channels * batch);
|
||||||
double imag = 0.0;
|
|
||||||
for (int64_t k = 0; k < filter_len; ++k) {
|
|
||||||
const float sample = padded[static_cast<size_t>(frame_offset + k)];
|
|
||||||
real += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f));
|
|
||||||
imag += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f + n_freqs));
|
|
||||||
}
|
|
||||||
magnitude[static_cast<size_t>(f)] = static_cast<float>(std::sqrt(real * real + imag * imag));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t m = 0; m < n_mels; ++m) {
|
auto real = ggml_ext_slice(ctx, frames, 1, 0, n_freqs);
|
||||||
double mel_value = 0.0;
|
auto imag = ggml_ext_slice(ctx, frames, 1, n_freqs, stft_channels);
|
||||||
for (int64_t f = 0; f < n_freqs; ++f) {
|
auto magnitude = ggml_sqrt(ctx,
|
||||||
mel_value += static_cast<double>(mel_basis.index(f, m)) * static_cast<double>(magnitude[static_cast<size_t>(f)]);
|
ggml_add(ctx,
|
||||||
}
|
ggml_sqr(ctx, real),
|
||||||
log_mel.index(m, frame, c, b) = static_cast<float>(std::log(std::max(mel_value, 1e-5)));
|
ggml_sqr(ctx, imag)));
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return log_mel;
|
magnitude = ggml_cont(ctx, ggml_permute(ctx, magnitude, 1, 0, 2, 3));
|
||||||
|
auto mel = ggml_mul_mat(ctx, mel_basis, magnitude);
|
||||||
|
mel = ggml_log(ctx, ggml_clamp(ctx, mel, 1e-5f, std::numeric_limits<float>::max()));
|
||||||
|
|
||||||
|
return ggml_reshape_4d(ctx, mel, n_mels, frame_count, channels, batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<float> build_hann_resample_filter(int ratio) {
|
static std::vector<float> build_hann_resample_filter(int ratio) {
|
||||||
@ -276,75 +246,6 @@ namespace LTXV {
|
|||||||
return filter;
|
return filter;
|
||||||
}
|
}
|
||||||
|
|
||||||
static sd::Tensor<float> upsample_waveform_hann(const sd::Tensor<float>& waveform_in, int ratio) {
|
|
||||||
auto waveform = normalize_waveform_for_host(waveform_in);
|
|
||||||
if (ratio <= 1) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int lowpass_filter_width = 6;
|
|
||||||
const double rolloff = 0.99;
|
|
||||||
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
|
|
||||||
const int kernel_size = 2 * width * ratio + 1;
|
|
||||||
const int pad = width;
|
|
||||||
const int pad_left = 2 * width * ratio;
|
|
||||||
const int pad_right = kernel_size - ratio;
|
|
||||||
const int64_t time = waveform.shape()[0];
|
|
||||||
const int64_t channels = waveform.shape()[1];
|
|
||||||
const int64_t batch = waveform.shape()[2];
|
|
||||||
const int64_t padded_time = time + 2 * pad;
|
|
||||||
const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size;
|
|
||||||
const int64_t cropped_time = conv_out_time - pad_left - pad_right;
|
|
||||||
auto filter = build_hann_resample_filter(ratio);
|
|
||||||
|
|
||||||
sd::Tensor<float> output({cropped_time, channels, batch});
|
|
||||||
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
|
|
||||||
std::vector<float> conv_out(static_cast<size_t>(conv_out_time), 0.0f);
|
|
||||||
|
|
||||||
for (int64_t b = 0; b < batch; ++b) {
|
|
||||||
for (int64_t c = 0; c < channels; ++c) {
|
|
||||||
std::fill(padded.begin(), padded.end(), 0.0f);
|
|
||||||
const float first = waveform.index(0, c, b);
|
|
||||||
const float last = waveform.index(time - 1, c, b);
|
|
||||||
for (int i = 0; i < pad; ++i) {
|
|
||||||
padded[static_cast<size_t>(i)] = first;
|
|
||||||
padded[static_cast<size_t>(pad + time + i)] = last;
|
|
||||||
}
|
|
||||||
for (int64_t t = 0; t < time; ++t) {
|
|
||||||
padded[static_cast<size_t>(pad + t)] = waveform.index(t, c, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::fill(conv_out.begin(), conv_out.end(), 0.0f);
|
|
||||||
for (int64_t t = 0; t < padded_time; ++t) {
|
|
||||||
const double sample = static_cast<double>(padded[static_cast<size_t>(t)]) * ratio;
|
|
||||||
const int64_t out_base = t * ratio;
|
|
||||||
for (int k = 0; k < kernel_size; ++k) {
|
|
||||||
conv_out[static_cast<size_t>(out_base + k)] += static_cast<float>(sample * filter[static_cast<size_t>(k)]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t t = 0; t < cropped_time; ++t) {
|
|
||||||
output.index(t, c, b) = conv_out[static_cast<size_t>(t + pad_left)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> crop_waveform_samples(const sd::Tensor<float>& waveform_in, int64_t target_samples) {
|
|
||||||
auto waveform = normalize_waveform_for_host(waveform_in);
|
|
||||||
if (waveform.shape()[0] == target_samples) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
if (waveform.shape()[0] > target_samples) {
|
|
||||||
return sd::ops::slice(waveform, 0, 0, target_samples);
|
|
||||||
}
|
|
||||||
sd::Tensor<float> output({target_samples, waveform.shape()[1], waveform.shape()[2]});
|
|
||||||
sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform);
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ggml_type audio_conv_weight_type(ggml_type type) {
|
static ggml_type audio_conv_weight_type(ggml_type type) {
|
||||||
return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type;
|
return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type;
|
||||||
}
|
}
|
||||||
@ -413,22 +314,101 @@ namespace LTXV {
|
|||||||
return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1);
|
return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ggml_tensor* reverse_1d_filter(ggml_context* ctx, ggml_tensor* filter) {
|
||||||
|
GGML_ASSERT(ctx != nullptr);
|
||||||
|
GGML_ASSERT(filter != nullptr);
|
||||||
|
GGML_ASSERT(filter->ne[1] == 1);
|
||||||
|
GGML_ASSERT(filter->ne[2] == 1);
|
||||||
|
GGML_ASSERT(filter->ne[3] == 1);
|
||||||
|
|
||||||
|
ggml_tensor* reversed = nullptr;
|
||||||
|
for (int64_t k = filter->ne[0] - 1; k >= 0; --k) {
|
||||||
|
auto slice = ggml_ext_slice(ctx, filter, 0, k, k + 1);
|
||||||
|
reversed = reversed == nullptr ? slice : ggml_concat(ctx, reversed, slice, 0);
|
||||||
|
}
|
||||||
|
return reversed;
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx,
|
static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
ggml_tensor* filter,
|
ggml_tensor* filter,
|
||||||
int stride) {
|
int stride) {
|
||||||
GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1);
|
GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1);
|
||||||
GGML_ASSERT(filter->ne[1] == 1);
|
GGML_ASSERT(filter->ne[1] == 1);
|
||||||
|
GGML_ASSERT(filter->ne[2] == 1 && filter->ne[3] == 1);
|
||||||
|
|
||||||
ggml_tensor* out = nullptr;
|
const int64_t time = x->ne[0];
|
||||||
for (int64_t c = 0; c < x->ne[1]; ++c) {
|
const int64_t channels = x->ne[1];
|
||||||
auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1);
|
const int64_t kernel_size = filter->ne[0];
|
||||||
auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1);
|
const int64_t out_time = (time - 1) * stride + kernel_size;
|
||||||
yi = ggml_ext_scale(ctx, yi, static_cast<float>(stride));
|
|
||||||
yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1);
|
auto x_flat = ggml_reshape_3d(ctx, x, 1, time, channels);
|
||||||
out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1);
|
if (stride > 1) {
|
||||||
|
auto zero_unit = ggml_ext_scale(ctx, x_flat, 0.0f);
|
||||||
|
auto zero_tail = zero_unit;
|
||||||
|
for (int i = 1; i < stride - 1; ++i) {
|
||||||
|
zero_tail = ggml_concat(ctx, zero_tail, zero_unit, 0);
|
||||||
}
|
}
|
||||||
return out;
|
x_flat = ggml_concat(ctx, x_flat, zero_tail, 0);
|
||||||
|
}
|
||||||
|
x_flat = ggml_reshape_3d(ctx, x_flat, time * stride, 1, channels);
|
||||||
|
|
||||||
|
auto reversed_filter = reverse_1d_filter(ctx, filter);
|
||||||
|
auto out = ggml_conv_1d(ctx, reversed_filter, x_flat, 1, static_cast<int>(kernel_size - 1), 1);
|
||||||
|
if (out->ne[0] > out_time) {
|
||||||
|
out = ggml_ext_slice(ctx, out, 0, 0, out_time);
|
||||||
|
}
|
||||||
|
GGML_ASSERT(out->ne[0] == out_time);
|
||||||
|
GGML_ASSERT(out->ne[1] == 1);
|
||||||
|
GGML_ASSERT(out->ne[2] == channels);
|
||||||
|
|
||||||
|
out = ggml_ext_scale(ctx, out, static_cast<float>(stride));
|
||||||
|
return ggml_reshape_4d(ctx, out, out_time, channels, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_tensor* upsample_waveform_hann(GGMLRunnerContext* runner_ctx,
|
||||||
|
ggml_tensor* waveform,
|
||||||
|
ggml_tensor* filter,
|
||||||
|
int ratio) {
|
||||||
|
auto ctx = runner_ctx->ggml_ctx;
|
||||||
|
GGML_ASSERT(ctx != nullptr);
|
||||||
|
GGML_ASSERT(waveform != nullptr);
|
||||||
|
GGML_ASSERT(filter != nullptr);
|
||||||
|
GGML_ASSERT(waveform->ne[3] == 1);
|
||||||
|
if (ratio <= 1) {
|
||||||
|
return waveform;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int lowpass_filter_width = 6;
|
||||||
|
const double rolloff = 0.99;
|
||||||
|
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
|
||||||
|
const int kernel_size = 2 * width * ratio + 1;
|
||||||
|
const int pad = width;
|
||||||
|
const int pad_left = 2 * width * ratio;
|
||||||
|
const int pad_right = kernel_size - ratio;
|
||||||
|
const int64_t time = waveform->ne[0];
|
||||||
|
const int64_t channels = waveform->ne[1];
|
||||||
|
const int64_t batch = waveform->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(filter->ne[0] == kernel_size);
|
||||||
|
|
||||||
|
auto x = ggml_reshape_3d(ctx, waveform, time, channels * batch, 1);
|
||||||
|
x = replicate_pad_1d(runner_ctx, x, pad, pad);
|
||||||
|
x = depthwise_conv_transpose1d(ctx, x, filter, ratio);
|
||||||
|
x = ggml_ext_slice(ctx, x, 0, pad_left, x->ne[0] - pad_right);
|
||||||
|
return ggml_reshape_3d(ctx, x, x->ne[0], channels, batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_tensor* crop_waveform_samples(ggml_context* ctx,
|
||||||
|
ggml_tensor* waveform,
|
||||||
|
int64_t target_samples) {
|
||||||
|
GGML_ASSERT(ctx != nullptr);
|
||||||
|
GGML_ASSERT(waveform != nullptr);
|
||||||
|
if (waveform->ne[0] == target_samples) {
|
||||||
|
return waveform;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(waveform->ne[0] > target_samples);
|
||||||
|
return ggml_ext_slice(ctx, waveform, 0, 0, target_samples);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PixelNorm2D : public UnaryBlock {
|
struct PixelNorm2D : public UnaryBlock {
|
||||||
@ -950,41 +930,66 @@ namespace LTXV {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx,
|
ggml_tensor* decode(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* latent,
|
ggml_tensor* latent,
|
||||||
int target_time,
|
ggml_tensor* bwe_skip_filter) {
|
||||||
int target_freq) {
|
int target_time = static_cast<int>(latent->ne[1]) * config.latent_downsample_factor() -
|
||||||
|
(config.latent_downsample_factor() - 1);
|
||||||
|
int target_freq = config.mel_bins;
|
||||||
|
|
||||||
|
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
|
||||||
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
|
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
|
||||||
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
|
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
|
||||||
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
|
auto mel = decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
|
||||||
return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) {
|
|
||||||
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
|
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
|
||||||
return vocoder->forward(ctx, mel);
|
auto waveform = vocoder->forward(ctx, mel);
|
||||||
|
|
||||||
|
if (config.has_bwe) {
|
||||||
|
GGML_ASSERT(bwe_skip_filter != nullptr);
|
||||||
|
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
|
||||||
|
const int64_t low_time = waveform->ne[0];
|
||||||
|
const int64_t out_time = low_time * bwe_ratio;
|
||||||
|
int64_t remainder = low_time % config.bwe_hop_length;
|
||||||
|
auto bwe_waveform = waveform;
|
||||||
|
if (remainder != 0) {
|
||||||
|
bwe_waveform = ggml_pad_ext(ctx->ggml_ctx,
|
||||||
|
bwe_waveform,
|
||||||
|
0,
|
||||||
|
static_cast<int>(config.bwe_hop_length - remainder),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) {
|
auto mel_basis = params["vocoder.mel_stft.mel_basis"];
|
||||||
GGML_ASSERT(config.has_bwe);
|
auto stft_basis = params["vocoder.mel_stft.stft_fn.forward_basis"];
|
||||||
|
GGML_ASSERT(mel_basis != nullptr && stft_basis != nullptr);
|
||||||
|
auto bwe_mel = compute_log_mel_spectrogram(ctx, bwe_waveform, stft_basis, mel_basis, config.bwe_hop_length);
|
||||||
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
|
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
|
||||||
return bwe_generator->forward(ctx, mel);
|
auto residual = bwe_generator->forward(ctx, bwe_mel);
|
||||||
|
|
||||||
|
auto skip = upsample_waveform_hann(ctx,
|
||||||
|
bwe_waveform,
|
||||||
|
bwe_skip_filter,
|
||||||
|
bwe_ratio);
|
||||||
|
waveform = ggml_clamp(ctx->ggml_ctx,
|
||||||
|
ggml_add(ctx->ggml_ctx, residual, skip),
|
||||||
|
-1.0f,
|
||||||
|
1.0f);
|
||||||
|
waveform = crop_waveform_samples(ctx->ggml_ctx, waveform, out_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* mel_basis_tensor() const {
|
return waveform;
|
||||||
auto iter = params.find("vocoder.mel_stft.mel_basis");
|
|
||||||
return iter == params.end() ? nullptr : iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* stft_forward_basis_tensor() const {
|
|
||||||
auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis");
|
|
||||||
return iter == params.end() ? nullptr : iter->second;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LTXAudioVAERunner : public GGMLRunner {
|
struct LTXAudioVAERunner : public GGMLRunner {
|
||||||
LTXAudioVAEConfig config;
|
LTXAudioVAEConfig config;
|
||||||
LTXAudioVAE model;
|
LTXAudioVAE model;
|
||||||
|
sd::Tensor<float> bwe_skip_filter_tensor;
|
||||||
|
|
||||||
LTXAudioVAERunner(ggml_backend_t backend,
|
LTXAudioVAERunner(ggml_backend_t backend,
|
||||||
ggml_backend_t params_backend,
|
ggml_backend_t params_backend,
|
||||||
@ -994,6 +999,10 @@ namespace LTXV {
|
|||||||
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
|
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
|
||||||
model(config) {
|
model(config) {
|
||||||
model.init(params_ctx, tensor_storage_map, prefix);
|
model.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
if (config.has_bwe) {
|
||||||
|
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
|
||||||
|
bwe_skip_filter_tensor = sd::Tensor<float>::from_vector(build_hann_resample_filter(bwe_ratio));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
@ -1008,77 +1017,22 @@ namespace LTXV {
|
|||||||
return "ltx_audio_vae";
|
return "ltx_audio_vae";
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_base_graph(const sd::Tensor<float>& latent_tensor) {
|
|
||||||
auto latent = make_input(latent_tensor);
|
|
||||||
int target_time = static_cast<int>(latent_tensor.shape()[1]) * config.latent_downsample_factor() -
|
|
||||||
(config.latent_downsample_factor() - 1);
|
|
||||||
int target_freq = config.mel_bins;
|
|
||||||
|
|
||||||
ggml_cgraph* gf = new_graph_custom(655360);
|
|
||||||
auto runner_ctx = GGMLRunner::get_context();
|
|
||||||
auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq);
|
|
||||||
auto waveform = model.run_vocoder(&runner_ctx, mel);
|
|
||||||
ggml_build_forward_expand(gf, waveform);
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_cgraph* build_bwe_graph(const sd::Tensor<float>& mel_tensor) {
|
|
||||||
auto mel = make_input(mel_tensor);
|
|
||||||
ggml_cgraph* gf = new_graph_custom(655360);
|
|
||||||
auto runner_ctx = GGMLRunner::get_context();
|
|
||||||
auto residual = model.run_bwe_generator(&runner_ctx, mel);
|
|
||||||
ggml_build_forward_expand(gf, residual);
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<float> compute_base_waveform(int n_threads,
|
|
||||||
const sd::Tensor<float>& latent_tensor) {
|
|
||||||
auto get_graph = [&]() -> ggml_cgraph* {
|
|
||||||
return build_base_graph(latent_tensor);
|
|
||||||
};
|
|
||||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<float> compute_bwe_residual(int n_threads,
|
|
||||||
const sd::Tensor<float>& mel_tensor) {
|
|
||||||
auto get_graph = [&]() -> ggml_cgraph* {
|
|
||||||
return build_bwe_graph(mel_tensor);
|
|
||||||
};
|
|
||||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<float> decode(int n_threads,
|
sd::Tensor<float> decode(int n_threads,
|
||||||
const sd::Tensor<float>& latent_tensor) {
|
const sd::Tensor<float>& latent_tensor) {
|
||||||
auto waveform = compute_base_waveform(n_threads, latent_tensor);
|
int64_t t0 = ggml_time_ms();
|
||||||
if (!config.has_bwe || waveform.empty()) {
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
return waveform;
|
auto latent = make_input(latent_tensor);
|
||||||
}
|
ggml_tensor* bwe_skip_filter = config.has_bwe ? make_input(bwe_skip_filter_tensor) : nullptr;
|
||||||
|
ggml_cgraph* gf = new_graph_custom(655360);
|
||||||
auto waveform_host = normalize_waveform_for_host(waveform);
|
auto runner_ctx = GGMLRunner::get_context();
|
||||||
const int64_t low_time = waveform_host.shape()[0];
|
auto waveform = model.decode(&runner_ctx, latent, bwe_skip_filter);
|
||||||
const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate;
|
ggml_build_forward_expand(gf, waveform);
|
||||||
int64_t remainder = low_time % config.bwe_hop_length;
|
return gf;
|
||||||
if (remainder != 0) {
|
};
|
||||||
sd::Tensor<float> padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]});
|
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
|
||||||
sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host);
|
int64_t t1 = ggml_time_ms();
|
||||||
waveform_host = std::move(padded);
|
LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
}
|
return result;
|
||||||
|
|
||||||
auto mel_basis_tensor = model.mel_basis_tensor();
|
|
||||||
auto stft_basis_tensor = model.stft_forward_basis_tensor();
|
|
||||||
GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr);
|
|
||||||
auto mel_basis = load_param_tensor_f32(mel_basis_tensor);
|
|
||||||
auto forward_basis = load_param_tensor_f32(stft_basis_tensor);
|
|
||||||
auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length);
|
|
||||||
auto residual_raw = compute_bwe_residual(n_threads, bwe_mel);
|
|
||||||
if (residual_raw.empty()) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
auto residual = normalize_waveform_for_host(residual_raw);
|
|
||||||
auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate);
|
|
||||||
auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f);
|
|
||||||
auto cropped = crop_waveform_samples(combined, out_time);
|
|
||||||
return restore_trailing_singleton_dims(cropped, 4);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void test(const std::string& input_path) {
|
void test(const std::string& input_path) {
|
||||||
|
|||||||
@ -5218,14 +5218,24 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
sd_ctx->sd->diffusion_model->free_params_buffer();
|
sd_ctx->sd->diffusion_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t latent_end = ggml_time_ms();
|
||||||
|
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
|
||||||
|
|
||||||
sd_audio_t* generated_audio = nullptr;
|
sd_audio_t* generated_audio = nullptr;
|
||||||
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
||||||
latents.audio_length > 0 &&
|
latents.audio_length > 0 &&
|
||||||
sd_ctx->sd->audio_vae_model != nullptr) {
|
sd_ctx->sd->audio_vae_model != nullptr) {
|
||||||
|
int64_t audio_latent_decode_start = ggml_time_ms();
|
||||||
|
|
||||||
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
|
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
|
||||||
latents.audio_length,
|
latents.audio_length,
|
||||||
sd_ctx->sd->get_latent_channel());
|
sd_ctx->sd->get_latent_channel());
|
||||||
if (!audio_latent.empty()) {
|
if (!audio_latent.empty()) {
|
||||||
|
LOG_DEBUG("decode audio latent %dx%dx%dx%d",
|
||||||
|
(int)audio_latent.shape()[0],
|
||||||
|
(int)audio_latent.shape()[1],
|
||||||
|
(int)audio_latent.shape()[2],
|
||||||
|
(int)audio_latent.shape()[3]);
|
||||||
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
|
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
|
||||||
if (!waveform.empty()) {
|
if (!waveform.empty()) {
|
||||||
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
|
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
|
||||||
@ -5233,6 +5243,8 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
|
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int64_t audio_latent_decode_end = ggml_time_ms();
|
||||||
|
LOG_INFO("decoding audio latent completed, taking %.2fs", (audio_latent_decode_end - audio_latent_decode_start) * 1.0f / 1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (latents.video_conditioning_frame_count > 0) {
|
if (latents.video_conditioning_frame_count > 0) {
|
||||||
@ -5245,9 +5257,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t latent_end = ggml_time_ms();
|
|
||||||
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
|
|
||||||
|
|
||||||
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
|
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
|
||||||
if (result == nullptr) {
|
if (result == nullptr) {
|
||||||
free_sd_audio(generated_audio);
|
free_sd_audio(generated_audio);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user