train_station/tensor/iterator/
chunks.rs

1//! Linear chunk iterators over tensors as contiguous views
2
3use crate::tensor::core::Tensor;
4use std::iter::{ExactSizeIterator, FusedIterator};
5
6pub struct TensorChunksIterator<'a> {
7    pub(crate) source: &'a Tensor,
8    pub(crate) chunk_size: usize,
9    pub(crate) position: usize,
10    pub(crate) end: usize,
11}
12
13impl<'a> TensorChunksIterator<'a> {
14    #[inline]
15    pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
16        assert!(chunk_size > 0, "chunk_size must be > 0");
17        Self {
18            source,
19            chunk_size,
20            position: 0,
21            end: source.size(),
22        }
23    }
24
25    #[inline]
26    fn create_chunk_view(&self, start: usize, len: usize) -> Tensor {
27        if len == 0 {
28            return Tensor::new(vec![0]);
29        }
30        // Force contiguous slice to avoid stepped views exposing stride>1 in iterators
31        let v = self.source.slice_view(start, 1, len);
32        if v.is_contiguous() {
33            v
34        } else {
35            v.contiguous()
36        }
37    }
38}
39
40impl<'a> Iterator for TensorChunksIterator<'a> {
41    type Item = Tensor;
42    #[inline]
43    fn next(&mut self) -> Option<Self::Item> {
44        if self.position >= self.end {
45            return None;
46        }
47        let start = self.position;
48        let remaining = self.end - start;
49        let take = remaining.min(self.chunk_size);
50        self.position += take;
51        Some(self.create_chunk_view(start, take))
52    }
53    #[inline]
54    fn size_hint(&self) -> (usize, Option<usize>) {
55        let remaining = self.end.saturating_sub(self.position);
56        let n = if remaining == 0 {
57            0
58        } else {
59            remaining.div_ceil(self.chunk_size)
60        };
61        (n, Some(n))
62    }
63}
64
65impl<'a> ExactSizeIterator for TensorChunksIterator<'a> {
66    #[inline]
67    fn len(&self) -> usize {
68        let remaining = self.end.saturating_sub(self.position);
69        if remaining == 0 {
70            0
71        } else {
72            remaining.div_ceil(self.chunk_size)
73        }
74    }
75}
76
77impl<'a> FusedIterator for TensorChunksIterator<'a> {}
78
79impl<'a> DoubleEndedIterator for TensorChunksIterator<'a> {
80    #[inline]
81    fn next_back(&mut self) -> Option<Self::Item> {
82        if self.position >= self.end {
83            return None;
84        }
85        let remaining = self.end - self.position;
86        let take = remaining.min(self.chunk_size);
87        self.end -= take;
88        Some(self.create_chunk_view(self.end, take))
89    }
90}
91
92pub struct TensorChunksExactIterator<'a> {
93    pub(crate) source: &'a Tensor,
94    pub(crate) chunk_size: usize,
95    pub(crate) position: usize,
96    pub(crate) exact_end: usize,
97    pub(crate) remainder_start: usize,
98    pub(crate) remainder_len: usize,
99}
100
101impl<'a> TensorChunksExactIterator<'a> {
102    #[inline]
103    pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
104        assert!(chunk_size > 0, "chunk_size must be > 0");
105        let size = source.size();
106        let exact_chunks = size / chunk_size;
107        let exact_end = exact_chunks * chunk_size;
108        let remainder_len = size - exact_end;
109        Self {
110            source,
111            chunk_size,
112            position: 0,
113            exact_end,
114            remainder_start: exact_end,
115            remainder_len,
116        }
117    }
118
119    #[inline]
120    pub fn remainder(&self) -> Tensor {
121        if self.remainder_len == 0 {
122            Tensor::new(vec![0])
123        } else {
124            let v = self
125                .source
126                .slice_view(self.remainder_start, 1, self.remainder_len);
127            if v.is_contiguous() {
128                v
129            } else {
130                v.contiguous()
131            }
132        }
133    }
134
135    #[inline]
136    fn create_chunk_view(&self, start: usize) -> Tensor {
137        let v = self.source.slice_view(start, 1, self.chunk_size);
138        if v.is_contiguous() {
139            v
140        } else {
141            v.contiguous()
142        }
143    }
144}
145
146impl<'a> Iterator for TensorChunksExactIterator<'a> {
147    type Item = Tensor;
148    #[inline]
149    fn next(&mut self) -> Option<Self::Item> {
150        if self.position >= self.exact_end {
151            return None;
152        }
153        let start = self.position;
154        self.position += self.chunk_size;
155        Some(self.create_chunk_view(start))
156    }
157    #[inline]
158    fn size_hint(&self) -> (usize, Option<usize>) {
159        let remaining = self.exact_end.saturating_sub(self.position);
160        let n = remaining / self.chunk_size;
161        (n, Some(n))
162    }
163}
164
165impl<'a> ExactSizeIterator for TensorChunksExactIterator<'a> {
166    #[inline]
167    fn len(&self) -> usize {
168        (self.exact_end.saturating_sub(self.position)) / self.chunk_size
169    }
170}
171
172impl<'a> FusedIterator for TensorChunksExactIterator<'a> {}
173
174impl<'a> DoubleEndedIterator for TensorChunksExactIterator<'a> {
175    #[inline]
176    fn next_back(&mut self) -> Option<Self::Item> {
177        if self.position >= self.exact_end {
178            return None;
179        }
180        self.exact_end = self.exact_end.saturating_sub(self.chunk_size);
181        Some(self.create_chunk_view(self.exact_end))
182    }
183}
184
185impl Tensor {
186    #[inline]
187    pub fn iter_chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_> {
188        TensorChunksIterator::new(self, chunk_size)
189    }
190
191    #[inline]
192    pub fn iter_chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_> {
193        TensorChunksExactIterator::new(self, chunk_size)
194    }
195
196    /// Iterate with an auto-tuned chunk size for cache-friendly processing
197    ///
198    /// Heuristic:
199    /// - Target ~64 KiB blocks (16K f32 elements) for good L1/L2 behavior
200    /// - Clamp to [4K, 262_144] elements
201    /// - Round to a multiple of SIMD lane width when possible
202    #[inline]
203    pub fn iter_fast_chunks(&self) -> TensorChunksIterator<'_> {
204        let n = self.size();
205        if n == 0 {
206            return TensorChunksIterator::new(self, 1);
207        }
208        // 64 KiB in f32 elements
209        let mut sz = 16_384usize;
210        // Adjust by tensor size rough scale
211        if n < 16_384 {
212            sz = 4_096;
213        }
214        if n > 1_048_576 {
215            sz = 65_536;
216        }
217        // Align to SIMD lane if available
218        let lane = crate::tensor::core::Tensor::simd_lane_width_elems_runtime();
219        if lane > 1 {
220            sz = sz.div_ceil(lane) * lane;
221        }
222        // Clamp
223        sz = sz.clamp(4_096, 262_144);
224        TensorChunksIterator::new(self, sz)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_chunks() {
234        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
235        let v: Vec<Tensor> = t.iter_chunks(2).collect();
236        assert_eq!(v.len(), 3);
237        assert_eq!(v[0].data(), &[1.0, 2.0]);
238        assert_eq!(v[1].data(), &[3.0, 4.0]);
239        assert_eq!(v[2].data(), &[5.0]);
240    }
241
242    #[test]
243    fn test_chunks_exact() {
244        let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
245        let mut it = t.iter_chunks_exact(2);
246        let a = it.next().unwrap();
247        let b = it.next().unwrap();
248        assert!(it.next().is_none());
249        assert_eq!(a.data(), &[10.0, 20.0]);
250        assert_eq!(b.data(), &[30.0, 40.0]);
251        assert_eq!(it.remainder().data(), &[50.0]);
252    }
253
254    #[test]
255    fn test_chunks_over_3d_outerdim() {
256        // 3D tensor [B=2, T=3, F=4]
257        let vals: Vec<f32> = (0..24).map(|i| i as f32).collect();
258        let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
259        // Iterate outermost dim via iter_dim(0), then chunk time dimension T with size 2
260        let mut collected: Vec<Vec<Tensor>> = Vec::new();
261        for b in t.iter_dim(0) {
262            let chunks: Vec<Tensor> = b.split(2, 0); // split along current outer dim (was T)
263                                                     // Each chunk should have shape [2,4] then [1,4]
264            assert_eq!(chunks[0].shape().dims(), vec![2, 4]);
265            assert_eq!(chunks[1].shape().dims(), vec![1, 4]);
266            collected.push(chunks);
267        }
268        assert_eq!(collected.len(), 2);
269    }
270
271    #[test]
272    fn test_chunks_gradient_after_collect() {
273        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
274            .unwrap()
275            .with_requires_grad();
276        // Split columns in chunks of 2, operate per chunk, re-concat and sum
277        let parts = t.split(2, 1);
278        let parts2: Vec<Tensor> = parts.into_iter().map(|p| p.mul_scalar(3.0)).collect();
279        let y = Tensor::cat(&parts2, 1);
280        let mut loss = y.sum();
281        loss.backward(None);
282        let g = t.grad_owned().unwrap();
283        // Each element gets grad 3.0
284        assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0, 3.0, 3.0]);
285    }
286}