1use crate::random::core::Random;
37use crate::random::distributions::{Beta, Dirichlet, MultivariateNormal, VonMises, WeightedChoice};
38use ::ndarray::{
39 Array, Array1, Array2, Array3, ArrayBase, Data, DataMut, DataOwned, Dimension, Ix1, Ix2, Ix3,
40 ShapeBuilder,
41};
42use rand::Rng;
43use rand_distr::{Distribution, Normal, Uniform};
44use std::marker::PhantomData;
45
46#[derive(Debug, Clone, Copy)]
48pub struct StandardNormal;
49
50impl Distribution<f64> for StandardNormal {
51 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
52 Normal::new(0.0, 1.0).expect("Operation failed").sample(rng)
53 }
54}
55
56impl Distribution<f32> for StandardNormal {
57 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
58 Normal::new(0.0f32, 1.0f32)
59 .expect("Operation failed")
60 .sample(rng)
61 }
62}
63
64#[derive(Debug, Clone, Copy)]
66pub struct StandardUniform;
67
68impl Distribution<f64> for StandardUniform {
69 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
70 Uniform::new(0.0, 1.0)
71 .expect("Operation failed")
72 .sample(rng)
73 }
74}
75
76impl Distribution<f32> for StandardUniform {
77 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
78 Uniform::new(0.0f32, 1.0f32)
79 .expect("Operation failed")
80 .sample(rng)
81 }
82}
83
84pub trait RandomExt<A, D: Dimension> {
89 fn random<Dist, R>(shape: D, distribution: Dist, rng: &mut Random<R>) -> Self
91 where
92 Dist: Distribution<A>,
93 R: Rng;
94
95 fn random_using<F, R>(shape: D, rng: &mut Random<R>, f: F) -> Self
97 where
98 F: FnMut() -> A,
99 R: Rng;
100
101 fn standard_normal<R>(shape: D, rng: &mut Random<R>) -> Self
103 where
104 A: From<f64>,
105 R: Rng;
106
107 fn standard_uniform<R>(shape: D, rng: &mut Random<R>) -> Self
109 where
110 A: From<f64>,
111 R: Rng;
112
113 fn normal<R>(shape: D, mean: f64, std: f64, rng: &mut Random<R>) -> Self
115 where
116 A: From<f64>,
117 R: Rng;
118
119 fn uniform<R>(shape: D, low: f64, high: f64, rng: &mut Random<R>) -> Self
121 where
122 A: From<f64>,
123 R: Rng;
124
125 fn beta<R>(shape: D, alpha: f64, beta: f64, rng: &mut Random<R>) -> Self
127 where
128 A: From<f64>,
129 R: Rng;
130
131 fn exponential<R>(shape: D, lambda: f64, rng: &mut Random<R>) -> Self
133 where
134 A: From<f64>,
135 R: Rng;
136
137 fn randint<R>(shape: D, low: i64, high: i64, rng: &mut Random<R>) -> Self
139 where
140 A: From<i64>,
141 R: Rng;
142}
143
144impl<A, S, D> RandomExt<A, D> for ArrayBase<S, D>
145where
146 S: DataOwned<Elem = A>,
147 D: Dimension,
148 A: Clone,
149{
150 fn random<Dist, R>(shape: D, distribution: Dist, rng: &mut Random<R>) -> Self
151 where
152 Dist: Distribution<A>,
153 R: Rng,
154 {
155 let size = shape.size();
156
157 let values: Vec<A> = (0..size)
159 .map(|_| distribution.sample(&mut rng.rng))
160 .collect();
161
162 Self::from_shape_vec(shape, values).expect("Operation failed")
163 }
164
165 fn random_using<F, R>(shape: D, rng: &mut Random<R>, mut f: F) -> Self
166 where
167 F: FnMut() -> A,
168 R: Rng,
169 {
170 let size = shape.size();
171
172 let values: Vec<A> = (0..size).map(|_| f()).collect();
173
174 Self::from_shape_vec(shape, values).expect("Operation failed")
175 }
176
177 fn standard_normal<R>(shape: D, rng: &mut Random<R>) -> Self
178 where
179 A: From<f64>,
180 R: Rng,
181 {
182 let normal_dist = Normal::new(0.0, 1.0).expect("Operation failed");
183 let size = shape.size();
184 let values: Vec<A> = (0..size)
185 .map(|_| A::from(normal_dist.sample(&mut rng.rng)))
186 .collect();
187 Self::from_shape_vec(shape, values).expect("Operation failed")
188 }
189
190 fn standard_uniform<R>(shape: D, rng: &mut Random<R>) -> Self
191 where
192 A: From<f64>,
193 R: Rng,
194 {
195 let uniform_dist = Uniform::new(0.0, 1.0).expect("Operation failed");
196 let size = shape.size();
197 let values: Vec<A> = (0..size)
198 .map(|_| A::from(uniform_dist.sample(&mut rng.rng)))
199 .collect();
200 Self::from_shape_vec(shape, values).expect("Operation failed")
201 }
202
203 fn normal<R>(shape: D, mean: f64, std: f64, rng: &mut Random<R>) -> Self
204 where
205 A: From<f64>,
206 R: Rng,
207 {
208 let normal_dist = Normal::new(mean, std).expect("Operation failed");
209 let size = shape.size();
210 let values: Vec<A> = (0..size)
211 .map(|_| A::from(normal_dist.sample(&mut rng.rng)))
212 .collect();
213 Self::from_shape_vec(shape, values).expect("Operation failed")
214 }
215
216 fn uniform<R>(shape: D, low: f64, high: f64, rng: &mut Random<R>) -> Self
217 where
218 A: From<f64>,
219 R: Rng,
220 {
221 let uniform_dist = Uniform::new(low, high).expect("Operation failed");
222 let size = shape.size();
223 let values: Vec<A> = (0..size)
224 .map(|_| A::from(uniform_dist.sample(&mut rng.rng)))
225 .collect();
226 Self::from_shape_vec(shape, values).expect("Operation failed")
227 }
228
229 fn beta<R>(shape: D, alpha: f64, beta: f64, rng: &mut Random<R>) -> Self
230 where
231 A: From<f64>,
232 R: Rng,
233 {
234 let beta_dist = Beta::new(alpha, beta).expect("Operation failed");
235 let size = shape.size();
236 let mut values = Vec::with_capacity(size);
237 for _ in 0..size {
238 let sample_val = beta_dist.sample(rng);
239 values.push(A::from(sample_val));
240 }
241 Self::from_shape_vec(shape, values).expect("Operation failed")
242 }
243
244 fn exponential<R>(shape: D, lambda: f64, rng: &mut Random<R>) -> Self
245 where
246 A: From<f64>,
247 R: Rng,
248 {
249 let exp_dist = rand_distr::Exp::new(lambda).expect("Operation failed");
250 let size = shape.size();
251 let values: Vec<A> = (0..size)
252 .map(|_| A::from(exp_dist.sample(&mut rng.rng)))
253 .collect();
254 Self::from_shape_vec(shape, values).expect("Operation failed")
255 }
256
257 fn randint<R>(shape: D, low: i64, high: i64, rng: &mut Random<R>) -> Self
258 where
259 A: From<i64>,
260 R: Rng,
261 {
262 let int_dist = Uniform::new(low, high).expect("Operation failed");
263 let size = shape.size();
264 let values: Vec<A> = (0..size)
265 .map(|_| A::from(int_dist.sample(&mut rng.rng)))
266 .collect();
267 Self::from_shape_vec(shape, values).expect("Operation failed")
268 }
269}
270
271pub trait ScientificRandomExt<A, D: Dimension> {
273 fn dirichlet<R>(shape: D, alpha: &[f64], rng: &mut Random<R>) -> Self
275 where
276 A: From<f64>,
277 R: Rng;
278
279 fn von_mises<R>(shape: D, mu: f64, kappa: f64, rng: &mut Random<R>) -> Self
281 where
282 A: From<f64>,
283 R: Rng;
284
285 fn multivariate_normal<R>(
287 mean: Vec<f64>,
288 covariance: Vec<Vec<f64>>,
289 n_samples: usize,
290 rng: &mut Random<R>,
291 ) -> Array<A, crate::ndarray::Ix2>
292 where
293 A: From<f64>,
294 R: Rng;
295
296 fn categorical<R, T>(
298 shape: D,
299 choices: &[T],
300 probabilities: &[f64],
301 rng: &mut Random<R>,
302 ) -> Array<T, D>
303 where
304 T: Clone,
305 R: Rng;
306
307 fn correlated_normal<R>(
309 shape: D,
310 correlation_matrix: &Array<f64, crate::ndarray::Ix2>,
311 rng: &mut Random<R>,
312 ) -> Self
313 where
314 A: From<f64>,
315 R: Rng;
316
317 fn sparse<R, Dist>(shape: D, density: f64, distribution: Dist, rng: &mut Random<R>) -> Self
319 where
320 A: Clone + Default,
321 R: Rng,
322 Dist: Distribution<A>;
323}
324
325impl<A, S, D> ScientificRandomExt<A, D> for ArrayBase<S, D>
326where
327 S: DataOwned<Elem = A>,
328 D: Dimension,
329 A: Clone,
330{
331 fn dirichlet<R>(shape: D, alpha: &[f64], rng: &mut Random<R>) -> Self
332 where
333 A: From<f64>,
334 R: Rng,
335 {
336 let dirichlet = Dirichlet::new(alpha.to_vec()).expect("Operation failed");
337 let size = shape.size();
338 let mut values = Vec::with_capacity(size);
339 for _ in 0..size {
340 let sample_vec = dirichlet.sample(rng);
341 let sample_val = sample_vec.get(0).copied().unwrap_or(0.0);
343 values.push(A::from(sample_val));
344 }
345 Self::from_shape_vec(shape, values).expect("Operation failed")
346 }
347
348 fn von_mises<R>(shape: D, mu: f64, kappa: f64, rng: &mut Random<R>) -> Self
349 where
350 A: From<f64>,
351 R: Rng,
352 {
353 let von_mises = VonMises::mu(mu, kappa).expect("Operation failed");
354 let size = shape.size();
355 let mut values = Vec::with_capacity(size);
356 for _ in 0..size {
357 let sample_val = von_mises.sample(rng);
358 values.push(A::from(sample_val));
359 }
360 Self::from_shape_vec(shape, values).expect("Operation failed")
361 }
362
363 fn multivariate_normal<R>(
364 mean: Vec<f64>,
365 covariance: Vec<Vec<f64>>,
366 n_samples: usize,
367 rng: &mut Random<R>,
368 ) -> Array<A, crate::ndarray::Ix2>
369 where
370 A: From<f64>,
371 R: Rng,
372 {
373 let mvn = MultivariateNormal::new(mean.clone(), covariance).expect("Operation failed");
374 let dim = mean.len();
375
376 Array::from_shape_fn((n_samples, dim), |_| {
377 let sample = mvn.sample(rng);
378 A::from(sample[0]) })
380 }
381
382 fn categorical<R, T>(
383 shape: D,
384 choices: &[T],
385 probabilities: &[f64],
386 rng: &mut Random<R>,
387 ) -> Array<T, D>
388 where
389 T: Clone,
390 R: Rng,
391 {
392 let weighted = WeightedChoice::new(choices.to_vec(), probabilities.to_vec())
393 .expect("Operation failed");
394 Array::from_shape_fn(shape, |_| weighted.sample(rng).clone())
395 }
396
397 fn correlated_normal<R>(
398 shape: D,
399 correlation_matrix: &Array<f64, crate::ndarray::Ix2>,
400 rng: &mut Random<R>,
401 ) -> Self
402 where
403 A: From<f64>,
404 R: Rng,
405 {
406 Self::standard_normal(shape, rng)
408 }
409
410 fn sparse<R, Dist>(shape: D, density: f64, distribution: Dist, rng: &mut Random<R>) -> Self
411 where
412 A: Clone + Default,
413 R: Rng,
414 Dist: Distribution<A>,
415 {
416 let size = shape.size();
417
418 let values: Vec<A> = (0..size)
419 .map(|_| {
420 if rng.rng.gen::<f64>() < density {
421 distribution.sample(&mut rng.rng)
422 } else {
423 A::default()
424 }
425 })
426 .collect();
427
428 Self::from_shape_vec(shape, values).expect("Operation failed")
429 }
430}
431
432pub mod convenience {
434 use super::*;
435 use crate::random::thread_rng;
436 use ::ndarray::{Array1, Array2, Array3, Ix1, Ix2, Ix3};
437
438 pub fn randn(size: usize) -> Array1<f64> {
440 let mut rng = thread_rng();
441 Array1::standard_normal(Ix1(size), &mut rng)
442 }
443
444 pub fn rand(size: usize) -> Array1<f64> {
446 let mut rng = thread_rng();
447 Array1::standard_uniform(Ix1(size), &mut rng)
448 }
449
450 pub fn randn2(rows: usize, cols: usize) -> Array2<f64> {
452 let mut rng = thread_rng();
453 Array2::standard_normal(Ix2(rows, cols), &mut rng)
454 }
455
456 pub fn rand2(rows: usize, cols: usize) -> Array2<f64> {
458 let mut rng = thread_rng();
459 Array2::standard_uniform(Ix2(rows, cols), &mut rng)
460 }
461
462 pub fn randn3(dim1: usize, dim2: usize, dim3: usize) -> Array3<f64> {
464 let mut rng = thread_rng();
465 Array3::standard_normal(Ix3(dim1, dim2, dim3), &mut rng)
466 }
467
468 pub fn randint(size: usize, low: i64, high: i64) -> Array1<i64> {
470 let mut rng = thread_rng();
471 Array1::randint(Ix1(size), low, high, &mut rng)
472 }
473
474 pub fn choice<T: Clone>(choices: &[T], size: usize) -> Array1<T> {
476 let mut rng = thread_rng();
477 let uniform_dist = Uniform::new(0, choices.len()).expect("Operation failed");
478 Array1::from_shape_fn(Ix1(size), |_| {
479 let idx = uniform_dist.sample(&mut rng.rng);
480 choices[idx].clone()
481 })
482 }
483}
484
485pub mod optimized {
487 use super::*;
488 use crate::random::parallel::ParallelRng;
489
490 pub fn parallel_randn<R: Rng + Send + Sync + Clone>(
492 shape: (usize, usize),
493 rng: &mut Random<R>,
494 ) -> Array<f64, crate::ndarray::Ix2> {
495 Array::standard_normal(Ix2(shape.0, shape.1), rng)
497 }
498
499 pub fn simd_rand<R: Rng>(size: usize, rng: &mut Random<R>) -> Array1<f64> {
501 Array1::standard_uniform(Ix1(size), rng)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::random::seeded_rng;
510 use approx::assert_abs_diff_eq;
511 use ndarray::{Array1, Array2, Ix1, Ix2};
512
513 #[test]
514 fn test_random_ext_basic() {
515 let mut rng = seeded_rng(42);
516 let arr: Array1<f64> = Array1::standard_normal(Ix1(10), &mut rng);
517 assert_eq!(arr.len(), 10);
518 }
519
520 #[test]
521 fn test_random_ext_uniform() {
522 let mut rng = seeded_rng(123);
523 let arr: Array2<f64> = Array2::uniform(Ix2(5, 5), 0.0, 1.0, &mut rng);
524 assert_eq!(arr.shape(), &[5, 5]);
525 assert!(arr.iter().all(|&x| x >= 0.0 && x < 1.0));
526 }
527
528 #[test]
529 fn test_reproducibility() {
530 let mut rng1 = seeded_rng(999);
531 let mut rng2 = seeded_rng(999);
532
533 let arr1: Array1<f64> = Array1::standard_normal(Ix1(100), &mut rng1);
534 let arr2: Array1<f64> = Array1::standard_normal(Ix1(100), &mut rng2);
535
536 for (a, b) in arr1.iter().zip(arr2.iter()) {
537 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
538 }
539 }
540
541 #[test]
542 fn test_convenience_functions() {
543 let arr = convenience::randn(50);
544 assert_eq!(arr.len(), 50);
545
546 let matrix = convenience::rand2(3, 4);
547 assert_eq!(matrix.shape(), &[3, 4]);
548 }
549
550 #[test]
551 fn test_scientific_extensions() {
552 let mut rng = seeded_rng(456);
553
554 let beta_arr: Array1<f64> = Array1::beta(Ix1(20), 2.0, 5.0, &mut rng);
556 assert_eq!(beta_arr.len(), 20);
557 assert!(beta_arr.iter().all(|&x| x >= 0.0 && x <= 1.0));
558
559 let vm_arr: Array1<f64> = Array1::von_mises(Ix1(15), 0.0, 1.0, &mut rng);
561 assert_eq!(vm_arr.len(), 15);
562 }
563}