train_station/tensor/iterator/
value.rs1use crate::tensor::core::Tensor;
4use crate::tensor::iterator::collect::optimized_copy;
5use std::iter::FusedIterator;
6use std::marker::PhantomData;
7
8pub struct TensorValuesIter<'a> {
13 ptr: *const f32,
14 len: usize,
15 pos: usize,
16 #[allow(dead_code)]
18 owner: Option<Tensor>,
19 _marker: PhantomData<&'a f32>,
20}
21
22impl<'a> TensorValuesIter<'a> {
23 #[inline]
24 pub(crate) fn new(t: &'a Tensor) -> Self {
25 let has_zero_stride = t.strides().contains(&0);
27 if (t.is_contiguous() && !has_zero_stride) || t.size() == 0 {
28 let len = t.size();
29 let ptr = unsafe { t.as_ptr() };
30 Self {
31 ptr,
32 len,
33 pos: 0,
34 owner: None,
35 _marker: PhantomData,
36 }
37 } else {
38 let len = t.size();
40 let mut owned = Tensor::new_uninitialized(vec![len]);
41 unsafe {
42 copy_strided_to_contiguous_1d(t, owned.as_mut_ptr());
43 }
44 let ptr = unsafe { owned.as_ptr() };
45 Self {
46 ptr,
47 len,
48 pos: 0,
49 owner: Some(owned),
50 _marker: PhantomData,
51 }
52 }
53 }
54}
55
56impl<'a> Iterator for TensorValuesIter<'a> {
57 type Item = f32;
58
59 #[inline]
60 fn next(&mut self) -> Option<Self::Item> {
61 if self.pos >= self.len {
62 return None;
63 }
64 let i = self.pos;
65 self.pos += 1;
66 unsafe { Some(*self.ptr.add(i)) }
67 }
68
69 #[inline]
70 fn size_hint(&self) -> (usize, Option<usize>) {
71 let remaining = self.len.saturating_sub(self.pos);
72 (remaining, Some(remaining))
73 }
74}
75
76impl<'a> ExactSizeIterator for TensorValuesIter<'a> {
77 #[inline]
78 fn len(&self) -> usize {
79 self.len.saturating_sub(self.pos)
80 }
81}
82
83impl<'a> FusedIterator for TensorValuesIter<'a> {}
84
85impl Tensor {
86 #[inline]
94 pub fn iter_values(&self) -> TensorValuesIter<'_> {
95 TensorValuesIter::new(self)
96 }
97
98 #[inline]
105 pub fn iter_values_mut(&mut self) -> std::slice::IterMut<'_, f32> {
106 assert!(
107 self.is_contiguous(),
108 "iter_values_mut requires contiguous tensor; call contiguous() first"
109 );
110 self.data_mut().iter_mut()
111 }
112}
113
114#[inline]
116unsafe fn copy_strided_to_contiguous_1d(src: &Tensor, dst_ptr: *mut f32) {
117 let size = src.size();
118 if size == 0 {
119 return;
120 }
121 let rank = src.shape().rank();
122 let src_base = src.as_ptr();
123
124 if rank >= 1 && src.stride(rank - 1) == 1 {
125 let dims = src.shape().dims();
127 let row_len = dims[rank - 1];
128 let outer: usize = if rank == 1 {
129 1
130 } else {
131 dims[..rank - 1].iter().product()
132 };
133 let strides = src.strides();
134 let mut coords = vec![0usize; rank];
135 for outer_idx in 0..outer {
136 if rank > 1 {
137 let mut tmp = outer_idx;
138 for i in (0..rank - 1).rev() {
139 let d = dims[i];
140 coords[i] = if d == 0 { 0 } else { tmp % d };
141 if d != 0 {
142 tmp /= d;
143 }
144 }
145 }
146 coords[rank - 1] = 0;
147 let mut src_off = 0usize;
148 for i in 0..rank {
149 src_off += coords[i] * strides[i];
150 }
151 let mut dst_row_index = 0usize;
153 if rank > 1 {
154 for i in 0..rank - 1 {
155 dst_row_index = dst_row_index * dims[i] + coords[i];
156 }
157 }
158 let dst_off = dst_row_index * row_len;
159 optimized_copy(src_base.add(src_off), dst_ptr.add(dst_off), row_len);
160 }
161 return;
162 }
163
164 let dims = src.shape().dims();
166 for dst_idx in 0..size {
167 let mut coords = vec![0usize; rank];
168 let mut tmp = dst_idx;
169 for i in (0..rank).rev() {
170 let d = dims[i];
171 coords[i] = if d == 0 { 0 } else { tmp % d };
172 if d != 0 {
173 tmp /= d;
174 }
175 }
176 let src_off = src.memory_offset(&coords);
177 *dst_ptr.add(dst_idx) = *src_base.add(src_off);
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_iter_values_contiguous() {
187 let t = Tensor::from_slice(&(0..8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
188 let vals: Vec<f32> = t.iter_values().collect();
189 assert_eq!(vals, (0..8).map(|i| i as f32).collect::<Vec<_>>());
190 }
191
192 #[test]
193 fn test_iter_values_mut_contiguous() {
194 let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
195 for v in t.iter_values_mut() {
196 *v *= 2.0;
197 }
198 assert_eq!(t.data(), &[2.0, 4.0, 6.0, 8.0]);
199 }
200
201 #[test]
202 #[should_panic]
203 fn test_iter_values_mut_panics_non_contiguous() {
204 let base =
205 Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
206 let perm = base.permute(vec![1, 0]);
207 let mut owned = perm.contiguous();
208 let mut not_contig = perm; let _ = not_contig.iter_values_mut();
211 let _ = owned.iter_values_mut(); }
214}