train_station/tensor/iterator/
viewdim.rs1use crate::tensor::core::utils::should_use_fast_path;
5use crate::tensor::core::Tensor;
6use std::iter::{ExactSizeIterator, FusedIterator};
7
8pub struct TensorDimIterator<'a> {
10 source: &'a Tensor,
11 dim: usize,
12 index: usize,
13 end: usize,
14 owner: Option<Tensor>,
16 use_element_views: bool,
18}
19
20impl<'a> TensorDimIterator<'a> {
21 #[inline]
22 pub fn new(source: &'a Tensor, dim: usize) -> Self {
23 let rank = source.shape().rank();
24 assert!(rank > 0, "iter_dim: cannot iterate dims of a 0-D tensor");
25 let dim = if dim < rank {
26 dim
27 } else {
28 panic!("iter_dim: dim {} out of bounds for rank {}", dim, rank)
29 };
30 let end = source.shape().dims()[dim];
31 let fast = should_use_fast_path(&[source]);
33 let owner = if fast && !source.is_contiguous() && source.size() > 0 {
35 Some(source.contiguous())
36 } else {
37 None
38 };
39 Self {
40 source,
41 dim,
42 index: 0,
43 end,
44 owner,
45 use_element_views: rank == 1,
46 }
47 }
48
49 #[inline]
50 fn slice_at(&self, i: usize) -> Tensor {
51 let base: &Tensor = match &self.owner {
53 Some(o) => o,
54 None => self.source,
55 };
56 if self.use_element_views {
57 base.element_view(i)
59 } else {
60 base.select(self.dim, i)
62 }
63 }
64}
65
66impl<'a> Iterator for TensorDimIterator<'a> {
67 type Item = Tensor;
68
69 #[inline]
70 fn next(&mut self) -> Option<Self::Item> {
71 if self.index >= self.end {
72 return None;
73 }
74 let i = self.index;
75 self.index += 1;
76 Some(self.slice_at(i))
77 }
78
79 #[inline]
80 fn size_hint(&self) -> (usize, Option<usize>) {
81 let rem = self.end - self.index;
82 (rem, Some(rem))
83 }
84}
85
86impl<'a> ExactSizeIterator for TensorDimIterator<'a> {
87 #[inline]
88 fn len(&self) -> usize {
89 self.end - self.index
90 }
91}
92
93impl<'a> FusedIterator for TensorDimIterator<'a> {}
94
95impl<'a> std::iter::DoubleEndedIterator for TensorDimIterator<'a> {
96 #[inline]
97 fn next_back(&mut self) -> Option<Self::Item> {
98 if self.index >= self.end {
99 return None;
100 }
101 self.end -= 1;
102 Some(self.slice_at(self.end))
103 }
104
105 #[inline]
106 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
107 if self.index >= self.end {
108 return None;
109 }
110 let new_end = self.end.saturating_sub(n + 1);
111 if new_end < self.index {
112 self.index = self.end;
113 return None;
114 }
115 self.end = new_end;
116 Some(self.slice_at(self.end))
117 }
118}
119
120impl Tensor {
121 #[inline]
142 pub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_> {
143 TensorDimIterator::new(self, dim)
144 }
145
146 #[inline]
186 pub fn iter(&self) -> TensorDimIterator<'_> {
187 self.iter_dim(0)
188 }
189
190 #[inline]
193 pub fn outer_iter(&self) -> TensorDimIterator<'_> {
194 self.iter_dim(0)
195 }
196}
197
198pub struct TensorDimOwnedIterator {
200 owner: Tensor,
201 dim: usize,
202 index: usize,
203 end: usize,
204 use_element_views: bool,
206}
207
208impl TensorDimOwnedIterator {
209 #[inline]
210 pub fn new(source: Tensor, dim: usize) -> Self {
211 let rank = source.shape().rank();
212 assert!(rank > 0, "iter_dim: cannot iterate dims of a 0-D tensor");
213 let dim = if dim < rank {
214 dim
215 } else {
216 panic!("iter_dim: dim {} out of bounds for rank {}", dim, rank)
217 };
218 let end = source.shape().dims()[dim];
219 let fast = should_use_fast_path(&[&source]);
220 let owner = if fast && !source.is_contiguous() && source.size() > 0 {
221 source.contiguous()
222 } else {
223 source
224 };
225 Self {
226 owner,
227 dim,
228 index: 0,
229 end,
230 use_element_views: rank == 1,
231 }
232 }
233
234 #[inline]
235 fn slice_at(&self, i: usize) -> Tensor {
236 if self.use_element_views {
237 self.owner.element_view(i)
238 } else {
239 self.owner.select(self.dim, i)
240 }
241 }
242}
243
244impl Iterator for TensorDimOwnedIterator {
245 type Item = Tensor;
246
247 #[inline]
248 fn next(&mut self) -> Option<Self::Item> {
249 if self.index >= self.end {
250 return None;
251 }
252 let i = self.index;
253 self.index += 1;
254 Some(self.slice_at(i))
255 }
256
257 #[inline]
258 fn size_hint(&self) -> (usize, Option<usize>) {
259 let rem = self.end - self.index;
260 (rem, Some(rem))
261 }
262}
263
264impl ExactSizeIterator for TensorDimOwnedIterator {
265 #[inline]
266 fn len(&self) -> usize {
267 self.end - self.index
268 }
269}
270
271impl FusedIterator for TensorDimOwnedIterator {}
272
273impl std::iter::DoubleEndedIterator for TensorDimOwnedIterator {
274 #[inline]
275 fn next_back(&mut self) -> Option<Self::Item> {
276 if self.index >= self.end {
277 return None;
278 }
279 self.end -= 1;
280 Some(self.slice_at(self.end))
281 }
282
283 #[inline]
284 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
285 if self.index >= self.end {
286 return None;
287 }
288 let new_end = self.end.saturating_sub(n + 1);
289 if new_end < self.index {
290 self.index = self.end;
291 return None;
292 }
293 self.end = new_end;
294 Some(self.slice_at(self.end))
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_iter_dim_shapes() {
304 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
305 let mut it0 = t.iter_dim(0);
306 let a = it0.next().unwrap();
307 let b = it0.next().unwrap();
308 assert!(it0.next().is_none());
309 assert_eq!(a.shape().dims(), vec![3]);
310 assert_eq!(b.shape().dims(), vec![3]);
311 assert_eq!(a.data(), &[1.0, 2.0, 3.0]);
312 assert_eq!(b.data(), &[4.0, 5.0, 6.0]);
313 }
314
315 #[test]
316 fn test_iter_dim_rank3() {
317 let vals: Vec<f32> = (0..24).map(|x| x as f32).collect();
318 let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
319 let v: Vec<Tensor> = t.iter_dim(1).collect();
321 assert_eq!(v.len(), 3);
322 assert_eq!(v[0].shape().rank(), 2);
323 assert_eq!(v[0].shape().dims(), vec![2, 4]);
324 }
325
326 #[test]
327 fn test_iter_dim_gradient_propagation() {
328 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
329 .unwrap()
330 .with_requires_grad();
331 let collected: Tensor = t.iter_dim(0).map(|row| row.mul_scalar(2.0)).collect();
333 let mut loss = collected.sum();
334 loss.backward(None);
335 let g = t.grad_owned().unwrap();
336 assert_eq!(g.data(), &[2.0, 2.0, 2.0, 2.0]);
338 }
339}