memory_efficient_example/
memory_efficient_example.rs1use ndarray::{Array2, Array3};
11use scirs2_neural::error::Result;
12use scirs2_neural::memory_efficient::{
13 BatchProcessorStats, GradientCheckpointing, InPlaceOperations, MemoryAwareBatchProcessor,
14 MemoryEfficientLayer, MemoryPool, MemoryUsage, PoolStatistics,
15};
16use std::time::Instant;
17
18fn main() -> Result<()> {
19 println!("Memory-Efficient Neural Network Operations Demo");
20 println!("===============================================");
21
22 demo_memory_pool()?;
24
25 demo_gradient_checkpointing()?;
27
28 demo_in_place_operations()?;
30
31 demo_memory_aware_batch_processing()?;
33
34 demo_memory_efficient_layer()?;
36
37 demo_memory_usage_tracking()?;
39
40 Ok(())
41}
42
43fn demo_memory_pool() -> Result<()> {
44 println!("\nš Memory Pool Demo");
45 println!("------------------");
46
47 let mut pool = MemoryPool::<f32>::new(50); println!("Allocating tensors...");
51 let tensor1 = pool.allocate(&[1000, 500]); let tensor2 = pool.allocate(&[500, 200]); let tensor3 = pool.allocate(&[100, 100]); let stats = pool.get_pool_stats();
56 println!("Pool stats after allocation:");
57 print_pool_stats(&stats);
58
59 println!("Returning tensors to pool...");
61 pool.deallocate(tensor1);
62 pool.deallocate(tensor2);
63 pool.deallocate(tensor3);
64
65 let stats = pool.get_pool_stats();
66 println!("Pool stats after deallocation:");
67 print_pool_stats(&stats);
68
69 println!("Reusing tensors (should be faster)...");
71 let start = Instant::now();
72 let _reused1 = pool.allocate(&[1000, 500]);
73 let _reused2 = pool.allocate(&[500, 200]);
74 let reuse_time = start.elapsed();
75 println!("Reuse time: {:?}", reuse_time);
76
77 let stats = pool.get_pool_stats();
78 println!("Final pool stats:");
79 print_pool_stats(&stats);
80
81 Ok(())
82}
83
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\nš Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
126
127fn demo_in_place_operations() -> Result<()> {
128 println!("\nā” In-place Operations Demo");
129 println!("--------------------------");
130
131 let mut relu_test = Array2::from_shape_vec(
133 (3, 4),
134 vec![
135 -1.0, 2.0, -3.0, 4.0, 0.5, -0.5, 1.5, -2.5, 3.0, -1.0, 0.0, 2.0,
136 ],
137 )?
138 .into_dyn();
139
140 let mut sigmoid_test =
141 Array2::from_shape_vec((2, 3), vec![-2.0, 0.0, 2.0, -1.0, 1.0, 3.0])?.into_dyn();
142
143 let mut add_test = Array2::from_elem((2, 2), 1.0).into_dyn();
144 let add_source = Array2::from_elem((2, 2), 0.5).into_dyn();
145
146 let mut norm_test =
147 Array2::from_shape_vec((2, 3), vec![1.0, 4.0, 7.0, 2.0, 5.0, 8.0])?.into_dyn();
148
149 println!("Before operations:");
150 println!("ReLU input (should clip negatives): {:?}", relu_test);
151 println!("Sigmoid input: {:?}", sigmoid_test);
152 println!("Addition target: {:?}", add_test);
153 println!("Normalization input: {:?}", norm_test);
154
155 println!("\nApplying in-place operations...");
157 InPlaceOperations::relu_inplace(&mut relu_test);
158 InPlaceOperations::sigmoid_inplace(&mut sigmoid_test);
159 InPlaceOperations::add_inplace(&mut add_test, &add_source)?;
160 InPlaceOperations::normalize_inplace(&mut norm_test)?;
161
162 println!("\nAfter operations:");
163 println!("ReLU result: {:?}", relu_test);
164 println!("Sigmoid result: {:?}", sigmoid_test);
165 println!("Addition result: {:?}", add_test);
166 println!("Normalized result: {:?}", norm_test);
167
168 let mut scale_test = Array2::from_elem((2, 2), 2.0).into_dyn();
170 println!("\nScaling test - before: {:?}", scale_test);
171 InPlaceOperations::scale_inplace(&mut scale_test, 3.0);
172 println!("Scaling test - after: {:?}", scale_test);
173
174 Ok(())
175}
176
177fn demo_memory_aware_batch_processing() -> Result<()> {
178 println!("\nš Memory-Aware Batch Processing Demo");
179 println!("------------------------------------");
180
181 let mut processor = MemoryAwareBatchProcessor::<f32>::new(
182 200, 150.0, 50, );
186
187 println!("Creating large dataset (1000 samples x 784 features)...");
189 let large_dataset = Array2::from_shape_fn((1000, 784), |(i, j)| {
190 (i as f32 * 0.01 + j as f32 * 0.001).sin()
191 })
192 .into_dyn();
193
194 println!("Dataset shape: {:?}", large_dataset.shape());
195 println!(
196 "Estimated memory: {:.2} MB",
197 (large_dataset.len() * std::mem::size_of::<f32>()) as f64 / (1024.0 * 1024.0)
198 );
199
200 println!("Processing with automatic batch size adjustment...");
202 let start = Instant::now();
203
204 let results = processor.process_batches(&large_dataset, |batch| {
205 let processed = batch.mapv(|x| x.tanh()); Ok(processed.to_owned())
208 })?;
209
210 let processing_time = start.elapsed();
211
212 println!("Processing completed in {:?}", processing_time);
213 println!("Number of result batches: {}", results.len());
214
215 let stats = processor.get_stats();
217 println!("Batch processor statistics:");
218 print_batch_processor_stats(&stats);
219
220 Ok(())
221}
222
223fn demo_memory_efficient_layer() -> Result<()> {
224 println!("\nš§ Memory-Efficient Layer Demo");
225 println!("------------------------------");
226
227 let layer = MemoryEfficientLayer::new(
229 784, 128, Some(64), )?;
233
234 println!("Created memory-efficient layer: 784 -> 128");
235
236 let input =
238 Array2::from_shape_fn((256, 784), |(i, j)| ((i + j) as f32 * 0.01).sin()).into_dyn();
239
240 println!("Input shape: {:?}", input.shape());
241
242 println!("Performing forward pass...");
244 let start = Instant::now();
245 let output = layer.forward(&input)?;
246 let forward_time = start.elapsed();
247
248 println!("Forward pass completed in {:?}", forward_time);
249 println!("Output shape: {:?}", output.shape());
250
251 let mean = output.mean().unwrap_or(0.0);
253 let std = {
254 let variance = output.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(0.0);
255 variance.sqrt()
256 };
257
258 println!("Output statistics:");
259 println!(" Mean: {:.6}", mean);
260 println!(" Std: {:.6}", std);
261 println!(
262 " Min: {:.6}",
263 output.iter().cloned().fold(f32::INFINITY, f32::min)
264 );
265 println!(
266 " Max: {:.6}",
267 output.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
268 );
269
270 Ok(())
271}
272
273fn demo_memory_usage_tracking() -> Result<()> {
274 println!("\nš Memory Usage Tracking Demo");
275 println!("-----------------------------");
276
277 let mut usage = MemoryUsage::new();
278
279 println!("Initial state:");
280 print_memory_usage(&usage);
281
282 println!("\nSimulating allocation patterns...");
284
285 usage.allocate(50 * 1024 * 1024); println!("After 50MB allocation:");
288 print_memory_usage(&usage);
289
290 for i in 1..=10 {
292 usage.allocate(1024 * 1024); if i % 3 == 0 {
294 println!("After {} small allocations:", i);
295 print_memory_usage(&usage);
296 }
297 }
298
299 println!("Peak memory usage reached:");
301 print_memory_usage(&usage);
302
303 println!("\nSimulating deallocations...");
305 for i in 1..=8 {
306 usage.deallocate(5 * 1024 * 1024); if i % 2 == 0 {
308 println!("After {} deallocations:", i);
309 print_memory_usage(&usage);
310 }
311 }
312
313 println!("\nFinal state (note peak is preserved):");
314 print_memory_usage(&usage);
315
316 Ok(())
317}
318
319fn print_pool_stats(stats: &PoolStatistics) {
322 println!(" Cached tensors: {}", stats.total_cached_tensors);
323 println!(" Unique shapes: {}", stats.unique_shapes);
324 println!(
325 " Pool size: {:.2}/{:.2} MB",
326 stats.current_pool_size_mb, stats.max_pool_size_mb
327 );
328}
329
330fn print_memory_usage(usage: &MemoryUsage) {
331 println!(" Current: {:.2} MB", usage.current_mb());
332 println!(" Peak: {:.2} MB", usage.peak_mb());
333 println!(" Active allocations: {}", usage.active_allocations);
334 println!(" Total allocations: {}", usage.total_allocations);
335}
336
337fn print_batch_processor_stats(stats: &BatchProcessorStats) {
338 println!(" Max batch size: {}", stats.max_batch_size);
339 println!(" Current memory: {:.2} MB", stats.current_memory_mb);
340 println!(" Peak memory: {:.2} MB", stats.peak_memory_mb);
341 println!(" Memory threshold: {:.2} MB", stats.memory_threshold_mb);
342 println!(" Pool stats:");
343 print_pool_stats(&stats.pool_stats);
344}