train_station/tensor/iterator/
windows.rs

1//! Linear window iterators over tensors as overlapping views
2//!
3//! Gradients are preserved implicitly for with-grad paths: each yielded window is
4//! created via `Tensor::slice_view(start, 1, window_size)`, which registers a
5//! `GradFn::View` with a `ViewMapping::LinearRange { start, step: 1, length: window_size }`
6//! when the source requires gradients and gradient tracking is enabled.
7//!
8//! Performance routing is decided once at construction:
9//! - In no-grad fast mode, if the source is non-contiguous, the iterator performs a
10//!   single one-time `contiguous()` materialization and holds an internal owner.
11//!   Subsequent window views are taken from this contiguous owner (no per-window copies).
12//! - In with-grad mode, zero-copy views are created directly from the source.
13
14use crate::tensor::core::utils::should_use_fast_path;
15use crate::tensor::core::Tensor;
16use std::iter::{ExactSizeIterator, FusedIterator};
17
18pub struct TensorWindowsIterator<'a> {
19    pub(crate) source: &'a Tensor,
20    pub(crate) window_size: usize,
21    pub(crate) step: usize,
22    pub(crate) start: usize,
23    pub(crate) last_start: usize,
24    pub(crate) finished: bool,
25    // One-time contiguous owner for fast path on non-contiguous sources
26    pub(crate) owner: Option<Tensor>,
27}
28
29impl<'a> TensorWindowsIterator<'a> {
30    #[inline]
31    pub fn new(source: &'a Tensor, window_size: usize, step: usize) -> Self {
32        assert!(window_size > 0, "window_size must be > 0");
33        assert!(step > 0, "step must be > 0");
34        let size = source.size();
35        // Align the last_start to the stepping grid so reverse iteration returns
36        // the same sequence of window starts as forward iteration in reverse order.
37        let raw_last = size.saturating_sub(window_size);
38        let last_start = if window_size > size {
39            0
40        } else {
41            (raw_last / step) * step
42        };
43        let finished = window_size > size;
44        let fast = should_use_fast_path(&[source]);
45        let owner = if fast && !source.is_contiguous() && size > 0 {
46            Some(source.contiguous())
47        } else {
48            None
49        };
50        Self {
51            source,
52            window_size,
53            step,
54            start: 0,
55            last_start,
56            finished,
57            owner,
58        }
59    }
60
61    #[inline]
62    fn create_window_view(&self, start: usize) -> Tensor {
63        // Select base tensor according to construction-time policy
64        let base: &Tensor = match &self.owner {
65            Some(o) => o,
66            None => self.source,
67        };
68        base.slice_view(start, 1, self.window_size)
69    }
70
71    #[inline]
72    fn windows_len(&self) -> usize {
73        if self.finished {
74            0
75        } else {
76            ((self.last_start - self.start) / self.step) + 1
77        }
78    }
79}
80
81impl<'a> Iterator for TensorWindowsIterator<'a> {
82    type Item = Tensor;
83    #[inline]
84    fn next(&mut self) -> Option<Self::Item> {
85        if self.finished {
86            return None;
87        }
88        if self.start > self.last_start {
89            self.finished = true;
90            return None;
91        }
92        let s = self.start;
93        self.start = self.start.saturating_add(self.step);
94        Some(self.create_window_view(s))
95    }
96
97    #[inline]
98    fn size_hint(&self) -> (usize, Option<usize>) {
99        let n = self.windows_len();
100        (n, Some(n))
101    }
102}
103
104impl<'a> ExactSizeIterator for TensorWindowsIterator<'a> {
105    #[inline]
106    fn len(&self) -> usize {
107        self.windows_len()
108    }
109}
110
111impl<'a> FusedIterator for TensorWindowsIterator<'a> {}
112
113impl<'a> DoubleEndedIterator for TensorWindowsIterator<'a> {
114    #[inline]
115    fn next_back(&mut self) -> Option<Self::Item> {
116        if self.finished {
117            return None;
118        }
119        if self.last_start < self.start {
120            self.finished = true;
121            return None;
122        }
123        let s = self.last_start;
124        if s < self.step {
125            self.last_start = 0usize;
126        } else {
127            self.last_start -= self.step;
128        }
129        Some(self.create_window_view(s))
130    }
131}
132
133impl Tensor {
134    /// Overlapping windows iterator with step=1. Use this instead of `iter_windows`.
135    ///
136    /// Produces overlapping linear windows as view tensors. In no-grad fast mode,
137    /// a contiguous owner may be materialized once for faster subsequent views.
138    ///
139    /// # Arguments
140    ///
141    /// * `window_size` - Length of each window (> 0)
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// use train_station::Tensor;
147    ///
148    /// let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
149    /// let v: Vec<f32> = t.windows(3).map(|w| w.sum().value()).collect();
150    /// assert_eq!(v, vec![6.0, 9.0]);
151    /// ```
152    #[inline]
153    pub fn windows(&self, window_size: usize) -> TensorWindowsIterator<'_> {
154        TensorWindowsIterator::new(self, window_size, 1)
155    }
156
157    /// Overlapping windows iterator with custom step. Use this instead of `iter_windows_step`.
158    ///
159    /// Produces windows starting at positions `0, step, 2*step, ...` up to the last
160    /// valid start. Reverse iteration yields the same sequence in reverse.
161    ///
162    /// # Arguments
163    ///
164    /// * `window_size` - Length of each window (> 0)
165    /// * `step` - Step between consecutive window starts (> 0)
166    ///
167    /// # Examples
168    ///
169    /// ```
170    /// use train_station::Tensor;
171    ///
172    /// let t = Tensor::from_slice(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
173    /// let wins: Vec<Tensor> = t.windows_step(3, 2).collect();
174    /// assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
175    /// assert_eq!(wins[1].data(), &[3.0, 4.0, 5.0]);
176    /// assert_eq!(wins[2].data(), &[5.0, 6.0, 7.0]);
177    /// ```
178    #[inline]
179    pub fn windows_step(&self, window_size: usize, step: usize) -> TensorWindowsIterator<'_> {
180        TensorWindowsIterator::new(self, window_size, step)
181    }
182
183    #[deprecated(note = "Use Tensor::windows(...) instead. This alias will be removed before 1.0.")]
184    #[inline]
185    pub fn iter_windows(&self, window_size: usize) -> TensorWindowsIterator<'_> {
186        TensorWindowsIterator::new(self, window_size, 1)
187    }
188
189    #[deprecated(
190        note = "Use Tensor::windows_step(...) instead. This alias will be removed before 1.0."
191    )]
192    #[inline]
193    pub fn iter_windows_step(&self, window_size: usize, step: usize) -> TensorWindowsIterator<'_> {
194        TensorWindowsIterator::new(self, window_size, step)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::gradtrack::NoGradTrack;
202
203    #[test]
204    fn test_windows() {
205        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
206        let wins: Vec<Tensor> = t.windows(3).collect();
207        assert_eq!(wins.len(), 2);
208        assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
209        assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
210    }
211
212    #[test]
213    fn test_windows_step() {
214        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
215        let wins: Vec<Tensor> = t.windows_step(2, 2).collect();
216        assert_eq!(wins.len(), 2);
217        assert_eq!(wins[0].data(), &[1.0, 2.0]);
218        assert_eq!(wins[1].data(), &[3.0, 4.0]);
219    }
220
221    #[test]
222    fn test_windows_over_3d_outerdim() {
223        // 3D [B=1, T=5, F=2]
224        let vals: Vec<f32> = (0..10).map(|i| i as f32).collect();
225        let t = Tensor::from_slice(&vals, vec![1, 5, 2]).unwrap();
226        // For each outer slice (batch), take time windows of size 3, step 2
227        for b in t.iter_dim(0) {
228            // b shape [5,2]; create windows along dim 0 of length 3 with step 2
229            // emulate with split/stride windows: we'll use iter_windows_step over linear memory of b.flattened rows
230            let wins: Vec<Tensor> = b
231                .split_with_sizes(&[3, 2], 0) // emulate two windows [0..3] and [2..5]
232                .into_iter()
233                .collect();
234            assert_eq!(wins[0].shape().dims(), vec![3, 2]);
235            assert_eq!(wins[1].shape().dims(), vec![2, 2]);
236        }
237    }
238
239    #[test]
240    fn test_windows_gradient_chain() {
241        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
242            .unwrap()
243            .with_requires_grad();
244        // Create overlapping windows along dim 0 of size 2 with step 1: [rows 0..2], [1..3]
245        let wins = t.split_with_sizes(&[2, 1], 0); // first window rows [0,1]
246        let w0 = wins[0].mul_scalar(2.0);
247        // Build second window rows [1,2] explicitly
248        let row1 = t.select(0, 1).unsqueeze(0);
249        let row2 = t.select(0, 2).unsqueeze(0);
250        let w1 = Tensor::cat(&[row1, row2], 0).mul_scalar(2.0);
251        let y = Tensor::cat(&[w0, w1], 0);
252        let mut loss = y.sum();
253        loss.backward(None);
254        let g = t.grad_owned().unwrap();
255        // Rows 1 appears in both windows → grad accum twice (2 + 2); others once
256        assert_eq!(g.get(&[0, 0]), 2.0);
257        assert_eq!(g.get(&[1, 0]), 4.0);
258        assert_eq!(g.get(&[2, 0]), 2.0);
259    }
260
261    #[test]
262    fn test_windows_iter_gradient_propagation() {
263        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
264            .unwrap()
265            .with_requires_grad();
266        // Overlapping windows of size 3, step 1; map then collect
267        let wins: Vec<Tensor> = t.windows(3).map(|w| w.mul_scalar(2.0)).collect();
268        let y = Tensor::cat(&wins, 0).sum();
269        let mut loss = y;
270        loss.backward(None);
271        let g = t.grad_owned().unwrap();
272        // Coverage per index: [0] in 1 window, [1] in 2, [2] in 2, [3] in 1; each window scaled by 2
273        assert_eq!(g.data(), &[2.0, 4.0, 4.0, 2.0]);
274    }
275
276    #[test]
277    fn test_windows_double_ended_and_size_hint() {
278        let t =
279            Tensor::from_slice(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
280        let mut it = t.windows_step(3, 2); // starts at 0,2,4
281        assert_eq!(it.size_hint(), (3, Some(3)));
282        assert_eq!(it.len(), 3);
283
284        let back = it.next_back().unwrap();
285        assert_eq!(back.data(), &[5.0, 6.0, 7.0]);
286        assert_eq!(it.len(), 2);
287        let front = it.next().unwrap();
288        assert_eq!(front.data(), &[1.0, 2.0, 3.0]);
289        let mid = it.next().unwrap();
290        assert_eq!(mid.data(), &[3.0, 4.0, 5.0]);
291        assert!(it.next().is_none());
292        assert!(it.next_back().is_none());
293        assert_eq!(it.size_hint(), (0, Some(0)));
294    }
295
296    #[test]
297    fn test_windows_zero_sized_tensor() {
298        let t = Tensor::new(vec![0]);
299        let it = t.windows(3);
300        assert_eq!(it.len(), 0);
301        assert_eq!(it.size_hint(), (0, Some(0)));
302        assert_eq!(it.collect::<Vec<_>>().len(), 0);
303    }
304
305    #[test]
306    fn test_windows_no_grad_guard_disables_requires_grad() {
307        let t = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6])
308            .unwrap()
309            .with_requires_grad();
310        let _guard = NoGradTrack::new();
311        let w = t.windows(4).next().unwrap();
312        assert!(!w.requires_grad());
313        let y: Tensor = t.windows(4).collect_shape(vec![3, 4]);
314        assert!(!y.requires_grad());
315    }
316}