sprs_rand/
lib.rs

1//! Random sparse matrix generation
2
3use crate::rand::distributions::Distribution;
4use crate::rand::Rng;
5use crate::rand::SeedableRng;
6use sprs::indexing::SpIndex;
7use sprs::{CsMat, CsMatI};
8
9/// Re-export [`rand`](https://docs.rs/rand/0.7.3/rand/)
10/// for version compatibility
11pub mod rand {
12    pub use rand::*;
13}
14
15/// Re-export [`rand_distr`](https://docs.rs/rand_distr/0.2.2/rand_distr)
16/// for version compatibility
17pub mod rand_distr {
18    pub use rand_distr::*;
19}
20
21/// Generate a random sparse matrix matching the given density and sampling
22/// the values of its non-zero elements from the provided distribution.
23pub fn rand_csr<R, N, D, I>(
24    rng: &mut R,
25    dist: D,
26    shape: (usize, usize),
27    density: f64,
28) -> CsMatI<N, I>
29where
30    R: Rng + ?Sized,
31    D: Distribution<N>,
32    N: Copy,
33    I: SpIndex,
34{
35    assert!((0.0..=1.0).contains(&density));
36    let exp_nnz =
37        (density * (shape.0 as f64) * (shape.1 as f64)).ceil() as usize;
38    let mut indptr = Vec::with_capacity(shape.0 + 1);
39    let mut indices = Vec::with_capacity(exp_nnz);
40    let mut data = Vec::with_capacity(exp_nnz);
41    // sample row indices
42    for _ in 0..exp_nnz {
43        indices.push(I::from_usize(rng.gen_range(0, shape.0)));
44        // Note: there won't be any correspondence between the data
45        // sampled here and the row sampled before, but this does not matter
46        // as we are sampling.
47        data.push(dist.sample(rng));
48    }
49    indices.sort_unstable();
50    indptr.push(I::from_usize(0));
51    let mut count = 0;
52    for &row in &indices {
53        while indptr.len() != row.index() + 1 {
54            indptr.push(I::from_usize(count));
55        }
56        count += 1;
57    }
58    while indptr.len() != shape.0 + 1 {
59        indptr.push(I::from_usize(count));
60    }
61    assert_eq!(indptr.last().unwrap().index(), exp_nnz);
62    indices.clear();
63    for row in 0..shape.0 {
64        let start = indptr[row].index();
65        let end = indptr[row + 1].index();
66        for _ in start..end {
67            loop {
68                let col = I::from_usize(rng.gen_range(0, shape.1));
69                let loc = indices[start..].binary_search(&col);
70                match loc {
71                    Ok(_) => {
72                        continue;
73                    }
74                    Err(loc) => {
75                        indices.insert(start + loc, col);
76                        break;
77                    }
78                }
79            }
80        }
81        indices[start..end].sort_unstable();
82    }
83
84    CsMatI::new(shape, indptr, indices, data)
85}
86
87/// Convenient wrapper for the common case of sampling a matrix with standard
88/// normal distribution of the nnz values, using a lightweight rng.
89pub fn rand_csr_std(shape: (usize, usize), density: f64) -> CsMat<f64> {
90    let mut rng = rand_pcg::Pcg64Mcg::from_entropy();
91    rand_csr(&mut rng, crate::rand_distr::StandardNormal, shape, density)
92}
93
94#[cfg(test)]
95mod tests {
96    use rand::distributions::Standard;
97    use rand::SeedableRng;
98    use sprs::CsMat;
99
100    #[test]
101    fn empty_random_mat() {
102        let mut rng = rand::thread_rng();
103        let empty: CsMat<f64> =
104            super::rand_csr(&mut rng, Standard, (0, 0), 0.3);
105        assert_eq!(empty.nnz(), 0);
106    }
107
108    #[test]
109    fn random_csr() {
110        let mut rng = rand::rngs::StdRng::seed_from_u64(1234);
111        let mat: CsMat<f64> =
112            super::rand_csr(&mut rng, Standard, (100, 70), 0.3);
113        assert!(mat.density() > 0.25);
114        assert!(mat.density() < 0.35);
115
116        let mat: CsMat<f64> =
117            super::rand_csr(&mut rng, Standard, (1, 10000), 0.3);
118        assert!(mat.density() > 0.28);
119        assert!(mat.density() < 0.32);
120    }
121
122    #[test]
123    fn random_csr_std() {
124        let mat = super::rand_csr_std((100, 1000), 0.2);
125        assert_eq!(mat.shape(), (100, 1000));
126        // Not checking the density as I have no control over the seed
127        // Checking the mean nnz value should be safe though
128        assert!(
129            mat.data().iter().sum::<f64>().abs() / (mat.data().len() as f64)
130                < 0.05
131        );
132    }
133}