memory_efficient_example/
memory_efficient_example.rs

1//! Memory-efficient neural network training example
2//!
3//! This example demonstrates various memory optimization techniques including:
4//! - Memory pool for tensor reuse
5//! - Gradient checkpointing for reduced memory usage
6//! - In-place operations to minimize allocations
7//! - Memory-aware batch processing
8//! - Memory usage tracking and monitoring
9
10use ndarray::{Array2, Array3};
11use scirs2_neural::error::Result;
12use scirs2_neural::memory_efficient::{
13    BatchProcessorStats, GradientCheckpointing, InPlaceOperations, MemoryAwareBatchProcessor,
14    MemoryEfficientLayer, MemoryPool, MemoryUsage, PoolStatistics,
15};
16use std::time::Instant;
17
18fn main() -> Result<()> {
19    println!("Memory-Efficient Neural Network Operations Demo");
20    println!("===============================================");
21
22    // Demo 1: Memory Pool Usage
23    demo_memory_pool()?;
24
25    // Demo 2: Gradient Checkpointing
26    demo_gradient_checkpointing()?;
27
28    // Demo 3: In-place Operations
29    demo_in_place_operations()?;
30
31    // Demo 4: Memory-aware Batch Processing
32    demo_memory_aware_batch_processing()?;
33
34    // Demo 5: Memory-efficient Layer
35    demo_memory_efficient_layer()?;
36
37    // Demo 6: Memory Usage Tracking
38    demo_memory_usage_tracking()?;
39
40    Ok(())
41}
42
43fn demo_memory_pool() -> Result<()> {
44    println!("\nšŸ”„ Memory Pool Demo");
45    println!("------------------");
46
47    let mut pool = MemoryPool::<f32>::new(50); // 50MB max pool size
48
49    // Allocate several tensors
50    println!("Allocating tensors...");
51    let tensor1 = pool.allocate(&[1000, 500]); // ~2MB
52    let tensor2 = pool.allocate(&[500, 200]); // ~400KB
53    let tensor3 = pool.allocate(&[100, 100]); // ~40KB
54
55    let stats = pool.get_pool_stats();
56    println!("Pool stats after allocation:");
57    print_pool_stats(&stats);
58
59    // Return tensors to pool
60    println!("Returning tensors to pool...");
61    pool.deallocate(tensor1);
62    pool.deallocate(tensor2);
63    pool.deallocate(tensor3);
64
65    let stats = pool.get_pool_stats();
66    println!("Pool stats after deallocation:");
67    print_pool_stats(&stats);
68
69    // Reuse tensors
70    println!("Reusing tensors (should be faster)...");
71    let start = Instant::now();
72    let _reused1 = pool.allocate(&[1000, 500]);
73    let _reused2 = pool.allocate(&[500, 200]);
74    let reuse_time = start.elapsed();
75    println!("Reuse time: {:?}", reuse_time);
76
77    let stats = pool.get_pool_stats();
78    println!("Final pool stats:");
79    print_pool_stats(&stats);
80
81    Ok(())
82}
83
84fn demo_gradient_checkpointing() -> Result<()> {
85    println!("\nšŸ“Š Gradient Checkpointing Demo");
86    println!("------------------------------");
87
88    let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90    // Set up checkpoint layers
91    checkpointing.add_checkpoint_layer("conv1".to_string());
92    checkpointing.add_checkpoint_layer("conv3".to_string());
93    checkpointing.add_checkpoint_layer("fc1".to_string());
94
95    println!("Storing activations at checkpoints...");
96
97    // Simulate storing activations during forward pass
98    let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99    let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100    let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102    checkpointing.store_checkpoint("conv1", conv1_activation)?;
103    checkpointing.store_checkpoint("conv3", conv3_activation)?;
104    checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106    let usage = checkpointing.get_memory_usage();
107    println!("Memory usage after checkpointing:");
108    print_memory_usage(&usage);
109
110    // Simulate retrieving checkpoints during backward pass
111    println!("Retrieving checkpoints for gradient computation...");
112    if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113        println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114    }
115
116    // Clear checkpoints to free memory
117    println!("Clearing checkpoints...");
118    checkpointing.clear_checkpoints();
119
120    let usage = checkpointing.get_memory_usage();
121    println!("Memory usage after clearing:");
122    print_memory_usage(&usage);
123
124    Ok(())
125}
126
127fn demo_in_place_operations() -> Result<()> {
128    println!("\n⚔ In-place Operations Demo");
129    println!("--------------------------");
130
131    // Create test arrays
132    let mut relu_test = Array2::from_shape_vec(
133        (3, 4),
134        vec![
135            -1.0, 2.0, -3.0, 4.0, 0.5, -0.5, 1.5, -2.5, 3.0, -1.0, 0.0, 2.0,
136        ],
137    )?
138    .into_dyn();
139
140    let mut sigmoid_test =
141        Array2::from_shape_vec((2, 3), vec![-2.0, 0.0, 2.0, -1.0, 1.0, 3.0])?.into_dyn();
142
143    let mut add_test = Array2::from_elem((2, 2), 1.0).into_dyn();
144    let add_source = Array2::from_elem((2, 2), 0.5).into_dyn();
145
146    let mut norm_test =
147        Array2::from_shape_vec((2, 3), vec![1.0, 4.0, 7.0, 2.0, 5.0, 8.0])?.into_dyn();
148
149    println!("Before operations:");
150    println!("ReLU input (should clip negatives): {:?}", relu_test);
151    println!("Sigmoid input: {:?}", sigmoid_test);
152    println!("Addition target: {:?}", add_test);
153    println!("Normalization input: {:?}", norm_test);
154
155    // Apply in-place operations
156    println!("\nApplying in-place operations...");
157    InPlaceOperations::relu_inplace(&mut relu_test);
158    InPlaceOperations::sigmoid_inplace(&mut sigmoid_test);
159    InPlaceOperations::add_inplace(&mut add_test, &add_source)?;
160    InPlaceOperations::normalize_inplace(&mut norm_test)?;
161
162    println!("\nAfter operations:");
163    println!("ReLU result: {:?}", relu_test);
164    println!("Sigmoid result: {:?}", sigmoid_test);
165    println!("Addition result: {:?}", add_test);
166    println!("Normalized result: {:?}", norm_test);
167
168    // Test scaling
169    let mut scale_test = Array2::from_elem((2, 2), 2.0).into_dyn();
170    println!("\nScaling test - before: {:?}", scale_test);
171    InPlaceOperations::scale_inplace(&mut scale_test, 3.0);
172    println!("Scaling test - after: {:?}", scale_test);
173
174    Ok(())
175}
176
177fn demo_memory_aware_batch_processing() -> Result<()> {
178    println!("\nšŸ”€ Memory-Aware Batch Processing Demo");
179    println!("------------------------------------");
180
181    let mut processor = MemoryAwareBatchProcessor::<f32>::new(
182        200,   // 200MB max memory
183        150.0, // 150MB threshold
184        50,    // 50MB pool size
185    );
186
187    // Create a large dataset that needs to be processed in chunks
188    println!("Creating large dataset (1000 samples x 784 features)...");
189    let large_dataset = Array2::from_shape_fn((1000, 784), |(i, j)| {
190        (i as f32 * 0.01 + j as f32 * 0.001).sin()
191    })
192    .into_dyn();
193
194    println!("Dataset shape: {:?}", large_dataset.shape());
195    println!(
196        "Estimated memory: {:.2} MB",
197        (large_dataset.len() * std::mem::size_of::<f32>()) as f64 / (1024.0 * 1024.0)
198    );
199
200    // Process in memory-aware batches
201    println!("Processing with automatic batch size adjustment...");
202    let start = Instant::now();
203
204    let results = processor.process_batches(&large_dataset, |batch| {
205        // Simulate some processing (e.g., forward pass through a layer)
206        let processed = batch.mapv(|x| x.tanh()); // Apply activation
207        Ok(processed.to_owned())
208    })?;
209
210    let processing_time = start.elapsed();
211
212    println!("Processing completed in {:?}", processing_time);
213    println!("Number of result batches: {}", results.len());
214
215    // Print statistics
216    let stats = processor.get_stats();
217    println!("Batch processor statistics:");
218    print_batch_processor_stats(&stats);
219
220    Ok(())
221}
222
223fn demo_memory_efficient_layer() -> Result<()> {
224    println!("\n🧠 Memory-Efficient Layer Demo");
225    println!("------------------------------");
226
227    // Create a memory-efficient layer
228    let layer = MemoryEfficientLayer::new(
229        784,      // Input size (e.g., 28x28 MNIST)
230        128,      // Output size
231        Some(64), // Chunk size
232    )?;
233
234    println!("Created memory-efficient layer: 784 -> 128");
235
236    // Create input data
237    let input =
238        Array2::from_shape_fn((256, 784), |(i, j)| ((i + j) as f32 * 0.01).sin()).into_dyn();
239
240    println!("Input shape: {:?}", input.shape());
241
242    // Forward pass
243    println!("Performing forward pass...");
244    let start = Instant::now();
245    let output = layer.forward(&input)?;
246    let forward_time = start.elapsed();
247
248    println!("Forward pass completed in {:?}", forward_time);
249    println!("Output shape: {:?}", output.shape());
250
251    // Verify output statistics
252    let mean = output.mean().unwrap_or(0.0);
253    let std = {
254        let variance = output.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(0.0);
255        variance.sqrt()
256    };
257
258    println!("Output statistics:");
259    println!("  Mean: {:.6}", mean);
260    println!("  Std: {:.6}", std);
261    println!(
262        "  Min: {:.6}",
263        output.iter().cloned().fold(f32::INFINITY, f32::min)
264    );
265    println!(
266        "  Max: {:.6}",
267        output.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
268    );
269
270    Ok(())
271}
272
273fn demo_memory_usage_tracking() -> Result<()> {
274    println!("\nšŸ“ˆ Memory Usage Tracking Demo");
275    println!("-----------------------------");
276
277    let mut usage = MemoryUsage::new();
278
279    println!("Initial state:");
280    print_memory_usage(&usage);
281
282    // Simulate various allocation patterns
283    println!("\nSimulating allocation patterns...");
284
285    // Large allocation
286    usage.allocate(50 * 1024 * 1024); // 50MB
287    println!("After 50MB allocation:");
288    print_memory_usage(&usage);
289
290    // Multiple small allocations
291    for i in 1..=10 {
292        usage.allocate(1024 * 1024); // 1MB each
293        if i % 3 == 0 {
294            println!("After {} small allocations:", i);
295            print_memory_usage(&usage);
296        }
297    }
298
299    // Peak usage reached
300    println!("Peak memory usage reached:");
301    print_memory_usage(&usage);
302
303    // Simulate deallocations
304    println!("\nSimulating deallocations...");
305    for i in 1..=8 {
306        usage.deallocate(5 * 1024 * 1024); // 5MB each
307        if i % 2 == 0 {
308            println!("After {} deallocations:", i);
309            print_memory_usage(&usage);
310        }
311    }
312
313    println!("\nFinal state (note peak is preserved):");
314    print_memory_usage(&usage);
315
316    Ok(())
317}
318
319// Helper functions for pretty printing
320
321fn print_pool_stats(stats: &PoolStatistics) {
322    println!("  Cached tensors: {}", stats.total_cached_tensors);
323    println!("  Unique shapes: {}", stats.unique_shapes);
324    println!(
325        "  Pool size: {:.2}/{:.2} MB",
326        stats.current_pool_size_mb, stats.max_pool_size_mb
327    );
328}
329
330fn print_memory_usage(usage: &MemoryUsage) {
331    println!("  Current: {:.2} MB", usage.current_mb());
332    println!("  Peak: {:.2} MB", usage.peak_mb());
333    println!("  Active allocations: {}", usage.active_allocations);
334    println!("  Total allocations: {}", usage.total_allocations);
335}
336
337fn print_batch_processor_stats(stats: &BatchProcessorStats) {
338    println!("  Max batch size: {}", stats.max_batch_size);
339    println!("  Current memory: {:.2} MB", stats.current_memory_mb);
340    println!("  Peak memory: {:.2} MB", stats.peak_memory_mb);
341    println!("  Memory threshold: {:.2} MB", stats.memory_threshold_mb);
342    println!("  Pool stats:");
343    print_pool_stats(&stats.pool_stats);
344}