1use std::iter::FusedIterator;
2
3use super::View;
4
5#[derive(Clone, Debug)]
9pub struct Iter<'a, T> {
10 view: View<'a, T>,
11 coords: Vec<usize>,
12 offset: usize,
13 index: usize,
14}
15
16impl<'a, T> Iter<'a, T> {
17 pub(super) fn new(view: View<'a, T>) -> Self {
18 Self {
19 view,
20 coords: vec![0; view.dimensions()],
21 offset: 0,
22 index: 0,
23 }
24 }
25
26 fn backstride(&self, axis: usize) -> usize {
27 self.view.strides[axis] * (self.view.shape[axis] - 1)
28 }
29
30 fn impl_next_rec(&mut self, axis: usize) -> Option<<Self as Iterator>::Item> {
31 if self.index == 0 {
32 self.index += 1;
33 return self.view.data.first();
34 };
35
36 self.coords[axis] += 1;
37 if self.coords[axis] < self.view.shape[axis] {
38 self.offset += self.view.strides[axis];
39 self.index += 1;
40 self.view.data.get(self.offset)
41 } else if axis > 0 {
42 self.coords[axis] = 0;
43 self.offset -= self.backstride(axis);
44 self.impl_next_rec(axis - 1)
45 } else {
46 None
47 }
48 }
49}
50
51impl<'a, T> Iterator for Iter<'a, T> {
52 type Item = &'a T;
53
54 fn next(&mut self) -> Option<Self::Item> {
55 self.impl_next_rec(self.view.dimensions() - 1)
56 }
57
58 fn size_hint(&self) -> (usize, Option<usize>) {
59 let n = self.view.shape.elements() - self.index;
60 (n, Some(n))
61 }
62}
63
64impl<'a, T> ExactSizeIterator for Iter<'a, T> {}
65
66impl<'a, T> FusedIterator for Iter<'a, T> {}
67
68#[cfg(test)]
69mod tests {
70 use crate::{array::Axis, Array};
71
72 macro_rules! assert_iter_eq {
73 ($array:ident [axis: $axis:literal, index: $index:literal], [$($expected:literal),* $(,)?] $(,)?) => {
74 let view = $array.index_axis(Axis($axis), $index);
75 let mut iter = view.iter().copied();
76
77 let expected = vec![$($expected),+];
78 let mut len = expected.len();
79 let mut actual = Vec::new();
80
81 for _ in 0..expected.len() {
82 assert_eq!(iter.len(), len);
83 len -= 1;
84 actual.push(iter.next().unwrap());
85 }
86
87 assert_eq!(iter.len(), 0);
88 assert!(iter.next().is_none());
89 assert_eq!(actual, expected);
90 };
91 }
92
93 #[test]
94 fn test_iter_2x2() {
95 let array = Array::from_iter(0..4, [2, 2]).unwrap();
96
97 assert_iter_eq!(array[axis: 0, index: 0], [0, 1]);
98 assert_iter_eq!(array[axis: 0, index: 1], [2, 3]);
99
100 assert_iter_eq!(array[axis: 1, index: 0], [0, 2]);
101 assert_iter_eq!(array[axis: 1, index: 1], [1, 3]);
102 }
103
104 #[test]
105 fn test_iter_2x3x2() {
106 let array = Array::from_iter(0..12, [2, 3, 2]).unwrap();
107
108 assert_iter_eq!(array[axis: 0, index: 0], [0, 1, 2, 3, 4, 5]);
109 assert_iter_eq!(array[axis: 0, index: 1], [6, 7, 8, 9, 10, 11]);
110
111 assert_iter_eq!(array[axis: 1, index: 0], [0, 1, 6, 7]);
112 assert_iter_eq!(array[axis: 1, index: 1], [2, 3, 8, 9]);
113 assert_iter_eq!(array[axis: 1, index: 2], [4, 5, 10, 11]);
114
115 assert_iter_eq!(array[axis: 2, index: 0], [0, 2, 4, 6, 8, 10]);
116 assert_iter_eq!(array[axis: 2, index: 1], [1, 3, 5, 7, 9, 11]);
117 }
118
119 #[test]
120 fn test_iter_2x1x2x3() {
121 let array = Array::from_iter(0..12, [2, 1, 2, 3]).unwrap();
122
123 assert_iter_eq!(array[axis: 0, index: 0], [0, 1, 2, 3, 4, 5]);
124 assert_iter_eq!(array[axis: 0, index: 1], [6, 7, 8, 9, 10, 11]);
125
126 assert_iter_eq!(array[axis: 1, index: 0], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
127
128 assert_iter_eq!(array[axis: 2, index: 0], [0, 1, 2, 6, 7, 8]);
129 assert_iter_eq!(array[axis: 2, index: 1], [3, 4, 5, 9, 10, 11]);
130
131 assert_iter_eq!(array[axis: 3, index: 0], [0, 3, 6, 9]);
132 assert_iter_eq!(array[axis: 3, index: 1], [1, 4, 7, 10]);
133 assert_iter_eq!(array[axis: 3, index: 2], [2, 5, 8, 11]);
134 }
135
136 #[test]
137 fn test_iter_fused() {
138 let array = Array::new([0.0, 1.0], [2, 1]).unwrap();
139 let view = array.get_axis(Axis(0), 0).unwrap();
140 let mut iter = view.iter();
141
142 assert_eq!(iter.next(), Some(&0.0));
143 assert_eq!(iter.next(), None);
144 assert_eq!(iter.next(), None);
145 }
146}