sfs_core/array/
iter.rs

1//! Array iterators.
2//!
3//! The types in this module are constructed via methods on [`Array`],
4//! and generally expose no functionality other than being iterable.
5
6use std::iter::FusedIterator;
7
8use super::{Array, Axis, Shape, View};
9
10/// An iterator over [`View`]s along an axis of an [`Array`].
11///
12/// See [`Array::iter_axis`] for details.
13#[derive(Debug)]
14pub struct AxisIter<'a, T> {
15    array: &'a Array<T>,
16    axis: Axis,
17    index: usize,
18}
19
20impl<'a, T> AxisIter<'a, T> {
21    pub(super) fn new(array: &'a Array<T>, axis: Axis) -> Self {
22        Self {
23            array,
24            axis,
25            index: 0,
26        }
27    }
28}
29
30impl<'a, T> Iterator for AxisIter<'a, T> {
31    type Item = View<'a, T>;
32
33    fn next(&mut self) -> Option<Self::Item> {
34        let view = self.array.get_axis(self.axis, self.index)?;
35        self.index += 1;
36        Some(view)
37    }
38
39    fn size_hint(&self) -> (usize, Option<usize>) {
40        let n = self.array.shape[self.axis.0];
41        (n, Some(n))
42    }
43}
44
45impl<'a, T> ExactSizeIterator for AxisIter<'a, T> {}
46
47impl<'a, T> FusedIterator for AxisIter<'a, T> {}
48
49/// An iterator over indices of elements in an array in row-major order.
50///
51/// See [`Array::iter_indices`] for details.
52#[derive(Debug)]
53pub struct IndicesIter<'a> {
54    shape: &'a Shape,
55    index: usize,
56    total: usize,
57}
58
59impl<'a> IndicesIter<'a> {
60    pub(crate) fn shape(&self) -> &'a Shape {
61        self.shape
62    }
63
64    pub(crate) fn new<T>(array: &'a Array<T>) -> Self {
65        Self::from_shape(array.shape())
66    }
67
68    pub(crate) fn from_shape(shape: &'a Shape) -> Self {
69        Self {
70            shape,
71            index: 0,
72            total: shape.elements(),
73        }
74    }
75}
76
77impl<'a> Iterator for IndicesIter<'a> {
78    type Item = Vec<usize>;
79
80    fn next(&mut self) -> Option<Self::Item> {
81        (self.index < self.total).then(|| {
82            self.index += 1;
83            self.shape.index_from_flat_unchecked(self.index - 1)
84        })
85    }
86
87    fn size_hint(&self) -> (usize, Option<usize>) {
88        let len = self.total - self.index;
89        (len, Some(len))
90    }
91}
92
93impl<'a> ExactSizeIterator for IndicesIter<'a> {}
94
95impl<'a> FusedIterator for IndicesIter<'a> {}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_iter_indices_1d() {
103        let array = Array::from_zeros(4);
104        let mut iter = array.iter_indices();
105
106        assert_eq!(iter.len(), 4);
107
108        assert_eq!(iter.next(), Some(vec![0]));
109        assert_eq!(iter.next(), Some(vec![1]));
110
111        assert_eq!(iter.len(), 2);
112
113        assert_eq!(iter.next(), Some(vec![2]));
114        assert_eq!(iter.next(), Some(vec![3]));
115
116        assert_eq!(iter.len(), 0);
117        assert!(iter.next().is_none());
118    }
119
120    #[test]
121    fn test_iter_indices_2d() {
122        let array = Array::from_zeros([2, 3]);
123        let mut iter = array.iter_indices();
124
125        assert_eq!(iter.len(), 6);
126
127        assert_eq!(iter.next(), Some(vec![0, 0]));
128        assert_eq!(iter.next(), Some(vec![0, 1]));
129        assert_eq!(iter.next(), Some(vec![0, 2]));
130
131        assert_eq!(iter.len(), 3);
132
133        assert_eq!(iter.next(), Some(vec![1, 0]));
134        assert_eq!(iter.next(), Some(vec![1, 1]));
135        assert_eq!(iter.next(), Some(vec![1, 2]));
136
137        assert_eq!(iter.len(), 0);
138        assert!(iter.next().is_none());
139    }
140
141    #[test]
142    fn test_iter_indices_3d() {
143        let array = Array::from_zeros([2, 1, 3]);
144        let mut iter = array.iter_indices();
145
146        assert_eq!(iter.next(), Some(vec![0, 0, 0]));
147        assert_eq!(iter.next(), Some(vec![0, 0, 1]));
148        assert_eq!(iter.next(), Some(vec![0, 0, 2]));
149
150        assert_eq!(iter.len(), 3);
151
152        assert_eq!(iter.next(), Some(vec![1, 0, 0]));
153        assert_eq!(iter.next(), Some(vec![1, 0, 1]));
154        assert_eq!(iter.next(), Some(vec![1, 0, 2]));
155
156        assert_eq!(iter.len(), 0);
157        assert!(iter.next().is_none());
158    }
159}