train_station/tensor/iterator/
windows.rs1use crate::tensor::core::utils::should_use_fast_path;
15use crate::tensor::core::Tensor;
16use std::iter::{ExactSizeIterator, FusedIterator};
17
18pub struct TensorWindowsIterator<'a> {
19 pub(crate) source: &'a Tensor,
20 pub(crate) window_size: usize,
21 pub(crate) step: usize,
22 pub(crate) start: usize,
23 pub(crate) last_start: usize,
24 pub(crate) finished: bool,
25 pub(crate) owner: Option<Tensor>,
27}
28
29impl<'a> TensorWindowsIterator<'a> {
30 #[inline]
31 pub fn new(source: &'a Tensor, window_size: usize, step: usize) -> Self {
32 assert!(window_size > 0, "window_size must be > 0");
33 assert!(step > 0, "step must be > 0");
34 let size = source.size();
35 let raw_last = size.saturating_sub(window_size);
38 let last_start = if window_size > size {
39 0
40 } else {
41 (raw_last / step) * step
42 };
43 let finished = window_size > size;
44 let fast = should_use_fast_path(&[source]);
45 let owner = if fast && !source.is_contiguous() && size > 0 {
46 Some(source.contiguous())
47 } else {
48 None
49 };
50 Self {
51 source,
52 window_size,
53 step,
54 start: 0,
55 last_start,
56 finished,
57 owner,
58 }
59 }
60
61 #[inline]
62 fn create_window_view(&self, start: usize) -> Tensor {
63 let base: &Tensor = match &self.owner {
65 Some(o) => o,
66 None => self.source,
67 };
68 base.slice_view(start, 1, self.window_size)
69 }
70
71 #[inline]
72 fn windows_len(&self) -> usize {
73 if self.finished {
74 0
75 } else {
76 ((self.last_start - self.start) / self.step) + 1
77 }
78 }
79}
80
81impl<'a> Iterator for TensorWindowsIterator<'a> {
82 type Item = Tensor;
83 #[inline]
84 fn next(&mut self) -> Option<Self::Item> {
85 if self.finished {
86 return None;
87 }
88 if self.start > self.last_start {
89 self.finished = true;
90 return None;
91 }
92 let s = self.start;
93 self.start = self.start.saturating_add(self.step);
94 Some(self.create_window_view(s))
95 }
96
97 #[inline]
98 fn size_hint(&self) -> (usize, Option<usize>) {
99 let n = self.windows_len();
100 (n, Some(n))
101 }
102}
103
104impl<'a> ExactSizeIterator for TensorWindowsIterator<'a> {
105 #[inline]
106 fn len(&self) -> usize {
107 self.windows_len()
108 }
109}
110
111impl<'a> FusedIterator for TensorWindowsIterator<'a> {}
112
113impl<'a> DoubleEndedIterator for TensorWindowsIterator<'a> {
114 #[inline]
115 fn next_back(&mut self) -> Option<Self::Item> {
116 if self.finished {
117 return None;
118 }
119 if self.last_start < self.start {
120 self.finished = true;
121 return None;
122 }
123 let s = self.last_start;
124 if s < self.step {
125 self.last_start = 0usize;
126 } else {
127 self.last_start -= self.step;
128 }
129 Some(self.create_window_view(s))
130 }
131}
132
133impl Tensor {
134 #[inline]
153 pub fn windows(&self, window_size: usize) -> TensorWindowsIterator<'_> {
154 TensorWindowsIterator::new(self, window_size, 1)
155 }
156
157 #[inline]
179 pub fn windows_step(&self, window_size: usize, step: usize) -> TensorWindowsIterator<'_> {
180 TensorWindowsIterator::new(self, window_size, step)
181 }
182
183 #[deprecated(note = "Use Tensor::windows(...) instead. This alias will be removed before 1.0.")]
184 #[inline]
185 pub fn iter_windows(&self, window_size: usize) -> TensorWindowsIterator<'_> {
186 TensorWindowsIterator::new(self, window_size, 1)
187 }
188
189 #[deprecated(
190 note = "Use Tensor::windows_step(...) instead. This alias will be removed before 1.0."
191 )]
192 #[inline]
193 pub fn iter_windows_step(&self, window_size: usize, step: usize) -> TensorWindowsIterator<'_> {
194 TensorWindowsIterator::new(self, window_size, step)
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::gradtrack::NoGradTrack;
202
203 #[test]
204 fn test_windows() {
205 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
206 let wins: Vec<Tensor> = t.windows(3).collect();
207 assert_eq!(wins.len(), 2);
208 assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
209 assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
210 }
211
212 #[test]
213 fn test_windows_step() {
214 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
215 let wins: Vec<Tensor> = t.windows_step(2, 2).collect();
216 assert_eq!(wins.len(), 2);
217 assert_eq!(wins[0].data(), &[1.0, 2.0]);
218 assert_eq!(wins[1].data(), &[3.0, 4.0]);
219 }
220
221 #[test]
222 fn test_windows_over_3d_outerdim() {
223 let vals: Vec<f32> = (0..10).map(|i| i as f32).collect();
225 let t = Tensor::from_slice(&vals, vec![1, 5, 2]).unwrap();
226 for b in t.iter_dim(0) {
228 let wins: Vec<Tensor> = b
231 .split_with_sizes(&[3, 2], 0) .into_iter()
233 .collect();
234 assert_eq!(wins[0].shape().dims(), vec![3, 2]);
235 assert_eq!(wins[1].shape().dims(), vec![2, 2]);
236 }
237 }
238
239 #[test]
240 fn test_windows_gradient_chain() {
241 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
242 .unwrap()
243 .with_requires_grad();
244 let wins = t.split_with_sizes(&[2, 1], 0); let w0 = wins[0].mul_scalar(2.0);
247 let row1 = t.select(0, 1).unsqueeze(0);
249 let row2 = t.select(0, 2).unsqueeze(0);
250 let w1 = Tensor::cat(&[row1, row2], 0).mul_scalar(2.0);
251 let y = Tensor::cat(&[w0, w1], 0);
252 let mut loss = y.sum();
253 loss.backward(None);
254 let g = t.grad_owned().unwrap();
255 assert_eq!(g.get(&[0, 0]), 2.0);
257 assert_eq!(g.get(&[1, 0]), 4.0);
258 assert_eq!(g.get(&[2, 0]), 2.0);
259 }
260
261 #[test]
262 fn test_windows_iter_gradient_propagation() {
263 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4])
264 .unwrap()
265 .with_requires_grad();
266 let wins: Vec<Tensor> = t.windows(3).map(|w| w.mul_scalar(2.0)).collect();
268 let y = Tensor::cat(&wins, 0).sum();
269 let mut loss = y;
270 loss.backward(None);
271 let g = t.grad_owned().unwrap();
272 assert_eq!(g.data(), &[2.0, 4.0, 4.0, 2.0]);
274 }
275
276 #[test]
277 fn test_windows_double_ended_and_size_hint() {
278 let t =
279 Tensor::from_slice(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
280 let mut it = t.windows_step(3, 2); assert_eq!(it.size_hint(), (3, Some(3)));
282 assert_eq!(it.len(), 3);
283
284 let back = it.next_back().unwrap();
285 assert_eq!(back.data(), &[5.0, 6.0, 7.0]);
286 assert_eq!(it.len(), 2);
287 let front = it.next().unwrap();
288 assert_eq!(front.data(), &[1.0, 2.0, 3.0]);
289 let mid = it.next().unwrap();
290 assert_eq!(mid.data(), &[3.0, 4.0, 5.0]);
291 assert!(it.next().is_none());
292 assert!(it.next_back().is_none());
293 assert_eq!(it.size_hint(), (0, Some(0)));
294 }
295
296 #[test]
297 fn test_windows_zero_sized_tensor() {
298 let t = Tensor::new(vec![0]);
299 let it = t.windows(3);
300 assert_eq!(it.len(), 0);
301 assert_eq!(it.size_hint(), (0, Some(0)));
302 assert_eq!(it.collect::<Vec<_>>().len(), 0);
303 }
304
305 #[test]
306 fn test_windows_no_grad_guard_disables_requires_grad() {
307 let t = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6])
308 .unwrap()
309 .with_requires_grad();
310 let _guard = NoGradTrack::new();
311 let w = t.windows(4).next().unwrap();
312 assert!(!w.requires_grad());
313 let y: Tensor = t.windows(4).collect_shape(vec![3, 4]);
314 assert!(!y.requires_grad());
315 }
316}