zenu_matrix/slice/
slice_dim.rs

1use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
2
3#[derive(Clone, Debug, Copy, PartialEq, Default)]
4pub struct SliceDim {
5    pub(crate) start: Option<usize>,
6    pub(crate) end: Option<usize>,
7    pub(crate) step: Option<usize>,
8}
9
10impl SliceDim {
11    #[must_use]
12    pub fn step(self, step: usize) -> Self {
13        Self {
14            start: self.start,
15            end: self.end,
16            step: Some(step),
17        }
18    }
19
20    fn validate(&self, dim: usize) -> bool {
21        let start = self.start.unwrap_or(0);
22        let end = self.end.unwrap_or(dim - 1);
23        let step = self.step.unwrap_or(1);
24
25        if start > end {
26            return false;
27        }
28
29        if start > dim {
30            return false;
31        }
32
33        if step == 0 {
34            return false;
35        }
36
37        true
38    }
39
40    fn new_dim_unchanged(&self, dim: usize) -> usize {
41        let start = self.start.unwrap_or(0);
42        let mut end = self.end.unwrap_or(dim);
43        let step = self.step.unwrap_or(1);
44
45        if end > dim {
46            end = dim;
47        }
48
49        (end - start + step - 1) / step
50    }
51
52    pub(super) fn new_dim(&self, dim: usize) -> usize {
53        if self.validate(dim) {
54            return self.new_dim_unchanged(dim);
55        }
56        panic!("invalid slice");
57    }
58
59    pub(super) fn new_stride(&self, stride: usize) -> usize {
60        let step = self.step.unwrap_or(1);
61        stride * step
62    }
63}
64
65impl From<Range<usize>> for SliceDim {
66    fn from(range: Range<usize>) -> Self {
67        SliceDim {
68            start: Some(range.start),
69            end: Some(range.end),
70            step: None,
71        }
72    }
73}
74
75impl From<RangeFull> for SliceDim {
76    fn from(_: RangeFull) -> Self {
77        SliceDim {
78            start: None,
79            end: None,
80            step: None,
81        }
82    }
83}
84
85impl From<RangeTo<usize>> for SliceDim {
86    fn from(range: RangeTo<usize>) -> Self {
87        SliceDim {
88            start: None,
89            end: Some(range.end),
90            step: None,
91        }
92    }
93}
94
95impl From<RangeFrom<usize>> for SliceDim {
96    fn from(range: RangeFrom<usize>) -> Self {
97        SliceDim {
98            start: Some(range.start),
99            end: None,
100            step: None,
101        }
102    }
103}
104
105impl From<RangeInclusive<usize>> for SliceDim {
106    fn from(range: RangeInclusive<usize>) -> Self {
107        SliceDim {
108            start: Some(*range.start()),
109            end: Some(*range.end()),
110            step: None,
111        }
112    }
113}
114
115impl From<RangeToInclusive<usize>> for SliceDim {
116    fn from(range: RangeToInclusive<usize>) -> Self {
117        SliceDim {
118            start: None,
119            end: Some(range.end + 1),
120            step: None,
121        }
122    }
123}
124
125impl From<usize> for SliceDim {
126    fn from(index: usize) -> Self {
127        SliceDim {
128            start: Some(index),
129            end: Some(index),
130            step: None,
131        }
132    }
133}
134
135#[test]
136fn slice_index() {
137    let slice_dim = SliceDim {
138        start: Some(0),
139        end: Some(10),
140        step: None,
141    };
142
143    let dim = 20;
144    let new_dim = slice_dim.new_dim(dim);
145    assert_eq!(new_dim, 10);
146    let new_stride = slice_dim.new_stride(1);
147    assert_eq!(new_stride, 1);
148}
149
150#[test]
151fn slice_index_with_stride() {
152    let slice_dim = SliceDim {
153        start: Some(0),
154        end: Some(10),
155        step: Some(2),
156    };
157
158    let dim = 20;
159    let new_dim = slice_dim.new_dim(dim);
160    assert_eq!(new_dim, 5);
161    let new_stride = slice_dim.new_stride(1);
162    assert_eq!(new_stride, 2);
163}
164
165#[test]
166fn slice_dim_full_range() {
167    let slice_dim = SliceDim {
168        start: None,
169        end: None,
170        step: None,
171    };
172
173    let dim = 20;
174    let new_dim = slice_dim.new_dim(dim);
175    assert_eq!(new_dim, 20);
176    let new_stride = slice_dim.new_stride(1);
177    assert_eq!(new_stride, 1);
178}