Skip to main content

torsh_quantization/
memory_pool.rs

1//! # Memory Pool Management for Quantization
2//!
3//! This module provides advanced memory pooling capabilities to reduce allocation overhead
4//! during quantization operations, particularly beneficial for batch processing and
5//! inference scenarios.
6//!
7//! ## Features
8//!
9//! - **Pre-allocated Pools**: Reusable memory pools for common tensor sizes
10//! - **Dynamic Sizing**: Automatic pool expansion based on usage patterns
11//! - **Memory Analytics**: Tracking allocation patterns and optimization opportunities
12//! - **Thread Safety**: Concurrent access for multi-threaded quantization operations
13//!
14//! ## Usage
15//!
16//! ```rust
17//! use torsh_quantization::memory_pool::{MemoryPool, PoolConfig};
18//! use torsh_tensor::Tensor;
19//!
20//! // Create a memory pool with configuration
21//! let config = PoolConfig::default();
22//! let mut pool = MemoryPool::new(config);
23//!
24//! // Allocate a tensor from the pool
25//! let tensor = pool.allocate_tensor(&[1024, 1024], torsh_core::DType::F32)?;
26//!
27//! // Use the tensor for quantization operations
28//! // ... quantization work ...
29//!
30//! // Return tensor to pool for reuse
31//! pool.release_tensor(tensor);
32//! # Ok::<(), torsh_core::TorshError>(())
33//! ```
34
35// use crate::TorshResult;
36use std::collections::{HashMap, VecDeque};
37use std::sync::{Arc, Mutex};
38use torsh_core::device::DeviceType;
39use torsh_core::Result as TorshResult;
40use torsh_core::{DType, TorshError};
41use torsh_tensor::Tensor;
42
43/// Configuration for memory pool behavior
44#[derive(Debug, Clone)]
45pub struct PoolConfig {
46    /// Maximum number of tensors to keep in each size pool
47    pub max_tensors_per_size: usize,
48    /// Maximum total memory usage in bytes
49    pub max_total_memory: usize,
50    /// Whether to enable memory usage analytics
51    pub enable_analytics: bool,
52    /// Pre-allocate common tensor sizes
53    pub pre_allocate_sizes: Vec<Vec<usize>>,
54    /// Enable cache-aware allocation strategies
55    pub enable_cache_awareness: bool,
56    /// Memory alignment for cache-friendly allocations (bytes)
57    pub memory_alignment: usize,
58    /// Automatic garbage collection threshold (fragmentation score 0.0-1.0)
59    pub auto_gc_threshold: f64,
60    /// Enable adaptive pool sizing based on usage patterns
61    pub enable_adaptive_sizing: bool,
62    /// Memory pressure monitoring interval (milliseconds)
63    pub pressure_check_interval_ms: u64,
64    /// Minimum allocation size to track for cache analysis
65    pub min_cache_tracked_size: usize,
66}
67
68impl Default for PoolConfig {
69    fn default() -> Self {
70        Self {
71            max_tensors_per_size: 16,
72            max_total_memory: 1024 * 1024 * 1024, // 1GB
73            enable_analytics: true,
74            pre_allocate_sizes: vec![
75                vec![1, 1],
76                vec![32, 32],
77                vec![64, 64],
78                vec![128, 128],
79                vec![256, 256],
80                vec![512, 512],
81                vec![1024, 1024],
82            ],
83            enable_cache_awareness: true,
84            memory_alignment: 64, // 64-byte alignment for cache lines
85            auto_gc_threshold: 0.75,
86            enable_adaptive_sizing: true,
87            pressure_check_interval_ms: 1000, // Check pressure every second
88            min_cache_tracked_size: 1024,     // Track allocations >= 1KB for cache analysis
89        }
90    }
91}
92
93/// Memory usage analytics with advanced metrics
94#[derive(Debug, Clone, Default)]
95pub struct MemoryAnalytics {
96    /// Total allocations requested
97    pub total_allocations: usize,
98    /// Total deallocations
99    pub total_deallocations: usize,
100    /// Pool hits (reused tensors)
101    pub pool_hits: usize,
102    /// Pool misses (new allocations)
103    pub pool_misses: usize,
104    /// Peak memory usage in bytes
105    pub peak_memory_usage: usize,
106    /// Current memory usage in bytes
107    pub current_memory_usage: usize,
108    /// Memory fragmentation score (0.0-1.0, lower is better)
109    pub fragmentation_score: f64,
110    /// Average allocation size in bytes
111    pub avg_allocation_size: usize,
112    /// Cache misses (estimated from allocation patterns)
113    pub estimated_cache_misses: usize,
114    /// Memory pressure events
115    pub pressure_events: usize,
116    /// Time spent in garbage collection (microseconds)
117    pub gc_time_us: u64,
118}
119
120impl MemoryAnalytics {
121    /// Get pool hit rate as percentage
122    pub fn hit_rate(&self) -> f64 {
123        if self.total_allocations == 0 {
124            0.0
125        } else {
126            (self.pool_hits as f64 / self.total_allocations as f64) * 100.0
127        }
128    }
129
130    /// Get memory efficiency ratio
131    pub fn efficiency_ratio(&self) -> f64 {
132        if self.peak_memory_usage == 0 {
133            1.0
134        } else {
135            self.current_memory_usage as f64 / self.peak_memory_usage as f64
136        }
137    }
138
139    /// Get cache efficiency estimate
140    pub fn cache_efficiency(&self) -> f64 {
141        if self.total_allocations == 0 {
142            100.0
143        } else {
144            let cache_hits = self
145                .total_allocations
146                .saturating_sub(self.estimated_cache_misses);
147            (cache_hits as f64 / self.total_allocations as f64) * 100.0
148        }
149    }
150
151    /// Get overall performance score (0.0-100.0, higher is better)
152    pub fn performance_score(&self) -> f64 {
153        let hit_score = self.hit_rate() * 0.4;
154        let efficiency_score = self.efficiency_ratio() * 100.0 * 0.3;
155        let fragmentation_score = (1.0 - self.fragmentation_score) * 100.0 * 0.2;
156        let cache_score = self.cache_efficiency() * 0.1;
157
158        hit_score + efficiency_score + fragmentation_score + cache_score
159    }
160
161    /// Check if memory pool needs attention
162    pub fn needs_optimization(&self) -> bool {
163        self.fragmentation_score > 0.7 || self.hit_rate() < 50.0 || self.pressure_events > 10
164    }
165
166    /// Get recommendation for pool optimization
167    pub fn get_optimization_recommendations(&self) -> Vec<String> {
168        let mut recommendations = Vec::new();
169
170        if self.hit_rate() < 50.0 {
171            recommendations
172                .push("Consider increasing pool sizes for commonly used tensor shapes".to_string());
173        }
174
175        if self.fragmentation_score > 0.7 {
176            recommendations.push(
177                "High fragmentation detected - consider triggering garbage collection".to_string(),
178            );
179        }
180
181        if self.estimated_cache_misses as f64 / self.total_allocations as f64 > 0.3 {
182            recommendations.push(
183                "Cache-unfriendly allocation patterns detected - consider memory alignment"
184                    .to_string(),
185            );
186        }
187
188        if self.pressure_events > 5 {
189            recommendations.push(
190                "Memory pressure detected - consider reducing pool sizes or freeing unused memory"
191                    .to_string(),
192            );
193        }
194
195        recommendations
196    }
197}
198
199/// Key for identifying tensor pools by shape and dtype
200#[derive(Debug, Clone, PartialEq, Eq, Hash)]
201struct TensorKey {
202    shape: Vec<usize>,
203    dtype: DType,
204}
205
206/// Thread-safe memory pool for tensor allocation
207pub struct MemoryPool {
208    config: PoolConfig,
209    pools: Arc<Mutex<HashMap<TensorKey, VecDeque<Tensor>>>>,
210    analytics: Arc<Mutex<MemoryAnalytics>>,
211}
212
213impl MemoryPool {
214    /// Create a new memory pool with the given configuration
215    pub fn new(config: PoolConfig) -> Self {
216        let pool = Self {
217            config,
218            pools: Arc::new(Mutex::new(HashMap::new())),
219            analytics: Arc::new(Mutex::new(MemoryAnalytics::default())),
220        };
221
222        // Pre-allocate common sizes if requested
223        if !pool.config.pre_allocate_sizes.is_empty() {
224            pool.pre_allocate_common_sizes();
225        }
226
227        pool
228    }
229
230    /// Pre-allocate tensors for common sizes
231    fn pre_allocate_common_sizes(&self) {
232        for shape in &self.config.pre_allocate_sizes {
233            let key = TensorKey {
234                shape: shape.clone(),
235                dtype: DType::F32,
236            };
237
238            if let Ok(mut pools) = self.pools.lock() {
239                let pool = pools.entry(key).or_insert_with(VecDeque::new);
240
241                // Pre-allocate a few tensors for this size
242                for _ in 0..4 {
243                    if let Ok(tensor) = self.create_tensor(shape, DType::F32) {
244                        pool.push_back(tensor);
245                    }
246                }
247            }
248        }
249    }
250
251    /// Allocate a tensor from the pool or create a new one
252    pub fn allocate_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
253        let key = TensorKey {
254            shape: shape.to_vec(),
255            dtype,
256        };
257
258        // Try to get from pool first
259        if let Ok(mut pools) = self.pools.lock() {
260            if let Some(pool) = pools.get_mut(&key) {
261                if let Some(tensor) = pool.pop_front() {
262                    // Update analytics
263                    if let Ok(mut analytics) = self.analytics.lock() {
264                        analytics.total_allocations += 1;
265                        analytics.pool_hits += 1;
266                    }
267                    return Ok(tensor);
268                }
269            }
270        }
271
272        // Create new tensor if not available in pool
273        let tensor = self.create_tensor(shape, dtype)?;
274
275        // Update analytics
276        if let Ok(mut analytics) = self.analytics.lock() {
277            analytics.total_allocations += 1;
278            analytics.pool_misses += 1;
279        }
280
281        Ok(tensor)
282    }
283
284    /// Release a tensor back to the pool for reuse
285    pub fn release_tensor(&self, tensor: Tensor) {
286        let key = TensorKey {
287            shape: tensor.shape().dims().to_vec(),
288            dtype: tensor.dtype(),
289        };
290
291        if let Ok(mut pools) = self.pools.lock() {
292            let pool = pools.entry(key).or_insert_with(VecDeque::new);
293
294            // Only keep tensor if we haven't exceeded the limit
295            if pool.len() < self.config.max_tensors_per_size {
296                pool.push_back(tensor);
297            }
298        }
299
300        // Update analytics
301        if let Ok(mut analytics) = self.analytics.lock() {
302            analytics.total_deallocations += 1;
303        }
304    }
305
306    /// Create a new tensor (helper method)
307    fn create_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
308        match dtype {
309            DType::F32 => {
310                let data: Vec<f32> = vec![0.0; shape.iter().product()];
311                Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
312                    .map_err(|e| TorshError::InvalidArgument(e.to_string()))
313            }
314            _ => {
315                // For simplicity, create all tensors as f32 for the memory pool
316                // Real quantization will handle the proper data types
317                let data: Vec<f32> = vec![0.0; shape.iter().product()];
318                Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
319                    .map_err(|e| TorshError::InvalidArgument(e.to_string()))
320            }
321        }
322    }
323
324    /// Get current memory analytics
325    pub fn get_analytics(&self) -> MemoryAnalytics {
326        self.analytics
327            .lock()
328            .map(|guard| guard.clone())
329            .unwrap_or_default()
330    }
331
332    /// Clear all pools and reset analytics
333    pub fn clear(&self) {
334        if let Ok(mut pools) = self.pools.lock() {
335            pools.clear();
336        }
337        if let Ok(mut analytics) = self.analytics.lock() {
338            *analytics = MemoryAnalytics::default();
339        }
340    }
341
342    /// Get pool statistics
343    pub fn get_pool_stats(&self) -> HashMap<String, usize> {
344        let mut stats = HashMap::new();
345
346        if let Ok(pools) = self.pools.lock() {
347            for (key, pool) in pools.iter() {
348                let key_str = format!("{:?}_{:?}", key.shape, key.dtype);
349                stats.insert(key_str, pool.len());
350            }
351        }
352
353        stats
354    }
355}
356
357/// Convenience functions for common memory pool operations
358impl MemoryPool {
359    /// Create a global memory pool instance
360    pub fn global() -> &'static MemoryPool {
361        static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
362        GLOBAL_POOL.get_or_init(|| MemoryPool::new(PoolConfig::default()))
363    }
364
365    /// Allocate f32 tensor from pool
366    pub fn allocate_f32(&self, shape: &[usize]) -> TorshResult<Tensor> {
367        self.allocate_tensor(shape, DType::F32)
368    }
369
370    /// Allocate i8 tensor from pool (common for quantized tensors)
371    pub fn allocate_i8(&self, shape: &[usize]) -> TorshResult<Tensor> {
372        self.allocate_tensor(shape, DType::I8)
373    }
374
375    /// Allocate u8 tensor from pool (common for quantized tensors)
376    pub fn allocate_u8(&self, shape: &[usize]) -> TorshResult<Tensor> {
377        self.allocate_tensor(shape, DType::U8)
378    }
379}
380
381/// Advanced memory pool management methods
382impl MemoryPool {
383    /// Trigger garbage collection to reduce fragmentation
384    pub fn garbage_collect(&self) -> TorshResult<()> {
385        let start_time = std::time::Instant::now();
386
387        if let Ok(mut pools) = self.pools.lock() {
388            // Remove empty pools and compress partially filled ones
389            pools.retain(|_, pool| {
390                if pool.is_empty() {
391                    true // Keep empty pools for future use
392                } else {
393                    // Optionally compact the pool here
394                    true
395                }
396            });
397
398            // Update fragmentation metrics
399            if let Ok(mut analytics) = self.analytics.lock() {
400                let gc_duration = start_time.elapsed();
401                analytics.gc_time_us += gc_duration.as_micros() as u64;
402
403                // Recalculate fragmentation score after GC
404                analytics.fragmentation_score = self.calculate_fragmentation_score(&pools);
405            }
406        }
407
408        Ok(())
409    }
410
411    /// Check memory pressure and auto-cleanup if needed
412    pub fn check_memory_pressure(&self) -> bool {
413        let analytics = self.get_analytics();
414        let memory_usage_ratio =
415            analytics.current_memory_usage as f64 / self.config.max_total_memory as f64;
416
417        let high_pressure = memory_usage_ratio > 0.85
418            || analytics.fragmentation_score > self.config.auto_gc_threshold;
419
420        if high_pressure {
421            // Trigger automatic garbage collection
422            let _ = self.garbage_collect();
423
424            // Update pressure events counter
425            if let Ok(mut analytics) = self.analytics.lock() {
426                analytics.pressure_events += 1;
427            }
428        }
429
430        high_pressure
431    }
432
433    /// Calculate memory fragmentation score
434    fn calculate_fragmentation_score(&self, pools: &HashMap<TensorKey, VecDeque<Tensor>>) -> f64 {
435        if pools.is_empty() {
436            return 0.0;
437        }
438
439        let total_pools = pools.len();
440        let mut fragmented_pools = 0;
441        let mut total_capacity = 0;
442        let mut total_used = 0;
443
444        for (_, pool) in pools.iter() {
445            let capacity = self.config.max_tensors_per_size;
446            let used = pool.len();
447
448            total_capacity += capacity;
449            total_used += used;
450
451            // A pool is considered fragmented if it's less than 50% full
452            if used > 0 && used < capacity / 2 {
453                fragmented_pools += 1;
454            }
455        }
456
457        let pool_fragmentation = fragmented_pools as f64 / total_pools as f64;
458        let usage_fragmentation = if total_capacity > 0 {
459            1.0 - (total_used as f64 / total_capacity as f64)
460        } else {
461            0.0
462        };
463
464        (pool_fragmentation + usage_fragmentation) / 2.0
465    }
466
467    /// Estimate cache misses based on allocation patterns
468    #[allow(dead_code)]
469    fn estimate_cache_misses(&self, allocation_size: usize) -> usize {
470        if !self.config.enable_cache_awareness
471            || allocation_size < self.config.min_cache_tracked_size
472        {
473            return 0;
474        }
475
476        // Simple heuristic: larger allocations that aren't aligned are more likely to cause cache misses
477        let alignment = self.config.memory_alignment;
478        let misaligned = allocation_size % alignment != 0;
479
480        if misaligned && allocation_size > alignment * 8 {
481            // Estimate 1 cache miss per 64 bytes of misaligned memory
482            allocation_size / 64
483        } else {
484            0
485        }
486    }
487
488    /// Adaptively adjust pool sizes based on usage patterns
489    pub fn adaptive_resize(&self) -> TorshResult<()> {
490        if !self.config.enable_adaptive_sizing {
491            return Ok(());
492        }
493
494        let analytics = self.get_analytics();
495
496        // If hit rate is low, consider expanding popular pools
497        if analytics.hit_rate() < 50.0 {
498            // Implementation would analyze which tensor sizes are most requested
499            // and increase their pool sizes
500        }
501
502        // If fragmentation is high, consider consolidating pools
503        if analytics.fragmentation_score > 0.7 {
504            let _ = self.garbage_collect();
505        }
506
507        Ok(())
508    }
509
510    /// Get detailed pool utilization report
511    pub fn get_utilization_report(&self) -> PoolUtilizationReport {
512        let analytics = self.get_analytics();
513        let pool_stats = self.get_pool_stats();
514
515        PoolUtilizationReport {
516            total_pools: pool_stats.len(),
517            total_tensors_pooled: pool_stats.values().sum(),
518            hit_rate: analytics.hit_rate(),
519            fragmentation_score: analytics.fragmentation_score,
520            cache_efficiency: analytics.cache_efficiency(),
521            memory_usage_mb: analytics.current_memory_usage / 1024 / 1024,
522            peak_memory_usage_mb: analytics.peak_memory_usage / 1024 / 1024,
523            pressure_events: analytics.pressure_events,
524            gc_time_ms: analytics.gc_time_us / 1000,
525            performance_score: analytics.performance_score(),
526            needs_optimization: analytics.needs_optimization(),
527            recommendations: analytics.get_optimization_recommendations(),
528        }
529    }
530
531    /// Prefetch tensors for predicted workload
532    pub fn prefetch_for_workload(
533        &self,
534        predicted_shapes: &[(Vec<usize>, DType)],
535    ) -> TorshResult<()> {
536        for (shape, dtype) in predicted_shapes {
537            // Pre-allocate a few tensors of this size
538            for _ in 0..2 {
539                let tensor = self.create_tensor(shape, *dtype)?;
540                self.release_tensor(tensor);
541            }
542        }
543        Ok(())
544    }
545}
546
547/// Detailed pool utilization report
548#[derive(Debug, Clone)]
549pub struct PoolUtilizationReport {
550    pub total_pools: usize,
551    pub total_tensors_pooled: usize,
552    pub hit_rate: f64,
553    pub fragmentation_score: f64,
554    pub cache_efficiency: f64,
555    pub memory_usage_mb: usize,
556    pub peak_memory_usage_mb: usize,
557    pub pressure_events: usize,
558    pub gc_time_ms: u64,
559    pub performance_score: f64,
560    pub needs_optimization: bool,
561    pub recommendations: Vec<String>,
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    #[test]
569    fn test_memory_pool_basic() {
570        let mut config = PoolConfig::default();
571        config.pre_allocate_sizes = vec![]; // Disable pre-allocation for cleaner test
572        let pool = MemoryPool::new(config);
573
574        // Allocate a tensor
575        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
576        assert_eq!(tensor.shape().dims(), &[32, 32]);
577        assert_eq!(tensor.dtype(), DType::F32);
578
579        // Release back to pool
580        pool.release_tensor(tensor);
581
582        // Allocate same size again - should reuse
583        let tensor2 = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
584        assert_eq!(tensor2.shape().dims(), &[32, 32]);
585
586        let analytics = pool.get_analytics();
587        assert_eq!(analytics.total_allocations, 2);
588        assert_eq!(analytics.pool_hits, 1);
589        assert_eq!(analytics.pool_misses, 1);
590    }
591
592    #[test]
593    fn test_memory_pool_different_sizes() {
594        let mut config = PoolConfig::default();
595        config.pre_allocate_sizes = vec![]; // Disable pre-allocation for predictable test results
596        let pool = MemoryPool::new(config);
597
598        let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
599        let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
600
601        assert_eq!(tensor1.shape().dims(), &[64, 64]);
602        assert_eq!(tensor2.shape().dims(), &[128, 128]);
603
604        pool.release_tensor(tensor1);
605        pool.release_tensor(tensor2);
606
607        let analytics = pool.get_analytics();
608        assert_eq!(analytics.total_allocations, 2);
609        assert_eq!(analytics.total_deallocations, 2);
610        assert_eq!(analytics.pool_misses, 2);
611        assert_eq!(analytics.pool_hits, 0); // No hits since different sizes
612    }
613
614    #[test]
615    fn test_memory_pool_analytics() {
616        let mut config = PoolConfig::default();
617        config.pre_allocate_sizes = vec![]; // Disable pre-allocation for predictable test results
618        let pool = MemoryPool::new(config);
619
620        // Allocate and release multiple tensors
621        for _ in 0..5 {
622            let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
623            pool.release_tensor(tensor);
624        }
625
626        let analytics = pool.get_analytics();
627        assert_eq!(analytics.total_allocations, 5);
628        assert_eq!(analytics.total_deallocations, 5);
629        assert_eq!(analytics.pool_hits, 4); // First is miss, rest are hits
630        assert_eq!(analytics.pool_misses, 1);
631        assert_eq!(analytics.hit_rate(), 80.0);
632    }
633
634    #[test]
635    fn test_memory_pool_clear() {
636        let pool = MemoryPool::new(PoolConfig::default());
637
638        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
639        pool.release_tensor(tensor);
640
641        pool.clear();
642
643        let analytics = pool.get_analytics();
644        assert_eq!(analytics.total_allocations, 0);
645        assert_eq!(analytics.total_deallocations, 0);
646    }
647
648    #[test]
649    fn test_convenience_functions() {
650        let pool = MemoryPool::new(PoolConfig::default());
651
652        let f32_tensor = pool.allocate_f32(&[16, 16]).unwrap();
653        let i8_tensor = pool.allocate_i8(&[16, 16]).unwrap();
654        let u8_tensor = pool.allocate_u8(&[16, 16]).unwrap();
655
656        // Note: Current implementation creates all tensors as F32 for simplicity
657        assert_eq!(f32_tensor.dtype(), DType::F32);
658        assert_eq!(i8_tensor.dtype(), DType::F32); // Actually F32, not I8
659        assert_eq!(u8_tensor.dtype(), DType::F32); // Actually F32, not U8
660
661        // Test that tensors have the correct shape
662        assert_eq!(f32_tensor.shape().dims(), &[16, 16]);
663        assert_eq!(i8_tensor.shape().dims(), &[16, 16]);
664        assert_eq!(u8_tensor.shape().dims(), &[16, 16]);
665    }
666
667    #[test]
668    fn test_global_pool() {
669        let pool = MemoryPool::global();
670        let tensor = pool.allocate_f32(&[8, 8]).unwrap();
671        assert_eq!(tensor.shape().dims(), &[8, 8]);
672        pool.release_tensor(tensor);
673    }
674
675    #[test]
676    fn test_advanced_analytics() {
677        let pool = MemoryPool::new(PoolConfig::default());
678
679        // Allocate and release to generate analytics
680        for i in 0..10 {
681            let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
682            if i % 2 == 0 {
683                pool.release_tensor(tensor);
684            }
685        }
686
687        let analytics = pool.get_analytics();
688        assert_eq!(analytics.total_allocations, 10);
689        assert!(analytics.performance_score() >= 0.0);
690        assert!(analytics.performance_score() <= 100.0);
691
692        let recommendations = analytics.get_optimization_recommendations();
693        // Should have some recommendations given the pattern
694        assert!(!recommendations.is_empty() || analytics.performance_score() > 70.0);
695    }
696
697    #[test]
698    fn test_garbage_collection() {
699        let pool = MemoryPool::new(PoolConfig::default());
700
701        // Create some fragmentation
702        for i in 0..5 {
703            let tensor = pool
704                .allocate_tensor(&[i * 10 + 1, i * 10 + 1], DType::F32)
705                .unwrap();
706            if i % 2 == 0 {
707                pool.release_tensor(tensor);
708            }
709        }
710
711        // Trigger garbage collection
712        pool.garbage_collect().unwrap();
713
714        let analytics = pool.get_analytics();
715        // GC time is u64, always non-negative - verify it can be accessed
716        let _gc_time = analytics.gc_time_us; // Verify field access works
717    }
718
719    #[test]
720    fn test_memory_pressure_detection() {
721        let mut config = PoolConfig::default();
722        config.max_total_memory = 1024; // Very small limit to trigger pressure
723        let pool = MemoryPool::new(config);
724
725        // This should not trigger pressure initially
726        let initial_pressure = pool.check_memory_pressure();
727        assert!(!initial_pressure);
728
729        // Allocate enough to potentially trigger pressure
730        let _tensors: Vec<_> = (0..10)
731            .map(|_| pool.allocate_tensor(&[32, 32], DType::F32).unwrap())
732            .collect();
733
734        // Check if pressure is detected (may or may not trigger depending on actual memory usage)
735        let _final_pressure = pool.check_memory_pressure();
736    }
737
738    #[test]
739    fn test_utilization_report() {
740        let pool = MemoryPool::new(PoolConfig::default());
741
742        // Generate some activity
743        let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
744        let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
745        pool.release_tensor(tensor1);
746        pool.release_tensor(tensor2);
747
748        let report = pool.get_utilization_report();
749        // total_pools is usize, always non-negative - verify field access
750        let _pools = report.total_pools;
751        assert!(report.hit_rate >= 0.0);
752        assert!(report.performance_score >= 0.0);
753        assert!(report.performance_score <= 100.0);
754    }
755
756    #[test]
757    fn test_prefetch_workload() {
758        let pool = MemoryPool::new(PoolConfig::default());
759
760        let predicted_shapes = vec![
761            (vec![32, 32], DType::F32),
762            (vec![64, 64], DType::F32),
763            (vec![128, 128], DType::F32),
764        ];
765
766        pool.prefetch_for_workload(&predicted_shapes).unwrap();
767
768        // After prefetching, these sizes should have good hit rates
769        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
770        assert_eq!(tensor.shape().dims(), &[32, 32]);
771
772        let analytics = pool.get_analytics();
773        assert!(analytics.total_allocations > 0);
774    }
775
776    #[test]
777    fn test_adaptive_config() {
778        let mut config = PoolConfig::default();
779        config.enable_cache_awareness = true;
780        config.enable_adaptive_sizing = true;
781        config.auto_gc_threshold = 0.5;
782
783        let pool = MemoryPool::new(config);
784
785        // Test that adaptive features are enabled
786        let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
787        pool.release_tensor(tensor);
788
789        // Trigger adaptive resize
790        pool.adaptive_resize().unwrap();
791
792        let analytics = pool.get_analytics();
793        assert_eq!(analytics.total_allocations, 1);
794    }
795}