train_station/tensor/iterator/
value.rs

1//! High-performance value iterators over tensor data
2
3use crate::tensor::core::Tensor;
4use crate::tensor::iterator::collect::optimized_copy;
5use std::iter::FusedIterator;
6use std::marker::PhantomData;
7
8/// Read-only iterator over tensor element values (by value)
9///
10/// - Contiguous tensors: zero-copy over underlying memory
11/// - Non-contiguous tensors: materializes a contiguous view once, then iterates
12pub struct TensorValuesIter<'a> {
13    ptr: *const f32,
14    len: usize,
15    pos: usize,
16    // Keep an owned contiguous buffer alive for non-contiguous tensors
17    #[allow(dead_code)]
18    owner: Option<Tensor>,
19    _marker: PhantomData<&'a f32>,
20}
21
22impl<'a> TensorValuesIter<'a> {
23    #[inline]
24    pub(crate) fn new(t: &'a Tensor) -> Self {
25        // Treat any zero-stride view as non-contiguous to avoid unsafe flat iteration
26        let has_zero_stride = t.strides().contains(&0);
27        if (t.is_contiguous() && !has_zero_stride) || t.size() == 0 {
28            let len = t.size();
29            let ptr = unsafe { t.as_ptr() };
30            Self {
31                ptr,
32                len,
33                pos: 0,
34                owner: None,
35                _marker: PhantomData,
36            }
37        } else {
38            // Materialize once into a contiguous, aligned buffer WITHOUT grad wiring
39            let len = t.size();
40            let mut owned = Tensor::new_uninitialized(vec![len]);
41            unsafe {
42                copy_strided_to_contiguous_1d(t, owned.as_mut_ptr());
43            }
44            let ptr = unsafe { owned.as_ptr() };
45            Self {
46                ptr,
47                len,
48                pos: 0,
49                owner: Some(owned),
50                _marker: PhantomData,
51            }
52        }
53    }
54}
55
56impl<'a> Iterator for TensorValuesIter<'a> {
57    type Item = f32;
58
59    #[inline]
60    fn next(&mut self) -> Option<Self::Item> {
61        if self.pos >= self.len {
62            return None;
63        }
64        let i = self.pos;
65        self.pos += 1;
66        unsafe { Some(*self.ptr.add(i)) }
67    }
68
69    #[inline]
70    fn size_hint(&self) -> (usize, Option<usize>) {
71        let remaining = self.len.saturating_sub(self.pos);
72        (remaining, Some(remaining))
73    }
74}
75
76impl<'a> ExactSizeIterator for TensorValuesIter<'a> {
77    #[inline]
78    fn len(&self) -> usize {
79        self.len.saturating_sub(self.pos)
80    }
81}
82
83impl<'a> FusedIterator for TensorValuesIter<'a> {}
84
85impl Tensor {
86    /// Iterate over tensor values efficiently (by value)
87    ///
88    /// - Contiguous tensors: zero-copy over the underlying memory
89    /// - Non-contiguous tensors: materializes a contiguous view once, then iterates
90    ///
91    /// Gradient tracking note: This iterator yields values (f32) and does not
92    /// carry per-element gradient semantics. Use whole-tensor ops for autograd.
93    #[inline]
94    pub fn iter_values(&self) -> TensorValuesIter<'_> {
95        TensorValuesIter::new(self)
96    }
97
98    /// Iterate mutably over tensor values (contiguous only)
99    ///
100    /// This returns a standard slice iterator over the underlying memory.
101    ///
102    /// Panics if the tensor is not contiguous. Use `contiguous()` beforehand
103    /// to materialize a contiguous buffer if needed.
104    #[inline]
105    pub fn iter_values_mut(&mut self) -> std::slice::IterMut<'_, f32> {
106        assert!(
107            self.is_contiguous(),
108            "iter_values_mut requires contiguous tensor; call contiguous() first"
109        );
110        self.data_mut().iter_mut()
111    }
112}
113
114/// Copy an arbitrary-strided tensor into a 1D contiguous buffer (aligned)
115#[inline]
116unsafe fn copy_strided_to_contiguous_1d(src: &Tensor, dst_ptr: *mut f32) {
117    let size = src.size();
118    if size == 0 {
119        return;
120    }
121    let rank = src.shape().rank();
122    let src_base = src.as_ptr();
123
124    if rank >= 1 && src.stride(rank - 1) == 1 {
125        // Fast row-by-row path: last dim contiguous
126        let dims = src.shape().dims();
127        let row_len = dims[rank - 1];
128        let outer: usize = if rank == 1 {
129            1
130        } else {
131            dims[..rank - 1].iter().product()
132        };
133        let strides = src.strides();
134        let mut coords = vec![0usize; rank];
135        for outer_idx in 0..outer {
136            if rank > 1 {
137                let mut tmp = outer_idx;
138                for i in (0..rank - 1).rev() {
139                    let d = dims[i];
140                    coords[i] = if d == 0 { 0 } else { tmp % d };
141                    if d != 0 {
142                        tmp /= d;
143                    }
144                }
145            }
146            coords[rank - 1] = 0;
147            let mut src_off = 0usize;
148            for i in 0..rank {
149                src_off += coords[i] * strides[i];
150            }
151            // destination linear row index
152            let mut dst_row_index = 0usize;
153            if rank > 1 {
154                for i in 0..rank - 1 {
155                    dst_row_index = dst_row_index * dims[i] + coords[i];
156                }
157            }
158            let dst_off = dst_row_index * row_len;
159            optimized_copy(src_base.add(src_off), dst_ptr.add(dst_off), row_len);
160        }
161        return;
162    }
163
164    // General fallback: compute coordinates for each element
165    let dims = src.shape().dims();
166    for dst_idx in 0..size {
167        let mut coords = vec![0usize; rank];
168        let mut tmp = dst_idx;
169        for i in (0..rank).rev() {
170            let d = dims[i];
171            coords[i] = if d == 0 { 0 } else { tmp % d };
172            if d != 0 {
173                tmp /= d;
174            }
175        }
176        let src_off = src.memory_offset(&coords);
177        *dst_ptr.add(dst_idx) = *src_base.add(src_off);
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_iter_values_contiguous() {
187        let t = Tensor::from_slice(&(0..8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
188        let vals: Vec<f32> = t.iter_values().collect();
189        assert_eq!(vals, (0..8).map(|i| i as f32).collect::<Vec<_>>());
190    }
191
192    #[test]
193    fn test_iter_values_mut_contiguous() {
194        let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
195        for v in t.iter_values_mut() {
196            *v *= 2.0;
197        }
198        assert_eq!(t.data(), &[2.0, 4.0, 6.0, 8.0]);
199    }
200
201    #[test]
202    #[should_panic]
203    fn test_iter_values_mut_panics_non_contiguous() {
204        let base =
205            Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
206        let perm = base.permute(vec![1, 0]);
207        let mut owned = perm.contiguous();
208        // Now owned is contiguous; but the panic test requires non-contiguous, so use perm directly
209        let mut not_contig = perm; // non-contiguous view
210        let _ = not_contig.iter_values_mut();
211        // Should panic before this point
212        let _ = owned.iter_values_mut(); // ensure method compiles
213    }
214}