#ifndef __SD_TENSOR_HPP__ #define __SD_TENSOR_HPP__ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "rng.hpp" namespace sd { template class Tensor; inline std::vector tensor_unravel_index(int64_t flat, const std::vector& shape); [[noreturn]] inline void tensor_throw_invalid_argument(const std::string& message) { std::fprintf(stderr, "sd::Tensor error: %s\n", message.c_str()); std::fflush(stderr); throw std::invalid_argument(message); } inline std::string tensor_shape_to_string(const std::vector& shape) { std::ostringstream oss; oss << "["; for (size_t i = 0; i < shape.size(); ++i) { if (i != 0) { oss << ", "; } oss << shape[i]; } oss << "]"; return oss.str(); } inline int64_t tensor_numel(const std::vector& shape) { if (shape.empty()) { return 0; } int64_t numel = 1; for (int64_t dim : shape) { if (dim < 0) { tensor_throw_invalid_argument("Tensor shape must be non-negative, got shape=" + tensor_shape_to_string(shape)); } numel *= dim; } return numel; } template class Tensor { public: Tensor() = default; explicit Tensor(std::vector shape) : data_(static_cast(tensor_numel(shape))), shape_(std::move(shape)) { } Tensor(std::vector shape, std::vector data) : data_(std::move(data)), shape_(std::move(shape)) { if (static_cast(data_.size()) != tensor_numel(shape_)) { tensor_throw_invalid_argument("Tensor data size does not match shape: data.size()=" + std::to_string(data_.size()) + ", shape=" + tensor_shape_to_string(shape_) + ", numel=" + std::to_string(tensor_numel(shape_))); } } const std::vector& shape() const { return shape_; } int64_t dim() const { return static_cast(shape_.size()); } int64_t numel() const { return static_cast(data_.size()); } bool empty() const { return data_.empty(); } T* data() { return data_.data(); } const T* data() const { return data_.data(); } std::vector& values() { return data_; } const std::vector& values() const { return data_; } void resize(std::vector shape) { shape_ = std::move(shape); data_.resize(static_cast(tensor_numel(shape_))); } Tensor& reshape_(std::vector shape) { if (tensor_numel(shape) != numel()) { tensor_throw_invalid_argument("Tensor reshape changes element count: from shape=" + tensor_shape_to_string(shape_) + " (numel=" + std::to_string(numel()) + ") to shape=" + tensor_shape_to_string(shape) + " (numel=" + std::to_string(tensor_numel(shape)) + ")"); } shape_ = std::move(shape); return *this; } Tensor reshape(std::vector shape) const { Tensor result = *this; result.reshape_(std::move(shape)); return result; } Tensor& squeeze_() { std::vector new_shape; new_shape.reserve(shape_.size()); for (int64_t dim : shape_) { if (dim != 1) { new_shape.push_back(dim); } } shape_ = std::move(new_shape); return *this; } Tensor& squeeze_(size_t dim) { if (dim >= shape_.size()) { tensor_throw_invalid_argument("Tensor squeeze dimension out of range: dim=" + std::to_string(dim) + ", shape=" + tensor_shape_to_string(shape_)); } if (shape_[dim] != 1) { tensor_throw_invalid_argument("Tensor squeeze requires dimension size 1: dim=" + std::to_string(dim) + ", shape=" + tensor_shape_to_string(shape_)); } shape_.erase(shape_.begin() + static_cast(dim)); return *this; } Tensor squeeze() const { Tensor result = *this; result.squeeze_(); return result; } Tensor squeeze(size_t dim) const { Tensor result = *this; result.squeeze_(dim); return result; } Tensor& unsqueeze_(size_t dim) { if (dim > shape_.size()) { tensor_throw_invalid_argument("Tensor unsqueeze dimension out of range: dim=" + std::to_string(dim) + ", shape=" + tensor_shape_to_string(shape_)); } shape_.insert(shape_.begin() + static_cast(dim), 1); return *this; } Tensor unsqueeze(size_t dim) const { Tensor result = *this; result.unsqueeze_(dim); return result; } Tensor permute(const std::vector& dims) const { if (dims.size() != static_cast(dim())) { tensor_throw_invalid_argument("Tensor permute requires one dimension index per axis: tensor_shape=" + tensor_shape_to_string(shape_) + ", dims_size=" + std::to_string(dims.size())); } std::vector seen(dims.size(), false); std::vector out_shape(dims.size(), 1); for (size_t i = 0; i < dims.size(); ++i) { size_t dim_index = dims[i]; if (dim_index >= dims.size() || seen[dim_index]) { tensor_throw_invalid_argument("Tensor permute dimensions must be a valid permutation: tensor_shape=" + tensor_shape_to_string(shape_)); } seen[dim_index] = true; out_shape[i] = shape_[dim_index]; } Tensor result(out_shape); if (result.numel() == 0) { return result; } for (int64_t flat = 0; flat < result.numel(); ++flat) { std::vector out_coord = tensor_unravel_index(flat, out_shape); std::vector src_coord(static_cast(dim()), 0); for (size_t i = 0; i < dims.size(); ++i) { src_coord[dims[i]] = out_coord[i]; } result[flat] = index(src_coord); } return result; } Tensor& permute_(const std::vector& dims) { *this = permute(dims); return *this; } void fill_(const T& value) { std::fill(data_.begin(), data_.end(), value); } Tensor& masked_fill_(const Tensor& mask, const T& value); T mean() const; static Tensor zeros(std::vector shape) { return Tensor(std::move(shape)); } static Tensor zeros_like(const Tensor& other) { return zeros(other.shape()); } static Tensor ones(std::vector shape) { return full(std::move(shape), static_cast(1)); } static Tensor ones_like(const Tensor& other) { return ones(other.shape()); } static Tensor full(std::vector shape, const T& value) { Tensor tensor(std::move(shape)); tensor.fill_(value); return tensor; } static Tensor randn(std::vector shape, const std::shared_ptr& rng) { static_assert(std::is_same_v, "Tensor::randn currently requires Tensor"); if (!rng) { tensor_throw_invalid_argument("Tensor randn requires a valid RNG"); } const uint32_t size = static_cast(tensor_numel(shape)); return Tensor(std::move(shape), rng->randn(size)); } static Tensor randn_like(const Tensor& other, const std::shared_ptr& rng) { return randn(other.shape(), rng); } static Tensor from_vector(std::vector data) { const int64_t size = static_cast(data.size()); return Tensor({size}, std::move(data)); } T& index(const std::vector& coord) { return data_.at(offset_of(coord)); } const T& index(const std::vector& coord) const { return data_.at(offset_of(coord)); } template && ...)>> T& index(Indices... indices) { return index(std::vector{static_cast(indices)...}); } template && ...)>> const T& index(Indices... indices) const { return index(std::vector{static_cast(indices)...}); } T& operator[](int64_t index) { return data_.at(static_cast(index)); } const T& operator[](int64_t index) const { return data_.at(static_cast(index)); } private: size_t offset_of(const std::vector& coord) const { if (coord.size() != shape_.size()) { tensor_throw_invalid_argument("Tensor index rank mismatch: coord_rank=" + std::to_string(coord.size()) + ", shape=" + tensor_shape_to_string(shape_)); } size_t offset = 0; size_t stride = 1; for (size_t i = 0; i < shape_.size(); ++i) { if (coord[i] < 0 || coord[i] >= shape_[i]) { tensor_throw_invalid_argument("Tensor index out of range: shape=" + tensor_shape_to_string(shape_)); } offset += static_cast(coord[i]) * stride; stride *= static_cast(shape_[i]); } return offset; } std::vector data_; std::vector shape_; }; template inline T Tensor::mean() const { if (empty()) { return T{}; } T sum = T{}; for (const T& value : data_) { sum += value; } return sum / static_cast(numel()); } template <> inline float Tensor::mean() const { if (empty()) { return 0.0f; } double sum = 0.0; for (float value : data_) { sum += static_cast(value); } return static_cast(sum / static_cast(numel())); } template inline void tensor_check_same_shape(const Tensor& lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { tensor_throw_invalid_argument("Tensor shapes must match: lhs_shape=" + tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" + tensor_shape_to_string(rhs.shape())); } } inline std::vector tensor_broadcast_shape(const std::vector& lhs, const std::vector& rhs) { size_t ndim = std::max(lhs.size(), rhs.size()); std::vector shape(ndim, 1); for (size_t i = 0; i < ndim; ++i) { int64_t lhs_dim = lhs.size() > i ? lhs[i] : 1; int64_t rhs_dim = rhs.size() > i ? rhs[i] : 1; if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) { tensor_throw_invalid_argument("Tensor shapes are not broadcastable: lhs_shape=" + tensor_shape_to_string(lhs) + ", rhs_shape=" + tensor_shape_to_string(rhs)); } shape[i] = std::max(lhs_dim, rhs_dim); } return shape; } inline std::vector tensor_unravel_index(int64_t flat, const std::vector& shape) { std::vector coord(shape.size(), 0); for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] <= 0) { tensor_throw_invalid_argument("Tensor unravel_index requires positive shape: shape=" + tensor_shape_to_string(shape)); } coord[i] = flat % shape[i]; flat /= shape[i]; } return coord; } inline std::vector tensor_compute_strides(const std::vector& shape) { std::vector strides(shape.size(), 1); int64_t stride = 1; for (size_t i = 0; i < shape.size(); ++i) { strides[i] = stride; stride *= shape[i]; } return strides; } template inline void tensor_for_each_broadcast_offset(const std::vector& out_shape, const std::vector& lhs_shape_raw, const std::vector& lhs_strides_raw, const std::vector& rhs_shape_raw, const std::vector& rhs_strides_raw, F&& fn) { const size_t ndim = out_shape.size(); std::vector out_strides = tensor_compute_strides(out_shape); std::vector lhs_shape(ndim, 1); std::vector lhs_strides(ndim, 0); std::vector rhs_shape(ndim, 1); std::vector rhs_strides(ndim, 0); for (size_t i = 0; i < lhs_shape_raw.size(); ++i) { lhs_shape[i] = lhs_shape_raw[i]; lhs_strides[i] = lhs_strides_raw[i]; } for (size_t i = 0; i < rhs_shape_raw.size(); ++i) { rhs_shape[i] = rhs_shape_raw[i]; rhs_strides[i] = rhs_strides_raw[i]; } const int64_t numel = tensor_numel(out_shape); for (int64_t flat = 0; flat < numel; ++flat) { int64_t remaining = flat; int64_t lhs_offset = 0; int64_t rhs_offset = 0; for (size_t i = ndim; i-- > 0;) { int64_t coord = remaining / out_strides[i]; remaining %= out_strides[i]; if (lhs_shape[i] != 1) { lhs_offset += coord * lhs_strides[i]; } if (rhs_shape[i] != 1) { rhs_offset += coord * rhs_strides[i]; } } fn(flat, lhs_offset, rhs_offset); } } template inline Tensor& Tensor::masked_fill_(const Tensor& mask, const T& value) { if (empty()) { return *this; } tensor_broadcast_shape(shape_, mask.shape()); const std::vector data_strides = tensor_compute_strides(shape_); const std::vector mask_strides = tensor_compute_strides(mask.shape()); const uint8_t* mask_data = mask.data(); tensor_for_each_broadcast_offset(shape_, shape_, data_strides, mask.shape(), mask_strides, [&](int64_t, int64_t data_offset, int64_t mask_offset) { if (mask_data[mask_offset] != 0) { data_[static_cast(data_offset)] = value; } }); return *this; } template ::value>> inline Tensor operator<(const Tensor& lhs, Scalar rhs) { Tensor result(lhs.shape()); const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { result[i] = lhs[i] < value ? 1 : 0; } return result; } template ::value>> inline Tensor operator<(Scalar lhs, const Tensor& rhs) { Tensor result(rhs.shape()); const T value = static_cast(lhs); for (int64_t i = 0; i < rhs.numel(); ++i) { result[i] = value < rhs[i] ? 1 : 0; } return result; } template inline Tensor operator<(const Tensor& lhs, const Tensor& rhs) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] < rhs_data[rhs_offset] ? 1 : 0; }); return result; } template inline Tensor& operator+=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] += rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] += rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator+=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] += value; } return lhs; } template inline Tensor& operator-=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] -= rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] -= rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator-=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] -= value; } return lhs; } template inline Tensor& operator*=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] *= rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] *= rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator*=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] *= value; } return lhs; } template inline Tensor& operator/=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] /= rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] /= rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator/=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] /= value; } return lhs; } template inline Tensor operator+(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] + rhs_data[rhs_offset]; }); return result; } lhs += rhs; return lhs; } template ::value>> inline Tensor operator+(Tensor lhs, Scalar rhs) { lhs += rhs; return lhs; } template ::value>> inline Tensor operator+(Scalar lhs, Tensor rhs) { rhs += lhs; return rhs; } template inline Tensor operator-(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] - rhs_data[rhs_offset]; }); return result; } lhs -= rhs; return lhs; } template ::value>> inline Tensor operator-(Tensor lhs, Scalar rhs) { lhs -= rhs; return lhs; } template ::value>> inline Tensor operator-(Scalar lhs, const Tensor& rhs) { Tensor result = rhs; const T value = static_cast(lhs); for (int64_t i = 0; i < result.numel(); ++i) { result[i] = value - result[i]; } return result; } template inline Tensor operator*(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] * rhs_data[rhs_offset]; }); return result; } lhs *= rhs; return lhs; } template ::value>> inline Tensor operator*(Tensor lhs, Scalar rhs) { lhs *= rhs; return lhs; } template ::value>> inline Tensor operator*(Scalar lhs, Tensor rhs) { rhs *= lhs; return rhs; } template inline Tensor operator/(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] / rhs_data[rhs_offset]; }); return result; } lhs /= rhs; return lhs; } template ::value>> inline Tensor operator/(Tensor lhs, Scalar rhs) { lhs /= rhs; return lhs; } template ::value>> inline Tensor operator/(Scalar lhs, const Tensor& rhs) { Tensor result = rhs; const T value = static_cast(lhs); for (int64_t i = 0; i < result.numel(); ++i) { result[i] = value / result[i]; } return result; } template inline Tensor operator-(const Tensor& tensor) { Tensor result = tensor; for (int64_t i = 0; i < result.numel(); ++i) { result[i] = -result[i]; } return result; } template inline Tensor zeros(std::vector shape) { return Tensor::zeros(std::move(shape)); } template inline Tensor full(std::vector shape, const T& value) { return Tensor::full(std::move(shape), value); } template inline Tensor randn(std::vector shape, const std::shared_ptr& rng) { return Tensor::randn(std::move(shape), rng); } template inline Tensor randn_like(const Tensor& tensor, const std::shared_ptr& rng) { return Tensor::randn(tensor.shape(), rng); } template inline std::vector tensor_to_vector(const Tensor& tensor) { return tensor.values(); } namespace ops { enum class InterpolateMode { Nearest, NearestExact, NearestMax, NearestMin, NearestAvg, Bilinear, Bicubic, Lanczos, }; inline bool is_nearest_like_interpolate_mode(InterpolateMode mode) { return mode == InterpolateMode::Nearest || mode == InterpolateMode::NearestExact || mode == InterpolateMode::NearestMax || mode == InterpolateMode::NearestMin || mode == InterpolateMode::NearestAvg; } inline bool is_2d_filter_interpolate_mode(InterpolateMode mode) { return mode == InterpolateMode::Bilinear || mode == InterpolateMode::Bicubic || mode == InterpolateMode::Lanczos; } inline int64_t nearest_exact_interpolate_index(int64_t output_index, int64_t input_size, int64_t output_size) { const double scale = static_cast(input_size) / static_cast(output_size); const double center = (static_cast(output_index) + 0.5) * scale - 0.5; return std::min(std::max(static_cast(std::floor(center + 0.5)), 0), input_size - 1); } inline double linear_interpolate_weight(double x) { x = std::abs(x); return x < 1.0 ? 1.0 - x : 0.0; } inline double cubic_interpolate_weight(double x) { constexpr double a = -0.75; // Match PyTorch bicubic interpolation. x = std::abs(x); if (x <= 1.0) { return ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0; } if (x < 2.0) { return ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a; } return 0.0; } inline double sinc(double x) { constexpr double pi = 3.14159265358979323846; if (std::abs(x) < 1e-12) { return 1.0; } const double pix = pi * x; return std::sin(pix) / pix; } inline double lanczos_interpolate_weight(double x) { constexpr double radius = 3.0; x = std::abs(x); if (x >= radius) { return 0.0; } return sinc(x) * sinc(x / radius); } struct InterpolateContributor { int64_t index; double weight; }; inline std::vector> make_interpolate_contributors( int64_t input_size, int64_t output_size, InterpolateMode mode, bool antialias) { std::vector> contributors(static_cast(output_size)); const double scale = static_cast(input_size) / static_cast(output_size); const double filter_scale = antialias ? std::max(1.0, scale) : 1.0; for (int64_t out = 0; out < output_size; ++out) { const double center = (static_cast(out) + 0.5) * scale - 0.5; int64_t start = 0; int64_t end = 0; if (mode == InterpolateMode::Bilinear) { const double support = filter_scale; start = static_cast(std::ceil(center - support)); end = static_cast(std::floor(center + support)); } else if (mode == InterpolateMode::Bicubic) { const double support = 2.0 * filter_scale; start = static_cast(std::ceil(center - support)); end = static_cast(std::floor(center + support)); } else if (mode == InterpolateMode::Lanczos) { const double support = 3.0 * filter_scale; start = static_cast(std::ceil(center - support)); end = static_cast(std::floor(center + support)); } else { tensor_throw_invalid_argument("Unsupported 2D filter interpolate mode: mode=" + std::to_string(static_cast(mode))); } double weight_sum = 0.0; std::vector& axis_contributors = contributors[static_cast(out)]; axis_contributors.reserve(static_cast(end - start + 1)); for (int64_t in = start; in <= end; ++in) { double weight = 0.0; if (mode == InterpolateMode::Bilinear) { weight = linear_interpolate_weight((center - static_cast(in)) / filter_scale); } else if (mode == InterpolateMode::Bicubic) { weight = cubic_interpolate_weight((center - static_cast(in)) / filter_scale); } else { weight = lanczos_interpolate_weight((center - static_cast(in)) / filter_scale); } if (weight == 0.0) { continue; } const int64_t clamped_index = std::min(std::max(in, 0), input_size - 1); axis_contributors.push_back({clamped_index, weight}); weight_sum += weight; } if ((antialias || mode == InterpolateMode::Lanczos) && std::abs(weight_sum) > 1e-12) { for (auto& contributor : axis_contributors) { contributor.weight /= weight_sum; } } if (axis_contributors.empty()) { const int64_t nearest = std::min( std::max(static_cast(std::floor(center + 0.5)), 0), input_size - 1); axis_contributors.push_back({nearest, 1.0}); } } return contributors; } template inline Tensor interpolate_2d_filter(const Tensor& input, const std::vector& output_shape, InterpolateMode mode, bool antialias) { if (input.dim() < 2) { tensor_throw_invalid_argument("2D filter interpolate requires rank >= 2: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } for (size_t i = 2; i < output_shape.size(); ++i) { if (input.shape()[i] != output_shape[i]) { tensor_throw_invalid_argument("2D filter interpolate only supports resizing dimensions 0 and 1: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } } Tensor output(output_shape); const int64_t input_width = input.shape()[0]; const int64_t input_height = input.shape()[1]; const int64_t output_width = output_shape[0]; const int64_t output_height = output_shape[1]; const int64_t input_plane = input_width * input_height; const int64_t output_plane = output_width * output_height; const int64_t plane_count = input.numel() / input_plane; auto x_contributors = make_interpolate_contributors(input_width, output_width, mode, antialias); auto y_contributors = make_interpolate_contributors(input_height, output_height, mode, antialias); for (int64_t plane = 0; plane < plane_count; ++plane) { const int64_t input_plane_offset = plane * input_plane; const int64_t output_plane_offset = plane * output_plane; for (int64_t y = 0; y < output_height; ++y) { const auto& y_axis = y_contributors[static_cast(y)]; for (int64_t x = 0; x < output_width; ++x) { const auto& x_axis = x_contributors[static_cast(x)]; double value = 0.0; for (const auto& yc : y_axis) { const int64_t input_row_offset = input_plane_offset + yc.index * input_width; for (const auto& xc : x_axis) { value += static_cast(input.data()[input_row_offset + xc.index]) * xc.weight * yc.weight; } } output.data()[output_plane_offset + y * output_width + x] = static_cast(value); } } } return output; } inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) { if (index < 0) { index += dim_size; } return index; } template inline std::pair resolve_slice_bounds(const Tensor& input, size_t dim, int64_t start, int64_t end) { if (dim >= static_cast(input.dim())) { tensor_throw_invalid_argument("Tensor slice dimension out of range: dim=" + std::to_string(dim) + ", rank=" + std::to_string(input.dim()) + ", input_shape=" + tensor_shape_to_string(input.shape())); } int64_t dim_size = input.shape()[dim]; start = normalize_slice_bound(start, dim_size); end = normalize_slice_bound(end, dim_size); if (start < 0 || start > dim_size || end < 0 || end > dim_size || start > end) { tensor_throw_invalid_argument("Tensor slice bounds out of range: dim=" + std::to_string(dim) + ", start=" + std::to_string(start) + ", end=" + std::to_string(end) + ", input_shape=" + tensor_shape_to_string(input.shape())); } return {start, end}; } template inline Tensor exp(const Tensor& input) { Tensor output(input.shape()); for (int64_t i = 0; i < input.numel(); ++i) { output[i] = static_cast(std::exp(static_cast(input[i]))); } return output; } template inline Tensor clamp(const Tensor& input, const T& min_value, const T& max_value) { if (min_value > max_value) { tensor_throw_invalid_argument("Tensor clamp requires min_value <= max_value"); } Tensor output(input.shape()); for (int64_t i = 0; i < input.numel(); ++i) { output[i] = std::clamp(input[i], min_value, max_value); } return output; } template inline Tensor round(const Tensor& input) { Tensor output(input.shape()); for (int64_t i = 0; i < input.numel(); ++i) { output[i] = static_cast(std::round(static_cast(input[i]))); } return output; } template inline Tensor slice(const Tensor& input, size_t dim, int64_t start, int64_t end) { auto [resolved_start, resolved_end] = resolve_slice_bounds(input, dim, start, end); std::vector out_shape = input.shape(); out_shape[dim] = resolved_end - resolved_start; Tensor output(out_shape); if (output.numel() == 0) { return output; } int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= input.shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(input.dim()); ++i) { outer *= input.shape()[i]; } int64_t src_chunk = (resolved_end - resolved_start) * inner; int64_t src_stride = input.shape()[dim] * inner; for (int64_t i = 0; i < outer; ++i) { const int64_t src_offset = i * src_stride + resolved_start * inner; const int64_t dst_offset = i * src_chunk; std::copy_n(input.data() + src_offset, src_chunk, output.data() + dst_offset); } return output; } template inline Tensor narrow(const Tensor& input, size_t dim, int64_t start, int64_t length) { if (length < 0) { tensor_throw_invalid_argument("Tensor narrow requires non-negative length: length=" + std::to_string(length) + ", input_shape=" + tensor_shape_to_string(input.shape())); } return slice(input, dim, start, start + length); } template inline void slice_assign(Tensor* dst, size_t dim, int64_t start, int64_t end, const Tensor& src) { if (dst == nullptr) { tensor_throw_invalid_argument("Tensor slice_assign requires non-null dst"); } auto [resolved_start, resolved_end] = resolve_slice_bounds(*dst, dim, start, end); if (src.dim() != dst->dim()) { tensor_throw_invalid_argument("Tensor slice_assign requires matching rank: dst_shape=" + tensor_shape_to_string(dst->shape()) + ", src_shape=" + tensor_shape_to_string(src.shape())); } std::vector expected_shape = dst->shape(); expected_shape[dim] = resolved_end - resolved_start; if (src.shape() != expected_shape) { tensor_throw_invalid_argument("Tensor slice_assign requires matching source shape: dst_shape=" + tensor_shape_to_string(dst->shape()) + ", src_shape=" + tensor_shape_to_string(src.shape()) + ", expected_src_shape=" + tensor_shape_to_string(expected_shape)); } if (src.numel() == 0) { return; } int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= dst->shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(dst->dim()); ++i) { outer *= dst->shape()[i]; } int64_t dst_chunk = (resolved_end - resolved_start) * inner; int64_t dst_stride = dst->shape()[dim] * inner; for (int64_t i = 0; i < outer; ++i) { const int64_t dst_offset = i * dst_stride + resolved_start * inner; const int64_t src_offset = i * dst_chunk; std::copy_n(src.data() + src_offset, dst_chunk, dst->data() + dst_offset); } } template inline void fill_slice(Tensor* dst, size_t dim, int64_t start, int64_t end, const T& value) { if (dst == nullptr) { tensor_throw_invalid_argument("Tensor fill_slice requires non-null dst"); } auto [resolved_start, resolved_end] = resolve_slice_bounds(*dst, dim, start, end); int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= dst->shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(dst->dim()); ++i) { outer *= dst->shape()[i]; } int64_t chunk = (resolved_end - resolved_start) * inner; int64_t stride = dst->shape()[dim] * inner; for (int64_t i = 0; i < outer; ++i) { const int64_t offset = i * stride + resolved_start * inner; std::fill_n(dst->data() + offset, chunk, value); } } template inline Tensor interpolate(const Tensor& input, std::vector output_shape, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false, bool antialias = false) { const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode); const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode); if (!is_nearest_like_mode && !is_2d_filter_mode) { tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + std::to_string(static_cast(mode))); } if (antialias && !is_2d_filter_mode) { tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" + std::to_string(static_cast(mode))); } if (align_corners) { tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } if (input.shape() == output_shape) { return input; } if (input.dim() != static_cast(output_shape.size())) { tensor_throw_invalid_argument("Tensor interpolate requires matching rank: input_dim=" + std::to_string(input.dim()) + ", output_dim=" + std::to_string(output_shape.size()) + ", input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } for (size_t i = 0; i < output_shape.size(); ++i) { if (output_shape[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate output shape must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } if (input.shape()[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate input shape must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } } if (is_2d_filter_mode) { return interpolate_2d_filter(input, output_shape, mode, antialias); } bool has_downsampling = false; for (int64_t i = 0; i < input.dim(); ++i) { if (input.shape()[i] > output_shape[i]) { has_downsampling = true; break; } } Tensor output(std::move(output_shape)); if (mode == InterpolateMode::Nearest || mode == InterpolateMode::NearestExact || !has_downsampling) { for (int64_t flat = 0; flat < output.numel(); ++flat) { std::vector output_coord = tensor_unravel_index(flat, output.shape()); std::vector input_coord(static_cast(input.dim()), 0); for (size_t i = 0; i < static_cast(input.dim()); ++i) { if (mode == InterpolateMode::NearestExact) { input_coord[i] = nearest_exact_interpolate_index(output_coord[i], input.shape()[i], output.shape()[i]); } else { input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; } } output[flat] = input.index(input_coord); } return output; } auto init_reduction = [&]() -> T { switch (mode) { case InterpolateMode::NearestMax: return std::numeric_limits::lowest(); case InterpolateMode::NearestMin: return std::numeric_limits::max(); case InterpolateMode::NearestAvg: return T(0); case InterpolateMode::Nearest: return T(0); case InterpolateMode::NearestExact: return T(0); case InterpolateMode::Bilinear: case InterpolateMode::Bicubic: case InterpolateMode::Lanczos: break; } tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + std::to_string(static_cast(mode))); }; auto reduce_value = [&](T& acc, const T& sample) { switch (mode) { case InterpolateMode::NearestMax: acc = std::max(acc, sample); break; case InterpolateMode::NearestMin: acc = std::min(acc, sample); break; case InterpolateMode::NearestAvg: acc += sample; break; case InterpolateMode::Nearest: break; case InterpolateMode::NearestExact: break; case InterpolateMode::Bilinear: case InterpolateMode::Bicubic: case InterpolateMode::Lanczos: break; } }; // Reduction modes only differ from nearest mode when downsampling. for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) { std::vector output_coord = tensor_unravel_index(flat_out, output.shape()); std::vector input_start(output.dim(), 0); std::vector input_end(output.dim(), 0); for (size_t i = 0; i < static_cast(output.dim()); ++i) { const int64_t input_dim = input.shape()[i]; const int64_t output_dim = output.shape()[i]; input_start[i] = std::max(int64_t(0), static_cast(output_coord[i] * input_dim / output_dim)); input_end[i] = std::min(input_dim, ((output_coord[i] + 1) * input_dim + output_dim - 1) / output_dim); } T value = init_reduction(); bool done_window = false; std::vector current_in_coord = input_start; while (!done_window) { reduce_value(value, input.index(current_in_coord)); for (int d = static_cast(output.dim()) - 1; d >= 0; --d) { if (++current_in_coord[d] < input_end[d]) { break; } current_in_coord[d] = input_start[d]; if (d == 0) { done_window = true; } } } if (mode == InterpolateMode::NearestAvg) { int64_t window_size = 1; for (size_t i = 0; i < static_cast(output.dim()); ++i) { window_size *= (input_end[i] - input_start[i]); } value /= static_cast(window_size); } output[flat_out] = value; } return output; } template inline Tensor interpolate(const Tensor& input, const std::optional>& size, const std::optional>& scale_factor, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false, bool antialias = false) { const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode); const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode); if (!is_nearest_like_mode && !is_2d_filter_mode) { tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + std::to_string(static_cast(mode))); } if (antialias && !is_2d_filter_mode) { tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" + std::to_string(static_cast(mode))); } if (align_corners) { tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" + tensor_shape_to_string(input.shape())); } if (size.has_value() == scale_factor.has_value()) { tensor_throw_invalid_argument("Tensor interpolate requires exactly one of size or scale_factor: input_shape=" + tensor_shape_to_string(input.shape())); } std::vector output_shape = input.shape(); if (size.has_value()) { if (size->empty() || size->size() > output_shape.size()) { tensor_throw_invalid_argument("Tensor interpolate size must target low dimensions: input_shape=" + tensor_shape_to_string(input.shape()) + ", size_rank=" + std::to_string(size->size())); } for (size_t i = 0; i < size->size(); ++i) { if ((*size)[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate size must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", size=" + tensor_shape_to_string(*size)); } output_shape[i] = (*size)[i]; } } else { if (scale_factor->empty() || scale_factor->size() > output_shape.size()) { tensor_throw_invalid_argument("Tensor interpolate scale_factor must target low dimensions: input_shape=" + tensor_shape_to_string(input.shape()) + ", scale_factor_rank=" + std::to_string(scale_factor->size())); } for (size_t i = 0; i < scale_factor->size(); ++i) { if ((*scale_factor)[i] <= 0.0) { tensor_throw_invalid_argument("Tensor interpolate scale_factor must be positive: input_shape=" + tensor_shape_to_string(input.shape())); } output_shape[i] = static_cast( std::floor(static_cast(output_shape[i]) * (*scale_factor)[i])); if (output_shape[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate output shape must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } } } return interpolate(input, std::move(output_shape), mode, align_corners, antialias); } template inline Tensor interpolate(const Tensor& input, const std::optional>& size, double scale_factor, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false, bool antialias = false) { return interpolate(input, size, std::vector(size.has_value() ? size->size() : input.dim(), scale_factor), mode, align_corners, antialias); } template inline Tensor max_pool_2d(const Tensor& input, std::vector kernel_size, std::vector stride, std::vector padding) { if (input.dim() < 2) { tensor_throw_invalid_argument("Tensor max_pool_2d requires input_dim >= 2: input_dim=" + std::to_string(input.dim()) + ", input_shape=" + tensor_shape_to_string(input.shape())); } if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) { tensor_throw_invalid_argument("Tensor max_pool_2d requires kernel_size, stride, and padding to have length 2"); } for (size_t i = 0; i < 2; ++i) { if (kernel_size[i] <= 0) { tensor_throw_invalid_argument("Tensor max_pool_2d kernel_size must be positive: kernel_size=" + tensor_shape_to_string(kernel_size)); } if (stride[i] <= 0) { tensor_throw_invalid_argument("Tensor max_pool_2d stride must be positive: stride=" + tensor_shape_to_string(stride)); } if (padding[i] < 0) { tensor_throw_invalid_argument("Tensor max_pool_2d padding must be non-negative: padding=" + tensor_shape_to_string(padding)); } } const int64_t in_height = input.shape()[0]; const int64_t in_width = input.shape()[1]; const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1; const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1; if (out_height <= 0 || out_width <= 0) { tensor_throw_invalid_argument("max_pool_2d results in invalid output dimensions: " + std::to_string(out_height) + "x" + std::to_string(out_width)); } std::vector output_shape = input.shape(); output_shape[0] = out_height; output_shape[1] = out_width; Tensor output(std::move(output_shape)); for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) { std::vector output_coord = tensor_unravel_index(flat_out, output.shape()); std::vector input_coord = output_coord; const int64_t oh = output_coord[0]; const int64_t ow = output_coord[1]; T max_val = std::numeric_limits::lowest(); bool has_valid_input = false; for (int64_t kh = 0; kh < kernel_size[0]; ++kh) { for (int64_t kw = 0; kw < kernel_size[1]; ++kw) { const int64_t ih = oh * stride[0] + kh - padding[0]; const int64_t iw = ow * stride[1] + kw - padding[1]; if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { input_coord[0] = ih; input_coord[1] = iw; max_val = std::max(max_val, input.index(input_coord)); has_valid_input = true; } } } output[flat_out] = has_valid_input ? max_val : T(0); } return output; } template inline Tensor concat(const Tensor& lhs, const Tensor& rhs, size_t dim) { if (lhs.dim() != rhs.dim()) { tensor_throw_invalid_argument("Tensor concat requires same rank: lhs_dim=" + std::to_string(lhs.dim()) + ", rhs_dim=" + std::to_string(rhs.dim()) + ", lhs_shape=" + tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" + tensor_shape_to_string(rhs.shape())); } if (dim >= static_cast(lhs.dim())) { tensor_throw_invalid_argument("Tensor concat dimension out of range: dim=" + std::to_string(dim) + ", rank=" + std::to_string(lhs.dim()) + ", lhs_shape=" + tensor_shape_to_string(lhs.shape())); } std::vector out_shape = lhs.shape(); for (size_t i = 0; i < static_cast(lhs.dim()); ++i) { if (i == dim) { continue; } if (lhs.shape()[i] != rhs.shape()[i]) { tensor_throw_invalid_argument("Tensor concat requires matching non-concat dimensions: dim=" + std::to_string(dim) + ", lhs_shape=" + tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" + tensor_shape_to_string(rhs.shape())); } } out_shape[dim] += rhs.shape()[dim]; Tensor out(out_shape); int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= lhs.shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(lhs.dim()); ++i) { outer *= lhs.shape()[i]; } int64_t lhs_chunk = lhs.shape()[dim] * inner; int64_t rhs_chunk = rhs.shape()[dim] * inner; int64_t out_chunk = lhs_chunk + rhs_chunk; for (int64_t i = 0; i < outer; ++i) { int64_t lhs_offset = i * lhs_chunk; int64_t rhs_offset = i * rhs_chunk; int64_t out_offset = i * out_chunk; std::copy_n(lhs.data() + lhs_offset, lhs_chunk, out.data() + out_offset); std::copy_n(rhs.data() + rhs_offset, rhs_chunk, out.data() + out_offset + lhs_chunk); } return out; } template inline std::vector> chunk(const Tensor& tensor, int64_t chunks, size_t dim) { if (chunks <= 0) { tensor_throw_invalid_argument("Tensor chunk requires chunks > 0: chunks=" + std::to_string(chunks) + ", tensor_shape=" + tensor_shape_to_string(tensor.shape())); } if (dim >= static_cast(tensor.dim())) { tensor_throw_invalid_argument("Tensor chunk dimension out of range: dim=" + std::to_string(dim) + ", rank=" + std::to_string(tensor.dim()) + ", tensor_shape=" + tensor_shape_to_string(tensor.shape())); } const int64_t dim_size = tensor.shape()[dim]; if (dim_size == 0) { return {}; } if (dim_size % chunks != 0) { tensor_throw_invalid_argument("Tensor chunk requires the dimension size to be divisible by chunks: dim=" + std::to_string(dim) + ", dim_size=" + std::to_string(dim_size) + ", chunks=" + std::to_string(chunks) + ", tensor_shape=" + tensor_shape_to_string(tensor.shape())); } const int64_t chunk_size = dim_size / chunks; int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= tensor.shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(tensor.dim()); ++i) { outer *= tensor.shape()[i]; } std::vector> parts; parts.reserve(static_cast(chunks)); for (int64_t start = 0; start < dim_size; start += chunk_size) { std::vector part_shape = tensor.shape(); part_shape[dim] = chunk_size; Tensor part(part_shape); const int64_t src_chunk = chunk_size * inner; const int64_t dst_chunk = src_chunk; for (int64_t i = 0; i < outer; ++i) { const int64_t src_offset = (i * dim_size + start) * inner; const int64_t dst_offset = i * dst_chunk; std::copy_n(tensor.data() + src_offset, src_chunk, part.data() + dst_offset); } parts.push_back(std::move(part)); } return parts; } } // namespace ops } // namespace sd #endif