train_station/tensor/iterator/mod.rs
1//! Iterator module for tensor element-wise operations
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! # Key Features
9//!
10//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
11//! DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
12//! - **Gradient Tracking**: Automatic gradient propagation through element operations
13//! - **Performance Optimized**: True zero-copy views with shared memory
14//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
15//! - **Memory Efficient**: Adaptive view creation based on tensor size
16//! - **Zero-Copy Operations**: Element views share memory with source tensor
17//! - **Full Tensor Operations**: Each element supports all tensor methods
18//!
19//! # Performance Characteristics
20//!
21//! - **View Creation**: O(1) per element with true zero-copy views
22//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
23//! - **SIMD Operations**: Full utilization of existing optimizations
24//! - **Gradient Tracking**: True gradient flow with element-level accumulation
25//! - **Iterator Overhead**: Minimal performance impact for element access
26//! - **Collection Optimization**: Efficient reconstruction from element views
27//!
28//! # Examples
29//!
30//! ## Basic Element Iteration
31//!
32//! ```
33//! use train_station::Tensor;
34//!
35//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
36//!
37//! // Basic iteration over elements
38//! for element in tensor.iter() {
39//! println!("Element value: {}", element.value());
40//! }
41//!
42//! // Collect elements into a new tensor
43//! let collected: Tensor = tensor.iter().collect();
44//! assert_eq!(collected.data(), tensor.data());
45//! ```
46//!
47//! ## Element-Wise Transformations
48//!
49//! ```
50//! use train_station::Tensor;
51//!
52//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
53//!
54//! // Apply tensor operations to each element
55//! let doubled: Tensor = tensor.iter()
56//! .map(|elem| elem.mul_scalar(2.0))
57//! .collect();
58//!
59//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
60//!
61//! // Chain multiple operations
62//! let transformed: Tensor = tensor.iter()
63//! .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
64//! .collect();
65//!
66//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
67//! ```
68//!
69//! ## Advanced Iterator Operations
70//!
71//! ```
72//! use train_station::Tensor;
73//!
74//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
75//!
76//! // Filter elements based on values
77//! let large_values: Tensor = tensor.iter()
78//! .filter(|elem| elem.value() > 3.0)
79//! .collect();
80//!
81//! assert_eq!(large_values.data(), &[4.0, 5.0]);
82//!
83//! // Use enumerate for indexed operations
84//! let indexed: Tensor = tensor.iter()
85//! .enumerate()
86//! .map(|(i, elem)| elem.add_scalar(i as f32))
87//! .collect();
88//!
89//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
90//! ```
91//!
92//! ## Range Iteration
93//!
94//! ```
95//! use train_station::Tensor;
96//!
97//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
98//!
99//! // Iterate over a specific range
100//! let middle: Tensor = tensor.iter_range(1, 4)
101//! .map(|elem| elem.mul_scalar(2.0))
102//! .collect();
103//!
104//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
105//! ```
106//!
107//! ## Double-Ended Iteration
108//!
109//! ```
110//! use train_station::Tensor;
111//!
112//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
113//!
114//! // Reverse iteration
115//! let reversed: Tensor = tensor.iter().rev().collect();
116//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
117//!
118//! // Iterate from both ends
119//! let mut iter = tensor.iter();
120//! assert_eq!(iter.next().unwrap().value(), 1.0);
121//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
122//! ```
123//!
124//! ## Gradient Tracking
125//!
126//! ```
127//! use train_station::Tensor;
128//!
129//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
130//! .unwrap()
131//! .with_requires_grad();
132//!
133//! // Element operations maintain gradient tracking
134//! let result: Tensor = tensor.iter()
135//! .map(|elem| elem.mul_scalar(2.0))
136//! .collect();
137//!
138//! assert!(result.requires_grad());
139//! assert_eq!(result.data(), &[2.0, 4.0]);
140//! ```
141//!
142//! # Design Principles
143//!
144//! - **Zero-Copy Views**: Element views share memory with source tensor
145//! - **Full Tensor Operations**: Each element supports all tensor methods
146//! - **Standard Library Integration**: Complete compatibility with Rust iterators
147//! - **Performance First**: Optimized for high-performance element access
148//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
149//! - **Memory Efficiency**: Minimal overhead for element iteration
150//! - **Type Safety**: Compile-time guarantees for iterator operations
151
152pub mod chunks;
153pub mod collect;
154pub mod element;
155pub mod value;
156pub mod viewdim;
157pub mod windows;
158
159use crate::gradtrack::is_grad_enabled;
160use crate::tensor::core::Tensor;
161pub use collect::{TensorCollectExt, ValuesCollectExt};
162use std::iter::FromIterator;
163
164/// High-performance iterator over tensor elements as view tensors
165///
166/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
167/// all existing tensor operations and gradient tracking. Implements all
168/// standard iterator traits for maximum compatibility with Rust's ecosystem.
169///
170/// This iterator provides zero-copy access to tensor elements through view
171/// tensors, enabling efficient element-wise operations while maintaining
172/// full compatibility with Rust's standard library iterator methods.
173///
174/// # Performance
175///
176/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
177/// - **O(1) Element Access**: Constant-time view creation for each element
178/// - **Memory Efficient**: ~64 bytes overhead per element view
179/// - **SIMD Compatible**: All tensor operations use existing optimizations
180/// - **Gradient Tracking**: Full gradtrack support through element operations
181///
182/// # Implementation Details
183///
184/// The iterator creates lightweight view tensors on-demand, sharing the same
185/// memory allocation as the source tensor. This ensures zero-copy semantics
186/// while maintaining full tensor operation compatibility.
187///
188/// Each element view is created using `Tensor::element_view()`, which provides
189/// a true view of the underlying data without any copying. The view tensors
190/// support all standard tensor operations including gradient tracking.
191///
192/// # Standard Library Compatibility
193///
194/// This iterator implements all standard iterator traits:
195/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
196/// - `ExactSizeIterator`: Precise size information with `len()`
197/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
198/// - `FusedIterator`: Fused iteration for better performance
199/// - `IntoIterator`: Automatic conversion for `for` loops
200///
201/// # Examples
202///
203/// ## Basic Iteration
204///
205/// ```
206/// use train_station::Tensor;
207///
208/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
209///
210/// // Basic iteration
211/// for element in tensor.iter() {
212/// println!("Element value: {}", element.value());
213/// }
214///
215/// // Standard library methods
216/// let sum: f32 = tensor.iter()
217/// .map(|elem| elem.value())
218/// .sum();
219///
220/// assert_eq!(sum, 6.0);
221/// ```
222///
223/// ## Element Operations
224///
225/// ```
226/// use train_station::Tensor;
227///
228/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
229///
230/// // Tensor operations on elements
231/// let transformed: Tensor = tensor.iter()
232/// .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
233/// .collect();
234///
235/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
236/// ```
237///
238/// ## Advanced Iterator Methods
239///
240/// ```
241/// use train_station::Tensor;
242///
243/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
244///
245/// // Filter and transform
246/// let result: Tensor = tensor.iter()
247/// .filter(|elem| elem.value() > 2.0)
248/// .map(|elem| elem.mul_scalar(10.0))
249/// .collect();
250///
251/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
252///
253/// // Reverse iteration
254/// let reversed: Tensor = tensor.iter().rev().collect();
255/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
256/// ```
257// Re-export iterator types from submodules for public API
258// ===== IntoIterator Implementation =====
259/// IntoIterator for &Tensor now iterates outermost dimension, yielding sub-tensors (views)
260impl<'a> IntoIterator for &'a Tensor {
261 type Item = Tensor;
262 type IntoIter = crate::tensor::iterator::viewdim::TensorDimIterator<'a>;
263
264 fn into_iter(self) -> Self::IntoIter {
265 // Iterate outermost dim by default
266 self.iter_dim(0)
267 }
268}
269
270// ===== FromIterator Implementation =====
271
272impl FromIterator<Tensor> for Tensor {
273 /// Collect element view tensors back into a single tensor
274 ///
275 /// This method reconstructs a tensor from an iterator of element view tensors.
276 /// It includes optimizations for common patterns and maintains gradient tracking
277 /// when appropriate.
278 ///
279 /// The collection process automatically detects whether all elements are scalar
280 /// views (shape `[1]`) and uses optimized collection strategies accordingly.
281 /// Gradient tracking is preserved when any input element requires gradients.
282 ///
283 /// # Performance
284 ///
285 /// - **Optimized Collection**: Specialized paths for scalar and mixed views
286 /// - **Memory Efficient**: Direct memory copying without intermediate allocations
287 /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
288 /// - **Shape Detection**: Automatic detection of element shapes for optimization
289 ///
290 /// # Implementation Details
291 ///
292 /// The method performs the following steps:
293 /// 1. **Element Collection**: Gathers all element tensors from the iterator
294 /// 2. **Shape Analysis**: Determines if all elements are scalar views
295 /// 3. **Optimized Path**: Uses specialized collection for scalar views
296 /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
297 /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
298 ///
299 /// # Examples
300 ///
301 /// ## Basic Collection
302 ///
303 /// ```
304 /// use train_station::Tensor;
305 ///
306 /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
307 /// let doubled: Tensor = original.iter()
308 /// .map(|elem| elem.mul_scalar(2.0))
309 /// .collect();
310 ///
311 /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
312 /// ```
313 ///
314 /// ## Collection with Gradient Tracking
315 ///
316 /// ```
317 /// use train_station::Tensor;
318 ///
319 /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
320 /// .unwrap()
321 /// .with_requires_grad();
322 ///
323 /// let result: Tensor = original.iter()
324 /// .map(|elem| elem.mul_scalar(2.0))
325 /// .collect();
326 ///
327 /// assert!(result.requires_grad());
328 /// assert_eq!(result.data(), &[2.0, 4.0]);
329 /// ```
330 ///
331 /// ## Empty Iterator Handling
332 ///
333 /// ```
334 /// use train_station::Tensor;
335 ///
336 /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
337 /// assert_eq!(empty.size(), 0);
338 /// assert_eq!(empty.shape().dims(), vec![0]);
339 /// ```
340 fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
341 let elements: Vec<Tensor> = iter.into_iter().collect();
342
343 if elements.is_empty() {
344 return Tensor::new(vec![0]);
345 }
346
347 // Check if all elements are scalars (size == 1). Supports both [1] and 0-D [] shapes
348 let all_scalars = elements.iter().all(|e| e.size() == 1);
349
350 if all_scalars {
351 // Optimized path for scalar element views
352 Self::collect_scalar_views(elements)
353 } else {
354 // General path for mixed shapes
355 Self::collect_mixed_views(elements)
356 }
357 }
358}
359
360impl Tensor {
361 /// Optimized collection for scalar element views
362 ///
363 /// This method efficiently reconstructs a tensor from scalar element views,
364 /// preserving gradient tracking and using optimized memory operations.
365 ///
366 /// This is the fast path for collection when all elements are scalar views
367 /// (shape `[1]`). It performs direct memory copying and sets up gradient
368 /// tracking when any input element requires gradients.
369 ///
370 /// # Arguments
371 ///
372 /// * `elements` - Vector of scalar element view tensors
373 ///
374 /// # Returns
375 ///
376 /// A new tensor containing all element values in a 1D layout
377 ///
378 /// # Performance
379 ///
380 /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
381 /// - **Gradient Optimization**: Efficient gradient tracking setup
382 /// - **Memory Efficient**: Minimal overhead for collection process
383 /// - **SIMD Compatible**: Result tensor supports all optimizations
384 ///
385 /// # Implementation Details
386 ///
387 /// The method performs the following steps:
388 /// 1. **Allocation**: Creates uninitialized tensor with correct size
389 /// 2. **Gradient Check**: Determines if any element requires gradients
390 /// 3. **Memory Copy**: Direct copying from element views to result
391 /// 4. **Gradient Setup**: Configures gradient tracking when needed
392 /// 5. **Operation Registration**: Registers with gradtrack engine
393 fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
394 if elements.is_empty() {
395 return Tensor::new(vec![0]);
396 }
397 // Fast path: if no element requires grad or gradients are disabled, copy directly
398 let any_requires = elements.iter().any(|t| t.requires_grad());
399 if !any_requires || !is_grad_enabled() {
400 let n = elements.len();
401 let mut out = Tensor::new_uninitialized(vec![n]);
402 unsafe {
403 let dst = out.as_mut_ptr();
404 for (i, t) in elements.iter().enumerate() {
405 debug_assert_eq!(t.size(), 1);
406 std::ptr::copy_nonoverlapping(t.as_ptr(), dst.add(i), 1);
407 }
408 }
409 return out;
410 }
411
412 // Grad-preserving path: concat along dim 0 then flatten
413 let mut prepped: Vec<Tensor> = Vec::with_capacity(elements.len());
414 for t in elements.into_iter() {
415 if t.shape().rank() == 0 {
416 prepped.push(t.unsqueeze(0)); // [] -> [1]
417 } else {
418 prepped.push(t);
419 }
420 }
421 let concatenated = Tensor::cat(&prepped, 0); // shape: [N, 1]
422 concatenated.flatten() // shape: [N]
423 }
424
425 /// General collection for mixed element shapes
426 ///
427 /// This method handles collection when elements have different shapes,
428 /// flattening all elements into a 1D tensor.
429 ///
430 /// This is the general path for collection when elements have varying shapes.
431 /// It flattens all elements into a single 1D tensor and preserves gradient
432 /// tracking when any input element requires gradients.
433 ///
434 /// # Arguments
435 ///
436 /// * `elements` - Vector of element tensors with potentially different shapes
437 ///
438 /// # Returns
439 ///
440 /// A new 1D tensor containing all flattened element values
441 ///
442 /// # Performance
443 ///
444 /// - **Flattening**: Converts all elements to 1D layout
445 /// - **Memory Copy**: Efficient copying with size calculation
446 /// - **Gradient Preservation**: Maintains gradtrack functionality
447 /// - **Mixed Shapes**: Handles elements with different dimensions
448 ///
449 /// # Implementation Details
450 ///
451 /// The method performs the following steps:
452 /// 1. **Size Calculation**: Sums sizes of all elements for total size
453 /// 2. **Allocation**: Creates uninitialized tensor with total size
454 /// 3. **Sequential Copy**: Copies each element's data sequentially
455 /// 4. **Gradient Setup**: Configures gradient tracking when needed
456 /// 5. **Operation Registration**: Registers with gradtrack engine
457 fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
458 let requires_grad = elements.iter().any(|e| e.requires_grad());
459 // Concatenate then flatten to preserve gradient connections
460 let concatenated = Tensor::cat(&elements, 0);
461 let flattened = concatenated.flatten();
462 if requires_grad && is_grad_enabled() {
463 // Flags are handled by ops; return as-is
464 }
465 flattened
466 }
467
468 // Iterator entry points are implemented in iterator/element.rs
469}
470
471// Redundant iterator type and collection trait/impls have been moved to dedicated files.
472
473#[cfg(test)]
474mod tests {
475 //! Comprehensive tests for tensor element iterator functionality
476 //!
477 //! These tests cover all aspects of the iterator implementation:
478 //! - Basic iteration functionality
479 //! - Standard library trait compliance
480 //! - Gradient tracking through element operations
481 //! - Performance characteristics
482 //! - Edge cases and error conditions
483
484 use super::*;
485
486 /// Test basic iterator functionality
487 #[test]
488 fn test_basic_iteration() {
489 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
490
491 let elements: Vec<Tensor> = tensor.iter_elements().collect();
492 assert_eq!(elements.len(), 4);
493
494 // Check that each element is a scalar tensor with correct value
495 for (i, elem) in elements.iter().enumerate() {
496 assert_eq!(elem.shape().dims(), vec![1]);
497 assert_eq!(elem.size(), 1);
498 assert_eq!(elem.value(), (i + 1) as f32);
499 }
500 }
501
502 /// Test Iterator trait methods
503 #[test]
504 fn test_iterator_trait_methods() {
505 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
506 let mut iter = tensor.iter();
507
508 // Test next()
509 let _first = iter.next().unwrap();
510 assert_eq!(_first.value(), 1.0);
511
512 // Test size_hint() after consuming one element
513 assert_eq!(iter.size_hint(), (4, Some(4)));
514
515 // Test count()
516 assert_eq!(iter.count(), 4);
517
518 // Test nth()
519 let mut iter = tensor.iter();
520 let third = iter.nth(2).unwrap();
521 assert_eq!(third.value(), 3.0);
522
523 // Test last()
524 let mut iter = tensor.iter();
525 let last = iter.next_back().unwrap();
526 assert_eq!(last.value(), 5.0);
527 }
528
529 /// Test ExactSizeIterator
530 #[test]
531 fn test_exact_size_iterator() {
532 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
533 let iter = tensor.iter();
534
535 assert_eq!(iter.len(), 3);
536
537 // Test that len() decreases as we consume the iterator
538 let mut iter = tensor.iter();
539 assert_eq!(iter.len(), 3);
540 iter.next();
541 assert_eq!(iter.len(), 2);
542 iter.next();
543 assert_eq!(iter.len(), 1);
544 iter.next();
545 assert_eq!(iter.len(), 0);
546 }
547
548 /// Test DoubleEndedIterator
549 #[test]
550 fn test_double_ended_iterator() {
551 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
552 let mut iter = tensor.iter();
553
554 // Test next_back()
555 let last = iter.next_back().unwrap();
556 assert_eq!(last.value(), 4.0);
557
558 let first = iter.next().unwrap();
559 assert_eq!(first.value(), 1.0);
560
561 // Test nth_back()
562 let mut iter = tensor.iter();
563 let second_to_last = iter.nth_back(1).unwrap();
564 assert_eq!(second_to_last.value(), 3.0);
565
566 // Test consuming from both ends
567 let mut iter = tensor.iter();
568 assert_eq!(iter.next().unwrap().value(), 1.0);
569 assert_eq!(iter.next_back().unwrap().value(), 4.0);
570 assert_eq!(iter.next().unwrap().value(), 2.0);
571 assert_eq!(iter.next_back().unwrap().value(), 3.0);
572 assert!(iter.next().is_none());
573 }
574
575 /// Test IntoIterator trait
576 #[test]
577 fn test_into_iterator() {
578 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
579
580 // Test with for loop
581 let mut values = Vec::new();
582 for element in &tensor {
583 values.push(element.value());
584 }
585 assert_eq!(values, vec![1.0, 2.0, 3.0]);
586
587 // Test with into_iter() explicitly
588 let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
589 assert_eq!(values, vec![1.0, 2.0, 3.0]);
590 }
591
592 /// Test FromIterator trait (collect)
593 #[test]
594 fn test_from_iterator() {
595 let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
596
597 // Test collecting back to tensor
598 let collected: Tensor = original.iter().collect();
599 assert_eq!(collected.shape().dims(), vec![4]);
600 assert_eq!(collected.data(), original.data());
601
602 // Test collecting with transformations
603 let doubled: Tensor = original
604 .iter()
605 .map(|elem| {
606 let val = elem.value();
607 Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
608 })
609 .collect();
610
611 assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
612 }
613
614 /// Test standard library iterator methods
615 #[test]
616 fn test_std_iterator_methods() {
617 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
618
619 // Test map
620 let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
621 assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
622
623 // Test filter
624 let large_values: Vec<f32> = tensor
625 .iter()
626 .filter(|elem| elem.value() > 3.0)
627 .map(|elem| elem.value())
628 .collect();
629 assert_eq!(large_values, vec![4.0, 5.0]);
630
631 // Test enumerate
632 let with_indices: Vec<(usize, f32)> = tensor
633 .iter()
634 .enumerate()
635 .map(|(i, elem)| (i, elem.value()))
636 .collect();
637 assert_eq!(
638 with_indices,
639 vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
640 );
641
642 // Test fold
643 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
644 assert_eq!(sum, 15.0);
645
646 // Test find
647 let found = tensor.iter().find(|elem| elem.value() == 3.0);
648 assert!(found.is_some());
649 assert_eq!(found.unwrap().value(), 3.0);
650
651 // Test any/all
652 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
653 assert!(all_positive);
654
655 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
656 assert!(any_large);
657 }
658
659 /// Test element operations with tensor methods
660 #[test]
661 fn test_element_tensor_operations() {
662 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
663
664 // Test scalar operations on elements
665 let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
666 assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
667
668 let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
669 assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
670
671 // Test chaining operations
672 let complex: Tensor = tensor
673 .iter()
674 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
675 .collect();
676 assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
677 }
678
679 /// Test gradient tracking through element operations
680 #[test]
681 fn test_gradient_tracking() {
682 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
683 .unwrap()
684 .with_requires_grad();
685
686 // Perform element-wise operations
687 let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
688
689 // The result should require gradients if any element requires gradients
690 // Note: Current implementation creates copies, so gradient tracking is
691 // implemented but may not propagate back to original tensor
692 assert!(result.requires_grad());
693
694 // For now, just verify the forward pass works with gradient-enabled tensors
695 // Full gradient propagation would require true view implementation
696 assert_eq!(result.data(), &[2.0, 4.0]);
697 }
698
699 /// Test with zero-sized tensors
700 #[test]
701 fn test_zero_sized_tensor() {
702 let empty = Tensor::new(vec![0]);
703 let iter = empty.iter();
704
705 assert_eq!(iter.len(), 0);
706 assert_eq!(iter.size_hint(), (0, Some(0)));
707
708 let collected: Tensor = iter.collect();
709 assert_eq!(collected.size(), 0);
710 }
711
712 /// Test range iteration
713 #[test]
714 fn test_range_iteration() {
715 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
716
717 // Test middle range
718 let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
719 assert_eq!(middle, vec![2.0, 3.0, 4.0]);
720
721 // Test out of bounds (should be clamped)
722 let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
723 assert_eq!(clamped, vec![4.0, 5.0]);
724
725 // Test empty range
726 let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
727 assert_eq!(empty, Vec::<f32>::new());
728 }
729
730 /// Test complex iterator chains
731 #[test]
732 fn test_complex_chains() {
733 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
734
735 // Complex chain: enumerate -> filter -> map -> collect
736 let result: Tensor = tensor
737 .iter()
738 .enumerate()
739 .filter(|(i, _)| i % 2 == 0) // Take even indices
740 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
741 .collect();
742
743 // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
744 assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
745
746 // Test with rev()
747 let reversed: Tensor = tensor.iter().rev().take(3).collect();
748
749 assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
750 }
751
752 /// Performance test for iterator overhead
753 #[test]
754 fn test_performance() {
755 let large_tensor =
756 Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
757 .unwrap();
758
759 let start = std::time::Instant::now();
760
761 let result: Tensor = large_tensor
762 .iter()
763 .map(|elem| elem.mul_scalar(2.0))
764 .collect();
765
766 let duration = start.elapsed();
767 println!("Iterator performance test took: {:?}", duration);
768
769 // Verify correctness
770 assert_eq!(result.size(), 1000);
771 assert_eq!(result.data()[0], 0.0);
772 assert_eq!(result.data()[999], 1998.0);
773 }
774
775 /// Test chunks iterator basic behavior
776 #[test]
777 fn test_chunks_basic() {
778 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
779 let chunks: Vec<Tensor> = t.iter_chunks(2).collect();
780 assert_eq!(chunks.len(), 3);
781 assert_eq!(chunks[0].data(), &[1.0, 2.0]);
782 assert_eq!(chunks[1].data(), &[3.0, 4.0]);
783 assert_eq!(chunks[2].data(), &[5.0]);
784 }
785
786 /// Test chunks_exact with remainder
787 #[test]
788 fn test_chunks_exact_with_remainder() {
789 let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
790 let mut it = t.iter_chunks_exact(2);
791 let v0 = it.next().unwrap();
792 let v1 = it.next().unwrap();
793 assert!(it.next().is_none());
794 assert_eq!(v0.data(), &[10.0, 20.0]);
795 assert_eq!(v1.data(), &[30.0, 40.0]);
796 let r = it.remainder();
797 assert_eq!(r.data(), &[50.0]);
798 }
799
800 /// Test windows iterator with step 1
801 #[test]
802 fn test_windows_basic() {
803 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
804 let wins: Vec<Tensor> = t.iter_windows(3).collect();
805 assert_eq!(wins.len(), 2);
806 assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
807 assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
808 }
809
810 /// Test windows iterator with custom step
811 #[test]
812 fn test_windows_step() {
813 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
814 let wins: Vec<Tensor> = t.iter_windows_step(2, 2).collect();
815 assert_eq!(wins.len(), 2);
816 assert_eq!(wins[0].data(), &[1.0, 2.0]);
817 assert_eq!(wins[1].data(), &[3.0, 4.0]);
818 }
819
820 /// Test collect_shape utility
821 #[test]
822 fn test_collect_shape() {
823 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
824 let mat = t.iter_chunks(2).collect_shape(vec![3, 2]);
825 assert_eq!(mat.shape().dims(), &[3, 2]);
826 assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
827 }
828
829 /// Performance comparison: Tensor iterator/view system vs Vec iteration
830 ///
831 /// This test compares end-to-end pipelines (creation → iteration → ops → collection)
832 /// across multiple sizes and loop styles, and prints a concise summary.
833 #[test]
834 fn test_iterator_vs_vec_performance_summary() {
835 use std::time::Instant;
836
837 let sizes: [usize; 3] = [100, 1000, 10_000];
838 let iterations: usize = 3; // exclude 1 warmup
839 let chunk_size: usize = 8192;
840
841 println!(
842 "Iterator/View vs Vec performance ({} runs avg, chunk_size={})",
843 iterations, chunk_size
844 );
845
846 for &n in &sizes {
847 // -------- Element-wise iterator pipeline (Tensor) --------
848 let mut total_elem_tensor = std::time::Duration::ZERO;
849 for run in 0..(iterations + 1) {
850 let t0 = Instant::now();
851 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
852 let t = Tensor::from_slice(&data, vec![n]).unwrap();
853 let out: Tensor = t
854 .iter_elements()
855 .map(|e| e.mul_scalar(2.0).add_scalar(1.0))
856 .collect();
857 // Touch a value to avoid any dead-code elimination concerns
858 let _ = out.get(&[0]);
859 let dt = t0.elapsed();
860 if run > 0 {
861 total_elem_tensor += dt;
862 }
863 }
864 let avg_elem_tensor = total_elem_tensor / iterations as u32;
865
866 // -------- Element-wise iterator pipeline (Vec) --------
867 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
868 let mut total_elem_vec = std::time::Duration::ZERO;
869 for run in 0..(iterations + 1) {
870 let t0 = Instant::now();
871 let _v_out: Vec<f32> = data.iter().map(|&x| 2.0 * x + 1.0).collect();
872 let dt = t0.elapsed();
873 if run > 0 {
874 total_elem_vec += dt;
875 }
876 }
877 let avg_elem_vec = total_elem_vec / iterations as u32;
878
879 // -------- Chunked iterator pipeline (Tensor) --------
880 let mut total_chunks_tensor = std::time::Duration::ZERO;
881 for run in 0..(iterations + 1) {
882 let t0 = Instant::now();
883 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
884 let t = Tensor::from_slice(&data, vec![n]).unwrap();
885 let parts: Vec<Tensor> = t
886 .iter_chunks(chunk_size)
887 .map(|c| c.mul_scalar(2.0).add_scalar(1.0))
888 .collect();
889 let out = Tensor::cat(&parts, 0);
890 let _ = out.get(&[out.size().saturating_sub(1)]);
891 let dt = t0.elapsed();
892 if run > 0 {
893 total_chunks_tensor += dt;
894 }
895 }
896 let avg_chunks_tensor = total_chunks_tensor / iterations as u32;
897
898 // -------- Chunked iterator pipeline (Vec) --------
899 let mut total_chunks_vec = std::time::Duration::ZERO;
900 for run in 0..(iterations + 1) {
901 let t0 = Instant::now();
902 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
903 let mut out: Vec<f32> = Vec::with_capacity(n);
904 for chunk in data.chunks(chunk_size) {
905 for &x in chunk.iter() {
906 out.push(2.0 * x + 1.0);
907 }
908 }
909 let _ = out.get(out.len().saturating_sub(1)).copied().unwrap_or(0.0);
910 let dt = t0.elapsed();
911 if run > 0 {
912 total_chunks_vec += dt;
913 }
914 }
915 let avg_chunks_vec = total_chunks_vec / iterations as u32;
916
917 // -------- Auto-tuned fast chunks (Tensor) --------
918 let mut total_fast_chunks_tensor = std::time::Duration::ZERO;
919 for run in 0..(iterations + 1) {
920 let t0 = Instant::now();
921 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
922 let t = Tensor::from_slice(&data, vec![n]).unwrap();
923 let parts: Vec<Tensor> = t
924 .iter_fast_chunks()
925 .map(|c| c.mul_scalar(2.0).add_scalar(1.0))
926 .collect();
927 let out = Tensor::cat(&parts, 0);
928 let _ = out.get(&[out.size().saturating_sub(1)]);
929 let dt = t0.elapsed();
930 if run > 0 {
931 total_fast_chunks_tensor += dt;
932 }
933 }
934 let avg_fast_chunks_tensor = total_fast_chunks_tensor / iterations as u32;
935
936 // -------- Value iterator (Tensor) --------
937 let mut total_values_tensor = std::time::Duration::ZERO;
938 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
939 let t = Tensor::from_slice(&data, vec![n]).unwrap();
940
941 for run in 0..(iterations + 1) {
942 let t0 = Instant::now();
943 let _v_out: Tensor = t.iter_values().map(|x| 2.0 * x + 1.0).collect();
944 let dt = t0.elapsed();
945 if run > 0 {
946 total_values_tensor += dt;
947 }
948 }
949 let avg_values_tensor = total_values_tensor / iterations as u32;
950
951 // -------- Mutable value iterator (Tensor) --------
952 let mut total_values_mut_tensor = std::time::Duration::ZERO;
953 for run in 0..(iterations + 1) {
954 let t0 = Instant::now();
955 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
956 let mut out = Tensor::from_slice(&data, vec![n]).unwrap();
957 for v in out.iter_values_mut() {
958 *v = 2.0 * *v + 1.0;
959 }
960 let _ = out.get(&[out.size().saturating_sub(1)]);
961 let dt = t0.elapsed();
962 if run > 0 {
963 total_values_mut_tensor += dt;
964 }
965 }
966 let avg_values_mut_tensor = total_values_mut_tensor / iterations as u32;
967
968 // -------- Summary per size --------
969 let s_elem = avg_elem_vec.as_secs_f64() / avg_elem_tensor.as_secs_f64();
970 let s_chunks = avg_chunks_vec.as_secs_f64() / avg_chunks_tensor.as_secs_f64();
971 let s_fast_chunks = avg_chunks_vec.as_secs_f64() / avg_fast_chunks_tensor.as_secs_f64();
972 let s_values = avg_elem_vec.as_secs_f64() / avg_values_tensor.as_secs_f64();
973 let s_values_mut = avg_elem_vec.as_secs_f64() / avg_values_mut_tensor.as_secs_f64();
974
975 println!(
976 "\n[Size: {:>9} elements]\n - Tensor (element): {:>8.3} ms\n - Vec (element): {:>8.3} ms\n Speedup (Tensor/Vec): {:>6.2}x\n - Tensor (chunks): {:>8.3} ms\n - Vec (chunks): {:>8.3} ms\n Speedup (Tensor/Vec): {:>6.2}x\n - Tensor (fast_chunks): {:>8.3} ms\n Speedup (fast_chunks vs Vec chunks): {:>6.2}x\n - Tensor (values): {:>8.3} ms\n Speedup (values vs Vec element): {:>6.2}x\n - Tensor (values_mut): {:>8.3} ms\n Speedup (values_mut vs Vec element): {:>6.2}x",
977 n,
978 avg_elem_tensor.as_secs_f64() * 1e3,
979 avg_elem_vec.as_secs_f64() * 1e3,
980 s_elem,
981 avg_chunks_tensor.as_secs_f64() * 1e3,
982 avg_chunks_vec.as_secs_f64() * 1e3,
983 s_chunks,
984 avg_fast_chunks_tensor.as_secs_f64() * 1e3,
985 s_fast_chunks,
986 avg_values_tensor.as_secs_f64() * 1e3,
987 s_values,
988 avg_values_mut_tensor.as_secs_f64() * 1e3,
989 s_values_mut,
990 );
991 }
992
993 println!("\nNote: timings include creation, iteration, ops (2x+1), and collection.");
994 }
995
996 /// Test iter_values over contiguous tensors
997 #[test]
998 fn test_iter_values_contiguous() {
999 let t =
1000 Tensor::from_slice(&(0..16).map(|i| i as f32).collect::<Vec<_>>(), vec![16]).unwrap();
1001 let vals: Vec<f32> = t.iter_values().collect();
1002 assert_eq!(vals, (0..16).map(|i| i as f32).collect::<Vec<_>>());
1003 }
1004
1005 /// Test iter_values_mut requires contiguous and mutates in place
1006 #[test]
1007 fn test_iter_values_mut_contiguous() {
1008 let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1009 for v in t.iter_values_mut() {
1010 *v += 1.0;
1011 }
1012 assert_eq!(t.data(), &[2.0, 3.0, 4.0, 5.0]);
1013 }
1014
1015 /// Test iter_fast_chunks heuristic produces reasonable chunking
1016 #[test]
1017 fn test_iter_fast_chunks_basic() {
1018 let t = Tensor::from_slice(
1019 &(0..100_000).map(|i| i as f32).collect::<Vec<_>>(),
1020 vec![100_000],
1021 )
1022 .unwrap();
1023 let mut total = 0usize;
1024 for c in t.iter_fast_chunks() {
1025 total += c.size();
1026 }
1027 assert_eq!(total, 100_000);
1028 }
1029}