train_station/tensor/iterator/
viewdim.rs1use crate::tensor::core::Tensor;
5use std::iter::{ExactSizeIterator, FusedIterator};
6
7pub struct TensorDimIterator<'a> {
9 source: &'a Tensor,
10 dim: usize,
11 index: usize,
12 end: usize,
13}
14
15impl<'a> TensorDimIterator<'a> {
16 #[inline]
17 pub fn new(source: &'a Tensor, dim: usize) -> Self {
18 let rank = source.shape().rank();
19 assert!(rank > 0, "iter_dim: cannot iterate dims of a 0-D tensor");
20 let dim = if dim < rank {
21 dim
22 } else {
23 panic!("iter_dim: dim {} out of bounds for rank {}", dim, rank)
24 };
25 let end = source.shape().dims()[dim];
26 Self {
27 source,
28 dim,
29 index: 0,
30 end,
31 }
32 }
33
34 #[inline]
35 fn slice_at(&self, i: usize) -> Tensor {
36 self.source.select(self.dim, i)
41 }
42}
43
44impl<'a> Iterator for TensorDimIterator<'a> {
45 type Item = Tensor;
46
47 #[inline]
48 fn next(&mut self) -> Option<Self::Item> {
49 if self.index >= self.end {
50 return None;
51 }
52 let i = self.index;
53 self.index += 1;
54 Some(self.slice_at(i))
55 }
56
57 #[inline]
58 fn size_hint(&self) -> (usize, Option<usize>) {
59 let rem = self.end - self.index;
60 (rem, Some(rem))
61 }
62}
63
64impl<'a> ExactSizeIterator for TensorDimIterator<'a> {
65 #[inline]
66 fn len(&self) -> usize {
67 self.end - self.index
68 }
69}
70
71impl<'a> FusedIterator for TensorDimIterator<'a> {}
72
73impl<'a> std::iter::DoubleEndedIterator for TensorDimIterator<'a> {
74 #[inline]
75 fn next_back(&mut self) -> Option<Self::Item> {
76 if self.index >= self.end {
77 return None;
78 }
79 self.end -= 1;
80 Some(self.slice_at(self.end))
81 }
82
83 #[inline]
84 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
85 if self.index >= self.end {
86 return None;
87 }
88 let new_end = self.end.saturating_sub(n + 1);
89 if new_end < self.index {
90 self.index = self.end;
91 return None;
92 }
93 self.end = new_end;
94 Some(self.slice_at(self.end))
95 }
96}
97
98impl Tensor {
99 #[inline]
102 pub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_> {
103 TensorDimIterator::new(self, dim)
104 }
105 #[inline]
107 pub fn iter(&self) -> TensorDimIterator<'_> {
108 self.iter_dim(0)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_iter_dim_shapes() {
118 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
119 let mut it0 = t.iter_dim(0);
120 let a = it0.next().unwrap();
121 let b = it0.next().unwrap();
122 assert!(it0.next().is_none());
123 assert_eq!(a.shape().dims(), vec![3]);
124 assert_eq!(b.shape().dims(), vec![3]);
125 assert_eq!(a.data(), &[1.0, 2.0, 3.0]);
126 assert_eq!(b.data(), &[4.0, 5.0, 6.0]);
127 }
128
129 #[test]
130 fn test_iter_dim_rank3() {
131 let vals: Vec<f32> = (0..24).map(|x| x as f32).collect();
132 let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
133 let v: Vec<Tensor> = t.iter_dim(1).collect();
135 assert_eq!(v.len(), 3);
136 assert_eq!(v[0].shape().rank(), 2);
137 assert_eq!(v[0].shape().dims(), vec![2, 4]);
138 }
139
140 #[test]
141 fn test_iter_dim_gradient_propagation() {
142 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
143 .unwrap()
144 .with_requires_grad();
145 let collected: Tensor = t.iter_dim(0).map(|row| row.mul_scalar(2.0)).collect();
147 let mut loss = collected.sum();
148 loss.backward(None);
149 let g = t.grad_owned().unwrap();
150 assert_eq!(g.data(), &[2.0, 2.0, 2.0, 2.0]);
152 }
153}