#ifndef __RNG_MT19937_HPP__ #define __RNG_MT19937_HPP__ #include #include #include "rng.hpp" // RNG imitiating torch cpu randn on CPU. // Port from pytorch, original license: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/LICENSE // Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/TransformationHelper.h, for uniform_real // Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/native/cpu/DistributionTemplates.h, for normal_kernel/normal_fill/normal_fill_16 // Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/MT19937RNGEngine.h, for mt19937_engine // Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/DistributionsHelper.h, for uniform_real_distribution/normal_distribution class MT19937RNG : public RNG { static const int N = 624; static const int M = 397; static const uint32_t MATRIX_A = 0x9908b0dfU; static const uint32_t UMASK = 0x80000000U; static const uint32_t LMASK = 0x7fffffffU; struct State { uint64_t seed_; int left_; bool seeded_; uint32_t next_; std::array state_; bool has_next_gauss = false; double next_gauss = 0.0f; }; State s; uint32_t mix_bits(uint32_t u, uint32_t v) { return (u & UMASK) | (v & LMASK); } uint32_t twist(uint32_t u, uint32_t v) { return (mix_bits(u, v) >> 1) ^ ((v & 1) ? MATRIX_A : 0); } void next_state() { uint32_t* p = s.state_.data(); s.left_ = N; s.next_ = 0; for (int j = N - M + 1; --j; p++) p[0] = p[M] ^ twist(p[0], p[1]); for (int j = M; --j; p++) p[0] = p[M - N] ^ twist(p[0], p[1]); p[0] = p[M - N] ^ twist(p[0], s.state_[0]); } uint32_t rand_uint32() { if (--s.left_ == 0) next_state(); uint32_t y = s.state_[s.next_++]; y ^= (y >> 11); y ^= (y << 7) & 0x9d2c5680U; y ^= (y << 15) & 0xefc60000U; y ^= (y >> 18); return y; } uint64_t rand_uint64() { uint64_t high = (uint64_t)rand_uint32(); uint64_t low = (uint64_t)rand_uint32(); return (high << 32) | low; } template T uniform_real(V val, T from, T to) { constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); constexpr auto DIVISOR = static_cast(1) / (static_cast(1) << std::numeric_limits::digits); T x = (val & MASK) * DIVISOR; return (x * (to - from) + from); } double normal_double_value(double mean, double std) { if (s.has_next_gauss) { s.has_next_gauss = false; return s.next_gauss; } double u1 = uniform_real(rand_uint64(), 0., 1.); // double double u2 = uniform_real(rand_uint64(), 0., 1.); // double double r = std::sqrt(-2.0 * std::log1p(-u2)); double theta = 2.0 * 3.14159265358979323846 * u1; double value = r * std::cos(theta) * std + mean; s.next_gauss = r * std::sin(theta) * std + mean; s.has_next_gauss = true; return value; } void normal_fill_16(float* data, float mean, float std) { for (int j = 0; j < 8; ++j) { float u1 = 1.0f - data[j]; float u2 = data[j + 8]; float r = std::sqrt(-2.0f * std::log(u1)); float theta = 2.0f * 3.14159265358979323846 * u2; data[j] = r * std::cos(theta) * std + mean; data[j + 8] = r * std::sin(theta) * std + mean; } } void randn(float* data, int64_t size, float mean = 0.0f, float std = 1.0f) { if (size >= 16) { for (int64_t i = 0; i < size; i++) { data[i] = uniform_real(rand_uint32(), 0.f, 1.f); } for (int64_t i = 0; i < size - 15; i += 16) { normal_fill_16(data + i, mean, std); } if (size % 16 != 0) { // Recompute the last 16 values. data = data + size - 16; for (int64_t i = 0; i < 16; i++) { data[i] = uniform_real(rand_uint32(), 0.f, 1.f); } normal_fill_16(data, mean, std); } } else { // Strange handling, hard to understand, but keeping it consistent with PyTorch. for (int64_t i = 0; i < size; i++) { data[i] = (float)normal_double_value(mean, std); } } } public: MT19937RNG(uint64_t seed = 0) { manual_seed(seed); } void manual_seed(uint64_t seed) override { s.seed_ = seed; s.seeded_ = true; s.state_[0] = (uint32_t)(seed & 0xffffffffU); for (int j = 1; j < N; j++) { uint32_t prev = s.state_[j - 1]; s.state_[j] = 1812433253U * (prev ^ (prev >> 30)) + j; } s.left_ = 1; s.next_ = 0; s.has_next_gauss = false; } std::vector randn(uint32_t n) override { std::vector out; out.resize(n); randn((float*)out.data(), out.size()); return out; } }; #endif // __RNG_MT19937_HPP__