train_station/tensor/iterator/
chunks.rs

1//! Linear chunk iterators over tensors as contiguous views
2//!
3//! Gradients are preserved implicitly for with-grad paths: each yielded chunk is
4//! created via `Tensor::slice_view(start, 1, len)`, which registers a
5//! `GradFn::View` with a `ViewMapping::LinearRange { start, step: 1, length: len }`
6//! when the source requires gradients and gradient tracking is enabled.
7//!
8//! Performance routing is decided once at construction:
9//! - In no-grad fast mode, if the source is non-contiguous, the iterator performs a
10//!   single one-time `contiguous()` materialization and holds an internal owner.
11//!   Subsequent chunk views are taken from this contiguous owner (no per-chunk copies).
12//! - In with-grad mode, zero-copy views are created directly from the source.
13
14use crate::tensor::core::utils::should_use_fast_path;
15use crate::tensor::core::Tensor;
16use std::iter::{ExactSizeIterator, FusedIterator};
17
18pub struct TensorChunksIterator<'a> {
19    pub(crate) source: &'a Tensor,
20    pub(crate) chunk_size: usize,
21    pub(crate) position: usize,
22    pub(crate) end: usize,
23    // One-time contiguous owner for fast path on non-contiguous sources
24    pub(crate) owner: Option<Tensor>,
25}
26
27impl<'a> TensorChunksIterator<'a> {
28    #[inline]
29    pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
30        assert!(chunk_size > 0, "chunk_size must be > 0");
31        let fast = should_use_fast_path(&[source]);
32        // One-time contiguous materialization policy for fast path
33        let owner = if fast && !source.is_contiguous() && source.size() > 0 {
34            Some(source.contiguous())
35        } else {
36            None
37        };
38        Self {
39            source,
40            chunk_size,
41            position: 0,
42            end: source.size(),
43            owner,
44        }
45    }
46
47    #[inline]
48    fn create_chunk_view(&self, start: usize, len: usize) -> Tensor {
49        if len == 0 {
50            return Tensor::new(vec![0]);
51        }
52        // Choose base tensor once according to construction-time policy
53        let base: &Tensor = match &self.owner {
54            Some(o) => o,
55            None => self.source,
56        };
57        base.slice_view(start, 1, len)
58    }
59}
60
61impl<'a> Iterator for TensorChunksIterator<'a> {
62    type Item = Tensor;
63    #[inline]
64    fn next(&mut self) -> Option<Self::Item> {
65        if self.position >= self.end {
66            return None;
67        }
68        let start = self.position;
69        let remaining = self.end - start;
70        let take = remaining.min(self.chunk_size);
71        self.position += take;
72        Some(self.create_chunk_view(start, take))
73    }
74    #[inline]
75    fn size_hint(&self) -> (usize, Option<usize>) {
76        let remaining = self.end.saturating_sub(self.position);
77        let n = if remaining == 0 {
78            0
79        } else {
80            remaining.div_ceil(self.chunk_size)
81        };
82        (n, Some(n))
83    }
84}
85
86impl<'a> ExactSizeIterator for TensorChunksIterator<'a> {
87    #[inline]
88    fn len(&self) -> usize {
89        let remaining = self.end.saturating_sub(self.position);
90        if remaining == 0 {
91            0
92        } else {
93            remaining.div_ceil(self.chunk_size)
94        }
95    }
96}
97
98impl<'a> FusedIterator for TensorChunksIterator<'a> {}
99
100impl<'a> DoubleEndedIterator for TensorChunksIterator<'a> {
101    #[inline]
102    fn next_back(&mut self) -> Option<Self::Item> {
103        if self.position >= self.end {
104            return None;
105        }
106        let remaining = self.end - self.position;
107        // Match standard slice chunk semantics: the last chunk may be a smaller remainder
108        let rem = remaining % self.chunk_size;
109        let take = if rem == 0 { self.chunk_size } else { rem };
110        self.end -= take;
111        Some(self.create_chunk_view(self.end, take))
112    }
113}
114
115pub struct TensorChunksExactIterator<'a> {
116    pub(crate) source: &'a Tensor,
117    pub(crate) chunk_size: usize,
118    pub(crate) position: usize,
119    pub(crate) exact_end: usize,
120    pub(crate) remainder_start: usize,
121    pub(crate) remainder_len: usize,
122    // One-time contiguous owner for fast path on non-contiguous sources
123    pub(crate) owner: Option<Tensor>,
124}
125
126impl<'a> TensorChunksExactIterator<'a> {
127    #[inline]
128    pub fn new(source: &'a Tensor, chunk_size: usize) -> Self {
129        assert!(chunk_size > 0, "chunk_size must be > 0");
130        let size = source.size();
131        let exact_chunks = size / chunk_size;
132        let exact_end = exact_chunks * chunk_size;
133        let remainder_len = size - exact_end;
134        let fast = should_use_fast_path(&[source]);
135        let owner = if fast && !source.is_contiguous() && size > 0 {
136            Some(source.contiguous())
137        } else {
138            None
139        };
140        Self {
141            source,
142            chunk_size,
143            position: 0,
144            exact_end,
145            remainder_start: exact_end,
146            remainder_len,
147            owner,
148        }
149    }
150
151    #[inline]
152    pub fn remainder(&self) -> Tensor {
153        if self.remainder_len == 0 {
154            Tensor::new(vec![0])
155        } else {
156            let base: &Tensor = match &self.owner {
157                Some(o) => o,
158                None => self.source,
159            };
160            base.slice_view(self.remainder_start, 1, self.remainder_len)
161        }
162    }
163
164    #[inline]
165    fn create_chunk_view(&self, start: usize) -> Tensor {
166        let base: &Tensor = match &self.owner {
167            Some(o) => o,
168            None => self.source,
169        };
170        base.slice_view(start, 1, self.chunk_size)
171    }
172}
173
174impl<'a> Iterator for TensorChunksExactIterator<'a> {
175    type Item = Tensor;
176    #[inline]
177    fn next(&mut self) -> Option<Self::Item> {
178        if self.position >= self.exact_end {
179            return None;
180        }
181        let start = self.position;
182        self.position += self.chunk_size;
183        Some(self.create_chunk_view(start))
184    }
185    #[inline]
186    fn size_hint(&self) -> (usize, Option<usize>) {
187        let remaining = self.exact_end.saturating_sub(self.position);
188        let n = remaining / self.chunk_size;
189        (n, Some(n))
190    }
191}
192
193impl<'a> ExactSizeIterator for TensorChunksExactIterator<'a> {
194    #[inline]
195    fn len(&self) -> usize {
196        (self.exact_end.saturating_sub(self.position)) / self.chunk_size
197    }
198}
199
200impl<'a> FusedIterator for TensorChunksExactIterator<'a> {}
201
202impl<'a> DoubleEndedIterator for TensorChunksExactIterator<'a> {
203    #[inline]
204    fn next_back(&mut self) -> Option<Self::Item> {
205        if self.position >= self.exact_end {
206            return None;
207        }
208        self.exact_end = self.exact_end.saturating_sub(self.chunk_size);
209        Some(self.create_chunk_view(self.exact_end))
210    }
211}
212
213impl Tensor {
214    /// Standard slice-like chunks iterator. Use this instead of `iter_chunks`.
215    ///
216    /// Iterates over contiguous or view-backed slices of the tensor with the
217    /// specified chunk size. In no-grad fast mode, a single contiguous owner may
218    /// be materialized to optimize subsequent views.
219    ///
220    /// # Arguments
221    ///
222    /// * `chunk_size` - Number of elements per chunk (must be > 0)
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use train_station::tensor::TensorCollectExt;
228    /// use train_station::Tensor;
229    ///
230    /// let t = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
231    /// let y = t.chunks(2).map(|c| c.mul_scalar(2.0)).collect_shape(vec![6]);
232    /// assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
233    /// ```
234    #[inline]
235    pub fn chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_> {
236        TensorChunksIterator::new(self, chunk_size)
237    }
238
239    /// Standard slice-like exact chunks iterator. Use this instead of `iter_chunks_exact`.
240    ///
241    /// Yields only the exact chunks of size `chunk_size`, exposing any remainder
242    /// via `remainder()`. See `chunks()` for a variant that yields the remainder as
243    /// the last (smaller) chunk.
244    ///
245    /// # Examples
246    ///
247    /// ```
248    /// use train_station::Tensor;
249    ///
250    /// let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
251    /// let mut it = t.chunks_exact(2);
252    /// assert_eq!(it.next().unwrap().data(), &[1.0, 2.0]);
253    /// assert_eq!(it.next().unwrap().data(), &[3.0, 4.0]);
254    /// assert_eq!(it.remainder().data(), &[5.0]);
255    /// ```
256    #[inline]
257    pub fn chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_> {
258        TensorChunksExactIterator::new(self, chunk_size)
259    }
260
261    #[deprecated(note = "Use Tensor::chunks(...) instead. This alias will be removed before 1.0.")]
262    #[inline]
263    pub fn iter_chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_> {
264        TensorChunksIterator::new(self, chunk_size)
265    }
266
267    #[deprecated(
268        note = "Use Tensor::chunks_exact(...) instead. This alias will be removed before 1.0."
269    )]
270    #[inline]
271    pub fn iter_chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_> {
272        TensorChunksExactIterator::new(self, chunk_size)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::gradtrack::NoGradTrack;
280
281    #[test]
282    fn test_chunks() {
283        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
284        let v: Vec<Tensor> = t.chunks(2).collect();
285        assert_eq!(v.len(), 3);
286        assert_eq!(v[0].data(), &[1.0, 2.0]);
287        assert_eq!(v[1].data(), &[3.0, 4.0]);
288        assert_eq!(v[2].data(), &[5.0]);
289    }
290
291    #[test]
292    fn test_chunks_exact() {
293        let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
294        let mut it = t.chunks_exact(2);
295        let a = it.next().unwrap();
296        let b = it.next().unwrap();
297        assert!(it.next().is_none());
298        assert_eq!(a.data(), &[10.0, 20.0]);
299        assert_eq!(b.data(), &[30.0, 40.0]);
300        assert_eq!(it.remainder().data(), &[50.0]);
301    }
302
303    #[test]
304    fn test_chunks_over_3d_outerdim() {
305        // 3D tensor [B=2, T=3, F=4]
306        let vals: Vec<f32> = (0..24).map(|i| i as f32).collect();
307        let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
308        // Iterate outermost dim via iter_dim(0), then chunk time dimension T with size 2
309        let mut collected: Vec<Vec<Tensor>> = Vec::new();
310        for b in t.iter_dim(0) {
311            let chunks: Vec<Tensor> = b.split(2, 0); // split along current outer dim (was T)
312                                                     // Each chunk should have shape [2,4] then [1,4]
313            assert_eq!(chunks[0].shape().dims(), vec![2, 4]);
314            assert_eq!(chunks[1].shape().dims(), vec![1, 4]);
315            collected.push(chunks);
316        }
317        assert_eq!(collected.len(), 2);
318    }
319
320    #[test]
321    fn test_chunks_gradient_after_collect() {
322        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
323            .unwrap()
324            .with_requires_grad();
325        // Split columns in chunks of 2, operate per chunk, re-concat and sum
326        let parts = t.split(2, 1);
327        let parts2: Vec<Tensor> = parts.into_iter().map(|p| p.mul_scalar(3.0)).collect();
328        let y = Tensor::cat(&parts2, 1);
329        let mut loss = y.sum();
330        loss.backward(None);
331        let g = t.grad_owned().unwrap();
332        // Each element gets grad 3.0
333        assert_eq!(g.data(), &[3.0, 3.0, 3.0, 3.0, 3.0, 3.0]);
334    }
335
336    #[test]
337    fn test_chunks_iter_gradient_propagation() {
338        let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0], vec![4])
339            .unwrap()
340            .with_requires_grad();
341        // Use chunks iterator over 2-sized windows, map op, collect and sum
342        let parts: Vec<Tensor> = t.chunks(2).map(|c| c.add_scalar(1.0)).collect();
343        let y = Tensor::cat(&parts, 0);
344        let mut loss = y.sum();
345        loss.backward(None);
346        let g = t.grad_owned().unwrap();
347        assert_eq!(g.data(), &[1.0, 1.0, 1.0, 1.0]);
348    }
349
350    #[test]
351    fn test_chunks_double_ended_and_size_hint() {
352        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
353        let mut it = t.chunks(2);
354        assert_eq!(it.size_hint(), (3, Some(3)));
355        assert_eq!(it.len(), 3);
356
357        let last = it.next_back().unwrap();
358        assert_eq!(last.data(), &[5.0]);
359        assert_eq!(it.len(), 2);
360
361        let first = it.next().unwrap();
362        assert_eq!(first.data(), &[1.0, 2.0]);
363        assert_eq!(it.len(), 1);
364
365        let middle = it.next().unwrap();
366        assert_eq!(middle.data(), &[3.0, 4.0]);
367        assert!(it.next().is_none());
368        assert!(it.next_back().is_none());
369        assert_eq!(it.size_hint(), (0, Some(0)));
370    }
371
372    #[test]
373    fn test_chunks_zero_sized_tensor() {
374        let t = Tensor::new(vec![0]);
375        let it = t.chunks(3);
376        assert_eq!(it.len(), 0);
377        assert_eq!(it.size_hint(), (0, Some(0)));
378        assert_eq!(it.collect::<Vec<_>>().len(), 0);
379    }
380
381    #[test]
382    fn test_chunks_no_grad_guard_disables_requires_grad() {
383        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
384            .unwrap()
385            .with_requires_grad();
386        let _guard = NoGradTrack::new();
387        let mut it = t.chunks(2);
388        let a = it.next().unwrap();
389        let b = it.next().unwrap();
390        assert!(!a.requires_grad());
391        assert!(!b.requires_grad());
392        // Collect under guard should produce no-grad result
393        let y: Tensor = t.chunks(2).collect_shape(vec![2, 2]);
394        assert!(!y.requires_grad());
395    }
396}