1use super::{ConstShape, Shape};
8
9#[derive(Clone, Debug)]
11pub struct Indices<S: Shape> {
12 n: usize,
13 i: usize,
14 rev_i: usize,
15 shape: S,
16}
17
18impl<S: Shape> Indices<S> {
19 pub fn from_shape(shape: S) -> Self {
21 let n = shape.iter().product::<usize>();
22
23 Self {
24 n,
25 i: 0,
26 rev_i: n,
27 shape,
28 }
29 }
30}
31
32impl<const D: usize> Iterator for Indices<ConstShape<D>> {
33 type Item = [usize; D];
34
35 fn next(&mut self) -> Option<Self::Item> {
36 if self.i < self.rev_i {
37 let idx = compute_index_unchecked(self.i, self.n, self.shape);
38 self.i += 1;
39 Some(idx)
40 } else {
41 None
42 }
43 }
44
45 fn size_hint(&self) -> (usize, Option<usize>) {
46 let len = self.rev_i - self.i;
47 (len, Some(len))
48 }
49}
50
51impl<const D: usize> DoubleEndedIterator for Indices<ConstShape<D>> {
52 fn next_back(&mut self) -> Option<Self::Item> {
53 if self.i < self.rev_i {
54 self.rev_i -= 1;
55 let idx = compute_index_unchecked(self.rev_i, self.n, self.shape);
56 Some(idx)
57 } else {
58 None
59 }
60 }
61}
62
63impl<S: Shape> ExactSizeIterator for Indices<S> where Indices<S>: Iterator {}
64
65impl<S: Shape> std::iter::FusedIterator for Indices<S> where Indices<S>: Iterator {}
66
67fn compute_index_unchecked<const D: usize>(
68 mut flat: usize,
69 mut n: usize,
70 shape: [usize; D],
71) -> [usize; D] {
72 let mut index = [0; D];
73 for i in 0..D {
74 n /= shape[i];
75 index[i] = flat / n;
76 flat %= n;
77 }
78 index
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn test_compute_index() {
87 assert_eq!(compute_index_unchecked(3, 4, [4]), [3]);
88 assert_eq!(compute_index_unchecked(16, 28, [4, 7]), [2, 2]);
89 assert_eq!(compute_index_unchecked(3, 6, [1, 3, 2]), [0, 1, 1]);
90 }
91
92 #[test]
93 fn test_indices_1d() {
94 let mut iter = Indices::from_shape([4]);
95
96 assert_eq!(iter.len(), 4);
97
98 assert_eq!(iter.next(), Some([0]));
99 assert_eq!(iter.next(), Some([1]));
100
101 assert_eq!(iter.len(), 2);
102
103 assert_eq!(iter.next(), Some([2]));
104 assert_eq!(iter.next(), Some([3]));
105
106 assert_eq!(iter.len(), 0);
107 assert!(iter.next().is_none());
108 }
109
110 #[test]
111 fn test_indices_2d() {
112 let mut iter = Indices::from_shape([2, 3]);
113
114 assert_eq!(iter.len(), 6);
115
116 assert_eq!(iter.next(), Some([0, 0]));
117 assert_eq!(iter.next(), Some([0, 1]));
118 assert_eq!(iter.next(), Some([0, 2]));
119
120 assert_eq!(iter.len(), 3);
121
122 assert_eq!(iter.next(), Some([1, 0]));
123 assert_eq!(iter.next(), Some([1, 1]));
124 assert_eq!(iter.next(), Some([1, 2]));
125
126 assert_eq!(iter.len(), 0);
127 assert!(iter.next().is_none());
128 }
129
130 #[test]
131 fn test_indices_2d_mixed_direction() {
132 let mut iter = Indices::from_shape([2, 3]);
133
134 assert_eq!(iter.len(), 6);
135
136 assert_eq!(iter.next(), Some([0, 0]));
137 assert_eq!(iter.next_back(), Some([1, 2]));
138 assert_eq!(iter.next_back(), Some([1, 1]));
139
140 assert_eq!(iter.len(), 3);
141
142 assert_eq!(iter.next(), Some([0, 1]));
143 assert_eq!(iter.next_back(), Some([1, 0]));
144 assert_eq!(iter.next(), Some([0, 2]));
145
146 assert_eq!(iter.len(), 0);
147 assert!(iter.next().is_none());
148 }
149
150 #[test]
151 fn test_indices_3d() {
152 let mut iter = Indices::from_shape([2, 1, 3]);
153
154 assert_eq!(iter.len(), 6);
155
156 assert_eq!(iter.next(), Some([0, 0, 0]));
157 assert_eq!(iter.next(), Some([0, 0, 1]));
158 assert_eq!(iter.next(), Some([0, 0, 2]));
159
160 assert_eq!(iter.len(), 3);
161
162 assert_eq!(iter.next(), Some([1, 0, 0]));
163 assert_eq!(iter.next(), Some([1, 0, 1]));
164 assert_eq!(iter.next(), Some([1, 0, 2]));
165
166 assert_eq!(iter.len(), 0);
167 assert!(iter.next().is_none());
168 }
169
170 #[test]
171 fn test_indices_3d_rev() {
172 let mut iter = Indices::from_shape([2, 1, 3]).rev();
173
174 assert_eq!(iter.len(), 6);
175
176 assert_eq!(iter.next(), Some([1, 0, 2]));
177 assert_eq!(iter.next(), Some([1, 0, 1]));
178 assert_eq!(iter.next(), Some([1, 0, 0]));
179
180 assert_eq!(iter.len(), 3);
181
182 assert_eq!(iter.next(), Some([0, 0, 2]));
183 assert_eq!(iter.next(), Some([0, 0, 1]));
184 assert_eq!(iter.next(), Some([0, 0, 0]));
185
186 assert_eq!(iter.len(), 0);
187 assert!(iter.next().is_none());
188 }
189}