Skip to main content

svod_tensor/rand/
primitive.rs

1//! `Tensor::rand` — uniform `[0, 1)` float draws backed by `BinaryOp::Threefry`.
2//!
3//! Pipeline:
4//!
5//! 1. `state::next_counter(device, num)` returns `(seed_buf, counter_val)` —
6//!    counter advances atomically per call. `counter_val` is stamped into the
7//!    graph as a CONST u64; the seed remains a BUFFER, which is what prevents
8//!    the rand output from const-folding.
9//! 2. Derive a per-call `new_key` via one THREEFRY pass over `(c_low, c_high)`
10//!    against the seed.
11//! 3. Build `counts0 = arange(num/2)`, `counts1 = counts0 + num/2`, then run
12//!    the bulk THREEFRY pass to produce `num` uint32 random words.
13//! 4. `bits_to_rand`: shift-right by `(bitsize - mantissa_bits)`, OR with the
14//!    bit pattern of `1.0`, bitcast back to the target float dtype, then
15//!    subtract 1.0 to land in `[0, 1)`.
16//!
17//! Supports `Float16`, `BFloat16`, `Float32`, `Float64` on `DeviceSpec::Cpu`.
18//! Multi-device support is a straightforward extension of the same pipeline.
19
20use snafu::ResultExt;
21use svod_dtype::{DType, DeviceSpec, ScalarDType};
22use svod_ir::{ConstValue, UOp, shape::Shape, shape::to_vec_usize};
23
24use crate::{Error, Result, Tensor, UOpSnafu};
25
26use super::state;
27
28impl Tensor {
29    /// Uniform `[0, 1)` random tensor with float32 dtype on the default CPU device.
30    ///
31    /// THREEFRY-backed; deterministic for a fixed seed (set via
32    /// [`crate::rand::manual_seed`]).
33    pub fn rand(shape: &[usize]) -> Result<Tensor> {
34        Self::rand_with(shape, DType::Float32, DeviceSpec::Cpu)
35    }
36
37    /// Variant of [`Tensor::rand`] with explicit dtype and device.
38    ///
39    /// Supported dtypes: `Float16`, `BFloat16`, `Float32`, `Float64`. Integer
40    /// dtypes are not supported here — use `Tensor::randint` instead.
41    pub fn rand_with(shape: &[usize], dtype: DType, device: DeviceSpec) -> Result<Tensor> {
42        let scalar = dtype.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
43            operation: format!("Tensor::rand: non-scalar dtype {dtype:?}"),
44        })?;
45        if !scalar.is_float() {
46            return Err(Error::SymbolicShapeUnsupported {
47                operation: format!(
48                    "Tensor::rand: float dtype required, got {scalar:?}; use Tensor::randint for integers"
49                ),
50            });
51        }
52        let numel: usize = shape.iter().product();
53        if numel == 0 {
54            return Tensor::zeros(shape, dtype);
55        }
56        // Number of uint32 words needed to cover `numel * itemsize` bytes.
57        let num_words = (numel * scalar.bytes()).div_ceil(4) as u64;
58        let (seed, counter_val) = state::next_counter(&device, num_words);
59        let bits = random_bits(&seed, counter_val, num_words as usize)?;
60        bits_to_rand(&bits, shape, dtype)
61    }
62}
63
64/// Produce `num` uint32 random words by stamping `counter_val` as a CONST u64
65/// into the THREEFRY graph. Single-chunk (no outer loop): `num` is bounded by
66/// `usize`, which is more than enough for any realistic tensor shape.
67fn random_bits(seed: &Tensor, counter_val: u64, num: usize) -> Result<Tensor> {
68    let u32_dt = DType::Scalar(ScalarDType::UInt32);
69
70    // c_low, c_high as `[1]` u32 Tensors. CONST is fine — the per-call key
71    // derivation only depends on `seed` (BUFFER) for non-foldability.
72    let c_low = Tensor::full(&[1], (counter_val & 0xFFFF_FFFF) as u32, u32_dt.clone())?;
73    let c_high = Tensor::full(&[1], (counter_val >> 32) as u32, u32_dt.clone())?;
74
75    // Step 1: per-call key derivation. THREEFRY(seed, [c_low, c_high]) → `[2]` u32.
76    let new_key = threefry_random_bits(seed, &c_low, &c_high)?;
77
78    // Step 2: build counts0 = arange(half), counts1 = counts0 + half.
79    // `arange_with_dtype` only fast-paths on `ConstValue::Int`, so go through
80    // the i64 entry point and cast to u32.
81    let half = num.div_ceil(2);
82    let counts0 = Tensor::arange(0, Some(half as i64), None)?.cast(u32_dt.clone())?;
83    let half_t = Tensor::full(&[half], half as u32, u32_dt)?;
84    let counts1 = counts0.try_add(&half_t)?;
85
86    // Step 3: bulk THREEFRY pass. Returns `[2 * half]` u32; truncate to `num`.
87    let bits_full = threefry_random_bits(&new_key, &counts0, &counts1)?;
88    bits_full.try_shrink([(0usize, num)])
89}
90
91/// Pack `(counts1, counts0)` into u64, run THREEFRY against a broadcast u64
92/// key, then split the result back into `(low_u32, high_u32)` and concat. The
93/// result shape is `[2 * counts0.len()]` u32.
94///
95/// Pub(crate) so the test module can pin its output against JAX's
96/// `jax.extend.random.threefry_2x32` reference values.
97pub(crate) fn threefry_random_bits(key: &Tensor, counts0: &Tensor, counts1: &Tensor) -> Result<Tensor> {
98    let u32_dt = DType::Scalar(ScalarDType::UInt32);
99    let u64_dt = DType::Scalar(ScalarDType::UInt64);
100    let counts_shape: Shape = counts0.shape()?;
101
102    let shift_32 = Tensor::full(&to_vec_usize(&counts_shape).context(UOpSnafu)?, 32u32, u64_dt.clone())?;
103
104    // x = (counts1 << 32) | counts0  (u64)
105    let c0_u64 = counts0.cast(u64_dt.clone())?;
106    let c1_u64 = counts1.cast(u64_dt.clone())?;
107    let c1_shifted = c1_u64.try_shl(&shift_32)?;
108    let x = c1_shifted.try_bitor(&c0_u64)?;
109
110    // key_packed = (key[1] << 32) | key[0], then broadcast from [1] to counts_shape.
111    let k_shift_32 = Tensor::full(&[1], 32u32, u64_dt.clone())?;
112    let k0 = key.try_shrink([(0usize, 1usize)])?.cast(u64_dt.clone())?;
113    let k1 = key.try_shrink([(1usize, 2usize)])?.cast(u64_dt.clone())?;
114    let key_packed = k1.try_shl(&k_shift_32)?.try_bitor(&k0)?;
115    let key_broadcast = key_packed.broadcast_to(&counts_shape)?;
116
117    // THREEFRY at the UOp level (no Tensor-level wrapper for it yet — the
118    // op is otherwise only used inside the rangeify decomp).
119    let result_uop = UOp::threefry(x.uop().clone(), key_broadcast.uop().clone()).context(UOpSnafu)?;
120    let result = Tensor::from_lazy(result_uop);
121
122    // Split each u64 result into two u32s, concat: `[low_0, …, low_{N-1}, high_0, …, high_{N-1}]`.
123    let mask_u64 = Tensor::full(&to_vec_usize(&counts_shape).context(UOpSnafu)?, 0xFFFF_FFFFu64, u64_dt)?;
124    let lo_u64 = result.try_bitand(&mask_u64)?;
125    let lo = lo_u64.cast(u32_dt.clone())?;
126    let hi_u64 = result.try_shr(&shift_32)?.try_bitand(&mask_u64)?;
127    let hi = hi_u64.cast(u32_dt)?;
128
129    Tensor::cat(&[&lo, &hi], 0)
130}
131
132/// Convert raw u32 random bits → float in `[0, 1)` via mantissa-fill.
133///
134/// `bits` has shape `[num_u32_words]`. Output has shape `shape` and dtype `dtype`.
135///
136/// For f32 this is straightforward (u32 → u32 = identity bitcast → shift+OR →
137/// bitcast to f32). For f16/bf16 (2 bytes) we bitcast u32 → u16 which doubles
138/// the element count; for f64 (8 bytes) we bitcast u32 → u64 which halves it.
139/// The size-changing bitcast is handled by [`Tensor::bitcast`].
140fn bits_to_rand(bits: &Tensor, shape: &[usize], dtype: DType) -> Result<Tensor> {
141    let scalar = dtype.scalar().expect("scalar dtype validated by rand_with");
142    let (_, nmant) = scalar.finfo();
143    let uint_dt = DType::Scalar(scalar.float_to_uint());
144    let total_bits = (scalar.bytes() * 8) as u32;
145    let shift = total_bits - nmant;
146
147    // Bitcast u32 bits → matching uint of the target float dtype. May change
148    // element count when `scalar.bytes() != 4`.
149    let uint_bits = bits.bitcast(uint_dt.clone())?;
150    let bits_shape_concrete = to_vec_usize(&uint_bits.shape()?).context(UOpSnafu)?;
151
152    let shift_t = Tensor::full(&bits_shape_concrete, ConstValue::UInt(shift as u64), uint_dt.clone())?;
153    let shifted = uint_bits.try_shr(&shift_t)?;
154
155    let one_bits = ConstValue::UInt(one_bits_for(scalar));
156    let one_bits_t = Tensor::full(&bits_shape_concrete, one_bits, uint_dt)?;
157    let or_ed = shifted.try_bitor(&one_bits_t)?;
158
159    let in_one_two = or_ed.bitcast(dtype.clone())?;
160    let one_f = Tensor::full(&bits_shape_concrete, ConstValue::Float(1.0), dtype)?;
161    let in_unit = in_one_two.try_sub(&one_f)?;
162
163    // Bits-to-floats may produce more elements than needed (e.g. odd-numel f16
164    // doubles to even count). Truncate to `numel` then reshape.
165    let numel: usize = shape.iter().product();
166    let trimmed = in_unit.try_shrink([(0usize, numel)])?;
167    let isize_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
168    trimmed.try_reshape(&isize_shape)
169}
170
171/// Bit pattern of `1.0` in the given float dtype, widened to `u64` for the
172/// (uniform) `ConstValue::UInt` constructor. Values are well-known and
173/// verifiable via `half::f16::from_f32(1.0).to_bits()` etc.
174fn one_bits_for(s: ScalarDType) -> u64 {
175    match s {
176        ScalarDType::Float16 => 0x3C00,
177        ScalarDType::BFloat16 => 0x3F80,
178        ScalarDType::Float32 => 0x3F80_0000,
179        ScalarDType::Float64 => 0x3FF0_0000_0000_0000,
180        _ => panic!("one_bits_for: non-float dtype {s:?}"),
181    }
182}