mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
Compare commits
2 Commits
ef92a0027e
...
2e3514625a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e3514625a | ||
|
|
47d8198b69 |
@ -1602,6 +1602,23 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
|
|||||||
return num;
|
return num;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__STATIC_INLINE__ ggml_tensor* ggml_ext_vec_concat(ggml_context* ctx,
|
||||||
|
std::vector<ggml_tensor*>& tensors,
|
||||||
|
int dim) {
|
||||||
|
while (tensors.size() > 1) {
|
||||||
|
std::vector<ggml_tensor*> next_level;
|
||||||
|
for (size_t i = 0; i < tensors.size(); i += 2) {
|
||||||
|
if (i + 1 < tensors.size()) {
|
||||||
|
next_level.push_back(ggml_concat(ctx, tensors[i], tensors[i + 1], dim));
|
||||||
|
} else {
|
||||||
|
next_level.push_back(tensors[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensors = std::move(next_level);
|
||||||
|
}
|
||||||
|
return tensors[0];
|
||||||
|
}
|
||||||
|
|
||||||
/* SDXL with LoRA requires more space */
|
/* SDXL with LoRA requires more space */
|
||||||
#define MAX_PARAMS_TENSOR_NUM 32768
|
#define MAX_PARAMS_TENSOR_NUM 32768
|
||||||
#define MAX_GRAPH_SIZE 327680
|
#define MAX_GRAPH_SIZE 327680
|
||||||
@ -3139,6 +3156,163 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Conv2d_grouped : public UnaryBlock {
|
||||||
|
protected:
|
||||||
|
int64_t in_channels;
|
||||||
|
int64_t out_channels;
|
||||||
|
int groups;
|
||||||
|
std::pair<int, int> kernel_size;
|
||||||
|
std::pair<int, int> stride;
|
||||||
|
std::pair<int, int> padding;
|
||||||
|
std::pair<int, int> dilation;
|
||||||
|
bool bias;
|
||||||
|
float scale = 1.f;
|
||||||
|
std::string prefix;
|
||||||
|
|
||||||
|
void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map, const std::string prefix = "") override {
|
||||||
|
this->prefix = prefix;
|
||||||
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
|
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels / groups, out_channels);
|
||||||
|
if (bias) {
|
||||||
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
|
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
Conv2d_grouped(int64_t in_channels,
|
||||||
|
int64_t out_channels,
|
||||||
|
int groups,
|
||||||
|
std::pair<int, int> kernel_size,
|
||||||
|
std::pair<int, int> stride = {1, 1},
|
||||||
|
std::pair<int, int> padding = {0, 0},
|
||||||
|
std::pair<int, int> dilation = {1, 1},
|
||||||
|
bool bias = true)
|
||||||
|
: in_channels(in_channels),
|
||||||
|
out_channels(out_channels),
|
||||||
|
groups(groups),
|
||||||
|
kernel_size(kernel_size),
|
||||||
|
stride(stride),
|
||||||
|
padding(padding),
|
||||||
|
dilation(dilation),
|
||||||
|
bias(bias) {}
|
||||||
|
|
||||||
|
void set_scale(float scale_value) {
|
||||||
|
scale = scale_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string get_desc() {
|
||||||
|
return "Conv2d_grouped";
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
|
ggml_tensor* w = params["weight"];
|
||||||
|
ggml_tensor* b = nullptr;
|
||||||
|
if (bias) {
|
||||||
|
b = params["bias"];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groups == 1) {
|
||||||
|
if (ctx->weight_adapter) {
|
||||||
|
WeightAdapter::ForwardParams forward_params;
|
||||||
|
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||||
|
forward_params.conv2d.s0 = stride.second;
|
||||||
|
forward_params.conv2d.s1 = stride.first;
|
||||||
|
forward_params.conv2d.p0 = padding.second;
|
||||||
|
forward_params.conv2d.p1 = padding.first;
|
||||||
|
forward_params.conv2d.d0 = dilation.second;
|
||||||
|
forward_params.conv2d.d1 = dilation.first;
|
||||||
|
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||||
|
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
|
||||||
|
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
|
||||||
|
forward_params.conv2d.scale = scale;
|
||||||
|
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params);
|
||||||
|
}
|
||||||
|
return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, b,
|
||||||
|
stride.second, stride.first,
|
||||||
|
padding.second, padding.first,
|
||||||
|
dilation.second, dilation.first,
|
||||||
|
ctx->conv2d_direct_enabled,
|
||||||
|
ctx->circular_x_enabled,
|
||||||
|
ctx->circular_y_enabled,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groups == in_channels && groups == out_channels) {
|
||||||
|
ggml_tensor* res;
|
||||||
|
if (ctx->conv2d_direct_enabled) {
|
||||||
|
res = ggml_conv_2d_dw_direct(ctx->ggml_ctx, x, w,
|
||||||
|
stride.second, stride.first,
|
||||||
|
padding.second, padding.first,
|
||||||
|
dilation.second, dilation.first);
|
||||||
|
} else {
|
||||||
|
res = ggml_conv_2d_dw(ctx->ggml_ctx, x, w,
|
||||||
|
stride.second, stride.first,
|
||||||
|
padding.second, padding.first,
|
||||||
|
dilation.second, dilation.first);
|
||||||
|
}
|
||||||
|
if (b) {
|
||||||
|
res = ggml_add(ctx->ggml_ctx, res, b);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t ic_g = in_channels / groups;
|
||||||
|
int64_t oc_g = out_channels / groups;
|
||||||
|
|
||||||
|
std::vector<ggml_tensor*> out_slices(groups);
|
||||||
|
|
||||||
|
for (int i = 0; i < groups; ++i) {
|
||||||
|
size_t x_offset = i * ic_g * x->nb[2];
|
||||||
|
ggml_tensor* x_i = ggml_view_4d(ctx->ggml_ctx, x,
|
||||||
|
x->ne[0], x->ne[1], ic_g, x->ne[3],
|
||||||
|
x->nb[1], x->nb[2], x->nb[3],
|
||||||
|
x_offset);
|
||||||
|
|
||||||
|
size_t w_offset = i * oc_g * w->nb[3];
|
||||||
|
ggml_tensor* w_i = ggml_view_4d(ctx->ggml_ctx, w,
|
||||||
|
w->ne[0], w->ne[1], w->ne[2], oc_g,
|
||||||
|
w->nb[1], w->nb[2], w->nb[3],
|
||||||
|
w_offset);
|
||||||
|
|
||||||
|
ggml_tensor* b_i = nullptr;
|
||||||
|
if (b) {
|
||||||
|
size_t b_offset = i * oc_g * b->nb[0];
|
||||||
|
b_i = ggml_view_1d(ctx->ggml_ctx, b, oc_g, b_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->weight_adapter) {
|
||||||
|
WeightAdapter::ForwardParams forward_params;
|
||||||
|
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;
|
||||||
|
forward_params.conv2d.s0 = stride.second;
|
||||||
|
forward_params.conv2d.s1 = stride.first;
|
||||||
|
forward_params.conv2d.p0 = padding.second;
|
||||||
|
forward_params.conv2d.p1 = padding.first;
|
||||||
|
forward_params.conv2d.d0 = dilation.second;
|
||||||
|
forward_params.conv2d.d1 = dilation.first;
|
||||||
|
forward_params.conv2d.direct = ctx->conv2d_direct_enabled;
|
||||||
|
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
|
||||||
|
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
|
||||||
|
forward_params.conv2d.scale = scale;
|
||||||
|
out_slices[i] = ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x_i, w_i, b_i, prefix, forward_params);
|
||||||
|
} else {
|
||||||
|
out_slices[i] = ggml_ext_conv_2d(ctx->ggml_ctx, x_i, w_i, b_i,
|
||||||
|
stride.second, stride.first,
|
||||||
|
padding.second, padding.first,
|
||||||
|
dilation.second, dilation.first,
|
||||||
|
ctx->conv2d_direct_enabled,
|
||||||
|
ctx->circular_x_enabled,
|
||||||
|
ctx->circular_y_enabled,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* out = ggml_ext_vec_concat(ctx->ggml_ctx, out_slices, 2);
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class Conv3d : public UnaryBlock {
|
class Conv3d : public UnaryBlock {
|
||||||
protected:
|
protected:
|
||||||
int64_t in_channels;
|
int64_t in_channels;
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
#define __SD_LTX_AUDIO_VAE_H__
|
#define __SD_LTX_AUDIO_VAE_H__
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -171,90 +172,59 @@ namespace LTXV {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static sd::Tensor<float> squeeze_trailing_singleton_dims(sd::Tensor<float> tensor) {
|
static ggml_tensor* compute_log_mel_spectrogram(GGMLRunnerContext* runner_ctx,
|
||||||
while (tensor.dim() > 0 && tensor.shape().back() == 1) {
|
ggml_tensor* waveform,
|
||||||
tensor = tensor.squeeze(static_cast<size_t>(tensor.dim() - 1));
|
ggml_tensor* forward_basis,
|
||||||
}
|
ggml_tensor* mel_basis,
|
||||||
return tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> normalize_waveform_for_host(sd::Tensor<float> waveform) {
|
|
||||||
waveform = squeeze_trailing_singleton_dims(std::move(waveform));
|
|
||||||
if (waveform.empty()) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
if (waveform.dim() == 1) {
|
|
||||||
return waveform.reshape({waveform.shape()[0], 1, 1});
|
|
||||||
}
|
|
||||||
if (waveform.dim() == 2) {
|
|
||||||
return waveform.reshape({waveform.shape()[0], waveform.shape()[1], 1});
|
|
||||||
}
|
|
||||||
if (waveform.dim() == 3) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("Unsupported waveform rank for host processing: rank=" + std::to_string(waveform.dim()));
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> load_param_tensor_f32(ggml_tensor* tensor) {
|
|
||||||
GGML_ASSERT(tensor != nullptr);
|
|
||||||
return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(tensor));
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> compute_log_mel_spectrogram(const sd::Tensor<float>& waveform_in,
|
|
||||||
const sd::Tensor<float>& forward_basis,
|
|
||||||
const sd::Tensor<float>& mel_basis,
|
|
||||||
int hop_length) {
|
int hop_length) {
|
||||||
auto waveform = normalize_waveform_for_host(waveform_in);
|
auto ctx = runner_ctx->ggml_ctx;
|
||||||
GGML_ASSERT(forward_basis.dim() >= 3);
|
GGML_ASSERT(ctx != nullptr);
|
||||||
GGML_ASSERT(mel_basis.dim() >= 2);
|
GGML_ASSERT(waveform != nullptr);
|
||||||
|
GGML_ASSERT(forward_basis != nullptr);
|
||||||
|
GGML_ASSERT(mel_basis != nullptr);
|
||||||
|
GGML_ASSERT(waveform->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(forward_basis->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(mel_basis->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(forward_basis->ne[1] == 1);
|
||||||
|
|
||||||
const int64_t time = waveform.shape()[0];
|
const int64_t time = waveform->ne[0];
|
||||||
const int64_t channels = waveform.shape()[1];
|
const int64_t channels = waveform->ne[1];
|
||||||
const int64_t batch = waveform.shape()[2];
|
const int64_t batch = waveform->ne[2];
|
||||||
const int64_t filter_len = forward_basis.shape()[0];
|
const int64_t filter_len = forward_basis->ne[0];
|
||||||
const int64_t basis_freq2 = forward_basis.shape().back();
|
const int64_t stft_channels = forward_basis->ne[2];
|
||||||
const int64_t n_freqs = basis_freq2 / 2;
|
const int64_t n_freqs = stft_channels / 2;
|
||||||
const int64_t n_mels = mel_basis.shape()[1];
|
const int64_t n_mels = mel_basis->ne[1];
|
||||||
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
|
const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
|
||||||
const int64_t padded_time = time + left_pad;
|
const int64_t padded_time = time + left_pad;
|
||||||
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
|
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;
|
||||||
|
|
||||||
sd::Tensor<float> log_mel({n_mels, frame_count, channels, batch});
|
GGML_ASSERT(stft_channels % 2 == 0);
|
||||||
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
|
GGML_ASSERT(mel_basis->ne[0] == n_freqs);
|
||||||
std::vector<float> magnitude(static_cast<size_t>(n_freqs), 0.0f);
|
GGML_ASSERT(waveform->ne[3] == 1);
|
||||||
|
GGML_ASSERT(frame_count > 0);
|
||||||
|
|
||||||
for (int64_t b = 0; b < batch; ++b) {
|
auto x = ggml_reshape_3d(ctx, waveform, time, 1, channels * batch);
|
||||||
for (int64_t c = 0; c < channels; ++c) {
|
if (left_pad > 0) {
|
||||||
std::fill(padded.begin(), padded.end(), 0.0f);
|
x = ggml_pad_ext(ctx, x, static_cast<int>(left_pad), 0, 0, 0, 0, 0, 0, 0);
|
||||||
for (int64_t t = 0; t < time; ++t) {
|
|
||||||
padded[static_cast<size_t>(t + left_pad)] = waveform.index(t, c, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t frame = 0; frame < frame_count; ++frame) {
|
auto frames = ggml_conv_1d(ctx, forward_basis, x, hop_length, 0, 1);
|
||||||
const int64_t frame_offset = frame * hop_length;
|
GGML_ASSERT(frames->ne[0] == frame_count);
|
||||||
for (int64_t f = 0; f < n_freqs; ++f) {
|
GGML_ASSERT(frames->ne[1] == stft_channels);
|
||||||
double real = 0.0;
|
GGML_ASSERT(frames->ne[2] == channels * batch);
|
||||||
double imag = 0.0;
|
|
||||||
for (int64_t k = 0; k < filter_len; ++k) {
|
|
||||||
const float sample = padded[static_cast<size_t>(frame_offset + k)];
|
|
||||||
real += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f));
|
|
||||||
imag += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f + n_freqs));
|
|
||||||
}
|
|
||||||
magnitude[static_cast<size_t>(f)] = static_cast<float>(std::sqrt(real * real + imag * imag));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t m = 0; m < n_mels; ++m) {
|
auto real = ggml_ext_slice(ctx, frames, 1, 0, n_freqs);
|
||||||
double mel_value = 0.0;
|
auto imag = ggml_ext_slice(ctx, frames, 1, n_freqs, stft_channels);
|
||||||
for (int64_t f = 0; f < n_freqs; ++f) {
|
auto magnitude = ggml_sqrt(ctx,
|
||||||
mel_value += static_cast<double>(mel_basis.index(f, m)) * static_cast<double>(magnitude[static_cast<size_t>(f)]);
|
ggml_add(ctx,
|
||||||
}
|
ggml_sqr(ctx, real),
|
||||||
log_mel.index(m, frame, c, b) = static_cast<float>(std::log(std::max(mel_value, 1e-5)));
|
ggml_sqr(ctx, imag)));
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return log_mel;
|
magnitude = ggml_cont(ctx, ggml_permute(ctx, magnitude, 1, 0, 2, 3));
|
||||||
|
auto mel = ggml_mul_mat(ctx, mel_basis, magnitude);
|
||||||
|
mel = ggml_log(ctx, ggml_clamp(ctx, mel, 1e-5f, std::numeric_limits<float>::max()));
|
||||||
|
|
||||||
|
return ggml_reshape_4d(ctx, mel, n_mels, frame_count, channels, batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<float> build_hann_resample_filter(int ratio) {
|
static std::vector<float> build_hann_resample_filter(int ratio) {
|
||||||
@ -276,75 +246,6 @@ namespace LTXV {
|
|||||||
return filter;
|
return filter;
|
||||||
}
|
}
|
||||||
|
|
||||||
static sd::Tensor<float> upsample_waveform_hann(const sd::Tensor<float>& waveform_in, int ratio) {
|
|
||||||
auto waveform = normalize_waveform_for_host(waveform_in);
|
|
||||||
if (ratio <= 1) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int lowpass_filter_width = 6;
|
|
||||||
const double rolloff = 0.99;
|
|
||||||
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
|
|
||||||
const int kernel_size = 2 * width * ratio + 1;
|
|
||||||
const int pad = width;
|
|
||||||
const int pad_left = 2 * width * ratio;
|
|
||||||
const int pad_right = kernel_size - ratio;
|
|
||||||
const int64_t time = waveform.shape()[0];
|
|
||||||
const int64_t channels = waveform.shape()[1];
|
|
||||||
const int64_t batch = waveform.shape()[2];
|
|
||||||
const int64_t padded_time = time + 2 * pad;
|
|
||||||
const int64_t conv_out_time = (padded_time - 1) * ratio + kernel_size;
|
|
||||||
const int64_t cropped_time = conv_out_time - pad_left - pad_right;
|
|
||||||
auto filter = build_hann_resample_filter(ratio);
|
|
||||||
|
|
||||||
sd::Tensor<float> output({cropped_time, channels, batch});
|
|
||||||
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
|
|
||||||
std::vector<float> conv_out(static_cast<size_t>(conv_out_time), 0.0f);
|
|
||||||
|
|
||||||
for (int64_t b = 0; b < batch; ++b) {
|
|
||||||
for (int64_t c = 0; c < channels; ++c) {
|
|
||||||
std::fill(padded.begin(), padded.end(), 0.0f);
|
|
||||||
const float first = waveform.index(0, c, b);
|
|
||||||
const float last = waveform.index(time - 1, c, b);
|
|
||||||
for (int i = 0; i < pad; ++i) {
|
|
||||||
padded[static_cast<size_t>(i)] = first;
|
|
||||||
padded[static_cast<size_t>(pad + time + i)] = last;
|
|
||||||
}
|
|
||||||
for (int64_t t = 0; t < time; ++t) {
|
|
||||||
padded[static_cast<size_t>(pad + t)] = waveform.index(t, c, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::fill(conv_out.begin(), conv_out.end(), 0.0f);
|
|
||||||
for (int64_t t = 0; t < padded_time; ++t) {
|
|
||||||
const double sample = static_cast<double>(padded[static_cast<size_t>(t)]) * ratio;
|
|
||||||
const int64_t out_base = t * ratio;
|
|
||||||
for (int k = 0; k < kernel_size; ++k) {
|
|
||||||
conv_out[static_cast<size_t>(out_base + k)] += static_cast<float>(sample * filter[static_cast<size_t>(k)]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t t = 0; t < cropped_time; ++t) {
|
|
||||||
output.index(t, c, b) = conv_out[static_cast<size_t>(t + pad_left)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
static sd::Tensor<float> crop_waveform_samples(const sd::Tensor<float>& waveform_in, int64_t target_samples) {
|
|
||||||
auto waveform = normalize_waveform_for_host(waveform_in);
|
|
||||||
if (waveform.shape()[0] == target_samples) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
if (waveform.shape()[0] > target_samples) {
|
|
||||||
return sd::ops::slice(waveform, 0, 0, target_samples);
|
|
||||||
}
|
|
||||||
sd::Tensor<float> output({target_samples, waveform.shape()[1], waveform.shape()[2]});
|
|
||||||
sd::ops::slice_assign(&output, 0, 0, waveform.shape()[0], waveform);
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ggml_type audio_conv_weight_type(ggml_type type) {
|
static ggml_type audio_conv_weight_type(ggml_type type) {
|
||||||
return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type;
|
return type == GGML_TYPE_BF16 ? GGML_TYPE_F16 : type;
|
||||||
}
|
}
|
||||||
@ -413,22 +314,101 @@ namespace LTXV {
|
|||||||
return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1);
|
return ggml_reshape_4d(ctx, out, out->ne[0], out->ne[1], 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ggml_tensor* reverse_1d_filter(ggml_context* ctx, ggml_tensor* filter) {
|
||||||
|
GGML_ASSERT(ctx != nullptr);
|
||||||
|
GGML_ASSERT(filter != nullptr);
|
||||||
|
GGML_ASSERT(filter->ne[1] == 1);
|
||||||
|
GGML_ASSERT(filter->ne[2] == 1);
|
||||||
|
GGML_ASSERT(filter->ne[3] == 1);
|
||||||
|
|
||||||
|
ggml_tensor* reversed = nullptr;
|
||||||
|
for (int64_t k = filter->ne[0] - 1; k >= 0; --k) {
|
||||||
|
auto slice = ggml_ext_slice(ctx, filter, 0, k, k + 1);
|
||||||
|
reversed = reversed == nullptr ? slice : ggml_concat(ctx, reversed, slice, 0);
|
||||||
|
}
|
||||||
|
return reversed;
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx,
|
static ggml_tensor* depthwise_conv_transpose1d(ggml_context* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
ggml_tensor* filter,
|
ggml_tensor* filter,
|
||||||
int stride) {
|
int stride) {
|
||||||
GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1);
|
GGML_ASSERT(x->ne[2] == 1 && x->ne[3] == 1);
|
||||||
GGML_ASSERT(filter->ne[1] == 1);
|
GGML_ASSERT(filter->ne[1] == 1);
|
||||||
|
GGML_ASSERT(filter->ne[2] == 1 && filter->ne[3] == 1);
|
||||||
|
|
||||||
ggml_tensor* out = nullptr;
|
const int64_t time = x->ne[0];
|
||||||
for (int64_t c = 0; c < x->ne[1]; ++c) {
|
const int64_t channels = x->ne[1];
|
||||||
auto xi = ggml_ext_slice(ctx, x, 1, c, c + 1);
|
const int64_t kernel_size = filter->ne[0];
|
||||||
auto yi = ggml_conv_transpose_1d(ctx, filter, xi, stride, 0, 1);
|
const int64_t out_time = (time - 1) * stride + kernel_size;
|
||||||
yi = ggml_ext_scale(ctx, yi, static_cast<float>(stride));
|
|
||||||
yi = ggml_reshape_4d(ctx, yi, yi->ne[0], 1, 1, 1);
|
auto x_flat = ggml_reshape_3d(ctx, x, 1, time, channels);
|
||||||
out = out == nullptr ? yi : ggml_concat(ctx, out, yi, 1);
|
if (stride > 1) {
|
||||||
|
auto zero_unit = ggml_ext_scale(ctx, x_flat, 0.0f);
|
||||||
|
auto zero_tail = zero_unit;
|
||||||
|
for (int i = 1; i < stride - 1; ++i) {
|
||||||
|
zero_tail = ggml_concat(ctx, zero_tail, zero_unit, 0);
|
||||||
}
|
}
|
||||||
return out;
|
x_flat = ggml_concat(ctx, x_flat, zero_tail, 0);
|
||||||
|
}
|
||||||
|
x_flat = ggml_reshape_3d(ctx, x_flat, time * stride, 1, channels);
|
||||||
|
|
||||||
|
auto reversed_filter = reverse_1d_filter(ctx, filter);
|
||||||
|
auto out = ggml_conv_1d(ctx, reversed_filter, x_flat, 1, static_cast<int>(kernel_size - 1), 1);
|
||||||
|
if (out->ne[0] > out_time) {
|
||||||
|
out = ggml_ext_slice(ctx, out, 0, 0, out_time);
|
||||||
|
}
|
||||||
|
GGML_ASSERT(out->ne[0] == out_time);
|
||||||
|
GGML_ASSERT(out->ne[1] == 1);
|
||||||
|
GGML_ASSERT(out->ne[2] == channels);
|
||||||
|
|
||||||
|
out = ggml_ext_scale(ctx, out, static_cast<float>(stride));
|
||||||
|
return ggml_reshape_4d(ctx, out, out_time, channels, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_tensor* upsample_waveform_hann(GGMLRunnerContext* runner_ctx,
|
||||||
|
ggml_tensor* waveform,
|
||||||
|
ggml_tensor* filter,
|
||||||
|
int ratio) {
|
||||||
|
auto ctx = runner_ctx->ggml_ctx;
|
||||||
|
GGML_ASSERT(ctx != nullptr);
|
||||||
|
GGML_ASSERT(waveform != nullptr);
|
||||||
|
GGML_ASSERT(filter != nullptr);
|
||||||
|
GGML_ASSERT(waveform->ne[3] == 1);
|
||||||
|
if (ratio <= 1) {
|
||||||
|
return waveform;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int lowpass_filter_width = 6;
|
||||||
|
const double rolloff = 0.99;
|
||||||
|
const int width = static_cast<int>(std::ceil(static_cast<double>(lowpass_filter_width) / rolloff));
|
||||||
|
const int kernel_size = 2 * width * ratio + 1;
|
||||||
|
const int pad = width;
|
||||||
|
const int pad_left = 2 * width * ratio;
|
||||||
|
const int pad_right = kernel_size - ratio;
|
||||||
|
const int64_t time = waveform->ne[0];
|
||||||
|
const int64_t channels = waveform->ne[1];
|
||||||
|
const int64_t batch = waveform->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(filter->ne[0] == kernel_size);
|
||||||
|
|
||||||
|
auto x = ggml_reshape_3d(ctx, waveform, time, channels * batch, 1);
|
||||||
|
x = replicate_pad_1d(runner_ctx, x, pad, pad);
|
||||||
|
x = depthwise_conv_transpose1d(ctx, x, filter, ratio);
|
||||||
|
x = ggml_ext_slice(ctx, x, 0, pad_left, x->ne[0] - pad_right);
|
||||||
|
return ggml_reshape_3d(ctx, x, x->ne[0], channels, batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_tensor* crop_waveform_samples(ggml_context* ctx,
|
||||||
|
ggml_tensor* waveform,
|
||||||
|
int64_t target_samples) {
|
||||||
|
GGML_ASSERT(ctx != nullptr);
|
||||||
|
GGML_ASSERT(waveform != nullptr);
|
||||||
|
if (waveform->ne[0] == target_samples) {
|
||||||
|
return waveform;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(waveform->ne[0] > target_samples);
|
||||||
|
return ggml_ext_slice(ctx, waveform, 0, 0, target_samples);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PixelNorm2D : public UnaryBlock {
|
struct PixelNorm2D : public UnaryBlock {
|
||||||
@ -950,41 +930,66 @@ namespace LTXV {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* decode_to_mel(GGMLRunnerContext* ctx,
|
ggml_tensor* decode(GGMLRunnerContext* ctx,
|
||||||
ggml_tensor* latent,
|
ggml_tensor* latent,
|
||||||
int target_time,
|
ggml_tensor* bwe_skip_filter) {
|
||||||
int target_freq) {
|
int target_time = static_cast<int>(latent->ne[1]) * config.latent_downsample_factor() -
|
||||||
|
(config.latent_downsample_factor() - 1);
|
||||||
|
int target_freq = config.mel_bins;
|
||||||
|
|
||||||
|
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
|
||||||
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
|
auto mean = params["audio_vae.per_channel_statistics.mean-of-means"];
|
||||||
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
|
auto stddev = params["audio_vae.per_channel_statistics.std-of-means"];
|
||||||
auto decoder = std::dynamic_pointer_cast<AudioDecoder>(blocks["audio_vae.decoder"]);
|
auto mel = decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
|
||||||
return decoder->forward(ctx, latent, mean, stddev, target_time, target_freq);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* run_vocoder(GGMLRunnerContext* ctx, ggml_tensor* mel) {
|
|
||||||
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
|
auto vocoder = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.vocoder"]);
|
||||||
return vocoder->forward(ctx, mel);
|
auto waveform = vocoder->forward(ctx, mel);
|
||||||
|
|
||||||
|
if (config.has_bwe) {
|
||||||
|
GGML_ASSERT(bwe_skip_filter != nullptr);
|
||||||
|
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
|
||||||
|
const int64_t low_time = waveform->ne[0];
|
||||||
|
const int64_t out_time = low_time * bwe_ratio;
|
||||||
|
int64_t remainder = low_time % config.bwe_hop_length;
|
||||||
|
auto bwe_waveform = waveform;
|
||||||
|
if (remainder != 0) {
|
||||||
|
bwe_waveform = ggml_pad_ext(ctx->ggml_ctx,
|
||||||
|
bwe_waveform,
|
||||||
|
0,
|
||||||
|
static_cast<int>(config.bwe_hop_length - remainder),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* run_bwe_generator(GGMLRunnerContext* ctx, ggml_tensor* mel) {
|
auto mel_basis = params["vocoder.mel_stft.mel_basis"];
|
||||||
GGML_ASSERT(config.has_bwe);
|
auto stft_basis = params["vocoder.mel_stft.stft_fn.forward_basis"];
|
||||||
|
GGML_ASSERT(mel_basis != nullptr && stft_basis != nullptr);
|
||||||
|
auto bwe_mel = compute_log_mel_spectrogram(ctx, bwe_waveform, stft_basis, mel_basis, config.bwe_hop_length);
|
||||||
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
|
auto bwe_generator = std::dynamic_pointer_cast<Vocoder>(blocks["vocoder.bwe_generator"]);
|
||||||
return bwe_generator->forward(ctx, mel);
|
auto residual = bwe_generator->forward(ctx, bwe_mel);
|
||||||
|
|
||||||
|
auto skip = upsample_waveform_hann(ctx,
|
||||||
|
bwe_waveform,
|
||||||
|
bwe_skip_filter,
|
||||||
|
bwe_ratio);
|
||||||
|
waveform = ggml_clamp(ctx->ggml_ctx,
|
||||||
|
ggml_add(ctx->ggml_ctx, residual, skip),
|
||||||
|
-1.0f,
|
||||||
|
1.0f);
|
||||||
|
waveform = crop_waveform_samples(ctx->ggml_ctx, waveform, out_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* mel_basis_tensor() const {
|
return waveform;
|
||||||
auto iter = params.find("vocoder.mel_stft.mel_basis");
|
|
||||||
return iter == params.end() ? nullptr : iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor* stft_forward_basis_tensor() const {
|
|
||||||
auto iter = params.find("vocoder.mel_stft.stft_fn.forward_basis");
|
|
||||||
return iter == params.end() ? nullptr : iter->second;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LTXAudioVAERunner : public GGMLRunner {
|
struct LTXAudioVAERunner : public GGMLRunner {
|
||||||
LTXAudioVAEConfig config;
|
LTXAudioVAEConfig config;
|
||||||
LTXAudioVAE model;
|
LTXAudioVAE model;
|
||||||
|
sd::Tensor<float> bwe_skip_filter_tensor;
|
||||||
|
|
||||||
LTXAudioVAERunner(ggml_backend_t backend,
|
LTXAudioVAERunner(ggml_backend_t backend,
|
||||||
ggml_backend_t params_backend,
|
ggml_backend_t params_backend,
|
||||||
@ -994,6 +999,10 @@ namespace LTXV {
|
|||||||
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
|
config(LTXAudioVAEConfig::detect_from_weights(tensor_storage_map)),
|
||||||
model(config) {
|
model(config) {
|
||||||
model.init(params_ctx, tensor_storage_map, prefix);
|
model.init(params_ctx, tensor_storage_map, prefix);
|
||||||
|
if (config.has_bwe) {
|
||||||
|
const int bwe_ratio = config.bwe_output_sample_rate / config.bwe_input_sample_rate;
|
||||||
|
bwe_skip_filter_tensor = sd::Tensor<float>::from_vector(build_hann_resample_filter(bwe_ratio));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
|
||||||
@ -1008,77 +1017,22 @@ namespace LTXV {
|
|||||||
return "ltx_audio_vae";
|
return "ltx_audio_vae";
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_base_graph(const sd::Tensor<float>& latent_tensor) {
|
|
||||||
auto latent = make_input(latent_tensor);
|
|
||||||
int target_time = static_cast<int>(latent_tensor.shape()[1]) * config.latent_downsample_factor() -
|
|
||||||
(config.latent_downsample_factor() - 1);
|
|
||||||
int target_freq = config.mel_bins;
|
|
||||||
|
|
||||||
ggml_cgraph* gf = new_graph_custom(655360);
|
|
||||||
auto runner_ctx = GGMLRunner::get_context();
|
|
||||||
auto mel = model.decode_to_mel(&runner_ctx, latent, target_time, target_freq);
|
|
||||||
auto waveform = model.run_vocoder(&runner_ctx, mel);
|
|
||||||
ggml_build_forward_expand(gf, waveform);
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_cgraph* build_bwe_graph(const sd::Tensor<float>& mel_tensor) {
|
|
||||||
auto mel = make_input(mel_tensor);
|
|
||||||
ggml_cgraph* gf = new_graph_custom(655360);
|
|
||||||
auto runner_ctx = GGMLRunner::get_context();
|
|
||||||
auto residual = model.run_bwe_generator(&runner_ctx, mel);
|
|
||||||
ggml_build_forward_expand(gf, residual);
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<float> compute_base_waveform(int n_threads,
|
|
||||||
const sd::Tensor<float>& latent_tensor) {
|
|
||||||
auto get_graph = [&]() -> ggml_cgraph* {
|
|
||||||
return build_base_graph(latent_tensor);
|
|
||||||
};
|
|
||||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<float> compute_bwe_residual(int n_threads,
|
|
||||||
const sd::Tensor<float>& mel_tensor) {
|
|
||||||
auto get_graph = [&]() -> ggml_cgraph* {
|
|
||||||
return build_bwe_graph(mel_tensor);
|
|
||||||
};
|
|
||||||
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::Tensor<float> decode(int n_threads,
|
sd::Tensor<float> decode(int n_threads,
|
||||||
const sd::Tensor<float>& latent_tensor) {
|
const sd::Tensor<float>& latent_tensor) {
|
||||||
auto waveform = compute_base_waveform(n_threads, latent_tensor);
|
int64_t t0 = ggml_time_ms();
|
||||||
if (!config.has_bwe || waveform.empty()) {
|
auto get_graph = [&]() -> ggml_cgraph* {
|
||||||
return waveform;
|
auto latent = make_input(latent_tensor);
|
||||||
}
|
ggml_tensor* bwe_skip_filter = config.has_bwe ? make_input(bwe_skip_filter_tensor) : nullptr;
|
||||||
|
ggml_cgraph* gf = new_graph_custom(655360);
|
||||||
auto waveform_host = normalize_waveform_for_host(waveform);
|
auto runner_ctx = GGMLRunner::get_context();
|
||||||
const int64_t low_time = waveform_host.shape()[0];
|
auto waveform = model.decode(&runner_ctx, latent, bwe_skip_filter);
|
||||||
const int64_t out_time = low_time * config.bwe_output_sample_rate / config.bwe_input_sample_rate;
|
ggml_build_forward_expand(gf, waveform);
|
||||||
int64_t remainder = low_time % config.bwe_hop_length;
|
return gf;
|
||||||
if (remainder != 0) {
|
};
|
||||||
sd::Tensor<float> padded({low_time + (config.bwe_hop_length - remainder), waveform_host.shape()[1], waveform_host.shape()[2]});
|
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), 4);
|
||||||
sd::ops::slice_assign(&padded, 0, 0, low_time, waveform_host);
|
int64_t t1 = ggml_time_ms();
|
||||||
waveform_host = std::move(padded);
|
LOG_INFO("ltx audio vae decode completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||||
}
|
return result;
|
||||||
|
|
||||||
auto mel_basis_tensor = model.mel_basis_tensor();
|
|
||||||
auto stft_basis_tensor = model.stft_forward_basis_tensor();
|
|
||||||
GGML_ASSERT(mel_basis_tensor != nullptr && stft_basis_tensor != nullptr);
|
|
||||||
auto mel_basis = load_param_tensor_f32(mel_basis_tensor);
|
|
||||||
auto forward_basis = load_param_tensor_f32(stft_basis_tensor);
|
|
||||||
auto bwe_mel = compute_log_mel_spectrogram(waveform_host, forward_basis, mel_basis, config.bwe_hop_length);
|
|
||||||
auto residual_raw = compute_bwe_residual(n_threads, bwe_mel);
|
|
||||||
if (residual_raw.empty()) {
|
|
||||||
return waveform;
|
|
||||||
}
|
|
||||||
auto residual = normalize_waveform_for_host(residual_raw);
|
|
||||||
auto skip = upsample_waveform_hann(waveform_host, config.bwe_output_sample_rate / config.bwe_input_sample_rate);
|
|
||||||
auto combined = sd::ops::clamp(residual + skip, -1.0f, 1.0f);
|
|
||||||
auto cropped = crop_waveform_samples(combined, out_time);
|
|
||||||
return restore_trailing_singleton_dims(cropped, 4);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void test(const std::string& input_path) {
|
void test(const std::string& input_path) {
|
||||||
|
|||||||
@ -5218,14 +5218,24 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
sd_ctx->sd->diffusion_model->free_params_buffer();
|
sd_ctx->sd->diffusion_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t latent_end = ggml_time_ms();
|
||||||
|
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
|
||||||
|
|
||||||
sd_audio_t* generated_audio = nullptr;
|
sd_audio_t* generated_audio = nullptr;
|
||||||
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
|
||||||
latents.audio_length > 0 &&
|
latents.audio_length > 0 &&
|
||||||
sd_ctx->sd->audio_vae_model != nullptr) {
|
sd_ctx->sd->audio_vae_model != nullptr) {
|
||||||
|
int64_t audio_latent_decode_start = ggml_time_ms();
|
||||||
|
|
||||||
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
|
auto audio_latent = unpack_ltxav_audio_latent(final_latent,
|
||||||
latents.audio_length,
|
latents.audio_length,
|
||||||
sd_ctx->sd->get_latent_channel());
|
sd_ctx->sd->get_latent_channel());
|
||||||
if (!audio_latent.empty()) {
|
if (!audio_latent.empty()) {
|
||||||
|
LOG_DEBUG("decode audio latent %dx%dx%dx%d",
|
||||||
|
(int)audio_latent.shape()[0],
|
||||||
|
(int)audio_latent.shape()[1],
|
||||||
|
(int)audio_latent.shape()[2],
|
||||||
|
(int)audio_latent.shape()[3]);
|
||||||
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
|
auto waveform = sd_ctx->sd->decode_ltx_audio_latent(audio_latent);
|
||||||
if (!waveform.empty()) {
|
if (!waveform.empty()) {
|
||||||
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
|
generated_audio = waveform_to_sd_audio(sd_ctx->sd, waveform);
|
||||||
@ -5233,6 +5243,8 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
|
LOG_WARN("LTX audio latent decode failed; continuing with silent video output");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int64_t audio_latent_decode_end = ggml_time_ms();
|
||||||
|
LOG_INFO("decoding audio latent completed, taking %.2fs", (audio_latent_decode_end - audio_latent_decode_start) * 1.0f / 1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (latents.video_conditioning_frame_count > 0) {
|
if (latents.video_conditioning_frame_count > 0) {
|
||||||
@ -5245,9 +5257,6 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
|
|||||||
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t latent_end = ggml_time_ms();
|
|
||||||
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
|
|
||||||
|
|
||||||
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
|
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
|
||||||
if (result == nullptr) {
|
if (result == nullptr) {
|
||||||
free_sd_audio(generated_audio);
|
free_sd_audio(generated_audio);
|
||||||
|
|||||||
88
src/tae.hpp
88
src/tae.hpp
@ -259,7 +259,51 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_tensor* patchify(ggml_context* ctx,
|
class WideMemBlock : public GGMLBlock {
|
||||||
|
bool has_skip_conv = false;
|
||||||
|
|
||||||
|
public:
|
||||||
|
WideMemBlock(int channels, int out_channels)
|
||||||
|
: has_skip_conv(channels != out_channels) {
|
||||||
|
int groups = std::max(1, out_channels / 64);
|
||||||
|
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels * 2, out_channels, {1, 1}, {1, 1}));
|
||||||
|
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {1, 1}, {1, 1}));
|
||||||
|
blocks["conv.6"] = std::shared_ptr<GGMLBlock>(new Conv2d_grouped(out_channels, out_channels, groups, {3, 3}, {1, 1}, {1, 1}));
|
||||||
|
if (has_skip_conv) {
|
||||||
|
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* past) {
|
||||||
|
// x: [n, channels, h, w]
|
||||||
|
auto conv0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
|
||||||
|
auto conv1 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.2"]);
|
||||||
|
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
|
||||||
|
auto conv3 = std::dynamic_pointer_cast<Conv2d_grouped>(blocks["conv.6"]);
|
||||||
|
|
||||||
|
auto h = ggml_concat(ctx->ggml_ctx, x, past, 2);
|
||||||
|
h = conv0->forward(ctx, h);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
h = conv1->forward(ctx, h);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
h = conv2->forward(ctx, h);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
h = conv3->forward(ctx, h);
|
||||||
|
|
||||||
|
auto skip = x;
|
||||||
|
if (has_skip_conv) {
|
||||||
|
auto skip_conv = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
|
||||||
|
skip = skip_conv->forward(ctx, x);
|
||||||
|
}
|
||||||
|
h = ggml_add_inplace(ctx->ggml_ctx, h, skip);
|
||||||
|
h = ggml_relu_inplace(ctx->ggml_ctx, h);
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_tensor*
|
||||||
|
patchify(ggml_context* ctx,
|
||||||
ggml_tensor* x,
|
ggml_tensor* x,
|
||||||
int64_t patch_size,
|
int64_t patch_size,
|
||||||
int64_t b = 1) {
|
int64_t b = 1) {
|
||||||
@ -325,7 +369,6 @@ public:
|
|||||||
int t_downscale = 1;
|
int t_downscale = 1;
|
||||||
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
TinyVideoEncoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_downscale = {true, true, false})
|
||||||
: z_channels(z_channels), patch_size(patch_size) {
|
: z_channels(z_channels), patch_size(patch_size) {
|
||||||
// self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
|
||||||
t_downscale = 1;
|
t_downscale = 1;
|
||||||
for (bool downscale : time_downscale) {
|
for (bool downscale : time_downscale) {
|
||||||
if (downscale) {
|
if (downscale) {
|
||||||
@ -384,11 +427,18 @@ class TinyVideoDecoder : public UnaryBlock {
|
|||||||
int channels[num_layers + 1] = {256, 128, 64, 64};
|
int channels[num_layers + 1] = {256, 128, 64, 64};
|
||||||
int patch_size = 1;
|
int patch_size = 1;
|
||||||
int t_upscale = 1;
|
int t_upscale = 1;
|
||||||
|
bool is_wide = false;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true})
|
TinyVideoDecoder(int z_channels = 4, int patch_size = 1, std::vector<bool> time_upscale = {false, true, true}, bool is_wide = false)
|
||||||
: z_channels(z_channels), patch_size(patch_size) {
|
: z_channels(z_channels), patch_size(patch_size), is_wide(is_wide) {
|
||||||
t_upscale = 1;
|
t_upscale = 1;
|
||||||
|
if (is_wide) {
|
||||||
|
channels[0] = 1024;
|
||||||
|
channels[1] = 512;
|
||||||
|
channels[2] = 256;
|
||||||
|
}
|
||||||
|
|
||||||
for (bool upscale : time_upscale) {
|
for (bool upscale : time_upscale) {
|
||||||
if (upscale) {
|
if (upscale) {
|
||||||
t_upscale *= 2;
|
t_upscale *= 2;
|
||||||
@ -400,8 +450,12 @@ public:
|
|||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
int stride = time_upscale[i] ? 2 : 1;
|
int stride = time_upscale[i] ? 2 : 1;
|
||||||
for (int j = 0; j < num_blocks; j++) {
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
|
if (is_wide) {
|
||||||
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new WideMemBlock(channels[i], channels[i]));
|
||||||
|
} else {
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels[i], channels[i]));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
index++; // nn.Upsample()
|
index++; // nn.Upsample()
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TGrow(channels[i], stride));
|
||||||
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels[i], channels[i + 1], {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
|
||||||
@ -425,10 +479,15 @@ public:
|
|||||||
int index = 3;
|
int index = 3;
|
||||||
for (int i = 0; i < num_layers; i++) {
|
for (int i = 0; i < num_layers; i++) {
|
||||||
for (int j = 0; j < num_blocks; j++) {
|
for (int j = 0; j < num_blocks; j++) {
|
||||||
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
|
||||||
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
|
||||||
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
|
||||||
|
if (is_wide) {
|
||||||
|
auto block = std::dynamic_pointer_cast<WideMemBlock>(blocks[std::to_string(index++)]);
|
||||||
h = block->forward(ctx, h, mem);
|
h = block->forward(ctx, h, mem);
|
||||||
|
} else{
|
||||||
|
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
|
||||||
|
h = block->forward(ctx, h, mem);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// upsample
|
// upsample
|
||||||
index++;
|
index++;
|
||||||
@ -455,6 +514,7 @@ class TAEHV : public GGMLBlock {
|
|||||||
protected:
|
protected:
|
||||||
bool decode_only;
|
bool decode_only;
|
||||||
SDVersion version;
|
SDVersion version;
|
||||||
|
bool is_wide;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int z_channels = 16;
|
int z_channels = 16;
|
||||||
@ -462,8 +522,8 @@ public:
|
|||||||
std::vector<bool> time_upscale = {false, true, true};
|
std::vector<bool> time_upscale = {false, true, true};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2, bool is_wide = false)
|
||||||
: decode_only(decode_only), version(version) {
|
: decode_only(decode_only), version(version), is_wide(is_wide) {
|
||||||
int patch = 1;
|
int patch = 1;
|
||||||
if (version == VERSION_WAN2_2_TI2V) {
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
z_channels = 48;
|
z_channels = 48;
|
||||||
@ -474,7 +534,7 @@ public:
|
|||||||
time_downscale = {true, true, true};
|
time_downscale = {true, true, true};
|
||||||
time_upscale = {true, true, true};
|
time_upscale = {true, true, true};
|
||||||
}
|
}
|
||||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale));
|
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoDecoder(z_channels, patch, time_upscale, is_wide));
|
||||||
if (!decode_only) {
|
if (!decode_only) {
|
||||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
|
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new TinyVideoEncoder(z_channels, patch, time_downscale));
|
||||||
}
|
}
|
||||||
@ -623,6 +683,7 @@ struct TinyImageAutoEncoder : public VAE {
|
|||||||
struct TinyVideoAutoEncoder : public VAE {
|
struct TinyVideoAutoEncoder : public VAE {
|
||||||
TAEHV taehv;
|
TAEHV taehv;
|
||||||
bool decode_only = false;
|
bool decode_only = false;
|
||||||
|
bool is_wide = false;
|
||||||
|
|
||||||
TinyVideoAutoEncoder(ggml_backend_t backend,
|
TinyVideoAutoEncoder(ggml_backend_t backend,
|
||||||
ggml_backend_t params_backend,
|
ggml_backend_t params_backend,
|
||||||
@ -631,8 +692,14 @@ struct TinyVideoAutoEncoder : public VAE {
|
|||||||
bool decoder_only = true,
|
bool decoder_only = true,
|
||||||
SDVersion version = VERSION_WAN2)
|
SDVersion version = VERSION_WAN2)
|
||||||
: decode_only(decoder_only),
|
: decode_only(decoder_only),
|
||||||
taehv(decoder_only, version),
|
|
||||||
VAE(version, backend, params_backend) {
|
VAE(version, backend, params_backend) {
|
||||||
|
for (auto tensor_storage : tensor_storage_map) {
|
||||||
|
if (tensor_storage.first.find(prefix + ".3.conv.6.weight") != std::string::npos) {
|
||||||
|
is_wide = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
taehv = TAEHV(decoder_only, version, is_wide);
|
||||||
scale_input = false;
|
scale_input = false;
|
||||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||||
}
|
}
|
||||||
@ -663,7 +730,8 @@ struct TinyVideoAutoEncoder : public VAE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
|
||||||
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
ggml_cgraph* gf = decode_graph && is_wide ? ggml_new_graph_custom(compute_ctx, 4096, false)
|
||||||
|
: ggml_new_graph(compute_ctx);
|
||||||
ggml_tensor* z = make_input(z_tensor);
|
ggml_tensor* z = make_input(z_tensor);
|
||||||
auto runner_ctx = get_context();
|
auto runner_ctx = get_context();
|
||||||
ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);
|
ggml_tensor* out = decode_graph ? taehv.decode(&runner_ctx, z) : taehv.encode(&runner_ctx, z);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user