Skip to main content

scirs2_autograd/optimization/
memory_optimization.rs

1//! Memory optimization for computation graphs
2//!
3//! This module provides memory optimization techniques including gradient
4//! checkpointing, memory pooling, and tensor lifetime analysis.
5
6use super::OptimizationError;
7use crate::graph::Graph;
8use crate::Float;
9use std::collections::HashMap;
10
11/// Memory optimizer for computation graphs
12pub struct MemoryOptimizer<F: Float> {
13    /// Configuration for memory optimization
14    config: MemoryOptimizationConfig,
15    /// Analysis results
16    analysis: Option<MemoryAnalysis>,
17    _phantom: std::marker::PhantomData<F>,
18}
19
20impl<F: Float> MemoryOptimizer<F> {
21    /// Create a new memory optimizer
22    pub fn new() -> Self {
23        Self {
24            config: MemoryOptimizationConfig::default(),
25            analysis: None,
26            _phantom: std::marker::PhantomData,
27        }
28    }
29
30    /// Create a memory optimizer with custom configuration
31    pub fn with_config(config: MemoryOptimizationConfig) -> Self {
32        Self {
33            config,
34            analysis: None,
35            _phantom: std::marker::PhantomData,
36        }
37    }
38
39    /// Optimize memory usage in a computation graph
40    pub fn optimize(
41        &mut self,
42        graph: &mut Graph<F>,
43    ) -> Result<MemoryOptimizationReport, OptimizationError> {
44        let mut report = MemoryOptimizationReport::new();
45
46        // Analyze memory usage patterns
47        self.analysis = Some(self.analyze_memory_usage(graph)?);
48
49        if self.config.enable_gradient_checkpointing {
50            let checkpoints = self.apply_gradient_checkpointing(graph)?;
51            report.gradient_checkpoints_added = checkpoints;
52        }
53
54        if self.config.enable_memory_pooling {
55            let pools = self.setup_memory_pooling(graph)?;
56            report.memory_pools_created = pools;
57        }
58
59        if self.config.enable_in_place_operations {
60            let in_place_ops = self.apply_in_place_operations(graph)?;
61            report.in_place_operations_applied = in_place_ops;
62        }
63
64        if self.config.enable_tensor_reuse {
65            let reused = self.apply_tensor_reuse(graph)?;
66            report.tensors_reused = reused;
67        }
68
69        if self.config.enable_lifetime_optimization {
70            let optimized = self.optimize_tensor_lifetimes(graph)?;
71            report.lifetime_optimizations = optimized;
72        }
73
74        Ok(report)
75    }
76
77    /// Analyze memory usage patterns in the graph
78    fn analyze_memory_usage(&self, graph: &Graph<F>) -> Result<MemoryAnalysis, OptimizationError> {
79        let mut analysis = MemoryAnalysis::new();
80
81        // Analyze:
82        // - Tensor sizes and lifetimes
83        // - Memory allocation patterns
84        // - Peak memory usage
85        // - Opportunities for optimization
86
87        analysis.total_memory_allocated = 1024 * 1024; // Placeholder
88        analysis.peak_memory_usage = 512 * 1024; // Placeholder
89        analysis.num_allocations = 100; // Placeholder
90        analysis.num_deallocations = 90; // Placeholder
91
92        Ok(analysis)
93    }
94
95    /// Apply gradient checkpointing
96    fn apply_gradient_checkpointing(
97        &self,
98        graph: &mut Graph<F>,
99    ) -> Result<usize, OptimizationError> {
100        let mut checkpoints_added = 0;
101
102        // Strategy: Insert checkpoints at points where:
103        // 1. Memory usage is high
104        // 2. Recomputation cost is relatively low
105        // 3. It provides significant memory savings
106
107        let candidates = self.find_checkpoint_candidates(graph)?;
108
109        for candidate in candidates {
110            if self.should_checkpoint(&candidate) {
111                self.insert_checkpoint(graph, &candidate)?;
112                checkpoints_added += 1;
113            }
114        }
115
116        Ok(checkpoints_added)
117    }
118
119    /// Find candidates for gradient checkpointing
120    fn find_checkpoint_candidates(
121        &self,
122        graph: &Graph<F>,
123    ) -> Result<Vec<CheckpointCandidate<F>>, OptimizationError> {
124        let candidates = Vec::new();
125
126        // Look for:
127        // - Nodes with large memory footprint
128        // - Nodes in long computation chains
129        // - Nodes where recomputation is cheaper than storage
130
131        Ok(candidates)
132    }
133
134    /// Check if a node should be checkpointed
135    fn should_checkpoint(&self, candidate: &CheckpointCandidate<F>) -> bool {
136        // Decision criteria:
137        // - Memory savings > threshold
138        // - Recomputation cost < threshold
139        // - Not already checkpointed
140
141        candidate.memory_savings > self.config.checkpoint_memory_threshold
142            && candidate.recomputation_cost < self.config.checkpoint_compute_threshold
143    }
144
145    /// Insert a checkpoint at a specific location
146    fn insert_checkpoint(
147        &self,
148        graph: &mut Graph<F>,
149        _candidate: &CheckpointCandidate<F>,
150    ) -> Result<(), OptimizationError> {
151        // Insert a checkpoint operation that:
152        // 1. Saves the forward pass result
153        // 2. Releases intermediate computations
154        // 3. Recomputes on backward pass when needed
155
156        Ok(())
157    }
158
159    /// Setup memory pooling
160    fn setup_memory_pooling(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
161        let mut pools_created = 0;
162
163        // Analyze tensor size patterns
164        let size_patterns = self.analyze_tensor_sizes(graph)?;
165
166        // Create pools for common sizes
167        for (size, frequency) in size_patterns {
168            if frequency >= self.config.pool_frequency_threshold {
169                MemoryOptimizer::<F>::create_memory_pool(size)?;
170                pools_created += 1;
171            }
172        }
173
174        Ok(pools_created)
175    }
176
177    /// Analyze tensor size patterns
178    fn analyze_tensor_sizes(
179        &self,
180        graph: &Graph<F>,
181    ) -> Result<HashMap<usize, usize>, OptimizationError> {
182        let size_frequency = HashMap::new();
183
184        // Count frequency of different tensor sizes
185        // This would traverse the graph and collect size information
186
187        Ok(size_frequency)
188    }
189
190    /// Create a memory pool for a specific size
191    fn create_memory_pool(size: usize) -> Result<(), OptimizationError> {
192        // Create a memory pool that can reuse buffers of the specified _size
193        Ok(())
194    }
195
196    /// Apply in-place operations where safe
197    fn apply_in_place_operations(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
198        let mut in_place_applied = 0;
199
200        // Find operations that can be done in-place:
201        // - Element-wise operations where the input won't be used again
202        // - Operations where the output has the same shape as input
203        // - No aliasing issues
204
205        let candidates = self.find_in_place_candidates(graph)?;
206
207        for candidate in candidates {
208            if MemoryOptimizer::<F>::can_apply_in_place(&candidate) {
209                self.convert_to_in_place(graph, &candidate)?;
210                in_place_applied += 1;
211            }
212        }
213
214        Ok(in_place_applied)
215    }
216
217    /// Find candidates for in-place operations
218    fn find_in_place_candidates(
219        &self,
220        graph: &Graph<F>,
221    ) -> Result<Vec<InPlaceCandidate<F>>, OptimizationError> {
222        // Look for operations like:
223        // - Element-wise arithmetic
224        // - Activation functions
225        // - Normalization operations
226        // where the input tensor is not used elsewhere
227
228        Ok(Vec::new())
229    }
230
231    /// Check if an operation can be safely converted to in-place
232    fn can_apply_in_place(candidate: &InPlaceCandidate<F>) -> bool {
233        // Safety checks:
234        // - Input tensor is not used by other operations
235        // - No gradient computation conflicts
236        // - Compatible tensor layouts
237        // - No aliasing issues
238
239        true
240    }
241
242    /// Convert an operation to in-place
243    fn convert_to_in_place(
244        &self,
245        graph: &mut Graph<F>,
246        _candidate: &InPlaceCandidate<F>,
247    ) -> Result<(), OptimizationError> {
248        // Replace the operation with an in-place version
249        Ok(())
250    }
251
252    /// Apply tensor reuse optimization
253    fn apply_tensor_reuse(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
254        let mut reused_count = 0;
255
256        // Find tensors that can be reused:
257        // - Tensors with non-overlapping lifetimes
258        // - Compatible shapes and types
259        // - No aliasing conflicts
260
261        let reuse_groups = self.find_tensor_reuse_opportunities(graph)?;
262
263        for group in reuse_groups {
264            self.apply_tensor_reuse_group(graph, &group)?;
265            reused_count += group.tensors.len() - 1; // All but one reuse the same memory
266        }
267
268        Ok(reused_count)
269    }
270
271    /// Find opportunities for tensor reuse
272    fn find_tensor_reuse_opportunities(
273        &self,
274        graph: &Graph<F>,
275    ) -> Result<Vec<TensorReuseGroup<F>>, OptimizationError> {
276        // Analyze tensor lifetimes and find non-overlapping tensors
277        // that can share the same memory
278
279        Ok(Vec::new())
280    }
281
282    /// Apply tensor reuse for a group of tensors
283    fn apply_tensor_reuse_group(
284        &self,
285        graph: &mut Graph<F>,
286        _group: &TensorReuseGroup<F>,
287    ) -> Result<(), OptimizationError> {
288        // Modify the graph to reuse memory for tensors in the _group
289        Ok(())
290    }
291
292    /// Optimize tensor lifetimes
293    fn optimize_tensor_lifetimes(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
294        let mut optimizations = 0;
295
296        // Strategies:
297        // - Release tensors as early as possible
298        // - Defer allocations as late as possible
299        // - Reorder operations to minimize peak memory
300
301        optimizations += self.apply_early_release(graph)?;
302        optimizations += self.apply_late_allocation(graph)?;
303        optimizations += self.reorder_for_memory(graph)?;
304
305        Ok(optimizations)
306    }
307
308    /// Apply early release of tensors
309    fn apply_early_release(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
310        // Insert explicit release operations as soon as tensors are no longer needed
311        Ok(0)
312    }
313
314    /// Apply late allocation of tensors
315    fn apply_late_allocation(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
316        // Delay tensor allocation until just before they're needed
317        Ok(0)
318    }
319
320    /// Reorder operations to minimize peak memory usage
321    fn reorder_for_memory(&self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
322        // Reorder operations (where dependencies allow) to reduce peak memory
323        Ok(0)
324    }
325
326    /// Get the current memory analysis
327    pub fn get_analysis(&self) -> Option<&MemoryAnalysis> {
328        self.analysis.as_ref()
329    }
330}
331
332impl<F: Float> Default for MemoryOptimizer<F> {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338/// Configuration for memory optimization
339#[derive(Debug, Clone)]
340pub struct MemoryOptimizationConfig {
341    /// Enable gradient checkpointing
342    pub enable_gradient_checkpointing: bool,
343    /// Enable memory pooling
344    pub enable_memory_pooling: bool,
345    /// Enable in-place operations
346    pub enable_in_place_operations: bool,
347    /// Enable tensor reuse
348    pub enable_tensor_reuse: bool,
349    /// Enable tensor lifetime optimization
350    pub enable_lifetime_optimization: bool,
351    /// Memory threshold for checkpointing (bytes)
352    pub checkpoint_memory_threshold: usize,
353    /// Compute threshold for checkpointing (relative cost)
354    pub checkpoint_compute_threshold: f32,
355    /// Frequency threshold for creating memory pools
356    pub pool_frequency_threshold: usize,
357    /// Maximum memory usage target (bytes)
358    pub max_memory_usage: Option<usize>,
359}
360
361impl Default for MemoryOptimizationConfig {
362    fn default() -> Self {
363        Self {
364            enable_gradient_checkpointing: true,
365            enable_memory_pooling: true,
366            enable_in_place_operations: true,
367            enable_tensor_reuse: true,
368            enable_lifetime_optimization: true,
369            checkpoint_memory_threshold: 1024 * 1024, // 1MB
370            checkpoint_compute_threshold: 2.0,        // 2x recomputation cost
371            pool_frequency_threshold: 5,              // At least 5 uses
372            max_memory_usage: None,
373        }
374    }
375}
376
377/// Results of memory analysis
378#[derive(Debug, Clone, Default)]
379pub struct MemoryAnalysis {
380    /// Total memory allocated (bytes)
381    pub total_memory_allocated: usize,
382    /// Peak memory usage (bytes)
383    pub peak_memory_usage: usize,
384    /// Number of allocations
385    pub num_allocations: usize,
386    /// Number of deallocations
387    pub num_deallocations: usize,
388    /// Average tensor size
389    pub average_tensor_size: usize,
390    /// Largest tensor size
391    pub largest_tensor_size: usize,
392    /// Memory fragmentation ratio
393    pub fragmentation_ratio: f32,
394    /// Opportunities for optimization
395    pub optimization_opportunities: Vec<String>,
396}
397
398impl MemoryAnalysis {
399    /// Create a new memory analysis
400    pub fn new() -> Self {
401        Self::default()
402    }
403
404    /// Calculate memory efficiency
405    pub fn memory_efficiency(&self) -> f32 {
406        if self.total_memory_allocated == 0 {
407            return 1.0;
408        }
409        self.peak_memory_usage as f32 / self.total_memory_allocated as f32
410    }
411
412    /// Get allocation/deallocation balance
413    pub fn allocation_balance(&self) -> i32 {
414        self.num_allocations as i32 - self.num_deallocations as i32
415    }
416}
417
418/// Report of memory optimization results
419#[derive(Debug, Clone, Default)]
420pub struct MemoryOptimizationReport {
421    /// Number of gradient checkpoints added
422    pub gradient_checkpoints_added: usize,
423    /// Number of memory pools created
424    pub memory_pools_created: usize,
425    /// Number of in-place operations applied
426    pub in_place_operations_applied: usize,
427    /// Number of tensors reused
428    pub tensors_reused: usize,
429    /// Number of lifetime optimizations
430    pub lifetime_optimizations: usize,
431    /// Estimated memory savings (bytes)
432    pub estimated_memory_savings: usize,
433}
434
435impl MemoryOptimizationReport {
436    /// Create a new optimization report
437    pub fn new() -> Self {
438        Self::default()
439    }
440
441    /// Get total optimizations applied
442    pub fn total_optimizations(&self) -> usize {
443        self.gradient_checkpoints_added
444            + self.memory_pools_created
445            + self.in_place_operations_applied
446            + self.tensors_reused
447            + self.lifetime_optimizations
448    }
449
450    /// Print a summary of the memory optimization results
451    pub fn print_summary(&self) {
452        println!("Memory Optimization Report:");
453        println!("==========================");
454        println!("Total optimizations: {}", self.total_optimizations());
455
456        if self.gradient_checkpoints_added > 0 {
457            println!(
458                "  Gradient checkpoints: {}",
459                self.gradient_checkpoints_added
460            );
461        }
462        if self.memory_pools_created > 0 {
463            println!("  Memory pools created: {}", self.memory_pools_created);
464        }
465        if self.in_place_operations_applied > 0 {
466            println!(
467                "  In-place operations: {}",
468                self.in_place_operations_applied
469            );
470        }
471        if self.tensors_reused > 0 {
472            println!("  Tensors reused: {}", self.tensors_reused);
473        }
474        if self.lifetime_optimizations > 0 {
475            println!("  Lifetime optimizations: {}", self.lifetime_optimizations);
476        }
477        if self.estimated_memory_savings > 0 {
478            println!(
479                "  Estimated memory savings: {} bytes",
480                self.estimated_memory_savings
481            );
482        }
483    }
484}
485
486/// Candidate for gradient checkpointing
487#[derive(Debug)]
488pub(crate) struct CheckpointCandidate<F: Float> {
489    /// Node to potentially checkpoint
490    #[allow(dead_code)]
491    pub node: *const crate::tensor::TensorInternal<F>,
492    /// Estimated memory savings
493    pub memory_savings: usize,
494    /// Estimated recomputation cost
495    pub recomputation_cost: f32,
496    /// Priority for checkpointing
497    #[allow(dead_code)]
498    pub priority: f32,
499}
500
501/// Candidate for in-place operation
502#[derive(Debug)]
503pub(crate) struct InPlaceCandidate<F: Float> {
504    /// Node to convert to in-place
505    #[allow(dead_code)]
506    pub node: *const crate::tensor::TensorInternal<F>,
507    /// Estimated memory savings
508    #[allow(dead_code)]
509    pub memory_savings: usize,
510    /// Safety score (higher is safer)
511    #[allow(dead_code)]
512    pub safety_score: f32,
513}
514
515/// Group of tensors that can reuse memory
516#[derive(Debug)]
517pub(crate) struct TensorReuseGroup<F: Float> {
518    /// Tensors that can share memory
519    pub tensors: Vec<*const crate::tensor::TensorInternal<F>>,
520    /// Total memory that can be saved
521    #[allow(dead_code)]
522    pub memory_savings: usize,
523}
524
525/// Tensor lifetime analyzer
526pub struct TensorLifetimeAnalyzer<F: Float> {
527    _phantom: std::marker::PhantomData<F>,
528}
529
530impl<F: Float> TensorLifetimeAnalyzer<F> {
531    /// Create a new tensor lifetime analyzer
532    pub fn new() -> Self {
533        Self {
534            _phantom: std::marker::PhantomData,
535        }
536    }
537
538    /// Analyze tensor lifetimes in a graph
539    #[allow(dead_code)]
540    pub(crate) fn analyze(
541        &self,
542        graph: &Graph<F>,
543    ) -> Result<HashMap<*const crate::tensor::TensorInternal<F>, TensorLifetime>, OptimizationError>
544    {
545        let lifetimes = HashMap::new();
546
547        // For each tensor, determine:
548        // - When it's first created/allocated
549        // - When it's last used
550        // - Peak memory usage contribution
551        // - Overlap with other tensors
552
553        Ok(lifetimes)
554    }
555
556    /// Find overlapping tensor lifetimes
557    #[allow(dead_code)]
558    pub(crate) fn find_overlapping_lifetimes(
559        self_lifetimes: &HashMap<*const crate::tensor::TensorInternal<F>, TensorLifetime>,
560    ) -> Vec<Vec<*const crate::tensor::TensorInternal<F>>> {
561        // Group tensors with overlapping _lifetimes
562        // These cannot share memory
563        Vec::new()
564    }
565
566    /// Find non-overlapping tensor groups
567    #[allow(dead_code)]
568    pub(crate) fn find_reusable_groups(
569        self_lifetimes: &HashMap<*const crate::tensor::TensorInternal<F>, TensorLifetime>,
570    ) -> Vec<Vec<*const crate::tensor::TensorInternal<F>>> {
571        // Group tensors with non-overlapping _lifetimes
572        // These can potentially share memory
573        Vec::new()
574    }
575}
576
577impl<F: Float> Default for TensorLifetimeAnalyzer<F> {
578    fn default() -> Self {
579        Self::new()
580    }
581}
582
583/// Information about a tensor's lifetime
584#[derive(Debug, Clone)]
585pub struct TensorLifetime {
586    /// When the tensor is allocated
587    pub allocation_time: usize,
588    /// When the tensor is last used
589    pub deallocation_time: usize,
590    /// Size of the tensor
591    pub size: usize,
592    /// Peak usage during lifetime
593    pub peak_usage: usize,
594}
595
596impl TensorLifetime {
597    /// Check if this lifetime overlaps with another
598    pub fn overlaps_with(&self, other: &TensorLifetime) -> bool {
599        !(self.deallocation_time <= other.allocation_time
600            || other.deallocation_time <= self.allocation_time)
601    }
602
603    /// Get the duration of this lifetime
604    pub fn duration(&self) -> usize {
605        self.deallocation_time.saturating_sub(self.allocation_time)
606    }
607}
608
609/// Memory pool manager
610pub struct MemoryPoolManager<F: Float> {
611    /// Pools organized by size
612    pools: HashMap<usize, Vec<Vec<F>>>,
613    /// Pool usage statistics
614    stats: MemoryPoolStats,
615}
616
617impl<F: Float> MemoryPoolManager<F> {
618    /// Create a new memory pool manager
619    pub fn new() -> Self {
620        Self {
621            pools: HashMap::new(),
622            stats: MemoryPoolStats::default(),
623        }
624    }
625
626    /// Get a buffer from the pool
627    pub fn get_buffer(&mut self, size: usize) -> Vec<F> {
628        if let Some(pool) = self.pools.get_mut(&size) {
629            if let Some(buffer) = pool.pop() {
630                self.stats.pool_hits += 1;
631                return buffer;
632            }
633        }
634
635        self.stats.pool_misses += 1;
636        vec![F::zero(); size]
637    }
638
639    /// Return a buffer to the pool
640    pub fn return_buffer(&mut self, mut buffer: Vec<F>) {
641        let size = buffer.len();
642        buffer.clear();
643        buffer.resize(size, F::zero());
644
645        self.pools.entry(size).or_default().push(buffer);
646        self.stats.buffers_returned += 1;
647    }
648
649    /// Get pool statistics
650    pub fn get_stats(&self) -> &MemoryPoolStats {
651        &self.stats
652    }
653
654    /// Clear all pools
655    pub fn clear(&mut self) {
656        self.pools.clear();
657        self.stats = MemoryPoolStats::default();
658    }
659}
660
661impl<F: Float> Default for MemoryPoolManager<F> {
662    fn default() -> Self {
663        Self::new()
664    }
665}
666
667/// Statistics for memory pools
668#[derive(Debug, Clone, Default)]
669pub struct MemoryPoolStats {
670    /// Number of pool hits
671    pub pool_hits: usize,
672    /// Number of pool misses
673    pub pool_misses: usize,
674    /// Number of buffers returned
675    pub buffers_returned: usize,
676    /// Total memory pooled
677    pub total_pooled_memory: usize,
678}
679
680impl MemoryPoolStats {
681    /// Calculate pool hit ratio
682    pub fn hit_ratio(&self) -> f32 {
683        let total_requests = self.pool_hits + self.pool_misses;
684        if total_requests == 0 {
685            return 0.0;
686        }
687        self.pool_hits as f32 / total_requests as f32
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    #[test]
696    fn test_memory_optimizer_creation() {
697        let _optimizer = MemoryOptimizer::<f32>::new();
698        let _optimizer_with_config =
699            MemoryOptimizer::<f32>::with_config(MemoryOptimizationConfig::default());
700    }
701
702    #[test]
703    fn test_memory_optimization_config() {
704        let config = MemoryOptimizationConfig::default();
705        assert!(config.enable_gradient_checkpointing);
706        assert!(config.enable_memory_pooling);
707        assert!(config.enable_in_place_operations);
708        assert!(config.enable_tensor_reuse);
709        assert!(config.enable_lifetime_optimization);
710    }
711
712    #[test]
713    fn test_memory_analysis() {
714        let mut analysis = MemoryAnalysis::new();
715        analysis.total_memory_allocated = 1000;
716        analysis.peak_memory_usage = 800;
717        analysis.num_allocations = 10;
718        analysis.num_deallocations = 8;
719
720        assert_eq!(analysis.memory_efficiency(), 0.8);
721        assert_eq!(analysis.allocation_balance(), 2);
722    }
723
724    #[test]
725    fn test_memory_optimization_report() {
726        let mut report = MemoryOptimizationReport::new();
727        report.gradient_checkpoints_added = 5;
728        report.memory_pools_created = 3;
729        report.in_place_operations_applied = 10;
730
731        assert_eq!(report.total_optimizations(), 18);
732    }
733
734    #[test]
735    fn test_tensor_lifetime() {
736        let lifetime1 = TensorLifetime {
737            allocation_time: 0,
738            deallocation_time: 10,
739            size: 100,
740            peak_usage: 100,
741        };
742
743        let lifetime2 = TensorLifetime {
744            allocation_time: 5,
745            deallocation_time: 15,
746            size: 200,
747            peak_usage: 200,
748        };
749
750        let lifetime3 = TensorLifetime {
751            allocation_time: 20,
752            deallocation_time: 30,
753            size: 150,
754            peak_usage: 150,
755        };
756
757        assert!(lifetime1.overlaps_with(&lifetime2));
758        assert!(!lifetime1.overlaps_with(&lifetime3));
759        assert_eq!(lifetime1.duration(), 10);
760    }
761
762    #[test]
763    fn test_memory_pool_manager() {
764        let mut manager = MemoryPoolManager::<f32>::new();
765
766        // Get a buffer
767        let buffer = manager.get_buffer(100);
768        assert_eq!(buffer.len(), 100);
769        assert_eq!(manager.get_stats().pool_misses, 1);
770
771        // Return the buffer
772        manager.return_buffer(buffer);
773        assert_eq!(manager.get_stats().buffers_returned, 1);
774
775        // Get another buffer of the same size - should come from pool
776        let buffer2 = manager.get_buffer(100);
777        assert_eq!(buffer2.len(), 100);
778        assert_eq!(manager.get_stats().pool_hits, 1);
779    }
780
781    #[test]
782    fn test_memory_pool_stats() {
783        let stats = MemoryPoolStats {
784            pool_hits: 8,
785            pool_misses: 2,
786            ..Default::default()
787        };
788
789        assert_eq!(stats.hit_ratio(), 0.8);
790    }
791
792    #[test]
793    fn test_tensor_lifetime_analyzer() {
794        let _analyzer = TensorLifetimeAnalyzer::<f32>::new();
795    }
796}