sfs_core/array/view/
iter.rs

1use std::iter::FusedIterator;
2
3use super::View;
4
5/// An iterator over the elements in a [`View`].
6///
7/// See [`View::iter`] for details.
8#[derive(Clone, Debug)]
9pub struct Iter<'a, T> {
10    view: View<'a, T>,
11    coords: Vec<usize>,
12    offset: usize,
13    index: usize,
14}
15
16impl<'a, T> Iter<'a, T> {
17    pub(super) fn new(view: View<'a, T>) -> Self {
18        Self {
19            view,
20            coords: vec![0; view.dimensions()],
21            offset: 0,
22            index: 0,
23        }
24    }
25
26    fn backstride(&self, axis: usize) -> usize {
27        self.view.strides[axis] * (self.view.shape[axis] - 1)
28    }
29
30    fn impl_next_rec(&mut self, axis: usize) -> Option<<Self as Iterator>::Item> {
31        if self.index == 0 {
32            self.index += 1;
33            return self.view.data.first();
34        };
35
36        self.coords[axis] += 1;
37        if self.coords[axis] < self.view.shape[axis] {
38            self.offset += self.view.strides[axis];
39            self.index += 1;
40            self.view.data.get(self.offset)
41        } else if axis > 0 {
42            self.coords[axis] = 0;
43            self.offset -= self.backstride(axis);
44            self.impl_next_rec(axis - 1)
45        } else {
46            None
47        }
48    }
49}
50
51impl<'a, T> Iterator for Iter<'a, T> {
52    type Item = &'a T;
53
54    fn next(&mut self) -> Option<Self::Item> {
55        self.impl_next_rec(self.view.dimensions() - 1)
56    }
57
58    fn size_hint(&self) -> (usize, Option<usize>) {
59        let n = self.view.shape.elements() - self.index;
60        (n, Some(n))
61    }
62}
63
64impl<'a, T> ExactSizeIterator for Iter<'a, T> {}
65
66impl<'a, T> FusedIterator for Iter<'a, T> {}
67
68#[cfg(test)]
69mod tests {
70    use crate::{array::Axis, Array};
71
72    macro_rules! assert_iter_eq {
73        ($array:ident [axis: $axis:literal, index: $index:literal], [$($expected:literal),* $(,)?] $(,)?) => {
74            let view = $array.index_axis(Axis($axis), $index);
75            let mut iter = view.iter().copied();
76
77            let expected = vec![$($expected),+];
78            let mut len = expected.len();
79            let mut actual = Vec::new();
80
81            for _ in 0..expected.len() {
82                assert_eq!(iter.len(), len);
83                len -= 1;
84                actual.push(iter.next().unwrap());
85            }
86
87            assert_eq!(iter.len(), 0);
88            assert!(iter.next().is_none());
89            assert_eq!(actual, expected);
90        };
91    }
92
93    #[test]
94    fn test_iter_2x2() {
95        let array = Array::from_iter(0..4, [2, 2]).unwrap();
96
97        assert_iter_eq!(array[axis: 0, index: 0], [0, 1]);
98        assert_iter_eq!(array[axis: 0, index: 1], [2, 3]);
99
100        assert_iter_eq!(array[axis: 1, index: 0], [0, 2]);
101        assert_iter_eq!(array[axis: 1, index: 1], [1, 3]);
102    }
103
104    #[test]
105    fn test_iter_2x3x2() {
106        let array = Array::from_iter(0..12, [2, 3, 2]).unwrap();
107
108        assert_iter_eq!(array[axis: 0, index: 0], [0, 1, 2, 3, 4, 5]);
109        assert_iter_eq!(array[axis: 0, index: 1], [6, 7, 8, 9, 10, 11]);
110
111        assert_iter_eq!(array[axis: 1, index: 0], [0, 1, 6, 7]);
112        assert_iter_eq!(array[axis: 1, index: 1], [2, 3, 8, 9]);
113        assert_iter_eq!(array[axis: 1, index: 2], [4, 5, 10, 11]);
114
115        assert_iter_eq!(array[axis: 2, index: 0], [0, 2, 4, 6, 8, 10]);
116        assert_iter_eq!(array[axis: 2, index: 1], [1, 3, 5, 7, 9, 11]);
117    }
118
119    #[test]
120    fn test_iter_2x1x2x3() {
121        let array = Array::from_iter(0..12, [2, 1, 2, 3]).unwrap();
122
123        assert_iter_eq!(array[axis: 0, index: 0], [0, 1, 2, 3, 4, 5]);
124        assert_iter_eq!(array[axis: 0, index: 1], [6, 7, 8, 9, 10, 11]);
125
126        assert_iter_eq!(array[axis: 1, index: 0], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
127
128        assert_iter_eq!(array[axis: 2, index: 0], [0, 1, 2, 6, 7, 8]);
129        assert_iter_eq!(array[axis: 2, index: 1], [3, 4, 5, 9, 10, 11]);
130
131        assert_iter_eq!(array[axis: 3, index: 0], [0, 3, 6, 9]);
132        assert_iter_eq!(array[axis: 3, index: 1], [1, 4, 7, 10]);
133        assert_iter_eq!(array[axis: 3, index: 2], [2, 5, 8, 11]);
134    }
135
136    #[test]
137    fn test_iter_fused() {
138        let array = Array::new([0.0, 1.0], [2, 1]).unwrap();
139        let view = array.get_axis(Axis(0), 0).unwrap();
140        let mut iter = view.iter();
141
142        assert_eq!(iter.next(), Some(&0.0));
143        assert_eq!(iter.next(), None);
144        assert_eq!(iter.next(), None);
145    }
146}