element_iteration/
element_iteration.rs

1//! # Element Iteration - Basic Tensor Element Processing
2//!
3//! ## Overview
4//!
5//! This example demonstrates the fundamental tensor iterator functionality in Train Station,
6//! showing how to iterate over tensor elements as individual view tensors. Each element
7//! becomes a proper Tensor of shape [1] that supports all existing tensor operations
8//! and gradient tracking.
9//!
10//! ## Learning Objectives
11//!
12//! - Understand basic tensor element iteration
13//! - Learn standard iterator trait methods
14//! - Master element-wise transformations
15//! - Explore gradient tracking through iterations
16//!
17//! ## Prerequisites
18//!
19//! - Basic Rust knowledge and iterator concepts
20//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
21//! - Familiarity with functional programming patterns
22//!
23//! ## Key Concepts Demonstrated
24//!
25//! - **Element Views**: Each element becomes a true tensor view of shape [1]
26//! - **Standard Library Integration**: Full compatibility with Rust's iterator traits
27//! - **Gradient Tracking**: Automatic gradient propagation through element operations
28//! - **Zero-Copy Semantics**: True views with shared memory allocation
29//! - **NoGrad Fast Paths**: Use `with_no_grad` and `data().iter()` to stream values directly for
30//!   maximum throughput when gradients are not needed
31//!
32//! ## Example Code Structure
33//!
34//! 1. **Basic Iteration**: Simple element access and transformation
35//! 2. **Standard Methods**: Using Iterator trait methods (map, filter, collect)
36//! 3. **Gradient Tracking**: Demonstrating autograd through element operations
37//! 4. **Advanced Patterns**: Complex iterator chains and transformations
38//! 5. **NoGrad & Streaming**: Raw data iteration + accelerated collection for inference
39//!
40//! When to use which iterator path:
41//! - Use `iter()` for multi-dim outer iteration (row-major slices). Combine with
42//!   `collect_shape([..])` to preserve shape after per-slice transforms.
43//! - Use `iter_flat()` for scalar element transforms requiring GradTrack. Combine with
44//!   `collect_shape` to reshape the output efficiently instead of manual concatenation.
45//! - Use `chunks()` or `iter_fast_chunks()` to process large 1D (or flattened) tensors
46//!   in cache-friendly blocks; `collect_shape` to reassemble.
47//! - For inference-only value pipelines, use `with_no_grad` + `data().iter().copied()` and
48//!   `collect_shape` to stream directly into the destination tensor.
49//!
50//! ## Expected Output
51//!
52//! The example will demonstrate various iteration patterns, showing element-wise
53//! transformations, gradient tracking, and performance characteristics of the
54//! tensor iterator system.
55//!
56//! ## Performance Notes
57//!
58//! - View creation is O(1) per element with true zero-copy semantics
59//! - Memory overhead is ~64 bytes per view tensor (no data copying)
60//! - All operations leverage existing SIMD-optimized tensor implementations
61//!
62//! ## Next Steps
63//!
64//! - Explore advanced_patterns.rs for complex iterator chains
65//! - Study performance_optimization.rs for large-scale processing
66//! - Review tensor operations for element-wise mathematical functions
67
68use train_station::tensor::{TensorCollectExt, ValuesCollectExt};
69use train_station::{gradtrack::with_no_grad, Tensor};
70
71/// Main example function demonstrating basic element iteration
72///
73/// This function serves as the primary educational entry point,
74/// with extensive inline comments explaining each step.
75fn main() -> Result<(), Box<dyn std::error::Error>> {
76    println!("Starting Element Iteration Example");
77
78    demonstrate_basic_iteration()?;
79    demonstrate_standard_methods()?;
80    demonstrate_gradient_tracking()?;
81    demonstrate_advanced_patterns()?;
82    demonstrate_row_wise_collect_shape()?;
83    demonstrate_nograd_and_streaming()?;
84
85    println!("Element Iteration Example completed successfully!");
86    Ok(())
87}
88
89/// Demonstrate basic tensor element iteration
90///
91/// Shows how to create iterators over tensor elements and perform
92/// simple element-wise operations.
93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94    println!("\n--- Basic Element Iteration ---");
95
96    // Create a simple tensor for demonstration
97    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98    println!("Original tensor: {:?}", tensor.data());
99
100    // Basic iteration with for loop
101    println!("\nBasic iteration with for loop:");
102    for (i, element) in tensor.iter().enumerate() {
103        println!(
104            "  Element {}: value = {:.1}, shape = {:?}",
105            i,
106            element.value(),
107            element.shape().dims()
108        );
109    }
110
111    // Element-wise transformation
112    println!("\nElement-wise transformation (2x + 1):");
113    let transformed: Tensor = tensor
114        .iter()
115        .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116        .collect();
117    println!("  Result: {:?}", transformed.data());
118
119    // Filtering elements
120    println!("\nFiltering elements (values > 3.0):");
121    let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122    println!("  Filtered: {:?}", filtered.data());
123
124    Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132    println!("\n--- Standard Iterator Methods ---");
133
134    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136    // Using map for transformations
137    println!("\nMap transformation (square each element):");
138    let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139    println!("  Squared: {:?}", squared.data());
140
141    // Using enumerate for indexed operations
142    println!("\nEnumerate with indexed operations:");
143    let indexed: Tensor = tensor
144        .iter()
145        .enumerate()
146        .map(|(i, elem)| elem.add_scalar(i as f32))
147        .collect();
148    println!("  Indexed: {:?}", indexed.data());
149
150    // Using fold for reduction
151    println!("\nFold for sum calculation:");
152    let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153    println!("  Sum: {:.1}", sum);
154
155    // Using find for element search
156    println!("\nFind specific element:");
157    if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158        println!("  Found element with value 3.0: {:.1}", found.value());
159    }
160
161    // Using any/all for condition checking
162    println!("\nCondition checking:");
163    let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164    let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165    println!("  All positive: {}", all_positive);
166    println!("  Any > 4.0: {}", any_large);
167
168    Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
258
259/// Demonstrate per-row transforms with shape-preserving collection
260///
261/// Shows how to use `iter()` over the outer dimension on a 2D tensor and
262/// `collect_shape([..])` to maintain the original shape after mapping.
263fn demonstrate_row_wise_collect_shape() -> Result<(), Box<dyn std::error::Error>> {
264    println!("\n--- Row-wise iteration with collect_shape ---");
265    let mat = Tensor::from_slice(&(1..=12).map(|x| x as f32).collect::<Vec<_>>(), vec![3, 4])?;
266    println!("Input shape: {:?}", mat.shape().dims());
267
268    // Map each row: 1.1*x + 0.5, then collect back to [3,4]
269    let out: Tensor = mat
270        .iter()
271        .map(|row| row.mul_scalar(1.1).add_scalar(0.5))
272        .collect_shape(vec![3, 4]);
273    println!("  Output shape: {:?}", out.shape().dims());
274
275    Ok(())
276}
277
278/// Demonstrate NoGrad fast paths and raw data streaming
279///
280/// Highlights how to get maximum performance in inference by:
281/// - Disabling gradient tracking with `with_no_grad`
282/// - Iterating raw values via `tensor.data().iter().copied()`
283/// - Using `collect_shape` to stream directly into destination tensors
284fn demonstrate_nograd_and_streaming() -> Result<(), Box<dyn std::error::Error>> {
285    println!("\n--- NoGrad & Streaming (Inference Fast Paths) ---");
286
287    let input = Tensor::from_slice(
288        &(0..24).map(|i| i as f32 * 0.25).collect::<Vec<_>>(),
289        vec![4, 6],
290    )?;
291    println!("Input shape: {:?}", input.shape().dims());
292
293    // NoGrad: stream values directly and reshape
294    let out = with_no_grad(|| {
295        input
296            .data()
297            .iter()
298            .copied()
299            .map(|x| 1.2 * x - 0.3)
300            .collect_shape(vec![4, 6])
301    });
302    println!(
303        "  NoGrad streamed map (1.2x-0.3) -> shape {:?}",
304        out.shape().dims()
305    );
306
307    // Compare to view-based element iteration in NoGrad
308    let out_view: Tensor = with_no_grad(|| {
309        input
310            .iter()
311            .map(|e| e.mul_scalar(1.2).add_scalar(-0.3))
312            .collect_shape(vec![4, 6])
313    });
314    println!(
315        "  NoGrad view-based map shape {:?}",
316        out_view.shape().dims()
317    );
318
319    // Quick parity check
320    assert_eq!(out.data(), out_view.data());
321    println!("  Parity check passed.");
322
323    // Show simple flatten + collect back to a different shape
324    let reshaped = with_no_grad(|| input.data().iter().copied().collect_shape(vec![6, 4]));
325    println!(
326        "  Reshaped via streaming collect_shape: {:?}",
327        reshaped.shape().dims()
328    );
329
330    Ok(())
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    /// Test basic iteration functionality
338    #[test]
339    fn test_basic_iteration() {
340        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
341        let elements: Vec<Tensor> = tensor.iter().collect();
342
343        assert_eq!(elements.len(), 3);
344        assert_eq!(elements[0].value(), 1.0);
345        assert_eq!(elements[1].value(), 2.0);
346        assert_eq!(elements[2].value(), 3.0);
347    }
348
349    /// Test element-wise transformation
350    #[test]
351    fn test_element_transformation() {
352        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
353        let doubled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
354
355        assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
356    }
357
358    /// Test gradient tracking
359    #[test]
360    fn test_gradient_tracking() {
361        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
362            .unwrap()
363            .with_requires_grad();
364
365        let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
366
367        assert!(result.requires_grad());
368        assert_eq!(result.data(), &[2.0, 4.0]);
369    }
370}