train_station/tensor/iterator/
windows.rs1use crate::tensor::core::Tensor;
4use std::iter::{ExactSizeIterator, FusedIterator};
5
6pub struct TensorWindowsIterator<'a> {
7 pub(crate) source: &'a Tensor,
8 pub(crate) window_size: usize,
9 pub(crate) step: usize,
10 pub(crate) start: usize,
11 pub(crate) last_start: usize,
12 pub(crate) finished: bool,
13}
14
15impl<'a> TensorWindowsIterator<'a> {
16 #[inline]
17 pub fn new(source: &'a Tensor, window_size: usize, step: usize) -> Self {
18 assert!(window_size > 0, "window_size must be > 0");
19 assert!(step > 0, "step must be > 0");
20 let size = source.size();
21 let last_start = size.saturating_sub(window_size);
22 let finished = window_size > size;
23 Self {
24 source,
25 window_size,
26 step,
27 start: 0,
28 last_start,
29 finished,
30 }
31 }
32
33 #[inline]
34 fn create_window_view(&self, start: usize) -> Tensor {
35 self.source.slice_view(start, 1, self.window_size)
36 }
37
38 #[inline]
39 fn windows_len(&self) -> usize {
40 if self.finished {
41 0
42 } else {
43 ((self.last_start - self.start) / self.step) + 1
44 }
45 }
46}
47
48impl<'a> Iterator for TensorWindowsIterator<'a> {
49 type Item = Tensor;
50 #[inline]
51 fn next(&mut self) -> Option<Self::Item> {
52 if self.finished {
53 return None;
54 }
55 if self.start > self.last_start {
56 self.finished = true;
57 return None;
58 }
59 let s = self.start;
60 self.start = self.start.saturating_add(self.step);
61 Some(self.create_window_view(s))
62 }
63
64 #[inline]
65 fn size_hint(&self) -> (usize, Option<usize>) {
66 let n = self.windows_len();
67 (n, Some(n))
68 }
69}
70
71impl<'a> ExactSizeIterator for TensorWindowsIterator<'a> {
72 #[inline]
73 fn len(&self) -> usize {
74 self.windows_len()
75 }
76}
77
78impl<'a> FusedIterator for TensorWindowsIterator<'a> {}
79
80impl<'a> DoubleEndedIterator for TensorWindowsIterator<'a> {
81 #[inline]
82 fn next_back(&mut self) -> Option<Self::Item> {
83 if self.finished {
84 return None;
85 }
86 if self.last_start < self.start {
87 self.finished = true;
88 return None;
89 }
90 let s = self.last_start;
91 if s < self.step {
92 self.last_start = 0usize;
93 } else {
94 self.last_start -= self.step;
95 }
96 Some(self.create_window_view(s))
97 }
98}
99
100impl Tensor {
101 #[inline]
102 pub fn iter_windows(&self, window_size: usize) -> TensorWindowsIterator<'_> {
103 TensorWindowsIterator::new(self, window_size, 1)
104 }
105
106 #[inline]
107 pub fn iter_windows_step(&self, window_size: usize, step: usize) -> TensorWindowsIterator<'_> {
108 TensorWindowsIterator::new(self, window_size, step)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_windows() {
118 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
119 let wins: Vec<Tensor> = t.iter_windows(3).collect();
120 assert_eq!(wins.len(), 2);
121 assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
122 assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
123 }
124
125 #[test]
126 fn test_windows_step() {
127 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
128 let wins: Vec<Tensor> = t.iter_windows_step(2, 2).collect();
129 assert_eq!(wins.len(), 2);
130 assert_eq!(wins[0].data(), &[1.0, 2.0]);
131 assert_eq!(wins[1].data(), &[3.0, 4.0]);
132 }
133
134 #[test]
135 fn test_windows_over_3d_outerdim() {
136 let vals: Vec<f32> = (0..10).map(|i| i as f32).collect();
138 let t = Tensor::from_slice(&vals, vec![1, 5, 2]).unwrap();
139 for b in t.iter_dim(0) {
141 let wins: Vec<Tensor> = b
144 .split_with_sizes(&[3, 2], 0) .into_iter()
146 .collect();
147 assert_eq!(wins[0].shape().dims(), vec![3, 2]);
148 assert_eq!(wins[1].shape().dims(), vec![2, 2]);
149 }
150 }
151
152 #[test]
153 fn test_windows_gradient_chain() {
154 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2])
155 .unwrap()
156 .with_requires_grad();
157 let wins = t.split_with_sizes(&[2, 1], 0); let w0 = wins[0].mul_scalar(2.0);
160 let row1 = t.select(0, 1).unsqueeze(0);
162 let row2 = t.select(0, 2).unsqueeze(0);
163 let w1 = Tensor::cat(&[row1, row2], 0).mul_scalar(2.0);
164 let y = Tensor::cat(&[w0, w1], 0);
165 let mut loss = y.sum();
166 loss.backward(None);
167 let g = t.grad_owned().unwrap();
168 assert_eq!(g.get(&[0, 0]), 2.0);
170 assert_eq!(g.get(&[1, 0]), 4.0);
171 assert_eq!(g.get(&[2, 0]), 2.0);
172 }
173}