svod_tensor/rand/
primitive.rs1use snafu::ResultExt;
21use svod_dtype::{DType, DeviceSpec, ScalarDType};
22use svod_ir::{ConstValue, UOp, shape::Shape, shape::to_vec_usize};
23
24use crate::{Error, Result, Tensor, UOpSnafu};
25
26use super::state;
27
28impl Tensor {
29 pub fn rand(shape: &[usize]) -> Result<Tensor> {
34 Self::rand_with(shape, DType::Float32, DeviceSpec::Cpu)
35 }
36
37 pub fn rand_with(shape: &[usize], dtype: DType, device: DeviceSpec) -> Result<Tensor> {
42 let scalar = dtype.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
43 operation: format!("Tensor::rand: non-scalar dtype {dtype:?}"),
44 })?;
45 if !scalar.is_float() {
46 return Err(Error::SymbolicShapeUnsupported {
47 operation: format!(
48 "Tensor::rand: float dtype required, got {scalar:?}; use Tensor::randint for integers"
49 ),
50 });
51 }
52 let numel: usize = shape.iter().product();
53 if numel == 0 {
54 return Tensor::zeros(shape, dtype);
55 }
56 let num_words = (numel * scalar.bytes()).div_ceil(4) as u64;
58 let (seed, counter_val) = state::next_counter(&device, num_words);
59 let bits = random_bits(&seed, counter_val, num_words as usize)?;
60 bits_to_rand(&bits, shape, dtype)
61 }
62}
63
64fn random_bits(seed: &Tensor, counter_val: u64, num: usize) -> Result<Tensor> {
68 let u32_dt = DType::Scalar(ScalarDType::UInt32);
69
70 let c_low = Tensor::full(&[1], (counter_val & 0xFFFF_FFFF) as u32, u32_dt.clone())?;
73 let c_high = Tensor::full(&[1], (counter_val >> 32) as u32, u32_dt.clone())?;
74
75 let new_key = threefry_random_bits(seed, &c_low, &c_high)?;
77
78 let half = num.div_ceil(2);
82 let counts0 = Tensor::arange(0, Some(half as i64), None)?.cast(u32_dt.clone())?;
83 let half_t = Tensor::full(&[half], half as u32, u32_dt)?;
84 let counts1 = counts0.try_add(&half_t)?;
85
86 let bits_full = threefry_random_bits(&new_key, &counts0, &counts1)?;
88 bits_full.try_shrink([(0usize, num)])
89}
90
91pub(crate) fn threefry_random_bits(key: &Tensor, counts0: &Tensor, counts1: &Tensor) -> Result<Tensor> {
98 let u32_dt = DType::Scalar(ScalarDType::UInt32);
99 let u64_dt = DType::Scalar(ScalarDType::UInt64);
100 let counts_shape: Shape = counts0.shape()?;
101
102 let shift_32 = Tensor::full(&to_vec_usize(&counts_shape).context(UOpSnafu)?, 32u32, u64_dt.clone())?;
103
104 let c0_u64 = counts0.cast(u64_dt.clone())?;
106 let c1_u64 = counts1.cast(u64_dt.clone())?;
107 let c1_shifted = c1_u64.try_shl(&shift_32)?;
108 let x = c1_shifted.try_bitor(&c0_u64)?;
109
110 let k_shift_32 = Tensor::full(&[1], 32u32, u64_dt.clone())?;
112 let k0 = key.try_shrink([(0usize, 1usize)])?.cast(u64_dt.clone())?;
113 let k1 = key.try_shrink([(1usize, 2usize)])?.cast(u64_dt.clone())?;
114 let key_packed = k1.try_shl(&k_shift_32)?.try_bitor(&k0)?;
115 let key_broadcast = key_packed.broadcast_to(&counts_shape)?;
116
117 let result_uop = UOp::threefry(x.uop().clone(), key_broadcast.uop().clone()).context(UOpSnafu)?;
120 let result = Tensor::from_lazy(result_uop);
121
122 let mask_u64 = Tensor::full(&to_vec_usize(&counts_shape).context(UOpSnafu)?, 0xFFFF_FFFFu64, u64_dt)?;
124 let lo_u64 = result.try_bitand(&mask_u64)?;
125 let lo = lo_u64.cast(u32_dt.clone())?;
126 let hi_u64 = result.try_shr(&shift_32)?.try_bitand(&mask_u64)?;
127 let hi = hi_u64.cast(u32_dt)?;
128
129 Tensor::cat(&[&lo, &hi], 0)
130}
131
132fn bits_to_rand(bits: &Tensor, shape: &[usize], dtype: DType) -> Result<Tensor> {
141 let scalar = dtype.scalar().expect("scalar dtype validated by rand_with");
142 let (_, nmant) = scalar.finfo();
143 let uint_dt = DType::Scalar(scalar.float_to_uint());
144 let total_bits = (scalar.bytes() * 8) as u32;
145 let shift = total_bits - nmant;
146
147 let uint_bits = bits.bitcast(uint_dt.clone())?;
150 let bits_shape_concrete = to_vec_usize(&uint_bits.shape()?).context(UOpSnafu)?;
151
152 let shift_t = Tensor::full(&bits_shape_concrete, ConstValue::UInt(shift as u64), uint_dt.clone())?;
153 let shifted = uint_bits.try_shr(&shift_t)?;
154
155 let one_bits = ConstValue::UInt(one_bits_for(scalar));
156 let one_bits_t = Tensor::full(&bits_shape_concrete, one_bits, uint_dt)?;
157 let or_ed = shifted.try_bitor(&one_bits_t)?;
158
159 let in_one_two = or_ed.bitcast(dtype.clone())?;
160 let one_f = Tensor::full(&bits_shape_concrete, ConstValue::Float(1.0), dtype)?;
161 let in_unit = in_one_two.try_sub(&one_f)?;
162
163 let numel: usize = shape.iter().product();
166 let trimmed = in_unit.try_shrink([(0usize, numel)])?;
167 let isize_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
168 trimmed.try_reshape(&isize_shape)
169}
170
171fn one_bits_for(s: ScalarDType) -> u64 {
175 match s {
176 ScalarDType::Float16 => 0x3C00,
177 ScalarDType::BFloat16 => 0x3F80,
178 ScalarDType::Float32 => 0x3F80_0000,
179 ScalarDType::Float64 => 0x3FF0_0000_0000_0000,
180 _ => panic!("one_bits_for: non-float dtype {s:?}"),
181 }
182}