train_station/tensor/iterator/
element.rs

1//! Element iterator over tensor elements as zero-copy scalar views
2
3use crate::tensor::core::Tensor;
4use std::iter::{ExactSizeIterator, FusedIterator};
5
6/// High-performance iterator over tensor elements as view tensors
7///
8/// Each element becomes a `Tensor` view of shape `[1]` that shares memory with
9/// the source and preserves gradient tracking.
10pub struct TensorElementIterator<'a> {
11    pub(crate) source: &'a Tensor,
12    pub(crate) position: usize,
13    pub(crate) end: usize,
14}
15
16impl<'a> TensorElementIterator<'a> {
17    #[inline]
18    pub fn new(tensor: &'a Tensor) -> Self {
19        Self {
20            source: tensor,
21            position: 0,
22            end: tensor.size(),
23        }
24    }
25
26    #[inline]
27    pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
28        let end = end.min(tensor.size());
29        let start = start.min(end);
30        Self {
31            source: tensor,
32            position: start,
33            end,
34        }
35    }
36
37    #[inline]
38    fn create_element_view(&self, index: usize) -> Tensor {
39        debug_assert!(index < self.source.size());
40        self.source.element_view(index)
41    }
42}
43
44impl<'a> Iterator for TensorElementIterator<'a> {
45    type Item = Tensor;
46
47    #[inline]
48    fn next(&mut self) -> Option<Self::Item> {
49        if self.position < self.end {
50            let view = self.create_element_view(self.position);
51            self.position += 1;
52            Some(view)
53        } else {
54            None
55        }
56    }
57
58    #[inline]
59    fn size_hint(&self) -> (usize, Option<usize>) {
60        let remaining = self.end - self.position;
61        (remaining, Some(remaining))
62    }
63
64    #[inline]
65    fn count(self) -> usize {
66        self.end - self.position
67    }
68
69    #[inline]
70    fn nth(&mut self, n: usize) -> Option<Self::Item> {
71        let new_pos = self.position.saturating_add(n);
72        if new_pos < self.end {
73            self.position = new_pos + 1;
74            Some(self.create_element_view(new_pos))
75        } else {
76            self.position = self.end;
77            None
78        }
79    }
80
81    #[inline]
82    fn last(self) -> Option<Self::Item> {
83        if self.position < self.end {
84            let last_idx = self.end - 1;
85            Some(self.create_element_view(last_idx))
86        } else {
87            None
88        }
89    }
90}
91
92impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
93    #[inline]
94    fn len(&self) -> usize {
95        self.end - self.position
96    }
97}
98
99impl<'a> FusedIterator for TensorElementIterator<'a> {}
100
101impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
102    #[inline]
103    fn next_back(&mut self) -> Option<Self::Item> {
104        if self.position < self.end {
105            self.end -= 1;
106            Some(self.create_element_view(self.end))
107        } else {
108            None
109        }
110    }
111
112    #[inline]
113    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
114        let new_end = self.end.saturating_sub(n + 1);
115        if new_end >= self.position {
116            self.end = new_end;
117            Some(self.create_element_view(self.end))
118        } else {
119            self.position = self.end;
120            None
121        }
122    }
123}
124
125impl Tensor {
126    /// Create an iterator over scalar elements (flattened view)
127    #[inline]
128    pub fn iter_elements(&self) -> TensorElementIterator<'_> {
129        TensorElementIterator::new(self)
130    }
131
132    /// Create an iterator over a clamped range of elements
133    #[inline]
134    pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_> {
135        TensorElementIterator::with_range(self, start, end)
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_basic_iteration() {
145        let t = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
146        let vals: Vec<f32> = t.iter_elements().map(|e| e.value()).collect();
147        assert_eq!(vals, vec![1.0, 2.0, 3.0]);
148    }
149
150    #[test]
151    fn test_range_iteration() {
152        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
153        let vals: Vec<f32> = t.iter_range(1, 3).map(|e| e.value()).collect();
154        assert_eq!(vals, vec![2.0, 3.0]);
155    }
156}