1use std::iter::FusedIterator;
7
8use super::{Array, Axis, Shape, View};
9
10#[derive(Debug)]
14pub struct AxisIter<'a, T> {
15 array: &'a Array<T>,
16 axis: Axis,
17 index: usize,
18}
19
20impl<'a, T> AxisIter<'a, T> {
21 pub(super) fn new(array: &'a Array<T>, axis: Axis) -> Self {
22 Self {
23 array,
24 axis,
25 index: 0,
26 }
27 }
28}
29
30impl<'a, T> Iterator for AxisIter<'a, T> {
31 type Item = View<'a, T>;
32
33 fn next(&mut self) -> Option<Self::Item> {
34 let view = self.array.get_axis(self.axis, self.index)?;
35 self.index += 1;
36 Some(view)
37 }
38
39 fn size_hint(&self) -> (usize, Option<usize>) {
40 let n = self.array.shape[self.axis.0];
41 (n, Some(n))
42 }
43}
44
45impl<'a, T> ExactSizeIterator for AxisIter<'a, T> {}
46
47impl<'a, T> FusedIterator for AxisIter<'a, T> {}
48
49#[derive(Debug)]
53pub struct IndicesIter<'a> {
54 shape: &'a Shape,
55 index: usize,
56 total: usize,
57}
58
59impl<'a> IndicesIter<'a> {
60 pub(crate) fn shape(&self) -> &'a Shape {
61 self.shape
62 }
63
64 pub(crate) fn new<T>(array: &'a Array<T>) -> Self {
65 Self::from_shape(array.shape())
66 }
67
68 pub(crate) fn from_shape(shape: &'a Shape) -> Self {
69 Self {
70 shape,
71 index: 0,
72 total: shape.elements(),
73 }
74 }
75}
76
77impl<'a> Iterator for IndicesIter<'a> {
78 type Item = Vec<usize>;
79
80 fn next(&mut self) -> Option<Self::Item> {
81 (self.index < self.total).then(|| {
82 self.index += 1;
83 self.shape.index_from_flat_unchecked(self.index - 1)
84 })
85 }
86
87 fn size_hint(&self) -> (usize, Option<usize>) {
88 let len = self.total - self.index;
89 (len, Some(len))
90 }
91}
92
93impl<'a> ExactSizeIterator for IndicesIter<'a> {}
94
95impl<'a> FusedIterator for IndicesIter<'a> {}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_iter_indices_1d() {
103 let array = Array::from_zeros(4);
104 let mut iter = array.iter_indices();
105
106 assert_eq!(iter.len(), 4);
107
108 assert_eq!(iter.next(), Some(vec![0]));
109 assert_eq!(iter.next(), Some(vec![1]));
110
111 assert_eq!(iter.len(), 2);
112
113 assert_eq!(iter.next(), Some(vec![2]));
114 assert_eq!(iter.next(), Some(vec![3]));
115
116 assert_eq!(iter.len(), 0);
117 assert!(iter.next().is_none());
118 }
119
120 #[test]
121 fn test_iter_indices_2d() {
122 let array = Array::from_zeros([2, 3]);
123 let mut iter = array.iter_indices();
124
125 assert_eq!(iter.len(), 6);
126
127 assert_eq!(iter.next(), Some(vec![0, 0]));
128 assert_eq!(iter.next(), Some(vec![0, 1]));
129 assert_eq!(iter.next(), Some(vec![0, 2]));
130
131 assert_eq!(iter.len(), 3);
132
133 assert_eq!(iter.next(), Some(vec![1, 0]));
134 assert_eq!(iter.next(), Some(vec![1, 1]));
135 assert_eq!(iter.next(), Some(vec![1, 2]));
136
137 assert_eq!(iter.len(), 0);
138 assert!(iter.next().is_none());
139 }
140
141 #[test]
142 fn test_iter_indices_3d() {
143 let array = Array::from_zeros([2, 1, 3]);
144 let mut iter = array.iter_indices();
145
146 assert_eq!(iter.next(), Some(vec![0, 0, 0]));
147 assert_eq!(iter.next(), Some(vec![0, 0, 1]));
148 assert_eq!(iter.next(), Some(vec![0, 0, 2]));
149
150 assert_eq!(iter.len(), 3);
151
152 assert_eq!(iter.next(), Some(vec![1, 0, 0]));
153 assert_eq!(iter.next(), Some(vec![1, 0, 1]));
154 assert_eq!(iter.next(), Some(vec![1, 0, 2]));
155
156 assert_eq!(iter.len(), 0);
157 assert!(iter.next().is_none());
158 }
159}