train_station/tensor/iterator/mod.rs
1//! Iterator module for tensor iteration
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! Implicit performance routing: iterator constructors decide at creation time whether
9//! to use a no-grad fast path (borrowed contiguous traversal or one-time materialization)
10//! or a grad-preserving view path based on `requires_grad` and `gradtrack::is_grad_enabled()`.
11//!
12//! # Key Features
13//!
14//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
15//! DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
16//! - **Gradient Tracking**: Automatic gradient propagation through element operations
17//! - **Performance Optimized**: True zero-copy views with shared memory
18//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
19//! - **Memory Efficient**: Adaptive view creation based on tensor size
20//! - **Zero-Copy Operations**: Element views share memory with source tensor
21//! - **Full Tensor Operations**: Each element supports all tensor methods
22//!
23//! # Performance Characteristics
24//!
25//! - **View Creation**: O(1) per element with true zero-copy views
26//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
27//! - **SIMD Operations**: Full utilization of existing optimizations
28//! - **Gradient Tracking**: True gradient flow with element-level accumulation
29//! - **Iterator Overhead**: Minimal performance impact for element access
30//! - **Collection Optimization**: Efficient reconstruction from element views
31//!
32//! # Examples
33//!
34//! ## Basic Element Iteration (1D tensors)
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
40//!
41//! // Basic iteration over elements
42//! // For 1D tensors, `iter()` yields scalar views
43//! for element in tensor.iter() {
44//! println!("Element value: {}", element.value());
45//! }
46//!
47//! // Collect elements into a new tensor
48//! let collected: Tensor = tensor.iter().collect();
49//! assert_eq!(collected.data(), tensor.data());
50//! ```
51//!
52//! ## Element-Wise Transformations (1D convenience)
53//!
54//! ```
55//! use train_station::Tensor;
56//!
57//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
58//!
59//! // Apply tensor operations to each element
60//! let doubled: Tensor = tensor.iter()
61//! .map(|elem| elem.mul_scalar(2.0))
62//! .collect();
63//!
64//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
65//!
66//! // Chain multiple operations
67//! let transformed: Tensor = tensor.iter()
68//! .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
69//! .collect();
70//!
71//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
72//! ```
73//!
74//! ## Advanced Iterator Operations (1D convenience)
75//!
76//! ```
77//! use train_station::Tensor;
78//!
79//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
80//!
81//! // Filter elements based on values
82//! let large_values: Tensor = tensor.iter()
83//! .filter(|elem| elem.value() > 3.0)
84//! .collect();
85//!
86//! assert_eq!(large_values.data(), &[4.0, 5.0]);
87//!
88//! // Use enumerate for indexed operations
89//! let indexed: Tensor = tensor.iter()
90//! .enumerate()
91//! .map(|(i, elem)| elem.add_scalar(i as f32))
92//! .collect();
93//!
94//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
95//! ```
96//!
97//! ## Range Iteration (1D convenience)
98//!
99//! ```
100//! use train_station::Tensor;
101//!
102//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
103//!
104//! // Iterate over a specific range
105//! let middle: Tensor = tensor.iter_range(1, 4)
106//! .map(|elem| elem.mul_scalar(2.0))
107//! .collect();
108//!
109//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
110//! ```
111//!
112//! ## Double-Ended Iteration (1D convenience)
113//!
114//! ```
115//! use train_station::Tensor;
116//!
117//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
118//!
119//! // Reverse iteration
120//! let reversed: Tensor = tensor.iter().rev().collect();
121//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
122//!
123//! // Iterate from both ends
124//! let mut iter = tensor.iter();
125//! assert_eq!(iter.next().unwrap().value(), 1.0);
126//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
127//! ```
128//!
129//! ## Gradient Tracking
130//!
131//! ```
132//! use train_station::Tensor;
133//!
134//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
135//! .unwrap()
136//! .with_requires_grad();
137//!
138//! // Element operations maintain gradient tracking
139//! let result: Tensor = tensor.iter()
140//! .map(|elem| elem.mul_scalar(2.0))
141//! .collect();
142//!
143//! assert!(result.requires_grad());
144//! assert_eq!(result.data(), &[2.0, 4.0]);
145//! ```
146//!
147//! # Design Principles and API Overview
148//!
149//! - Use `iter()` or `outer_iter()` to iterate outermost-dimension sub-tensors (Vec-like semantics)
150//! - Use `iter_dim(dim)` to iterate sub-tensors along an arbitrary dimension
151//! - Use `iter_flat()` to iterate scalar views in row-major order
152//! - Use `chunks()` / `chunks_exact()` for linear chunk views
153//! - Use `windows()` / `windows_step()` for overlapping linear window views
154//!
155//! Deprecated aliases (will be removed pre-1.0): `iter_chunks`, `iter_chunks_exact`,
156//! `iter_windows`, `iter_windows_step`.
157//!
158//! - **Zero-Copy Views**: Element views share memory with source tensor
159//! - **Full Tensor Operations**: Each element supports all tensor methods
160//! - **Standard Library Integration**: Complete compatibility with Rust iterators
161//! - **Performance First**: Optimized for high-performance element access
162//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
163//! - **Memory Efficiency**: Minimal overhead for element iteration
164//! - **Type Safety**: Compile-time guarantees for iterator operations
165
166pub mod chunks;
167pub mod collect;
168pub mod element;
169pub mod viewdim;
170pub mod windows;
171
172use crate::gradtrack::is_grad_enabled;
173use crate::tensor::core::Tensor;
174pub use collect::{TensorCollectExt, ValuesCollectExt};
175use std::iter::FromIterator;
176
177/// High-performance iterator over tensor elements as view tensors
178///
179/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
180/// all existing tensor operations and gradient tracking. Implements all
181/// standard iterator traits for maximum compatibility with Rust's ecosystem.
182///
183/// This iterator provides zero-copy access to tensor elements through view
184/// tensors, enabling efficient element-wise operations while maintaining
185/// full compatibility with Rust's standard library iterator methods.
186///
187/// # Performance
188///
189/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
190/// - **O(1) Element Access**: Constant-time view creation for each element
191/// - **Memory Efficient**: ~64 bytes overhead per element view
192/// - **SIMD Compatible**: All tensor operations use existing optimizations
193/// - **Gradient Tracking**: Full gradtrack support through element operations
194///
195/// # Implementation Details
196///
197/// The iterator creates lightweight view tensors on-demand, sharing the same
198/// memory allocation as the source tensor. This ensures zero-copy semantics
199/// while maintaining full tensor operation compatibility.
200///
201/// Each element view is created using `Tensor::element_view()`, which provides
202/// a true view of the underlying data without any copying. The view tensors
203/// support all standard tensor operations including gradient tracking.
204///
205/// # Standard Library Compatibility
206///
207/// This iterator implements all standard iterator traits:
208/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
209/// - `ExactSizeIterator`: Precise size information with `len()`
210/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
211/// - `FusedIterator`: Fused iteration for better performance
212/// - `IntoIterator`: Automatic conversion for `for` loops
213///
214/// # Examples
215///
216/// ## Basic Iteration
217///
218/// ```
219/// use train_station::Tensor;
220///
221/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
222///
223/// // Basic iteration
224/// for element in tensor.iter() {
225/// println!("Element value: {}", element.value());
226/// }
227///
228/// // Standard library methods
229/// let sum: f32 = tensor.iter()
230/// .map(|elem| elem.value())
231/// .sum();
232///
233/// assert_eq!(sum, 6.0);
234/// ```
235///
236/// ## Element Operations
237///
238/// ```
239/// use train_station::Tensor;
240///
241/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
242///
243/// // Tensor operations on elements
244/// let transformed: Tensor = tensor.iter()
245/// .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
246/// .collect();
247///
248/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
249/// ```
250///
251/// ## Advanced Iterator Methods
252///
253/// ```
254/// use train_station::Tensor;
255///
256/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
257///
258/// // Filter and transform
259/// let result: Tensor = tensor.iter()
260/// .filter(|elem| elem.value() > 2.0)
261/// .map(|elem| elem.mul_scalar(10.0))
262/// .collect();
263///
264/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
265///
266/// // Reverse iteration
267/// let reversed: Tensor = tensor.iter().rev().collect();
268/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
269/// ```
270// Re-export iterator types from submodules for public API
271// ===== IntoIterator Implementation =====
272/// IntoIterator for &Tensor now iterates outermost dimension, yielding sub-tensors (views)
273impl<'a> IntoIterator for &'a Tensor {
274 type Item = Tensor;
275 type IntoIter = crate::tensor::iterator::viewdim::TensorDimIterator<'a>;
276
277 fn into_iter(self) -> Self::IntoIter {
278 // Iterate outermost dim by default
279 self.iter_dim(0)
280 }
281}
282
283/// IntoIterator for owned Tensor: iterate outermost dimension producing sub-tensors.
284/// Enables `.into_iter().flatten()` patterns on owned tensors.
285impl IntoIterator for Tensor {
286 type Item = Tensor;
287 type IntoIter = crate::tensor::iterator::viewdim::TensorDimOwnedIterator;
288
289 fn into_iter(self) -> Self::IntoIter {
290 crate::tensor::iterator::viewdim::TensorDimOwnedIterator::new(self, 0)
291 }
292}
293
294// ===== FromIterator Implementation =====
295
296impl FromIterator<Tensor> for Tensor {
297 /// Collect element view tensors back into a single tensor
298 ///
299 /// This method reconstructs a tensor from an iterator of element view tensors.
300 /// It includes optimizations for common patterns and maintains gradient tracking
301 /// when appropriate.
302 ///
303 /// The collection process automatically detects whether all elements are scalar
304 /// views (shape `[1]`) and uses optimized collection strategies accordingly.
305 /// Gradient tracking is preserved when any input element requires gradients.
306 ///
307 /// # Performance
308 ///
309 /// - **Optimized Collection**: Specialized paths for scalar and mixed views
310 /// - **Memory Efficient**: Direct memory copying without intermediate allocations
311 /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
312 /// - **Shape Detection**: Automatic detection of element shapes for optimization
313 ///
314 /// # Implementation Details
315 ///
316 /// The method performs the following steps:
317 /// 1. **Element Collection**: Gathers all element tensors from the iterator
318 /// 2. **Shape Analysis**: Determines if all elements are scalar views
319 /// 3. **Optimized Path**: Uses specialized collection for scalar views
320 /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
321 /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
322 ///
323 /// # Examples
324 ///
325 /// ## Basic Collection
326 ///
327 /// ```
328 /// use train_station::Tensor;
329 ///
330 /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
331 /// let doubled: Tensor = original.iter()
332 /// .map(|elem| elem.mul_scalar(2.0))
333 /// .collect();
334 ///
335 /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
336 /// ```
337 ///
338 /// ## Collection with Gradient Tracking
339 ///
340 /// ```
341 /// use train_station::Tensor;
342 ///
343 /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
344 /// .unwrap()
345 /// .with_requires_grad();
346 ///
347 /// let result: Tensor = original.iter()
348 /// .map(|elem| elem.mul_scalar(2.0))
349 /// .collect();
350 ///
351 /// assert!(result.requires_grad());
352 /// assert_eq!(result.data(), &[2.0, 4.0]);
353 /// ```
354 ///
355 /// ## Empty Iterator Handling
356 ///
357 /// ```
358 /// use train_station::Tensor;
359 ///
360 /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
361 /// assert_eq!(empty.size(), 0);
362 /// assert_eq!(empty.shape().dims(), vec![0]);
363 /// ```
364 fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
365 let elements: Vec<Tensor> = iter.into_iter().collect();
366
367 if elements.is_empty() {
368 return Tensor::new(vec![0]);
369 }
370
371 // Check if all elements are scalars (size == 1). Supports both [1] and 0-D [] shapes
372 let all_scalars = elements.iter().all(|e| e.size() == 1);
373
374 if all_scalars {
375 // Optimized path for scalar element views
376 Self::collect_scalar_views(elements)
377 } else {
378 // General path for mixed shapes
379 Self::collect_mixed_views(elements)
380 }
381 }
382}
383
384impl Tensor {
385 /// Optimized collection for scalar element views
386 ///
387 /// This method efficiently reconstructs a tensor from scalar element views,
388 /// preserving gradient tracking and using optimized memory operations.
389 ///
390 /// This is the fast path for collection when all elements are scalar views
391 /// (shape `[1]`). It performs direct memory copying and sets up gradient
392 /// tracking when any input element requires gradients.
393 ///
394 /// # Arguments
395 ///
396 /// * `elements` - Vector of scalar element view tensors
397 ///
398 /// # Returns
399 ///
400 /// A new tensor containing all element values in a 1D layout
401 ///
402 /// # Performance
403 ///
404 /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
405 /// - **Gradient Optimization**: Efficient gradient tracking setup
406 /// - **Memory Efficient**: Minimal overhead for collection process
407 /// - **SIMD Compatible**: Result tensor supports all optimizations
408 ///
409 /// # Implementation Details
410 ///
411 /// The method performs the following steps:
412 /// 1. **Allocation**: Creates uninitialized tensor with correct size
413 /// 2. **Gradient Check**: Determines if any element requires gradients
414 /// 3. **Memory Copy**: Direct copying from element views to result
415 /// 4. **Gradient Setup**: Configures gradient tracking when needed
416 /// 5. **Operation Registration**: Registers with gradtrack engine
417 fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
418 if elements.is_empty() {
419 return Tensor::new(vec![0]);
420 }
421 // Fast path: if no element requires grad or gradients are disabled, copy directly
422 let any_requires = elements.iter().any(|t| t.requires_grad());
423 if !any_requires || !is_grad_enabled() {
424 let n = elements.len();
425 let mut out = Tensor::new_uninitialized(vec![n]);
426 unsafe {
427 let dst = out.as_mut_ptr();
428 for (i, t) in elements.iter().enumerate() {
429 debug_assert_eq!(t.size(), 1);
430 std::ptr::copy_nonoverlapping(t.as_ptr(), dst.add(i), 1);
431 }
432 }
433 return out;
434 }
435
436 // Grad-preserving path: concat along dim 0 then flatten
437 let mut prepped: Vec<Tensor> = Vec::with_capacity(elements.len());
438 for t in elements.into_iter() {
439 if t.shape().rank() == 0 {
440 prepped.push(t.unsqueeze(0)); // [] -> [1]
441 } else {
442 prepped.push(t);
443 }
444 }
445 let concatenated = Tensor::cat(&prepped, 0); // shape: [N, 1]
446 concatenated.flatten() // shape: [N]
447 }
448
449 /// General collection for mixed element shapes
450 ///
451 /// This method handles collection when elements have different shapes,
452 /// flattening all elements into a 1D tensor.
453 ///
454 /// This is the general path for collection when elements have varying shapes.
455 /// It flattens all elements into a single 1D tensor and preserves gradient
456 /// tracking when any input element requires gradients.
457 ///
458 /// # Arguments
459 ///
460 /// * `elements` - Vector of element tensors with potentially different shapes
461 ///
462 /// # Returns
463 ///
464 /// A new 1D tensor containing all flattened element values
465 ///
466 /// # Performance
467 ///
468 /// - **Flattening**: Converts all elements to 1D layout
469 /// - **Memory Copy**: Efficient copying with size calculation
470 /// - **Gradient Preservation**: Maintains gradtrack functionality
471 /// - **Mixed Shapes**: Handles elements with different dimensions
472 ///
473 /// # Implementation Details
474 ///
475 /// The method performs the following steps:
476 /// 1. **Size Calculation**: Sums sizes of all elements for total size
477 /// 2. **Allocation**: Creates uninitialized tensor with total size
478 /// 3. **Sequential Copy**: Copies each element's data sequentially
479 /// 4. **Gradient Setup**: Configures gradient tracking when needed
480 /// 5. **Operation Registration**: Registers with gradtrack engine
481 fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
482 let requires_grad = elements.iter().any(|e| e.requires_grad());
483 // Concatenate then flatten to preserve gradient connections
484 let concatenated = Tensor::cat(&elements, 0);
485 let flattened = concatenated.flatten();
486 if requires_grad && is_grad_enabled() {
487 // Flags are handled by ops; return as-is
488 }
489 flattened
490 }
491
492 // Iterator entry points are implemented in iterator/element.rs
493}
494
495// Redundant iterator type and collection trait/impls have been moved to dedicated files.
496
497#[cfg(test)]
498mod tests {
499 //! Comprehensive tests for tensor element iterator functionality
500 //!
501 //! These tests cover all aspects of the iterator implementation:
502 //! - Basic iteration functionality
503 //! - Standard library trait compliance
504 //! - Gradient tracking through element operations
505 //! - Performance characteristics
506 //! - Edge cases and error conditions
507
508 use super::*;
509
510 /// Test basic iterator functionality
511 #[test]
512 fn test_basic_iteration() {
513 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
514
515 let elements: Vec<Tensor> = tensor.iter_elements().collect();
516 assert_eq!(elements.len(), 4);
517
518 // Check that each element is a scalar tensor with correct value
519 for (i, elem) in elements.iter().enumerate() {
520 assert_eq!(elem.shape().dims(), vec![1]);
521 assert_eq!(elem.size(), 1);
522 assert_eq!(elem.value(), (i + 1) as f32);
523 }
524 }
525
526 /// Test Iterator trait methods
527 #[test]
528 fn test_iterator_trait_methods() {
529 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
530 let mut iter = tensor.iter();
531
532 // Test next()
533 let _first = iter.next().unwrap();
534 assert_eq!(_first.value(), 1.0);
535
536 // Test size_hint() after consuming one element
537 assert_eq!(iter.size_hint(), (4, Some(4)));
538
539 // Test count()
540 assert_eq!(iter.count(), 4);
541
542 // Test nth()
543 let mut iter = tensor.iter();
544 let third = iter.nth(2).unwrap();
545 assert_eq!(third.value(), 3.0);
546
547 // Test last()
548 let mut iter = tensor.iter();
549 let last = iter.next_back().unwrap();
550 assert_eq!(last.value(), 5.0);
551 }
552
553 /// Test ExactSizeIterator
554 #[test]
555 fn test_exact_size_iterator() {
556 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
557 let iter = tensor.iter();
558
559 assert_eq!(iter.len(), 3);
560
561 // Test that len() decreases as we consume the iterator
562 let mut iter = tensor.iter();
563 assert_eq!(iter.len(), 3);
564 iter.next();
565 assert_eq!(iter.len(), 2);
566 iter.next();
567 assert_eq!(iter.len(), 1);
568 iter.next();
569 assert_eq!(iter.len(), 0);
570 }
571
572 /// Test DoubleEndedIterator
573 #[test]
574 fn test_double_ended_iterator() {
575 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
576 let mut iter = tensor.iter();
577
578 // Test next_back()
579 let last = iter.next_back().unwrap();
580 assert_eq!(last.value(), 4.0);
581
582 let first = iter.next().unwrap();
583 assert_eq!(first.value(), 1.0);
584
585 // Test nth_back()
586 let mut iter = tensor.iter();
587 let second_to_last = iter.nth_back(1).unwrap();
588 assert_eq!(second_to_last.value(), 3.0);
589
590 // Test consuming from both ends
591 let mut iter = tensor.iter();
592 assert_eq!(iter.next().unwrap().value(), 1.0);
593 assert_eq!(iter.next_back().unwrap().value(), 4.0);
594 assert_eq!(iter.next().unwrap().value(), 2.0);
595 assert_eq!(iter.next_back().unwrap().value(), 3.0);
596 assert!(iter.next().is_none());
597 }
598
599 /// Test IntoIterator trait
600 #[test]
601 fn test_into_iterator() {
602 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
603
604 // Test with for loop
605 let mut values = Vec::new();
606 for element in &tensor {
607 values.push(element.value());
608 }
609 assert_eq!(values, vec![1.0, 2.0, 3.0]);
610
611 // Test with into_iter() explicitly
612 let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
613 assert_eq!(values, vec![1.0, 2.0, 3.0]);
614 }
615
616 /// Test FromIterator trait (collect)
617 #[test]
618 fn test_from_iterator() {
619 let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
620
621 // Test collecting back to tensor
622 let collected: Tensor = original.iter().collect();
623 assert_eq!(collected.shape().dims(), vec![4]);
624 assert_eq!(collected.data(), original.data());
625
626 // Test collecting with transformations
627 let doubled: Tensor = original
628 .iter()
629 .map(|elem| {
630 let val = elem.value();
631 Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
632 })
633 .collect();
634
635 assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
636 }
637
638 /// Test standard library iterator methods
639 #[test]
640 fn test_std_iterator_methods() {
641 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
642
643 // Test map
644 let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
645 assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
646
647 // Test filter
648 let large_values: Vec<f32> = tensor
649 .iter()
650 .filter(|elem| elem.value() > 3.0)
651 .map(|elem| elem.value())
652 .collect();
653 assert_eq!(large_values, vec![4.0, 5.0]);
654
655 // Test enumerate
656 let with_indices: Vec<(usize, f32)> = tensor
657 .iter()
658 .enumerate()
659 .map(|(i, elem)| (i, elem.value()))
660 .collect();
661 assert_eq!(
662 with_indices,
663 vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
664 );
665
666 // Test fold
667 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
668 assert_eq!(sum, 15.0);
669
670 // Test find
671 let found = tensor.iter().find(|elem| elem.value() == 3.0);
672 assert!(found.is_some());
673 assert_eq!(found.unwrap().value(), 3.0);
674
675 // Test any/all
676 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
677 assert!(all_positive);
678
679 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
680 assert!(any_large);
681 }
682
683 /// Test element operations with tensor methods
684 #[test]
685 fn test_element_tensor_operations() {
686 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
687
688 // Test scalar operations on elements
689 let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
690 assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
691
692 let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
693 assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
694
695 // Test chaining operations
696 let complex: Tensor = tensor
697 .iter()
698 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
699 .collect();
700 assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
701 }
702
703 /// Test gradient tracking through element operations
704 #[test]
705 fn test_gradient_tracking() {
706 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
707 .unwrap()
708 .with_requires_grad();
709
710 // Perform element-wise operations
711 let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
712
713 // The result should require gradients if any element requires gradients
714 // Note: Current implementation creates copies, so gradient tracking is
715 // implemented but may not propagate back to original tensor
716 assert!(result.requires_grad());
717
718 // For now, just verify the forward pass works with gradient-enabled tensors
719 // Full gradient propagation would require true view implementation
720 assert_eq!(result.data(), &[2.0, 4.0]);
721 }
722
723 /// Test with zero-sized tensors
724 #[test]
725 fn test_zero_sized_tensor() {
726 let empty = Tensor::new(vec![0]);
727 let iter = empty.iter();
728
729 assert_eq!(iter.len(), 0);
730 assert_eq!(iter.size_hint(), (0, Some(0)));
731
732 let collected: Tensor = iter.collect();
733 assert_eq!(collected.size(), 0);
734 }
735
736 /// Test range iteration
737 #[test]
738 fn test_range_iteration() {
739 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
740
741 // Test middle range
742 let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
743 assert_eq!(middle, vec![2.0, 3.0, 4.0]);
744
745 // Test out of bounds (should be clamped)
746 let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
747 assert_eq!(clamped, vec![4.0, 5.0]);
748
749 // Test empty range
750 let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
751 assert_eq!(empty, Vec::<f32>::new());
752 }
753
754 /// Test complex iterator chains
755 #[test]
756 fn test_complex_chains() {
757 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
758
759 // Complex chain: enumerate -> filter -> map -> collect
760 let result: Tensor = tensor
761 .iter()
762 .enumerate()
763 .filter(|(i, _)| i % 2 == 0) // Take even indices
764 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
765 .collect();
766
767 // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
768 assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
769
770 // Test with rev()
771 let reversed: Tensor = tensor.iter().rev().take(3).collect();
772
773 assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
774 }
775
776 /// Performance test for iterator overhead
777 #[test]
778 fn test_performance() {
779 let large_tensor =
780 Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
781 .unwrap();
782
783 let start = std::time::Instant::now();
784
785 let result: Tensor = large_tensor
786 .iter()
787 .map(|elem| elem.mul_scalar(2.0))
788 .collect();
789
790 let duration = start.elapsed();
791 println!("Iterator performance test took: {:?}", duration);
792
793 // Verify correctness
794 assert_eq!(result.size(), 1000);
795 assert_eq!(result.data()[0], 0.0);
796 assert_eq!(result.data()[999], 1998.0);
797 }
798
799 /// Test chunks iterator basic behavior
800 #[test]
801 fn test_chunks_basic() {
802 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
803 let chunks: Vec<Tensor> = t.chunks(2).collect();
804 assert_eq!(chunks.len(), 3);
805 assert_eq!(chunks[0].data(), &[1.0, 2.0]);
806 assert_eq!(chunks[1].data(), &[3.0, 4.0]);
807 assert_eq!(chunks[2].data(), &[5.0]);
808 }
809
810 /// Test chunks_exact with remainder
811 #[test]
812 fn test_chunks_exact_with_remainder() {
813 let t = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0], vec![5]).unwrap();
814 let mut it = t.chunks_exact(2);
815 let v0 = it.next().unwrap();
816 let v1 = it.next().unwrap();
817 assert!(it.next().is_none());
818 assert_eq!(v0.data(), &[10.0, 20.0]);
819 assert_eq!(v1.data(), &[30.0, 40.0]);
820 let r = it.remainder();
821 assert_eq!(r.data(), &[50.0]);
822 }
823
824 /// Test windows iterator with step 1
825 #[test]
826 fn test_windows_basic() {
827 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
828 let wins: Vec<Tensor> = t.windows(3).collect();
829 assert_eq!(wins.len(), 2);
830 assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
831 assert_eq!(wins[1].data(), &[2.0, 3.0, 4.0]);
832 }
833
834 /// Test windows iterator with custom step
835 #[test]
836 fn test_windows_step() {
837 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
838 let wins: Vec<Tensor> = t.windows_step(2, 2).collect();
839 assert_eq!(wins.len(), 2);
840 assert_eq!(wins[0].data(), &[1.0, 2.0]);
841 assert_eq!(wins[1].data(), &[3.0, 4.0]);
842 }
843
844 /// Test collect_shape utility
845 #[test]
846 fn test_collect_shape() {
847 let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
848 let mat = t.chunks(2).collect_shape(vec![3, 2]);
849 assert_eq!(mat.shape().dims(), &[3, 2]);
850 assert_eq!(mat.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
851 }
852
853 /// Performance comparison: Tensor iterator/view system vs Vec iteration
854 ///
855 /// This test compares end-to-end pipelines (creation → iteration → ops → collection)
856 /// across multiple sizes and loop styles, and prints a concise summary.
857 #[test]
858 fn test_iterator_vs_vec_performance_summary() {
859 use std::time::Instant;
860
861 let sizes: [usize; 3] = [100, 1000, 10_000];
862 let iterations: usize = 3; // exclude 1 warmup
863 let chunk_size: usize = 8192;
864
865 println!(
866 "Iterator/View vs Vec performance ({} runs avg, chunk_size={})",
867 iterations, chunk_size
868 );
869
870 for &n in &sizes {
871 // -------- Element-wise iterator pipeline (Tensor) --------
872 let mut total_elem_tensor = std::time::Duration::ZERO;
873 for run in 0..(iterations + 1) {
874 let t0 = Instant::now();
875 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
876 let t = Tensor::from_slice(&data, vec![n]).unwrap();
877 let out: Tensor = t
878 .iter_elements()
879 .map(|e| e.mul_scalar(2.0).add_scalar(1.0))
880 .collect();
881 // Touch a value to avoid any dead-code elimination concerns
882 let _ = out.get(&[0]);
883 let dt = t0.elapsed();
884 if run > 0 {
885 total_elem_tensor += dt;
886 }
887 }
888 let avg_elem_tensor = total_elem_tensor / iterations as u32;
889
890 // -------- Element-wise iterator pipeline (Vec) --------
891 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
892 let mut total_elem_vec = std::time::Duration::ZERO;
893 for run in 0..(iterations + 1) {
894 let t0 = Instant::now();
895 let _v_out: Vec<f32> = data.iter().map(|&x| 2.0 * x + 1.0).collect();
896 let dt = t0.elapsed();
897 if run > 0 {
898 total_elem_vec += dt;
899 }
900 }
901 let avg_elem_vec = total_elem_vec / iterations as u32;
902
903 // -------- Chunked iterator pipeline (Tensor) --------
904 let mut total_chunks_tensor = std::time::Duration::ZERO;
905 for run in 0..(iterations + 1) {
906 let t0 = Instant::now();
907 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
908 let t = Tensor::from_slice(&data, vec![n]).unwrap();
909 let parts: Vec<Tensor> = t
910 .chunks(chunk_size)
911 .map(|c| c.mul_scalar(2.0).add_scalar(1.0))
912 .collect();
913 let out = Tensor::cat(&parts, 0);
914 let _ = out.get(&[out.size().saturating_sub(1)]);
915 let dt = t0.elapsed();
916 if run > 0 {
917 total_chunks_tensor += dt;
918 }
919 }
920 let avg_chunks_tensor = total_chunks_tensor / iterations as u32;
921
922 // -------- Chunked iterator pipeline (Vec) --------
923 let mut total_chunks_vec = std::time::Duration::ZERO;
924 for run in 0..(iterations + 1) {
925 let t0 = Instant::now();
926 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
927 let mut out: Vec<f32> = Vec::with_capacity(n);
928 for chunk in data.chunks(chunk_size) {
929 for &x in chunk.iter() {
930 out.push(2.0 * x + 1.0);
931 }
932 }
933 let _ = out.get(out.len().saturating_sub(1)).copied().unwrap_or(0.0);
934 let dt = t0.elapsed();
935 if run > 0 {
936 total_chunks_vec += dt;
937 }
938 }
939 let avg_chunks_vec = total_chunks_vec / iterations as u32;
940
941 // -------- Value iterator (Tensor) --------
942 let mut total_values_tensor = std::time::Duration::ZERO;
943 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
944 let t = Tensor::from_slice(&data, vec![n]).unwrap();
945
946 for run in 0..(iterations + 1) {
947 let t0 = Instant::now();
948 let _v_out: Tensor = t.iter().map(|e| 2.0 * e.value() + 1.0).collect();
949 let dt = t0.elapsed();
950 if run > 0 {
951 total_values_tensor += dt;
952 }
953 }
954 let avg_values_tensor = total_values_tensor / iterations as u32;
955
956 // -------- Mutable value iterator (Tensor) --------
957 let mut total_values_mut_tensor = std::time::Duration::ZERO;
958 for run in 0..(iterations + 1) {
959 let t0 = Instant::now();
960 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
961 let mut out = Tensor::from_slice(&data, vec![n]).unwrap();
962 {
963 let d = out.data_mut();
964 for v in d.iter_mut() {
965 *v = 2.0 * *v + 1.0;
966 }
967 }
968 let _ = out.get(&[out.size().saturating_sub(1)]);
969 let dt = t0.elapsed();
970 if run > 0 {
971 total_values_mut_tensor += dt;
972 }
973 }
974 let avg_values_mut_tensor = total_values_mut_tensor / iterations as u32;
975
976 // -------- Summary per size --------
977 let s_elem = avg_elem_vec.as_secs_f64() / avg_elem_tensor.as_secs_f64();
978 let s_chunks = avg_chunks_vec.as_secs_f64() / avg_chunks_tensor.as_secs_f64();
979 let s_values = avg_elem_vec.as_secs_f64() / avg_values_tensor.as_secs_f64();
980 let s_values_mut = avg_elem_vec.as_secs_f64() / avg_values_mut_tensor.as_secs_f64();
981
982 println!(
983 "\n[Size: {:>9} elements]\n - Tensor (element): {:>8.3} ms\n - Vec (element): {:>8.3} ms\n Speedup (Tensor/Vec): {:>6.2}x\n - Tensor (chunks): {:>8.3} ms\n - Vec (chunks): {:>8.3} ms\n Speedup (Tensor/Vec): {:>6.2}x\n - Tensor (values): {:>8.3} ms\n Speedup (values vs Vec element): {:>6.2}x\n - Tensor (values_mut): {:>8.3} ms\n Speedup (values_mut vs Vec element): {:>6.2}x",
984 n,
985 avg_elem_tensor.as_secs_f64() * 1e3,
986 avg_elem_vec.as_secs_f64() * 1e3,
987 s_elem,
988 avg_chunks_tensor.as_secs_f64() * 1e3,
989 avg_chunks_vec.as_secs_f64() * 1e3,
990 s_chunks,
991 avg_values_tensor.as_secs_f64() * 1e3,
992 s_values,
993 avg_values_mut_tensor.as_secs_f64() * 1e3,
994 s_values_mut,
995 );
996 }
997
998 println!("\nNote: timings include creation, iteration, ops (2x+1), and collection.");
999 }
1000
1001 /// Replacement for previous values-only iteration: derive values via scalar views
1002 #[test]
1003 fn test_values_via_views() {
1004 let t =
1005 Tensor::from_slice(&(0..16).map(|i| i as f32).collect::<Vec<_>>(), vec![16]).unwrap();
1006 let vals: Vec<f32> = t.iter().map(|e| e.value()).collect();
1007 assert_eq!(vals, (0..16).map(|i| i as f32).collect::<Vec<_>>());
1008 }
1009
1010 /// Replacement for previous iter_values_mut: mutate via data_mut
1011 #[test]
1012 fn test_mutation_via_data_mut() {
1013 let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
1014 {
1015 let d = t.data_mut();
1016 for v in d.iter_mut() {
1017 *v += 1.0;
1018 }
1019 }
1020 assert_eq!(t.data(), &[2.0, 3.0, 4.0, 5.0]);
1021 }
1022}