1use rayon::iter::{
2 plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer},
3 IndexedParallelIterator, ParallelIterator,
4};
5
6use crate::saf::{AsSafView, Saf, SafView};
7
8pub trait IntoBlockIterator<const N: usize> {
10 type Item: AsSafView<N>;
12 type Iter: ExactSizeIterator<Item = Self::Item>;
14
15 fn into_block_iter(self, block_size: usize) -> Self::Iter;
18}
19
20impl<'a, const N: usize> IntoBlockIterator<N> for &'a Saf<N> {
21 type Item = SafView<'a, N>;
22 type Iter = BlockIter<'a, N>;
23
24 fn into_block_iter(self, block_size: usize) -> Self::Iter {
25 BlockIter::new(self.view(), block_size)
26 }
27}
28
29impl<'a, const N: usize> IntoBlockIterator<N> for SafView<'a, N> {
30 type Item = SafView<'a, N>;
31 type Iter = BlockIter<'a, N>;
32
33 fn into_block_iter(self, block_size: usize) -> Self::Iter {
34 BlockIter::new(self, block_size)
35 }
36}
37
38impl<'a, 'b, const N: usize> IntoBlockIterator<N> for &'b SafView<'a, N> {
39 type Item = SafView<'a, N>;
40 type Iter = BlockIter<'a, N>;
41
42 fn into_block_iter(self, block_size: usize) -> Self::Iter {
43 BlockIter::new(*self, block_size)
44 }
45}
46
47pub trait IntoParallelBlockIterator<const N: usize> {
49 type Item: AsSafView<N>;
51 type Iter: IndexedParallelIterator<Item = Self::Item>;
53
54 fn into_par_block_iter(self, block_size: usize) -> Self::Iter;
57}
58
59impl<'a, const N: usize> IntoParallelBlockIterator<N> for &'a Saf<N> {
60 type Item = SafView<'a, N>;
61 type Iter = ParBlockIter<'a, N>;
62
63 fn into_par_block_iter(self, block_size: usize) -> Self::Iter {
64 ParBlockIter::new(self.view(), block_size)
65 }
66}
67
68impl<'a, const N: usize> IntoParallelBlockIterator<N> for SafView<'a, N> {
69 type Item = SafView<'a, N>;
70 type Iter = ParBlockIter<'a, N>;
71
72 fn into_par_block_iter(self, block_size: usize) -> Self::Iter {
73 ParBlockIter::new(self, block_size)
74 }
75}
76
77impl<'a, 'b, const N: usize> IntoParallelBlockIterator<N> for &'b SafView<'a, N> {
78 type Item = SafView<'a, N>;
79 type Iter = ParBlockIter<'a, N>;
80
81 fn into_par_block_iter(self, block_size: usize) -> Self::Iter {
82 ParBlockIter::new(*self, block_size)
83 }
84}
85
86#[derive(Debug)]
88pub struct BlockIter<'a, const N: usize> {
89 iter: ::std::slice::Chunks<'a, f32>,
90 shape: [usize; N],
91}
92
93impl<'a, const N: usize> BlockIter<'a, N> {
94 pub(in crate::saf) fn new(saf: SafView<'a, N>, block_size: usize) -> Self {
95 let iter = saf.values.chunks(saf.width() * block_size);
96
97 Self {
98 iter,
99 shape: saf.shape,
100 }
101 }
102}
103
104impl<'a, const N: usize> Iterator for BlockIter<'a, N> {
105 type Item = SafView<'a, N>;
106
107 #[inline]
108 fn next(&mut self) -> Option<Self::Item> {
109 self.iter
110 .next()
111 .map(|item| SafView::new_unchecked(item, self.shape))
112 }
113
114 fn size_hint(&self) -> (usize, Option<usize>) {
115 self.iter.size_hint()
116 }
117}
118
119impl<'a, const N: usize> ExactSizeIterator for BlockIter<'a, N> {
120 fn len(&self) -> usize {
121 self.iter.len()
122 }
123}
124
125impl<'a, const N: usize> DoubleEndedIterator for BlockIter<'a, N> {
126 fn next_back(&mut self) -> Option<Self::Item> {
127 self.iter
128 .next_back()
129 .map(|item| SafView::new_unchecked(item, self.shape))
130 }
131}
132
133#[derive(Debug)]
135pub struct ParBlockIter<'a, const N: usize> {
136 values: &'a [f32],
137 shape: [usize; N],
138 chunk_size: usize,
139}
140
141impl<'a, const N: usize> ParBlockIter<'a, N> {
142 pub(in crate::saf) fn new(saf: SafView<'a, N>, block_size: usize) -> Self {
143 Self {
144 values: saf.values,
145 shape: saf.shape,
146 chunk_size: saf.width() * block_size,
147 }
148 }
149}
150
151impl<'a, const N: usize> ParallelIterator for ParBlockIter<'a, N> {
158 type Item = SafView<'a, N>;
159
160 fn drive_unindexed<C>(self, consumer: C) -> C::Result
161 where
162 C: UnindexedConsumer<Self::Item>,
163 {
164 bridge(self, consumer)
165 }
166
167 fn opt_len(&self) -> Option<usize> {
168 Some(self.len())
169 }
170}
171
172impl<'a, const N: usize> IndexedParallelIterator for ParBlockIter<'a, N> {
173 fn drive<C>(self, consumer: C) -> C::Result
174 where
175 C: Consumer<Self::Item>,
176 {
177 bridge(self, consumer)
178 }
179
180 fn len(&self) -> usize {
181 let n = self.values.len();
182 if n == 0 {
183 0
184 } else {
185 (n - 1) / self.chunk_size + 1
186 }
187 }
188
189 fn with_producer<CB>(self, callback: CB) -> CB::Output
190 where
191 CB: ProducerCallback<Self::Item>,
192 {
193 callback.callback(BlockProducer {
194 values: self.values,
195 shape: self.shape,
196 chunk_size: self.chunk_size,
197 })
198 }
199}
200
201struct BlockProducer<'a, const N: usize> {
202 values: &'a [f32],
203 shape: [usize; N],
204 chunk_size: usize,
205}
206
207impl<'a, const N: usize> Producer for BlockProducer<'a, N> {
208 type Item = SafView<'a, N>;
209 type IntoIter = BlockIter<'a, N>;
210
211 fn into_iter(self) -> Self::IntoIter {
212 BlockIter {
213 iter: self.values.chunks(self.chunk_size),
214 shape: self.shape,
215 }
216 }
217
218 fn split_at(self, index: usize) -> (Self, Self) {
219 let elem_index = self.values.len().min(index * self.chunk_size);
220 let (left, right) = self.values.split_at(elem_index);
221
222 (
223 Self {
224 values: left,
225 shape: self.shape,
226 chunk_size: self.chunk_size,
227 },
228 Self {
229 values: right,
230 shape: self.shape,
231 chunk_size: self.chunk_size,
232 },
233 )
234 }
235}