1use crate::rand::distr::Distribution;
4use crate::rand::Rng;
5use crate::rand::RngExt;
6use crate::rand::SeedableRng;
7use sprs::indexing::SpIndex;
8use sprs::{CsMat, CsMatI};
9
10pub mod rand {
13 pub use rand::*;
14}
15
16pub mod rand_distr {
19 pub use rand_distr::*;
20}
21
22pub fn rand_csr<R, N, D, I>(
25 rng: &mut R,
26 dist: D,
27 shape: (usize, usize),
28 density: f64,
29) -> CsMatI<N, I>
30where
31 R: Rng + ?Sized,
32 D: Distribution<N>,
33 N: Copy,
34 I: SpIndex,
35{
36 assert!((0.0..=1.0).contains(&density));
37 let exp_nnz =
38 (density * (shape.0 as f64) * (shape.1 as f64)).ceil() as usize;
39 let mut indptr = Vec::with_capacity(shape.0 + 1);
40 let mut indices = Vec::with_capacity(exp_nnz);
41 let mut data = Vec::with_capacity(exp_nnz);
42 for _ in 0..exp_nnz {
44 indices.push(I::from_usize(rng.random_range(0..shape.0)));
45 data.push(dist.sample(rng));
49 }
50 indices.sort_unstable();
51 indptr.push(I::from_usize(0));
52 let mut count = 0;
53 for &row in &indices {
54 while indptr.len() != row.index() + 1 {
55 indptr.push(I::from_usize(count));
56 }
57 count += 1;
58 }
59 while indptr.len() != shape.0 + 1 {
60 indptr.push(I::from_usize(count));
61 }
62 assert_eq!(indptr.last().unwrap().index(), exp_nnz);
63 indices.clear();
64 for row in 0..shape.0 {
65 let start = indptr[row].index();
66 let end = indptr[row + 1].index();
67 for _ in start..end {
68 loop {
69 let col = I::from_usize(rng.random_range(0..shape.1));
70 let loc = indices[start..].binary_search(&col);
71 if let Err(loc) = loc {
72 indices.insert(start + loc, col);
73 break;
74 }
75 }
76 }
77 indices[start..end].sort_unstable();
78 }
79
80 CsMatI::new(shape, indptr, indices, data)
81}
82
83pub fn rand_csr_std(shape: (usize, usize), density: f64) -> CsMat<f64> {
86 let mut rng = rand_pcg::Pcg64Mcg::from_rng(&mut rand::rng());
87 rand_csr(&mut rng, crate::rand_distr::StandardNormal, shape, density)
88}
89
90#[cfg(test)]
91mod tests {
92 use rand::distr::StandardUniform;
93 use rand::SeedableRng;
94 use sprs::CsMat;
95
96 #[test]
97 fn empty_random_mat() {
98 let mut rng = rand::rng();
99 let empty: CsMat<f64> =
100 super::rand_csr(&mut rng, StandardUniform, (0, 0), 0.3);
101 assert_eq!(empty.nnz(), 0);
102 }
103
104 #[test]
105 fn random_csr() {
106 let mut rng = rand::rngs::StdRng::seed_from_u64(1234);
107 let mat: CsMat<f64> =
108 super::rand_csr(&mut rng, StandardUniform, (100, 70), 0.3);
109 assert!(mat.density() > 0.25);
110 assert!(mat.density() < 0.35);
111
112 let mat: CsMat<f64> =
113 super::rand_csr(&mut rng, StandardUniform, (1, 10000), 0.3);
114 assert!(mat.density() > 0.28);
115 assert!(mat.density() < 0.32);
116 }
117
118 #[test]
119 fn random_csr_std() {
120 let mat = super::rand_csr_std((100, 1000), 0.2);
121 assert_eq!(mat.shape(), (100, 1000));
122 assert!(
125 mat.data().iter().sum::<f64>().abs() / (mat.data().len() as f64)
126 < 0.05
127 );
128 }
129}