train_station/tensor/iterator/mod.rs
1//! Iterator module for tensor element-wise operations
2//!
3//! This module provides high-performance iterators over tensor elements, where each
4//! element is represented as a view tensor of shape `[1]`. This design allows for
5//! seamless integration with Rust's standard library iterator methods while
6//! leveraging the existing tensor operation framework and gradient tracking.
7//!
8//! # Key Features
9//!
10//! - **Standard Library Compatibility**: Full implementation of Iterator, ExactSizeIterator,
11//! DoubleEndedIterator, FusedIterator, IntoIterator, and FromIterator traits
12//! - **Gradient Tracking**: Automatic gradient propagation through element operations
13//! - **Performance Optimized**: True zero-copy views with shared memory
14//! - **SIMD Compatible**: All operations use existing optimized tensor implementations
15//! - **Memory Efficient**: Adaptive view creation based on tensor size
16//! - **Zero-Copy Operations**: Element views share memory with source tensor
17//! - **Full Tensor Operations**: Each element supports all tensor methods
18//!
19//! # Performance Characteristics
20//!
21//! - **View Creation**: O(1) per element with true zero-copy views
22//! - **Memory Overhead**: ~64 bytes per view tensor (no data copying)
23//! - **SIMD Operations**: Full utilization of existing optimizations
24//! - **Gradient Tracking**: True gradient flow with element-level accumulation
25//! - **Iterator Overhead**: Minimal performance impact for element access
26//! - **Collection Optimization**: Efficient reconstruction from element views
27//!
28//! # Examples
29//!
30//! ## Basic Element Iteration
31//!
32//! ```
33//! use train_station::Tensor;
34//!
35//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
36//!
37//! // Basic iteration over elements
38//! for element in tensor.iter() {
39//! println!("Element value: {}", element.value());
40//! }
41//!
42//! // Collect elements into a new tensor
43//! let collected: Tensor = tensor.iter().collect();
44//! assert_eq!(collected.data(), tensor.data());
45//! ```
46//!
47//! ## Element-Wise Transformations
48//!
49//! ```
50//! use train_station::Tensor;
51//!
52//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
53//!
54//! // Apply tensor operations to each element
55//! let doubled: Tensor = tensor.iter()
56//! .map(|elem| elem.mul_scalar(2.0))
57//! .collect();
58//!
59//! assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
60//!
61//! // Chain multiple operations
62//! let transformed: Tensor = tensor.iter()
63//! .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
64//! .collect();
65//!
66//! assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
67//! ```
68//!
69//! ## Advanced Iterator Operations
70//!
71//! ```
72//! use train_station::Tensor;
73//!
74//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
75//!
76//! // Filter elements based on values
77//! let large_values: Tensor = tensor.iter()
78//! .filter(|elem| elem.value() > 3.0)
79//! .collect();
80//!
81//! assert_eq!(large_values.data(), &[4.0, 5.0]);
82//!
83//! // Use enumerate for indexed operations
84//! let indexed: Tensor = tensor.iter()
85//! .enumerate()
86//! .map(|(i, elem)| elem.add_scalar(i as f32))
87//! .collect();
88//!
89//! assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
90//! ```
91//!
92//! ## Range Iteration
93//!
94//! ```
95//! use train_station::Tensor;
96//!
97//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
98//!
99//! // Iterate over a specific range
100//! let middle: Tensor = tensor.iter_range(1, 4)
101//! .map(|elem| elem.mul_scalar(2.0))
102//! .collect();
103//!
104//! assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
105//! ```
106//!
107//! ## Double-Ended Iteration
108//!
109//! ```
110//! use train_station::Tensor;
111//!
112//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
113//!
114//! // Reverse iteration
115//! let reversed: Tensor = tensor.iter().rev().collect();
116//! assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
117//!
118//! // Iterate from both ends
119//! let mut iter = tensor.iter();
120//! assert_eq!(iter.next().unwrap().value(), 1.0);
121//! assert_eq!(iter.next_back().unwrap().value(), 4.0);
122//! ```
123//!
124//! ## Gradient Tracking
125//!
126//! ```
127//! use train_station::Tensor;
128//!
129//! let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
130//! .unwrap()
131//! .with_requires_grad();
132//!
133//! // Element operations maintain gradient tracking
134//! let result: Tensor = tensor.iter()
135//! .map(|elem| elem.mul_scalar(2.0))
136//! .collect();
137//!
138//! assert!(result.requires_grad());
139//! assert_eq!(result.data(), &[2.0, 4.0]);
140//! ```
141//!
142//! # Design Principles
143//!
144//! - **Zero-Copy Views**: Element views share memory with source tensor
145//! - **Full Tensor Operations**: Each element supports all tensor methods
146//! - **Standard Library Integration**: Complete compatibility with Rust iterators
147//! - **Performance First**: Optimized for high-performance element access
148//! - **Gradient Preservation**: Maintains gradtrack functionality through operations
149//! - **Memory Efficiency**: Minimal overhead for element iteration
150//! - **Type Safety**: Compile-time guarantees for iterator operations
151
152use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
153use crate::tensor::core::Tensor;
154use std::iter::{FromIterator, FusedIterator};
155
156/// High-performance iterator over tensor elements as view tensors
157///
158/// Each element becomes a proper `Tensor` view of shape `[1]` that can use
159/// all existing tensor operations and gradient tracking. Implements all
160/// standard iterator traits for maximum compatibility with Rust's ecosystem.
161///
162/// This iterator provides zero-copy access to tensor elements through view
163/// tensors, enabling efficient element-wise operations while maintaining
164/// full compatibility with Rust's standard library iterator methods.
165///
166/// # Performance
167///
168/// - **Zero-Copy Views**: Each element is a view tensor sharing memory with source
169/// - **O(1) Element Access**: Constant-time view creation for each element
170/// - **Memory Efficient**: ~64 bytes overhead per element view
171/// - **SIMD Compatible**: All tensor operations use existing optimizations
172/// - **Gradient Tracking**: Full gradtrack support through element operations
173///
174/// # Implementation Details
175///
176/// The iterator creates lightweight view tensors on-demand, sharing the same
177/// memory allocation as the source tensor. This ensures zero-copy semantics
178/// while maintaining full tensor operation compatibility.
179///
180/// Each element view is created using `Tensor::element_view()`, which provides
181/// a true view of the underlying data without any copying. The view tensors
182/// support all standard tensor operations including gradient tracking.
183///
184/// # Standard Library Compatibility
185///
186/// This iterator implements all standard iterator traits:
187/// - `Iterator`: Basic iteration with `next()` and `size_hint()`
188/// - `ExactSizeIterator`: Precise size information with `len()`
189/// - `DoubleEndedIterator`: Reverse iteration with `next_back()`
190/// - `FusedIterator`: Fused iteration for better performance
191/// - `IntoIterator`: Automatic conversion for `for` loops
192///
193/// # Examples
194///
195/// ## Basic Iteration
196///
197/// ```
198/// use train_station::Tensor;
199///
200/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
201///
202/// // Basic iteration
203/// for element in tensor.iter() {
204/// println!("Element value: {}", element.value());
205/// }
206///
207/// // Standard library methods
208/// let sum: f32 = tensor.iter()
209/// .map(|elem| elem.value())
210/// .sum();
211///
212/// assert_eq!(sum, 6.0);
213/// ```
214///
215/// ## Element Operations
216///
217/// ```
218/// use train_station::Tensor;
219///
220/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
221///
222/// // Tensor operations on elements
223/// let transformed: Tensor = tensor.iter()
224/// .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
225/// .collect();
226///
227/// assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
228/// ```
229///
230/// ## Advanced Iterator Methods
231///
232/// ```
233/// use train_station::Tensor;
234///
235/// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
236///
237/// // Filter and transform
238/// let result: Tensor = tensor.iter()
239/// .filter(|elem| elem.value() > 2.0)
240/// .map(|elem| elem.mul_scalar(10.0))
241/// .collect();
242///
243/// assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
244///
245/// // Reverse iteration
246/// let reversed: Tensor = tensor.iter().rev().collect();
247/// assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);
248/// ```
249pub struct TensorElementIterator<'a> {
250 /// Reference to the source tensor
251 source: &'a Tensor,
252 /// Current position in iteration
253 position: usize,
254 /// End position (exclusive)
255 end: usize,
256}
257
258impl<'a> TensorElementIterator<'a> {
259 /// Create a new iterator over all tensor elements
260 ///
261 /// Creates an iterator that yields view tensors for each element in the
262 /// source tensor. Each element becomes a `Tensor` of shape `[1]` that
263 /// supports all tensor operations and gradient tracking.
264 ///
265 /// # Arguments
266 ///
267 /// * `tensor` - The source tensor to iterate over
268 ///
269 /// # Returns
270 ///
271 /// An iterator that yields view tensors for each element
272 ///
273 /// # Performance
274 ///
275 /// - **O(1) Creation**: Constant-time iterator initialization
276 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
277 /// - **Memory Efficient**: Minimal overhead for iterator state
278 ///
279 /// # Implementation Details
280 ///
281 /// This method creates an iterator that yields view tensors for each element
282 /// in the source tensor. Each element becomes a `Tensor` of shape `[1]` that
283 /// supports all tensor operations and gradient tracking.
284 ///
285 /// The iterator provides zero-copy access to tensor elements through view
286 /// tensors, enabling efficient element-wise operations while maintaining
287 /// full compatibility with Rust's standard library iterator methods.
288 pub fn new(tensor: &'a Tensor) -> Self {
289 Self {
290 source: tensor,
291 position: 0,
292 end: tensor.size(),
293 }
294 }
295
296 /// Create an iterator over a specific range of elements
297 ///
298 /// Creates an iterator that yields view tensors for elements in the specified
299 /// range. The range is automatically clamped to valid tensor bounds for safety.
300 ///
301 /// # Arguments
302 ///
303 /// * `tensor` - The source tensor to iterate over
304 /// * `start` - Starting index (inclusive)
305 /// * `end` - Ending index (exclusive)
306 ///
307 /// # Returns
308 ///
309 /// An iterator that yields view tensors for elements in the specified range
310 ///
311 /// # Safety
312 ///
313 /// The range is automatically clamped to valid tensor bounds:
314 /// - `start` is clamped to `[0, tensor.size()]`
315 /// - `end` is clamped to `[start, tensor.size()]`
316 /// - Empty ranges (start >= end) are handled gracefully
317 ///
318 /// # Performance
319 ///
320 /// - **O(1) Creation**: Constant-time iterator initialization
321 /// - **Bounds Checking**: Automatic range validation and clamping
322 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
323 ///
324 /// # Implementation Details
325 ///
326 /// This method creates an iterator that yields view tensors for elements in
327 /// the specified range. The range is automatically clamped to valid tensor
328 /// bounds for safety, ensuring that out-of-bounds access is handled gracefully.
329 ///
330 /// The iterator provides zero-copy access to tensor elements through view
331 /// tensors, enabling efficient element-wise operations while maintaining
332 /// full compatibility with Rust's standard library iterator methods.
333 pub fn with_range(tensor: &'a Tensor, start: usize, end: usize) -> Self {
334 let end = end.min(tensor.size());
335 let start = start.min(end);
336 Self {
337 source: tensor,
338 position: start,
339 end,
340 }
341 }
342
343 /// Create an optimized element view for the given position
344 ///
345 /// This method creates a true view tensor of shape `[1]` that shares memory
346 /// with the element at the specified index in the source tensor. The view
347 /// enables zero-copy element access with full gradient tracking.
348 ///
349 /// # Arguments
350 ///
351 /// * `index` - Index of the element to create a view for
352 ///
353 /// # Returns
354 ///
355 /// A view tensor of shape `[1]` representing the element at the specified index
356 ///
357 /// # Safety
358 ///
359 /// The caller must ensure that `index < self.source.size()`.
360 ///
361 /// # Performance
362 ///
363 /// - **O(1) View Creation**: Constant-time view tensor creation
364 /// - **Zero-Copy**: View shares memory with source tensor
365 /// - **Memory Efficient**: ~64 bytes overhead for view metadata
366 /// - **Gradient Tracking**: Full gradtrack support through view operations
367 ///
368 /// # Implementation Details
369 ///
370 /// This method delegates to `Tensor::element_view()` which creates a true
371 /// view of the underlying data without any copying. The view tensor supports
372 /// all standard tensor operations including gradient tracking and SIMD
373 /// optimizations.
374 fn create_element_view(&self, index: usize) -> Tensor {
375 debug_assert!(index < self.source.size());
376
377 self.source.element_view(index)
378 }
379}
380
381// ===== Core Iterator Implementation =====
382
383impl<'a> Iterator for TensorElementIterator<'a> {
384 type Item = Tensor;
385
386 #[inline]
387 fn next(&mut self) -> Option<Self::Item> {
388 if self.position < self.end {
389 let view = self.create_element_view(self.position);
390 self.position += 1;
391 Some(view)
392 } else {
393 None
394 }
395 }
396
397 #[inline]
398 fn size_hint(&self) -> (usize, Option<usize>) {
399 let remaining = self.end - self.position;
400 (remaining, Some(remaining))
401 }
402
403 #[inline]
404 fn count(self) -> usize {
405 self.end - self.position
406 }
407
408 #[inline]
409 fn nth(&mut self, n: usize) -> Option<Self::Item> {
410 let new_pos = self.position.saturating_add(n);
411 if new_pos < self.end {
412 self.position = new_pos + 1;
413 Some(self.create_element_view(new_pos))
414 } else {
415 self.position = self.end;
416 None
417 }
418 }
419
420 #[inline]
421 fn last(self) -> Option<Self::Item> {
422 if self.position < self.end {
423 let last_idx = self.end - 1;
424 Some(self.create_element_view(last_idx))
425 } else {
426 None
427 }
428 }
429}
430
431impl<'a> ExactSizeIterator for TensorElementIterator<'a> {
432 #[inline]
433 fn len(&self) -> usize {
434 self.end - self.position
435 }
436}
437
438impl<'a> FusedIterator for TensorElementIterator<'a> {}
439
440impl<'a> DoubleEndedIterator for TensorElementIterator<'a> {
441 #[inline]
442 fn next_back(&mut self) -> Option<Self::Item> {
443 if self.position < self.end {
444 self.end -= 1;
445 Some(self.create_element_view(self.end))
446 } else {
447 None
448 }
449 }
450
451 #[inline]
452 fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
453 let new_end = self.end.saturating_sub(n + 1);
454 if new_end >= self.position {
455 self.end = new_end;
456 Some(self.create_element_view(self.end))
457 } else {
458 self.position = self.end;
459 None
460 }
461 }
462}
463
464// ===== IntoIterator Implementation =====
465
466impl<'a> IntoIterator for &'a Tensor {
467 type Item = Tensor;
468 type IntoIter = TensorElementIterator<'a>;
469
470 fn into_iter(self) -> Self::IntoIter {
471 TensorElementIterator::new(self)
472 }
473}
474
475// ===== FromIterator Implementation =====
476
477impl FromIterator<Tensor> for Tensor {
478 /// Collect element view tensors back into a single tensor
479 ///
480 /// This method reconstructs a tensor from an iterator of element view tensors.
481 /// It includes optimizations for common patterns and maintains gradient tracking
482 /// when appropriate.
483 ///
484 /// The collection process automatically detects whether all elements are scalar
485 /// views (shape `[1]`) and uses optimized collection strategies accordingly.
486 /// Gradient tracking is preserved when any input element requires gradients.
487 ///
488 /// # Performance
489 ///
490 /// - **Optimized Collection**: Specialized paths for scalar and mixed views
491 /// - **Memory Efficient**: Direct memory copying without intermediate allocations
492 /// - **Gradient Preservation**: Maintains gradtrack functionality when enabled
493 /// - **Shape Detection**: Automatic detection of element shapes for optimization
494 ///
495 /// # Implementation Details
496 ///
497 /// The method performs the following steps:
498 /// 1. **Element Collection**: Gathers all element tensors from the iterator
499 /// 2. **Shape Analysis**: Determines if all elements are scalar views
500 /// 3. **Optimized Path**: Uses specialized collection for scalar views
501 /// 4. **General Path**: Handles mixed shapes by flattening into 1D tensor
502 /// 5. **Gradient Setup**: Preserves gradient tracking when appropriate
503 ///
504 /// # Examples
505 ///
506 /// ## Basic Collection
507 ///
508 /// ```
509 /// use train_station::Tensor;
510 ///
511 /// let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
512 /// let doubled: Tensor = original.iter()
513 /// .map(|elem| elem.mul_scalar(2.0))
514 /// .collect();
515 ///
516 /// assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
517 /// ```
518 ///
519 /// ## Collection with Gradient Tracking
520 ///
521 /// ```
522 /// use train_station::Tensor;
523 ///
524 /// let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
525 /// .unwrap()
526 /// .with_requires_grad();
527 ///
528 /// let result: Tensor = original.iter()
529 /// .map(|elem| elem.mul_scalar(2.0))
530 /// .collect();
531 ///
532 /// assert!(result.requires_grad());
533 /// assert_eq!(result.data(), &[2.0, 4.0]);
534 /// ```
535 ///
536 /// ## Empty Iterator Handling
537 ///
538 /// ```
539 /// use train_station::Tensor;
540 ///
541 /// let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
542 /// assert_eq!(empty.size(), 0);
543 /// assert_eq!(empty.shape().dims, vec![0]);
544 /// ```
545 fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self {
546 let elements: Vec<Tensor> = iter.into_iter().collect();
547
548 if elements.is_empty() {
549 return Tensor::new(vec![0]);
550 }
551
552 // Check if all elements are scalar views (shape [1])
553 let all_scalars = elements.iter().all(|e| e.shape().dims == vec![1]);
554
555 if all_scalars {
556 // Optimized path for scalar element views
557 Self::collect_scalar_views(elements)
558 } else {
559 // General path for mixed shapes
560 Self::collect_mixed_views(elements)
561 }
562 }
563}
564
565impl Tensor {
566 /// Optimized collection for scalar element views
567 ///
568 /// This method efficiently reconstructs a tensor from scalar element views,
569 /// preserving gradient tracking and using optimized memory operations.
570 ///
571 /// This is the fast path for collection when all elements are scalar views
572 /// (shape `[1]`). It performs direct memory copying and sets up gradient
573 /// tracking when any input element requires gradients.
574 ///
575 /// # Arguments
576 ///
577 /// * `elements` - Vector of scalar element view tensors
578 ///
579 /// # Returns
580 ///
581 /// A new tensor containing all element values in a 1D layout
582 ///
583 /// # Performance
584 ///
585 /// - **Direct Memory Copy**: Single-pass copying without intermediate allocations
586 /// - **Gradient Optimization**: Efficient gradient tracking setup
587 /// - **Memory Efficient**: Minimal overhead for collection process
588 /// - **SIMD Compatible**: Result tensor supports all optimizations
589 ///
590 /// # Implementation Details
591 ///
592 /// The method performs the following steps:
593 /// 1. **Allocation**: Creates uninitialized tensor with correct size
594 /// 2. **Gradient Check**: Determines if any element requires gradients
595 /// 3. **Memory Copy**: Direct copying from element views to result
596 /// 4. **Gradient Setup**: Configures gradient tracking when needed
597 /// 5. **Operation Registration**: Registers with gradtrack engine
598 fn collect_scalar_views(elements: Vec<Tensor>) -> Self {
599 let len = elements.len();
600 let mut result = Self::new_uninitialized(vec![len]);
601
602 // Determine if we can track gradients
603 let requires_grad = elements.iter().any(|e| e.requires_grad());
604
605 // Copy data from element views
606 unsafe {
607 let dst = result.as_mut_ptr();
608 for (i, element) in elements.iter().enumerate() {
609 *dst.add(i) = *element.as_ptr();
610 }
611 }
612
613 // Set up gradient tracking
614 if requires_grad && is_grad_enabled() {
615 result.set_requires_grad_internal(true);
616 let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
617 let grad_fn = GradFn::ElementCollection {
618 element_ids: element_ids.clone(),
619 result_shape: vec![len],
620 };
621 result.set_grad_fn(grad_fn.clone());
622 GradEngine::register_operation(result.id(), element_ids, grad_fn);
623 }
624
625 result
626 }
627
628 /// General collection for mixed element shapes
629 ///
630 /// This method handles collection when elements have different shapes,
631 /// flattening all elements into a 1D tensor.
632 ///
633 /// This is the general path for collection when elements have varying shapes.
634 /// It flattens all elements into a single 1D tensor and preserves gradient
635 /// tracking when any input element requires gradients.
636 ///
637 /// # Arguments
638 ///
639 /// * `elements` - Vector of element tensors with potentially different shapes
640 ///
641 /// # Returns
642 ///
643 /// A new 1D tensor containing all flattened element values
644 ///
645 /// # Performance
646 ///
647 /// - **Flattening**: Converts all elements to 1D layout
648 /// - **Memory Copy**: Efficient copying with size calculation
649 /// - **Gradient Preservation**: Maintains gradtrack functionality
650 /// - **Mixed Shapes**: Handles elements with different dimensions
651 ///
652 /// # Implementation Details
653 ///
654 /// The method performs the following steps:
655 /// 1. **Size Calculation**: Sums sizes of all elements for total size
656 /// 2. **Allocation**: Creates uninitialized tensor with total size
657 /// 3. **Sequential Copy**: Copies each element's data sequentially
658 /// 4. **Gradient Setup**: Configures gradient tracking when needed
659 /// 5. **Operation Registration**: Registers with gradtrack engine
660 fn collect_mixed_views(elements: Vec<Tensor>) -> Self {
661 // For mixed shapes, flatten all elements into a 1D tensor
662 let total_size: usize = elements.iter().map(|e| e.size()).sum();
663 let mut result = Self::new_uninitialized(vec![total_size]);
664
665 let requires_grad = elements.iter().any(|e| e.requires_grad());
666 let mut offset = 0;
667
668 unsafe {
669 let dst = result.as_mut_ptr();
670 for element in &elements {
671 let src = element.as_ptr();
672 let size = element.size();
673 std::ptr::copy_nonoverlapping(src, dst.add(offset), size);
674 offset += size;
675 }
676 }
677
678 if requires_grad && is_grad_enabled() {
679 result.set_requires_grad_internal(true);
680 let element_ids: Vec<usize> = elements.iter().map(|e| e.id()).collect();
681 let grad_fn = GradFn::ElementCollection {
682 element_ids: element_ids.clone(),
683 result_shape: vec![total_size],
684 };
685 result.set_grad_fn(grad_fn.clone());
686 GradEngine::register_operation(result.id(), element_ids, grad_fn);
687 }
688
689 result
690 }
691
692 /// Create an iterator over tensor elements as view tensors
693 ///
694 /// Each element becomes a `Tensor` of shape `[1]` that supports all
695 /// tensor operations and gradient tracking. This is the main entry point
696 /// for element-wise iteration with full tensor operation support.
697 ///
698 /// The iterator provides zero-copy access to tensor elements through view
699 /// tensors, enabling efficient element-wise operations while maintaining
700 /// full compatibility with Rust's standard library iterator methods.
701 ///
702 /// # Returns
703 ///
704 /// An iterator that yields view tensors for each element
705 ///
706 /// # Performance
707 ///
708 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
709 /// - **O(1) Element Access**: Constant-time view creation for each element
710 /// - **Memory Efficient**: ~64 bytes overhead per element view
711 /// - **SIMD Compatible**: All tensor operations use existing optimizations
712 /// - **Gradient Tracking**: Full gradtrack support through element operations
713 ///
714 /// # Examples
715 ///
716 /// ## Basic Element Operations
717 ///
718 /// ```
719 /// use train_station::Tensor;
720 ///
721 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
722 ///
723 /// // Use any std iterator method
724 /// let result: Tensor = tensor.iter()
725 /// .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
726 /// .filter(|elem| elem.value() > 3.0) // Keep values > 3
727 /// .collect();
728 ///
729 /// assert_eq!(result.data(), &[5.0, 7.0]);
730 /// ```
731 ///
732 /// ## Advanced Iterator Chains
733 ///
734 /// ```
735 /// use train_station::Tensor;
736 ///
737 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
738 ///
739 /// // Chain with enumerate, zip, etc.
740 /// let indexed: Tensor = tensor.iter()
741 /// .enumerate()
742 /// .map(|(i, elem)| elem.add_scalar(i as f32))
743 /// .collect();
744 ///
745 /// assert_eq!(indexed.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
746 /// ```
747 ///
748 /// ## Double-Ended Iteration
749 ///
750 /// ```
751 /// use train_station::Tensor;
752 ///
753 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
754 ///
755 /// // Use double-ended iterator
756 /// let reversed: Tensor = tensor.iter()
757 /// .rev()
758 /// .collect();
759 ///
760 /// assert_eq!(reversed.data(), &[4.0, 3.0, 2.0, 1.0]);
761 /// ```
762 ///
763 /// ## Gradient Tracking
764 ///
765 /// ```
766 /// use train_station::Tensor;
767 ///
768 /// let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
769 /// .unwrap()
770 /// .with_requires_grad();
771 ///
772 /// let result: Tensor = tensor.iter()
773 /// .map(|elem| elem.mul_scalar(2.0))
774 /// .collect();
775 ///
776 /// assert!(result.requires_grad());
777 /// assert_eq!(result.data(), &[2.0, 4.0]);
778 /// ```
779 pub fn iter(&self) -> TensorElementIterator {
780 TensorElementIterator::new(self)
781 }
782
783 /// Create an iterator over a range of elements
784 ///
785 /// Creates an iterator that yields view tensors for elements in the specified
786 /// range. The range is automatically clamped to valid tensor bounds for safety.
787 ///
788 /// # Arguments
789 ///
790 /// * `start` - Starting index (inclusive)
791 /// * `end` - Ending index (exclusive)
792 ///
793 /// # Returns
794 ///
795 /// An iterator that yields view tensors for elements in the specified range
796 ///
797 /// # Safety
798 ///
799 /// The range is automatically clamped to valid tensor bounds:
800 /// - `start` is clamped to `[0, tensor.size()]`
801 /// - `end` is clamped to `[start, tensor.size()]`
802 /// - Empty ranges (start >= end) are handled gracefully
803 ///
804 /// # Performance
805 ///
806 /// - **O(1) Creation**: Constant-time iterator initialization
807 /// - **Bounds Checking**: Automatic range validation and clamping
808 /// - **Zero-Copy Views**: Each element is a view sharing memory with source
809 /// - **Memory Efficient**: Minimal overhead for range iteration
810 ///
811 /// # Examples
812 ///
813 /// ## Basic Range Iteration
814 ///
815 /// ```
816 /// use train_station::Tensor;
817 ///
818 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
819 /// let middle: Tensor = tensor.iter_range(1, 4)
820 /// .map(|elem| elem.mul_scalar(2.0))
821 /// .collect();
822 ///
823 /// assert_eq!(middle.data(), &[4.0, 6.0, 8.0]);
824 /// ```
825 ///
826 /// ## Range with Operations
827 ///
828 /// ```
829 /// use train_station::Tensor;
830 ///
831 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
832 ///
833 /// // Apply complex operations to range
834 /// let result: Tensor = tensor.iter_range(0, 3)
835 /// .enumerate()
836 /// .map(|(i, elem)| elem.add_scalar(i as f32))
837 /// .collect();
838 ///
839 /// assert_eq!(result.data(), &[1.0, 3.0, 5.0]);
840 /// ```
841 ///
842 /// ## Out of Bounds Handling
843 ///
844 /// ```
845 /// use train_station::Tensor;
846 ///
847 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
848 ///
849 /// // Out of bounds range is clamped
850 /// let empty: Tensor = tensor.iter_range(5, 10).collect();
851 /// assert_eq!(empty.size(), 0);
852 ///
853 /// // Partial out of bounds
854 /// let partial: Tensor = tensor.iter_range(1, 10).collect();
855 /// assert_eq!(partial.data(), &[2.0, 3.0]);
856 /// ```
857 pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator {
858 TensorElementIterator::with_range(self, start, end)
859 }
860}
861
862#[cfg(test)]
863mod tests {
864 //! Comprehensive tests for tensor element iterator functionality
865 //!
866 //! These tests cover all aspects of the iterator implementation:
867 //! - Basic iteration functionality
868 //! - Standard library trait compliance
869 //! - Gradient tracking through element operations
870 //! - Performance characteristics
871 //! - Edge cases and error conditions
872
873 use super::*;
874
875 /// Test basic iterator functionality
876 #[test]
877 fn test_basic_iteration() {
878 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
879
880 let elements: Vec<Tensor> = tensor.iter().collect();
881 assert_eq!(elements.len(), 4);
882
883 // Check that each element is a scalar tensor with correct value
884 for (i, elem) in elements.iter().enumerate() {
885 assert_eq!(elem.shape().dims, vec![1]);
886 assert_eq!(elem.size(), 1);
887 assert_eq!(elem.value(), (i + 1) as f32);
888 }
889 }
890
891 /// Test Iterator trait methods
892 #[test]
893 fn test_iterator_trait_methods() {
894 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
895 let mut iter = tensor.iter();
896
897 // Test next()
898 let first = iter.next().unwrap();
899 assert_eq!(first.value(), 1.0);
900
901 // Test size_hint()
902 assert_eq!(iter.size_hint(), (4, Some(4)));
903
904 // Test count()
905 assert_eq!(iter.count(), 4);
906
907 // Test nth()
908 let mut iter = tensor.iter();
909 let third = iter.nth(2).unwrap();
910 assert_eq!(third.value(), 3.0);
911
912 // Test last()
913 let iter = tensor.iter();
914 let last = iter.last().unwrap();
915 assert_eq!(last.value(), 5.0);
916 }
917
918 /// Test ExactSizeIterator
919 #[test]
920 fn test_exact_size_iterator() {
921 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
922 let iter = tensor.iter();
923
924 assert_eq!(iter.len(), 3);
925
926 // Test that len() decreases as we consume the iterator
927 let mut iter = tensor.iter();
928 assert_eq!(iter.len(), 3);
929 iter.next();
930 assert_eq!(iter.len(), 2);
931 iter.next();
932 assert_eq!(iter.len(), 1);
933 iter.next();
934 assert_eq!(iter.len(), 0);
935 }
936
937 /// Test DoubleEndedIterator
938 #[test]
939 fn test_double_ended_iterator() {
940 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
941 let mut iter = tensor.iter();
942
943 // Test next_back()
944 let last = iter.next_back().unwrap();
945 assert_eq!(last.value(), 4.0);
946
947 let first = iter.next().unwrap();
948 assert_eq!(first.value(), 1.0);
949
950 // Test nth_back()
951 let mut iter = tensor.iter();
952 let second_to_last = iter.nth_back(1).unwrap();
953 assert_eq!(second_to_last.value(), 3.0);
954
955 // Test consuming from both ends
956 let mut iter = tensor.iter();
957 assert_eq!(iter.next().unwrap().value(), 1.0);
958 assert_eq!(iter.next_back().unwrap().value(), 4.0);
959 assert_eq!(iter.next().unwrap().value(), 2.0);
960 assert_eq!(iter.next_back().unwrap().value(), 3.0);
961 assert!(iter.next().is_none());
962 }
963
964 /// Test IntoIterator trait
965 #[test]
966 fn test_into_iterator() {
967 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
968
969 // Test with for loop
970 let mut values = Vec::new();
971 for element in &tensor {
972 values.push(element.value());
973 }
974 assert_eq!(values, vec![1.0, 2.0, 3.0]);
975
976 // Test with into_iter() explicitly
977 let values: Vec<f32> = (&tensor).into_iter().map(|elem| elem.value()).collect();
978 assert_eq!(values, vec![1.0, 2.0, 3.0]);
979 }
980
981 /// Test FromIterator trait (collect)
982 #[test]
983 fn test_from_iterator() {
984 let original = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
985
986 // Test collecting back to tensor
987 let collected: Tensor = original.iter().collect();
988 assert_eq!(collected.shape().dims, vec![4]);
989 assert_eq!(collected.data(), original.data());
990
991 // Test collecting with transformations
992 let doubled: Tensor = original
993 .iter()
994 .map(|elem| {
995 let val = elem.value();
996 Tensor::from_slice(&[val * 2.0], vec![1]).unwrap()
997 })
998 .collect();
999
1000 assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0]);
1001 }
1002
1003 /// Test standard library iterator methods
1004 #[test]
1005 fn test_std_iterator_methods() {
1006 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1007
1008 // Test map
1009 let doubled: Vec<f32> = tensor.iter().map(|elem| elem.value() * 2.0).collect();
1010 assert_eq!(doubled, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1011
1012 // Test filter
1013 let large_values: Vec<f32> = tensor
1014 .iter()
1015 .filter(|elem| elem.value() > 3.0)
1016 .map(|elem| elem.value())
1017 .collect();
1018 assert_eq!(large_values, vec![4.0, 5.0]);
1019
1020 // Test enumerate
1021 let with_indices: Vec<(usize, f32)> = tensor
1022 .iter()
1023 .enumerate()
1024 .map(|(i, elem)| (i, elem.value()))
1025 .collect();
1026 assert_eq!(
1027 with_indices,
1028 vec![(0, 1.0), (1, 2.0), (2, 3.0), (3, 4.0), (4, 5.0)]
1029 );
1030
1031 // Test fold
1032 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
1033 assert_eq!(sum, 15.0);
1034
1035 // Test find
1036 let found = tensor.iter().find(|elem| elem.value() == 3.0);
1037 assert!(found.is_some());
1038 assert_eq!(found.unwrap().value(), 3.0);
1039
1040 // Test any/all
1041 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
1042 assert!(all_positive);
1043
1044 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
1045 assert!(any_large);
1046 }
1047
1048 /// Test element operations with tensor methods
1049 #[test]
1050 fn test_element_tensor_operations() {
1051 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
1052
1053 // Test scalar operations on elements
1054 let scaled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1055 assert_eq!(scaled.data(), &[2.0, 4.0, 6.0]);
1056
1057 let offset: Tensor = tensor.iter().map(|elem| elem.add_scalar(10.0)).collect();
1058 assert_eq!(offset.data(), &[11.0, 12.0, 13.0]);
1059
1060 // Test chaining operations
1061 let complex: Tensor = tensor
1062 .iter()
1063 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
1064 .collect();
1065 assert_eq!(complex.data(), &[3.0, 5.0, 7.0]);
1066 }
1067
1068 /// Test gradient tracking through element operations
1069 #[test]
1070 fn test_gradient_tracking() {
1071 let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
1072 .unwrap()
1073 .with_requires_grad();
1074
1075 // Perform element-wise operations
1076 let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
1077
1078 // The result should require gradients if any element requires gradients
1079 // Note: Current implementation creates copies, so gradient tracking is
1080 // implemented but may not propagate back to original tensor
1081 assert!(result.requires_grad());
1082
1083 // For now, just verify the forward pass works with gradient-enabled tensors
1084 // Full gradient propagation would require true view implementation
1085 assert_eq!(result.data(), &[2.0, 4.0]);
1086 }
1087
1088 /// Test with zero-sized tensors
1089 #[test]
1090 fn test_zero_sized_tensor() {
1091 let empty = Tensor::new(vec![0]);
1092 let iter = empty.iter();
1093
1094 assert_eq!(iter.len(), 0);
1095 assert_eq!(iter.size_hint(), (0, Some(0)));
1096
1097 let collected: Tensor = iter.collect();
1098 assert_eq!(collected.size(), 0);
1099 }
1100
1101 /// Test range iteration
1102 #[test]
1103 fn test_range_iteration() {
1104 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
1105
1106 // Test middle range
1107 let middle: Vec<f32> = tensor.iter_range(1, 4).map(|elem| elem.value()).collect();
1108 assert_eq!(middle, vec![2.0, 3.0, 4.0]);
1109
1110 // Test out of bounds (should be clamped)
1111 let clamped: Vec<f32> = tensor.iter_range(3, 10).map(|elem| elem.value()).collect();
1112 assert_eq!(clamped, vec![4.0, 5.0]);
1113
1114 // Test empty range
1115 let empty: Vec<f32> = tensor.iter_range(2, 2).map(|elem| elem.value()).collect();
1116 assert_eq!(empty, Vec::<f32>::new());
1117 }
1118
1119 /// Test complex iterator chains
1120 #[test]
1121 fn test_complex_chains() {
1122 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
1123
1124 // Complex chain: enumerate -> filter -> map -> collect
1125 let result: Tensor = tensor
1126 .iter()
1127 .enumerate()
1128 .filter(|(i, _)| i % 2 == 0) // Take even indices
1129 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
1130 .collect();
1131
1132 // Should have elements [1.0 + 0, 3.0 + 2, 5.0 + 4] = [1.0, 5.0, 9.0]
1133 assert_eq!(result.data(), &[1.0, 5.0, 9.0]);
1134
1135 // Test with rev()
1136 let reversed: Tensor = tensor.iter().rev().take(3).collect();
1137
1138 assert_eq!(reversed.data(), &[6.0, 5.0, 4.0]);
1139 }
1140
1141 /// Performance test for iterator overhead
1142 #[test]
1143 fn test_performance() {
1144 let large_tensor =
1145 Tensor::from_slice(&(0..1000).map(|i| i as f32).collect::<Vec<_>>(), vec![1000])
1146 .unwrap();
1147
1148 let start = std::time::Instant::now();
1149
1150 let result: Tensor = large_tensor
1151 .iter()
1152 .map(|elem| elem.mul_scalar(2.0))
1153 .collect();
1154
1155 let duration = start.elapsed();
1156 println!("Iterator performance test took: {:?}", duration);
1157
1158 // Verify correctness
1159 assert_eq!(result.size(), 1000);
1160 assert_eq!(result.data()[0], 0.0);
1161 assert_eq!(result.data()[999], 1998.0);
1162 }
1163}