train_station/tensor/iterator/
element.rs1use crate::tensor::core::Tensor;
4use std::iter::{ExactSizeIterator, FusedIterator};
5
6pub struct TensorElementIterator<'a> {
11 pub(crate) source: &'a Tensor,
12 pub(crate) position: usize,
13 pub(crate) end: usize,
14}
15
16impl<'a> TensorElementIterator<'a> {
17 #[inline]
18 pub fn new(tensor: &'a Tensor) -> Self {
19 Self {
20 source: tensor,
21 position: 0,
22 end: tensor.size(),
23 }
24 }
25
26 #[inline]
27 pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
28 let end = end.min(tensor.size());
29 let start = start.min(end);
30 Self {
31 source: tensor,
32 position: start,
33 end,
34 }
35 }
36
37 #[inline]
38 fn create_element_view(&self, index: usize) -> Tensor {
39 debug_assert!(index < self.source.size());
40 self.source.element_view(index)
41 }
42}
43
44impl<'a> Iterator for TensorElementIterator<'a> {
45 type Item = Tensor;
46
47 #[inline]
48 fn next(&mut self) -> Option<Self::Item> {
49 if self.position < self.end {
50 let view = self.create_element_view(self.position);
51 self.position += 1;
52 Some(view)
53 } else {
54 None
55 }
56 }
57
58 #[inline]
59 fn size_hint(&self) -> (usize, Option<usize>) {
60 let remaining = self.end - self.position;
61 (remaining, Some(remaining))
62 }
63
64 #[inline]
65 fn count(self) -> usize {
66 self.end - self.position
67 }
68
69 #[inline]
70 fn nth(&mut self, n: usize) -> Option<Self::Item> {
71 let new_pos = self.position.saturating_add(n);
72 if new_pos < self.end {
73 self.position = new_pos + 1;
74 Some(self.create_element_view(new_pos))
75 } else {
76 self.position = self.end;
77 None
78 }
79 }
80
81 #[inline]
82 fn last(self) -> Option<Self::Item> {
83 if self.position < self.end {
84 let last_idx = self.end - 1;
85 Some(self.create_element_view(last_idx))
86 } else {
87 None
88 }
89 }
90}
91
92impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
93 #[inline]
94 fn len(&self) -> usize {
95 self.end - self.position
96 }
97}
98
99impl<'a> FusedIterator for TensorElementIterator<'a> {}
100
101impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
102 #[inline]
103 fn next_back(&mut self) -> Option<Self::Item> {
104 if self.position < self.end {
105 self.end -= 1;
106 Some(self.create_element_view(self.end))
107 } else {
108 None
109 }
110 }
111
112 #[inline]
113 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
114 let new_end = self.end.saturating_sub(n + 1);
115 if new_end >= self.position {
116 self.end = new_end;
117 Some(self.create_element_view(self.end))
118 } else {
119 self.position = self.end;
120 None
121 }
122 }
123}
124
125impl Tensor {
126 #[inline]
128 pub fn iter_elements(&self) -> TensorElementIterator<'_> {
129 TensorElementIterator::new(self)
130 }
131
132 #[inline]
134 pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_> {
135 TensorElementIterator::with_range(self, start, end)
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_basic_iteration() {
145 let t = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
146 let vals: Vec<f32> = t.iter_elements().map(|e| e.value()).collect();
147 assert_eq!(vals, vec![1.0, 2.0, 3.0]);
148 }
149
150 #[test]
151 fn test_range_iteration() {
152 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
153 let vals: Vec<f32> = t.iter_range(1, 3).map(|e| e.value()).collect();
154 assert_eq!(vals, vec![2.0, 3.0]);
155 }
156}