zarrs/array_subset/iterators/
indices_iterator.rs1use std::iter::FusedIterator;
2
3use crate::{
4 array::{unravel_index, ArrayIndices},
5 array_subset::ArraySubset,
6};
7
8use rayon::iter::{
9 plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer},
10 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
11};
12
13#[derive(Clone)]
25pub struct Indices {
26 pub(crate) subset: ArraySubset,
27 pub(crate) range: std::ops::Range<usize>,
28}
29
30impl Indices {
31 #[must_use]
33 pub fn new(subset: ArraySubset) -> Self {
34 let length = subset.num_elements_usize();
35 Self {
36 subset,
37 range: 0..length,
38 }
39 }
40
41 #[must_use]
43 pub fn new_with_start_end(
44 subset: ArraySubset,
45 range: impl std::ops::RangeBounds<usize>,
46 ) -> Self {
47 let length = subset.num_elements_usize();
48 let start = match range.start_bound() {
49 std::ops::Bound::Included(start) => *start,
50 std::ops::Bound::Excluded(start) => start.saturating_add(1),
51 std::ops::Bound::Unbounded => 0,
52 };
53 let end = match range.end_bound() {
54 std::ops::Bound::Excluded(end) => (*end).min(length),
55 std::ops::Bound::Included(end) => end.saturating_add(1).min(length),
56 std::ops::Bound::Unbounded => length,
57 };
58 Self {
59 subset,
60 range: start..end,
61 }
62 }
63
64 #[must_use]
66 pub fn len(&self) -> usize {
67 self.range.end.saturating_sub(self.range.start)
68 }
69
70 #[must_use]
72 pub fn is_empty(&self) -> bool {
73 self.len() == 0
74 }
75
76 #[must_use]
78 pub fn iter(&self) -> IndicesIterator<'_> {
79 <&Self as IntoIterator>::into_iter(self)
80 }
81}
82
83impl<'a> IntoIterator for &'a Indices {
84 type Item = ArrayIndices;
85 type IntoIter = IndicesIterator<'a>;
86
87 fn into_iter(self) -> Self::IntoIter {
88 IndicesIterator {
89 subset: &self.subset,
90 range: self.range.clone(),
91 }
92 }
93}
94
95impl<'a> IntoParallelRefIterator<'a> for &'a Indices {
96 type Item = ArrayIndices;
97 type Iter = ParIndicesIterator<'a>;
98
99 fn par_iter(&self) -> Self::Iter {
100 ParIndicesIterator {
101 subset: &self.subset,
102 range: self.range.clone(),
103 }
104 }
105}
106
107impl<'a> IntoParallelIterator for &'a Indices {
108 type Item = ArrayIndices;
109 type Iter = ParIndicesIterator<'a>;
110
111 fn into_par_iter(self) -> Self::Iter {
112 ParIndicesIterator {
113 subset: &self.subset,
114 range: self.range.clone(),
115 }
116 }
117}
118
119impl IntoIterator for Indices {
120 type Item = ArrayIndices;
121 type IntoIter = IndicesIntoIterator;
122
123 fn into_iter(self) -> Self::IntoIter {
124 IndicesIntoIterator {
125 subset: self.subset,
126 range: self.range,
127 }
128 }
129}
130
131impl IntoParallelIterator for Indices {
132 type Item = ArrayIndices;
133 type Iter = ParIndicesIntoIterator;
134
135 fn into_par_iter(self) -> Self::Iter {
136 ParIndicesIntoIterator {
137 subset: self.subset,
138 range: self.range,
139 }
140 }
141}
142
143#[derive(Clone)]
147pub struct IndicesIterator<'a> {
148 pub(crate) subset: &'a ArraySubset,
149 pub(crate) range: std::ops::Range<usize>,
150}
151
152#[derive(Clone)]
156pub struct IndicesIntoIterator {
157 pub(crate) subset: ArraySubset,
158 pub(crate) range: std::ops::Range<usize>,
159}
160
161macro_rules! impl_indices_iterator {
162 ($iterator_type:ty) => {
163 impl Iterator for $iterator_type {
164 type Item = ArrayIndices;
165
166 fn next(&mut self) -> Option<Self::Item> {
167 if self.range.start >= self.range.end {
168 return None;
169 }
170 let mut indices = unravel_index(self.range.start as u64, self.subset.shape())?;
171 std::iter::zip(indices.iter_mut(), self.subset.start())
172 .for_each(|(index, start)| *index += start);
173
174 if self.range.start < self.range.end {
175 self.range.start += 1;
176 Some(indices)
177 } else {
178 None
179 }
180 }
181
182 fn size_hint(&self) -> (usize, Option<usize>) {
183 let length = self.range.end.saturating_sub(self.range.start);
184 (length, Some(length))
185 }
186 }
187
188 impl DoubleEndedIterator for $iterator_type {
189 fn next_back(&mut self) -> Option<Self::Item> {
190 if self.range.end > self.range.start {
191 self.range.end -= 1;
192 let mut indices = unravel_index(self.range.end as u64, self.subset.shape())?;
193 std::iter::zip(indices.iter_mut(), self.subset.start())
194 .for_each(|(index, start)| *index += start);
195 Some(indices)
196 } else {
197 None
198 }
199 }
200 }
201
202 impl ExactSizeIterator for $iterator_type {}
203
204 impl FusedIterator for $iterator_type {}
205 };
206}
207
208impl_indices_iterator!(IndicesIterator<'_>);
209impl_indices_iterator!(IndicesIntoIterator);
210
211pub struct ParIndicesIterator<'a> {
215 pub(crate) subset: &'a ArraySubset,
216 pub(crate) range: std::ops::Range<usize>,
217}
218
219pub struct ParIndicesIntoIterator {
223 pub(crate) subset: ArraySubset,
224 pub(crate) range: std::ops::Range<usize>,
225}
226
227macro_rules! impl_par_chunks_iterator {
228 ($iterator_type:ty) => {
229 impl ParallelIterator for $iterator_type {
230 type Item = ArrayIndices;
231
232 fn drive_unindexed<C>(self, consumer: C) -> C::Result
233 where
234 C: UnindexedConsumer<Self::Item>,
235 {
236 bridge(self, consumer)
237 }
238
239 fn opt_len(&self) -> Option<usize> {
240 Some(self.len())
241 }
242 }
243
244 impl IndexedParallelIterator for $iterator_type {
245 fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
246 callback.callback(self)
247 }
248
249 fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
250 bridge(self, consumer)
251 }
252
253 fn len(&self) -> usize {
254 self.range.end.saturating_sub(self.range.start)
255 }
256 }
257 };
258}
259
260impl_par_chunks_iterator!(ParIndicesIterator<'_>);
261impl_par_chunks_iterator!(ParIndicesIntoIterator);
262
263impl<'a> Producer for ParIndicesIterator<'a> {
264 type Item = ArrayIndices;
265 type IntoIter = IndicesIterator<'a>;
266
267 fn into_iter(self) -> Self::IntoIter {
268 IndicesIterator {
269 subset: self.subset,
270 range: self.range,
271 }
272 }
273
274 fn split_at(self, index: usize) -> (Self, Self) {
275 let left = ParIndicesIterator {
276 subset: self.subset,
277 range: self.range.start..self.range.start + index,
278 };
279 let right = ParIndicesIterator {
280 subset: self.subset,
281 range: (self.range.start + index)..self.range.end,
282 };
283 (left, right)
284 }
285}
286
287impl Producer for ParIndicesIntoIterator {
288 type Item = ArrayIndices;
289 type IntoIter = IndicesIntoIterator;
290
291 fn into_iter(self) -> Self::IntoIter {
292 IndicesIntoIterator {
293 subset: self.subset,
294 range: self.range,
295 }
296 }
297
298 fn split_at(self, index: usize) -> (Self, Self) {
299 let left = ParIndicesIntoIterator {
300 subset: self.subset.clone(),
301 range: self.range.start..self.range.start + index,
302 };
303 let right = ParIndicesIntoIterator {
304 subset: self.subset,
305 range: (self.range.start + index)..self.range.end,
306 };
307 (left, right)
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn indices_iterator_partial() {
317 let indices =
318 Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), 1..4);
319 assert_eq!(indices.len(), 3);
320 let mut iter = indices.iter();
321 assert_eq!(iter.next(), Some(vec![1, 6]));
322 assert_eq!(iter.next_back(), Some(vec![2, 6]));
323 assert_eq!(iter.next(), Some(vec![2, 5]));
324 assert_eq!(iter.next(), None);
325
326 assert_eq!(
327 indices.into_par_iter().map(|v| v[0] + v[1]).sum::<u64>(),
328 22
329 );
330
331 let indices =
332 Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), ..=0);
333 assert_eq!(indices.len(), 1);
334 let mut iter = indices.iter();
335 assert_eq!(iter.next(), Some(vec![1, 5]));
336 assert_eq!(iter.next(), None);
337 }
338
339 #[test]
340 fn indices_iterator_empty() {
341 let indices =
342 Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), 5..5);
343 assert_eq!(indices.len(), 0);
344 assert!(indices.is_empty());
345
346 let indices =
347 Indices::new_with_start_end(ArraySubset::new_with_ranges(&[1..3, 5..7]), 5..1);
348 assert_eq!(indices.len(), 0);
349 assert!(indices.is_empty());
350 }
351}