tensor_rs/tensor_impl/gen_tensor/
rand.rs1use super::GenTensor;
2use crate::tensor_trait::rand::Random;
3
4use rand::prelude::*;
5use rand::Rng;
6use rand_distr::{Normal, Uniform, Distribution, StandardNormal};
7
8impl<T> Random for GenTensor<T>
9where T: num_traits::Float + rand_distr::uniform::SampleUniform,
10 StandardNormal: Distribution<T> {
11 type TensorType = GenTensor<T>;
12 type ElementType = T;
13
14 fn rand_usize(rng: &mut StdRng,
15 dim: &[usize],
16 left: usize, right: usize) -> Self::TensorType {
17 let elem = dim.iter().product();
18
19 let mut dta = Vec::<Self::ElementType>::with_capacity(elem);
20 for _i in 0..elem {
21 let v: usize = rng.gen_range(left..right);
22 dta.push(T::from(v).unwrap());
23 }
24 GenTensor::new_raw(&dta, dim)
25 }
26
27 fn bernoulli() -> Self::TensorType {
28 unimplemented!();
29 }
30 fn cauchy() -> Self::TensorType {
31 unimplemented!();
32 }
33 fn exponential() -> Self::TensorType {
34 unimplemented!();
35 }
36 fn geometric() -> Self::TensorType {
37 unimplemented!();
38 }
39 fn log_normal() -> Self::TensorType {
40 unimplemented!();
41 }
42 fn normal(rng: &mut StdRng,
43 dim: &[usize],
44 mean: Self::ElementType,
45 std: Self::ElementType) -> Self::TensorType {
46 let elem = dim.iter().product();
47
48 let mut dta = Vec::<Self::ElementType>::with_capacity(elem);
49 let normal = Normal::<Self::ElementType>::new(mean, std).expect("");
50 for _i in 0..elem {
51 dta.push(normal.sample(rng));
52 }
53 GenTensor::new_raw(&dta, dim)
54 }
55 fn uniform(rng: &mut StdRng,
56 dim: &[usize],
57 from: Self::ElementType,
58 to: Self::ElementType) -> Self::TensorType {
59 let elem: usize = dim.iter().product();
60
61 let mut dta = Vec::<Self::ElementType>::with_capacity(elem);
62 let normal = Uniform::<Self::ElementType>::new(from, to);
63 for _i in 0..elem {
64 dta.push(normal.sample(rng));
65 }
66 GenTensor::new_raw(&dta, dim)
67 }
68}
69
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::tensor_trait::compare_tensor::CompareTensor;
75
76 #[test]
77 fn normalize_unit() {
78 let mut rng = StdRng::seed_from_u64(671);
79 let m = GenTensor::<f64>::uniform(&mut rng, &[2,2], 0., 10.);
80 assert!(GenTensor::<f64>::fill(10., &[2,2]).sub(&m).all(&|x| x > 0. && x < 10.));
81 }
82}