train_station/tensor/iterator/
windows.rs

1//! Linear window iterators over tensors as overlapping views
2
3use crate::tensor::core::Tensor;
4use std::iter::{ExactSizeIterator, FusedIterator};
5
6pub struct TensorWindowsIterator<'a> {
7    pub(crate) source: &'a Tensor,
8    pub(crate) window_size: usize,
9    pub(crate) step: usize,
10    pub(crate) start: usize,
11    pub(crate) last_start: usize,
12    pub(crate) finished: bool,
13}
14
15impl<'a> TensorWindowsIterator<'a> {
16    #[inline]
17    pub fn new(source: &'a Tensor, window_size: usize, step: usize) -> Self {
18        assert!(window_size > 0, "window_size must be > 0");
19        assert!(step > 0, "step must be > 0");
20        let size = source.size();
21        let last_start = size.saturating_sub(window_size);
22        let finished = window_size > size;
23        Self {
24            source,
25            window_size,
26            step,
27            start: 0,
28            last_start,
29            finished,
30        }
31    }
32
33    #[inline]
34    fn create_window_view(&self, start: usize) -> Tensor {
35        self.source.slice_view(start, 1, self.window_size)
36    }
37
38    #[inline]
39    fn windows_len(&self) -> usize {
40        if self.finished {
41            0
42        } else {
43            ((self.last_start - self.start) / self.step) + 1
44        }
45    }
46}
47
48impl<'a> Iterator for TensorWindowsIterator<'a> {
49    type Item = Tensor;
50    #[inline]
51    fn next(&mut self) -> Option<Self::Item> {
52        if self.finished {
53            return None;
54        }
55        if self.start > self.last_start {
56            self.finished = true;
57            return None;
58        }
59        let s = self.start;
60        self.start = self.start.saturating_add(self.step);
61        Some(self.create_window_view(s))
62    }
63
64    #[inline]
65    fn size_hint(&self) -> (usize, Option<usize>) {
66        let n = self.windows_len();
67        (n, Some(n))
68    }
69}
70
71impl<'a> ExactSizeIterator for TensorWindowsIterator<'a> {
72    #[inline]
73    fn len(&self) -> usize {
74        self.windows_len()
75    }
76}
77
78impl<'a> FusedIterator for TensorWindowsIterator<'a> {}
79
80impl<'a> DoubleEndedIterator for TensorWindowsIterator<'a> {
81    #[inline]
82    fn next_back(&mut self) -> Option<Self::Item> {
83        if self.finished {
84            return None;
85        }
86        if self.last_start < self.start {
87            self.finished = true;
88            return None;
89        }
90        let s = self.last_start;
91        if s < self.step {
92            self.last_start = 0usize;
93        } else {
94            self.last_start -= self.step;
95        }
96        Some(self.create_window_view(s))
97    }
98}
99
100impl Tensor {
101    #[inline]
102    pub fn iter_windows(&self, window_size: usize) -> TensorWindowsIterator<'_> {
103        TensorWindowsIterator::new(self, window_size, 1)
104    }
105
106    #[inline]
107    pub fn iter_windows_step(&self, window_size: usize, step: usize) -> TensorWindowsIterator<'_> {
108        TensorWindowsIterator::new(self, window_size, step)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_windows() {
118        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
119        let wins: Vec<Tensor> = t.iter_windows(3).collect();
120        assert_eq!(wins.len(), 2);
121        assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
122        assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
123    }
124
125    #[test]
126    fn test_windows_step() {
127        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
128        let wins: Vec<Tensor> = t.iter_windows_step(2, 2).collect();
129        assert_eq!(wins.len(), 2);
130        assert_eq!(wins[0].data(), &[1.0, 2.0]);
131        assert_eq!(wins[1].data(), &[3.0, 4.0]);
132    }
133
134    #[test]
135    fn test_windows_over_3d_outerdim() {
136        // 3D [B=1, T=5, F=2]
137        let vals: Vec<f32> = (0..10).map(|i| i as f32).collect();
138        let t = Tensor::from_slice(&vals, vec![1, 5, 2]).unwrap();
139        // For each outer slice (batch), take time windows of size 3, step 2
140        for b in t.iter_dim(0) {
141            // b shape [5,2]; create windows along dim 0 of length 3 with step 2
142            // emulate with split/stride windows: we'll use iter_windows_step over linear memory of b.flattened rows
143            let wins: Vec<Tensor> = b
144                .split_with_sizes(&[3, 2], 0) // emulate two windows [0..3] and [2..5]
145                .into_iter()
146                .collect();
147            assert_eq!(wins[0].shape().dims(), vec![3, 2]);
148            assert_eq!(wins[1].shape().dims(), vec![2, 2]);
149        }
150    }
151
152    #[test]
153    fn test_windows_gradient_chain() {
154        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
155            .unwrap()
156            .with_requires_grad();
157        // Create overlapping windows along dim 0 of size 2 with step 1: [rows 0..2], [1..3]
158        let wins = t.split_with_sizes(&[2, 1], 0); // first window rows [0,1]
159        let w0 = wins[0].mul_scalar(2.0);
160        // Build second window rows [1,2] explicitly
161        let row1 = t.select(0, 1).unsqueeze(0);
162        let row2 = t.select(0, 2).unsqueeze(0);
163        let w1 = Tensor::cat(&[row1, row2], 0).mul_scalar(2.0);
164        let y = Tensor::cat(&[w0, w1], 0);
165        let mut loss = y.sum();
166        loss.backward(None);
167        let g = t.grad_owned().unwrap();
168        // Rows 1 appears in both windows → grad accum twice (2 + 2); others once
169        assert_eq!(g.get(&[0, 0]), 2.0);
170        assert_eq!(g.get(&[1, 0]), 4.0);
171        assert_eq!(g.get(&[2, 0]), 2.0);
172    }
173}