Skip to main content

svod_tensor/rand/
distributions.rs

1//! Distribution wrappers around `Tensor::rand`.
2//!
3//! | Method | Formula |
4//! |---|---|
5//! | `uniform(shape, low, high)` | `(high - low) * rand + low` |
6//! | `randn(shape)` | Box-Muller: `cos(2π·u₁) · √(-2·ln(1 - u₂))` |
7//! | `normal(shape, mean, std)` | `std · randn + mean` |
8//! | `randint(shape, low, high)` | `((high-low)·rand).cast(int32) + low` |
9//! | `scaled_uniform(shape)` | `uniform(-1, 1) · prod(shape)^-½` |
10//! | `glorot_uniform(shape)` | `uniform(-b, b)`, `b = √(6 / (shape[0] + prod(shape[1..])))` |
11//! | `kaiming_uniform(shape, a)` | `uniform(-b, b)`, `b = √(6 / ((1+a²) · prod(shape[1..])))` |
12//! | `kaiming_normal(shape, a)` | `randn · √(2 / ((1+a²) · prod(shape[1..])))` |
13
14use svod_dtype::DType;
15use svod_ir::ConstValue;
16
17use crate::{Error, Result, Tensor};
18
19const TWO_PI: f64 = 2.0 * std::f64::consts::PI;
20
21fn fan_in(shape: &[usize]) -> usize {
22    // fan_in = prod(shape[1..]). For 1D inputs (e.g., bias) the product over
23    // the empty slice is 1.
24    shape.iter().skip(1).copied().product::<usize>().max(1)
25}
26
27impl Tensor {
28    /// Uniform `[low, high)` random tensor, float32, on the default (CPU) device.
29    ///
30    /// Convenience wrapper around [`Tensor::uniform_with_dtype`] with f32 output.
31    pub fn uniform(shape: &[usize], low: f64, high: f64) -> Result<Tensor> {
32        Self::uniform_with_dtype(shape, low, high, DType::Float32)
33    }
34
35    /// Uniform `[low, high)` random tensor with explicit float dtype.
36    ///
37    /// Generates a `[0, 1)` sample at f32, scales by `(high - low)`, **casts
38    /// to the target dtype**, then adds `low`. Casting before the offset
39    /// keeps the addition honest in low-precision targets (f16/bf16) where
40    /// `low` might otherwise be lost to rounding if applied at f32.
41    pub fn uniform_with_dtype(shape: &[usize], low: f64, high: f64, dtype: DType) -> Result<Tensor> {
42        if low >= high {
43            return Err(Error::ParamRange {
44                op: "Tensor::uniform",
45                param: "low/high",
46                value: format!("low={low}, high={high}"),
47                constraint: "low < high",
48            });
49        }
50        let u = Tensor::rand(shape)?;
51        let scale = u.broadcast_scalar(ConstValue::Float(high - low))?;
52        let scaled = u.try_mul(&scale)?.cast(dtype)?;
53        let offset = scaled.broadcast_scalar(ConstValue::Float(low))?;
54        scaled.try_add(&offset)
55    }
56
57    /// Standard normal `N(0, 1)` random tensor (float32, Box-Muller).
58    ///
59    /// Each output element draws from two `[0, 1)` uniforms via one combined
60    /// `rand([2, *shape])` call, so the RNG counter advances exactly once per
61    /// `randn` invocation regardless of `shape`.
62    pub fn randn(shape: &[usize]) -> Result<Tensor> {
63        // src = rand([2, *shape])  →  one counter advance for two halves.
64        let mut combined_shape: Vec<usize> = Vec::with_capacity(shape.len() + 1);
65        combined_shape.push(2);
66        combined_shape.extend_from_slice(shape);
67        let src = Tensor::rand(&combined_shape)?;
68
69        // u1 = src[0:1, …] reshaped to shape; u2 = src[1:2, …] reshaped to shape.
70        let mut shrink_u1: Vec<Option<(isize, isize)>> = Vec::with_capacity(combined_shape.len());
71        shrink_u1.push(Some((0, 1)));
72        shrink_u1.extend(std::iter::repeat_n(None, shape.len()));
73        let mut shrink_u2: Vec<Option<(isize, isize)>> = Vec::with_capacity(combined_shape.len());
74        shrink_u2.push(Some((1, 2)));
75        shrink_u2.extend(std::iter::repeat_n(None, shape.len()));
76        let target_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
77        let u1 = src.try_shrink(shrink_u1)?.try_reshape(&target_shape)?;
78        let u2 = src.try_shrink(shrink_u2)?.try_reshape(&target_shape)?;
79
80        // Box-Muller: cos(2π·u1) · √(-2·ln(1 - u2))
81        let two_pi = u1.broadcast_scalar(ConstValue::Float(TWO_PI))?;
82        let theta = u1.try_mul(&two_pi)?.cos()?;
83        let one = u2.broadcast_scalar(ConstValue::Float(1.0))?;
84        let neg_two = u2.broadcast_scalar(ConstValue::Float(-2.0))?;
85        let r = one.try_sub(&u2)?.try_log()?.try_mul(&neg_two)?.try_sqrt()?;
86        theta.try_mul(&r)
87    }
88
89    /// Normal `N(mean, std)` random tensor. Requires `std >= 0`.
90    pub fn normal(shape: &[usize], mean: f64, std: f64) -> Result<Tensor> {
91        if std < 0.0 {
92            return Err(Error::ParamRange {
93                op: "Tensor::normal",
94                param: "std",
95                value: format!("{std}"),
96                constraint: ">= 0",
97            });
98        }
99        let z = Tensor::randn(shape)?;
100        let std_t = z.broadcast_scalar(ConstValue::Float(std))?;
101        let mean_t = z.broadcast_scalar(ConstValue::Float(mean))?;
102        z.try_mul(&std_t)?.try_add(&mean_t)
103    }
104
105    /// Uniform integer tensor `[low, high)`, dtype `int32`. Requires `low < high`.
106    ///
107    /// Truncates `(high - low) · rand` to int32 **before** adding `low`.
108    /// Casting after the add would truncate-toward-zero asymmetrically for
109    /// negative `low` (e.g. `low=-3, rand≈0.005` would yield `-2` instead of
110    /// the correct `-3`).
111    pub fn randint(shape: &[usize], low: i64, high: i64) -> Result<Tensor> {
112        if low >= high {
113            return Err(Error::ParamRange {
114                op: "Tensor::randint",
115                param: "low/high",
116                value: format!("low={low}, high={high}"),
117                constraint: "low < high",
118            });
119        }
120        let scaled = Tensor::rand(shape)?;
121        let range = scaled.broadcast_scalar(ConstValue::Float((high - low) as f64))?;
122        let truncated = scaled.try_mul(&range)?.cast(DType::Int32)?;
123        let offset = truncated.broadcast_scalar(ConstValue::Int(low))?;
124        truncated.try_add(&offset)
125    }
126
127    /// `uniform(-1, 1) · prod(shape)^(-½)`. Same dtype contract as `uniform`.
128    pub fn scaled_uniform(shape: &[usize]) -> Result<Tensor> {
129        let numel: usize = shape.iter().copied().product::<usize>().max(1);
130        let scale = (numel as f64).powf(-0.5);
131        let u = Tensor::uniform(shape, -1.0, 1.0)?;
132        let scale_t = u.broadcast_scalar(ConstValue::Float(scale))?;
133        u.try_mul(&scale_t)
134    }
135
136    /// Glorot/Xavier uniform initializer, float32 output.
137    pub fn glorot_uniform(shape: &[usize]) -> Result<Tensor> {
138        Self::glorot_uniform_with_dtype(shape, DType::Float32)
139    }
140
141    /// Glorot/Xavier uniform initializer with explicit dtype.
142    /// `bound = √(6 / (shape[0] + prod(shape[1..]))); uniform(-bound, bound)`.
143    pub fn glorot_uniform_with_dtype(shape: &[usize], dtype: DType) -> Result<Tensor> {
144        if shape.is_empty() {
145            return Err(Error::ParamRange {
146                op: "Tensor::glorot_uniform",
147                param: "shape",
148                value: "[]".to_string(),
149                constraint: "at least 1D",
150            });
151        }
152        let fan_in_v = fan_in(shape);
153        let fan_out_v = shape[0];
154        let bound = (6.0 / (fan_out_v + fan_in_v) as f64).sqrt();
155        Self::uniform_with_dtype(shape, -bound, bound, dtype)
156    }
157
158    /// Kaiming/He uniform initializer for ReLU-family activations, float32 output.
159    pub fn kaiming_uniform(shape: &[usize], a: f64) -> Result<Tensor> {
160        Self::kaiming_uniform_with_dtype(shape, a, DType::Float32)
161    }
162
163    /// Kaiming/He uniform initializer with explicit dtype.
164    ///
165    /// `bound = √(6 / ((1 + a²) · prod(shape[1..]))); uniform(-bound, bound)`.
166    ///
167    /// `a` is the negative slope of the activation:
168    /// - `0.0` — plain ReLU (PyTorch default).
169    /// - `0.01` — leaky-ReLU with default slope.
170    pub fn kaiming_uniform_with_dtype(shape: &[usize], a: f64, dtype: DType) -> Result<Tensor> {
171        if shape.is_empty() {
172            return Err(Error::ParamRange {
173                op: "Tensor::kaiming_uniform",
174                param: "shape",
175                value: "[]".to_string(),
176                constraint: "at least 1D",
177            });
178        }
179        let bound = (6.0 / ((1.0 + a * a) * fan_in(shape) as f64)).sqrt();
180        Self::uniform_with_dtype(shape, -bound, bound, dtype)
181    }
182
183    /// Kaiming/He normal initializer for ReLU-family activations.
184    /// `std = √(2 / ((1 + a²) · prod(shape[1..]))); randn · std`.
185    pub fn kaiming_normal(shape: &[usize], a: f64) -> Result<Tensor> {
186        if shape.is_empty() {
187            return Err(Error::ParamRange {
188                op: "Tensor::kaiming_normal",
189                param: "shape",
190                value: "[]".to_string(),
191                constraint: "at least 1D",
192            });
193        }
194        let std = (2.0 / ((1.0 + a * a) * fan_in(shape) as f64)).sqrt();
195        let z = Tensor::randn(shape)?;
196        let std_t = z.broadcast_scalar(ConstValue::Float(std))?;
197        z.try_mul(&std_t)
198    }
199}