train_station/tensor/iterator/
element.rs

1//! Element iterator over tensor elements as zero-copy scalar views
2//!
3//! Gradients are preserved implicitly: each yielded scalar is created via
4//! `Tensor::element_view`, which registers a `GradFn::View` with a
5//! `ViewMapping::LinearRange { start, step: 1, length: 1 }`. When collecting
6//! mapped elements back into a tensor and reducing (e.g., `sum()`), gradients
7//! propagate correctly to the original tensor without extra flags.
8
9use crate::tensor::core::Tensor;
10use std::iter::{ExactSizeIterator, FusedIterator};
11
12/// High-performance iterator over tensor elements as view tensors
13///
14/// Each element becomes a `Tensor` view of shape `[1]` that shares memory with
15/// the source and preserves gradient tracking.
16pub struct TensorElementIterator<'a> {
17    pub(crate) source: &'a Tensor,
18    pub(crate) position: usize,
19    pub(crate) end: usize,
20}
21
22impl<'a> TensorElementIterator<'a> {
23    #[inline]
24    pub fn new(tensor: &'a Tensor) -> Self {
25        Self {
26            source: tensor,
27            position: 0,
28            end: tensor.size(),
29        }
30    }
31
32    #[inline]
33    pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
34        let end = end.min(tensor.size());
35        let start = start.min(end);
36        Self {
37            source: tensor,
38            position: start,
39            end,
40        }
41    }
42
43    #[inline]
44    fn create_element_view(&self, index: usize) -> Tensor {
45        debug_assert!(index < self.source.size());
46        self.source.element_view(index)
47    }
48}
49
50impl<'a> Iterator for TensorElementIterator<'a> {
51    type Item = Tensor;
52
53    #[inline]
54    fn next(&mut self) -> Option<Self::Item> {
55        if self.position < self.end {
56            let view = self.create_element_view(self.position);
57            self.position += 1;
58            Some(view)
59        } else {
60            None
61        }
62    }
63
64    #[inline]
65    fn size_hint(&self) -> (usize, Option<usize>) {
66        let remaining = self.end - self.position;
67        (remaining, Some(remaining))
68    }
69
70    #[inline]
71    fn count(self) -> usize {
72        self.end - self.position
73    }
74
75    #[inline]
76    fn nth(&mut self, n: usize) -> Option<Self::Item> {
77        let new_pos = self.position.saturating_add(n);
78        if new_pos < self.end {
79            self.position = new_pos + 1;
80            Some(self.create_element_view(new_pos))
81        } else {
82            self.position = self.end;
83            None
84        }
85    }
86
87    #[inline]
88    fn last(self) -> Option<Self::Item> {
89        if self.position < self.end {
90            let last_idx = self.end - 1;
91            Some(self.create_element_view(last_idx))
92        } else {
93            None
94        }
95    }
96}
97
98impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
99    #[inline]
100    fn len(&self) -> usize {
101        self.end - self.position
102    }
103}
104
105impl<'a> FusedIterator for TensorElementIterator<'a> {}
106
107impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
108    #[inline]
109    fn next_back(&mut self) -> Option<Self::Item> {
110        if self.position < self.end {
111            self.end -= 1;
112            Some(self.create_element_view(self.end))
113        } else {
114            None
115        }
116    }
117
118    #[inline]
119    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
120        let new_end = self.end.saturating_sub(n + 1);
121        if new_end >= self.position {
122            self.end = new_end;
123            Some(self.create_element_view(self.end))
124        } else {
125            self.position = self.end;
126            None
127        }
128    }
129}
130
131impl Tensor {
132    /// Create an iterator over scalar elements (flattened view)
133    ///
134    /// Each yielded item is a `[1]`-shaped `Tensor` view that shares storage with
135    /// the source. This iterator is GradTrack-aware; element operations propagate
136    /// gradients to the original tensor when gradients are enabled.
137    ///
138    /// # Returns
139    ///
140    /// An iterator producing scalar view tensors in row-major order.
141    ///
142    /// # Examples
143    ///
144    /// Collect transformed elements back to the original shape using `collect_shape`:
145    ///
146    /// ```
147    /// use train_station::tensor::TensorCollectExt;
148    /// use train_station::Tensor;
149    ///
150    /// let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
151    /// let y = x
152    ///     .iter_elements()
153    ///     .map(|e| e.mul_scalar(2.0))
154    ///     .collect_shape(vec![2, 2]);
155    /// assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0]);
156    /// ```
157    #[inline]
158    pub fn iter_elements(&self) -> TensorElementIterator<'_> {
159        TensorElementIterator::new(self)
160    }
161
162    /// Create an iterator over a clamped range of elements
163    ///
164    /// Produces scalar view tensors from `start..end` (clamped to `[0, size]`).
165    ///
166    /// # Arguments
167    ///
168    /// * `start` - Start index (inclusive)
169    /// * `end` - End index (exclusive)
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use train_station::Tensor;
175    ///
176    /// let x = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
177    /// let vals: Vec<f32> = x.iter_range(2, 5).map(|e| e.value()).collect();
178    /// assert_eq!(vals, vec![2.0, 3.0, 4.0]);
179    /// ```
180    #[inline]
181    pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_> {
182        TensorElementIterator::with_range(self, start, end)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::gradtrack::NoGradTrack;
190
191    #[test]
192    fn test_basic_iteration() {
193        let t = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
194        let vals: Vec<f32> = t.iter_elements().map(|e| e.value()).collect();
195        assert_eq!(vals, vec![1.0, 2.0, 3.0]);
196    }
197
198    #[test]
199    fn test_range_iteration() {
200        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
201        let vals: Vec<f32> = t.iter_range(1, 3).map(|e| e.value()).collect();
202        assert_eq!(vals, vec![2.0, 3.0]);
203    }
204
205    #[test]
206    fn test_iter_flat_gradient_propagation() {
207        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])
208            .unwrap()
209            .with_requires_grad();
210
211        // Map add_scalar over scalar views, collect, then sum
212        let collected: Tensor = t.iter_elements().map(|e| e.add_scalar(1.0)).collect();
213        let mut loss = collected.sum();
214        loss.backward(None);
215
216        let g = t.grad_owned().unwrap();
217        assert_eq!(g.shape().dims(), vec![5]);
218        assert_eq!(g.data(), &[1.0, 1.0, 1.0, 1.0, 1.0]);
219    }
220
221    #[test]
222    fn test_element_iterator_double_ended_and_exact_size() {
223        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
224        let mut it = t.iter_elements();
225        assert_eq!(it.len(), 4);
226        assert_eq!(it.size_hint(), (4, Some(4)));
227        assert_eq!(it.next_back().unwrap().value(), 4.0);
228        assert_eq!(it.next().unwrap().value(), 1.0);
229        assert_eq!(it.nth(1).unwrap().value(), 3.0);
230        assert!(it.next().is_none());
231        assert_eq!(t.iter_elements().last().unwrap().value(), 4.0);
232    }
233
234    #[test]
235    fn test_iter_range_clamping_and_zero_sized() {
236        let t = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
237        let vals: Vec<f32> = t.iter_range(2, 10).map(|e| e.value()).collect();
238        assert_eq!(vals, vec![3.0]);
239        let empty = Tensor::new(vec![0]);
240        let it = empty.iter_elements();
241        assert_eq!(it.len(), 0);
242        assert_eq!(it.size_hint(), (0, Some(0)));
243    }
244
245    #[test]
246    fn test_iter_no_grad_guard() {
247        let t = Tensor::from_slice(&[1.0, 2.0], vec![2])
248            .unwrap()
249            .with_requires_grad();
250        let _g = NoGradTrack::new();
251        let v = t.iter_elements().next().unwrap();
252        assert!(!v.requires_grad());
253    }
254}