svod_tensor/rand/
distributions.rs1use svod_dtype::DType;
15use svod_ir::ConstValue;
16
17use crate::{Error, Result, Tensor};
18
19const TWO_PI: f64 = 2.0 * std::f64::consts::PI;
20
21fn fan_in(shape: &[usize]) -> usize {
22 shape.iter().skip(1).copied().product::<usize>().max(1)
25}
26
27impl Tensor {
28 pub fn uniform(shape: &[usize], low: f64, high: f64) -> Result<Tensor> {
32 Self::uniform_with_dtype(shape, low, high, DType::Float32)
33 }
34
35 pub fn uniform_with_dtype(shape: &[usize], low: f64, high: f64, dtype: DType) -> Result<Tensor> {
42 if low >= high {
43 return Err(Error::ParamRange {
44 op: "Tensor::uniform",
45 param: "low/high",
46 value: format!("low={low}, high={high}"),
47 constraint: "low < high",
48 });
49 }
50 let u = Tensor::rand(shape)?;
51 let scale = u.broadcast_scalar(ConstValue::Float(high - low))?;
52 let scaled = u.try_mul(&scale)?.cast(dtype)?;
53 let offset = scaled.broadcast_scalar(ConstValue::Float(low))?;
54 scaled.try_add(&offset)
55 }
56
57 pub fn randn(shape: &[usize]) -> Result<Tensor> {
63 let mut combined_shape: Vec<usize> = Vec::with_capacity(shape.len() + 1);
65 combined_shape.push(2);
66 combined_shape.extend_from_slice(shape);
67 let src = Tensor::rand(&combined_shape)?;
68
69 let mut shrink_u1: Vec<Option<(isize, isize)>> = Vec::with_capacity(combined_shape.len());
71 shrink_u1.push(Some((0, 1)));
72 shrink_u1.extend(std::iter::repeat_n(None, shape.len()));
73 let mut shrink_u2: Vec<Option<(isize, isize)>> = Vec::with_capacity(combined_shape.len());
74 shrink_u2.push(Some((1, 2)));
75 shrink_u2.extend(std::iter::repeat_n(None, shape.len()));
76 let target_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
77 let u1 = src.try_shrink(shrink_u1)?.try_reshape(&target_shape)?;
78 let u2 = src.try_shrink(shrink_u2)?.try_reshape(&target_shape)?;
79
80 let two_pi = u1.broadcast_scalar(ConstValue::Float(TWO_PI))?;
82 let theta = u1.try_mul(&two_pi)?.cos()?;
83 let one = u2.broadcast_scalar(ConstValue::Float(1.0))?;
84 let neg_two = u2.broadcast_scalar(ConstValue::Float(-2.0))?;
85 let r = one.try_sub(&u2)?.try_log()?.try_mul(&neg_two)?.try_sqrt()?;
86 theta.try_mul(&r)
87 }
88
89 pub fn normal(shape: &[usize], mean: f64, std: f64) -> Result<Tensor> {
91 if std < 0.0 {
92 return Err(Error::ParamRange {
93 op: "Tensor::normal",
94 param: "std",
95 value: format!("{std}"),
96 constraint: ">= 0",
97 });
98 }
99 let z = Tensor::randn(shape)?;
100 let std_t = z.broadcast_scalar(ConstValue::Float(std))?;
101 let mean_t = z.broadcast_scalar(ConstValue::Float(mean))?;
102 z.try_mul(&std_t)?.try_add(&mean_t)
103 }
104
105 pub fn randint(shape: &[usize], low: i64, high: i64) -> Result<Tensor> {
112 if low >= high {
113 return Err(Error::ParamRange {
114 op: "Tensor::randint",
115 param: "low/high",
116 value: format!("low={low}, high={high}"),
117 constraint: "low < high",
118 });
119 }
120 let scaled = Tensor::rand(shape)?;
121 let range = scaled.broadcast_scalar(ConstValue::Float((high - low) as f64))?;
122 let truncated = scaled.try_mul(&range)?.cast(DType::Int32)?;
123 let offset = truncated.broadcast_scalar(ConstValue::Int(low))?;
124 truncated.try_add(&offset)
125 }
126
127 pub fn scaled_uniform(shape: &[usize]) -> Result<Tensor> {
129 let numel: usize = shape.iter().copied().product::<usize>().max(1);
130 let scale = (numel as f64).powf(-0.5);
131 let u = Tensor::uniform(shape, -1.0, 1.0)?;
132 let scale_t = u.broadcast_scalar(ConstValue::Float(scale))?;
133 u.try_mul(&scale_t)
134 }
135
136 pub fn glorot_uniform(shape: &[usize]) -> Result<Tensor> {
138 Self::glorot_uniform_with_dtype(shape, DType::Float32)
139 }
140
141 pub fn glorot_uniform_with_dtype(shape: &[usize], dtype: DType) -> Result<Tensor> {
144 if shape.is_empty() {
145 return Err(Error::ParamRange {
146 op: "Tensor::glorot_uniform",
147 param: "shape",
148 value: "[]".to_string(),
149 constraint: "at least 1D",
150 });
151 }
152 let fan_in_v = fan_in(shape);
153 let fan_out_v = shape[0];
154 let bound = (6.0 / (fan_out_v + fan_in_v) as f64).sqrt();
155 Self::uniform_with_dtype(shape, -bound, bound, dtype)
156 }
157
158 pub fn kaiming_uniform(shape: &[usize], a: f64) -> Result<Tensor> {
160 Self::kaiming_uniform_with_dtype(shape, a, DType::Float32)
161 }
162
163 pub fn kaiming_uniform_with_dtype(shape: &[usize], a: f64, dtype: DType) -> Result<Tensor> {
171 if shape.is_empty() {
172 return Err(Error::ParamRange {
173 op: "Tensor::kaiming_uniform",
174 param: "shape",
175 value: "[]".to_string(),
176 constraint: "at least 1D",
177 });
178 }
179 let bound = (6.0 / ((1.0 + a * a) * fan_in(shape) as f64)).sqrt();
180 Self::uniform_with_dtype(shape, -bound, bound, dtype)
181 }
182
183 pub fn kaiming_normal(shape: &[usize], a: f64) -> Result<Tensor> {
186 if shape.is_empty() {
187 return Err(Error::ParamRange {
188 op: "Tensor::kaiming_normal",
189 param: "shape",
190 value: "[]".to_string(),
191 constraint: "at least 1D",
192 });
193 }
194 let std = (2.0 / ((1.0 + a * a) * fan_in(shape) as f64)).sqrt();
195 let z = Tensor::randn(shape)?;
196 let std_t = z.broadcast_scalar(ConstValue::Float(std))?;
197 z.try_mul(&std_t)
198 }
199}