zenu_matrix/slice/
dynamic.rs

1use super::slice_dim::SliceDim;
2use crate::{dim::DimDyn, index::SliceTrait, shape_stride::ShapeStride};
3
4#[derive(Clone, Debug, Copy, PartialEq)]
5pub struct Slice {
6    pub index: [SliceDim; 6],
7    pub len: usize,
8}
9
10impl SliceTrait for Slice {
11    type Dim = DimDyn;
12
13    fn sliced_shape_stride(&self, shape: Self::Dim, stride: Self::Dim) -> ShapeStride<Self::Dim> {
14        let mut new_shape = DimDyn::default();
15        let mut new_stride = DimDyn::default();
16
17        for i in 0..self.len {
18            match self.index[i].new_dim(shape[i]) {
19                0 => continue,
20                new_dim => {
21                    new_shape.push_dim(new_dim);
22                    new_stride.push_dim(self.index[i].new_stride(stride[i]));
23                }
24            }
25        }
26
27        ShapeStride::new(new_shape, new_stride)
28    }
29
30    fn sliced_offset(&self, stride: Self::Dim) -> usize {
31        let mut offset = 0;
32
33        for i in 0..self.len {
34            let start = self.index[i].start.unwrap_or(0);
35            offset += start * stride[i];
36        }
37
38        // offset + original_offset
39        offset
40    }
41}
42
43impl From<&[SliceDim]> for Slice {
44    fn from(s: &[SliceDim]) -> Self {
45        if s.len() > 6 {
46            panic!("too many slice dimensions");
47        } else if s.len() == 1 {
48            Slice {
49                index: [
50                    s[0],
51                    SliceDim::default(),
52                    SliceDim::default(),
53                    SliceDim::default(),
54                    SliceDim::default(),
55                    SliceDim::default(),
56                ],
57                len: 1,
58            }
59        } else if s.len() == 2 {
60            Slice {
61                index: [
62                    s[0],
63                    s[1],
64                    SliceDim::default(),
65                    SliceDim::default(),
66                    SliceDim::default(),
67                    SliceDim::default(),
68                ],
69                len: 2,
70            }
71        } else if s.len() == 3 {
72            Slice {
73                index: [
74                    s[0],
75                    s[1],
76                    s[2],
77                    SliceDim::default(),
78                    SliceDim::default(),
79                    SliceDim::default(),
80                ],
81                len: 3,
82            }
83        } else if s.len() == 4 {
84            Slice {
85                index: [
86                    s[0],
87                    s[1],
88                    s[2],
89                    s[3],
90                    SliceDim::default(),
91                    SliceDim::default(),
92                ],
93                len: 4,
94            }
95        } else if s.len() == 5 {
96            Slice {
97                index: [s[0], s[1], s[2], s[3], s[4], SliceDim::default()],
98                len: 5,
99            }
100        } else {
101            Slice {
102                index: [s[0], s[1], s[2], s[3], s[4], s[5]],
103                len: 6,
104            }
105        }
106    }
107}
108
109#[cfg(test)]
110mod slice_dyn_slice {
111    use crate::{dim::DimDyn, index::SliceTrait, slice_dynamic};
112
113    #[test]
114    fn dyn_slice() {
115        let shape = DimDyn::new(&[2, 3, 4]);
116        let stride = DimDyn::new(&[12, 4, 1]);
117        let slice = slice_dynamic!(.., 1, 1..2);
118        let shape_stride = slice.sliced_shape_stride(shape, stride);
119        let result_shape = shape_stride.shape();
120        let result_stride = shape_stride.stride();
121        assert_eq!(result_shape, DimDyn::new(&[2, 1]));
122        assert_eq!(result_stride, DimDyn::new(&[12, 1]));
123    }
124}