1use crate::random::core::Random;
8use ::ndarray::{Array, Dimension, Ix1, Ix2, IxDyn};
9use rand::Rng;
10use rand_distr::{Distribution, Exp, Gamma, Normal, Uniform};
11
12pub trait OptimizedArrayRandom<T, D: Dimension> {
17 fn random_bulk<R, Dist>(shape: D, distribution: Dist, rng: &mut Random<R>) -> Self
19 where
20 R: Rng,
21 Dist: Distribution<T> + Copy;
22
23 fn random_using_bulk<R, F>(shape: D, rng: &mut Random<R>, f: F) -> Self
25 where
26 R: Rng,
27 F: FnMut(&mut Random<R>) -> T;
28}
29
30impl<T, D> OptimizedArrayRandom<T, D> for Array<T, D>
31where
32 D: Dimension,
33{
34 fn random_bulk<R, Dist>(shape: D, distribution: Dist, rng: &mut Random<R>) -> Self
35 where
36 R: Rng,
37 Dist: Distribution<T> + Copy,
38 {
39 let size = shape.size();
40 let mut data = Vec::with_capacity(size);
41
42 for _ in 0..size {
44 data.push(distribution.sample(&mut rng.rng));
45 }
46
47 Array::from_shape_vec(shape, data).expect("Operation failed")
48 }
49
50 fn random_using_bulk<R, F>(shape: D, rng: &mut Random<R>, mut f: F) -> Self
51 where
52 R: Rng,
53 F: FnMut(&mut Random<R>) -> T,
54 {
55 let size = shape.size();
56 let mut data = Vec::with_capacity(size);
57
58 for _ in 0..size {
60 data.push(f(rng));
61 }
62
63 Array::from_shape_vec(shape, data).expect("Operation failed")
64 }
65}
66
67pub fn random_uniform_array<D: Dimension>(shape: D, rng: &mut Random<impl Rng>) -> Array<f64, D> {
70 Array::random_bulk(
71 shape,
72 Uniform::new(0.0, 1.0).expect("Operation failed"),
73 rng,
74 )
75}
76
77pub fn random_normal_array<D: Dimension>(
79 shape: D,
80 mean: f64,
81 std_dev: f64,
82 rng: &mut Random<impl Rng>,
83) -> Array<f64, D> {
84 Array::random_bulk(
85 shape,
86 Normal::new(mean, std_dev).expect("Operation failed"),
87 rng,
88 )
89}
90
91pub fn random_exponential_array<D: Dimension>(
93 shape: D,
94 lambda: f64,
95 rng: &mut Random<impl Rng>,
96) -> Array<f64, D> {
97 Array::random_bulk(shape, Exp::new(lambda).expect("Operation failed"), rng)
98}
99
100pub fn random_gamma_array<D: Dimension>(
102 shape: D,
103 alpha: f64,
104 beta: f64,
105 rng: &mut Random<impl Rng>,
106) -> Array<f64, D> {
107 Array::random_bulk(
108 shape,
109 Gamma::new(alpha, beta).expect("Operation failed"),
110 rng,
111 )
112}
113
114pub fn random_sparse_array<D: Dimension>(
116 shape: D,
117 sparsity: f64,
118 rng: &mut Random<impl Rng>,
119) -> Array<f64, D> {
120 Array::random_using_bulk(shape, rng, |rng| {
121 if rng.gen_range(0.0..1.0) < sparsity {
122 0.0
123 } else {
124 rng.gen_range(-1.0..1.0)
125 }
126 })
127}
128
129pub fn random_xavier_weights(
131 fan_in: usize,
132 fan_out: usize,
133 rng: &mut Random<impl Rng>,
134) -> Array<f64, Ix2> {
135 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
136 Array::random_bulk(
137 crate::ndarray::Ix2(fan_out, fan_in),
138 Uniform::new(-limit, limit).expect("Operation failed"),
139 rng,
140 )
141}
142
143pub fn random_he_weights(
145 fan_in: usize,
146 fan_out: usize,
147 rng: &mut Random<impl Rng>,
148) -> Array<f64, Ix2> {
149 let std_dev = (2.0 / fan_in as f64).sqrt();
150 Array::random_bulk(
151 crate::ndarray::Ix2(fan_out, fan_in),
152 Normal::new(0.0, std_dev).expect("Operation failed"),
153 rng,
154 )
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::random::core::seeded_rng;
161 use ::ndarray::Ix2;
162
163 #[test]
164 fn test_optimized_array_random_bulk() {
165 let mut rng = seeded_rng(42);
166 let shape = Ix2(5, 5);
167
168 let array = Array::<f64, _>::random_bulk(
170 shape,
171 Uniform::new(0.0, 1.0).expect("Operation failed"),
172 &mut rng,
173 );
174 assert_eq!(array.shape(), &[5, 5]);
175 assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
176 }
177
178 #[test]
179 fn test_optimized_array_random_using_bulk() {
180 let mut rng = seeded_rng(123);
181 let shape = Ix2(3, 4);
182
183 let array =
185 Array::<i32, _>::random_using_bulk(shape, &mut rng, |rng| rng.gen_range(1..100));
186 assert_eq!(array.shape(), &[3, 4]);
187 assert!(array.iter().all(|&x| (1..100).contains(&x)));
188 }
189
190 #[test]
191 fn test_random_uniform_array() {
192 let mut rng = seeded_rng(456);
193 let array = random_uniform_array(Ix2(10, 10), &mut rng);
194
195 assert_eq!(array.shape(), &[10, 10]);
196 assert!(array.iter().all(|&x| (0.0..1.0).contains(&x)));
197 }
198
199 #[test]
200 fn test_random_normal_array() {
201 let mut rng = seeded_rng(789);
202 let array = random_normal_array(Ix2(5, 5), 0.0, 1.0, &mut rng);
203
204 assert_eq!(array.shape(), &[5, 5]);
205 }
207
208 #[test]
209 fn test_random_exponential_array() {
210 let mut rng = seeded_rng(101112);
211 let array = random_exponential_array(Ix2(3, 3), 1.0, &mut rng);
212
213 assert_eq!(array.shape(), &[3, 3]);
214 assert!(array.iter().all(|&x| x >= 0.0));
215 }
216
217 #[test]
218 fn test_random_gamma_array() {
219 let mut rng = seeded_rng(131415);
220 let array = random_gamma_array(Ix2(4, 4), 2.0, 1.0, &mut rng);
221
222 assert_eq!(array.shape(), &[4, 4]);
223 assert!(array.iter().all(|&x| x >= 0.0));
224 }
225
226 #[test]
227 fn test_random_sparse_array() {
228 let mut rng = seeded_rng(161718);
229 let array = random_sparse_array(Ix2(6, 6), 0.7, &mut rng);
230
231 assert_eq!(array.shape(), &[6, 6]);
232 let zero_count = array.iter().filter(|&&x| x == 0.0).count();
233 assert!(zero_count > 0); }
235
236 #[test]
237 fn test_neural_network_weight_initialization() {
238 let mut rng = seeded_rng(192021);
239
240 let xavier_weights = random_xavier_weights(10, 5, &mut rng);
242 assert_eq!(xavier_weights.shape(), &[5, 10]);
243
244 let he_weights = random_he_weights(10, 5, &mut rng);
246 assert_eq!(he_weights.shape(), &[5, 10]);
247 }
248}