train_station/tensor/iterator/
viewdim.rs

1//! Dimension iterator: iterate over sub-tensors along a specific dimension
2
3// Grad tracking is registered by select; no direct use here
4use crate::tensor::core::utils::should_use_fast_path;
5use crate::tensor::core::Tensor;
6use std::iter::{ExactSizeIterator, FusedIterator};
7
8/// Iterator that yields sub-tensors by slicing along a specific dimension.
9pub struct TensorDimIterator<'a> {
10    source: &'a Tensor,
11    dim: usize,
12    index: usize,
13    end: usize,
14    // One-time contiguous owner for fast path on non-contiguous sources
15    owner: Option<Tensor>,
16    // For 1D tensors, prefer element views for better grad semantics and collection perf
17    use_element_views: bool,
18}
19
20impl<'a> TensorDimIterator<'a> {
21    #[inline]
22    pub fn new(source: &'a Tensor, dim: usize) -> Self {
23        let rank = source.shape().rank();
24        assert!(rank > 0, "iter_dim: cannot iterate dims of a 0-D tensor");
25        let dim = if dim < rank {
26            dim
27        } else {
28            panic!("iter_dim: dim {} out of bounds for rank {}", dim, rank)
29        };
30        let end = source.shape().dims()[dim];
31        // Decide fast path once at construction (no-grad or inference)
32        let fast = should_use_fast_path(&[source]);
33        // One-time contiguous materialization for better cache/linearization
34        let owner = if fast && !source.is_contiguous() && source.size() > 0 {
35            Some(source.contiguous())
36        } else {
37            None
38        };
39        Self {
40            source,
41            dim,
42            index: 0,
43            end,
44            owner,
45            use_element_views: rank == 1,
46        }
47    }
48
49    #[inline]
50    fn slice_at(&self, i: usize) -> Tensor {
51        // Choose base tensor once according to construction-time policy
52        let base: &Tensor = match &self.owner {
53            Some(o) => o,
54            None => self.source,
55        };
56        if self.use_element_views {
57            // 1D optimization: scalar element views cooperate with grad tracking
58            base.element_view(i)
59        } else {
60            // Multi-dim: rely on select for shape/stride-correct subviews
61            base.select(self.dim, i)
62        }
63    }
64}
65
66impl<'a> Iterator for TensorDimIterator<'a> {
67    type Item = Tensor;
68
69    #[inline]
70    fn next(&mut self) -> Option<Self::Item> {
71        if self.index >= self.end {
72            return None;
73        }
74        let i = self.index;
75        self.index += 1;
76        Some(self.slice_at(i))
77    }
78
79    #[inline]
80    fn size_hint(&self) -> (usize, Option<usize>) {
81        let rem = self.end - self.index;
82        (rem, Some(rem))
83    }
84}
85
86impl<'a> ExactSizeIterator for TensorDimIterator<'a> {
87    #[inline]
88    fn len(&self) -> usize {
89        self.end - self.index
90    }
91}
92
93impl<'a> FusedIterator for TensorDimIterator<'a> {}
94
95impl<'a> std::iter::DoubleEndedIterator for TensorDimIterator<'a> {
96    #[inline]
97    fn next_back(&mut self) -> Option<Self::Item> {
98        if self.index >= self.end {
99            return None;
100        }
101        self.end -= 1;
102        Some(self.slice_at(self.end))
103    }
104
105    #[inline]
106    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
107        if self.index >= self.end {
108            return None;
109        }
110        let new_end = self.end.saturating_sub(n + 1);
111        if new_end < self.index {
112            self.index = self.end;
113            return None;
114        }
115        self.end = new_end;
116        Some(self.slice_at(self.end))
117    }
118}
119
120impl Tensor {
121    /// Iterate over sub-tensors along a specific dimension.
122    ///
123    /// Produces view tensors by slicing along the given dimension; each item
124    /// has that dimension removed (rank - 1). Views share storage and preserve
125    /// gradient tracking semantics.
126    ///
127    /// # Arguments
128    ///
129    /// * `dim` - Dimension to iterate over
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// use train_station::tensor::TensorCollectExt;
135    /// use train_station::Tensor;
136    ///
137    /// let t = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
138    /// let out = t.iter_dim(0).map(|row| row.add_scalar(1.0)).collect_shape(vec![2, 3]);
139    /// assert_eq!(out.data(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
140    /// ```
141    #[inline]
142    pub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_> {
143        TensorDimIterator::new(self, dim)
144    }
145
146    /// Default iterator over the outermost dimension, yielding sub-tensors (N-D) or scalar views (1-D).
147    ///
148    /// This is equivalent to `iter_dim(0)` with an important optimization:
149    /// - For 1-D tensors, it yields scalar element views of shape `[1]` (same as `iter_flat()`)
150    ///   to maximize GradTrack cooperation and collection performance.
151    /// - For N-D tensors (rank > 1), it yields sub-tensors with the outermost dimension removed
152    ///   (rank − 1), suitable for row/batch-wise processing.
153    ///
154    /// Views share storage with the source tensor and preserve gradient tracking semantics.
155    /// Use `collect_shape([..])` to reconstruct shape efficiently after per-item transforms.
156    ///
157    /// # Returns
158    ///
159    /// An iterator producing view tensors for each slice along the outermost dimension
160    /// (or scalar views for 1-D).
161    ///
162    /// # Examples
163    ///
164    /// 1-D: element views (shape `[1]`) and shape-preserving collection
165    ///
166    /// ```
167    /// use train_station::tensor::TensorCollectExt;
168    /// use train_station::Tensor;
169    ///
170    /// let v = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
171    /// let out = v.iter().map(|e| e.add_scalar(1.0)).collect_shape(vec![6]);
172    /// assert_eq!(out.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
173    /// ```
174    ///
175    /// 2-D: row-wise transforms and shape-preserving collection
176    ///
177    /// ```
178    /// use train_station::tensor::TensorCollectExt;
179    /// use train_station::Tensor;
180    ///
181    /// let m = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
182    /// let y = m.iter().map(|row| row.mul_scalar(2.0)).collect_shape(vec![2, 3]);
183    /// assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
184    /// ```
185    #[inline]
186    pub fn iter(&self) -> TensorDimIterator<'_> {
187        self.iter_dim(0)
188    }
189
190    /// Explicit alias for outermost-dimension iteration of sub-tensors.
191    /// Equivalent to `iter_dim(0)`.
192    #[inline]
193    pub fn outer_iter(&self) -> TensorDimIterator<'_> {
194        self.iter_dim(0)
195    }
196}
197
198/// Owned variant of TensorDimIterator enabling `IntoIterator for Tensor` and iterator flattening
199pub struct TensorDimOwnedIterator {
200    owner: Tensor,
201    dim: usize,
202    index: usize,
203    end: usize,
204    // For 1D tensors, using element views cooperates best with grad tracking and collection perf
205    use_element_views: bool,
206}
207
208impl TensorDimOwnedIterator {
209    #[inline]
210    pub fn new(source: Tensor, dim: usize) -> Self {
211        let rank = source.shape().rank();
212        assert!(rank > 0, "iter_dim: cannot iterate dims of a 0-D tensor");
213        let dim = if dim < rank {
214            dim
215        } else {
216            panic!("iter_dim: dim {} out of bounds for rank {}", dim, rank)
217        };
218        let end = source.shape().dims()[dim];
219        let fast = should_use_fast_path(&[&source]);
220        let owner = if fast && !source.is_contiguous() && source.size() > 0 {
221            source.contiguous()
222        } else {
223            source
224        };
225        Self {
226            owner,
227            dim,
228            index: 0,
229            end,
230            use_element_views: rank == 1,
231        }
232    }
233
234    #[inline]
235    fn slice_at(&self, i: usize) -> Tensor {
236        if self.use_element_views {
237            self.owner.element_view(i)
238        } else {
239            self.owner.select(self.dim, i)
240        }
241    }
242}
243
244impl Iterator for TensorDimOwnedIterator {
245    type Item = Tensor;
246
247    #[inline]
248    fn next(&mut self) -> Option<Self::Item> {
249        if self.index >= self.end {
250            return None;
251        }
252        let i = self.index;
253        self.index += 1;
254        Some(self.slice_at(i))
255    }
256
257    #[inline]
258    fn size_hint(&self) -> (usize, Option<usize>) {
259        let rem = self.end - self.index;
260        (rem, Some(rem))
261    }
262}
263
264impl ExactSizeIterator for TensorDimOwnedIterator {
265    #[inline]
266    fn len(&self) -> usize {
267        self.end - self.index
268    }
269}
270
271impl FusedIterator for TensorDimOwnedIterator {}
272
273impl std::iter::DoubleEndedIterator for TensorDimOwnedIterator {
274    #[inline]
275    fn next_back(&mut self) -> Option<Self::Item> {
276        if self.index >= self.end {
277            return None;
278        }
279        self.end -= 1;
280        Some(self.slice_at(self.end))
281    }
282
283    #[inline]
284    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
285        if self.index >= self.end {
286            return None;
287        }
288        let new_end = self.end.saturating_sub(n + 1);
289        if new_end < self.index {
290            self.index = self.end;
291            return None;
292        }
293        self.end = new_end;
294        Some(self.slice_at(self.end))
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_iter_dim_shapes() {
304        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
305        let mut it0 = t.iter_dim(0);
306        let a = it0.next().unwrap();
307        let b = it0.next().unwrap();
308        assert!(it0.next().is_none());
309        assert_eq!(a.shape().dims(), vec![3]);
310        assert_eq!(b.shape().dims(), vec![3]);
311        assert_eq!(a.data(), &[1.0, 2.0, 3.0]);
312        assert_eq!(b.data(), &[4.0, 5.0, 6.0]);
313    }
314
315    #[test]
316    fn test_iter_dim_rank3() {
317        let vals: Vec<f32> = (0..24).map(|x| x as f32).collect();
318        let t = Tensor::from_slice(&vals, vec![2, 3, 4]).unwrap();
319        // iter over dim 1 → each item is [2,4]-shaped (with dim 1 removed)
320        let v: Vec<Tensor> = t.iter_dim(1).collect();
321        assert_eq!(v.len(), 3);
322        assert_eq!(v[0].shape().rank(), 2);
323        assert_eq!(v[0].shape().dims(), vec![2, 4]);
324    }
325
326    #[test]
327    fn test_iter_dim_gradient_propagation() {
328        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
329            .unwrap()
330            .with_requires_grad();
331        // operate per-row then collect and sum
332        let collected: Tensor = t.iter_dim(0).map(|row| row.mul_scalar(2.0)).collect();
333        let mut loss = collected.sum();
334        loss.backward(None);
335        let g = t.grad_owned().unwrap();
336        // Each element receives gradient 2.0 through the 2x scaling and sum
337        assert_eq!(g.data(), &[2.0, 2.0, 2.0, 2.0]);
338    }
339}