zarrs/array_subset/iterators/
indices_iterator.rs

1use std::iter::FusedIterator;
2
3use crate::{
4    array::{unravel_index, ArrayIndices},
5    array_subset::ArraySubset,
6};
7
8use rayon::iter::{
9    plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer},
10    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
11};
12
13/// An iterator over the indices in an array subset.
14///
15/// Iterates over the last dimension fastest (i.e. C-contiguous order).
16/// For example, consider a 4x3 array with element indices
17/// ```text
18/// (0, 0)  (0, 1)  (0, 2)
19/// (1, 0)  (1, 1)  (1, 2)
20/// (2, 0)  (2, 1)  (2, 2)
21/// (3, 0)  (3, 1)  (3, 2)
22/// ```
23/// An iterator with an array subset corresponding to the lower right 2x2 region will produce `[(2, 1), (2, 2), (3, 1), (3, 2)]`.
24#[derive(Clone)]
25pub struct Indices {
26    pub(crate) subset: ArraySubset,
27    pub(crate) range: std::ops::Range<usize>,
28}
29
30impl Indices {
31    /// Create a new indices struct.
32    #[must_use]
33    pub fn new(subset: ArraySubset) -> Self {
34        let length = subset.num_elements_usize();
35        Self {
36            subset,
37            range: 0..length,
38        }
39    }
40
41    /// Create a new indices struct spanning `range`.
42    #[must_use]
43    pub fn new_with_start_end(
44        subset: ArraySubset,
45        range: impl std::ops::RangeBounds<usize>,
46    ) -> Self {
47        let length = subset.num_elements_usize();
48        let start = match range.start_bound() {
49            std::ops::Bound::Included(start) => *start,
50            std::ops::Bound::Excluded(start) => start.saturating_add(1),
51            std::ops::Bound::Unbounded => 0,
52        };
53        let end = match range.end_bound() {
54            std::ops::Bound::Excluded(end) => (*end).min(length),
55            std::ops::Bound::Included(end) => end.saturating_add(1).min(length),
56            std::ops::Bound::Unbounded => length,
57        };
58        Self {
59            subset,
60            range: start..end,
61        }
62    }
63
64    /// Return the number of indices.
65    #[must_use]
66    pub fn len(&self) -> usize {
67        self.range.end.saturating_sub(self.range.start)
68    }
69
70    /// Returns true if the number of indices is zero.
71    #[must_use]
72    pub fn is_empty(&self) -> bool {
73        self.len() == 0
74    }
75
76    /// Create a new serial iterator.
77    #[must_use]
78    pub fn iter(&self) -> IndicesIterator<'_> {
79        <&Self as IntoIterator>::into_iter(self)
80    }
81}
82
83impl<'a> IntoIterator for &'a Indices {
84    type Item = ArrayIndices;
85    type IntoIter = IndicesIterator<'a>;
86
87    fn into_iter(self) -> Self::IntoIter {
88        IndicesIterator {
89            subset: &self.subset,
90            range: self.range.clone(),
91        }
92    }
93}
94
95impl<'a> IntoParallelRefIterator<'a> for &'a Indices {
96    type Item = ArrayIndices;
97    type Iter = ParIndicesIterator<'a>;
98
99    fn par_iter(&self) -> Self::Iter {
100        ParIndicesIterator {
101            subset: &self.subset,
102            range: self.range.clone(),
103        }
104    }
105}
106
107impl<'a> IntoParallelIterator for &'a Indices {
108    type Item = ArrayIndices;
109    type Iter = ParIndicesIterator<'a>;
110
111    fn into_par_iter(self) -> Self::Iter {
112        ParIndicesIterator {
113            subset: &self.subset,
114            range: self.range.clone(),
115        }
116    }
117}
118
119impl IntoIterator for Indices {
120    type Item = ArrayIndices;
121    type IntoIter = IndicesIntoIterator;
122
123    fn into_iter(self) -> Self::IntoIter {
124        IndicesIntoIterator {
125            subset: self.subset,
126            range: self.range,
127        }
128    }
129}
130
131impl IntoParallelIterator for Indices {
132    type Item = ArrayIndices;
133    type Iter = ParIndicesIntoIterator;
134
135    fn into_par_iter(self) -> Self::Iter {
136        ParIndicesIntoIterator {
137            subset: self.subset,
138            range: self.range,
139        }
140    }
141}
142
143/// Serial indices iterator.
144///
145/// See [`Indices`].
146#[derive(Clone)]
147pub struct IndicesIterator<'a> {
148    pub(crate) subset: &'a ArraySubset,
149    pub(crate) range: std::ops::Range<usize>,
150}
151
152/// Serial indices iterator.
153///
154/// See [`Indices`].
155#[derive(Clone)]
156pub struct IndicesIntoIterator {
157    pub(crate) subset: ArraySubset,
158    pub(crate) range: std::ops::Range<usize>,
159}
160
161macro_rules! impl_indices_iterator {
162    ($iterator_type:ty) => {
163        impl Iterator for $iterator_type {
164            type Item = ArrayIndices;
165
166            fn next(&mut self) -> Option<Self::Item> {
167                if self.range.start >= self.range.end {
168                    return None;
169                }
170                let mut indices = unravel_index(self.range.start as u64, self.subset.shape())?;
171                std::iter::zip(indices.iter_mut(), self.subset.start())
172                    .for_each(|(index, start)| *index += start);
173
174                if self.range.start < self.range.end {
175                    self.range.start += 1;
176                    Some(indices)
177                } else {
178                    None
179                }
180            }
181
182            fn size_hint(&self) -> (usize, Option<usize>) {
183                let length = self.range.end.saturating_sub(self.range.start);
184                (length, Some(length))
185            }
186        }
187
188        impl DoubleEndedIterator for $iterator_type {
189            fn next_back(&mut self) -> Option<Self::Item> {
190                if self.range.end > self.range.start {
191                    self.range.end -= 1;
192                    let mut indices = unravel_index(self.range.end as u64, self.subset.shape())?;
193                    std::iter::zip(indices.iter_mut(), self.subset.start())
194                        .for_each(|(index, start)| *index += start);
195                    Some(indices)
196                } else {
197                    None
198                }
199            }
200        }
201
202        impl ExactSizeIterator for $iterator_type {}
203
204        impl FusedIterator for $iterator_type {}
205    };
206}
207
208impl_indices_iterator!(IndicesIterator<'_>);
209impl_indices_iterator!(IndicesIntoIterator);
210
211/// Parallel indices iterator.
212///
213/// See [`Indices`].
214pub struct ParIndicesIterator<'a> {
215    pub(crate) subset: &'a ArraySubset,
216    pub(crate) range: std::ops::Range<usize>,
217}
218
219/// Parallel indices iterator.
220///
221/// See [`Indices`].
222pub struct ParIndicesIntoIterator {
223    pub(crate) subset: ArraySubset,
224    pub(crate) range: std::ops::Range<usize>,
225}
226
227macro_rules! impl_par_chunks_iterator {
228    ($iterator_type:ty) => {
229        impl ParallelIterator for $iterator_type {
230            type Item = ArrayIndices;
231
232            fn drive_unindexed<C>(self, consumer: C) -> C::Result
233            where
234                C: UnindexedConsumer<Self::Item>,
235            {
236                bridge(self, consumer)
237            }
238
239            fn opt_len(&self) -> Option<usize> {
240                Some(self.len())
241            }
242        }
243
244        impl IndexedParallelIterator for $iterator_type {
245            fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
246                callback.callback(self)
247            }
248
249            fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
250                bridge(self, consumer)
251            }
252
253            fn len(&self) -> usize {
254                self.range.end.saturating_sub(self.range.start)
255            }
256        }
257    };
258}
259
260impl_par_chunks_iterator!(ParIndicesIterator<'_>);
261impl_par_chunks_iterator!(ParIndicesIntoIterator);
262
263impl<'a> Producer for ParIndicesIterator<'a> {
264    type Item = ArrayIndices;
265    type IntoIter = IndicesIterator<'a>;
266
267    fn into_iter(self) -> Self::IntoIter {
268        IndicesIterator {
269            subset: self.subset,
270            range: self.range,
271        }
272    }
273
274    fn split_at(self, index: usize) -> (Self, Self) {
275        let left = ParIndicesIterator {
276            subset: self.subset,
277            range: self.range.start..self.range.start + index,
278        };
279        let right = ParIndicesIterator {
280            subset: self.subset,
281            range: (self.range.start + index)..self.range.end,
282        };
283        (left, right)
284    }
285}
286
287impl Producer for ParIndicesIntoIterator {
288    type Item = ArrayIndices;
289    type IntoIter = IndicesIntoIterator;
290
291    fn into_iter(self) -> Self::IntoIter {
292        IndicesIntoIterator {
293            subset: self.subset,
294            range: self.range,
295        }
296    }
297
298    fn split_at(self, index: usize) -> (Self, Self) {
299        let left = ParIndicesIntoIterator {
300            subset: self.subset.clone(),
301            range: self.range.start..self.range.start + index,
302        };
303        let right = ParIndicesIntoIterator {
304            subset: self.subset,
305            range: (self.range.start + index)..self.range.end,
306        };
307        (left, right)
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn indices_iterator_partial() {
317        let indices =
318            Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), 1..4);
319        assert_eq!(indices.len(), 3);
320        let mut iter = indices.iter();
321        assert_eq!(iter.next(), Some(vec![1, 6]));
322        assert_eq!(iter.next_back(), Some(vec![2, 6]));
323        assert_eq!(iter.next(), Some(vec![2, 5]));
324        assert_eq!(iter.next(), None);
325
326        assert_eq!(
327            indices.into_par_iter().map(|v| v[0] + v[1]).sum::<u64>(),
328            22
329        );
330
331        let indices =
332            Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), ..=0);
333        assert_eq!(indices.len(), 1);
334        let mut iter = indices.iter();
335        assert_eq!(iter.next(), Some(vec![1, 5]));
336        assert_eq!(iter.next(), None);
337    }
338
339    #[test]
340    fn indices_iterator_empty() {
341        let indices =
342            Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), 5..5);
343        assert_eq!(indices.len(), 0);
344        assert!(indices.is_empty());
345
346        let indices =
347            Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), 5..1);
348        assert_eq!(indices.len(), 0);
349        assert!(indices.is_empty());
350    }
351}