performance_optimization/
performance_optimization.rs

1//! # Performance Optimization - Memory-Efficient Tensor Processing
2//!
3//! ## Overview
4//!
5//! This example demonstrates performance optimization techniques and memory-efficient
6//! processing patterns using Train Station's tensor iterator system. It showcases
7//! how to process large datasets efficiently while maintaining gradient tracking
8//! and leveraging SIMD optimizations.
9//!
10//! ## Learning Objectives
11//!
12//! - Understand performance characteristics of tensor iterators
13//! - Learn memory-efficient processing patterns
14//! - Master optimization techniques for large-scale processing
15//! - Explore benchmarking and performance analysis
16//!
17//! ## Prerequisites
18//!
19//! - Understanding of basic and advanced iterator concepts
20//! - Knowledge of performance optimization principles
21//! - Familiarity with memory management patterns
22//! - Experience with large-scale data processing
23//!
24//! ## Key Concepts Demonstrated
25//!
26//! - **Memory Efficiency**: Zero-copy views and shared memory allocation
27//! - **Performance Optimization**: SIMD utilization and batch processing
28//! - **Benchmarking**: Performance measurement and analysis
29//! - **Scalability**: Processing patterns for large datasets
30//! - **Resource Management**: Efficient memory and computation usage
31//!
32//! Recommended patterns:
33//! - For element-wise transforms across a tensor, prefer `iter_flat()` + `collect_shape([..])` to
34//!   reshape in one pass; use GradTrack when needed.
35//! - For multi-dim per-row/per-slice logic, use `iter()` and assemble with `collect_shape([..])`.
36//! - For large tensors, `chunks()` or `iter_fast_chunks()` can improve locality; collect with shape.
37//! - For inference-only pipelines, `with_no_grad` + value streaming and `collect_shape` provides
38//!   highest throughput by avoiding view creation and GradTrack overhead.
39//!
40//! ## Example Code Structure
41//!
42//! 1. **Performance Benchmarking**: Measuring iterator performance characteristics
43//! 2. **Memory Optimization**: Efficient memory usage patterns
44//! 3. **Large-Scale Processing**: Handling big datasets efficiently
45//! 4. **Optimization Techniques**: Advanced performance optimization strategies
46//!
47//! ## Expected Output
48//!
49//! The example will demonstrate performance characteristics and optimization
50//! techniques, showing how to efficiently process large datasets while
51//! maintaining memory efficiency and leveraging SIMD optimizations.
52//!
53//! ## Performance Notes
54//!
55//! - View creation overhead: ~64 bytes per element view
56//! - SIMD operations leverage existing optimized implementations
57//! - Memory sharing eliminates data copying overhead
58//! - Batch processing improves cache locality
59
60use std::time::Instant;
61use train_station::{
62    gradtrack::with_no_grad,
63    tensor::{TensorCollectExt, ValuesCollectExt},
64    Tensor,
65};
66
67/// Main example function demonstrating performance optimization
68///
69/// This function showcases performance optimization techniques and
70/// memory-efficient processing patterns for large-scale tensor operations.
71fn main() -> Result<(), Box<dyn std::error::Error>> {
72    println!("Starting Performance Optimization Example");
73
74    demonstrate_performance_benchmarking()?;
75    demonstrate_memory_optimization()?;
76    demonstrate_large_scale_processing()?;
77    demonstrate_optimization_techniques()?;
78
79    println!("Performance Optimization Example completed successfully!");
80    Ok(())
81}
82
83/// Demonstrate performance benchmarking and analysis
84///
85/// Shows how to measure and analyze the performance characteristics
86/// of tensor iterator operations and compare different approaches.
87fn demonstrate_performance_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Performance Benchmarking ---");
89
90    // Create test data of different sizes
91    let sizes = vec![100, 1000, 10000];
92
93    for size in sizes {
94        println!("\nBenchmarking with tensor size: {}", size);
95
96        // Generate test data
97        let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
98        let tensor = Tensor::from_slice(&data, vec![size])?;
99
100        // Benchmark 1: Direct tensor operations
101        let start = Instant::now();
102        let direct_result = tensor.mul_scalar(2.0).add_scalar(1.0);
103        let direct_time = start.elapsed();
104
105        // Benchmark 2: Iterator-based operations (grad-enabled views, flatten + collect_shape)
106        let start = Instant::now();
107        let iterator_result: Tensor = tensor
108            .iter_elements()
109            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
110            .collect_shape(vec![size]);
111        let iterator_time = start.elapsed();
112
113        // Benchmark 3: Chained iterator operations
114        let start = Instant::now();
115        let _chained_result: Tensor = tensor
116            .iter_elements()
117            .map(|elem| elem.mul_scalar(2.0))
118            .filter(|elem| elem.value() > size as f32)
119            .map(|elem| elem.add_scalar(1.0))
120            .collect();
121        let chained_time = start.elapsed();
122
123        // Benchmark 4: NoGrad raw data streaming + collect_shape
124        let start = Instant::now();
125        let _streamed: Tensor = with_no_grad(|| {
126            tensor
127                .data()
128                .iter()
129                .copied()
130                .map(|x| 2.0 * x + 1.0)
131                .collect_shape(vec![size])
132        });
133        let streaming_time = start.elapsed();
134
135        // Report results
136        println!("  Direct operations: {:?}", direct_time);
137        println!("  Iterator operations: {:?}", iterator_time);
138        println!("  Chained operations: {:?}", chained_time);
139        println!("  NoGrad streaming (data.iter): {:?}", streaming_time);
140
141        // Verify correctness
142        assert_eq!(direct_result.data(), iterator_result.data());
143        println!(
144            "  Results match: {}",
145            direct_result.data() == iterator_result.data()
146        );
147
148        // Performance ratios
149        let ratio = iterator_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
150        let ratio_stream = streaming_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
151        println!("  Iterator/Direct ratio: {:.2}x", ratio);
152        println!("  Streaming/Direct ratio: {:.2}x", ratio_stream);
153    }
154
155    Ok(())
156}
157
158/// Demonstrate memory optimization patterns
159///
160/// Shows memory-efficient processing patterns and techniques
161/// for minimizing memory usage while maintaining performance.
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163    println!("\n--- Memory Optimization ---");
164
165    // Create a large tensor for memory testing
166    let size = 10000;
167    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168    let tensor = Tensor::from_slice(&data, vec![size])?;
169
170    println!("Processing tensor of size: {}", size);
171
172    // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173    println!("\nPattern 1: Streaming Processing");
174    let chunk_size = 1000;
175    let start = Instant::now();
176    let flattened = tensor.view(vec![size as i32]);
177    let _streamed_result: Tensor = flattened
178        .chunks(chunk_size)
179        .map(|c| c.pow_scalar(2.0).sqrt())
180        .collect_shape(vec![size]);
181    let streamed_time = start.elapsed();
182
183    // Pattern 2: Full processing
184    let start = Instant::now();
185    let _full_result: Tensor = tensor
186        .iter_elements()
187        .map(|elem| elem.pow_scalar(2.0).sqrt())
188        .collect_shape(vec![size]);
189    let full_time = start.elapsed();
190
191    println!("  Streaming time: {:?}", streamed_time);
192    println!("  Full processing time: {:?}", full_time);
193    println!(
194        "  Memory efficiency ratio: {:.2}x",
195        full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196    );
197
198    // Pattern 3: Lazy evaluation with take
199    println!("\nPattern 2: Lazy Evaluation");
200    let start = Instant::now();
201    let lazy_result: Tensor = tensor
202        .iter_elements()
203        .take(1000) // Only process first 1000 elements
204        .map(|elem| elem.pow_scalar(2.0).sqrt())
205        .collect_shape(vec![1000]);
206    let lazy_time = start.elapsed();
207
208    println!("  Lazy processing (1000 elements): {:?}", lazy_time);
209    println!("  Lazy result size: {}", lazy_result.size());
210
211    // Pattern 4: Memory-efficient filtering
212    println!("\nPattern 3: Memory-Efficient Filtering");
213    let start = Instant::now();
214    let filtered_result: Tensor = tensor
215        .iter_elements()
216        .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217        .map(|elem| elem.mul_scalar(2.0))
218        .collect();
219    let filtered_time = start.elapsed();
220
221    println!("  Filtered processing: {:?}", filtered_time);
222    println!(
223        "  Filtered result size: {} (reduced from {})",
224        filtered_result.size(),
225        size
226    );
227
228    Ok(())
229}
230
231/// Demonstrate large-scale processing techniques
232///
233/// Shows how to efficiently process very large datasets using
234/// iterator patterns and optimization strategies.
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236    println!("\n--- Large-Scale Processing ---");
237
238    // Simulate large dataset processing
239    let sizes = vec![10000, 50000, 100000];
240
241    for size in sizes {
242        println!("\nProcessing dataset of size: {}", size);
243
244        // Generate large dataset
245        let data: Vec<f32> = (0..size)
246            .map(|i| {
247                let x = i as f32 / size as f32;
248                x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249            })
250            .collect();
251
252        let tensor = Tensor::from_slice(&data, vec![size])?;
253
254        // Technique 1: Batch processing
255        let batch_size = 1000;
256        let start = Instant::now();
257
258        let mut batch_results = Vec::new();
259        for batch_start in (0..size).step_by(batch_size) {
260            let batch_end = (batch_start + batch_size).min(size);
261            let batch: Tensor = tensor
262                .iter_range(batch_start, batch_end)
263                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264                .collect();
265            batch_results.push(batch);
266        }
267        let batch_time = start.elapsed();
268
269        // Technique 2: Parallel-like processing with stride
270        let start = Instant::now();
271        let stride = 4;
272        let strided_result: Tensor = tensor
273            .iter()
274            .enumerate()
275            .filter(|(i, _)| i % stride == 0)
276            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277            .collect();
278        let strided_time = start.elapsed();
279
280        // Technique 3: Hierarchical processing
281        let start = Instant::now();
282        let coarse: Tensor = tensor
283            .iter()
284            .enumerate()
285            .filter(|(i, _)| i % 10 == 0) // Every 10th element
286            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287            .collect();
288        let fine: Tensor = tensor
289            .iter()
290            .enumerate()
291            .filter(|(i, _)| i % 10 != 0) // Rest of elements
292            .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293            .collect();
294        let hierarchical_time = start.elapsed();
295
296        // Report performance
297        println!("  Batch processing: {:?}", batch_time);
298        println!("  Strided processing: {:?}", strided_time);
299        println!("  Hierarchical processing: {:?}", hierarchical_time);
300
301        // Memory usage analysis
302        let total_batches = size.div_ceil(batch_size);
303        println!("  Batch count: {}", total_batches);
304        println!("  Strided result size: {}", strided_result.size());
305        println!(
306            "  Hierarchical: coarse={}, fine={}",
307            coarse.size(),
308            fine.size()
309        );
310    }
311
312    Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320    println!("\n--- Optimization Techniques ---");
321
322    let size = 50000;
323    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324    let tensor = Tensor::from_slice(&data, vec![size])?;
325
326    println!("Optimizing processing for size: {}", size);
327
328    // Technique 1: Operation fusion
329    println!("\nTechnique 1: Operation Fusion");
330    let start = Instant::now();
331    let fused_result: Tensor = tensor
332        .iter()
333        .map(|elem| {
334            // Fuse multiple operations into single chain
335            elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336        })
337        .collect();
338    let fused_time = start.elapsed();
339
340    // Technique 2: Conditional optimization
341    println!("\nTechnique 2: Conditional Optimization");
342    let start = Instant::now();
343    let conditional_result: Tensor = tensor
344        .iter()
345        .map(|elem| {
346            let val = elem.value();
347            if val < size as f32 / 2.0 {
348                elem.mul_scalar(2.0) // Simple operation for small values
349            } else {
350                elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351            }
352        })
353        .collect();
354    let conditional_time = start.elapsed();
355
356    // Technique 3: Cache-friendly processing
357    println!("\nTechnique 3: Cache-Friendly Processing");
358    let start = Instant::now();
359    let cache_friendly_result: Tensor = tensor
360        .iter()
361        .take(1000) // Process in cache-friendly chunks
362        .map(|elem| elem.mul_scalar(2.0))
363        .collect();
364    let cache_friendly_time = start.elapsed();
365
366    // Technique 4: Memory pooling simulation
367    println!("\nTechnique 4: Memory Pooling Simulation");
368    let start = Instant::now();
369    let pooled_result: Tensor = tensor
370        .iter()
371        .enumerate()
372        .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373        .map(|(_, elem)| elem.pow_scalar(2.0))
374        .collect();
375    let pooled_time = start.elapsed();
376
377    // Report optimization results
378    println!("  Fused operations: {:?}", fused_time);
379    println!("  Conditional optimization: {:?}", conditional_time);
380    println!("  Cache-friendly processing: {:?}", cache_friendly_time);
381    println!("  Memory pooling simulation: {:?}", pooled_time);
382
383    // Performance analysis
384    let fastest = fused_time
385        .min(conditional_time)
386        .min(cache_friendly_time)
387        .min(pooled_time);
388    println!("  Fastest technique: {:?}", fastest);
389
390    // Memory efficiency analysis
391    println!("  Fused result size: {}", fused_result.size());
392    println!("  Conditional result size: {}", conditional_result.size());
393    println!(
394        "  Cache-friendly result size: {}",
395        cache_friendly_result.size()
396    );
397    println!("  Pooled result size: {}", pooled_result.size());
398
399    // Technique 5: Gradient optimization
400    println!("\nTechnique 5: Gradient Optimization");
401    let grad_tensor = tensor.with_requires_grad();
402    let start = Instant::now();
403
404    let grad_result: Tensor = grad_tensor
405        .iter()
406        .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407        .collect();
408
409    let mut loss = grad_result.sum();
410    loss.backward(None);
411    let grad_time = start.elapsed();
412
413    println!("  Gradient computation: {:?}", grad_time);
414    println!(
415        "  Gradient tracking enabled: {}",
416        grad_result.requires_grad()
417    );
418
419    Ok(())
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    /// Test performance benchmarking
427    #[test]
428    fn test_performance_benchmarking() {
429        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
430        let direct = tensor.mul_scalar(2.0);
431        let iterator: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
432
433        assert_eq!(direct.data(), iterator.data());
434    }
435
436    /// Test memory optimization
437    #[test]
438    fn test_memory_optimization() {
439        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
440        let streamed: Tensor = tensor
441            .iter_range(0, 2)
442            .map(|elem| elem.mul_scalar(2.0))
443            .collect();
444
445        assert_eq!(streamed.data(), &[2.0, 4.0]);
446    }
447
448    /// Test large-scale processing
449    #[test]
450    fn test_large_scale_processing() {
451        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
452        let strided: Tensor = tensor
453            .iter()
454            .enumerate()
455            .filter(|(i, _)| i % 2 == 0)
456            .map(|(_, elem)| elem)
457            .collect();
458
459        assert_eq!(strided.data(), &[1.0, 3.0]);
460    }
461
462    /// Test optimization techniques
463    #[test]
464    fn test_optimization_techniques() {
465        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
466        let fused: Tensor = tensor
467            .iter()
468            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
469            .collect();
470
471        assert_eq!(fused.data(), &[3.0, 5.0, 7.0]);
472    }
473}