1250 lines
53 KiB
C++

#ifndef __SD_TENSOR_HPP__
#define __SD_TENSOR_HPP__
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <initializer_list>
#include <memory>
#include <numeric>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "rng.hpp"
namespace sd {
template <typename T>
class Tensor;
inline std::vector<int64_t> tensor_unravel_index(int64_t flat, const std::vector<int64_t>& shape);
[[noreturn]] inline void tensor_throw_invalid_argument(const std::string& message) {
std::fprintf(stderr, "sd::Tensor error: %s\n", message.c_str());
std::fflush(stderr);
throw std::invalid_argument(message);
}
inline std::string tensor_shape_to_string(const std::vector<int64_t>& shape) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i != 0) {
oss << ", ";
}
oss << shape[i];
}
oss << "]";
return oss.str();
}
inline int64_t tensor_numel(const std::vector<int64_t>& shape) {
if (shape.empty()) {
return 0;
}
int64_t numel = 1;
for (int64_t dim : shape) {
if (dim < 0) {
tensor_throw_invalid_argument("Tensor shape must be non-negative, got shape=" +
tensor_shape_to_string(shape));
}
numel *= dim;
}
return numel;
}
template <typename T>
class Tensor {
public:
Tensor() = default;
explicit Tensor(std::vector<int64_t> shape)
: data_(static_cast<size_t>(tensor_numel(shape))), shape_(std::move(shape)) {
}
Tensor(std::vector<int64_t> shape, std::vector<T> data)
: data_(std::move(data)), shape_(std::move(shape)) {
if (static_cast<int64_t>(data_.size()) != tensor_numel(shape_)) {
tensor_throw_invalid_argument("Tensor data size does not match shape: data.size()=" +
std::to_string(data_.size()) + ", shape=" +
tensor_shape_to_string(shape_) + ", numel=" +
std::to_string(tensor_numel(shape_)));
}
}
const std::vector<int64_t>& shape() const {
return shape_;
}
int64_t dim() const {
return static_cast<int64_t>(shape_.size());
}
int64_t numel() const {
return static_cast<int64_t>(data_.size());
}
bool empty() const {
return data_.empty();
}
T* data() {
return data_.data();
}
const T* data() const {
return data_.data();
}
std::vector<T>& values() {
return data_;
}
const std::vector<T>& values() const {
return data_;
}
void resize(std::vector<int64_t> shape) {
shape_ = std::move(shape);
data_.resize(static_cast<size_t>(tensor_numel(shape_)));
}
Tensor& reshape_(std::vector<int64_t> shape) {
if (tensor_numel(shape) != numel()) {
tensor_throw_invalid_argument("Tensor reshape changes element count: from shape=" +
tensor_shape_to_string(shape_) + " (numel=" +
std::to_string(numel()) + ") to shape=" +
tensor_shape_to_string(shape) + " (numel=" +
std::to_string(tensor_numel(shape)) + ")");
}
shape_ = std::move(shape);
return *this;
}
Tensor reshape(std::vector<int64_t> shape) const {
Tensor result = *this;
result.reshape_(std::move(shape));
return result;
}
Tensor& squeeze_() {
std::vector<int64_t> new_shape;
new_shape.reserve(shape_.size());
for (int64_t dim : shape_) {
if (dim != 1) {
new_shape.push_back(dim);
}
}
shape_ = std::move(new_shape);
return *this;
}
Tensor& squeeze_(size_t dim) {
if (dim >= shape_.size()) {
tensor_throw_invalid_argument("Tensor squeeze dimension out of range: dim=" +
std::to_string(dim) + ", shape=" +
tensor_shape_to_string(shape_));
}
if (shape_[dim] != 1) {
tensor_throw_invalid_argument("Tensor squeeze requires dimension size 1: dim=" +
std::to_string(dim) + ", shape=" +
tensor_shape_to_string(shape_));
}
shape_.erase(shape_.begin() + static_cast<std::ptrdiff_t>(dim));
return *this;
}
Tensor squeeze() const {
Tensor result = *this;
result.squeeze_();
return result;
}
Tensor squeeze(size_t dim) const {
Tensor result = *this;
result.squeeze_(dim);
return result;
}
Tensor& unsqueeze_(size_t dim) {
if (dim > shape_.size()) {
tensor_throw_invalid_argument("Tensor unsqueeze dimension out of range: dim=" +
std::to_string(dim) + ", shape=" +
tensor_shape_to_string(shape_));
}
shape_.insert(shape_.begin() + static_cast<std::ptrdiff_t>(dim), 1);
return *this;
}
Tensor unsqueeze(size_t dim) const {
Tensor result = *this;
result.unsqueeze_(dim);
return result;
}
Tensor permute(const std::vector<size_t>& dims) const {
if (dims.size() != static_cast<size_t>(dim())) {
tensor_throw_invalid_argument("Tensor permute requires one dimension index per axis: tensor_shape=" +
tensor_shape_to_string(shape_) + ", dims_size=" +
std::to_string(dims.size()));
}
std::vector<bool> seen(dims.size(), false);
std::vector<int64_t> out_shape(dims.size(), 1);
for (size_t i = 0; i < dims.size(); ++i) {
size_t dim_index = dims[i];
if (dim_index >= dims.size() || seen[dim_index]) {
tensor_throw_invalid_argument("Tensor permute dimensions must be a valid permutation: tensor_shape=" +
tensor_shape_to_string(shape_));
}
seen[dim_index] = true;
out_shape[i] = shape_[dim_index];
}
Tensor result(out_shape);
if (result.numel() == 0) {
return result;
}
for (int64_t flat = 0; flat < result.numel(); ++flat) {
std::vector<int64_t> out_coord = tensor_unravel_index(flat, out_shape);
std::vector<int64_t> src_coord(static_cast<size_t>(dim()), 0);
for (size_t i = 0; i < dims.size(); ++i) {
src_coord[dims[i]] = out_coord[i];
}
result[flat] = index(src_coord);
}
return result;
}
Tensor& permute_(const std::vector<size_t>& dims) {
*this = permute(dims);
return *this;
}
void fill_(const T& value) {
std::fill(data_.begin(), data_.end(), value);
}
Tensor& masked_fill_(const Tensor<uint8_t>& mask, const T& value);
T mean() const;
static Tensor zeros(std::vector<int64_t> shape) {
return Tensor(std::move(shape));
}
static Tensor zeros_like(const Tensor& other) {
return zeros(other.shape());
}
static Tensor ones(std::vector<int64_t> shape) {
return full(std::move(shape), static_cast<T>(1));
}
static Tensor ones_like(const Tensor& other) {
return ones(other.shape());
}
static Tensor full(std::vector<int64_t> shape, const T& value) {
Tensor tensor(std::move(shape));
tensor.fill_(value);
return tensor;
}
static Tensor randn(std::vector<int64_t> shape, const std::shared_ptr<RNG>& rng) {
static_assert(std::is_same_v<T, float>, "Tensor::randn currently requires Tensor<float>");
if (!rng) {
tensor_throw_invalid_argument("Tensor randn requires a valid RNG");
}
const uint32_t size = static_cast<uint32_t>(tensor_numel(shape));
return Tensor(std::move(shape), rng->randn(size));
}
static Tensor randn_like(const Tensor& other, const std::shared_ptr<RNG>& rng) {
return randn(other.shape(), rng);
}
static Tensor from_vector(std::vector<T> data) {
const int64_t size = static_cast<int64_t>(data.size());
return Tensor({size}, std::move(data));
}
T& index(const std::vector<int64_t>& coord) {
return data_.at(offset_of(coord));
}
const T& index(const std::vector<int64_t>& coord) const {
return data_.at(offset_of(coord));
}
template <typename... Indices, typename = std::enable_if_t<(std::is_convertible_v<Indices, int64_t> && ...)>>
T& index(Indices... indices) {
return index(std::vector<int64_t>{static_cast<int64_t>(indices)...});
}
template <typename... Indices, typename = std::enable_if_t<(std::is_convertible_v<Indices, int64_t> && ...)>>
const T& index(Indices... indices) const {
return index(std::vector<int64_t>{static_cast<int64_t>(indices)...});
}
T& operator[](int64_t index) {
return data_.at(static_cast<size_t>(index));
}
const T& operator[](int64_t index) const {
return data_.at(static_cast<size_t>(index));
}
private:
size_t offset_of(const std::vector<int64_t>& coord) const {
if (coord.size() != shape_.size()) {
tensor_throw_invalid_argument("Tensor index rank mismatch: coord_rank=" +
std::to_string(coord.size()) + ", shape=" +
tensor_shape_to_string(shape_));
}
size_t offset = 0;
size_t stride = 1;
for (size_t i = 0; i < shape_.size(); ++i) {
if (coord[i] < 0 || coord[i] >= shape_[i]) {
tensor_throw_invalid_argument("Tensor index out of range: shape=" +
tensor_shape_to_string(shape_));
}
offset += static_cast<size_t>(coord[i]) * stride;
stride *= static_cast<size_t>(shape_[i]);
}
return offset;
}
std::vector<T> data_;
std::vector<int64_t> shape_;
};
template <typename T>
inline T Tensor<T>::mean() const {
if (empty()) {
return T{};
}
T sum = T{};
for (const T& value : data_) {
sum += value;
}
return sum / static_cast<T>(numel());
}
template <>
inline float Tensor<float>::mean() const {
if (empty()) {
return 0.0f;
}
double sum = 0.0;
for (float value : data_) {
sum += static_cast<double>(value);
}
return static_cast<float>(sum / static_cast<double>(numel()));
}
template <typename T>
inline void tensor_check_same_shape(const Tensor<T>& lhs, const Tensor<T>& rhs) {
if (lhs.shape() != rhs.shape()) {
tensor_throw_invalid_argument("Tensor shapes must match: lhs_shape=" +
tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" +
tensor_shape_to_string(rhs.shape()));
}
}
inline std::vector<int64_t> tensor_broadcast_shape(const std::vector<int64_t>& lhs, const std::vector<int64_t>& rhs) {
size_t ndim = std::max(lhs.size(), rhs.size());
std::vector<int64_t> shape(ndim, 1);
for (size_t i = 0; i < ndim; ++i) {
int64_t lhs_dim = lhs.size() > i ? lhs[i] : 1;
int64_t rhs_dim = rhs.size() > i ? rhs[i] : 1;
if (lhs_dim != rhs_dim && lhs_dim != 1 && rhs_dim != 1) {
tensor_throw_invalid_argument("Tensor shapes are not broadcastable: lhs_shape=" +
tensor_shape_to_string(lhs) + ", rhs_shape=" +
tensor_shape_to_string(rhs));
}
shape[i] = std::max(lhs_dim, rhs_dim);
}
return shape;
}
inline std::vector<int64_t> tensor_unravel_index(int64_t flat, const std::vector<int64_t>& shape) {
std::vector<int64_t> coord(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] <= 0) {
tensor_throw_invalid_argument("Tensor unravel_index requires positive shape: shape=" +
tensor_shape_to_string(shape));
}
coord[i] = flat % shape[i];
flat /= shape[i];
}
return coord;
}
inline std::vector<int64_t> tensor_compute_strides(const std::vector<int64_t>& shape) {
std::vector<int64_t> strides(shape.size(), 1);
int64_t stride = 1;
for (size_t i = 0; i < shape.size(); ++i) {
strides[i] = stride;
stride *= shape[i];
}
return strides;
}
template <typename F>
inline void tensor_for_each_broadcast_offset(const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& lhs_shape_raw,
const std::vector<int64_t>& lhs_strides_raw,
const std::vector<int64_t>& rhs_shape_raw,
const std::vector<int64_t>& rhs_strides_raw,
F&& fn) {
const size_t ndim = out_shape.size();
std::vector<int64_t> out_strides = tensor_compute_strides(out_shape);
std::vector<int64_t> lhs_shape(ndim, 1);
std::vector<int64_t> lhs_strides(ndim, 0);
std::vector<int64_t> rhs_shape(ndim, 1);
std::vector<int64_t> rhs_strides(ndim, 0);
for (size_t i = 0; i < lhs_shape_raw.size(); ++i) {
lhs_shape[i] = lhs_shape_raw[i];
lhs_strides[i] = lhs_strides_raw[i];
}
for (size_t i = 0; i < rhs_shape_raw.size(); ++i) {
rhs_shape[i] = rhs_shape_raw[i];
rhs_strides[i] = rhs_strides_raw[i];
}
const int64_t numel = tensor_numel(out_shape);
for (int64_t flat = 0; flat < numel; ++flat) {
int64_t remaining = flat;
int64_t lhs_offset = 0;
int64_t rhs_offset = 0;
for (size_t i = ndim; i-- > 0;) {
int64_t coord = remaining / out_strides[i];
remaining %= out_strides[i];
if (lhs_shape[i] != 1) {
lhs_offset += coord * lhs_strides[i];
}
if (rhs_shape[i] != 1) {
rhs_offset += coord * rhs_strides[i];
}
}
fn(flat, lhs_offset, rhs_offset);
}
}
template <typename T>
inline Tensor<T>& Tensor<T>::masked_fill_(const Tensor<uint8_t>& mask, const T& value) {
if (empty()) {
return *this;
}
tensor_broadcast_shape(shape_, mask.shape());
const std::vector<int64_t> data_strides = tensor_compute_strides(shape_);
const std::vector<int64_t> mask_strides = tensor_compute_strides(mask.shape());
const uint8_t* mask_data = mask.data();
tensor_for_each_broadcast_offset(shape_,
shape_,
data_strides,
mask.shape(),
mask_strides,
[&](int64_t, int64_t data_offset, int64_t mask_offset) {
if (mask_data[mask_offset] != 0) {
data_[static_cast<size_t>(data_offset)] = value;
}
});
return *this;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<uint8_t> operator<(const Tensor<T>& lhs, Scalar rhs) {
Tensor<uint8_t> result(lhs.shape());
const T value = static_cast<T>(rhs);
for (int64_t i = 0; i < lhs.numel(); ++i) {
result[i] = lhs[i] < value ? 1 : 0;
}
return result;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<uint8_t> operator<(Scalar lhs, const Tensor<T>& rhs) {
Tensor<uint8_t> result(rhs.shape());
const T value = static_cast<T>(lhs);
for (int64_t i = 0; i < rhs.numel(); ++i) {
result[i] = value < rhs[i] ? 1 : 0;
}
return result;
}
template <typename T>
inline Tensor<uint8_t> operator<(const Tensor<T>& lhs, const Tensor<T>& rhs) {
const std::vector<int64_t> out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape());
Tensor<uint8_t> result(out_shape);
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* lhs_data = lhs.data();
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(out_shape,
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) {
result[flat] = lhs_data[lhs_offset] < rhs_data[rhs_offset] ? 1 : 0;
});
return result;
}
template <typename T>
inline Tensor<T>& operator+=(Tensor<T>& lhs, const Tensor<T>& rhs) {
if (lhs.shape() == rhs.shape()) {
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] += rhs[i];
}
return lhs;
}
tensor_broadcast_shape(lhs.shape(), rhs.shape());
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(lhs.shape(),
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t, int64_t lhs_offset, int64_t rhs_offset) {
lhs[static_cast<int64_t>(lhs_offset)] += rhs_data[rhs_offset];
});
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T>& operator+=(Tensor<T>& lhs, Scalar rhs) {
const T value = static_cast<T>(rhs);
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] += value;
}
return lhs;
}
template <typename T>
inline Tensor<T>& operator-=(Tensor<T>& lhs, const Tensor<T>& rhs) {
if (lhs.shape() == rhs.shape()) {
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] -= rhs[i];
}
return lhs;
}
tensor_broadcast_shape(lhs.shape(), rhs.shape());
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(lhs.shape(),
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t, int64_t lhs_offset, int64_t rhs_offset) {
lhs[static_cast<int64_t>(lhs_offset)] -= rhs_data[rhs_offset];
});
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T>& operator-=(Tensor<T>& lhs, Scalar rhs) {
const T value = static_cast<T>(rhs);
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] -= value;
}
return lhs;
}
template <typename T>
inline Tensor<T>& operator*=(Tensor<T>& lhs, const Tensor<T>& rhs) {
if (lhs.shape() == rhs.shape()) {
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] *= rhs[i];
}
return lhs;
}
tensor_broadcast_shape(lhs.shape(), rhs.shape());
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(lhs.shape(),
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t, int64_t lhs_offset, int64_t rhs_offset) {
lhs[static_cast<int64_t>(lhs_offset)] *= rhs_data[rhs_offset];
});
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T>& operator*=(Tensor<T>& lhs, Scalar rhs) {
const T value = static_cast<T>(rhs);
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] *= value;
}
return lhs;
}
template <typename T>
inline Tensor<T>& operator/=(Tensor<T>& lhs, const Tensor<T>& rhs) {
if (lhs.shape() == rhs.shape()) {
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] /= rhs[i];
}
return lhs;
}
tensor_broadcast_shape(lhs.shape(), rhs.shape());
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(lhs.shape(),
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t, int64_t lhs_offset, int64_t rhs_offset) {
lhs[static_cast<int64_t>(lhs_offset)] /= rhs_data[rhs_offset];
});
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T>& operator/=(Tensor<T>& lhs, Scalar rhs) {
const T value = static_cast<T>(rhs);
for (int64_t i = 0; i < lhs.numel(); ++i) {
lhs[i] /= value;
}
return lhs;
}
template <typename T>
inline Tensor<T> operator+(Tensor<T> lhs, const Tensor<T>& rhs) {
if (lhs.shape() != rhs.shape()) {
const std::vector<int64_t> out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape());
Tensor<T> result(out_shape);
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* lhs_data = lhs.data();
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(out_shape,
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) {
result[flat] = lhs_data[lhs_offset] + rhs_data[rhs_offset];
});
return result;
}
lhs += rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator+(Tensor<T> lhs, Scalar rhs) {
lhs += rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator+(Scalar lhs, Tensor<T> rhs) {
rhs += lhs;
return rhs;
}
template <typename T>
inline Tensor<T> operator-(Tensor<T> lhs, const Tensor<T>& rhs) {
if (lhs.shape() != rhs.shape()) {
const std::vector<int64_t> out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape());
Tensor<T> result(out_shape);
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* lhs_data = lhs.data();
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(out_shape,
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) {
result[flat] = lhs_data[lhs_offset] - rhs_data[rhs_offset];
});
return result;
}
lhs -= rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator-(Tensor<T> lhs, Scalar rhs) {
lhs -= rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator-(Scalar lhs, const Tensor<T>& rhs) {
Tensor<T> result = rhs;
const T value = static_cast<T>(lhs);
for (int64_t i = 0; i < result.numel(); ++i) {
result[i] = value - result[i];
}
return result;
}
template <typename T>
inline Tensor<T> operator*(Tensor<T> lhs, const Tensor<T>& rhs) {
if (lhs.shape() != rhs.shape()) {
const std::vector<int64_t> out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape());
Tensor<T> result(out_shape);
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* lhs_data = lhs.data();
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(out_shape,
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) {
result[flat] = lhs_data[lhs_offset] * rhs_data[rhs_offset];
});
return result;
}
lhs *= rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator*(Tensor<T> lhs, Scalar rhs) {
lhs *= rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator*(Scalar lhs, Tensor<T> rhs) {
rhs *= lhs;
return rhs;
}
template <typename T>
inline Tensor<T> operator/(Tensor<T> lhs, const Tensor<T>& rhs) {
if (lhs.shape() != rhs.shape()) {
const std::vector<int64_t> out_shape = tensor_broadcast_shape(lhs.shape(), rhs.shape());
Tensor<T> result(out_shape);
const std::vector<int64_t> lhs_strides = tensor_compute_strides(lhs.shape());
const std::vector<int64_t> rhs_strides = tensor_compute_strides(rhs.shape());
const T* lhs_data = lhs.data();
const T* rhs_data = rhs.data();
tensor_for_each_broadcast_offset(out_shape,
lhs.shape(),
lhs_strides,
rhs.shape(),
rhs_strides,
[&](int64_t flat, int64_t lhs_offset, int64_t rhs_offset) {
result[flat] = lhs_data[lhs_offset] / rhs_data[rhs_offset];
});
return result;
}
lhs /= rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator/(Tensor<T> lhs, Scalar rhs) {
lhs /= rhs;
return lhs;
}
template <typename T, typename Scalar, typename = std::enable_if_t<std::is_arithmetic<Scalar>::value>>
inline Tensor<T> operator/(Scalar lhs, const Tensor<T>& rhs) {
Tensor<T> result = rhs;
const T value = static_cast<T>(lhs);
for (int64_t i = 0; i < result.numel(); ++i) {
result[i] = value / result[i];
}
return result;
}
template <typename T>
inline Tensor<T> operator-(const Tensor<T>& tensor) {
Tensor<T> result = tensor;
for (int64_t i = 0; i < result.numel(); ++i) {
result[i] = -result[i];
}
return result;
}
template <typename T>
inline Tensor<T> zeros(std::vector<int64_t> shape) {
return Tensor<T>::zeros(std::move(shape));
}
template <typename T>
inline Tensor<T> full(std::vector<int64_t> shape, const T& value) {
return Tensor<T>::full(std::move(shape), value);
}
template <typename T>
inline Tensor<T> randn(std::vector<int64_t> shape, const std::shared_ptr<RNG>& rng) {
return Tensor<T>::randn(std::move(shape), rng);
}
template <typename T>
inline Tensor<T> randn_like(const Tensor<T>& tensor, const std::shared_ptr<RNG>& rng) {
return Tensor<T>::randn(tensor.shape(), rng);
}
template <typename T>
inline std::vector<T> tensor_to_vector(const Tensor<T>& tensor) {
return tensor.values();
}
namespace ops {
enum class InterpolateMode {
Nearest,
};
inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) {
if (index < 0) {
index += dim_size;
}
return index;
}
template <typename T>
inline std::pair<int64_t, int64_t> resolve_slice_bounds(const Tensor<T>& input,
size_t dim,
int64_t start,
int64_t end) {
if (dim >= static_cast<size_t>(input.dim())) {
tensor_throw_invalid_argument("Tensor slice dimension out of range: dim=" +
std::to_string(dim) + ", rank=" +
std::to_string(input.dim()) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
int64_t dim_size = input.shape()[dim];
start = normalize_slice_bound(start, dim_size);
end = normalize_slice_bound(end, dim_size);
if (start < 0 || start > dim_size || end < 0 || end > dim_size || start > end) {
tensor_throw_invalid_argument("Tensor slice bounds out of range: dim=" +
std::to_string(dim) + ", start=" +
std::to_string(start) + ", end=" +
std::to_string(end) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
return {start, end};
}
template <typename T>
inline Tensor<T> exp(const Tensor<T>& input) {
Tensor<T> output(input.shape());
for (int64_t i = 0; i < input.numel(); ++i) {
output[i] = static_cast<T>(std::exp(static_cast<double>(input[i])));
}
return output;
}
template <typename T>
inline Tensor<T> clamp(const Tensor<T>& input, const T& min_value, const T& max_value) {
if (min_value > max_value) {
tensor_throw_invalid_argument("Tensor clamp requires min_value <= max_value");
}
Tensor<T> output(input.shape());
for (int64_t i = 0; i < input.numel(); ++i) {
output[i] = std::clamp(input[i], min_value, max_value);
}
return output;
}
template <typename T>
inline Tensor<T> round(const Tensor<T>& input) {
Tensor<T> output(input.shape());
for (int64_t i = 0; i < input.numel(); ++i) {
output[i] = static_cast<T>(std::round(static_cast<double>(input[i])));
}
return output;
}
template <typename T>
inline Tensor<T> slice(const Tensor<T>& input,
size_t dim,
int64_t start,
int64_t end) {
auto [resolved_start, resolved_end] = resolve_slice_bounds(input, dim, start, end);
std::vector<int64_t> out_shape = input.shape();
out_shape[dim] = resolved_end - resolved_start;
Tensor<T> output(out_shape);
if (output.numel() == 0) {
return output;
}
int64_t inner = 1;
for (size_t i = 0; i < dim; ++i) {
inner *= input.shape()[i];
}
int64_t outer = 1;
for (size_t i = dim + 1; i < static_cast<size_t>(input.dim()); ++i) {
outer *= input.shape()[i];
}
int64_t src_chunk = (resolved_end - resolved_start) * inner;
int64_t src_stride = input.shape()[dim] * inner;
for (int64_t i = 0; i < outer; ++i) {
const int64_t src_offset = i * src_stride + resolved_start * inner;
const int64_t dst_offset = i * src_chunk;
std::copy_n(input.data() + src_offset, src_chunk, output.data() + dst_offset);
}
return output;
}
template <typename T>
inline Tensor<T> narrow(const Tensor<T>& input,
size_t dim,
int64_t start,
int64_t length) {
if (length < 0) {
tensor_throw_invalid_argument("Tensor narrow requires non-negative length: length=" +
std::to_string(length) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
return slice(input, dim, start, start + length);
}
template <typename T>
inline void slice_assign(Tensor<T>* dst,
size_t dim,
int64_t start,
int64_t end,
const Tensor<T>& src) {
if (dst == nullptr) {
tensor_throw_invalid_argument("Tensor slice_assign requires non-null dst");
}
auto [resolved_start, resolved_end] = resolve_slice_bounds(*dst, dim, start, end);
if (src.dim() != dst->dim()) {
tensor_throw_invalid_argument("Tensor slice_assign requires matching rank: dst_shape=" +
tensor_shape_to_string(dst->shape()) + ", src_shape=" +
tensor_shape_to_string(src.shape()));
}
std::vector<int64_t> expected_shape = dst->shape();
expected_shape[dim] = resolved_end - resolved_start;
if (src.shape() != expected_shape) {
tensor_throw_invalid_argument("Tensor slice_assign requires matching source shape: dst_shape=" +
tensor_shape_to_string(dst->shape()) + ", src_shape=" +
tensor_shape_to_string(src.shape()) + ", expected_src_shape=" +
tensor_shape_to_string(expected_shape));
}
if (src.numel() == 0) {
return;
}
int64_t inner = 1;
for (size_t i = 0; i < dim; ++i) {
inner *= dst->shape()[i];
}
int64_t outer = 1;
for (size_t i = dim + 1; i < static_cast<size_t>(dst->dim()); ++i) {
outer *= dst->shape()[i];
}
int64_t dst_chunk = (resolved_end - resolved_start) * inner;
int64_t dst_stride = dst->shape()[dim] * inner;
for (int64_t i = 0; i < outer; ++i) {
const int64_t dst_offset = i * dst_stride + resolved_start * inner;
const int64_t src_offset = i * dst_chunk;
std::copy_n(src.data() + src_offset, dst_chunk, dst->data() + dst_offset);
}
}
template <typename T>
inline void fill_slice(Tensor<T>* dst,
size_t dim,
int64_t start,
int64_t end,
const T& value) {
if (dst == nullptr) {
tensor_throw_invalid_argument("Tensor fill_slice requires non-null dst");
}
auto [resolved_start, resolved_end] = resolve_slice_bounds(*dst, dim, start, end);
int64_t inner = 1;
for (size_t i = 0; i < dim; ++i) {
inner *= dst->shape()[i];
}
int64_t outer = 1;
for (size_t i = dim + 1; i < static_cast<size_t>(dst->dim()); ++i) {
outer *= dst->shape()[i];
}
int64_t chunk = (resolved_end - resolved_start) * inner;
int64_t stride = dst->shape()[dim] * inner;
for (int64_t i = 0; i < outer; ++i) {
const int64_t offset = i * stride + resolved_start * inner;
std::fill_n(dst->data() + offset, chunk, value);
}
}
template <typename T>
inline Tensor<T> interpolate(const Tensor<T>& input,
std::vector<int64_t> output_shape,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
if (input.shape() == output_shape) {
return input;
}
if (input.dim() != static_cast<int64_t>(output_shape.size())) {
tensor_throw_invalid_argument("Tensor interpolate requires matching rank: input_dim=" +
std::to_string(input.dim()) + ", output_dim=" +
std::to_string(output_shape.size()) + ", input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
for (size_t i = 0; i < output_shape.size(); ++i) {
if (output_shape[i] <= 0) {
tensor_throw_invalid_argument("Tensor interpolate output shape must be positive: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
if (input.shape()[i] <= 0) {
tensor_throw_invalid_argument("Tensor interpolate input shape must be positive: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
}
Tensor<T> output(std::move(output_shape));
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
}
output[flat] = input.index(input_coord);
}
return output;
}
template <typename T>
inline Tensor<T> interpolate(const Tensor<T>& input,
const std::optional<std::vector<int64_t>>& size,
const std::optional<std::vector<double>>& scale_factor,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
if (mode != InterpolateMode::Nearest) {
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" +
std::to_string(static_cast<int>(mode)));
}
if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (size.has_value() == scale_factor.has_value()) {
tensor_throw_invalid_argument("Tensor interpolate requires exactly one of size or scale_factor: input_shape=" +
tensor_shape_to_string(input.shape()));
}
std::vector<int64_t> output_shape = input.shape();
if (size.has_value()) {
if (size->empty() || size->size() > output_shape.size()) {
tensor_throw_invalid_argument("Tensor interpolate size must target low dimensions: input_shape=" +
tensor_shape_to_string(input.shape()) + ", size_rank=" +
std::to_string(size->size()));
}
for (size_t i = 0; i < size->size(); ++i) {
if ((*size)[i] <= 0) {
tensor_throw_invalid_argument("Tensor interpolate size must be positive: input_shape=" +
tensor_shape_to_string(input.shape()) + ", size=" +
tensor_shape_to_string(*size));
}
output_shape[i] = (*size)[i];
}
} else {
if (scale_factor->empty() || scale_factor->size() > output_shape.size()) {
tensor_throw_invalid_argument("Tensor interpolate scale_factor must target low dimensions: input_shape=" +
tensor_shape_to_string(input.shape()) + ", scale_factor_rank=" +
std::to_string(scale_factor->size()));
}
for (size_t i = 0; i < scale_factor->size(); ++i) {
if ((*scale_factor)[i] <= 0.0) {
tensor_throw_invalid_argument("Tensor interpolate scale_factor must be positive: input_shape=" +
tensor_shape_to_string(input.shape()));
}
output_shape[i] = static_cast<int64_t>(
std::floor(static_cast<double>(output_shape[i]) * (*scale_factor)[i]));
if (output_shape[i] <= 0) {
tensor_throw_invalid_argument("Tensor interpolate output shape must be positive: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape));
}
}
}
return interpolate(input, std::move(output_shape), mode, align_corners);
}
template <typename T>
inline Tensor<T> interpolate(const Tensor<T>& input,
const std::optional<std::vector<int64_t>>& size,
double scale_factor,
InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) {
return interpolate(input,
size,
std::vector<double>(size.has_value() ? size->size() : input.dim(), scale_factor),
mode,
align_corners);
}
template <typename T>
inline Tensor<T> concat(const Tensor<T>& lhs, const Tensor<T>& rhs, size_t dim) {
if (lhs.dim() != rhs.dim()) {
tensor_throw_invalid_argument("Tensor concat requires same rank: lhs_dim=" +
std::to_string(lhs.dim()) + ", rhs_dim=" +
std::to_string(rhs.dim()) + ", lhs_shape=" +
tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" +
tensor_shape_to_string(rhs.shape()));
}
if (dim >= static_cast<size_t>(lhs.dim())) {
tensor_throw_invalid_argument("Tensor concat dimension out of range: dim=" +
std::to_string(dim) + ", rank=" +
std::to_string(lhs.dim()) + ", lhs_shape=" +
tensor_shape_to_string(lhs.shape()));
}
std::vector<int64_t> out_shape = lhs.shape();
for (size_t i = 0; i < static_cast<size_t>(lhs.dim()); ++i) {
if (i == dim) {
continue;
}
if (lhs.shape()[i] != rhs.shape()[i]) {
tensor_throw_invalid_argument("Tensor concat requires matching non-concat dimensions: dim=" +
std::to_string(dim) + ", lhs_shape=" +
tensor_shape_to_string(lhs.shape()) + ", rhs_shape=" +
tensor_shape_to_string(rhs.shape()));
}
}
out_shape[dim] += rhs.shape()[dim];
Tensor<T> out(out_shape);
int64_t inner = 1;
for (size_t i = 0; i < dim; ++i) {
inner *= lhs.shape()[i];
}
int64_t outer = 1;
for (size_t i = dim + 1; i < static_cast<size_t>(lhs.dim()); ++i) {
outer *= lhs.shape()[i];
}
int64_t lhs_chunk = lhs.shape()[dim] * inner;
int64_t rhs_chunk = rhs.shape()[dim] * inner;
int64_t out_chunk = lhs_chunk + rhs_chunk;
for (int64_t i = 0; i < outer; ++i) {
int64_t lhs_offset = i * lhs_chunk;
int64_t rhs_offset = i * rhs_chunk;
int64_t out_offset = i * out_chunk;
std::copy_n(lhs.data() + lhs_offset, lhs_chunk, out.data() + out_offset);
std::copy_n(rhs.data() + rhs_offset, rhs_chunk, out.data() + out_offset + lhs_chunk);
}
return out;
}
template <typename T>
inline std::vector<Tensor<T>> chunk(const Tensor<T>& tensor, int64_t chunks, size_t dim) {
if (chunks <= 0) {
tensor_throw_invalid_argument("Tensor chunk requires chunks > 0: chunks=" +
std::to_string(chunks) + ", tensor_shape=" +
tensor_shape_to_string(tensor.shape()));
}
if (dim >= static_cast<size_t>(tensor.dim())) {
tensor_throw_invalid_argument("Tensor chunk dimension out of range: dim=" +
std::to_string(dim) + ", rank=" +
std::to_string(tensor.dim()) + ", tensor_shape=" +
tensor_shape_to_string(tensor.shape()));
}
const int64_t dim_size = tensor.shape()[dim];
if (dim_size == 0) {
return {};
}
if (dim_size % chunks != 0) {
tensor_throw_invalid_argument("Tensor chunk requires the dimension size to be divisible by chunks: dim=" +
std::to_string(dim) + ", dim_size=" +
std::to_string(dim_size) + ", chunks=" +
std::to_string(chunks) + ", tensor_shape=" +
tensor_shape_to_string(tensor.shape()));
}
const int64_t chunk_size = dim_size / chunks;
int64_t inner = 1;
for (size_t i = 0; i < dim; ++i) {
inner *= tensor.shape()[i];
}
int64_t outer = 1;
for (size_t i = dim + 1; i < static_cast<size_t>(tensor.dim()); ++i) {
outer *= tensor.shape()[i];
}
std::vector<Tensor<T>> parts;
parts.reserve(static_cast<size_t>(chunks));
for (int64_t start = 0; start < dim_size; start += chunk_size) {
std::vector<int64_t> part_shape = tensor.shape();
part_shape[dim] = chunk_size;
Tensor<T> part(part_shape);
const int64_t src_chunk = chunk_size * inner;
const int64_t dst_chunk = src_chunk;
for (int64_t i = 0; i < outer; ++i) {
const int64_t src_offset = (i * dim_size + start) * inner;
const int64_t dst_offset = i * dst_chunk;
std::copy_n(tensor.data() + src_offset, src_chunk, part.data() + dst_offset);
}
parts.push_back(std::move(part));
}
return parts;
}
} // namespace ops
} // namespace sd
#endif