train_station/tensor/iterator/
mod.rs

1//! Iterator module for tensor iteration
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! Implicit performance routing: iterator constructors decide at creation time whether
9//! to use a no-grad fast path (borrowed contiguous traversal or one-time materialization)
10//! or a grad-preserving view path based on `requires_grad` and `gradtrack::is_grad_enabled()`.
11//!
12//! # Key Features
13//!
14//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
15//!   DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
16//! - **Gradient Tracking**: Automatic gradient propagation through element operations
17//! - **Performance Optimized**: True zero-copy views with shared memory
18//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
19//! - **Memory Efficient**: Adaptive view creation based on tensor size
20//! - **Zero-Copy Operations**: Element views share memory with source tensor
21//! - **Full Tensor Operations**: Each element supports all tensor methods
22//!
23//! # Performance Characteristics
24//!
25//! - **View Creation**: O(1) per element with true zero-copy views
26//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
27//! - **SIMD Operations**: Full utilization of existing optimizations
28//! - **Gradient Tracking**: True gradient flow with element-level accumulation
29//! - **Iterator Overhead**: Minimal performance impact for element access
30//! - **Collection Optimization**: Efficient reconstruction from element views
31//!
32//! # Examples
33//!
34//! ## Basic Element Iteration (1D tensors)
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
40//!
41//! // Basic iteration over elements
42//! // For 1D tensors, `iter()` yields scalar views
43//! for element in tensor.iter() {
44//!     println!("Element value: {}", element.value());
45//! }
46//!
47//! // Collect elements into a new tensor
48//! let collected: Tensor = tensor.iter().collect();
49//! assert_eq!(collected.data(), tensor.data());
50//! ```
51//!
52//! ## Element-Wise Transformations (1D convenience)
53//!
54//! ```
55//! use train_station::Tensor;
56//!
57//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
58//!
59//! // Apply tensor operations to each element
60//! let doubled: Tensor = tensor.iter()
61//!     .map(|elem| elem.mul_scalar(2.0))
62//!     .collect();
63//!
64//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
65//!
66//! // Chain multiple operations
67//! let transformed: Tensor = tensor.iter()
68//!     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
69//!     .collect();
70//!
71//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
72//! ```
73//!
74//! ## Advanced Iterator Operations (1D convenience)
75//!
76//! ```
77//! use train_station::Tensor;
78//!
79//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
80//!
81//! // Filter elements based on values
82//! let large_values: Tensor = tensor.iter()
83//!     .filter(|elem| elem.value() > 3.0)
84//!     .collect();
85//!
86//! assert_eq!(large_values.data(), &[4.0, 5.0]);
87//!
88//! // Use enumerate for indexed operations
89//! let indexed: Tensor = tensor.iter()
90//!     .enumerate()
91//!     .map(|(i, elem)| elem.add_scalar(i as f32))
92//!     .collect();
93//!
94//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
95//! ```
96//!
97//! ## Range Iteration (1D convenience)
98//!
99//! ```
100//! use train_station::Tensor;
101//!
102//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
103//!
104//! // Iterate over a specific range
105//! let middle: Tensor = tensor.iter_range(1, 4)
106//!     .map(|elem| elem.mul_scalar(2.0))
107//!     .collect();
108//!
109//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
110//! ```
111//!
112//! ## Double-Ended Iteration (1D convenience)
113//!
114//! ```
115//! use train_station::Tensor;
116//!
117//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
118//!
119//! // Reverse iteration
120//! let reversed: Tensor = tensor.iter().rev().collect();
121//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
122//!
123//! // Iterate from both ends
124//! let mut iter = tensor.iter();
125//! assert_eq!(iter.next().unwrap().value(), 1.0);
126//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
127//! ```
128//!
129//! ## Gradient Tracking
130//!
131//! ```
132//! use train_station::Tensor;
133//!
134//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
135//!     .unwrap()
136//!     .with_requires_grad();
137//!
138//! // Element operations maintain gradient tracking
139//! let result: Tensor = tensor.iter()
140//!     .map(|elem| elem.mul_scalar(2.0))
141//!     .collect();
142//!
143//! assert!(result.requires_grad());
144//! assert_eq!(result.data(), &[2.0, 4.0]);
145//! ```
146//!
147//! # Design Principles and API Overview
148//!
149//! - Use `iter()` or `outer_iter()` to iterate outermost-dimension sub-tensors (Vec-like semantics)
150//! - Use `iter_dim(dim)` to iterate sub-tensors along an arbitrary dimension
151//! - Use `iter_flat()` to iterate scalar views in row-major order
152//! - Use `chunks()` / `chunks_exact()` for linear chunk views
153//! - Use `windows()` / `windows_step()` for overlapping linear window views
154//!
155//! Deprecated aliases (will be removed pre-1.0): `iter_chunks`, `iter_chunks_exact`,
156//! `iter_windows`, `iter_windows_step`.
157//!
158//! - **Zero-Copy Views**: Element views share memory with source tensor
159//! - **Full Tensor Operations**: Each element supports all tensor methods
160//! - **Standard Library Integration**: Complete compatibility with Rust iterators
161//! - **Performance First**: Optimized for high-performance element access
162//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
163//! - **Memory Efficiency**: Minimal overhead for element iteration
164//! - **Type Safety**: Compile-time guarantees for iterator operations
165
166pub mod chunks;
167pub mod collect;
168pub mod element;
169pub mod viewdim;
170pub mod windows;
171
172use crate::gradtrack::is_grad_enabled;
173use crate::tensor::core::Tensor;
174pub use collect::{TensorCollectExt, ValuesCollectExt};
175use std::iter::FromIterator;
176
177/// High-performance iterator over tensor elements as view tensors
178///
179/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
180/// all existing tensor operations and gradient tracking. Implements all
181/// standard iterator traits for maximum compatibility with Rust's ecosystem.
182///
183/// This iterator provides zero-copy access to tensor elements through view
184/// tensors, enabling efficient element-wise operations while maintaining
185/// full compatibility with Rust's standard library iterator methods.
186///
187/// # Performance
188///
189/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
190/// - **O(1) Element Access**: Constant-time view creation for each element
191/// - **Memory Efficient**: ~64 bytes overhead per element view
192/// - **SIMD Compatible**: All tensor operations use existing optimizations
193/// - **Gradient Tracking**: Full gradtrack support through element operations
194///
195/// # Implementation Details
196///
197/// The iterator creates lightweight view tensors on-demand, sharing the same
198/// memory allocation as the source tensor. This ensures zero-copy semantics
199/// while maintaining full tensor operation compatibility.
200///
201/// Each element view is created using `Tensor::element_view()`, which provides
202/// a true view of the underlying data without any copying. The view tensors
203/// support all standard tensor operations including gradient tracking.
204///
205/// # Standard Library Compatibility
206///
207/// This iterator implements all standard iterator traits:
208/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
209/// - `ExactSizeIterator`: Precise size information with `len()`
210/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
211/// - `FusedIterator`: Fused iteration for better performance
212/// - `IntoIterator`: Automatic conversion for `for` loops
213///
214/// # Examples
215///
216/// ## Basic Iteration
217///
218/// ```
219/// use train_station::Tensor;
220///
221/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
222///
223/// // Basic iteration
224/// for element in tensor.iter() {
225///     println!("Element value: {}", element.value());
226/// }
227///
228/// // Standard library methods
229/// let sum: f32 = tensor.iter()
230///     .map(|elem| elem.value())
231///     .sum();
232///
233/// assert_eq!(sum, 6.0);
234/// ```
235///
236/// ## Element Operations
237///
238/// ```
239/// use train_station::Tensor;
240///
241/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
242///
243/// // Tensor operations on elements
244/// let transformed: Tensor = tensor.iter()
245///     .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
246///     .collect();
247///
248/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
249/// ```
250///
251/// ## Advanced Iterator Methods
252///
253/// ```
254/// use train_station::Tensor;
255///
256/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
257///
258/// // Filter and transform
259/// let result: Tensor = tensor.iter()
260///     .filter(|elem| elem.value() > 2.0)
261///     .map(|elem| elem.mul_scalar(10.0))
262///     .collect();
263///
264/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
265///
266/// // Reverse iteration
267/// let reversed: Tensor = tensor.iter().rev().collect();
268/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
269/// ```
270// Re-export iterator types from submodules for public API
271// ===== IntoIterator Implementation =====
272/// IntoIterator for &Tensor now iterates outermost dimension, yielding sub-tensors (views)
273impl<'a> IntoIterator for &'a Tensor {
274    type Item = Tensor;
275    type IntoIter = crate::tensor::iterator::viewdim::TensorDimIterator<'a>;
276
277    fn into_iter(self) -> Self::IntoIter {
278        // Iterate outermost dim by default
279        self.iter_dim(0)
280    }
281}
282
283/// IntoIterator for owned Tensor: iterate outermost dimension producing sub-tensors.
284/// Enables `.into_iter().flatten()` patterns on owned tensors.
285impl IntoIterator for Tensor {
286    type Item = Tensor;
287    type IntoIter = crate::tensor::iterator::viewdim::TensorDimOwnedIterator;
288
289    fn into_iter(self) -> Self::IntoIter {
290        crate::tensor::iterator::viewdim::TensorDimOwnedIterator::new(self, 0)
291    }
292}
293
294// ===== FromIterator Implementation =====
295
296impl FromIterator<Tensor> for Tensor {
297    /// Collect element view tensors back into a single tensor
298    ///
299    /// This method reconstructs a tensor from an iterator of element view tensors.
300    /// It includes optimizations for common patterns and maintains gradient tracking
301    /// when appropriate.
302    ///
303    /// The collection process automatically detects whether all elements are scalar
304    /// views (shape `[1]`) and uses optimized collection strategies accordingly.
305    /// Gradient tracking is preserved when any input element requires gradients.
306    ///
307    /// # Performance
308    ///
309    /// - **Optimized Collection**: Specialized paths for scalar and mixed views
310    /// - **Memory Efficient**: Direct memory copying without intermediate allocations
311    /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
312    /// - **Shape Detection**: Automatic detection of element shapes for optimization
313    ///
314    /// # Implementation Details
315    ///
316    /// The method performs the following steps:
317    /// 1. **Element Collection**: Gathers all element tensors from the iterator
318    /// 2. **Shape Analysis**: Determines if all elements are scalar views
319    /// 3. **Optimized Path**: Uses specialized collection for scalar views
320    /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
321    /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
322    ///
323    /// # Examples
324    ///
325    /// ## Basic Collection
326    ///
327    /// ```
328    /// use train_station::Tensor;
329    ///
330    /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
331    /// let doubled: Tensor = original.iter()
332    ///     .map(|elem| elem.mul_scalar(2.0))
333    ///     .collect();
334    ///
335    /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
336    /// ```
337    ///
338    /// ## Collection with Gradient Tracking
339    ///
340    /// ```
341    /// use train_station::Tensor;
342    ///
343    /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
344    ///     .unwrap()
345    ///     .with_requires_grad();
346    ///
347    /// let result: Tensor = original.iter()
348    ///     .map(|elem| elem.mul_scalar(2.0))
349    ///     .collect();
350    ///
351    /// assert!(result.requires_grad());
352    /// assert_eq!(result.data(), &[2.0, 4.0]);
353    /// ```
354    ///
355    /// ## Empty Iterator Handling
356    ///
357    /// ```
358    /// use train_station::Tensor;
359    ///
360    /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
361    /// assert_eq!(empty.size(), 0);
362    /// assert_eq!(empty.shape().dims(), vec![0]);
363    /// ```
364    fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
365        let elements: Vec<Tensor> = iter.into_iter().collect();
366
367        if elements.is_empty() {
368            return Tensor::new(vec![0]);
369        }
370
371        // Check if all elements are scalars (size == 1). Supports both [1] and 0-D [] shapes
372        let all_scalars = elements.iter().all(|e| e.size() == 1);
373
374        if all_scalars {
375            // Optimized path for scalar element views
376            Self::collect_scalar_views(elements)
377        } else {
378            // General path for mixed shapes
379            Self::collect_mixed_views(elements)
380        }
381    }
382}
383
384impl Tensor {
385    /// Optimized collection for scalar element views
386    ///
387    /// This method efficiently reconstructs a tensor from scalar element views,
388    /// preserving gradient tracking and using optimized memory operations.
389    ///
390    /// This is the fast path for collection when all elements are scalar views
391    /// (shape `[1]`). It performs direct memory copying and sets up gradient
392    /// tracking when any input element requires gradients.
393    ///
394    /// # Arguments
395    ///
396    /// * `elements` - Vector of scalar element view tensors
397    ///
398    /// # Returns
399    ///
400    /// A new tensor containing all element values in a 1D layout
401    ///
402    /// # Performance
403    ///
404    /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
405    /// - **Gradient Optimization**: Efficient gradient tracking setup
406    /// - **Memory Efficient**: Minimal overhead for collection process
407    /// - **SIMD Compatible**: Result tensor supports all optimizations
408    ///
409    /// # Implementation Details
410    ///
411    /// The method performs the following steps:
412    /// 1. **Allocation**: Creates uninitialized tensor with correct size
413    /// 2. **Gradient Check**: Determines if any element requires gradients
414    /// 3. **Memory Copy**: Direct copying from element views to result
415    /// 4. **Gradient Setup**: Configures gradient tracking when needed
416    /// 5. **Operation Registration**: Registers with gradtrack engine
417    fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
418        if elements.is_empty() {
419            return Tensor::new(vec![0]);
420        }
421        // Fast path: if no element requires grad or gradients are disabled, copy directly
422        let any_requires = elements.iter().any(|t| t.requires_grad());
423        if !any_requires || !is_grad_enabled() {
424            let n = elements.len();
425            let mut out = Tensor::new_uninitialized(vec![n]);
426            unsafe {
427                let dst = out.as_mut_ptr();
428                for (i, t) in elements.iter().enumerate() {
429                    debug_assert_eq!(t.size(), 1);
430                    std::ptr::copy_nonoverlapping(t.as_ptr(), dst.add(i), 1);
431                }
432            }
433            return out;
434        }
435
436        // Grad-preserving path: concat along dim 0 then flatten
437        let mut prepped: Vec<Tensor> = Vec::with_capacity(elements.len());
438        for t in elements.into_iter() {
439            if t.shape().rank() == 0 {
440                prepped.push(t.unsqueeze(0)); // [] -> [1]
441            } else {
442                prepped.push(t);
443            }
444        }
445        let concatenated = Tensor::cat(&prepped, 0); // shape: [N, 1]
446        concatenated.flatten() // shape: [N]
447    }
448
449    /// General collection for mixed element shapes
450    ///
451    /// This method handles collection when elements have different shapes,
452    /// flattening all elements into a 1D tensor.
453    ///
454    /// This is the general path for collection when elements have varying shapes.
455    /// It flattens all elements into a single 1D tensor and preserves gradient
456    /// tracking when any input element requires gradients.
457    ///
458    /// # Arguments
459    ///
460    /// * `elements` - Vector of element tensors with potentially different shapes
461    ///
462    /// # Returns
463    ///
464    /// A new 1D tensor containing all flattened element values
465    ///
466    /// # Performance
467    ///
468    /// - **Flattening**: Converts all elements to 1D layout
469    /// - **Memory Copy**: Efficient copying with size calculation
470    /// - **Gradient Preservation**: Maintains gradtrack functionality
471    /// - **Mixed Shapes**: Handles elements with different dimensions
472    ///
473    /// # Implementation Details
474    ///
475    /// The method performs the following steps:
476    /// 1. **Size Calculation**: Sums sizes of all elements for total size
477    /// 2. **Allocation**: Creates uninitialized tensor with total size
478    /// 3. **Sequential Copy**: Copies each element's data sequentially
479    /// 4. **Gradient Setup**: Configures gradient tracking when needed
480    /// 5. **Operation Registration**: Registers with gradtrack engine
481    fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
482        let requires_grad = elements.iter().any(|e| e.requires_grad());
483        // Concatenate then flatten to preserve gradient connections
484        let concatenated = Tensor::cat(&elements, 0);
485        let flattened = concatenated.flatten();
486        if requires_grad && is_grad_enabled() {
487            // Flags are handled by ops; return as-is
488        }
489        flattened
490    }
491
492    // Iterator entry points are implemented in iterator/element.rs
493}
494
495// Redundant iterator type and collection trait/impls have been moved to dedicated files.
496
497#[cfg(test)]
498mod tests {
499    //! Comprehensive tests for tensor element iterator functionality
500    //!
501    //! These tests cover all aspects of the iterator implementation:
502    //! - Basic iteration functionality
503    //! - Standard library trait compliance
504    //! - Gradient tracking through element operations
505    //! - Performance characteristics
506    //! - Edge cases and error conditions
507
508    use super::*;
509
510    /// Test basic iterator functionality
511    #[test]
512    fn test_basic_iteration() {
513        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
514
515        let elements: Vec<Tensor> = tensor.iter_elements().collect();
516        assert_eq!(elements.len(), 4);
517
518        // Check that each element is a scalar tensor with correct value
519        for (i, elem) in elements.iter().enumerate() {
520            assert_eq!(elem.shape().dims(), vec![1]);
521            assert_eq!(elem.size(), 1);
522            assert_eq!(elem.value(), (i + 1) as f32);
523        }
524    }
525
526    /// Test Iterator trait methods
527    #[test]
528    fn test_iterator_trait_methods() {
529        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
530        let mut iter = tensor.iter();
531
532        // Test next()
533        let _first = iter.next().unwrap();
534        assert_eq!(_first.value(), 1.0);
535
536        // Test size_hint() after consuming one element
537        assert_eq!(iter.size_hint(), (4, Some(4)));
538
539        // Test count()
540        assert_eq!(iter.count(), 4);
541
542        // Test nth()
543        let mut iter = tensor.iter();
544        let third = iter.nth(2).unwrap();
545        assert_eq!(third.value(), 3.0);
546
547        // Test last()
548        let mut iter = tensor.iter();
549        let last = iter.next_back().unwrap();
550        assert_eq!(last.value(), 5.0);
551    }
552
553    /// Test ExactSizeIterator
554    #[test]
555    fn test_exact_size_iterator() {
556        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
557        let iter = tensor.iter();
558
559        assert_eq!(iter.len(), 3);
560
561        // Test that len() decreases as we consume the iterator
562        let mut iter = tensor.iter();
563        assert_eq!(iter.len(), 3);
564        iter.next();
565        assert_eq!(iter.len(), 2);
566        iter.next();
567        assert_eq!(iter.len(), 1);
568        iter.next();
569        assert_eq!(iter.len(), 0);
570    }
571
572    /// Test DoubleEndedIterator
573    #[test]
574    fn test_double_ended_iterator() {
575        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
576        let mut iter = tensor.iter();
577
578        // Test next_back()
579        let last = iter.next_back().unwrap();
580        assert_eq!(last.value(), 4.0);
581
582        let first = iter.next().unwrap();
583        assert_eq!(first.value(), 1.0);
584
585        // Test nth_back()
586        let mut iter = tensor.iter();
587        let second_to_last = iter.nth_back(1).unwrap();
588        assert_eq!(second_to_last.value(), 3.0);
589
590        // Test consuming from both ends
591        let mut iter = tensor.iter();
592        assert_eq!(iter.next().unwrap().value(), 1.0);
593        assert_eq!(iter.next_back().unwrap().value(), 4.0);
594        assert_eq!(iter.next().unwrap().value(), 2.0);
595        assert_eq!(iter.next_back().unwrap().value(), 3.0);
596        assert!(iter.next().is_none());
597    }
598
599    /// Test IntoIterator trait
600    #[test]
601    fn test_into_iterator() {
602        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
603
604        // Test with for loop
605        let mut values = Vec::new();
606        for element in &tensor {
607            values.push(element.value());
608        }
609        assert_eq!(values, vec![1.0, 2.0, 3.0]);
610
611        // Test with into_iter() explicitly
612        let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
613        assert_eq!(values, vec![1.0, 2.0, 3.0]);
614    }
615
616    /// Test FromIterator trait (collect)
617    #[test]
618    fn test_from_iterator() {
619        let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
620
621        // Test collecting back to tensor
622        let collected: Tensor = original.iter().collect();
623        assert_eq!(collected.shape().dims(), vec![4]);
624        assert_eq!(collected.data(), original.data());
625
626        // Test collecting with transformations
627        let doubled: Tensor = original
628            .iter()
629            .map(|elem| {
630                let val = elem.value();
631                Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
632            })
633            .collect();
634
635        assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
636    }
637
638    /// Test standard library iterator methods
639    #[test]
640    fn test_std_iterator_methods() {
641        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
642
643        // Test map
644        let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
645        assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
646
647        // Test filter
648        let large_values: Vec<f32> = tensor
649            .iter()
650            .filter(|elem| elem.value() > 3.0)
651            .map(|elem| elem.value())
652            .collect();
653        assert_eq!(large_values, vec![4.0, 5.0]);
654
655        // Test enumerate
656        let with_indices: Vec<(usize, f32)> = tensor
657            .iter()
658            .enumerate()
659            .map(|(i, elem)| (i, elem.value()))
660            .collect();
661        assert_eq!(
662            with_indices,
663            vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
664        );
665
666        // Test fold
667        let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
668        assert_eq!(sum, 15.0);
669
670        // Test find
671        let found = tensor.iter().find(|elem| elem.value() == 3.0);
672        assert!(found.is_some());
673        assert_eq!(found.unwrap().value(), 3.0);
674
675        // Test any/all
676        let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
677        assert!(all_positive);
678
679        let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
680        assert!(any_large);
681    }
682
683    /// Test element operations with tensor methods
684    #[test]
685    fn test_element_tensor_operations() {
686        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
687
688        // Test scalar operations on elements
689        let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
690        assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
691
692        let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
693        assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
694
695        // Test chaining operations
696        let complex: Tensor = tensor
697            .iter()
698            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
699            .collect();
700        assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
701    }
702
703    /// Test gradient tracking through element operations
704    #[test]
705    fn test_gradient_tracking() {
706        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
707            .unwrap()
708            .with_requires_grad();
709
710        // Perform element-wise operations
711        let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
712
713        // The result should require gradients if any element requires gradients
714        // Note: Current implementation creates copies, so gradient tracking is
715        // implemented but may not propagate back to original tensor
716        assert!(result.requires_grad());
717
718        // For now, just verify the forward pass works with gradient-enabled tensors
719        // Full gradient propagation would require true view implementation
720        assert_eq!(result.data(), &[2.0, 4.0]);
721    }
722
723    /// Test with zero-sized tensors
724    #[test]
725    fn test_zero_sized_tensor() {
726        let empty = Tensor::new(vec![0]);
727        let iter = empty.iter();
728
729        assert_eq!(iter.len(), 0);
730        assert_eq!(iter.size_hint(), (0, Some(0)));
731
732        let collected: Tensor = iter.collect();
733        assert_eq!(collected.size(), 0);
734    }
735
736    /// Test range iteration
737    #[test]
738    fn test_range_iteration() {
739        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
740
741        // Test middle range
742        let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
743        assert_eq!(middle, vec![2.0, 3.0, 4.0]);
744
745        // Test out of bounds (should be clamped)
746        let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
747        assert_eq!(clamped, vec![4.0, 5.0]);
748
749        // Test empty range
750        let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
751        assert_eq!(empty, Vec::<f32>::new());
752    }
753
754    /// Test complex iterator chains
755    #[test]
756    fn test_complex_chains() {
757        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
758
759        // Complex chain: enumerate -> filter -> map -> collect
760        let result: Tensor = tensor
761            .iter()
762            .enumerate()
763            .filter(|(i, _)| i % 2 == 0) // Take even indices
764            .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
765            .collect();
766
767        // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
768        assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
769
770        // Test with rev()
771        let reversed: Tensor = tensor.iter().rev().take(3).collect();
772
773        assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
774    }
775
776    /// Performance test for iterator overhead
777    #[test]
778    fn test_performance() {
779        let large_tensor =
780            Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
781                .unwrap();
782
783        let start = std::time::Instant::now();
784
785        let result: Tensor = large_tensor
786            .iter()
787            .map(|elem| elem.mul_scalar(2.0))
788            .collect();
789
790        let duration = start.elapsed();
791        println!("Iterator performance test took: {:?}", duration);
792
793        // Verify correctness
794        assert_eq!(result.size(), 1000);
795        assert_eq!(result.data()[0], 0.0);
796        assert_eq!(result.data()[999], 1998.0);
797    }
798
799    /// Test chunks iterator basic behavior
800    #[test]
801    fn test_chunks_basic() {
802        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
803        let chunks: Vec<Tensor> = t.chunks(2).collect();
804        assert_eq!(chunks.len(), 3);
805        assert_eq!(chunks[0].data(), &[1.0, 2.0]);
806        assert_eq!(chunks[1].data(), &[3.0, 4.0]);
807        assert_eq!(chunks[2].data(), &[5.0]);
808    }
809
810    /// Test chunks_exact with remainder
811    #[test]
812    fn test_chunks_exact_with_remainder() {
813        let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
814        let mut it = t.chunks_exact(2);
815        let v0 = it.next().unwrap();
816        let v1 = it.next().unwrap();
817        assert!(it.next().is_none());
818        assert_eq!(v0.data(), &[10.0, 20.0]);
819        assert_eq!(v1.data(), &[30.0, 40.0]);
820        let r = it.remainder();
821        assert_eq!(r.data(), &[50.0]);
822    }
823
824    /// Test windows iterator with step 1
825    #[test]
826    fn test_windows_basic() {
827        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
828        let wins: Vec<Tensor> = t.windows(3).collect();
829        assert_eq!(wins.len(), 2);
830        assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
831        assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
832    }
833
834    /// Test windows iterator with custom step
835    #[test]
836    fn test_windows_step() {
837        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
838        let wins: Vec<Tensor> = t.windows_step(2, 2).collect();
839        assert_eq!(wins.len(), 2);
840        assert_eq!(wins[0].data(), &[1.0, 2.0]);
841        assert_eq!(wins[1].data(), &[3.0, 4.0]);
842    }
843
844    /// Test collect_shape utility
845    #[test]
846    fn test_collect_shape() {
847        let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
848        let mat = t.chunks(2).collect_shape(vec![3, 2]);
849        assert_eq!(mat.shape().dims(), &[3, 2]);
850        assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
851    }
852
853    /// Performance comparison: Tensor iterator/view system vs Vec iteration
854    ///
855    /// This test compares end-to-end pipelines (creation → iteration → ops → collection)
856    /// across multiple sizes and loop styles, and prints a concise summary.
857    #[test]
858    fn test_iterator_vs_vec_performance_summary() {
859        use std::time::Instant;
860
861        let sizes: [usize; 3] = [100, 1000, 10_000];
862        let iterations: usize = 3; // exclude 1 warmup
863        let chunk_size: usize = 8192;
864
865        println!(
866            "Iterator/View vs Vec performance ({} runs avg, chunk_size={})",
867            iterations, chunk_size
868        );
869
870        for &n in &sizes {
871            // -------- Element-wise iterator pipeline (Tensor) --------
872            let mut total_elem_tensor = std::time::Duration::ZERO;
873            for run in 0..(iterations + 1) {
874                let t0 = Instant::now();
875                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
876                let t = Tensor::from_slice(&data, vec![n]).unwrap();
877                let out: Tensor = t
878                    .iter_elements()
879                    .map(|e| e.mul_scalar(2.0).add_scalar(1.0))
880                    .collect();
881                // Touch a value to avoid any dead-code elimination concerns
882                let _ = out.get(&[0]);
883                let dt = t0.elapsed();
884                if run > 0 {
885                    total_elem_tensor += dt;
886                }
887            }
888            let avg_elem_tensor = total_elem_tensor / iterations as u32;
889
890            // -------- Element-wise iterator pipeline (Vec) --------
891            let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
892            let mut total_elem_vec = std::time::Duration::ZERO;
893            for run in 0..(iterations + 1) {
894                let t0 = Instant::now();
895                let _v_out: Vec<f32> = data.iter().map(|&x| 2.0 * x + 1.0).collect();
896                let dt = t0.elapsed();
897                if run > 0 {
898                    total_elem_vec += dt;
899                }
900            }
901            let avg_elem_vec = total_elem_vec / iterations as u32;
902
903            // -------- Chunked iterator pipeline (Tensor) --------
904            let mut total_chunks_tensor = std::time::Duration::ZERO;
905            for run in 0..(iterations + 1) {
906                let t0 = Instant::now();
907                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
908                let t = Tensor::from_slice(&data, vec![n]).unwrap();
909                let parts: Vec<Tensor> = t
910                    .chunks(chunk_size)
911                    .map(|c| c.mul_scalar(2.0).add_scalar(1.0))
912                    .collect();
913                let out = Tensor::cat(&parts, 0);
914                let _ = out.get(&[out.size().saturating_sub(1)]);
915                let dt = t0.elapsed();
916                if run > 0 {
917                    total_chunks_tensor += dt;
918                }
919            }
920            let avg_chunks_tensor = total_chunks_tensor / iterations as u32;
921
922            // -------- Chunked iterator pipeline (Vec) --------
923            let mut total_chunks_vec = std::time::Duration::ZERO;
924            for run in 0..(iterations + 1) {
925                let t0 = Instant::now();
926                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
927                let mut out: Vec<f32> = Vec::with_capacity(n);
928                for chunk in data.chunks(chunk_size) {
929                    for &x in chunk.iter() {
930                        out.push(2.0 * x + 1.0);
931                    }
932                }
933                let _ = out.get(out.len().saturating_sub(1)).copied().unwrap_or(0.0);
934                let dt = t0.elapsed();
935                if run > 0 {
936                    total_chunks_vec += dt;
937                }
938            }
939            let avg_chunks_vec = total_chunks_vec / iterations as u32;
940
941            // -------- Value iterator (Tensor) --------
942            let mut total_values_tensor = std::time::Duration::ZERO;
943            let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
944            let t = Tensor::from_slice(&data, vec![n]).unwrap();
945
946            for run in 0..(iterations + 1) {
947                let t0 = Instant::now();
948                let _v_out: Tensor = t.iter().map(|e| 2.0 * e.value() + 1.0).collect();
949                let dt = t0.elapsed();
950                if run > 0 {
951                    total_values_tensor += dt;
952                }
953            }
954            let avg_values_tensor = total_values_tensor / iterations as u32;
955
956            // -------- Mutable value iterator (Tensor) --------
957            let mut total_values_mut_tensor = std::time::Duration::ZERO;
958            for run in 0..(iterations + 1) {
959                let t0 = Instant::now();
960                let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
961                let mut out = Tensor::from_slice(&data, vec![n]).unwrap();
962                {
963                    let d = out.data_mut();
964                    for v in d.iter_mut() {
965                        *v = 2.0 * *v + 1.0;
966                    }
967                }
968                let _ = out.get(&[out.size().saturating_sub(1)]);
969                let dt = t0.elapsed();
970                if run > 0 {
971                    total_values_mut_tensor += dt;
972                }
973            }
974            let avg_values_mut_tensor = total_values_mut_tensor / iterations as u32;
975
976            // -------- Summary per size --------
977            let s_elem = avg_elem_vec.as_secs_f64() / avg_elem_tensor.as_secs_f64();
978            let s_chunks = avg_chunks_vec.as_secs_f64() / avg_chunks_tensor.as_secs_f64();
979            let s_values = avg_elem_vec.as_secs_f64() / avg_values_tensor.as_secs_f64();
980            let s_values_mut = avg_elem_vec.as_secs_f64() / avg_values_mut_tensor.as_secs_f64();
981
982            println!(
983                "\n[Size: {:>9} elements]\n  - Tensor (element): {:>8.3} ms\n  - Vec   (element): {:>8.3} ms\n    Speedup (Tensor/Vec): {:>6.2}x\n  - Tensor (chunks):  {:>8.3} ms\n  - Vec   (chunks):  {:>8.3} ms\n    Speedup (Tensor/Vec): {:>6.2}x\n  - Tensor (values):  {:>8.3} ms\n    Speedup (values vs Vec element): {:>6.2}x\n  - Tensor (values_mut):  {:>8.3} ms\n    Speedup (values_mut vs Vec element): {:>6.2}x",
984                n,
985                avg_elem_tensor.as_secs_f64() * 1e3,
986                avg_elem_vec.as_secs_f64() * 1e3,
987                s_elem,
988                avg_chunks_tensor.as_secs_f64() * 1e3,
989                avg_chunks_vec.as_secs_f64() * 1e3,
990                s_chunks,
991                avg_values_tensor.as_secs_f64() * 1e3,
992                s_values,
993                avg_values_mut_tensor.as_secs_f64() * 1e3,
994                s_values_mut,
995            );
996        }
997
998        println!("\nNote: timings include creation, iteration, ops (2x+1), and collection.");
999    }
1000
1001    /// Replacement for previous values-only iteration: derive values via scalar views
1002    #[test]
1003    fn test_values_via_views() {
1004        let t =
1005            Tensor::from_slice(&(0..16).map(|i| i as f32).collect::<Vec<_>>(), vec![16]).unwrap();
1006        let vals: Vec<f32> = t.iter().map(|e| e.value()).collect();
1007        assert_eq!(vals, (0..16).map(|i| i as f32).collect::<Vec<_>>());
1008    }
1009
1010    /// Replacement for previous iter_values_mut: mutate via data_mut
1011    #[test]
1012    fn test_mutation_via_data_mut() {
1013        let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1014        {
1015            let d = t.data_mut();
1016            for v in d.iter_mut() {
1017                *v += 1.0;
1018            }
1019        }
1020        assert_eq!(t.data(), &[2.0, 3.0, 4.0, 5.0]);
1021    }
1022}