1use crate::rand::distributions::Distribution;
4use crate::rand::Rng;
5use crate::rand::SeedableRng;
6use sprs::indexing::SpIndex;
7use sprs::{CsMat, CsMatI};
8
9pub mod rand {
12 pub use rand::*;
13}
14
15pub mod rand_distr {
18 pub use rand_distr::*;
19}
20
21pub fn rand_csr<R, N, D, I>(
24 rng: &mut R,
25 dist: D,
26 shape: (usize, usize),
27 density: f64,
28) -> CsMatI<N, I>
29where
30 R: Rng + ?Sized,
31 D: Distribution<N>,
32 N: Copy,
33 I: SpIndex,
34{
35 assert!((0.0..=1.0).contains(&density));
36 let exp_nnz =
37 (density * (shape.0 as f64) * (shape.1 as f64)).ceil() as usize;
38 let mut indptr = Vec::with_capacity(shape.0 + 1);
39 let mut indices = Vec::with_capacity(exp_nnz);
40 let mut data = Vec::with_capacity(exp_nnz);
41 for _ in 0..exp_nnz {
43 indices.push(I::from_usize(rng.gen_range(0, shape.0)));
44 data.push(dist.sample(rng));
48 }
49 indices.sort_unstable();
50 indptr.push(I::from_usize(0));
51 let mut count = 0;
52 for &row in &indices {
53 while indptr.len() != row.index() + 1 {
54 indptr.push(I::from_usize(count));
55 }
56 count += 1;
57 }
58 while indptr.len() != shape.0 + 1 {
59 indptr.push(I::from_usize(count));
60 }
61 assert_eq!(indptr.last().unwrap().index(), exp_nnz);
62 indices.clear();
63 for row in 0..shape.0 {
64 let start = indptr[row].index();
65 let end = indptr[row + 1].index();
66 for _ in start..end {
67 loop {
68 let col = I::from_usize(rng.gen_range(0, shape.1));
69 let loc = indices[start..].binary_search(&col);
70 match loc {
71 Ok(_) => {
72 continue;
73 }
74 Err(loc) => {
75 indices.insert(start + loc, col);
76 break;
77 }
78 }
79 }
80 }
81 indices[start..end].sort_unstable();
82 }
83
84 CsMatI::new(shape, indptr, indices, data)
85}
86
87pub fn rand_csr_std(shape: (usize, usize), density: f64) -> CsMat<f64> {
90 let mut rng = rand_pcg::Pcg64Mcg::from_entropy();
91 rand_csr(&mut rng, crate::rand_distr::StandardNormal, shape, density)
92}
93
94#[cfg(test)]
95mod tests {
96 use rand::distributions::Standard;
97 use rand::SeedableRng;
98 use sprs::CsMat;
99
100 #[test]
101 fn empty_random_mat() {
102 let mut rng = rand::thread_rng();
103 let empty: CsMat<f64> =
104 super::rand_csr(&mut rng, Standard, (0, 0), 0.3);
105 assert_eq!(empty.nnz(), 0);
106 }
107
108 #[test]
109 fn random_csr() {
110 let mut rng = rand::rngs::StdRng::seed_from_u64(1234);
111 let mat: CsMat<f64> =
112 super::rand_csr(&mut rng, Standard, (100, 70), 0.3);
113 assert!(mat.density() > 0.25);
114 assert!(mat.density() < 0.35);
115
116 let mat: CsMat<f64> =
117 super::rand_csr(&mut rng, Standard, (1, 10000), 0.3);
118 assert!(mat.density() > 0.28);
119 assert!(mat.density() < 0.32);
120 }
121
122 #[test]
123 fn random_csr_std() {
124 let mat = super::rand_csr_std((100, 1000), 0.2);
125 assert_eq!(mat.shape(), (100, 1000));
126 assert!(
129 mat.data().iter().sum::<f64>().abs() / (mat.data().len() as f64)
130 < 0.05
131 );
132 }
133}