train_station/tensor/iterator/
mod.rs

1//! Iterator module for tensor element-wise operations
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! # Key Features
9//!
10//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
11//!   DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
12//! - **Gradient Tracking**: Automatic gradient propagation through element operations
13//! - **Performance Optimized**: True zero-copy views with shared memory
14//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
15//! - **Memory Efficient**: Adaptive view creation based on tensor size
16//! - **Zero-Copy Operations**: Element views share memory with source tensor
17//! - **Full Tensor Operations**: Each element supports all tensor methods
18//!
19//! # Performance Characteristics
20//!
21//! - **View Creation**: O(1) per element with true zero-copy views
22//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
23//! - **SIMD Operations**: Full utilization of existing optimizations
24//! - **Gradient Tracking**: True gradient flow with element-level accumulation
25//! - **Iterator Overhead**: Minimal performance impact for element access
26//! - **Collection Optimization**: Efficient reconstruction from element views
27//!
28//! # Examples
29//!
30//! ## Basic Element Iteration
31//!
32//! ```
33//! use train_station::Tensor;
34//!
35//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
36//!
37//! // Basic iteration over elements
38//! for element in tensor.iter() {
39//!     println!("Element value: {}", element.value());
40//! }
41//!
42//! // Collect elements into a new tensor
43//! let collected: Tensor = tensor.iter().collect();
44//! assert_eq!(collected.data(), tensor.data());
45//! ```
46//!
47//! ## Element-Wise Transformations
48//!
49//! ```
50//! use train_station::Tensor;
51//!
52//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
53//!
54//! // Apply tensor operations to each element
55//! let doubled: Tensor = tensor.iter()
56//!     .map(|elem| elem.mul_scalar(2.0))
57//!     .collect();
58//!
59//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
60//!
61//! // Chain multiple operations
62//! let transformed: Tensor = tensor.iter()
63//!     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
64//!     .collect();
65//!
66//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
67//! ```
68//!
69//! ## Advanced Iterator Operations
70//!
71//! ```
72//! use train_station::Tensor;
73//!
74//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
75//!
76//! // Filter elements based on values
77//! let large_values: Tensor = tensor.iter()
78//!     .filter(|elem| elem.value() > 3.0)
79//!     .collect();
80//!
81//! assert_eq!(large_values.data(), &[4.0, 5.0]);
82//!
83//! // Use enumerate for indexed operations
84//! let indexed: Tensor = tensor.iter()
85//!     .enumerate()
86//!     .map(|(i, elem)| elem.add_scalar(i as f32))
87//!     .collect();
88//!
89//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
90//! ```
91//!
92//! ## Range Iteration
93//!
94//! ```
95//! use train_station::Tensor;
96//!
97//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
98//!
99//! // Iterate over a specific range
100//! let middle: Tensor = tensor.iter_range(1, 4)
101//!     .map(|elem| elem.mul_scalar(2.0))
102//!     .collect();
103//!
104//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
105//! ```
106//!
107//! ## Double-Ended Iteration
108//!
109//! ```
110//! use train_station::Tensor;
111//!
112//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
113//!
114//! // Reverse iteration
115//! let reversed: Tensor = tensor.iter().rev().collect();
116//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
117//!
118//! // Iterate from both ends
119//! let mut iter = tensor.iter();
120//! assert_eq!(iter.next().unwrap().value(), 1.0);
121//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
122//! ```
123//!
124//! ## Gradient Tracking
125//!
126//! ```
127//! use train_station::Tensor;
128//!
129//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
130//!     .unwrap()
131//!     .with_requires_grad();
132//!
133//! // Element operations maintain gradient tracking
134//! let result: Tensor = tensor.iter()
135//!     .map(|elem| elem.mul_scalar(2.0))
136//!     .collect();
137//!
138//! assert!(result.requires_grad());
139//! assert_eq!(result.data(), &[2.0, 4.0]);
140//! ```
141//!
142//! # Design Principles
143//!
144//! - **Zero-Copy Views**: Element views share memory with source tensor
145//! - **Full Tensor Operations**: Each element supports all tensor methods
146//! - **Standard Library Integration**: Complete compatibility with Rust iterators
147//! - **Performance First**: Optimized for high-performance element access
148//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
149//! - **Memory Efficiency**: Minimal overhead for element iteration
150//! - **Type Safety**: Compile-time guarantees for iterator operations
151
152use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
153use crate::tensor::core::Tensor;
154use std::iter::{FromIterator, FusedIterator};
155
156/// High-performance iterator over tensor elements as view tensors
157///
158/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
159/// all existing tensor operations and gradient tracking. Implements all
160/// standard iterator traits for maximum compatibility with Rust's ecosystem.
161///
162/// This iterator provides zero-copy access to tensor elements through view
163/// tensors, enabling efficient element-wise operations while maintaining
164/// full compatibility with Rust's standard library iterator methods.
165///
166/// # Performance
167///
168/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
169/// - **O(1) Element Access**: Constant-time view creation for each element
170/// - **Memory Efficient**: ~64 bytes overhead per element view
171/// - **SIMD Compatible**: All tensor operations use existing optimizations
172/// - **Gradient Tracking**: Full gradtrack support through element operations
173///
174/// # Implementation Details
175///
176/// The iterator creates lightweight view tensors on-demand, sharing the same
177/// memory allocation as the source tensor. This ensures zero-copy semantics
178/// while maintaining full tensor operation compatibility.
179///
180/// Each element view is created using `Tensor::element_view()`, which provides
181/// a true view of the underlying data without any copying. The view tensors
182/// support all standard tensor operations including gradient tracking.
183///
184/// # Standard Library Compatibility
185///
186/// This iterator implements all standard iterator traits:
187/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
188/// - `ExactSizeIterator`: Precise size information with `len()`
189/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
190/// - `FusedIterator`: Fused iteration for better performance
191/// - `IntoIterator`: Automatic conversion for `for` loops
192///
193/// # Examples
194///
195/// ## Basic Iteration
196///
197/// ```
198/// use train_station::Tensor;
199///
200/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
201///
202/// // Basic iteration
203/// for element in tensor.iter() {
204///     println!("Element value: {}", element.value());
205/// }
206///
207/// // Standard library methods
208/// let sum: f32 = tensor.iter()
209///     .map(|elem| elem.value())
210///     .sum();
211///
212/// assert_eq!(sum, 6.0);
213/// ```
214///
215/// ## Element Operations
216///
217/// ```
218/// use train_station::Tensor;
219///
220/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
221///
222/// // Tensor operations on elements
223/// let transformed: Tensor = tensor.iter()
224///     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
225///     .collect();
226///
227/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
228/// ```
229///
230/// ## Advanced Iterator Methods
231///
232/// ```
233/// use train_station::Tensor;
234///
235/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
236///
237/// // Filter and transform
238/// let result: Tensor = tensor.iter()
239///     .filter(|elem| elem.value() > 2.0)
240///     .map(|elem| elem.mul_scalar(10.0))
241///     .collect();
242///
243/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
244///
245/// // Reverse iteration
246/// let reversed: Tensor = tensor.iter().rev().collect();
247/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
248/// ```
249pub struct TensorElementIterator<'a> {
250    /// Reference to the source tensor
251    source: &'a Tensor,
252    /// Current position in iteration
253    position: usize,
254    /// End position (exclusive)
255    end: usize,
256}
257
258impl<'a> TensorElementIterator<'a> {
259    /// Create a new iterator over all tensor elements
260    ///
261    /// Creates an iterator that yields view tensors for each element in the
262    /// source tensor. Each element becomes a `Tensor` of shape `[1]` that
263    /// supports all tensor operations and gradient tracking.
264    ///
265    /// # Arguments
266    ///
267    /// * `tensor` - The source tensor to iterate over
268    ///
269    /// # Returns
270    ///
271    /// An iterator that yields view tensors for each element
272    ///
273    /// # Performance
274    ///
275    /// - **O(1) Creation**: Constant-time iterator initialization
276    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
277    /// - **Memory Efficient**: Minimal overhead for iterator state
278    ///
279    /// # Implementation Details
280    ///
281    /// This method creates an iterator that yields view tensors for each element
282    /// in the source tensor. Each element becomes a `Tensor` of shape `[1]` that
283    /// supports all tensor operations and gradient tracking.
284    ///
285    /// The iterator provides zero-copy access to tensor elements through view
286    /// tensors, enabling efficient element-wise operations while maintaining
287    /// full compatibility with Rust's standard library iterator methods.
288    pub fn new(tensor: &'a Tensor) -> Self {
289        Self {
290            source: tensor,
291            position: 0,
292            end: tensor.size(),
293        }
294    }
295
296    /// Create an iterator over a specific range of elements
297    ///
298    /// Creates an iterator that yields view tensors for elements in the specified
299    /// range. The range is automatically clamped to valid tensor bounds for safety.
300    ///
301    /// # Arguments
302    ///
303    /// * `tensor` - The source tensor to iterate over
304    /// * `start` - Starting index (inclusive)
305    /// * `end` - Ending index (exclusive)
306    ///
307    /// # Returns
308    ///
309    /// An iterator that yields view tensors for elements in the specified range
310    ///
311    /// # Safety
312    ///
313    /// The range is automatically clamped to valid tensor bounds:
314    /// - `start` is clamped to `[0, tensor.size()]`
315    /// - `end` is clamped to `[start, tensor.size()]`
316    /// - Empty ranges (start >= end) are handled gracefully
317    ///
318    /// # Performance
319    ///
320    /// - **O(1) Creation**: Constant-time iterator initialization
321    /// - **Bounds Checking**: Automatic range validation and clamping
322    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
323    ///
324    /// # Implementation Details
325    ///
326    /// This method creates an iterator that yields view tensors for elements in
327    /// the specified range. The range is automatically clamped to valid tensor
328    /// bounds for safety, ensuring that out-of-bounds access is handled gracefully.
329    ///
330    /// The iterator provides zero-copy access to tensor elements through view
331    /// tensors, enabling efficient element-wise operations while maintaining
332    /// full compatibility with Rust's standard library iterator methods.
333    pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
334        let end = end.min(tensor.size());
335        let start = start.min(end);
336        Self {
337            source: tensor,
338            position: start,
339            end,
340        }
341    }
342
343    /// Create an optimized element view for the given position
344    ///
345    /// This method creates a true view tensor of shape `[1]` that shares memory
346    /// with the element at the specified index in the source tensor. The view
347    /// enables zero-copy element access with full gradient tracking.
348    ///
349    /// # Arguments
350    ///
351    /// * `index` - Index of the element to create a view for
352    ///
353    /// # Returns
354    ///
355    /// A view tensor of shape `[1]` representing the element at the specified index
356    ///
357    /// # Safety
358    ///
359    /// The caller must ensure that `index < self.source.size()`.
360    ///
361    /// # Performance
362    ///
363    /// - **O(1) View Creation**: Constant-time view tensor creation
364    /// - **Zero-Copy**: View shares memory with source tensor
365    /// - **Memory Efficient**: ~64 bytes overhead for view metadata
366    /// - **Gradient Tracking**: Full gradtrack support through view operations
367    ///
368    /// # Implementation Details
369    ///
370    /// This method delegates to `Tensor::element_view()` which creates a true
371    /// view of the underlying data without any copying. The view tensor supports
372    /// all standard tensor operations including gradient tracking and SIMD
373    /// optimizations.
374    fn create_element_view(&self, index: usize) -> Tensor {
375        debug_assert!(index < self.source.size());
376
377        self.source.element_view(index)
378    }
379}
380
381// ===== Core Iterator Implementation =====
382
383impl<'a> Iterator for TensorElementIterator<'a> {
384    type Item = Tensor;
385
386    #[inline]
387    fn next(&mut self) -> Option<Self::Item> {
388        if self.position < self.end {
389            let view = self.create_element_view(self.position);
390            self.position += 1;
391            Some(view)
392        } else {
393            None
394        }
395    }
396
397    #[inline]
398    fn size_hint(&self) -> (usize, Option<usize>) {
399        let remaining = self.end - self.position;
400        (remaining, Some(remaining))
401    }
402
403    #[inline]
404    fn count(self) -> usize {
405        self.end - self.position
406    }
407
408    #[inline]
409    fn nth(&mut self, n: usize) -> Option<Self::Item> {
410        let new_pos = self.position.saturating_add(n);
411        if new_pos < self.end {
412            self.position = new_pos + 1;
413            Some(self.create_element_view(new_pos))
414        } else {
415            self.position = self.end;
416            None
417        }
418    }
419
420    #[inline]
421    fn last(self) -> Option<Self::Item> {
422        if self.position < self.end {
423            let last_idx = self.end - 1;
424            Some(self.create_element_view(last_idx))
425        } else {
426            None
427        }
428    }
429}
430
431impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
432    #[inline]
433    fn len(&self) -> usize {
434        self.end - self.position
435    }
436}
437
438impl<'a> FusedIterator for TensorElementIterator<'a> {}
439
440impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
441    #[inline]
442    fn next_back(&mut self) -> Option<Self::Item> {
443        if self.position < self.end {
444            self.end -= 1;
445            Some(self.create_element_view(self.end))
446        } else {
447            None
448        }
449    }
450
451    #[inline]
452    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
453        let new_end = self.end.saturating_sub(n + 1);
454        if new_end >= self.position {
455            self.end = new_end;
456            Some(self.create_element_view(self.end))
457        } else {
458            self.position = self.end;
459            None
460        }
461    }
462}
463
464// ===== IntoIterator Implementation =====
465
466impl<'a> IntoIterator for &'a Tensor {
467    type Item = Tensor;
468    type IntoIter = TensorElementIterator<'a>;
469
470    fn into_iter(self) -> Self::IntoIter {
471        TensorElementIterator::new(self)
472    }
473}
474
475// ===== FromIterator Implementation =====
476
477impl FromIterator<Tensor> for Tensor {
478    /// Collect element view tensors back into a single tensor
479    ///
480    /// This method reconstructs a tensor from an iterator of element view tensors.
481    /// It includes optimizations for common patterns and maintains gradient tracking
482    /// when appropriate.
483    ///
484    /// The collection process automatically detects whether all elements are scalar
485    /// views (shape `[1]`) and uses optimized collection strategies accordingly.
486    /// Gradient tracking is preserved when any input element requires gradients.
487    ///
488    /// # Performance
489    ///
490    /// - **Optimized Collection**: Specialized paths for scalar and mixed views
491    /// - **Memory Efficient**: Direct memory copying without intermediate allocations
492    /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
493    /// - **Shape Detection**: Automatic detection of element shapes for optimization
494    ///
495    /// # Implementation Details
496    ///
497    /// The method performs the following steps:
498    /// 1. **Element Collection**: Gathers all element tensors from the iterator
499    /// 2. **Shape Analysis**: Determines if all elements are scalar views
500    /// 3. **Optimized Path**: Uses specialized collection for scalar views
501    /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
502    /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
503    ///
504    /// # Examples
505    ///
506    /// ## Basic Collection
507    ///
508    /// ```
509    /// use train_station::Tensor;
510    ///
511    /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
512    /// let doubled: Tensor = original.iter()
513    ///     .map(|elem| elem.mul_scalar(2.0))
514    ///     .collect();
515    ///
516    /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
517    /// ```
518    ///
519    /// ## Collection with Gradient Tracking
520    ///
521    /// ```
522    /// use train_station::Tensor;
523    ///
524    /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
525    ///     .unwrap()
526    ///     .with_requires_grad();
527    ///
528    /// let result: Tensor = original.iter()
529    ///     .map(|elem| elem.mul_scalar(2.0))
530    ///     .collect();
531    ///
532    /// assert!(result.requires_grad());
533    /// assert_eq!(result.data(), &[2.0, 4.0]);
534    /// ```
535    ///
536    /// ## Empty Iterator Handling
537    ///
538    /// ```
539    /// use train_station::Tensor;
540    ///
541    /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
542    /// assert_eq!(empty.size(), 0);
543    /// assert_eq!(empty.shape().dims, vec![0]);
544    /// ```
545    fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
546        let elements: Vec<Tensor> = iter.into_iter().collect();
547
548        if elements.is_empty() {
549            return Tensor::new(vec![0]);
550        }
551
552        // Check if all elements are scalar views (shape [1])
553        let all_scalars = elements.iter().all(|e| e.shape().dims == vec![1]);
554
555        if all_scalars {
556            // Optimized path for scalar element views
557            Self::collect_scalar_views(elements)
558        } else {
559            // General path for mixed shapes
560            Self::collect_mixed_views(elements)
561        }
562    }
563}
564
565impl Tensor {
566    /// Optimized collection for scalar element views
567    ///
568    /// This method efficiently reconstructs a tensor from scalar element views,
569    /// preserving gradient tracking and using optimized memory operations.
570    ///
571    /// This is the fast path for collection when all elements are scalar views
572    /// (shape `[1]`). It performs direct memory copying and sets up gradient
573    /// tracking when any input element requires gradients.
574    ///
575    /// # Arguments
576    ///
577    /// * `elements` - Vector of scalar element view tensors
578    ///
579    /// # Returns
580    ///
581    /// A new tensor containing all element values in a 1D layout
582    ///
583    /// # Performance
584    ///
585    /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
586    /// - **Gradient Optimization**: Efficient gradient tracking setup
587    /// - **Memory Efficient**: Minimal overhead for collection process
588    /// - **SIMD Compatible**: Result tensor supports all optimizations
589    ///
590    /// # Implementation Details
591    ///
592    /// The method performs the following steps:
593    /// 1. **Allocation**: Creates uninitialized tensor with correct size
594    /// 2. **Gradient Check**: Determines if any element requires gradients
595    /// 3. **Memory Copy**: Direct copying from element views to result
596    /// 4. **Gradient Setup**: Configures gradient tracking when needed
597    /// 5. **Operation Registration**: Registers with gradtrack engine
598    fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
599        let len = elements.len();
600        let mut result = Self::new_uninitialized(vec![len]);
601
602        // Determine if we can track gradients
603        let requires_grad = elements.iter().any(|e| e.requires_grad());
604
605        // Copy data from element views
606        unsafe {
607            let dst = result.as_mut_ptr();
608            for (i, element) in elements.iter().enumerate() {
609                *dst.add(i) = *element.as_ptr();
610            }
611        }
612
613        // Set up gradient tracking
614        if requires_grad && is_grad_enabled() {
615            result.set_requires_grad_internal(true);
616            let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
617            let grad_fn = GradFn::ElementCollection {
618                element_ids: element_ids.clone(),
619                result_shape: vec![len],
620            };
621            result.set_grad_fn(grad_fn.clone());
622            GradEngine::register_operation(result.id(), element_ids, grad_fn);
623        }
624
625        result
626    }
627
628    /// General collection for mixed element shapes
629    ///
630    /// This method handles collection when elements have different shapes,
631    /// flattening all elements into a 1D tensor.
632    ///
633    /// This is the general path for collection when elements have varying shapes.
634    /// It flattens all elements into a single 1D tensor and preserves gradient
635    /// tracking when any input element requires gradients.
636    ///
637    /// # Arguments
638    ///
639    /// * `elements` - Vector of element tensors with potentially different shapes
640    ///
641    /// # Returns
642    ///
643    /// A new 1D tensor containing all flattened element values
644    ///
645    /// # Performance
646    ///
647    /// - **Flattening**: Converts all elements to 1D layout
648    /// - **Memory Copy**: Efficient copying with size calculation
649    /// - **Gradient Preservation**: Maintains gradtrack functionality
650    /// - **Mixed Shapes**: Handles elements with different dimensions
651    ///
652    /// # Implementation Details
653    ///
654    /// The method performs the following steps:
655    /// 1. **Size Calculation**: Sums sizes of all elements for total size
656    /// 2. **Allocation**: Creates uninitialized tensor with total size
657    /// 3. **Sequential Copy**: Copies each element's data sequentially
658    /// 4. **Gradient Setup**: Configures gradient tracking when needed
659    /// 5. **Operation Registration**: Registers with gradtrack engine
660    fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
661        // For mixed shapes, flatten all elements into a 1D tensor
662        let total_size: usize = elements.iter().map(|e| e.size()).sum();
663        let mut result = Self::new_uninitialized(vec![total_size]);
664
665        let requires_grad = elements.iter().any(|e| e.requires_grad());
666        let mut offset = 0;
667
668        unsafe {
669            let dst = result.as_mut_ptr();
670            for element in &elements {
671                let src = element.as_ptr();
672                let size = element.size();
673                std::ptr::copy_nonoverlapping(src, dst.add(offset), size);
674                offset += size;
675            }
676        }
677
678        if requires_grad && is_grad_enabled() {
679            result.set_requires_grad_internal(true);
680            let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
681            let grad_fn = GradFn::ElementCollection {
682                element_ids: element_ids.clone(),
683                result_shape: vec![total_size],
684            };
685            result.set_grad_fn(grad_fn.clone());
686            GradEngine::register_operation(result.id(), element_ids, grad_fn);
687        }
688
689        result
690    }
691
692    /// Create an iterator over tensor elements as view tensors
693    ///
694    /// Each element becomes a `Tensor` of shape `[1]` that supports all
695    /// tensor operations and gradient tracking. This is the main entry point
696    /// for element-wise iteration with full tensor operation support.
697    ///
698    /// The iterator provides zero-copy access to tensor elements through view
699    /// tensors, enabling efficient element-wise operations while maintaining
700    /// full compatibility with Rust's standard library iterator methods.
701    ///
702    /// # Returns
703    ///
704    /// An iterator that yields view tensors for each element
705    ///
706    /// # Performance
707    ///
708    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
709    /// - **O(1) Element Access**: Constant-time view creation for each element
710    /// - **Memory Efficient**: ~64 bytes overhead per element view
711    /// - **SIMD Compatible**: All tensor operations use existing optimizations
712    /// - **Gradient Tracking**: Full gradtrack support through element operations
713    ///
714    /// # Examples
715    ///
716    /// ## Basic Element Operations
717    ///
718    /// ```
719    /// use train_station::Tensor;
720    ///
721    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
722    ///
723    /// // Use any std iterator method
724    /// let result: Tensor = tensor.iter()
725    ///     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
726    ///     .filter(|elem| elem.value() > 3.0)                // Keep values > 3
727    ///     .collect();
728    ///
729    /// assert_eq!(result.data(), &[5.0, 7.0]);
730    /// ```
731    ///
732    /// ## Advanced Iterator Chains
733    ///
734    /// ```
735    /// use train_station::Tensor;
736    ///
737    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
738    ///
739    /// // Chain with enumerate, zip, etc.
740    /// let indexed: Tensor = tensor.iter()
741    ///     .enumerate()
742    ///     .map(|(i, elem)| elem.add_scalar(i as f32))
743    ///     .collect();
744    ///
745    /// assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
746    /// ```
747    ///
748    /// ## Double-Ended Iteration
749    ///
750    /// ```
751    /// use train_station::Tensor;
752    ///
753    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
754    ///
755    /// // Use double-ended iterator
756    /// let reversed: Tensor = tensor.iter()
757    ///     .rev()
758    ///     .collect();
759    ///
760    /// assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
761    /// ```
762    ///
763    /// ## Gradient Tracking
764    ///
765    /// ```
766    /// use train_station::Tensor;
767    ///
768    /// let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
769    ///     .unwrap()
770    ///     .with_requires_grad();
771    ///
772    /// let result: Tensor = tensor.iter()
773    ///     .map(|elem| elem.mul_scalar(2.0))
774    ///     .collect();
775    ///
776    /// assert!(result.requires_grad());
777    /// assert_eq!(result.data(), &[2.0, 4.0]);
778    /// ```
779    pub fn iter(&self) -> TensorElementIterator {
780        TensorElementIterator::new(self)
781    }
782
783    /// Create an iterator over a range of elements
784    ///
785    /// Creates an iterator that yields view tensors for elements in the specified
786    /// range. The range is automatically clamped to valid tensor bounds for safety.
787    ///
788    /// # Arguments
789    ///
790    /// * `start` - Starting index (inclusive)
791    /// * `end` - Ending index (exclusive)
792    ///
793    /// # Returns
794    ///
795    /// An iterator that yields view tensors for elements in the specified range
796    ///
797    /// # Safety
798    ///
799    /// The range is automatically clamped to valid tensor bounds:
800    /// - `start` is clamped to `[0, tensor.size()]`
801    /// - `end` is clamped to `[start, tensor.size()]`
802    /// - Empty ranges (start >= end) are handled gracefully
803    ///
804    /// # Performance
805    ///
806    /// - **O(1) Creation**: Constant-time iterator initialization
807    /// - **Bounds Checking**: Automatic range validation and clamping
808    /// - **Zero-Copy Views**: Each element is a view sharing memory with source
809    /// - **Memory Efficient**: Minimal overhead for range iteration
810    ///
811    /// # Examples
812    ///
813    /// ## Basic Range Iteration
814    ///
815    /// ```
816    /// use train_station::Tensor;
817    ///
818    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
819    /// let middle: Tensor = tensor.iter_range(1, 4)
820    ///     .map(|elem| elem.mul_scalar(2.0))
821    ///     .collect();
822    ///
823    /// assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
824    /// ```
825    ///
826    /// ## Range with Operations
827    ///
828    /// ```
829    /// use train_station::Tensor;
830    ///
831    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
832    ///
833    /// // Apply complex operations to range
834    /// let result: Tensor = tensor.iter_range(0, 3)
835    ///     .enumerate()
836    ///     .map(|(i, elem)| elem.add_scalar(i as f32))
837    ///     .collect();
838    ///
839    /// assert_eq!(result.data(), &[1.0, 3.0, 5.0]);
840    /// ```
841    ///
842    /// ## Out of Bounds Handling
843    ///
844    /// ```
845    /// use train_station::Tensor;
846    ///
847    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
848    ///
849    /// // Out of bounds range is clamped
850    /// let empty: Tensor = tensor.iter_range(5, 10).collect();
851    /// assert_eq!(empty.size(), 0);
852    ///
853    /// // Partial out of bounds
854    /// let partial: Tensor = tensor.iter_range(1, 10).collect();
855    /// assert_eq!(partial.data(), &[2.0, 3.0]);
856    /// ```
857    pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator {
858        TensorElementIterator::with_range(self, start, end)
859    }
860}
861
862#[cfg(test)]
863mod tests {
864    //! Comprehensive tests for tensor element iterator functionality
865    //!
866    //! These tests cover all aspects of the iterator implementation:
867    //! - Basic iteration functionality
868    //! - Standard library trait compliance
869    //! - Gradient tracking through element operations
870    //! - Performance characteristics
871    //! - Edge cases and error conditions
872
873    use super::*;
874
875    /// Test basic iterator functionality
876    #[test]
877    fn test_basic_iteration() {
878        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
879
880        let elements: Vec<Tensor> = tensor.iter().collect();
881        assert_eq!(elements.len(), 4);
882
883        // Check that each element is a scalar tensor with correct value
884        for (i, elem) in elements.iter().enumerate() {
885            assert_eq!(elem.shape().dims, vec![1]);
886            assert_eq!(elem.size(), 1);
887            assert_eq!(elem.value(), (i + 1) as f32);
888        }
889    }
890
891    /// Test Iterator trait methods
892    #[test]
893    fn test_iterator_trait_methods() {
894        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
895        let mut iter = tensor.iter();
896
897        // Test next()
898        let first = iter.next().unwrap();
899        assert_eq!(first.value(), 1.0);
900
901        // Test size_hint()
902        assert_eq!(iter.size_hint(), (4, Some(4)));
903
904        // Test count()
905        assert_eq!(iter.count(), 4);
906
907        // Test nth()
908        let mut iter = tensor.iter();
909        let third = iter.nth(2).unwrap();
910        assert_eq!(third.value(), 3.0);
911
912        // Test last()
913        let iter = tensor.iter();
914        let last = iter.last().unwrap();
915        assert_eq!(last.value(), 5.0);
916    }
917
918    /// Test ExactSizeIterator
919    #[test]
920    fn test_exact_size_iterator() {
921        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
922        let iter = tensor.iter();
923
924        assert_eq!(iter.len(), 3);
925
926        // Test that len() decreases as we consume the iterator
927        let mut iter = tensor.iter();
928        assert_eq!(iter.len(), 3);
929        iter.next();
930        assert_eq!(iter.len(), 2);
931        iter.next();
932        assert_eq!(iter.len(), 1);
933        iter.next();
934        assert_eq!(iter.len(), 0);
935    }
936
937    /// Test DoubleEndedIterator
938    #[test]
939    fn test_double_ended_iterator() {
940        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
941        let mut iter = tensor.iter();
942
943        // Test next_back()
944        let last = iter.next_back().unwrap();
945        assert_eq!(last.value(), 4.0);
946
947        let first = iter.next().unwrap();
948        assert_eq!(first.value(), 1.0);
949
950        // Test nth_back()
951        let mut iter = tensor.iter();
952        let second_to_last = iter.nth_back(1).unwrap();
953        assert_eq!(second_to_last.value(), 3.0);
954
955        // Test consuming from both ends
956        let mut iter = tensor.iter();
957        assert_eq!(iter.next().unwrap().value(), 1.0);
958        assert_eq!(iter.next_back().unwrap().value(), 4.0);
959        assert_eq!(iter.next().unwrap().value(), 2.0);
960        assert_eq!(iter.next_back().unwrap().value(), 3.0);
961        assert!(iter.next().is_none());
962    }
963
964    /// Test IntoIterator trait
965    #[test]
966    fn test_into_iterator() {
967        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
968
969        // Test with for loop
970        let mut values = Vec::new();
971        for element in &tensor {
972            values.push(element.value());
973        }
974        assert_eq!(values, vec![1.0, 2.0, 3.0]);
975
976        // Test with into_iter() explicitly
977        let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
978        assert_eq!(values, vec![1.0, 2.0, 3.0]);
979    }
980
981    /// Test FromIterator trait (collect)
982    #[test]
983    fn test_from_iterator() {
984        let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
985
986        // Test collecting back to tensor
987        let collected: Tensor = original.iter().collect();
988        assert_eq!(collected.shape().dims, vec![4]);
989        assert_eq!(collected.data(), original.data());
990
991        // Test collecting with transformations
992        let doubled: Tensor = original
993            .iter()
994            .map(|elem| {
995                let val = elem.value();
996                Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
997            })
998            .collect();
999
1000        assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
1001    }
1002
1003    /// Test standard library iterator methods
1004    #[test]
1005    fn test_std_iterator_methods() {
1006        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1007
1008        // Test map
1009        let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
1010        assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1011
1012        // Test filter
1013        let large_values: Vec<f32> = tensor
1014            .iter()
1015            .filter(|elem| elem.value() > 3.0)
1016            .map(|elem| elem.value())
1017            .collect();
1018        assert_eq!(large_values, vec![4.0, 5.0]);
1019
1020        // Test enumerate
1021        let with_indices: Vec<(usize, f32)> = tensor
1022            .iter()
1023            .enumerate()
1024            .map(|(i, elem)| (i, elem.value()))
1025            .collect();
1026        assert_eq!(
1027            with_indices,
1028            vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
1029        );
1030
1031        // Test fold
1032        let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
1033        assert_eq!(sum, 15.0);
1034
1035        // Test find
1036        let found = tensor.iter().find(|elem| elem.value() == 3.0);
1037        assert!(found.is_some());
1038        assert_eq!(found.unwrap().value(), 3.0);
1039
1040        // Test any/all
1041        let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
1042        assert!(all_positive);
1043
1044        let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
1045        assert!(any_large);
1046    }
1047
1048    /// Test element operations with tensor methods
1049    #[test]
1050    fn test_element_tensor_operations() {
1051        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1052
1053        // Test scalar operations on elements
1054        let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1055        assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
1056
1057        let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
1058        assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
1059
1060        // Test chaining operations
1061        let complex: Tensor = tensor
1062            .iter()
1063            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
1064            .collect();
1065        assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
1066    }
1067
1068    /// Test gradient tracking through element operations
1069    #[test]
1070    fn test_gradient_tracking() {
1071        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
1072            .unwrap()
1073            .with_requires_grad();
1074
1075        // Perform element-wise operations
1076        let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1077
1078        // The result should require gradients if any element requires gradients
1079        // Note: Current implementation creates copies, so gradient tracking is
1080        // implemented but may not propagate back to original tensor
1081        assert!(result.requires_grad());
1082
1083        // For now, just verify the forward pass works with gradient-enabled tensors
1084        // Full gradient propagation would require true view implementation
1085        assert_eq!(result.data(), &[2.0, 4.0]);
1086    }
1087
1088    /// Test with zero-sized tensors
1089    #[test]
1090    fn test_zero_sized_tensor() {
1091        let empty = Tensor::new(vec![0]);
1092        let iter = empty.iter();
1093
1094        assert_eq!(iter.len(), 0);
1095        assert_eq!(iter.size_hint(), (0, Some(0)));
1096
1097        let collected: Tensor = iter.collect();
1098        assert_eq!(collected.size(), 0);
1099    }
1100
1101    /// Test range iteration
1102    #[test]
1103    fn test_range_iteration() {
1104        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1105
1106        // Test middle range
1107        let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
1108        assert_eq!(middle, vec![2.0, 3.0, 4.0]);
1109
1110        // Test out of bounds (should be clamped)
1111        let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
1112        assert_eq!(clamped, vec![4.0, 5.0]);
1113
1114        // Test empty range
1115        let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
1116        assert_eq!(empty, Vec::<f32>::new());
1117    }
1118
1119    /// Test complex iterator chains
1120    #[test]
1121    fn test_complex_chains() {
1122        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
1123
1124        // Complex chain: enumerate -> filter -> map -> collect
1125        let result: Tensor = tensor
1126            .iter()
1127            .enumerate()
1128            .filter(|(i, _)| i % 2 == 0) // Take even indices
1129            .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
1130            .collect();
1131
1132        // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
1133        assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
1134
1135        // Test with rev()
1136        let reversed: Tensor = tensor.iter().rev().take(3).collect();
1137
1138        assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
1139    }
1140
1141    /// Performance test for iterator overhead
1142    #[test]
1143    fn test_performance() {
1144        let large_tensor =
1145            Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
1146                .unwrap();
1147
1148        let start = std::time::Instant::now();
1149
1150        let result: Tensor = large_tensor
1151            .iter()
1152            .map(|elem| elem.mul_scalar(2.0))
1153            .collect();
1154
1155        let duration = start.elapsed();
1156        println!("Iterator performance test took: {:?}", duration);
1157
1158        // Verify correctness
1159        assert_eq!(result.size(), 1000);
1160        assert_eq!(result.data()[0], 0.0);
1161        assert_eq!(result.data()[999], 1998.0);
1162    }
1163}