performance_optimization/
performance_optimization.rs1use std::time::Instant;
53use train_station::Tensor;
54
55fn main() -> Result<(), Box<dyn std::error::Error>> {
60 println!("Starting Performance Optimization Example");
61
62 demonstrate_performance_benchmarking()?;
63 demonstrate_memory_optimization()?;
64 demonstrate_large_scale_processing()?;
65 demonstrate_optimization_techniques()?;
66
67 println!("Performance Optimization Example completed successfully!");
68 Ok(())
69}
70
71fn demonstrate_performance_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
76 println!("\n--- Performance Benchmarking ---");
77
78 let sizes = vec![100, 1000, 10000];
80
81 for size in sizes {
82 println!("\nBenchmarking with tensor size: {}", size);
83
84 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
86 let tensor = Tensor::from_slice(&data, vec![size])?;
87
88 let start = Instant::now();
90 let direct_result = tensor.mul_scalar(2.0).add_scalar(1.0);
91 let direct_time = start.elapsed();
92
93 let start = Instant::now();
95 let iterator_result: Tensor = tensor
96 .iter()
97 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
98 .collect();
99 let iterator_time = start.elapsed();
100
101 let start = Instant::now();
103 let _chained_result: Tensor = tensor
104 .iter()
105 .map(|elem| elem.mul_scalar(2.0))
106 .filter(|elem| elem.value() > size as f32)
107 .map(|elem| elem.add_scalar(1.0))
108 .collect();
109 let chained_time = start.elapsed();
110
111 println!(" Direct operations: {:?}", direct_time);
113 println!(" Iterator operations: {:?}", iterator_time);
114 println!(" Chained operations: {:?}", chained_time);
115
116 assert_eq!(direct_result.data(), iterator_result.data());
118 println!(
119 " Results match: {}",
120 direct_result.data() == iterator_result.data()
121 );
122
123 let ratio = iterator_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
125 println!(" Iterator/Direct ratio: {:.2}x", ratio);
126 }
127
128 Ok(())
129}
130
131fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
136 println!("\n--- Memory Optimization ---");
137
138 let size = 10000;
140 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
141 let tensor = Tensor::from_slice(&data, vec![size])?;
142
143 println!("Processing tensor of size: {}", size);
144
145 println!("\nPattern 1: Streaming Processing");
147 let chunk_size = 1000;
148 let start = Instant::now();
149
150 let mut streamed_result = Vec::new();
151 for chunk_start in (0..size).step_by(chunk_size) {
152 let chunk_end = (chunk_start + chunk_size).min(size);
153 let chunk: Tensor = tensor
154 .iter_range(chunk_start, chunk_end)
155 .map(|elem| elem.pow_scalar(2.0).sqrt())
156 .collect();
157 streamed_result.extend(chunk.data().iter().cloned());
158 }
159 let streamed_time = start.elapsed();
160
161 let start = Instant::now();
163 let _full_result: Tensor = tensor
164 .iter()
165 .map(|elem| elem.pow_scalar(2.0).sqrt())
166 .collect();
167 let full_time = start.elapsed();
168
169 println!(" Streaming time: {:?}", streamed_time);
170 println!(" Full processing time: {:?}", full_time);
171 println!(
172 " Memory efficiency ratio: {:.2}x",
173 full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
174 );
175
176 println!("\nPattern 2: Lazy Evaluation");
178 let start = Instant::now();
179 let lazy_result: Tensor = tensor
180 .iter()
181 .take(1000) .map(|elem| elem.pow_scalar(2.0).sqrt())
183 .collect();
184 let lazy_time = start.elapsed();
185
186 println!(" Lazy processing (1000 elements): {:?}", lazy_time);
187 println!(" Lazy result size: {}", lazy_result.size());
188
189 println!("\nPattern 3: Memory-Efficient Filtering");
191 let start = Instant::now();
192 let filtered_result: Tensor = tensor
193 .iter()
194 .filter(|elem| elem.value() > size as f32 / 2.0) .map(|elem| elem.mul_scalar(2.0))
196 .collect();
197 let filtered_time = start.elapsed();
198
199 println!(" Filtered processing: {:?}", filtered_time);
200 println!(
201 " Filtered result size: {} (reduced from {})",
202 filtered_result.size(),
203 size
204 );
205
206 Ok(())
207}
208
209fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
214 println!("\n--- Large-Scale Processing ---");
215
216 let sizes = vec![10000, 50000, 100000];
218
219 for size in sizes {
220 println!("\nProcessing dataset of size: {}", size);
221
222 let data: Vec<f32> = (0..size)
224 .map(|i| {
225 let x = i as f32 / size as f32;
226 x * x + 0.1 * (i % 10) as f32 })
228 .collect();
229
230 let tensor = Tensor::from_slice(&data, vec![size])?;
231
232 let batch_size = 1000;
234 let start = Instant::now();
235
236 let mut batch_results = Vec::new();
237 for batch_start in (0..size).step_by(batch_size) {
238 let batch_end = (batch_start + batch_size).min(size);
239 let batch: Tensor = tensor
240 .iter_range(batch_start, batch_end)
241 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
242 .collect();
243 batch_results.push(batch);
244 }
245 let batch_time = start.elapsed();
246
247 let start = Instant::now();
249 let stride = 4;
250 let strided_result: Tensor = tensor
251 .iter()
252 .enumerate()
253 .filter(|(i, _)| i % stride == 0)
254 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
255 .collect();
256 let strided_time = start.elapsed();
257
258 let start = Instant::now();
260 let coarse: Tensor = tensor
261 .iter()
262 .enumerate()
263 .filter(|(i, _)| i % 10 == 0) .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
265 .collect();
266 let fine: Tensor = tensor
267 .iter()
268 .enumerate()
269 .filter(|(i, _)| i % 10 != 0) .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
271 .collect();
272 let hierarchical_time = start.elapsed();
273
274 println!(" Batch processing: {:?}", batch_time);
276 println!(" Strided processing: {:?}", strided_time);
277 println!(" Hierarchical processing: {:?}", hierarchical_time);
278
279 let total_batches = (size + batch_size - 1) / batch_size;
281 println!(" Batch count: {}", total_batches);
282 println!(" Strided result size: {}", strided_result.size());
283 println!(
284 " Hierarchical: coarse={}, fine={}",
285 coarse.size(),
286 fine.size()
287 );
288 }
289
290 Ok(())
291}
292
293fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
298 println!("\n--- Optimization Techniques ---");
299
300 let size = 50000;
301 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
302 let tensor = Tensor::from_slice(&data, vec![size])?;
303
304 println!("Optimizing processing for size: {}", size);
305
306 println!("\nTechnique 1: Operation Fusion");
308 let start = Instant::now();
309 let fused_result: Tensor = tensor
310 .iter()
311 .map(|elem| {
312 elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
314 })
315 .collect();
316 let fused_time = start.elapsed();
317
318 println!("\nTechnique 2: Conditional Optimization");
320 let start = Instant::now();
321 let conditional_result: Tensor = tensor
322 .iter()
323 .map(|elem| {
324 let val = elem.value();
325 if val < size as f32 / 2.0 {
326 elem.mul_scalar(2.0) } else {
328 elem.pow_scalar(2.0).sqrt() }
330 })
331 .collect();
332 let conditional_time = start.elapsed();
333
334 println!("\nTechnique 3: Cache-Friendly Processing");
336 let start = Instant::now();
337 let cache_friendly_result: Tensor = tensor
338 .iter()
339 .take(1000) .map(|elem| elem.mul_scalar(2.0))
341 .collect();
342 let cache_friendly_time = start.elapsed();
343
344 println!("\nTechnique 4: Memory Pooling Simulation");
346 let start = Instant::now();
347 let pooled_result: Tensor = tensor
348 .iter()
349 .enumerate()
350 .filter(|(i, _)| i % 100 == 0) .map(|(_, elem)| elem.pow_scalar(2.0))
352 .collect();
353 let pooled_time = start.elapsed();
354
355 println!(" Fused operations: {:?}", fused_time);
357 println!(" Conditional optimization: {:?}", conditional_time);
358 println!(" Cache-friendly processing: {:?}", cache_friendly_time);
359 println!(" Memory pooling simulation: {:?}", pooled_time);
360
361 let fastest = fused_time
363 .min(conditional_time)
364 .min(cache_friendly_time)
365 .min(pooled_time);
366 println!(" Fastest technique: {:?}", fastest);
367
368 println!(" Fused result size: {}", fused_result.size());
370 println!(" Conditional result size: {}", conditional_result.size());
371 println!(
372 " Cache-friendly result size: {}",
373 cache_friendly_result.size()
374 );
375 println!(" Pooled result size: {}", pooled_result.size());
376
377 println!("\nTechnique 5: Gradient Optimization");
379 let grad_tensor = tensor.with_requires_grad();
380 let start = Instant::now();
381
382 let grad_result: Tensor = grad_tensor
383 .iter()
384 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
385 .collect();
386
387 let mut loss = grad_result.sum();
388 loss.backward(None);
389 let grad_time = start.elapsed();
390
391 println!(" Gradient computation: {:?}", grad_time);
392 println!(
393 " Gradient tracking enabled: {}",
394 grad_result.requires_grad()
395 );
396
397 Ok(())
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
406 fn test_performance_benchmarking() {
407 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
408 let direct = tensor.mul_scalar(2.0);
409 let iterator: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
410
411 assert_eq!(direct.data(), iterator.data());
412 }
413
414 #[test]
416 fn test_memory_optimization() {
417 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
418 let streamed: Tensor = tensor
419 .iter_range(0, 2)
420 .map(|elem| elem.mul_scalar(2.0))
421 .collect();
422
423 assert_eq!(streamed.data(), &[2.0, 4.0]);
424 }
425
426 #[test]
428 fn test_large_scale_processing() {
429 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
430 let strided: Tensor = tensor
431 .iter()
432 .enumerate()
433 .filter(|(i, _)| i % 2 == 0)
434 .map(|(_, elem)| elem)
435 .collect();
436
437 assert_eq!(strided.data(), &[1.0, 3.0]);
438 }
439
440 #[test]
442 fn test_optimization_techniques() {
443 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
444 let fused: Tensor = tensor
445 .iter()
446 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
447 .collect();
448
449 assert_eq!(fused.data(), &[3.0, 5.0, 7.0]);
450 }
451}