train_station/tensor/iterator/
element.rs1use crate::tensor::core::Tensor;
10use std::iter::{ExactSizeIterator, FusedIterator};
11
12pub struct TensorElementIterator<'a> {
17 pub(crate) source: &'a Tensor,
18 pub(crate) position: usize,
19 pub(crate) end: usize,
20}
21
22impl<'a> TensorElementIterator<'a> {
23 #[inline]
24 pub fn new(tensor: &'a Tensor) -> Self {
25 Self {
26 source: tensor,
27 position: 0,
28 end: tensor.size(),
29 }
30 }
31
32 #[inline]
33 pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
34 let end = end.min(tensor.size());
35 let start = start.min(end);
36 Self {
37 source: tensor,
38 position: start,
39 end,
40 }
41 }
42
43 #[inline]
44 fn create_element_view(&self, index: usize) -> Tensor {
45 debug_assert!(index < self.source.size());
46 self.source.element_view(index)
47 }
48}
49
50impl<'a> Iterator for TensorElementIterator<'a> {
51 type Item = Tensor;
52
53 #[inline]
54 fn next(&mut self) -> Option<Self::Item> {
55 if self.position < self.end {
56 let view = self.create_element_view(self.position);
57 self.position += 1;
58 Some(view)
59 } else {
60 None
61 }
62 }
63
64 #[inline]
65 fn size_hint(&self) -> (usize, Option<usize>) {
66 let remaining = self.end - self.position;
67 (remaining, Some(remaining))
68 }
69
70 #[inline]
71 fn count(self) -> usize {
72 self.end - self.position
73 }
74
75 #[inline]
76 fn nth(&mut self, n: usize) -> Option<Self::Item> {
77 let new_pos = self.position.saturating_add(n);
78 if new_pos < self.end {
79 self.position = new_pos + 1;
80 Some(self.create_element_view(new_pos))
81 } else {
82 self.position = self.end;
83 None
84 }
85 }
86
87 #[inline]
88 fn last(self) -> Option<Self::Item> {
89 if self.position < self.end {
90 let last_idx = self.end - 1;
91 Some(self.create_element_view(last_idx))
92 } else {
93 None
94 }
95 }
96}
97
98impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
99 #[inline]
100 fn len(&self) -> usize {
101 self.end - self.position
102 }
103}
104
105impl<'a> FusedIterator for TensorElementIterator<'a> {}
106
107impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
108 #[inline]
109 fn next_back(&mut self) -> Option<Self::Item> {
110 if self.position < self.end {
111 self.end -= 1;
112 Some(self.create_element_view(self.end))
113 } else {
114 None
115 }
116 }
117
118 #[inline]
119 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
120 let new_end = self.end.saturating_sub(n + 1);
121 if new_end >= self.position {
122 self.end = new_end;
123 Some(self.create_element_view(self.end))
124 } else {
125 self.position = self.end;
126 None
127 }
128 }
129}
130
131impl Tensor {
132 #[inline]
158 pub fn iter_elements(&self) -> TensorElementIterator<'_> {
159 TensorElementIterator::new(self)
160 }
161
162 #[inline]
181 pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_> {
182 TensorElementIterator::with_range(self, start, end)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::gradtrack::NoGradTrack;
190
191 #[test]
192 fn test_basic_iteration() {
193 let t = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
194 let vals: Vec<f32> = t.iter_elements().map(|e| e.value()).collect();
195 assert_eq!(vals, vec![1.0, 2.0, 3.0]);
196 }
197
198 #[test]
199 fn test_range_iteration() {
200 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
201 let vals: Vec<f32> = t.iter_range(1, 3).map(|e| e.value()).collect();
202 assert_eq!(vals, vec![2.0, 3.0]);
203 }
204
205 #[test]
206 fn test_iter_flat_gradient_propagation() {
207 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])
208 .unwrap()
209 .with_requires_grad();
210
211 let collected: Tensor = t.iter_elements().map(|e| e.add_scalar(1.0)).collect();
213 let mut loss = collected.sum();
214 loss.backward(None);
215
216 let g = t.grad_owned().unwrap();
217 assert_eq!(g.shape().dims(), vec![5]);
218 assert_eq!(g.data(), &[1.0, 1.0, 1.0, 1.0, 1.0]);
219 }
220
221 #[test]
222 fn test_element_iterator_double_ended_and_exact_size() {
223 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
224 let mut it = t.iter_elements();
225 assert_eq!(it.len(), 4);
226 assert_eq!(it.size_hint(), (4, Some(4)));
227 assert_eq!(it.next_back().unwrap().value(), 4.0);
228 assert_eq!(it.next().unwrap().value(), 1.0);
229 assert_eq!(it.nth(1).unwrap().value(), 3.0);
230 assert!(it.next().is_none());
231 assert_eq!(t.iter_elements().last().unwrap().value(), 4.0);
232 }
233
234 #[test]
235 fn test_iter_range_clamping_and_zero_sized() {
236 let t = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
237 let vals: Vec<f32> = t.iter_range(2, 10).map(|e| e.value()).collect();
238 assert_eq!(vals, vec![3.0]);
239 let empty = Tensor::new(vec![0]);
240 let it = empty.iter_elements();
241 assert_eq!(it.len(), 0);
242 assert_eq!(it.size_hint(), (0, Some(0)));
243 }
244
245 #[test]
246 fn test_iter_no_grad_guard() {
247 let t = Tensor::from_slice(&[1.0, 2.0], vec![2])
248 .unwrap()
249 .with_requires_grad();
250 let _g = NoGradTrack::new();
251 let v = t.iter_elements().next().unwrap();
252 assert!(!v.requires_grad());
253 }
254}