train_station/tensor/iterator/
mod.rs

1//! Iterator module for tensor element-wise operations
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! # Key Features
9//!
10//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
11//!   DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
12//! - **Gradient Tracking**: Automatic gradient propagation through element operations
13//! - **Performance Optimized**: True zero-copy views with shared memory
14//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
15//! - **Memory Efficient**: Adaptive view creation based on tensor size
16//! - **Zero-Copy Operations**: Element views share memory with source tensor
17//! - **Full Tensor Operations**: Each element supports all tensor methods
18//!
19//! # Performance Characteristics
20//!
21//! - **View Creation**: O(1) per element with true zero-copy views
22//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
23//! - **SIMD Operations**: Full utilization of existing optimizations
24//! - **Gradient Tracking**: True gradient flow with element-level accumulation
25//! - **Iterator Overhead**: Minimal performance impact for element access
26//! - **Collection Optimization**: Efficient reconstruction from element views
27//!
28//! # Examples
29//!
30//! ## Basic Element Iteration
31//!
32//! ```
33//! use train_station::Tensor;
34//!
35//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
36//!
37//! // Basic iteration over elements
38//! for element in tensor.iter() {
39//!     println!("Element value: {}", element.value());
40//! }
41//!
42//! // Collect elements into a new tensor
43//! let collected: Tensor = tensor.iter().collect();
44//! assert_eq!(collected.data(), tensor.data());
45//! ```
46//!
47//! ## Element-Wise Transformations
48//!
49//! ```
50//! use train_station::Tensor;
51//!
52//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
53//!
54//! // Apply tensor operations to each element
55//! let doubled: Tensor = tensor.iter()
56//!     .map(|elem| elem.mul_scalar(2.0))
57//!     .collect();
58//!
59//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
60//!
61//! // Chain multiple operations
62//! let transformed: Tensor = tensor.iter()
63//!     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
64//!     .collect();
65//!
66//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
67//! ```
68//!
69//! ## Advanced Iterator Operations
70//!
71//! ```
72//! use train_station::Tensor;
73//!
74//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
75//!
76//! // Filter elements based on values
77//! let large_values: Tensor = tensor.iter()
78//!     .filter(|elem| elem.value() > 3.0)
79//!     .collect();
80//!
81//! assert_eq!(large_values.data(), &[4.0, 5.0]);
82//!
83//! // Use enumerate for indexed operations
84//! let indexed: Tensor = tensor.iter()
85//!     .enumerate()
86//!     .map(|(i, elem)| elem.add_scalar(i as f32))
87//!     .collect();
88//!
89//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
90//! ```
91//!
92//! ## Range Iteration
93//!
94//! ```
95//! use train_station::Tensor;
96//!
97//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
98//!
99//! // Iterate over a specific range
100//! let middle: Tensor = tensor.iter_range(1, 4)
101//!     .map(|elem| elem.mul_scalar(2.0))
102//!     .collect();
103//!
104//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
105//! ```
106//!
107//! ## Double-Ended Iteration
108//!
109//! ```
110//! use train_station::Tensor;
111//!
112//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
113//!
114//! // Reverse iteration
115//! let reversed: Tensor = tensor.iter().rev().collect();
116//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
117//!
118//! // Iterate from both ends
119//! let mut iter = tensor.iter();
120//! assert_eq!(iter.next().unwrap().value(), 1.0);
121//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
122//! ```
123//!
124//! ## Gradient Tracking
125//!
126//! ```
127//! use train_station::Tensor;
128//!
129//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
130//!     .unwrap()
131//!     .with_requires_grad();
132//!
133//! // Element operations maintain gradient tracking
134//! let result: Tensor = tensor.iter()
135//!     .map(|elem| elem.mul_scalar(2.0))
136//!     .collect();
137//!
138//! assert!(result.requires_grad());
139//! assert_eq!(result.data(), &[2.0, 4.0]);
140//! ```
141//!
142//! # Design Principles
143//!
144//! - **Zero-Copy Views**: Element views share memory with source tensor
145//! - **Full Tensor Operations**: Each element supports all tensor methods
146//! - **Standard Library Integration**: Complete compatibility with Rust iterators
147//! - **Performance First**: Optimized for high-performance element access
148//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
149//! - **Memory Efficiency**: Minimal overhead for element iteration
150//! - **Type Safety**: Compile-time guarantees for iterator operations
151
152use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
153use crate::tensor::core::Tensor;
154use std::iter::{FromIterator, FusedIterator};
155
156/// High-performance iterator over tensor elements as view tensors
157///
158/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
159/// all existing tensor operations and gradient tracking. Implements all
160/// standard iterator traits for maximum compatibility with Rust's ecosystem.
161///
162/// This iterator provides zero-copy access to tensor elements through view
163/// tensors, enabling efficient element-wise operations while maintaining
164/// full compatibility with Rust's standard library iterator methods.
165///
166/// # Performance
167///
168/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
169/// - **O(1) Element Access**: Constant-time view creation for each element
170/// - **Memory Efficient**: ~64 bytes overhead per element view
171/// - **SIMD Compatible**: All tensor operations use existing optimizations
172/// - **Gradient Tracking**: Full gradtrack support through element operations
173///
174/// # Implementation Details
175///
176/// The iterator creates lightweight view tensors on-demand, sharing the same
177/// memory allocation as the source tensor. This ensures zero-copy semantics
178/// while maintaining full tensor operation compatibility.
179///
180/// Each element view is created using `Tensor::element_view()`, which provides
181/// a true view of the underlying data without any copying. The view tensors
182/// support all standard tensor operations including gradient tracking.
183///
184/// # Standard Library Compatibility
185///
186/// This iterator implements all standard iterator traits:
187/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
188/// - `ExactSizeIterator`: Precise size information with `len()`
189/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
190/// - `FusedIterator`: Fused iteration for better performance
191/// - `IntoIterator`: Automatic conversion for `for` loops
192///
193/// # Examples
194///
195/// ## Basic Iteration
196///
197/// ```
198/// use train_station::Tensor;
199///
200/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
201///
202/// // Basic iteration
203/// for element in tensor.iter() {
204///     println!("Element value: {}", element.value());
205/// }
206///
207/// // Standard library methods
208/// let sum: f32 = tensor.iter()
209///     .map(|elem| elem.value())
210///     .sum();
211///
212/// assert_eq!(sum, 6.0);
213/// ```
214///
215/// ## Element Operations
216///
217/// ```
218/// use train_station::Tensor;
219///
220/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
221///
222/// // Tensor operations on elements
223/// let transformed: Tensor = tensor.iter()
224///     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
225///     .collect();
226///
227/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
228/// ```
229///
230/// ## Advanced Iterator Methods
231///
232/// ```
233/// use train_station::Tensor;
234///
235/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
236///
237/// // Filter and transform
238/// let result: Tensor = tensor.iter()
239///     .filter(|elem| elem.value() > 2.0)
240///     .map(|elem| elem.mul_scalar(10.0))
241///     .collect();
242///
243/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
244///
245/// // Reverse iteration
246/// let reversed: Tensor = tensor.iter().rev().collect();
247/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
248/// ```
249pub struct TensorElementIterator<'a> {
250    /// Reference to the source tensor
251    source: &'a Tensor,
252    /// Current position in iteration
253    position: usize,
254    /// End position (exclusive)
255    end: usize,
256}
257
258impl<'a> TensorElementIterator<'a> {
259    /// Create a new iterator over all tensor elements
260    ///
261    /// Creates an iterator that yields view tensors for each element in the
262    /// source tensor. Each element becomes a `Tensor` of shape `[1]` that
263    /// supports all tensor operations and gradient tracking.
264    ///
265    /// # Arguments
266    ///
267    /// * `tensor` - The source tensor to iterate over
268    ///
269    /// # Returns
270    ///
271    /// An iterator that yields view tensors for each element
272    ///
273    /// # Performance
274    ///
275    /// - **O(1) Creation**: Constant-time iterator initialization
276    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
277    /// - **Memory Efficient**: Minimal overhead for iterator state
278    ///
279    /// # Implementation Details
280    ///
281    /// This method creates an iterator that yields view tensors for each element
282    /// in the source tensor. Each element becomes a `Tensor` of shape `[1]` that
283    /// supports all tensor operations and gradient tracking.
284    ///
285    /// The iterator provides zero-copy access to tensor elements through view
286    /// tensors, enabling efficient element-wise operations while maintaining
287    /// full compatibility with Rust's standard library iterator methods.
288    #[track_caller]
289    pub fn new(tensor: &'a Tensor) -> Self {
290        Self {
291            source: tensor,
292            position: 0,
293            end: tensor.size(),
294        }
295    }
296
297    /// Create an iterator over a specific range of elements
298    ///
299    /// Creates an iterator that yields view tensors for elements in the specified
300    /// range. The range is automatically clamped to valid tensor bounds for safety.
301    ///
302    /// # Arguments
303    ///
304    /// * `tensor` - The source tensor to iterate over
305    /// * `start` - Starting index (inclusive)
306    /// * `end` - Ending index (exclusive)
307    ///
308    /// # Returns
309    ///
310    /// An iterator that yields view tensors for elements in the specified range
311    ///
312    /// # Safety
313    ///
314    /// The range is automatically clamped to valid tensor bounds:
315    /// - `start` is clamped to `[0, tensor.size()]`
316    /// - `end` is clamped to `[start, tensor.size()]`
317    /// - Empty ranges (start >= end) are handled gracefully
318    ///
319    /// # Performance
320    ///
321    /// - **O(1) Creation**: Constant-time iterator initialization
322    /// - **Bounds Checking**: Automatic range validation and clamping
323    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
324    ///
325    /// # Implementation Details
326    ///
327    /// This method creates an iterator that yields view tensors for elements in
328    /// the specified range. The range is automatically clamped to valid tensor
329    /// bounds for safety, ensuring that out-of-bounds access is handled gracefully.
330    ///
331    /// The iterator provides zero-copy access to tensor elements through view
332    /// tensors, enabling efficient element-wise operations while maintaining
333    /// full compatibility with Rust's standard library iterator methods.
334    #[track_caller]
335    pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
336        let end = end.min(tensor.size());
337        let start = start.min(end);
338        Self {
339            source: tensor,
340            position: start,
341            end,
342        }
343    }
344
345    /// Create an optimized element view for the given position
346    ///
347    /// This method creates a true view tensor of shape `[1]` that shares memory
348    /// with the element at the specified index in the source tensor. The view
349    /// enables zero-copy element access with full gradient tracking.
350    ///
351    /// # Arguments
352    ///
353    /// * `index` - Index of the element to create a view for
354    ///
355    /// # Returns
356    ///
357    /// A view tensor of shape `[1]` representing the element at the specified index
358    ///
359    /// # Safety
360    ///
361    /// The caller must ensure that `index < self.source.size()`.
362    ///
363    /// # Performance
364    ///
365    /// - **O(1) View Creation**: Constant-time view tensor creation
366    /// - **Zero-Copy**: View shares memory with source tensor
367    /// - **Memory Efficient**: ~64 bytes overhead for view metadata
368    /// - **Gradient Tracking**: Full gradtrack support through view operations
369    ///
370    /// # Implementation Details
371    ///
372    /// This method delegates to `Tensor::element_view()` which creates a true
373    /// view of the underlying data without any copying. The view tensor supports
374    /// all standard tensor operations including gradient tracking and SIMD
375    /// optimizations.
376    fn create_element_view(&self, index: usize) -> Tensor {
377        debug_assert!(index < self.source.size());
378
379        self.source.element_view(index)
380    }
381}
382
383// ===== Core Iterator Implementation =====
384
385impl<'a> Iterator for TensorElementIterator<'a> {
386    type Item = Tensor;
387
388    #[inline]
389    fn next(&mut self) -> Option<Self::Item> {
390        if self.position < self.end {
391            let view = self.create_element_view(self.position);
392            self.position += 1;
393            Some(view)
394        } else {
395            None
396        }
397    }
398
399    #[inline]
400    fn size_hint(&self) -> (usize, Option<usize>) {
401        let remaining = self.end - self.position;
402        (remaining, Some(remaining))
403    }
404
405    #[inline]
406    fn count(self) -> usize {
407        self.end - self.position
408    }
409
410    #[inline]
411    fn nth(&mut self, n: usize) -> Option<Self::Item> {
412        let new_pos = self.position.saturating_add(n);
413        if new_pos < self.end {
414            self.position = new_pos + 1;
415            Some(self.create_element_view(new_pos))
416        } else {
417            self.position = self.end;
418            None
419        }
420    }
421
422    #[inline]
423    fn last(self) -> Option<Self::Item> {
424        if self.position < self.end {
425            let last_idx = self.end - 1;
426            Some(self.create_element_view(last_idx))
427        } else {
428            None
429        }
430    }
431}
432
433impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
434    #[inline]
435    fn len(&self) -> usize {
436        self.end - self.position
437    }
438}
439
440impl<'a> FusedIterator for TensorElementIterator<'a> {}
441
442impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
443    #[inline]
444    fn next_back(&mut self) -> Option<Self::Item> {
445        if self.position < self.end {
446            self.end -= 1;
447            Some(self.create_element_view(self.end))
448        } else {
449            None
450        }
451    }
452
453    #[inline]
454    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
455        let new_end = self.end.saturating_sub(n + 1);
456        if new_end >= self.position {
457            self.end = new_end;
458            Some(self.create_element_view(self.end))
459        } else {
460            self.position = self.end;
461            None
462        }
463    }
464}
465
466// ===== IntoIterator Implementation =====
467
468impl<'a> IntoIterator for &'a Tensor {
469    type Item = Tensor;
470    type IntoIter = TensorElementIterator<'a>;
471
472    fn into_iter(self) -> Self::IntoIter {
473        TensorElementIterator::new(self)
474    }
475}
476
477// ===== FromIterator Implementation =====
478
479impl FromIterator<Tensor> for Tensor {
480    /// Collect element view tensors back into a single tensor
481    ///
482    /// This method reconstructs a tensor from an iterator of element view tensors.
483    /// It includes optimizations for common patterns and maintains gradient tracking
484    /// when appropriate.
485    ///
486    /// The collection process automatically detects whether all elements are scalar
487    /// views (shape `[1]`) and uses optimized collection strategies accordingly.
488    /// Gradient tracking is preserved when any input element requires gradients.
489    ///
490    /// # Performance
491    ///
492    /// - **Optimized Collection**: Specialized paths for scalar and mixed views
493    /// - **Memory Efficient**: Direct memory copying without intermediate allocations
494    /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
495    /// - **Shape Detection**: Automatic detection of element shapes for optimization
496    ///
497    /// # Implementation Details
498    ///
499    /// The method performs the following steps:
500    /// 1. **Element Collection**: Gathers all element tensors from the iterator
501    /// 2. **Shape Analysis**: Determines if all elements are scalar views
502    /// 3. **Optimized Path**: Uses specialized collection for scalar views
503    /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
504    /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
505    ///
506    /// # Examples
507    ///
508    /// ## Basic Collection
509    ///
510    /// ```
511    /// use train_station::Tensor;
512    ///
513    /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
514    /// let doubled: Tensor = original.iter()
515    ///     .map(|elem| elem.mul_scalar(2.0))
516    ///     .collect();
517    ///
518    /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
519    /// ```
520    ///
521    /// ## Collection with Gradient Tracking
522    ///
523    /// ```
524    /// use train_station::Tensor;
525    ///
526    /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
527    ///     .unwrap()
528    ///     .with_requires_grad();
529    ///
530    /// let result: Tensor = original.iter()
531    ///     .map(|elem| elem.mul_scalar(2.0))
532    ///     .collect();
533    ///
534    /// assert!(result.requires_grad());
535    /// assert_eq!(result.data(), &[2.0, 4.0]);
536    /// ```
537    ///
538    /// ## Empty Iterator Handling
539    ///
540    /// ```
541    /// use train_station::Tensor;
542    ///
543    /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
544    /// assert_eq!(empty.size(), 0);
545    /// assert_eq!(empty.shape().dims, vec![0]);
546    /// ```
547    fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
548        let elements: Vec<Tensor> = iter.into_iter().collect();
549
550        if elements.is_empty() {
551            return Tensor::new(vec![0]);
552        }
553
554        // Check if all elements are scalar views (shape [1])
555        let all_scalars = elements.iter().all(|e| e.shape().dims == vec![1]);
556
557        if all_scalars {
558            // Optimized path for scalar element views
559            Self::collect_scalar_views(elements)
560        } else {
561            // General path for mixed shapes
562            Self::collect_mixed_views(elements)
563        }
564    }
565}
566
567impl Tensor {
568    /// Optimized collection for scalar element views
569    ///
570    /// This method efficiently reconstructs a tensor from scalar element views,
571    /// preserving gradient tracking and using optimized memory operations.
572    ///
573    /// This is the fast path for collection when all elements are scalar views
574    /// (shape `[1]`). It performs direct memory copying and sets up gradient
575    /// tracking when any input element requires gradients.
576    ///
577    /// # Arguments
578    ///
579    /// * `elements` - Vector of scalar element view tensors
580    ///
581    /// # Returns
582    ///
583    /// A new tensor containing all element values in a 1D layout
584    ///
585    /// # Performance
586    ///
587    /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
588    /// - **Gradient Optimization**: Efficient gradient tracking setup
589    /// - **Memory Efficient**: Minimal overhead for collection process
590    /// - **SIMD Compatible**: Result tensor supports all optimizations
591    ///
592    /// # Implementation Details
593    ///
594    /// The method performs the following steps:
595    /// 1. **Allocation**: Creates uninitialized tensor with correct size
596    /// 2. **Gradient Check**: Determines if any element requires gradients
597    /// 3. **Memory Copy**: Direct copying from element views to result
598    /// 4. **Gradient Setup**: Configures gradient tracking when needed
599    /// 5. **Operation Registration**: Registers with gradtrack engine
600    fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
601        let len = elements.len();
602        let mut result = Self::new_uninitialized(vec![len]);
603
604        // Determine if we can track gradients
605        let requires_grad = elements.iter().any(|e| e.requires_grad());
606
607        // Copy data from element views
608        unsafe {
609            let dst = result.as_mut_ptr();
610            for (i, element) in elements.iter().enumerate() {
611                *dst.add(i) = *element.as_ptr();
612            }
613        }
614
615        // Set up gradient tracking
616        if requires_grad && is_grad_enabled() {
617            result.set_requires_grad_internal(true);
618            let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
619            let grad_fn = GradFn::ElementCollection {
620                element_ids: element_ids.clone(),
621                result_shape: vec![len],
622            };
623            result.set_grad_fn(grad_fn.clone());
624            GradEngine::register_operation(result.id(), element_ids, grad_fn);
625        }
626
627        result
628    }
629
630    /// General collection for mixed element shapes
631    ///
632    /// This method handles collection when elements have different shapes,
633    /// flattening all elements into a 1D tensor.
634    ///
635    /// This is the general path for collection when elements have varying shapes.
636    /// It flattens all elements into a single 1D tensor and preserves gradient
637    /// tracking when any input element requires gradients.
638    ///
639    /// # Arguments
640    ///
641    /// * `elements` - Vector of element tensors with potentially different shapes
642    ///
643    /// # Returns
644    ///
645    /// A new 1D tensor containing all flattened element values
646    ///
647    /// # Performance
648    ///
649    /// - **Flattening**: Converts all elements to 1D layout
650    /// - **Memory Copy**: Efficient copying with size calculation
651    /// - **Gradient Preservation**: Maintains gradtrack functionality
652    /// - **Mixed Shapes**: Handles elements with different dimensions
653    ///
654    /// # Implementation Details
655    ///
656    /// The method performs the following steps:
657    /// 1. **Size Calculation**: Sums sizes of all elements for total size
658    /// 2. **Allocation**: Creates uninitialized tensor with total size
659    /// 3. **Sequential Copy**: Copies each element's data sequentially
660    /// 4. **Gradient Setup**: Configures gradient tracking when needed
661    /// 5. **Operation Registration**: Registers with gradtrack engine
662    fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
663        // For mixed shapes, flatten all elements into a 1D tensor
664        let total_size: usize = elements.iter().map(|e| e.size()).sum();
665        let mut result = Self::new_uninitialized(vec![total_size]);
666
667        let requires_grad = elements.iter().any(|e| e.requires_grad());
668        let mut offset = 0;
669
670        unsafe {
671            let dst = result.as_mut_ptr();
672            for element in &elements {
673                let src = element.as_ptr();
674                let size = element.size();
675                std::ptr::copy_nonoverlapping(src, dst.add(offset), size);
676                offset += size;
677            }
678        }
679
680        if requires_grad && is_grad_enabled() {
681            result.set_requires_grad_internal(true);
682            let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
683            let grad_fn = GradFn::ElementCollection {
684                element_ids: element_ids.clone(),
685                result_shape: vec![total_size],
686            };
687            result.set_grad_fn(grad_fn.clone());
688            GradEngine::register_operation(result.id(), element_ids, grad_fn);
689        }
690
691        result
692    }
693
694    /// Create an iterator over tensor elements as view tensors
695    ///
696    /// Each element becomes a `Tensor` of shape `[1]` that supports all
697    /// tensor operations and gradient tracking. This is the main entry point
698    /// for element-wise iteration with full tensor operation support.
699    ///
700    /// The iterator provides zero-copy access to tensor elements through view
701    /// tensors, enabling efficient element-wise operations while maintaining
702    /// full compatibility with Rust's standard library iterator methods.
703    ///
704    /// # Returns
705    ///
706    /// An iterator that yields view tensors for each element
707    ///
708    /// # Performance
709    ///
710    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
711    /// - **O(1) Element Access**: Constant-time view creation for each element
712    /// - **Memory Efficient**: ~64 bytes overhead per element view
713    /// - **SIMD Compatible**: All tensor operations use existing optimizations
714    /// - **Gradient Tracking**: Full gradtrack support through element operations
715    ///
716    /// # Examples
717    ///
718    /// ## Basic Element Operations
719    ///
720    /// ```
721    /// use train_station::Tensor;
722    ///
723    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
724    ///
725    /// // Use any std iterator method
726    /// let result: Tensor = tensor.iter()
727    ///     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
728    ///     .filter(|elem| elem.value() > 3.0)                // Keep values > 3
729    ///     .collect();
730    ///
731    /// assert_eq!(result.data(), &[5.0, 7.0]);
732    /// ```
733    ///
734    /// ## Advanced Iterator Chains
735    ///
736    /// ```
737    /// use train_station::Tensor;
738    ///
739    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
740    ///
741    /// // Chain with enumerate, zip, etc.
742    /// let indexed: Tensor = tensor.iter()
743    ///     .enumerate()
744    ///     .map(|(i, elem)| elem.add_scalar(i as f32))
745    ///     .collect();
746    ///
747    /// assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
748    /// ```
749    ///
750    /// ## Double-Ended Iteration
751    ///
752    /// ```
753    /// use train_station::Tensor;
754    ///
755    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
756    ///
757    /// // Use double-ended iterator
758    /// let reversed: Tensor = tensor.iter()
759    ///     .rev()
760    ///     .collect();
761    ///
762    /// assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
763    /// ```
764    ///
765    /// ## Gradient Tracking
766    ///
767    /// ```
768    /// use train_station::Tensor;
769    ///
770    /// let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
771    ///     .unwrap()
772    ///     .with_requires_grad();
773    ///
774    /// let result: Tensor = tensor.iter()
775    ///     .map(|elem| elem.mul_scalar(2.0))
776    ///     .collect();
777    ///
778    /// assert!(result.requires_grad());
779    /// assert_eq!(result.data(), &[2.0, 4.0]);
780    /// ```
781    #[track_caller]
782    pub fn iter(&self) -> TensorElementIterator<'_> {
783        TensorElementIterator::new(self)
784    }
785
786    /// Create an iterator over a range of elements
787    ///
788    /// Creates an iterator that yields view tensors for elements in the specified
789    /// range. The range is automatically clamped to valid tensor bounds for safety.
790    ///
791    /// # Arguments
792    ///
793    /// * `start` - Starting index (inclusive)
794    /// * `end` - Ending index (exclusive)
795    ///
796    /// # Returns
797    ///
798    /// An iterator that yields view tensors for elements in the specified range
799    ///
800    /// # Safety
801    ///
802    /// The range is automatically clamped to valid tensor bounds:
803    /// - `start` is clamped to `[0, tensor.size()]`
804    /// - `end` is clamped to `[start, tensor.size()]`
805    /// - Empty ranges (start >= end) are handled gracefully
806    ///
807    /// # Performance
808    ///
809    /// - **O(1) Creation**: Constant-time iterator initialization
810    /// - **Bounds Checking**: Automatic range validation and clamping
811    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
812    /// - **Memory Efficient**: Minimal overhead for range iteration
813    ///
814    /// # Examples
815    ///
816    /// ## Basic Range Iteration
817    ///
818    /// ```
819    /// use train_station::Tensor;
820    ///
821    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
822    /// let middle: Tensor = tensor.iter_range(1, 4)
823    ///     .map(|elem| elem.mul_scalar(2.0))
824    ///     .collect();
825    ///
826    /// assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
827    /// ```
828    ///
829    /// ## Range with Operations
830    ///
831    /// ```
832    /// use train_station::Tensor;
833    ///
834    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
835    ///
836    /// // Apply complex operations to range
837    /// let result: Tensor = tensor.iter_range(0, 3)
838    ///     .enumerate()
839    ///     .map(|(i, elem)| elem.add_scalar(i as f32))
840    ///     .collect();
841    ///
842    /// assert_eq!(result.data(), &[1.0, 3.0, 5.0]);
843    /// ```
844    ///
845    /// ## Out of Bounds Handling
846    ///
847    /// ```
848    /// use train_station::Tensor;
849    ///
850    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
851    ///
852    /// // Out of bounds range is clamped
853    /// let empty: Tensor = tensor.iter_range(5, 10).collect();
854    /// assert_eq!(empty.size(), 0);
855    ///
856    /// // Partial out of bounds
857    /// let partial: Tensor = tensor.iter_range(1, 10).collect();
858    /// assert_eq!(partial.data(), &[2.0, 3.0]);
859    /// ```
860    #[track_caller]
861    pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_> {
862        TensorElementIterator::with_range(self, start, end)
863    }
864}
865
866#[cfg(test)]
867mod tests {
868    //! Comprehensive tests for tensor element iterator functionality
869    //!
870    //! These tests cover all aspects of the iterator implementation:
871    //! - Basic iteration functionality
872    //! - Standard library trait compliance
873    //! - Gradient tracking through element operations
874    //! - Performance characteristics
875    //! - Edge cases and error conditions
876
877    use super::*;
878
879    /// Test basic iterator functionality
880    #[test]
881    fn test_basic_iteration() {
882        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
883
884        let elements: Vec<Tensor> = tensor.iter().collect();
885        assert_eq!(elements.len(), 4);
886
887        // Check that each element is a scalar tensor with correct value
888        for (i, elem) in elements.iter().enumerate() {
889            assert_eq!(elem.shape().dims, vec![1]);
890            assert_eq!(elem.size(), 1);
891            assert_eq!(elem.value(), (i + 1) as f32);
892        }
893    }
894
895    /// Test Iterator trait methods
896    #[test]
897    fn test_iterator_trait_methods() {
898        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
899        let mut iter = tensor.iter();
900
901        // Test next()
902        let first = iter.next().unwrap();
903        assert_eq!(first.value(), 1.0);
904
905        // Test size_hint()
906        assert_eq!(iter.size_hint(), (4, Some(4)));
907
908        // Test count()
909        assert_eq!(iter.count(), 4);
910
911        // Test nth()
912        let mut iter = tensor.iter();
913        let third = iter.nth(2).unwrap();
914        assert_eq!(third.value(), 3.0);
915
916        // Test last()
917        let iter = tensor.iter();
918        let last = iter.last().unwrap();
919        assert_eq!(last.value(), 5.0);
920    }
921
922    /// Test ExactSizeIterator
923    #[test]
924    fn test_exact_size_iterator() {
925        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
926        let iter = tensor.iter();
927
928        assert_eq!(iter.len(), 3);
929
930        // Test that len() decreases as we consume the iterator
931        let mut iter = tensor.iter();
932        assert_eq!(iter.len(), 3);
933        iter.next();
934        assert_eq!(iter.len(), 2);
935        iter.next();
936        assert_eq!(iter.len(), 1);
937        iter.next();
938        assert_eq!(iter.len(), 0);
939    }
940
941    /// Test DoubleEndedIterator
942    #[test]
943    fn test_double_ended_iterator() {
944        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
945        let mut iter = tensor.iter();
946
947        // Test next_back()
948        let last = iter.next_back().unwrap();
949        assert_eq!(last.value(), 4.0);
950
951        let first = iter.next().unwrap();
952        assert_eq!(first.value(), 1.0);
953
954        // Test nth_back()
955        let mut iter = tensor.iter();
956        let second_to_last = iter.nth_back(1).unwrap();
957        assert_eq!(second_to_last.value(), 3.0);
958
959        // Test consuming from both ends
960        let mut iter = tensor.iter();
961        assert_eq!(iter.next().unwrap().value(), 1.0);
962        assert_eq!(iter.next_back().unwrap().value(), 4.0);
963        assert_eq!(iter.next().unwrap().value(), 2.0);
964        assert_eq!(iter.next_back().unwrap().value(), 3.0);
965        assert!(iter.next().is_none());
966    }
967
968    /// Test IntoIterator trait
969    #[test]
970    fn test_into_iterator() {
971        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
972
973        // Test with for loop
974        let mut values = Vec::new();
975        for element in &tensor {
976            values.push(element.value());
977        }
978        assert_eq!(values, vec![1.0, 2.0, 3.0]);
979
980        // Test with into_iter() explicitly
981        let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
982        assert_eq!(values, vec![1.0, 2.0, 3.0]);
983    }
984
985    /// Test FromIterator trait (collect)
986    #[test]
987    fn test_from_iterator() {
988        let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
989
990        // Test collecting back to tensor
991        let collected: Tensor = original.iter().collect();
992        assert_eq!(collected.shape().dims, vec![4]);
993        assert_eq!(collected.data(), original.data());
994
995        // Test collecting with transformations
996        let doubled: Tensor = original
997            .iter()
998            .map(|elem| {
999                let val = elem.value();
1000                Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
1001            })
1002            .collect();
1003
1004        assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
1005    }
1006
1007    /// Test standard library iterator methods
1008    #[test]
1009    fn test_std_iterator_methods() {
1010        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1011
1012        // Test map
1013        let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
1014        assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1015
1016        // Test filter
1017        let large_values: Vec<f32> = tensor
1018            .iter()
1019            .filter(|elem| elem.value() > 3.0)
1020            .map(|elem| elem.value())
1021            .collect();
1022        assert_eq!(large_values, vec![4.0, 5.0]);
1023
1024        // Test enumerate
1025        let with_indices: Vec<(usize, f32)> = tensor
1026            .iter()
1027            .enumerate()
1028            .map(|(i, elem)| (i, elem.value()))
1029            .collect();
1030        assert_eq!(
1031            with_indices,
1032            vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
1033        );
1034
1035        // Test fold
1036        let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
1037        assert_eq!(sum, 15.0);
1038
1039        // Test find
1040        let found = tensor.iter().find(|elem| elem.value() == 3.0);
1041        assert!(found.is_some());
1042        assert_eq!(found.unwrap().value(), 3.0);
1043
1044        // Test any/all
1045        let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
1046        assert!(all_positive);
1047
1048        let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
1049        assert!(any_large);
1050    }
1051
1052    /// Test element operations with tensor methods
1053    #[test]
1054    fn test_element_tensor_operations() {
1055        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1056
1057        // Test scalar operations on elements
1058        let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1059        assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
1060
1061        let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
1062        assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
1063
1064        // Test chaining operations
1065        let complex: Tensor = tensor
1066            .iter()
1067            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
1068            .collect();
1069        assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
1070    }
1071
1072    /// Test gradient tracking through element operations
1073    #[test]
1074    fn test_gradient_tracking() {
1075        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
1076            .unwrap()
1077            .with_requires_grad();
1078
1079        // Perform element-wise operations
1080        let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1081
1082        // The result should require gradients if any element requires gradients
1083        // Note: Current implementation creates copies, so gradient tracking is
1084        // implemented but may not propagate back to original tensor
1085        assert!(result.requires_grad());
1086
1087        // For now, just verify the forward pass works with gradient-enabled tensors
1088        // Full gradient propagation would require true view implementation
1089        assert_eq!(result.data(), &[2.0, 4.0]);
1090    }
1091
1092    /// Test with zero-sized tensors
1093    #[test]
1094    fn test_zero_sized_tensor() {
1095        let empty = Tensor::new(vec![0]);
1096        let iter = empty.iter();
1097
1098        assert_eq!(iter.len(), 0);
1099        assert_eq!(iter.size_hint(), (0, Some(0)));
1100
1101        let collected: Tensor = iter.collect();
1102        assert_eq!(collected.size(), 0);
1103    }
1104
1105    /// Test range iteration
1106    #[test]
1107    fn test_range_iteration() {
1108        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1109
1110        // Test middle range
1111        let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
1112        assert_eq!(middle, vec![2.0, 3.0, 4.0]);
1113
1114        // Test out of bounds (should be clamped)
1115        let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
1116        assert_eq!(clamped, vec![4.0, 5.0]);
1117
1118        // Test empty range
1119        let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
1120        assert_eq!(empty, Vec::<f32>::new());
1121    }
1122
1123    /// Test complex iterator chains
1124    #[test]
1125    fn test_complex_chains() {
1126        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
1127
1128        // Complex chain: enumerate -> filter -> map -> collect
1129        let result: Tensor = tensor
1130            .iter()
1131            .enumerate()
1132            .filter(|(i, _)| i % 2 == 0) // Take even indices
1133            .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
1134            .collect();
1135
1136        // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
1137        assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
1138
1139        // Test with rev()
1140        let reversed: Tensor = tensor.iter().rev().take(3).collect();
1141
1142        assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
1143    }
1144
1145    /// Performance test for iterator overhead
1146    #[test]
1147    fn test_performance() {
1148        let large_tensor =
1149            Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
1150                .unwrap();
1151
1152        let start = std::time::Instant::now();
1153
1154        let result: Tensor = large_tensor
1155            .iter()
1156            .map(|elem| elem.mul_scalar(2.0))
1157            .collect();
1158
1159        let duration = start.elapsed();
1160        println!("Iterator performance test took: {:?}", duration);
1161
1162        // Verify correctness
1163        assert_eq!(result.size(), 1000);
1164        assert_eq!(result.data()[0], 0.0);
1165        assert_eq!(result.data()[999], 1998.0);
1166    }
1167}