train_station/tensor/iterator/mod.rs
1//! Iterator module for tensor element-wise operations
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! # Key Features
9//!
10//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
11//! DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
12//! - **Gradient Tracking**: Automatic gradient propagation through element operations
13//! - **Performance Optimized**: True zero-copy views with shared memory
14//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
15//! - **Memory Efficient**: Adaptive view creation based on tensor size
16//! - **Zero-Copy Operations**: Element views share memory with source tensor
17//! - **Full Tensor Operations**: Each element supports all tensor methods
18//!
19//! # Performance Characteristics
20//!
21//! - **View Creation**: O(1) per element with true zero-copy views
22//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
23//! - **SIMD Operations**: Full utilization of existing optimizations
24//! - **Gradient Tracking**: True gradient flow with element-level accumulation
25//! - **Iterator Overhead**: Minimal performance impact for element access
26//! - **Collection Optimization**: Efficient reconstruction from element views
27//!
28//! # Examples
29//!
30//! ## Basic Element Iteration
31//!
32//! ```
33//! use train_station::Tensor;
34//!
35//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
36//!
37//! // Basic iteration over elements
38//! for element in tensor.iter() {
39//! println!("Element value: {}", element.value());
40//! }
41//!
42//! // Collect elements into a new tensor
43//! let collected: Tensor = tensor.iter().collect();
44//! assert_eq!(collected.data(), tensor.data());
45//! ```
46//!
47//! ## Element-Wise Transformations
48//!
49//! ```
50//! use train_station::Tensor;
51//!
52//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
53//!
54//! // Apply tensor operations to each element
55//! let doubled: Tensor = tensor.iter()
56//! .map(|elem| elem.mul_scalar(2.0))
57//! .collect();
58//!
59//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
60//!
61//! // Chain multiple operations
62//! let transformed: Tensor = tensor.iter()
63//! .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
64//! .collect();
65//!
66//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
67//! ```
68//!
69//! ## Advanced Iterator Operations
70//!
71//! ```
72//! use train_station::Tensor;
73//!
74//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
75//!
76//! // Filter elements based on values
77//! let large_values: Tensor = tensor.iter()
78//! .filter(|elem| elem.value() > 3.0)
79//! .collect();
80//!
81//! assert_eq!(large_values.data(), &[4.0, 5.0]);
82//!
83//! // Use enumerate for indexed operations
84//! let indexed: Tensor = tensor.iter()
85//! .enumerate()
86//! .map(|(i, elem)| elem.add_scalar(i as f32))
87//! .collect();
88//!
89//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
90//! ```
91//!
92//! ## Range Iteration
93//!
94//! ```
95//! use train_station::Tensor;
96//!
97//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
98//!
99//! // Iterate over a specific range
100//! let middle: Tensor = tensor.iter_range(1, 4)
101//! .map(|elem| elem.mul_scalar(2.0))
102//! .collect();
103//!
104//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
105//! ```
106//!
107//! ## Double-Ended Iteration
108//!
109//! ```
110//! use train_station::Tensor;
111//!
112//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
113//!
114//! // Reverse iteration
115//! let reversed: Tensor = tensor.iter().rev().collect();
116//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
117//!
118//! // Iterate from both ends
119//! let mut iter = tensor.iter();
120//! assert_eq!(iter.next().unwrap().value(), 1.0);
121//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
122//! ```
123//!
124//! ## Gradient Tracking
125//!
126//! ```
127//! use train_station::Tensor;
128//!
129//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
130//! .unwrap()
131//! .with_requires_grad();
132//!
133//! // Element operations maintain gradient tracking
134//! let result: Tensor = tensor.iter()
135//! .map(|elem| elem.mul_scalar(2.0))
136//! .collect();
137//!
138//! assert!(result.requires_grad());
139//! assert_eq!(result.data(), &[2.0, 4.0]);
140//! ```
141//!
142//! # Design Principles
143//!
144//! - **Zero-Copy Views**: Element views share memory with source tensor
145//! - **Full Tensor Operations**: Each element supports all tensor methods
146//! - **Standard Library Integration**: Complete compatibility with Rust iterators
147//! - **Performance First**: Optimized for high-performance element access
148//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
149//! - **Memory Efficiency**: Minimal overhead for element iteration
150//! - **Type Safety**: Compile-time guarantees for iterator operations
151
152use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
153use crate::tensor::core::Tensor;
154use std::iter::{FromIterator, FusedIterator};
155
156/// High-performance iterator over tensor elements as view tensors
157///
158/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
159/// all existing tensor operations and gradient tracking. Implements all
160/// standard iterator traits for maximum compatibility with Rust's ecosystem.
161///
162/// This iterator provides zero-copy access to tensor elements through view
163/// tensors, enabling efficient element-wise operations while maintaining
164/// full compatibility with Rust's standard library iterator methods.
165///
166/// # Performance
167///
168/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
169/// - **O(1) Element Access**: Constant-time view creation for each element
170/// - **Memory Efficient**: ~64 bytes overhead per element view
171/// - **SIMD Compatible**: All tensor operations use existing optimizations
172/// - **Gradient Tracking**: Full gradtrack support through element operations
173///
174/// # Implementation Details
175///
176/// The iterator creates lightweight view tensors on-demand, sharing the same
177/// memory allocation as the source tensor. This ensures zero-copy semantics
178/// while maintaining full tensor operation compatibility.
179///
180/// Each element view is created using `Tensor::element_view()`, which provides
181/// a true view of the underlying data without any copying. The view tensors
182/// support all standard tensor operations including gradient tracking.
183///
184/// # Standard Library Compatibility
185///
186/// This iterator implements all standard iterator traits:
187/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
188/// - `ExactSizeIterator`: Precise size information with `len()`
189/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
190/// - `FusedIterator`: Fused iteration for better performance
191/// - `IntoIterator`: Automatic conversion for `for` loops
192///
193/// # Examples
194///
195/// ## Basic Iteration
196///
197/// ```
198/// use train_station::Tensor;
199///
200/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
201///
202/// // Basic iteration
203/// for element in tensor.iter() {
204/// println!("Element value: {}", element.value());
205/// }
206///
207/// // Standard library methods
208/// let sum: f32 = tensor.iter()
209/// .map(|elem| elem.value())
210/// .sum();
211///
212/// assert_eq!(sum, 6.0);
213/// ```
214///
215/// ## Element Operations
216///
217/// ```
218/// use train_station::Tensor;
219///
220/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
221///
222/// // Tensor operations on elements
223/// let transformed: Tensor = tensor.iter()
224/// .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
225/// .collect();
226///
227/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
228/// ```
229///
230/// ## Advanced Iterator Methods
231///
232/// ```
233/// use train_station::Tensor;
234///
235/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
236///
237/// // Filter and transform
238/// let result: Tensor = tensor.iter()
239/// .filter(|elem| elem.value() > 2.0)
240/// .map(|elem| elem.mul_scalar(10.0))
241/// .collect();
242///
243/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
244///
245/// // Reverse iteration
246/// let reversed: Tensor = tensor.iter().rev().collect();
247/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
248/// ```
249pub struct TensorElementIterator<'a> {
250 /// Reference to the source tensor
251 source: &'a Tensor,
252 /// Current position in iteration
253 position: usize,
254 /// End position (exclusive)
255 end: usize,
256}
257
258impl<'a> TensorElementIterator<'a> {
259 /// Create a new iterator over all tensor elements
260 ///
261 /// Creates an iterator that yields view tensors for each element in the
262 /// source tensor. Each element becomes a `Tensor` of shape `[1]` that
263 /// supports all tensor operations and gradient tracking.
264 ///
265 /// # Arguments
266 ///
267 /// * `tensor` - The source tensor to iterate over
268 ///
269 /// # Returns
270 ///
271 /// An iterator that yields view tensors for each element
272 ///
273 /// # Performance
274 ///
275 /// - **O(1) Creation**: Constant-time iterator initialization
276 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
277 /// - **Memory Efficient**: Minimal overhead for iterator state
278 ///
279 /// # Implementation Details
280 ///
281 /// This method creates an iterator that yields view tensors for each element
282 /// in the source tensor. Each element becomes a `Tensor` of shape `[1]` that
283 /// supports all tensor operations and gradient tracking.
284 ///
285 /// The iterator provides zero-copy access to tensor elements through view
286 /// tensors, enabling efficient element-wise operations while maintaining
287 /// full compatibility with Rust's standard library iterator methods.
288 #[track_caller]
289 pub fn new(tensor: &'a Tensor) -> Self {
290 Self {
291 source: tensor,
292 position: 0,
293 end: tensor.size(),
294 }
295 }
296
297 /// Create an iterator over a specific range of elements
298 ///
299 /// Creates an iterator that yields view tensors for elements in the specified
300 /// range. The range is automatically clamped to valid tensor bounds for safety.
301 ///
302 /// # Arguments
303 ///
304 /// * `tensor` - The source tensor to iterate over
305 /// * `start` - Starting index (inclusive)
306 /// * `end` - Ending index (exclusive)
307 ///
308 /// # Returns
309 ///
310 /// An iterator that yields view tensors for elements in the specified range
311 ///
312 /// # Safety
313 ///
314 /// The range is automatically clamped to valid tensor bounds:
315 /// - `start` is clamped to `[0, tensor.size()]`
316 /// - `end` is clamped to `[start, tensor.size()]`
317 /// - Empty ranges (start >= end) are handled gracefully
318 ///
319 /// # Performance
320 ///
321 /// - **O(1) Creation**: Constant-time iterator initialization
322 /// - **Bounds Checking**: Automatic range validation and clamping
323 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
324 ///
325 /// # Implementation Details
326 ///
327 /// This method creates an iterator that yields view tensors for elements in
328 /// the specified range. The range is automatically clamped to valid tensor
329 /// bounds for safety, ensuring that out-of-bounds access is handled gracefully.
330 ///
331 /// The iterator provides zero-copy access to tensor elements through view
332 /// tensors, enabling efficient element-wise operations while maintaining
333 /// full compatibility with Rust's standard library iterator methods.
334 #[track_caller]
335 pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
336 let end = end.min(tensor.size());
337 let start = start.min(end);
338 Self {
339 source: tensor,
340 position: start,
341 end,
342 }
343 }
344
345 /// Create an optimized element view for the given position
346 ///
347 /// This method creates a true view tensor of shape `[1]` that shares memory
348 /// with the element at the specified index in the source tensor. The view
349 /// enables zero-copy element access with full gradient tracking.
350 ///
351 /// # Arguments
352 ///
353 /// * `index` - Index of the element to create a view for
354 ///
355 /// # Returns
356 ///
357 /// A view tensor of shape `[1]` representing the element at the specified index
358 ///
359 /// # Safety
360 ///
361 /// The caller must ensure that `index < self.source.size()`.
362 ///
363 /// # Performance
364 ///
365 /// - **O(1) View Creation**: Constant-time view tensor creation
366 /// - **Zero-Copy**: View shares memory with source tensor
367 /// - **Memory Efficient**: ~64 bytes overhead for view metadata
368 /// - **Gradient Tracking**: Full gradtrack support through view operations
369 ///
370 /// # Implementation Details
371 ///
372 /// This method delegates to `Tensor::element_view()` which creates a true
373 /// view of the underlying data without any copying. The view tensor supports
374 /// all standard tensor operations including gradient tracking and SIMD
375 /// optimizations.
376 fn create_element_view(&self, index: usize) -> Tensor {
377 debug_assert!(index < self.source.size());
378
379 self.source.element_view(index)
380 }
381}
382
383// ===== Core Iterator Implementation =====
384
385impl<'a> Iterator for TensorElementIterator<'a> {
386 type Item = Tensor;
387
388 #[inline]
389 fn next(&mut self) -> Option<Self::Item> {
390 if self.position < self.end {
391 let view = self.create_element_view(self.position);
392 self.position += 1;
393 Some(view)
394 } else {
395 None
396 }
397 }
398
399 #[inline]
400 fn size_hint(&self) -> (usize, Option<usize>) {
401 let remaining = self.end - self.position;
402 (remaining, Some(remaining))
403 }
404
405 #[inline]
406 fn count(self) -> usize {
407 self.end - self.position
408 }
409
410 #[inline]
411 fn nth(&mut self, n: usize) -> Option<Self::Item> {
412 let new_pos = self.position.saturating_add(n);
413 if new_pos < self.end {
414 self.position = new_pos + 1;
415 Some(self.create_element_view(new_pos))
416 } else {
417 self.position = self.end;
418 None
419 }
420 }
421
422 #[inline]
423 fn last(self) -> Option<Self::Item> {
424 if self.position < self.end {
425 let last_idx = self.end - 1;
426 Some(self.create_element_view(last_idx))
427 } else {
428 None
429 }
430 }
431}
432
433impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
434 #[inline]
435 fn len(&self) -> usize {
436 self.end - self.position
437 }
438}
439
440impl<'a> FusedIterator for TensorElementIterator<'a> {}
441
442impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
443 #[inline]
444 fn next_back(&mut self) -> Option<Self::Item> {
445 if self.position < self.end {
446 self.end -= 1;
447 Some(self.create_element_view(self.end))
448 } else {
449 None
450 }
451 }
452
453 #[inline]
454 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
455 let new_end = self.end.saturating_sub(n + 1);
456 if new_end >= self.position {
457 self.end = new_end;
458 Some(self.create_element_view(self.end))
459 } else {
460 self.position = self.end;
461 None
462 }
463 }
464}
465
466// ===== IntoIterator Implementation =====
467
468impl<'a> IntoIterator for &'a Tensor {
469 type Item = Tensor;
470 type IntoIter = TensorElementIterator<'a>;
471
472 fn into_iter(self) -> Self::IntoIter {
473 TensorElementIterator::new(self)
474 }
475}
476
477// ===== FromIterator Implementation =====
478
479impl FromIterator<Tensor> for Tensor {
480 /// Collect element view tensors back into a single tensor
481 ///
482 /// This method reconstructs a tensor from an iterator of element view tensors.
483 /// It includes optimizations for common patterns and maintains gradient tracking
484 /// when appropriate.
485 ///
486 /// The collection process automatically detects whether all elements are scalar
487 /// views (shape `[1]`) and uses optimized collection strategies accordingly.
488 /// Gradient tracking is preserved when any input element requires gradients.
489 ///
490 /// # Performance
491 ///
492 /// - **Optimized Collection**: Specialized paths for scalar and mixed views
493 /// - **Memory Efficient**: Direct memory copying without intermediate allocations
494 /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
495 /// - **Shape Detection**: Automatic detection of element shapes for optimization
496 ///
497 /// # Implementation Details
498 ///
499 /// The method performs the following steps:
500 /// 1. **Element Collection**: Gathers all element tensors from the iterator
501 /// 2. **Shape Analysis**: Determines if all elements are scalar views
502 /// 3. **Optimized Path**: Uses specialized collection for scalar views
503 /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
504 /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
505 ///
506 /// # Examples
507 ///
508 /// ## Basic Collection
509 ///
510 /// ```
511 /// use train_station::Tensor;
512 ///
513 /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
514 /// let doubled: Tensor = original.iter()
515 /// .map(|elem| elem.mul_scalar(2.0))
516 /// .collect();
517 ///
518 /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
519 /// ```
520 ///
521 /// ## Collection with Gradient Tracking
522 ///
523 /// ```
524 /// use train_station::Tensor;
525 ///
526 /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
527 /// .unwrap()
528 /// .with_requires_grad();
529 ///
530 /// let result: Tensor = original.iter()
531 /// .map(|elem| elem.mul_scalar(2.0))
532 /// .collect();
533 ///
534 /// assert!(result.requires_grad());
535 /// assert_eq!(result.data(), &[2.0, 4.0]);
536 /// ```
537 ///
538 /// ## Empty Iterator Handling
539 ///
540 /// ```
541 /// use train_station::Tensor;
542 ///
543 /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
544 /// assert_eq!(empty.size(), 0);
545 /// assert_eq!(empty.shape().dims, vec![0]);
546 /// ```
547 fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
548 let elements: Vec<Tensor> = iter.into_iter().collect();
549
550 if elements.is_empty() {
551 return Tensor::new(vec![0]);
552 }
553
554 // Check if all elements are scalar views (shape [1])
555 let all_scalars = elements.iter().all(|e| e.shape().dims == vec![1]);
556
557 if all_scalars {
558 // Optimized path for scalar element views
559 Self::collect_scalar_views(elements)
560 } else {
561 // General path for mixed shapes
562 Self::collect_mixed_views(elements)
563 }
564 }
565}
566
567impl Tensor {
568 /// Optimized collection for scalar element views
569 ///
570 /// This method efficiently reconstructs a tensor from scalar element views,
571 /// preserving gradient tracking and using optimized memory operations.
572 ///
573 /// This is the fast path for collection when all elements are scalar views
574 /// (shape `[1]`). It performs direct memory copying and sets up gradient
575 /// tracking when any input element requires gradients.
576 ///
577 /// # Arguments
578 ///
579 /// * `elements` - Vector of scalar element view tensors
580 ///
581 /// # Returns
582 ///
583 /// A new tensor containing all element values in a 1D layout
584 ///
585 /// # Performance
586 ///
587 /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
588 /// - **Gradient Optimization**: Efficient gradient tracking setup
589 /// - **Memory Efficient**: Minimal overhead for collection process
590 /// - **SIMD Compatible**: Result tensor supports all optimizations
591 ///
592 /// # Implementation Details
593 ///
594 /// The method performs the following steps:
595 /// 1. **Allocation**: Creates uninitialized tensor with correct size
596 /// 2. **Gradient Check**: Determines if any element requires gradients
597 /// 3. **Memory Copy**: Direct copying from element views to result
598 /// 4. **Gradient Setup**: Configures gradient tracking when needed
599 /// 5. **Operation Registration**: Registers with gradtrack engine
600 fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
601 let len = elements.len();
602 let mut result = Self::new_uninitialized(vec![len]);
603
604 // Determine if we can track gradients
605 let requires_grad = elements.iter().any(|e| e.requires_grad());
606
607 // Copy data from element views
608 unsafe {
609 let dst = result.as_mut_ptr();
610 for (i, element) in elements.iter().enumerate() {
611 *dst.add(i) = *element.as_ptr();
612 }
613 }
614
615 // Set up gradient tracking
616 if requires_grad && is_grad_enabled() {
617 result.set_requires_grad_internal(true);
618 let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
619 let grad_fn = GradFn::ElementCollection {
620 element_ids: element_ids.clone(),
621 result_shape: vec![len],
622 };
623 result.set_grad_fn(grad_fn.clone());
624 GradEngine::register_operation(result.id(), element_ids, grad_fn);
625 }
626
627 result
628 }
629
630 /// General collection for mixed element shapes
631 ///
632 /// This method handles collection when elements have different shapes,
633 /// flattening all elements into a 1D tensor.
634 ///
635 /// This is the general path for collection when elements have varying shapes.
636 /// It flattens all elements into a single 1D tensor and preserves gradient
637 /// tracking when any input element requires gradients.
638 ///
639 /// # Arguments
640 ///
641 /// * `elements` - Vector of element tensors with potentially different shapes
642 ///
643 /// # Returns
644 ///
645 /// A new 1D tensor containing all flattened element values
646 ///
647 /// # Performance
648 ///
649 /// - **Flattening**: Converts all elements to 1D layout
650 /// - **Memory Copy**: Efficient copying with size calculation
651 /// - **Gradient Preservation**: Maintains gradtrack functionality
652 /// - **Mixed Shapes**: Handles elements with different dimensions
653 ///
654 /// # Implementation Details
655 ///
656 /// The method performs the following steps:
657 /// 1. **Size Calculation**: Sums sizes of all elements for total size
658 /// 2. **Allocation**: Creates uninitialized tensor with total size
659 /// 3. **Sequential Copy**: Copies each element's data sequentially
660 /// 4. **Gradient Setup**: Configures gradient tracking when needed
661 /// 5. **Operation Registration**: Registers with gradtrack engine
662 fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
663 // For mixed shapes, flatten all elements into a 1D tensor
664 let total_size: usize = elements.iter().map(|e| e.size()).sum();
665 let mut result = Self::new_uninitialized(vec![total_size]);
666
667 let requires_grad = elements.iter().any(|e| e.requires_grad());
668 let mut offset = 0;
669
670 unsafe {
671 let dst = result.as_mut_ptr();
672 for element in &elements {
673 let src = element.as_ptr();
674 let size = element.size();
675 std::ptr::copy_nonoverlapping(src, dst.add(offset), size);
676 offset += size;
677 }
678 }
679
680 if requires_grad && is_grad_enabled() {
681 result.set_requires_grad_internal(true);
682 let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
683 let grad_fn = GradFn::ElementCollection {
684 element_ids: element_ids.clone(),
685 result_shape: vec![total_size],
686 };
687 result.set_grad_fn(grad_fn.clone());
688 GradEngine::register_operation(result.id(), element_ids, grad_fn);
689 }
690
691 result
692 }
693
694 /// Create an iterator over tensor elements as view tensors
695 ///
696 /// Each element becomes a `Tensor` of shape `[1]` that supports all
697 /// tensor operations and gradient tracking. This is the main entry point
698 /// for element-wise iteration with full tensor operation support.
699 ///
700 /// The iterator provides zero-copy access to tensor elements through view
701 /// tensors, enabling efficient element-wise operations while maintaining
702 /// full compatibility with Rust's standard library iterator methods.
703 ///
704 /// # Returns
705 ///
706 /// An iterator that yields view tensors for each element
707 ///
708 /// # Performance
709 ///
710 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
711 /// - **O(1) Element Access**: Constant-time view creation for each element
712 /// - **Memory Efficient**: ~64 bytes overhead per element view
713 /// - **SIMD Compatible**: All tensor operations use existing optimizations
714 /// - **Gradient Tracking**: Full gradtrack support through element operations
715 ///
716 /// # Examples
717 ///
718 /// ## Basic Element Operations
719 ///
720 /// ```
721 /// use train_station::Tensor;
722 ///
723 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
724 ///
725 /// // Use any std iterator method
726 /// let result: Tensor = tensor.iter()
727 /// .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
728 /// .filter(|elem| elem.value() > 3.0) // Keep values > 3
729 /// .collect();
730 ///
731 /// assert_eq!(result.data(), &[5.0, 7.0]);
732 /// ```
733 ///
734 /// ## Advanced Iterator Chains
735 ///
736 /// ```
737 /// use train_station::Tensor;
738 ///
739 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
740 ///
741 /// // Chain with enumerate, zip, etc.
742 /// let indexed: Tensor = tensor.iter()
743 /// .enumerate()
744 /// .map(|(i, elem)| elem.add_scalar(i as f32))
745 /// .collect();
746 ///
747 /// assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
748 /// ```
749 ///
750 /// ## Double-Ended Iteration
751 ///
752 /// ```
753 /// use train_station::Tensor;
754 ///
755 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
756 ///
757 /// // Use double-ended iterator
758 /// let reversed: Tensor = tensor.iter()
759 /// .rev()
760 /// .collect();
761 ///
762 /// assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
763 /// ```
764 ///
765 /// ## Gradient Tracking
766 ///
767 /// ```
768 /// use train_station::Tensor;
769 ///
770 /// let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
771 /// .unwrap()
772 /// .with_requires_grad();
773 ///
774 /// let result: Tensor = tensor.iter()
775 /// .map(|elem| elem.mul_scalar(2.0))
776 /// .collect();
777 ///
778 /// assert!(result.requires_grad());
779 /// assert_eq!(result.data(), &[2.0, 4.0]);
780 /// ```
781 #[track_caller]
782 pub fn iter(&self) -> TensorElementIterator<'_> {
783 TensorElementIterator::new(self)
784 }
785
786 /// Create an iterator over a range of elements
787 ///
788 /// Creates an iterator that yields view tensors for elements in the specified
789 /// range. The range is automatically clamped to valid tensor bounds for safety.
790 ///
791 /// # Arguments
792 ///
793 /// * `start` - Starting index (inclusive)
794 /// * `end` - Ending index (exclusive)
795 ///
796 /// # Returns
797 ///
798 /// An iterator that yields view tensors for elements in the specified range
799 ///
800 /// # Safety
801 ///
802 /// The range is automatically clamped to valid tensor bounds:
803 /// - `start` is clamped to `[0, tensor.size()]`
804 /// - `end` is clamped to `[start, tensor.size()]`
805 /// - Empty ranges (start >= end) are handled gracefully
806 ///
807 /// # Performance
808 ///
809 /// - **O(1) Creation**: Constant-time iterator initialization
810 /// - **Bounds Checking**: Automatic range validation and clamping
811 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
812 /// - **Memory Efficient**: Minimal overhead for range iteration
813 ///
814 /// # Examples
815 ///
816 /// ## Basic Range Iteration
817 ///
818 /// ```
819 /// use train_station::Tensor;
820 ///
821 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
822 /// let middle: Tensor = tensor.iter_range(1, 4)
823 /// .map(|elem| elem.mul_scalar(2.0))
824 /// .collect();
825 ///
826 /// assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
827 /// ```
828 ///
829 /// ## Range with Operations
830 ///
831 /// ```
832 /// use train_station::Tensor;
833 ///
834 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
835 ///
836 /// // Apply complex operations to range
837 /// let result: Tensor = tensor.iter_range(0, 3)
838 /// .enumerate()
839 /// .map(|(i, elem)| elem.add_scalar(i as f32))
840 /// .collect();
841 ///
842 /// assert_eq!(result.data(), &[1.0, 3.0, 5.0]);
843 /// ```
844 ///
845 /// ## Out of Bounds Handling
846 ///
847 /// ```
848 /// use train_station::Tensor;
849 ///
850 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
851 ///
852 /// // Out of bounds range is clamped
853 /// let empty: Tensor = tensor.iter_range(5, 10).collect();
854 /// assert_eq!(empty.size(), 0);
855 ///
856 /// // Partial out of bounds
857 /// let partial: Tensor = tensor.iter_range(1, 10).collect();
858 /// assert_eq!(partial.data(), &[2.0, 3.0]);
859 /// ```
860 #[track_caller]
861 pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_> {
862 TensorElementIterator::with_range(self, start, end)
863 }
864}
865
866#[cfg(test)]
867mod tests {
868 //! Comprehensive tests for tensor element iterator functionality
869 //!
870 //! These tests cover all aspects of the iterator implementation:
871 //! - Basic iteration functionality
872 //! - Standard library trait compliance
873 //! - Gradient tracking through element operations
874 //! - Performance characteristics
875 //! - Edge cases and error conditions
876
877 use super::*;
878
879 /// Test basic iterator functionality
880 #[test]
881 fn test_basic_iteration() {
882 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
883
884 let elements: Vec<Tensor> = tensor.iter().collect();
885 assert_eq!(elements.len(), 4);
886
887 // Check that each element is a scalar tensor with correct value
888 for (i, elem) in elements.iter().enumerate() {
889 assert_eq!(elem.shape().dims, vec![1]);
890 assert_eq!(elem.size(), 1);
891 assert_eq!(elem.value(), (i + 1) as f32);
892 }
893 }
894
895 /// Test Iterator trait methods
896 #[test]
897 fn test_iterator_trait_methods() {
898 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
899 let mut iter = tensor.iter();
900
901 // Test next()
902 let first = iter.next().unwrap();
903 assert_eq!(first.value(), 1.0);
904
905 // Test size_hint()
906 assert_eq!(iter.size_hint(), (4, Some(4)));
907
908 // Test count()
909 assert_eq!(iter.count(), 4);
910
911 // Test nth()
912 let mut iter = tensor.iter();
913 let third = iter.nth(2).unwrap();
914 assert_eq!(third.value(), 3.0);
915
916 // Test last()
917 let iter = tensor.iter();
918 let last = iter.last().unwrap();
919 assert_eq!(last.value(), 5.0);
920 }
921
922 /// Test ExactSizeIterator
923 #[test]
924 fn test_exact_size_iterator() {
925 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
926 let iter = tensor.iter();
927
928 assert_eq!(iter.len(), 3);
929
930 // Test that len() decreases as we consume the iterator
931 let mut iter = tensor.iter();
932 assert_eq!(iter.len(), 3);
933 iter.next();
934 assert_eq!(iter.len(), 2);
935 iter.next();
936 assert_eq!(iter.len(), 1);
937 iter.next();
938 assert_eq!(iter.len(), 0);
939 }
940
941 /// Test DoubleEndedIterator
942 #[test]
943 fn test_double_ended_iterator() {
944 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
945 let mut iter = tensor.iter();
946
947 // Test next_back()
948 let last = iter.next_back().unwrap();
949 assert_eq!(last.value(), 4.0);
950
951 let first = iter.next().unwrap();
952 assert_eq!(first.value(), 1.0);
953
954 // Test nth_back()
955 let mut iter = tensor.iter();
956 let second_to_last = iter.nth_back(1).unwrap();
957 assert_eq!(second_to_last.value(), 3.0);
958
959 // Test consuming from both ends
960 let mut iter = tensor.iter();
961 assert_eq!(iter.next().unwrap().value(), 1.0);
962 assert_eq!(iter.next_back().unwrap().value(), 4.0);
963 assert_eq!(iter.next().unwrap().value(), 2.0);
964 assert_eq!(iter.next_back().unwrap().value(), 3.0);
965 assert!(iter.next().is_none());
966 }
967
968 /// Test IntoIterator trait
969 #[test]
970 fn test_into_iterator() {
971 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
972
973 // Test with for loop
974 let mut values = Vec::new();
975 for element in &tensor {
976 values.push(element.value());
977 }
978 assert_eq!(values, vec![1.0, 2.0, 3.0]);
979
980 // Test with into_iter() explicitly
981 let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
982 assert_eq!(values, vec![1.0, 2.0, 3.0]);
983 }
984
985 /// Test FromIterator trait (collect)
986 #[test]
987 fn test_from_iterator() {
988 let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
989
990 // Test collecting back to tensor
991 let collected: Tensor = original.iter().collect();
992 assert_eq!(collected.shape().dims, vec![4]);
993 assert_eq!(collected.data(), original.data());
994
995 // Test collecting with transformations
996 let doubled: Tensor = original
997 .iter()
998 .map(|elem| {
999 let val = elem.value();
1000 Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
1001 })
1002 .collect();
1003
1004 assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
1005 }
1006
1007 /// Test standard library iterator methods
1008 #[test]
1009 fn test_std_iterator_methods() {
1010 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1011
1012 // Test map
1013 let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
1014 assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1015
1016 // Test filter
1017 let large_values: Vec<f32> = tensor
1018 .iter()
1019 .filter(|elem| elem.value() > 3.0)
1020 .map(|elem| elem.value())
1021 .collect();
1022 assert_eq!(large_values, vec![4.0, 5.0]);
1023
1024 // Test enumerate
1025 let with_indices: Vec<(usize, f32)> = tensor
1026 .iter()
1027 .enumerate()
1028 .map(|(i, elem)| (i, elem.value()))
1029 .collect();
1030 assert_eq!(
1031 with_indices,
1032 vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
1033 );
1034
1035 // Test fold
1036 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
1037 assert_eq!(sum, 15.0);
1038
1039 // Test find
1040 let found = tensor.iter().find(|elem| elem.value() == 3.0);
1041 assert!(found.is_some());
1042 assert_eq!(found.unwrap().value(), 3.0);
1043
1044 // Test any/all
1045 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
1046 assert!(all_positive);
1047
1048 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
1049 assert!(any_large);
1050 }
1051
1052 /// Test element operations with tensor methods
1053 #[test]
1054 fn test_element_tensor_operations() {
1055 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1056
1057 // Test scalar operations on elements
1058 let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1059 assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
1060
1061 let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
1062 assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
1063
1064 // Test chaining operations
1065 let complex: Tensor = tensor
1066 .iter()
1067 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
1068 .collect();
1069 assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
1070 }
1071
1072 /// Test gradient tracking through element operations
1073 #[test]
1074 fn test_gradient_tracking() {
1075 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
1076 .unwrap()
1077 .with_requires_grad();
1078
1079 // Perform element-wise operations
1080 let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1081
1082 // The result should require gradients if any element requires gradients
1083 // Note: Current implementation creates copies, so gradient tracking is
1084 // implemented but may not propagate back to original tensor
1085 assert!(result.requires_grad());
1086
1087 // For now, just verify the forward pass works with gradient-enabled tensors
1088 // Full gradient propagation would require true view implementation
1089 assert_eq!(result.data(), &[2.0, 4.0]);
1090 }
1091
1092 /// Test with zero-sized tensors
1093 #[test]
1094 fn test_zero_sized_tensor() {
1095 let empty = Tensor::new(vec![0]);
1096 let iter = empty.iter();
1097
1098 assert_eq!(iter.len(), 0);
1099 assert_eq!(iter.size_hint(), (0, Some(0)));
1100
1101 let collected: Tensor = iter.collect();
1102 assert_eq!(collected.size(), 0);
1103 }
1104
1105 /// Test range iteration
1106 #[test]
1107 fn test_range_iteration() {
1108 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1109
1110 // Test middle range
1111 let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
1112 assert_eq!(middle, vec![2.0, 3.0, 4.0]);
1113
1114 // Test out of bounds (should be clamped)
1115 let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
1116 assert_eq!(clamped, vec![4.0, 5.0]);
1117
1118 // Test empty range
1119 let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
1120 assert_eq!(empty, Vec::<f32>::new());
1121 }
1122
1123 /// Test complex iterator chains
1124 #[test]
1125 fn test_complex_chains() {
1126 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
1127
1128 // Complex chain: enumerate -> filter -> map -> collect
1129 let result: Tensor = tensor
1130 .iter()
1131 .enumerate()
1132 .filter(|(i, _)| i % 2 == 0) // Take even indices
1133 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
1134 .collect();
1135
1136 // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
1137 assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
1138
1139 // Test with rev()
1140 let reversed: Tensor = tensor.iter().rev().take(3).collect();
1141
1142 assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
1143 }
1144
1145 /// Performance test for iterator overhead
1146 #[test]
1147 fn test_performance() {
1148 let large_tensor =
1149 Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
1150 .unwrap();
1151
1152 let start = std::time::Instant::now();
1153
1154 let result: Tensor = large_tensor
1155 .iter()
1156 .map(|elem| elem.mul_scalar(2.0))
1157 .collect();
1158
1159 let duration = start.elapsed();
1160 println!("Iterator performance test took: {:?}", duration);
1161
1162 // Verify correctness
1163 assert_eq!(result.size(), 1000);
1164 assert_eq!(result.data()[0], 0.0);
1165 assert_eq!(result.data()[999], 1998.0);
1166 }
1167}