sprs_rssn/sparse/
slicing.rs

1//! This module implementations to slice a matrix along the desired dimension.
2//! We're using a sealed trait to enable using ranges for an idiomatic API.
3
4use crate::range::Range;
5use crate::{CsMatBase, CsMatViewI, CsMatViewMutI, SpIndex};
6use std::ops::{Deref, DerefMut};
7
8impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
9    CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
10where
11    IptrStorage: Deref<Target = [Iptr]>,
12    IStorage: Deref<Target = [I]>,
13    DStorage: Deref<Target = [N]>,
14{
15    /// Slice the outer dimension of the matrix according to the specified
16    /// range.
17    pub fn slice_outer<S: Range>(
18        &self,
19        range: S,
20    ) -> CsMatViewI<'_, N, I, Iptr> {
21        self.view().slice_outer_rbr(range)
22    }
23}
24
25impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
26    CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
27where
28    IptrStorage: Deref<Target = [Iptr]>,
29    IStorage: Deref<Target = [I]>,
30    DStorage: DerefMut<Target = [N]>,
31{
32    /// Slice the outer dimension of the matrix according to the specified
33    /// range.
34    pub fn slice_outer_mut<S: Range>(
35        &mut self,
36        range: S,
37    ) -> CsMatViewMutI<'_, N, I, Iptr> {
38        let start = range.start();
39        let end = range.end().unwrap_or_else(|| self.outer_dims());
40        assert!(end >= start, "Invalid view");
41
42        let outer_inds_slice = self.indptr.outer_inds_slice(start, end);
43        let (nrows, ncols) = match self.storage() {
44            crate::CSR => ((end - start), self.ncols),
45            crate::CSC => (self.nrows, (end - start)),
46        };
47        CsMatViewMutI {
48            nrows,
49            ncols,
50            storage: self.storage,
51            indptr: self.indptr.middle_slice(range),
52            indices: &self.indices[outer_inds_slice.clone()],
53            data: &mut self.data[outer_inds_slice],
54        }
55    }
56}
57
58impl<'a, N, I, Iptr> crate::CsMatViewI<'a, N, I, Iptr>
59where
60    I: crate::SpIndex,
61    Iptr: crate::SpIndex,
62{
63    /// Slice the outer dimension of the matrix according to the specified
64    /// range.
65    pub fn slice_outer_rbr<S>(
66        &self,
67        range: S,
68    ) -> crate::CsMatViewI<'a, N, I, Iptr>
69    where
70        S: Range,
71    {
72        let start = range.start();
73        let end = range.end().unwrap_or_else(|| self.outer_dims());
74        assert!(end >= start, "Invalid view");
75
76        let outer_inds_slice = self.indptr.outer_inds_slice(start, end);
77        let (nrows, ncols) = match self.storage() {
78            crate::CSR => ((end - start), self.ncols),
79            crate::CSC => (self.nrows, (end - start)),
80        };
81        crate::CsMatViewI {
82            nrows,
83            ncols,
84            storage: self.storage,
85            indptr: self.indptr.middle_slice_rbr(range),
86            indices: &self.indices[outer_inds_slice.clone()],
87            data: &self.data[outer_inds_slice],
88        }
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use crate::CsMat;
95
96    #[test]
97    fn slice_outer() {
98        let size = 11;
99        let csr: CsMat<f64> = CsMat::eye(size);
100        let sliced = csr.slice_outer(2..7);
101        let mut iter = sliced.into_iter();
102        assert_eq!(iter.next().unwrap(), (&1., (0, 2)));
103        assert_eq!(iter.next().unwrap(), (&1., (1, 3)));
104        assert_eq!(iter.next().unwrap(), (&1., (2, 4)));
105        assert_eq!(iter.next().unwrap(), (&1., (3, 5)));
106        assert_eq!(iter.next().unwrap(), (&1., (4, 6)));
107        assert!(iter.next().is_none());
108    }
109}