train_station/tensor/iterator/
mod.rs

1//! Iterator module for tensor element-wise operations
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! # Key Features
9//!
10//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
11//!   DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
12//! - **Gradient Tracking**: Automatic gradient propagation through element operations
13//! - **Performance Optimized**: True zero-copy views with shared memory
14//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
15//! - **Memory Efficient**: Adaptive view creation based on tensor size
16//! - **Zero-Copy Operations**: Element views share memory with source tensor
17//! - **Full Tensor Operations**: Each element supports all tensor methods
18//!
19//! # Performance Characteristics
20//!
21//! - **View Creation**: O(1) per element with true zero-copy views
22//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
23//! - **SIMD Operations**: Full utilization of existing optimizations
24//! - **Gradient Tracking**: True gradient flow with element-level accumulation
25//! - **Iterator Overhead**: Minimal performance impact for element access
26//! - **Collection Optimization**: Efficient reconstruction from element views
27//!
28//! # Examples
29//!
30//! ## Basic Element Iteration
31//!
32//! ```
33//! use train_station::Tensor;
34//!
35//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
36//!
37//! // Basic iteration over elements
38//! for element in tensor.iter() {
39//!     println!("Element value: {}", element.value());
40//! }
41//!
42//! // Collect elements into a new tensor
43//! let collected: Tensor = tensor.iter().collect();
44//! assert_eq!(collected.data(), tensor.data());
45//! ```
46//!
47//! ## Element-Wise Transformations
48//!
49//! ```
50//! use train_station::Tensor;
51//!
52//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
53//!
54//! // Apply tensor operations to each element
55//! let doubled: Tensor = tensor.iter()
56//!     .map(|elem| elem.mul_scalar(2.0))
57//!     .collect();
58//!
59//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
60//!
61//! // Chain multiple operations
62//! let transformed: Tensor = tensor.iter()
63//!     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
64//!     .collect();
65//!
66//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
67//! ```
68//!
69//! ## Advanced Iterator Operations
70//!
71//! ```
72//! use train_station::Tensor;
73//!
74//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
75//!
76//! // Filter elements based on values
77//! let large_values: Tensor = tensor.iter()
78//!     .filter(|elem| elem.value() > 3.0)
79//!     .collect();
80//!
81//! assert_eq!(large_values.data(), &[4.0, 5.0]);
82//!
83//! // Use enumerate for indexed operations
84//! let indexed: Tensor = tensor.iter()
85//!     .enumerate()
86//!     .map(|(i, elem)| elem.add_scalar(i as f32))
87//!     .collect();
88//!
89//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
90//! ```
91//!
92//! ## Range Iteration
93//!
94//! ```
95//! use train_station::Tensor;
96//!
97//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
98//!
99//! // Iterate over a specific range
100//! let middle: Tensor = tensor.iter_range(1, 4)
101//!     .map(|elem| elem.mul_scalar(2.0))
102//!     .collect();
103//!
104//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
105//! ```
106//!
107//! ## Double-Ended Iteration
108//!
109//! ```
110//! use train_station::Tensor;
111//!
112//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
113//!
114//! // Reverse iteration
115//! let reversed: Tensor = tensor.iter().rev().collect();
116//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
117//!
118//! // Iterate from both ends
119//! let mut iter = tensor.iter();
120//! assert_eq!(iter.next().unwrap().value(), 1.0);
121//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
122//! ```
123//!
124//! ## Gradient Tracking
125//!
126//! ```
127//! use train_station::Tensor;
128//!
129//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
130//!     .unwrap()
131//!     .with_requires_grad();
132//!
133//! // Element operations maintain gradient tracking
134//! let result: Tensor = tensor.iter()
135//!     .map(|elem| elem.mul_scalar(2.0))
136//!     .collect();
137//!
138//! assert!(result.requires_grad());
139//! assert_eq!(result.data(), &[2.0, 4.0]);
140//! ```
141//!
142//! # Design Principles
143//!
144//! - **Zero-Copy Views**: Element views share memory with source tensor
145//! - **Full Tensor Operations**: Each element supports all tensor methods
146//! - **Standard Library Integration**: Complete compatibility with Rust iterators
147//! - **Performance First**: Optimized for high-performance element access
148//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
149//! - **Memory Efficiency**: Minimal overhead for element iteration
150//! - **Type Safety**: Compile-time guarantees for iterator operations
151
152pub mod chunks;
153pub mod collect;
154pub mod element;
155pub mod value;
156pub mod viewdim;
157pub mod windows;
158
159use crate::gradtrack::is_grad_enabled;
160use crate::tensor::core::Tensor;
161pub use collect::{TensorCollectExt, ValuesCollectExt};
162use std::iter::FromIterator;
163
164/// High-performance iterator over tensor elements as view tensors
165///
166/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
167/// all existing tensor operations and gradient tracking. Implements all
168/// standard iterator traits for maximum compatibility with Rust's ecosystem.
169///
170/// This iterator provides zero-copy access to tensor elements through view
171/// tensors, enabling efficient element-wise operations while maintaining
172/// full compatibility with Rust's standard library iterator methods.
173///
174/// # Performance
175///
176/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
177/// - **O(1) Element Access**: Constant-time view creation for each element
178/// - **Memory Efficient**: ~64 bytes overhead per element view
179/// - **SIMD Compatible**: All tensor operations use existing optimizations
180/// - **Gradient Tracking**: Full gradtrack support through element operations
181///
182/// # Implementation Details
183///
184/// The iterator creates lightweight view tensors on-demand, sharing the same
185/// memory allocation as the source tensor. This ensures zero-copy semantics
186/// while maintaining full tensor operation compatibility.
187///
188/// Each element view is created using `Tensor::element_view()`, which provides
189/// a true view of the underlying data without any copying. The view tensors
190/// support all standard tensor operations including gradient tracking.
191///
192/// # Standard Library Compatibility
193///
194/// This iterator implements all standard iterator traits:
195/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
196/// - `ExactSizeIterator`: Precise size information with `len()`
197/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
198/// - `FusedIterator`: Fused iteration for better performance
199/// - `IntoIterator`: Automatic conversion for `for` loops
200///
201/// # Examples
202///
203/// ## Basic Iteration
204///
205/// ```
206/// use train_station::Tensor;
207///
208/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
209///
210/// // Basic iteration
211/// for element in tensor.iter() {
212///     println!("Element value: {}", element.value());
213/// }
214///
215/// // Standard library methods
216/// let sum: f32 = tensor.iter()
217///     .map(|elem| elem.value())
218///     .sum();
219///
220/// assert_eq!(sum, 6.0);
221/// ```
222///
223/// ## Element Operations
224///
225/// ```
226/// use train_station::Tensor;
227///
228/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
229///
230/// // Tensor operations on elements
231/// let transformed: Tensor = tensor.iter()
232///     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
233///     .collect();
234///
235/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
236/// ```
237///
238/// ## Advanced Iterator Methods
239///
240/// ```
241/// use train_station::Tensor;
242///
243/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
244///
245/// // Filter and transform
246/// let result: Tensor = tensor.iter()
247///     .filter(|elem| elem.value() > 2.0)
248///     .map(|elem| elem.mul_scalar(10.0))
249///     .collect();
250///
251/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
252///
253/// // Reverse iteration
254/// let reversed: Tensor = tensor.iter().rev().collect();
255/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
256/// ```
257// Re-export iterator types from submodules for public API
258// ===== IntoIterator Implementation =====
259/// IntoIterator for &Tensor now iterates outermost dimension, yielding sub-tensors (views)
260impl<'a> IntoIterator for &'a Tensor {
261    type Item = Tensor;
262    type IntoIter = crate::tensor::iterator::viewdim::TensorDimIterator<'a>;
263
264    fn into_iter(self) -> Self::IntoIter {
265        // Iterate outermost dim by default
266        self.iter_dim(0)
267    }
268}
269
270// ===== FromIterator Implementation =====
271
272impl FromIterator<Tensor> for Tensor {
273    /// Collect element view tensors back into a single tensor
274    ///
275    /// This method reconstructs a tensor from an iterator of element view tensors.
276    /// It includes optimizations for common patterns and maintains gradient tracking
277    /// when appropriate.
278    ///
279    /// The collection process automatically detects whether all elements are scalar
280    /// views (shape `[1]`) and uses optimized collection strategies accordingly.
281    /// Gradient tracking is preserved when any input element requires gradients.
282    ///
283    /// # Performance
284    ///
285    /// - **Optimized Collection**: Specialized paths for scalar and mixed views
286    /// - **Memory Efficient**: Direct memory copying without intermediate allocations
287    /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
288    /// - **Shape Detection**: Automatic detection of element shapes for optimization
289    ///
290    /// # Implementation Details
291    ///
292    /// The method performs the following steps:
293    /// 1. **Element Collection**: Gathers all element tensors from the iterator
294    /// 2. **Shape Analysis**: Determines if all elements are scalar views
295    /// 3. **Optimized Path**: Uses specialized collection for scalar views
296    /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
297    /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
298    ///
299    /// # Examples
300    ///
301    /// ## Basic Collection
302    ///
303    /// ```
304    /// use train_station::Tensor;
305    ///
306    /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
307    /// let doubled: Tensor = original.iter()
308    ///     .map(|elem| elem.mul_scalar(2.0))
309    ///     .collect();
310    ///
311    /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
312    /// ```
313    ///
314    /// ## Collection with Gradient Tracking
315    ///
316    /// ```
317    /// use train_station::Tensor;
318    ///
319    /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
320    ///     .unwrap()
321    ///     .with_requires_grad();
322    ///
323    /// let result: Tensor = original.iter()
324    ///     .map(|elem| elem.mul_scalar(2.0))
325    ///     .collect();
326    ///
327    /// assert!(result.requires_grad());
328    /// assert_eq!(result.data(), &[2.0, 4.0]);
329    /// ```
330    ///
331    /// ## Empty Iterator Handling
332    ///
333    /// ```
334    /// use train_station::Tensor;
335    ///
336    /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
337    /// assert_eq!(empty.size(), 0);
338    /// assert_eq!(empty.shape().dims(), vec![0]);
339    /// ```
340    fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
341        let elements: Vec<Tensor> = iter.into_iter().collect();
342
343        if elements.is_empty() {
344            return Tensor::new(vec![0]);
345        }
346
347        // Check if all elements are scalars (size == 1). Supports both [1] and 0-D [] shapes
348        let all_scalars = elements.iter().all(|e| e.size() == 1);
349
350        if all_scalars {
351            // Optimized path for scalar element views
352            Self::collect_scalar_views(elements)
353        } else {
354            // General path for mixed shapes
355            Self::collect_mixed_views(elements)
356        }
357    }
358}
359
360impl Tensor {
361    /// Optimized collection for scalar element views
362    ///
363    /// This method efficiently reconstructs a tensor from scalar element views,
364    /// preserving gradient tracking and using optimized memory operations.
365    ///
366    /// This is the fast path for collection when all elements are scalar views
367    /// (shape `[1]`). It performs direct memory copying and sets up gradient
368    /// tracking when any input element requires gradients.
369    ///
370    /// # Arguments
371    ///
372    /// * `elements` - Vector of scalar element view tensors
373    ///
374    /// # Returns
375    ///
376    /// A new tensor containing all element values in a 1D layout
377    ///
378    /// # Performance
379    ///
380    /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
381    /// - **Gradient Optimization**: Efficient gradient tracking setup
382    /// - **Memory Efficient**: Minimal overhead for collection process
383    /// - **SIMD Compatible**: Result tensor supports all optimizations
384    ///
385    /// # Implementation Details
386    ///
387    /// The method performs the following steps:
388    /// 1. **Allocation**: Creates uninitialized tensor with correct size
389    /// 2. **Gradient Check**: Determines if any element requires gradients
390    /// 3. **Memory Copy**: Direct copying from element views to result
391    /// 4. **Gradient Setup**: Configures gradient tracking when needed
392    /// 5. **Operation Registration**: Registers with gradtrack engine
393    fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
394        if elements.is_empty() {
395            return Tensor::new(vec![0]);
396        }
397        // Fast path: if no element requires grad or gradients are disabled, copy directly
398        let any_requires = elements.iter().any(|t| t.requires_grad());
399        if !any_requires || !is_grad_enabled() {
400            let n = elements.len();
401            let mut out = Tensor::new_uninitialized(vec![n]);
402            unsafe {
403                let dst = out.as_mut_ptr();
404                for (i, t) in elements.iter().enumerate() {
405                    debug_assert_eq!(t.size(), 1);
406                    std::ptr::copy_nonoverlapping(t.as_ptr(), dst.add(i), 1);
407                }
408            }
409            return out;
410        }
411
412        // Grad-preserving path: concat along dim 0 then flatten
413        let mut prepped: Vec<Tensor> = Vec::with_capacity(elements.len());
414        for t in elements.into_iter() {
415            if t.shape().rank() == 0 {
416                prepped.push(t.unsqueeze(0)); // [] -> [1]
417            } else {
418                prepped.push(t);
419            }
420        }
421        let concatenated = Tensor::cat(&prepped, 0); // shape: [N, 1]
422        concatenated.flatten() // shape: [N]
423    }
424
425    /// General collection for mixed element shapes
426    ///
427    /// This method handles collection when elements have different shapes,
428    /// flattening all elements into a 1D tensor.
429    ///
430    /// This is the general path for collection when elements have varying shapes.
431    /// It flattens all elements into a single 1D tensor and preserves gradient
432    /// tracking when any input element requires gradients.
433    ///
434    /// # Arguments
435    ///
436    /// * `elements` - Vector of element tensors with potentially different shapes
437    ///
438    /// # Returns
439    ///
440    /// A new 1D tensor containing all flattened element values
441    ///
442    /// # Performance
443    ///
444    /// - **Flattening**: Converts all elements to 1D layout
445    /// - **Memory Copy**: Efficient copying with size calculation
446    /// - **Gradient Preservation**: Maintains gradtrack functionality
447    /// - **Mixed Shapes**: Handles elements with different dimensions
448    ///
449    /// # Implementation Details
450    ///
451    /// The method performs the following steps:
452    /// 1. **Size Calculation**: Sums sizes of all elements for total size
453    /// 2. **Allocation**: Creates uninitialized tensor with total size
454    /// 3. **Sequential Copy**: Copies each element's data sequentially
455    /// 4. **Gradient Setup**: Configures gradient tracking when needed
456    /// 5. **Operation Registration**: Registers with gradtrack engine
457    fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
458        let requires_grad = elements.iter().any(|e| e.requires_grad());
459        // Concatenate then flatten to preserve gradient connections
460        let concatenated = Tensor::cat(&elements, 0);
461        let flattened = concatenated.flatten();
462        if requires_grad && is_grad_enabled() {
463            // Flags are handled by ops; return as-is
464        }
465        flattened
466    }
467
468    // Iterator entry points are implemented in iterator/element.rs
469}
470
471// Redundant iterator type and collection trait/impls have been moved to dedicated files.
472
473#[cfg(test)]
474mod tests {
475    //! Comprehensive tests for tensor element iterator functionality
476    //!
477    //! These tests cover all aspects of the iterator implementation:
478    //! - Basic iteration functionality
479    //! - Standard library trait compliance
480    //! - Gradient tracking through element operations
481    //! - Performance characteristics
482    //! - Edge cases and error conditions
483
484    use super::*;
485
486    /// Test basic iterator functionality
487    #[test]
488    fn test_basic_iteration() {
489        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
490
491        let elements: Vec<Tensor> = tensor.iter_elements().collect();
492        assert_eq!(elements.len(), 4);
493
494        // Check that each element is a scalar tensor with correct value
495        for (i, elem) in elements.iter().enumerate() {
496            assert_eq!(elem.shape().dims(), vec![1]);
497            assert_eq!(elem.size(), 1);
498            assert_eq!(elem.value(), (i + 1) as f32);
499        }
500    }
501
502    /// Test Iterator trait methods
503    #[test]
504    fn test_iterator_trait_methods() {
505        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
506        let mut iter = tensor.iter();
507
508        // Test next()
509        let _first = iter.next().unwrap();
510        assert_eq!(_first.value(), 1.0);
511
512        // Test size_hint() after consuming one element
513        assert_eq!(iter.size_hint(), (4, Some(4)));
514
515        // Test count()
516        assert_eq!(iter.count(), 4);
517
518        // Test nth()
519        let mut iter = tensor.iter();
520        let third = iter.nth(2).unwrap();
521        assert_eq!(third.value(), 3.0);
522
523        // Test last()
524        let mut iter = tensor.iter();
525        let last = iter.next_back().unwrap();
526        assert_eq!(last.value(), 5.0);
527    }
528
529    /// Test ExactSizeIterator
530    #[test]
531    fn test_exact_size_iterator() {
532        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
533        let iter = tensor.iter();
534
535        assert_eq!(iter.len(), 3);
536
537        // Test that len() decreases as we consume the iterator
538        let mut iter = tensor.iter();
539        assert_eq!(iter.len(), 3);
540        iter.next();
541        assert_eq!(iter.len(), 2);
542        iter.next();
543        assert_eq!(iter.len(), 1);
544        iter.next();
545        assert_eq!(iter.len(), 0);
546    }
547
548    /// Test DoubleEndedIterator
549    #[test]
550    fn test_double_ended_iterator() {
551        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
552        let mut iter = tensor.iter();
553
554        // Test next_back()
555        let last = iter.next_back().unwrap();
556        assert_eq!(last.value(), 4.0);
557
558        let first = iter.next().unwrap();
559        assert_eq!(first.value(), 1.0);
560
561        // Test nth_back()
562        let mut iter = tensor.iter();
563        let second_to_last = iter.nth_back(1).unwrap();
564        assert_eq!(second_to_last.value(), 3.0);
565
566        // Test consuming from both ends
567        let mut iter = tensor.iter();
568        assert_eq!(iter.next().unwrap().value(), 1.0);
569        assert_eq!(iter.next_back().unwrap().value(), 4.0);
570        assert_eq!(iter.next().unwrap().value(), 2.0);
571        assert_eq!(iter.next_back().unwrap().value(), 3.0);
572        assert!(iter.next().is_none());
573    }
574
575    /// Test IntoIterator trait
576    #[test]
577    fn test_into_iterator() {
578        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
579
580        // Test with for loop
581        let mut values = Vec::new();
582        for element in &tensor {
583            values.push(element.value());
584        }
585        assert_eq!(values, vec![1.0, 2.0, 3.0]);
586
587        // Test with into_iter() explicitly
588        let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
589        assert_eq!(values, vec![1.0, 2.0, 3.0]);
590    }
591
592    /// Test FromIterator trait (collect)
593    #[test]
594    fn test_from_iterator() {
595        let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
596
597        // Test collecting back to tensor
598        let collected: Tensor = original.iter().collect();
599        assert_eq!(collected.shape().dims(), vec![4]);
600        assert_eq!(collected.data(), original.data());
601
602        // Test collecting with transformations
603        let doubled: Tensor = original
604            .iter()
605            .map(|elem| {
606                let val = elem.value();
607                Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
608            })
609            .collect();
610
611        assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
612    }
613
614    /// Test standard library iterator methods
615    #[test]
616    fn test_std_iterator_methods() {
617        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
618
619        // Test map
620        let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
621        assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
622
623        // Test filter
624        let large_values: Vec<f32> = tensor
625            .iter()
626            .filter(|elem| elem.value() > 3.0)
627            .map(|elem| elem.value())
628            .collect();
629        assert_eq!(large_values, vec![4.0, 5.0]);
630
631        // Test enumerate
632        let with_indices: Vec<(usize, f32)> = tensor
633            .iter()
634            .enumerate()
635            .map(|(i, elem)| (i, elem.value()))
636            .collect();
637        assert_eq!(
638            with_indices,
639            vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
640        );
641
642        // Test fold
643        let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
644        assert_eq!(sum, 15.0);
645
646        // Test find
647        let found = tensor.iter().find(|elem| elem.value() == 3.0);
648        assert!(found.is_some());
649        assert_eq!(found.unwrap().value(), 3.0);
650
651        // Test any/all
652        let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
653        assert!(all_positive);
654
655        let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
656        assert!(any_large);
657    }
658
659    /// Test element operations with tensor methods
660    #[test]
661    fn test_element_tensor_operations() {
662        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
663
664        // Test scalar operations on elements
665        let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
666        assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
667
668        let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
669        assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
670
671        // Test chaining operations
672        let complex: Tensor = tensor
673            .iter()
674            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
675            .collect();
676        assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
677    }
678
679    /// Test gradient tracking through element operations
680    #[test]
681    fn test_gradient_tracking() {
682        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
683            .unwrap()
684            .with_requires_grad();
685
686        // Perform element-wise operations
687        let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
688
689        // The result should require gradients if any element requires gradients
690        // Note: Current implementation creates copies, so gradient tracking is
691        // implemented but may not propagate back to original tensor
692        assert!(result.requires_grad());
693
694        // For now, just verify the forward pass works with gradient-enabled tensors
695        // Full gradient propagation would require true view implementation
696        assert_eq!(result.data(), &[2.0, 4.0]);
697    }
698
699    /// Test with zero-sized tensors
700    #[test]
701    fn test_zero_sized_tensor() {
702        let empty = Tensor::new(vec![0]);
703        let iter = empty.iter();
704
705        assert_eq!(iter.len(), 0);
706        assert_eq!(iter.size_hint(), (0, Some(0)));
707
708        let collected: Tensor = iter.collect();
709        assert_eq!(collected.size(), 0);
710    }
711
712    /// Test range iteration
713    #[test]
714    fn test_range_iteration() {
715        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
716
717        // Test middle range
718        let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
719        assert_eq!(middle, vec![2.0, 3.0, 4.0]);
720
721        // Test out of bounds (should be clamped)
722        let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
723        assert_eq!(clamped, vec![4.0, 5.0]);
724
725        // Test empty range
726        let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
727        assert_eq!(empty, Vec::<f32>::new());
728    }
729
730    /// Test complex iterator chains
731    #[test]
732    fn test_complex_chains() {
733        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
734
735        // Complex chain: enumerate -> filter -> map -> collect
736        let result: Tensor = tensor
737            .iter()
738            .enumerate()
739            .filter(|(i, _)| i % 2 == 0) // Take even indices
740            .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
741            .collect();
742
743        // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
744        assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
745
746        // Test with rev()
747        let reversed: Tensor = tensor.iter().rev().take(3).collect();
748
749        assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
750    }
751
752    /// Performance test for iterator overhead
753    #[test]
754    fn test_performance() {
755        let large_tensor =
756            Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
757                .unwrap();
758
759        let start = std::time::Instant::now();
760
761        let result: Tensor = large_tensor
762            .iter()
763            .map(|elem| elem.mul_scalar(2.0))
764            .collect();
765
766        let duration = start.elapsed();
767        println!("Iterator performance test took: {:?}", duration);
768
769        // Verify correctness
770        assert_eq!(result.size(), 1000);
771        assert_eq!(result.data()[0], 0.0);
772        assert_eq!(result.data()[999], 1998.0);
773    }
774
775    /// Test chunks iterator basic behavior
776    #[test]
777    fn test_chunks_basic() {
778        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
779        let chunks: Vec<Tensor> = t.iter_chunks(2).collect();
780        assert_eq!(chunks.len(), 3);
781        assert_eq!(chunks[0].data(), &[1.0, 2.0]);
782        assert_eq!(chunks[1].data(), &[3.0, 4.0]);
783        assert_eq!(chunks[2].data(), &[5.0]);
784    }
785
786    /// Test chunks_exact with remainder
787    #[test]
788    fn test_chunks_exact_with_remainder() {
789        let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
790        let mut it = t.iter_chunks_exact(2);
791        let v0 = it.next().unwrap();
792        let v1 = it.next().unwrap();
793        assert!(it.next().is_none());
794        assert_eq!(v0.data(), &[10.0, 20.0]);
795        assert_eq!(v1.data(), &[30.0, 40.0]);
796        let r = it.remainder();
797        assert_eq!(r.data(), &[50.0]);
798    }
799
800    /// Test windows iterator with step 1
801    #[test]
802    fn test_windows_basic() {
803        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
804        let wins: Vec<Tensor> = t.iter_windows(3).collect();
805        assert_eq!(wins.len(), 2);
806        assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
807        assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
808    }
809
810    /// Test windows iterator with custom step
811    #[test]
812    fn test_windows_step() {
813        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
814        let wins: Vec<Tensor> = t.iter_windows_step(2, 2).collect();
815        assert_eq!(wins.len(), 2);
816        assert_eq!(wins[0].data(), &[1.0, 2.0]);
817        assert_eq!(wins[1].data(), &[3.0, 4.0]);
818    }
819
820    /// Test collect_shape utility
821    #[test]
822    fn test_collect_shape() {
823        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
824        let mat = t.iter_chunks(2).collect_shape(vec![3, 2]);
825        assert_eq!(mat.shape().dims(), &[3, 2]);
826        assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
827    }
828
829    /// Performance comparison: Tensor iterator/view system vs Vec iteration
830    ///
831    /// This test compares end-to-end pipelines (creation → iteration → ops → collection)
832    /// across multiple sizes and loop styles, and prints a concise summary.
833    #[test]
834    fn test_iterator_vs_vec_performance_summary() {
835        use std::time::Instant;
836
837        let sizes: [usize; 3] = [100, 1000, 10_000];
838        let iterations: usize = 3; // exclude 1 warmup
839        let chunk_size: usize = 8192;
840
841        println!(
842            "Iterator/View vs Vec performance ({} runs avg, chunk_size={})",
843            iterations, chunk_size
844        );
845
846        for &n in &sizes {
847            // -------- Element-wise iterator pipeline (Tensor) --------
848            let mut total_elem_tensor = std::time::Duration::ZERO;
849            for run in 0..(iterations + 1) {
850                let t0 = Instant::now();
851                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
852                let t = Tensor::from_slice(&data, vec![n]).unwrap();
853                let out: Tensor = t
854                    .iter_elements()
855                    .map(|e| e.mul_scalar(2.0).add_scalar(1.0))
856                    .collect();
857                // Touch a value to avoid any dead-code elimination concerns
858                let _ = out.get(&[0]);
859                let dt = t0.elapsed();
860                if run > 0 {
861                    total_elem_tensor += dt;
862                }
863            }
864            let avg_elem_tensor = total_elem_tensor / iterations as u32;
865
866            // -------- Element-wise iterator pipeline (Vec) --------
867            let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
868            let mut total_elem_vec = std::time::Duration::ZERO;
869            for run in 0..(iterations + 1) {
870                let t0 = Instant::now();
871                let _v_out: Vec<f32> = data.iter().map(|&x| 2.0 * x + 1.0).collect();
872                let dt = t0.elapsed();
873                if run > 0 {
874                    total_elem_vec += dt;
875                }
876            }
877            let avg_elem_vec = total_elem_vec / iterations as u32;
878
879            // -------- Chunked iterator pipeline (Tensor) --------
880            let mut total_chunks_tensor = std::time::Duration::ZERO;
881            for run in 0..(iterations + 1) {
882                let t0 = Instant::now();
883                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
884                let t = Tensor::from_slice(&data, vec![n]).unwrap();
885                let parts: Vec<Tensor> = t
886                    .iter_chunks(chunk_size)
887                    .map(|c| c.mul_scalar(2.0).add_scalar(1.0))
888                    .collect();
889                let out = Tensor::cat(&parts, 0);
890                let _ = out.get(&[out.size().saturating_sub(1)]);
891                let dt = t0.elapsed();
892                if run > 0 {
893                    total_chunks_tensor += dt;
894                }
895            }
896            let avg_chunks_tensor = total_chunks_tensor / iterations as u32;
897
898            // -------- Chunked iterator pipeline (Vec) --------
899            let mut total_chunks_vec = std::time::Duration::ZERO;
900            for run in 0..(iterations + 1) {
901                let t0 = Instant::now();
902                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
903                let mut out: Vec<f32> = Vec::with_capacity(n);
904                for chunk in data.chunks(chunk_size) {
905                    for &x in chunk.iter() {
906                        out.push(2.0 * x + 1.0);
907                    }
908                }
909                let _ = out.get(out.len().saturating_sub(1)).copied().unwrap_or(0.0);
910                let dt = t0.elapsed();
911                if run > 0 {
912                    total_chunks_vec += dt;
913                }
914            }
915            let avg_chunks_vec = total_chunks_vec / iterations as u32;
916
917            // -------- Auto-tuned fast chunks (Tensor) --------
918            let mut total_fast_chunks_tensor = std::time::Duration::ZERO;
919            for run in 0..(iterations + 1) {
920                let t0 = Instant::now();
921                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
922                let t = Tensor::from_slice(&data, vec![n]).unwrap();
923                let parts: Vec<Tensor> = t
924                    .iter_fast_chunks()
925                    .map(|c| c.mul_scalar(2.0).add_scalar(1.0))
926                    .collect();
927                let out = Tensor::cat(&parts, 0);
928                let _ = out.get(&[out.size().saturating_sub(1)]);
929                let dt = t0.elapsed();
930                if run > 0 {
931                    total_fast_chunks_tensor += dt;
932                }
933            }
934            let avg_fast_chunks_tensor = total_fast_chunks_tensor / iterations as u32;
935
936            // -------- Value iterator (Tensor) --------
937            let mut total_values_tensor = std::time::Duration::ZERO;
938            let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
939            let t = Tensor::from_slice(&data, vec![n]).unwrap();
940
941            for run in 0..(iterations + 1) {
942                let t0 = Instant::now();
943                let _v_out: Tensor = t.iter_values().map(|x| 2.0 * x + 1.0).collect();
944                let dt = t0.elapsed();
945                if run > 0 {
946                    total_values_tensor += dt;
947                }
948            }
949            let avg_values_tensor = total_values_tensor / iterations as u32;
950
951            // -------- Mutable value iterator (Tensor) --------
952            let mut total_values_mut_tensor = std::time::Duration::ZERO;
953            for run in 0..(iterations + 1) {
954                let t0 = Instant::now();
955                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
956                let mut out = Tensor::from_slice(&data, vec![n]).unwrap();
957                for v in out.iter_values_mut() {
958                    *v = 2.0 * *v + 1.0;
959                }
960                let _ = out.get(&[out.size().saturating_sub(1)]);
961                let dt = t0.elapsed();
962                if run > 0 {
963                    total_values_mut_tensor += dt;
964                }
965            }
966            let avg_values_mut_tensor = total_values_mut_tensor / iterations as u32;
967
968            // -------- Summary per size --------
969            let s_elem = avg_elem_vec.as_secs_f64() / avg_elem_tensor.as_secs_f64();
970            let s_chunks = avg_chunks_vec.as_secs_f64() / avg_chunks_tensor.as_secs_f64();
971            let s_fast_chunks = avg_chunks_vec.as_secs_f64() / avg_fast_chunks_tensor.as_secs_f64();
972            let s_values = avg_elem_vec.as_secs_f64() / avg_values_tensor.as_secs_f64();
973            let s_values_mut = avg_elem_vec.as_secs_f64() / avg_values_mut_tensor.as_secs_f64();
974
975            println!(
976                "\n[Size: {:>9} elements]\n  - Tensor (element): {:>8.3} ms\n  - Vec   (element): {:>8.3} ms\n    Speedup (Tensor/Vec): {:>6.2}x\n  - Tensor (chunks):  {:>8.3} ms\n  - Vec   (chunks):  {:>8.3} ms\n    Speedup (Tensor/Vec): {:>6.2}x\n  - Tensor (fast_chunks): {:>8.3} ms\n    Speedup (fast_chunks vs Vec chunks): {:>6.2}x\n  - Tensor (values):  {:>8.3} ms\n    Speedup (values vs Vec element): {:>6.2}x\n  - Tensor (values_mut):  {:>8.3} ms\n    Speedup (values_mut vs Vec element): {:>6.2}x",
977                n,
978                avg_elem_tensor.as_secs_f64() * 1e3,
979                avg_elem_vec.as_secs_f64() * 1e3,
980                s_elem,
981                avg_chunks_tensor.as_secs_f64() * 1e3,
982                avg_chunks_vec.as_secs_f64() * 1e3,
983                s_chunks,
984                avg_fast_chunks_tensor.as_secs_f64() * 1e3,
985                s_fast_chunks,
986                avg_values_tensor.as_secs_f64() * 1e3,
987                s_values,
988                avg_values_mut_tensor.as_secs_f64() * 1e3,
989                s_values_mut,
990            );
991        }
992
993        println!("\nNote: timings include creation, iteration, ops (2x+1), and collection.");
994    }
995
996    /// Test iter_values over contiguous tensors
997    #[test]
998    fn test_iter_values_contiguous() {
999        let t =
1000            Tensor::from_slice(&(0..16).map(|i| i as f32).collect::<Vec<_>>(), vec![16]).unwrap();
1001        let vals: Vec<f32> = t.iter_values().collect();
1002        assert_eq!(vals, (0..16).map(|i| i as f32).collect::<Vec<_>>());
1003    }
1004
1005    /// Test iter_values_mut requires contiguous and mutates in place
1006    #[test]
1007    fn test_iter_values_mut_contiguous() {
1008        let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1009        for v in t.iter_values_mut() {
1010            *v += 1.0;
1011        }
1012        assert_eq!(t.data(), &[2.0, 3.0, 4.0, 5.0]);
1013    }
1014
1015    /// Test iter_fast_chunks heuristic produces reasonable chunking
1016    #[test]
1017    fn test_iter_fast_chunks_basic() {
1018        let t = Tensor::from_slice(
1019            &(0..100_000).map(|i| i as f32).collect::<Vec<_>>(),
1020            vec![100_000],
1021        )
1022        .unwrap();
1023        let mut total = 0usize;
1024        for c in t.iter_fast_chunks() {
1025            total += c.size();
1026        }
1027        assert_eq!(total, 100_000);
1028    }
1029}