train_station/tensor/iterator/
viewdim.rs

1//! Dimension iterator: iterate over sub-tensors along a specific dimension
2
3// Grad tracking is registered by select; no direct use here
4use crate::tensor::core::Tensor;
5use std::iter::{ExactSizeIterator, FusedIterator};
6
7/// Iterator that yields sub-tensors by slicing along a specific dimension.
8pub struct TensorDimIterator<'a> {
9    source: &'a Tensor,
10    dim: usize,
11    index: usize,
12    end: usize,
13}
14
15impl<'a> TensorDimIterator<'a> {
16    #[inline]
17    pub fn new(source: &'a Tensor, dim: usize) -> Self {
18        let rank = source.shape().rank();
19        assert!(rank > 0, "iter_dim: cannot iterate dims of a 0-D tensor");
20        let dim = if dim < rank {
21            dim
22        } else {
23            panic!("iter_dim: dim {} out of bounds for rank {}", dim, rank)
24        };
25        let end = source.shape().dims()[dim];
26        Self {
27            source,
28            dim,
29            index: 0,
30            end,
31        }
32    }
33
34    #[inline]
35    fn slice_at(&self, i: usize) -> Tensor {
36        // Build view using select along dim i
37        // Use existing select API if available; otherwise emulate via as_strided over remaining dims
38
39        // Maintain gradient tracking - select already registers GradFn::Select
40        self.source.select(self.dim, i)
41    }
42}
43
44impl<'a> Iterator for TensorDimIterator<'a> {
45    type Item = Tensor;
46
47    #[inline]
48    fn next(&mut self) -> Option<Self::Item> {
49        if self.index >= self.end {
50            return None;
51        }
52        let i = self.index;
53        self.index += 1;
54        Some(self.slice_at(i))
55    }
56
57    #[inline]
58    fn size_hint(&self) -> (usize, Option<usize>) {
59        let rem = self.end - self.index;
60        (rem, Some(rem))
61    }
62}
63
64impl<'a> ExactSizeIterator for TensorDimIterator<'a> {
65    #[inline]
66    fn len(&self) -> usize {
67        self.end - self.index
68    }
69}
70
71impl<'a> FusedIterator for TensorDimIterator<'a> {}
72
73impl<'a> std::iter::DoubleEndedIterator for TensorDimIterator<'a> {
74    #[inline]
75    fn next_back(&mut self) -> Option<Self::Item> {
76        if self.index >= self.end {
77            return None;
78        }
79        self.end -= 1;
80        Some(self.slice_at(self.end))
81    }
82
83    #[inline]
84    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
85        if self.index >= self.end {
86            return None;
87        }
88        let new_end = self.end.saturating_sub(n + 1);
89        if new_end < self.index {
90            self.index = self.end;
91            return None;
92        }
93        self.end = new_end;
94        Some(self.slice_at(self.end))
95    }
96}
97
98impl Tensor {
99    /// Iterate over sub-tensors along a specific dimension.
100    /// Each item is a view with that dimension removed (rank-1).
101    #[inline]
102    pub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_> {
103        TensorDimIterator::new(self, dim)
104    }
105    /// Default iterator over the outermost dimension, yielding sub-tensors.
106    #[inline]
107    pub fn iter(&self) -> TensorDimIterator<'_> {
108        self.iter_dim(0)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_iter_dim_shapes() {
118        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
119        let mut it0 = t.iter_dim(0);
120        let a = it0.next().unwrap();
121        let b = it0.next().unwrap();
122        assert!(it0.next().is_none());
123        assert_eq!(a.shape().dims(), vec![3]);
124        assert_eq!(b.shape().dims(), vec![3]);
125        assert_eq!(a.data(), &[1.0, 2.0, 3.0]);
126        assert_eq!(b.data(), &[4.0, 5.0, 6.0]);
127    }
128
129    #[test]
130    fn test_iter_dim_rank3() {
131        let vals: Vec<f32> = (0..24).map(|x| x as f32).collect();
132        let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
133        // iter over dim 1 → each item is [2,4]-shaped (with dim 1 removed)
134        let v: Vec<Tensor> = t.iter_dim(1).collect();
135        assert_eq!(v.len(), 3);
136        assert_eq!(v[0].shape().rank(), 2);
137        assert_eq!(v[0].shape().dims(), vec![2, 4]);
138    }
139
140    #[test]
141    fn test_iter_dim_gradient_propagation() {
142        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
143            .unwrap()
144            .with_requires_grad();
145        // operate per-row then collect and sum
146        let collected: Tensor = t.iter_dim(0).map(|row| row.mul_scalar(2.0)).collect();
147        let mut loss = collected.sum();
148        loss.backward(None);
149        let g = t.grad_owned().unwrap();
150        // Each element receives gradient 2.0 through the 2x scaling and sum
151        assert_eq!(g.data(), &[2.0, 2.0, 2.0, 2.0]);
152    }
153}