#ifndef __SD_TENSOR_HPP__ #define __SD_TENSOR_HPP__ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "rng.hpp" namespace sd { template class Tensor; inline std::vector tensor_unravel_index(int64_t flat, const std::vector& shape); [[noreturn]] inline void tensor_throw_invalid_argument(const std::string& message) { std::fprintf(stderr, "sd::Tensor error: %s\n", message.c_str()); std::fflush(stderr); throw std::invalid_argument(message); } inline std::string tensor_shape_to_string(const std::vector& shape) { std::ostringstream oss; oss << "["; for (size_t i = 0; i < shape.size(); ++i) { if (i != 0) { oss << ", "; } oss << shape[i]; } oss << "]"; return oss.str(); } inline int64_t tensor_numel(const std::vector& shape) { if (shape.empty()) { return 0; } int64_t numel = 1; for (int64_t dim : shape) { if (dim < 0) { tensor_throw_invalid_argument("Tensor shape must be non-negative, got shape=" + tensor_shape_to_string(shape)); } numel *= dim; } return numel; } template class Tensor { public: Tensor() = default; explicit Tensor(std::vector shape) : data_(static_cast(tensor_numel(shape))), shape_(std::move(shape)) { } Tensor(std::vector shape, std::vector data) : data_(std::move(data)), shape_(std::move(shape)) { if (static_cast(data_.size()) != tensor_numel(shape_)) { tensor_throw_invalid_argument("Tensor data size does not match shape: data.size()=" + std::to_string(data_.size()) + ", shape=" + tensor_shape_to_string(shape_) + ", numel=" + std::to_string(tensor_numel(shape_))); } } const std::vector& shape() const { return shape_; } int64_t dim() const { return static_cast(shape_.size()); } int64_t numel() const { return static_cast(data_.size()); } bool empty() const { return data_.empty(); } T* data() { return data_.data(); } const T* data() const { return data_.data(); } std::vector& values() { return data_; } const std::vector& values() const { return data_; } void resize(std::vector shape) { shape_ = std::move(shape); data_.resize(static_cast(tensor_numel(shape_))); } Tensor& reshape_(std::vector shape) { if (tensor_numel(shape) != numel()) { tensor_throw_invalid_argument("Tensor reshape changes element count: from shape=" + tensor_shape_to_string(shape_) + " (numel=" + std::to_string(numel()) + ") to shape=" + tensor_shape_to_string(shape) + " (numel=" + std::to_string(tensor_numel(shape)) + ")"); } shape_ = std::move(shape); return *this; } Tensor reshape(std::vector shape) const { Tensor result = *this; result.reshape_(std::move(shape)); return result; } Tensor& squeeze_() { std::vector new_shape; new_shape.reserve(shape_.size()); for (int64_t dim : shape_) { if (dim != 1) { new_shape.push_back(dim); } } shape_ = std::move(new_shape); return *this; } Tensor& squeeze_(size_t dim) { if (dim >= shape_.size()) { tensor_throw_invalid_argument("Tensor squeeze dimension out of range: dim=" + std::to_string(dim) + ", shape=" + tensor_shape_to_string(shape_)); } if (shape_[dim] != 1) { tensor_throw_invalid_argument("Tensor squeeze requires dimension size 1: dim=" + std::to_string(dim) + ", shape=" + tensor_shape_to_string(shape_)); } shape_.erase(shape_.begin() + static_cast(dim)); return *this; } Tensor squeeze() const { Tensor result = *this; result.squeeze_(); return result; } Tensor squeeze(size_t dim) const { Tensor result = *this; result.squeeze_(dim); return result; } Tensor& unsqueeze_(size_t dim) { if (dim > shape_.size()) { tensor_throw_invalid_argument("Tensor unsqueeze dimension out of range: dim=" + std::to_string(dim) + ", shape=" + tensor_shape_to_string(shape_)); } shape_.insert(shape_.begin() + static_cast(dim), 1); return *this; } Tensor unsqueeze(size_t dim) const { Tensor result = *this; result.unsqueeze_(dim); return result; } Tensor permute(const std::vector& dims) const { if (dims.size() != static_cast(dim())) { tensor_throw_invalid_argument("Tensor permute requires one dimension index per axis: tensor_shape=" + tensor_shape_to_string(shape_) + ", dims_size=" + std::to_string(dims.size())); } std::vector seen(dims.size(), false); std::vector out_shape(dims.size(), 1); for (size_t i = 0; i < dims.size(); ++i) { size_t dim_index = dims[i]; if (dim_index >= dims.size() || seen[dim_index]) { tensor_throw_invalid_argument("Tensor permute dimensions must be a valid permutation: tensor_shape=" + tensor_shape_to_string(shape_)); } seen[dim_index] = true; out_shape[i] = shape_[dim_index]; } Tensor result(out_shape); if (result.numel() == 0) { return result; } for (int64_t flat = 0; flat < result.numel(); ++flat) { std::vector out_coord = tensor_unravel_index(flat, out_shape); std::vector src_coord(static_cast(dim()), 0); for (size_t i = 0; i < dims.size(); ++i) { src_coord[dims[i]] = out_coord[i]; } result[flat] = index(src_coord); } return result; } Tensor& permute_(const std::vector& dims) { *this = permute(dims); return *this; } void fill_(const T& value) { std::fill(data_.begin(), data_.end(), value); } Tensor& masked_fill_(const Tensor& mask, const T& value); T mean() const; static Tensor zeros(std::vector shape) { return Tensor(std::move(shape)); } static Tensor zeros_like(const Tensor& other) { return zeros(other.shape()); } static Tensor ones(std::vector shape) { return full(std::move(shape), static_cast(1)); } static Tensor ones_like(const Tensor& other) { return ones(other.shape()); } static Tensor full(std::vector shape, const T& value) { Tensor tensor(std::move(shape)); tensor.fill_(value); return tensor; } static Tensor randn(std::vector shape, const std::shared_ptr& rng) { static_assert(std::is_same_v, "Tensor::randn currently requires Tensor"); if (!rng) { tensor_throw_invalid_argument("Tensor randn requires a valid RNG"); } const uint32_t size = static_cast(tensor_numel(shape)); return Tensor(std::move(shape), rng->randn(size)); } static Tensor randn_like(const Tensor& other, const std::shared_ptr& rng) { return randn(other.shape(), rng); } static Tensor from_vector(std::vector data) { const int64_t size = static_cast(data.size()); return Tensor({size}, std::move(data)); } T& index(const std::vector& coord) { return data_.at(offset_of(coord)); } const T& index(const std::vector& coord) const { return data_.at(offset_of(coord)); } template && ...)>> T& index(Indices... indices) { return index(std::vector{static_cast(indices)...}); } template && ...)>> const T& index(Indices... indices) const { return index(std::vector{static_cast(indices)...}); } T& operator[](int64_t index) { return data_.at(static_cast(index)); } const T& operator[](int64_t index) const { return data_.at(static_cast(index)); } private: size_t offset_of(const std::vector& coord) const { if (coord.size() != shape_.size()) { tensor_throw_invalid_argument("Tensor index rank mismatch: coord_rank=" + std::to_string(coord.size()) + ", shape=" + tensor_shape_to_string(shape_)); } size_t offset = 0; size_t stride = 1; for (size_t i = 0; i < shape_.size(); ++i) { if (coord[i] < 0 || coord[i] >= shape_[i]) { tensor_throw_invalid_argument("Tensor index out of range: shape=" + tensor_shape_to_string(shape_)); } offset += static_cast(coord[i]) * stride; stride *= static_cast(shape_[i]); } return offset; } std::vector data_; std::vector shape_; }; template inline T Tensor::mean() const { if (empty()) { return T{}; } T sum = T{}; for (const T& value : data_) { sum += value; } return sum / static_cast(numel()); } template <> inline float Tensor::mean() const { if (empty()) { return 0.0f; } double sum = 0.0; for (float value : data_) { sum += static_cast(value); } return static_cast(sum / static_cast(numel())); } template inline void tensor_check_same_shape(const Tensor& lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { tensor_throw_invalid_argument("Tensor shapes must match: lhs_shape=" + tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" + tensor_shape_to_string(rhs.shape())); } } inline std::vector tensor_broadcast_shape(const std::vector& lhs, const std::vector& rhs) { size_t ndim = std::max(lhs.size(), rhs.size()); std::vector shape(ndim, 1); for (size_t i = 0; i < ndim; ++i) { int64_t lhs_dim = lhs.size() > i ? lhs[i] : 1; int64_t rhs_dim = rhs.size() > i ? rhs[i] : 1; if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) { tensor_throw_invalid_argument("Tensor shapes are not broadcastable: lhs_shape=" + tensor_shape_to_string(lhs) + ", rhs_shape=" + tensor_shape_to_string(rhs)); } shape[i] = std::max(lhs_dim, rhs_dim); } return shape; } inline std::vector tensor_unravel_index(int64_t flat, const std::vector& shape) { std::vector coord(shape.size(), 0); for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] <= 0) { tensor_throw_invalid_argument("Tensor unravel_index requires positive shape: shape=" + tensor_shape_to_string(shape)); } coord[i] = flat % shape[i]; flat /= shape[i]; } return coord; } inline std::vector tensor_compute_strides(const std::vector& shape) { std::vector strides(shape.size(), 1); int64_t stride = 1; for (size_t i = 0; i < shape.size(); ++i) { strides[i] = stride; stride *= shape[i]; } return strides; } template inline void tensor_for_each_broadcast_offset(const std::vector& out_shape, const std::vector& lhs_shape_raw, const std::vector& lhs_strides_raw, const std::vector& rhs_shape_raw, const std::vector& rhs_strides_raw, F&& fn) { const size_t ndim = out_shape.size(); std::vector out_strides = tensor_compute_strides(out_shape); std::vector lhs_shape(ndim, 1); std::vector lhs_strides(ndim, 0); std::vector rhs_shape(ndim, 1); std::vector rhs_strides(ndim, 0); for (size_t i = 0; i < lhs_shape_raw.size(); ++i) { lhs_shape[i] = lhs_shape_raw[i]; lhs_strides[i] = lhs_strides_raw[i]; } for (size_t i = 0; i < rhs_shape_raw.size(); ++i) { rhs_shape[i] = rhs_shape_raw[i]; rhs_strides[i] = rhs_strides_raw[i]; } const int64_t numel = tensor_numel(out_shape); for (int64_t flat = 0; flat < numel; ++flat) { int64_t remaining = flat; int64_t lhs_offset = 0; int64_t rhs_offset = 0; for (size_t i = ndim; i-- > 0;) { int64_t coord = remaining / out_strides[i]; remaining %= out_strides[i]; if (lhs_shape[i] != 1) { lhs_offset += coord * lhs_strides[i]; } if (rhs_shape[i] != 1) { rhs_offset += coord * rhs_strides[i]; } } fn(flat, lhs_offset, rhs_offset); } } template inline Tensor& Tensor::masked_fill_(const Tensor& mask, const T& value) { if (empty()) { return *this; } tensor_broadcast_shape(shape_, mask.shape()); const std::vector data_strides = tensor_compute_strides(shape_); const std::vector mask_strides = tensor_compute_strides(mask.shape()); const uint8_t* mask_data = mask.data(); tensor_for_each_broadcast_offset(shape_, shape_, data_strides, mask.shape(), mask_strides, [&](int64_t, int64_t data_offset, int64_t mask_offset) { if (mask_data[mask_offset] != 0) { data_[static_cast(data_offset)] = value; } }); return *this; } template ::value>> inline Tensor operator<(const Tensor& lhs, Scalar rhs) { Tensor result(lhs.shape()); const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { result[i] = lhs[i] < value ? 1 : 0; } return result; } template ::value>> inline Tensor operator<(Scalar lhs, const Tensor& rhs) { Tensor result(rhs.shape()); const T value = static_cast(lhs); for (int64_t i = 0; i < rhs.numel(); ++i) { result[i] = value < rhs[i] ? 1 : 0; } return result; } template inline Tensor operator<(const Tensor& lhs, const Tensor& rhs) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] < rhs_data[rhs_offset] ? 1 : 0; }); return result; } template inline Tensor& operator+=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] += rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] += rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator+=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] += value; } return lhs; } template inline Tensor& operator-=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] -= rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] -= rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator-=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] -= value; } return lhs; } template inline Tensor& operator*=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] *= rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] *= rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator*=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] *= value; } return lhs; } template inline Tensor& operator/=(Tensor& lhs, const Tensor& rhs) { if (lhs.shape() == rhs.shape()) { for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] /= rhs[i]; } return lhs; } tensor_broadcast_shape(lhs.shape(), rhs.shape()); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(lhs.shape(), lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t, int64_t lhs_offset, int64_t rhs_offset) { lhs[static_cast(lhs_offset)] /= rhs_data[rhs_offset]; }); return lhs; } template ::value>> inline Tensor& operator/=(Tensor& lhs, Scalar rhs) { const T value = static_cast(rhs); for (int64_t i = 0; i < lhs.numel(); ++i) { lhs[i] /= value; } return lhs; } template inline Tensor operator+(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] + rhs_data[rhs_offset]; }); return result; } lhs += rhs; return lhs; } template ::value>> inline Tensor operator+(Tensor lhs, Scalar rhs) { lhs += rhs; return lhs; } template ::value>> inline Tensor operator+(Scalar lhs, Tensor rhs) { rhs += lhs; return rhs; } template inline Tensor operator-(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] - rhs_data[rhs_offset]; }); return result; } lhs -= rhs; return lhs; } template ::value>> inline Tensor operator-(Tensor lhs, Scalar rhs) { lhs -= rhs; return lhs; } template ::value>> inline Tensor operator-(Scalar lhs, const Tensor& rhs) { Tensor result = rhs; const T value = static_cast(lhs); for (int64_t i = 0; i < result.numel(); ++i) { result[i] = value - result[i]; } return result; } template inline Tensor operator*(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] * rhs_data[rhs_offset]; }); return result; } lhs *= rhs; return lhs; } template ::value>> inline Tensor operator*(Tensor lhs, Scalar rhs) { lhs *= rhs; return lhs; } template ::value>> inline Tensor operator*(Scalar lhs, Tensor rhs) { rhs *= lhs; return rhs; } template inline Tensor operator/(Tensor lhs, const Tensor& rhs) { if (lhs.shape() != rhs.shape()) { const std::vector out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape()); Tensor result(out_shape); const std::vector lhs_strides = tensor_compute_strides(lhs.shape()); const std::vector rhs_strides = tensor_compute_strides(rhs.shape()); const T* lhs_data = lhs.data(); const T* rhs_data = rhs.data(); tensor_for_each_broadcast_offset(out_shape, lhs.shape(), lhs_strides, rhs.shape(), rhs_strides, [&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) { result[flat] = lhs_data[lhs_offset] / rhs_data[rhs_offset]; }); return result; } lhs /= rhs; return lhs; } template ::value>> inline Tensor operator/(Tensor lhs, Scalar rhs) { lhs /= rhs; return lhs; } template ::value>> inline Tensor operator/(Scalar lhs, const Tensor& rhs) { Tensor result = rhs; const T value = static_cast(lhs); for (int64_t i = 0; i < result.numel(); ++i) { result[i] = value / result[i]; } return result; } template inline Tensor operator-(const Tensor& tensor) { Tensor result = tensor; for (int64_t i = 0; i < result.numel(); ++i) { result[i] = -result[i]; } return result; } template inline Tensor zeros(std::vector shape) { return Tensor::zeros(std::move(shape)); } template inline Tensor full(std::vector shape, const T& value) { return Tensor::full(std::move(shape), value); } template inline Tensor randn(std::vector shape, const std::shared_ptr& rng) { return Tensor::randn(std::move(shape), rng); } template inline Tensor randn_like(const Tensor& tensor, const std::shared_ptr& rng) { return Tensor::randn(tensor.shape(), rng); } template inline std::vector tensor_to_vector(const Tensor& tensor) { return tensor.values(); } namespace ops { enum class InterpolateMode { Nearest, }; inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) { if (index < 0) { index += dim_size; } return index; } template inline std::pair resolve_slice_bounds(const Tensor& input, size_t dim, int64_t start, int64_t end) { if (dim >= static_cast(input.dim())) { tensor_throw_invalid_argument("Tensor slice dimension out of range: dim=" + std::to_string(dim) + ", rank=" + std::to_string(input.dim()) + ", input_shape=" + tensor_shape_to_string(input.shape())); } int64_t dim_size = input.shape()[dim]; start = normalize_slice_bound(start, dim_size); end = normalize_slice_bound(end, dim_size); if (start < 0 || start > dim_size || end < 0 || end > dim_size || start > end) { tensor_throw_invalid_argument("Tensor slice bounds out of range: dim=" + std::to_string(dim) + ", start=" + std::to_string(start) + ", end=" + std::to_string(end) + ", input_shape=" + tensor_shape_to_string(input.shape())); } return {start, end}; } template inline Tensor exp(const Tensor& input) { Tensor output(input.shape()); for (int64_t i = 0; i < input.numel(); ++i) { output[i] = static_cast(std::exp(static_cast(input[i]))); } return output; } template inline Tensor clamp(const Tensor& input, const T& min_value, const T& max_value) { if (min_value > max_value) { tensor_throw_invalid_argument("Tensor clamp requires min_value <= max_value"); } Tensor output(input.shape()); for (int64_t i = 0; i < input.numel(); ++i) { output[i] = std::clamp(input[i], min_value, max_value); } return output; } template inline Tensor round(const Tensor& input) { Tensor output(input.shape()); for (int64_t i = 0; i < input.numel(); ++i) { output[i] = static_cast(std::round(static_cast(input[i]))); } return output; } template inline Tensor slice(const Tensor& input, size_t dim, int64_t start, int64_t end) { auto [resolved_start, resolved_end] = resolve_slice_bounds(input, dim, start, end); std::vector out_shape = input.shape(); out_shape[dim] = resolved_end - resolved_start; Tensor output(out_shape); if (output.numel() == 0) { return output; } int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= input.shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(input.dim()); ++i) { outer *= input.shape()[i]; } int64_t src_chunk = (resolved_end - resolved_start) * inner; int64_t src_stride = input.shape()[dim] * inner; for (int64_t i = 0; i < outer; ++i) { const int64_t src_offset = i * src_stride + resolved_start * inner; const int64_t dst_offset = i * src_chunk; std::copy_n(input.data() + src_offset, src_chunk, output.data() + dst_offset); } return output; } template inline Tensor narrow(const Tensor& input, size_t dim, int64_t start, int64_t length) { if (length < 0) { tensor_throw_invalid_argument("Tensor narrow requires non-negative length: length=" + std::to_string(length) + ", input_shape=" + tensor_shape_to_string(input.shape())); } return slice(input, dim, start, start + length); } template inline void slice_assign(Tensor* dst, size_t dim, int64_t start, int64_t end, const Tensor& src) { if (dst == nullptr) { tensor_throw_invalid_argument("Tensor slice_assign requires non-null dst"); } auto [resolved_start, resolved_end] = resolve_slice_bounds(*dst, dim, start, end); if (src.dim() != dst->dim()) { tensor_throw_invalid_argument("Tensor slice_assign requires matching rank: dst_shape=" + tensor_shape_to_string(dst->shape()) + ", src_shape=" + tensor_shape_to_string(src.shape())); } std::vector expected_shape = dst->shape(); expected_shape[dim] = resolved_end - resolved_start; if (src.shape() != expected_shape) { tensor_throw_invalid_argument("Tensor slice_assign requires matching source shape: dst_shape=" + tensor_shape_to_string(dst->shape()) + ", src_shape=" + tensor_shape_to_string(src.shape()) + ", expected_src_shape=" + tensor_shape_to_string(expected_shape)); } if (src.numel() == 0) { return; } int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= dst->shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(dst->dim()); ++i) { outer *= dst->shape()[i]; } int64_t dst_chunk = (resolved_end - resolved_start) * inner; int64_t dst_stride = dst->shape()[dim] * inner; for (int64_t i = 0; i < outer; ++i) { const int64_t dst_offset = i * dst_stride + resolved_start * inner; const int64_t src_offset = i * dst_chunk; std::copy_n(src.data() + src_offset, dst_chunk, dst->data() + dst_offset); } } template inline void fill_slice(Tensor* dst, size_t dim, int64_t start, int64_t end, const T& value) { if (dst == nullptr) { tensor_throw_invalid_argument("Tensor fill_slice requires non-null dst"); } auto [resolved_start, resolved_end] = resolve_slice_bounds(*dst, dim, start, end); int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= dst->shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(dst->dim()); ++i) { outer *= dst->shape()[i]; } int64_t chunk = (resolved_end - resolved_start) * inner; int64_t stride = dst->shape()[dim] * inner; for (int64_t i = 0; i < outer; ++i) { const int64_t offset = i * stride + resolved_start * inner; std::fill_n(dst->data() + offset, chunk, value); } } template inline Tensor interpolate(const Tensor& input, std::vector output_shape, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false) { if (mode != InterpolateMode::Nearest) { tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + std::to_string(static_cast(mode))); } if (align_corners) { tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } if (input.shape() == output_shape) { return input; } if (input.dim() != static_cast(output_shape.size())) { tensor_throw_invalid_argument("Tensor interpolate requires matching rank: input_dim=" + std::to_string(input.dim()) + ", output_dim=" + std::to_string(output_shape.size()) + ", input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } for (size_t i = 0; i < output_shape.size(); ++i) { if (output_shape[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate output shape must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } if (input.shape()[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate input shape must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } } Tensor output(std::move(output_shape)); for (int64_t flat = 0; flat < output.numel(); ++flat) { std::vector output_coord = tensor_unravel_index(flat, output.shape()); std::vector input_coord(static_cast(input.dim()), 0); for (size_t i = 0; i < static_cast(input.dim()); ++i) { input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; } output[flat] = input.index(input_coord); } return output; } template inline Tensor interpolate(const Tensor& input, const std::optional>& size, const std::optional>& scale_factor, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false) { if (mode != InterpolateMode::Nearest) { tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + std::to_string(static_cast(mode))); } if (align_corners) { tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + tensor_shape_to_string(input.shape())); } if (size.has_value() == scale_factor.has_value()) { tensor_throw_invalid_argument("Tensor interpolate requires exactly one of size or scale_factor: input_shape=" + tensor_shape_to_string(input.shape())); } std::vector output_shape = input.shape(); if (size.has_value()) { if (size->empty() || size->size() > output_shape.size()) { tensor_throw_invalid_argument("Tensor interpolate size must target low dimensions: input_shape=" + tensor_shape_to_string(input.shape()) + ", size_rank=" + std::to_string(size->size())); } for (size_t i = 0; i < size->size(); ++i) { if ((*size)[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate size must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", size=" + tensor_shape_to_string(*size)); } output_shape[i] = (*size)[i]; } } else { if (scale_factor->empty() || scale_factor->size() > output_shape.size()) { tensor_throw_invalid_argument("Tensor interpolate scale_factor must target low dimensions: input_shape=" + tensor_shape_to_string(input.shape()) + ", scale_factor_rank=" + std::to_string(scale_factor->size())); } for (size_t i = 0; i < scale_factor->size(); ++i) { if ((*scale_factor)[i] <= 0.0) { tensor_throw_invalid_argument("Tensor interpolate scale_factor must be positive: input_shape=" + tensor_shape_to_string(input.shape())); } output_shape[i] = static_cast( std::floor(static_cast(output_shape[i]) * (*scale_factor)[i])); if (output_shape[i] <= 0) { tensor_throw_invalid_argument("Tensor interpolate output shape must be positive: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } } } return interpolate(input, std::move(output_shape), mode, align_corners); } template inline Tensor interpolate(const Tensor& input, const std::optional>& size, double scale_factor, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false) { return interpolate(input, size, std::vector(size.has_value() ? size->size() : input.dim(), scale_factor), mode, align_corners); } template inline Tensor concat(const Tensor& lhs, const Tensor& rhs, size_t dim) { if (lhs.dim() != rhs.dim()) { tensor_throw_invalid_argument("Tensor concat requires same rank: lhs_dim=" + std::to_string(lhs.dim()) + ", rhs_dim=" + std::to_string(rhs.dim()) + ", lhs_shape=" + tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" + tensor_shape_to_string(rhs.shape())); } if (dim >= static_cast(lhs.dim())) { tensor_throw_invalid_argument("Tensor concat dimension out of range: dim=" + std::to_string(dim) + ", rank=" + std::to_string(lhs.dim()) + ", lhs_shape=" + tensor_shape_to_string(lhs.shape())); } std::vector out_shape = lhs.shape(); for (size_t i = 0; i < static_cast(lhs.dim()); ++i) { if (i == dim) { continue; } if (lhs.shape()[i] != rhs.shape()[i]) { tensor_throw_invalid_argument("Tensor concat requires matching non-concat dimensions: dim=" + std::to_string(dim) + ", lhs_shape=" + tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" + tensor_shape_to_string(rhs.shape())); } } out_shape[dim] += rhs.shape()[dim]; Tensor out(out_shape); int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= lhs.shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(lhs.dim()); ++i) { outer *= lhs.shape()[i]; } int64_t lhs_chunk = lhs.shape()[dim] * inner; int64_t rhs_chunk = rhs.shape()[dim] * inner; int64_t out_chunk = lhs_chunk + rhs_chunk; for (int64_t i = 0; i < outer; ++i) { int64_t lhs_offset = i * lhs_chunk; int64_t rhs_offset = i * rhs_chunk; int64_t out_offset = i * out_chunk; std::copy_n(lhs.data() + lhs_offset, lhs_chunk, out.data() + out_offset); std::copy_n(rhs.data() + rhs_offset, rhs_chunk, out.data() + out_offset + lhs_chunk); } return out; } template inline std::vector> chunk(const Tensor& tensor, int64_t chunks, size_t dim) { if (chunks <= 0) { tensor_throw_invalid_argument("Tensor chunk requires chunks > 0: chunks=" + std::to_string(chunks) + ", tensor_shape=" + tensor_shape_to_string(tensor.shape())); } if (dim >= static_cast(tensor.dim())) { tensor_throw_invalid_argument("Tensor chunk dimension out of range: dim=" + std::to_string(dim) + ", rank=" + std::to_string(tensor.dim()) + ", tensor_shape=" + tensor_shape_to_string(tensor.shape())); } const int64_t dim_size = tensor.shape()[dim]; if (dim_size == 0) { return {}; } if (dim_size % chunks != 0) { tensor_throw_invalid_argument("Tensor chunk requires the dimension size to be divisible by chunks: dim=" + std::to_string(dim) + ", dim_size=" + std::to_string(dim_size) + ", chunks=" + std::to_string(chunks) + ", tensor_shape=" + tensor_shape_to_string(tensor.shape())); } const int64_t chunk_size = dim_size / chunks; int64_t inner = 1; for (size_t i = 0; i < dim; ++i) { inner *= tensor.shape()[i]; } int64_t outer = 1; for (size_t i = dim + 1; i < static_cast(tensor.dim()); ++i) { outer *= tensor.shape()[i]; } std::vector> parts; parts.reserve(static_cast(chunks)); for (int64_t start = 0; start < dim_size; start += chunk_size) { std::vector part_shape = tensor.shape(); part_shape[dim] = chunk_size; Tensor part(part_shape); const int64_t src_chunk = chunk_size * inner; const int64_t dst_chunk = src_chunk; for (int64_t i = 0; i < outer; ++i) { const int64_t src_offset = (i * dim_size + start) * inner; const int64_t dst_offset = i * dst_chunk; std::copy_n(tensor.data() + src_offset, src_chunk, part.data() + dst_offset); } parts.push_back(std::move(part)); } return parts; } } // namespace ops } // namespace sd #endif