zenu_matrix/slice/
dynamic.rs1use super::slice_dim::SliceDim;
2use crate::{dim::DimDyn, index::SliceTrait, shape_stride::ShapeStride};
3
4#[derive(Clone, Debug, Copy, PartialEq)]
5pub struct Slice {
6 pub index: [SliceDim; 6],
7 pub len: usize,
8}
9
10impl SliceTrait for Slice {
11 type Dim = DimDyn;
12
13 fn sliced_shape_stride(&self, shape: Self::Dim, stride: Self::Dim) -> ShapeStride<Self::Dim> {
14 let mut new_shape = DimDyn::default();
15 let mut new_stride = DimDyn::default();
16
17 for i in 0..self.len {
18 match self.index[i].new_dim(shape[i]) {
19 0 => continue,
20 new_dim => {
21 new_shape.push_dim(new_dim);
22 new_stride.push_dim(self.index[i].new_stride(stride[i]));
23 }
24 }
25 }
26
27 ShapeStride::new(new_shape, new_stride)
28 }
29
30 fn sliced_offset(&self, stride: Self::Dim) -> usize {
31 let mut offset = 0;
32
33 for i in 0..self.len {
34 let start = self.index[i].start.unwrap_or(0);
35 offset += start * stride[i];
36 }
37
38 offset
40 }
41}
42
43impl From<&[SliceDim]> for Slice {
44 fn from(s: &[SliceDim]) -> Self {
45 if s.len() > 6 {
46 panic!("too many slice dimensions");
47 } else if s.len() == 1 {
48 Slice {
49 index: [
50 s[0],
51 SliceDim::default(),
52 SliceDim::default(),
53 SliceDim::default(),
54 SliceDim::default(),
55 SliceDim::default(),
56 ],
57 len: 1,
58 }
59 } else if s.len() == 2 {
60 Slice {
61 index: [
62 s[0],
63 s[1],
64 SliceDim::default(),
65 SliceDim::default(),
66 SliceDim::default(),
67 SliceDim::default(),
68 ],
69 len: 2,
70 }
71 } else if s.len() == 3 {
72 Slice {
73 index: [
74 s[0],
75 s[1],
76 s[2],
77 SliceDim::default(),
78 SliceDim::default(),
79 SliceDim::default(),
80 ],
81 len: 3,
82 }
83 } else if s.len() == 4 {
84 Slice {
85 index: [
86 s[0],
87 s[1],
88 s[2],
89 s[3],
90 SliceDim::default(),
91 SliceDim::default(),
92 ],
93 len: 4,
94 }
95 } else if s.len() == 5 {
96 Slice {
97 index: [s[0], s[1], s[2], s[3], s[4], SliceDim::default()],
98 len: 5,
99 }
100 } else {
101 Slice {
102 index: [s[0], s[1], s[2], s[3], s[4], s[5]],
103 len: 6,
104 }
105 }
106 }
107}
108
109#[cfg(test)]
110mod slice_dyn_slice {
111 use crate::{dim::DimDyn, index::SliceTrait, slice_dynamic};
112
113 #[test]
114 fn dyn_slice() {
115 let shape = DimDyn::new(&[2, 3, 4]);
116 let stride = DimDyn::new(&[12, 4, 1]);
117 let slice = slice_dynamic!(.., 1, 1..2);
118 let shape_stride = slice.sliced_shape_stride(shape, stride);
119 let result_shape = shape_stride.shape();
120 let result_stride = shape_stride.stride();
121 assert_eq!(result_shape, DimDyn::new(&[2, 1]));
122 assert_eq!(result_stride, DimDyn::new(&[12, 1]));
123 }
124}