1use std::collections::{HashMap, VecDeque};
37use std::sync::{Arc, Mutex};
38use torsh_core::device::DeviceType;
39use torsh_core::Result as TorshResult;
40use torsh_core::{DType, TorshError};
41use torsh_tensor::Tensor;
42
43#[derive(Debug, Clone)]
45pub struct PoolConfig {
46 pub max_tensors_per_size: usize,
48 pub max_total_memory: usize,
50 pub enable_analytics: bool,
52 pub pre_allocate_sizes: Vec<Vec<usize>>,
54 pub enable_cache_awareness: bool,
56 pub memory_alignment: usize,
58 pub auto_gc_threshold: f64,
60 pub enable_adaptive_sizing: bool,
62 pub pressure_check_interval_ms: u64,
64 pub min_cache_tracked_size: usize,
66}
67
68impl Default for PoolConfig {
69 fn default() -> Self {
70 Self {
71 max_tensors_per_size: 16,
72 max_total_memory: 1024 * 1024 * 1024, enable_analytics: true,
74 pre_allocate_sizes: vec![
75 vec![1, 1],
76 vec![32, 32],
77 vec![64, 64],
78 vec![128, 128],
79 vec![256, 256],
80 vec![512, 512],
81 vec![1024, 1024],
82 ],
83 enable_cache_awareness: true,
84 memory_alignment: 64, auto_gc_threshold: 0.75,
86 enable_adaptive_sizing: true,
87 pressure_check_interval_ms: 1000, min_cache_tracked_size: 1024, }
90 }
91}
92
93#[derive(Debug, Clone, Default)]
95pub struct MemoryAnalytics {
96 pub total_allocations: usize,
98 pub total_deallocations: usize,
100 pub pool_hits: usize,
102 pub pool_misses: usize,
104 pub peak_memory_usage: usize,
106 pub current_memory_usage: usize,
108 pub fragmentation_score: f64,
110 pub avg_allocation_size: usize,
112 pub estimated_cache_misses: usize,
114 pub pressure_events: usize,
116 pub gc_time_us: u64,
118}
119
120impl MemoryAnalytics {
121 pub fn hit_rate(&self) -> f64 {
123 if self.total_allocations == 0 {
124 0.0
125 } else {
126 (self.pool_hits as f64 / self.total_allocations as f64) * 100.0
127 }
128 }
129
130 pub fn efficiency_ratio(&self) -> f64 {
132 if self.peak_memory_usage == 0 {
133 1.0
134 } else {
135 self.current_memory_usage as f64 / self.peak_memory_usage as f64
136 }
137 }
138
139 pub fn cache_efficiency(&self) -> f64 {
141 if self.total_allocations == 0 {
142 100.0
143 } else {
144 let cache_hits = self
145 .total_allocations
146 .saturating_sub(self.estimated_cache_misses);
147 (cache_hits as f64 / self.total_allocations as f64) * 100.0
148 }
149 }
150
151 pub fn performance_score(&self) -> f64 {
153 let hit_score = self.hit_rate() * 0.4;
154 let efficiency_score = self.efficiency_ratio() * 100.0 * 0.3;
155 let fragmentation_score = (1.0 - self.fragmentation_score) * 100.0 * 0.2;
156 let cache_score = self.cache_efficiency() * 0.1;
157
158 hit_score + efficiency_score + fragmentation_score + cache_score
159 }
160
161 pub fn needs_optimization(&self) -> bool {
163 self.fragmentation_score > 0.7 || self.hit_rate() < 50.0 || self.pressure_events > 10
164 }
165
166 pub fn get_optimization_recommendations(&self) -> Vec<String> {
168 let mut recommendations = Vec::new();
169
170 if self.hit_rate() < 50.0 {
171 recommendations
172 .push("Consider increasing pool sizes for commonly used tensor shapes".to_string());
173 }
174
175 if self.fragmentation_score > 0.7 {
176 recommendations.push(
177 "High fragmentation detected - consider triggering garbage collection".to_string(),
178 );
179 }
180
181 if self.estimated_cache_misses as f64 / self.total_allocations as f64 > 0.3 {
182 recommendations.push(
183 "Cache-unfriendly allocation patterns detected - consider memory alignment"
184 .to_string(),
185 );
186 }
187
188 if self.pressure_events > 5 {
189 recommendations.push(
190 "Memory pressure detected - consider reducing pool sizes or freeing unused memory"
191 .to_string(),
192 );
193 }
194
195 recommendations
196 }
197}
198
199#[derive(Debug, Clone, PartialEq, Eq, Hash)]
201struct TensorKey {
202 shape: Vec<usize>,
203 dtype: DType,
204}
205
206pub struct MemoryPool {
208 config: PoolConfig,
209 pools: Arc<Mutex<HashMap<TensorKey, VecDeque<Tensor>>>>,
210 analytics: Arc<Mutex<MemoryAnalytics>>,
211}
212
213impl MemoryPool {
214 pub fn new(config: PoolConfig) -> Self {
216 let pool = Self {
217 config,
218 pools: Arc::new(Mutex::new(HashMap::new())),
219 analytics: Arc::new(Mutex::new(MemoryAnalytics::default())),
220 };
221
222 if !pool.config.pre_allocate_sizes.is_empty() {
224 pool.pre_allocate_common_sizes();
225 }
226
227 pool
228 }
229
230 fn pre_allocate_common_sizes(&self) {
232 for shape in &self.config.pre_allocate_sizes {
233 let key = TensorKey {
234 shape: shape.clone(),
235 dtype: DType::F32,
236 };
237
238 if let Ok(mut pools) = self.pools.lock() {
239 let pool = pools.entry(key).or_insert_with(VecDeque::new);
240
241 for _ in 0..4 {
243 if let Ok(tensor) = self.create_tensor(shape, DType::F32) {
244 pool.push_back(tensor);
245 }
246 }
247 }
248 }
249 }
250
251 pub fn allocate_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
253 let key = TensorKey {
254 shape: shape.to_vec(),
255 dtype,
256 };
257
258 if let Ok(mut pools) = self.pools.lock() {
260 if let Some(pool) = pools.get_mut(&key) {
261 if let Some(tensor) = pool.pop_front() {
262 if let Ok(mut analytics) = self.analytics.lock() {
264 analytics.total_allocations += 1;
265 analytics.pool_hits += 1;
266 }
267 return Ok(tensor);
268 }
269 }
270 }
271
272 let tensor = self.create_tensor(shape, dtype)?;
274
275 if let Ok(mut analytics) = self.analytics.lock() {
277 analytics.total_allocations += 1;
278 analytics.pool_misses += 1;
279 }
280
281 Ok(tensor)
282 }
283
284 pub fn release_tensor(&self, tensor: Tensor) {
286 let key = TensorKey {
287 shape: tensor.shape().dims().to_vec(),
288 dtype: tensor.dtype(),
289 };
290
291 if let Ok(mut pools) = self.pools.lock() {
292 let pool = pools.entry(key).or_insert_with(VecDeque::new);
293
294 if pool.len() < self.config.max_tensors_per_size {
296 pool.push_back(tensor);
297 }
298 }
299
300 if let Ok(mut analytics) = self.analytics.lock() {
302 analytics.total_deallocations += 1;
303 }
304 }
305
306 fn create_tensor(&self, shape: &[usize], dtype: DType) -> TorshResult<Tensor> {
308 match dtype {
309 DType::F32 => {
310 let data: Vec<f32> = vec![0.0; shape.iter().product()];
311 Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
312 .map_err(|e| TorshError::InvalidArgument(e.to_string()))
313 }
314 _ => {
315 let data: Vec<f32> = vec![0.0; shape.iter().product()];
318 Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)
319 .map_err(|e| TorshError::InvalidArgument(e.to_string()))
320 }
321 }
322 }
323
324 pub fn get_analytics(&self) -> MemoryAnalytics {
326 self.analytics
327 .lock()
328 .map(|guard| guard.clone())
329 .unwrap_or_default()
330 }
331
332 pub fn clear(&self) {
334 if let Ok(mut pools) = self.pools.lock() {
335 pools.clear();
336 }
337 if let Ok(mut analytics) = self.analytics.lock() {
338 *analytics = MemoryAnalytics::default();
339 }
340 }
341
342 pub fn get_pool_stats(&self) -> HashMap<String, usize> {
344 let mut stats = HashMap::new();
345
346 if let Ok(pools) = self.pools.lock() {
347 for (key, pool) in pools.iter() {
348 let key_str = format!("{:?}_{:?}", key.shape, key.dtype);
349 stats.insert(key_str, pool.len());
350 }
351 }
352
353 stats
354 }
355}
356
357impl MemoryPool {
359 pub fn global() -> &'static MemoryPool {
361 static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
362 GLOBAL_POOL.get_or_init(|| MemoryPool::new(PoolConfig::default()))
363 }
364
365 pub fn allocate_f32(&self, shape: &[usize]) -> TorshResult<Tensor> {
367 self.allocate_tensor(shape, DType::F32)
368 }
369
370 pub fn allocate_i8(&self, shape: &[usize]) -> TorshResult<Tensor> {
372 self.allocate_tensor(shape, DType::I8)
373 }
374
375 pub fn allocate_u8(&self, shape: &[usize]) -> TorshResult<Tensor> {
377 self.allocate_tensor(shape, DType::U8)
378 }
379}
380
381impl MemoryPool {
383 pub fn garbage_collect(&self) -> TorshResult<()> {
385 let start_time = std::time::Instant::now();
386
387 if let Ok(mut pools) = self.pools.lock() {
388 pools.retain(|_, pool| {
390 if pool.is_empty() {
391 true } else {
393 true
395 }
396 });
397
398 if let Ok(mut analytics) = self.analytics.lock() {
400 let gc_duration = start_time.elapsed();
401 analytics.gc_time_us += gc_duration.as_micros() as u64;
402
403 analytics.fragmentation_score = self.calculate_fragmentation_score(&pools);
405 }
406 }
407
408 Ok(())
409 }
410
411 pub fn check_memory_pressure(&self) -> bool {
413 let analytics = self.get_analytics();
414 let memory_usage_ratio =
415 analytics.current_memory_usage as f64 / self.config.max_total_memory as f64;
416
417 let high_pressure = memory_usage_ratio > 0.85
418 || analytics.fragmentation_score > self.config.auto_gc_threshold;
419
420 if high_pressure {
421 let _ = self.garbage_collect();
423
424 if let Ok(mut analytics) = self.analytics.lock() {
426 analytics.pressure_events += 1;
427 }
428 }
429
430 high_pressure
431 }
432
433 fn calculate_fragmentation_score(&self, pools: &HashMap<TensorKey, VecDeque<Tensor>>) -> f64 {
435 if pools.is_empty() {
436 return 0.0;
437 }
438
439 let total_pools = pools.len();
440 let mut fragmented_pools = 0;
441 let mut total_capacity = 0;
442 let mut total_used = 0;
443
444 for (_, pool) in pools.iter() {
445 let capacity = self.config.max_tensors_per_size;
446 let used = pool.len();
447
448 total_capacity += capacity;
449 total_used += used;
450
451 if used > 0 && used < capacity / 2 {
453 fragmented_pools += 1;
454 }
455 }
456
457 let pool_fragmentation = fragmented_pools as f64 / total_pools as f64;
458 let usage_fragmentation = if total_capacity > 0 {
459 1.0 - (total_used as f64 / total_capacity as f64)
460 } else {
461 0.0
462 };
463
464 (pool_fragmentation + usage_fragmentation) / 2.0
465 }
466
467 #[allow(dead_code)]
469 fn estimate_cache_misses(&self, allocation_size: usize) -> usize {
470 if !self.config.enable_cache_awareness
471 || allocation_size < self.config.min_cache_tracked_size
472 {
473 return 0;
474 }
475
476 let alignment = self.config.memory_alignment;
478 let misaligned = allocation_size % alignment != 0;
479
480 if misaligned && allocation_size > alignment * 8 {
481 allocation_size / 64
483 } else {
484 0
485 }
486 }
487
488 pub fn adaptive_resize(&self) -> TorshResult<()> {
490 if !self.config.enable_adaptive_sizing {
491 return Ok(());
492 }
493
494 let analytics = self.get_analytics();
495
496 if analytics.hit_rate() < 50.0 {
498 }
501
502 if analytics.fragmentation_score > 0.7 {
504 let _ = self.garbage_collect();
505 }
506
507 Ok(())
508 }
509
510 pub fn get_utilization_report(&self) -> PoolUtilizationReport {
512 let analytics = self.get_analytics();
513 let pool_stats = self.get_pool_stats();
514
515 PoolUtilizationReport {
516 total_pools: pool_stats.len(),
517 total_tensors_pooled: pool_stats.values().sum(),
518 hit_rate: analytics.hit_rate(),
519 fragmentation_score: analytics.fragmentation_score,
520 cache_efficiency: analytics.cache_efficiency(),
521 memory_usage_mb: analytics.current_memory_usage / 1024 / 1024,
522 peak_memory_usage_mb: analytics.peak_memory_usage / 1024 / 1024,
523 pressure_events: analytics.pressure_events,
524 gc_time_ms: analytics.gc_time_us / 1000,
525 performance_score: analytics.performance_score(),
526 needs_optimization: analytics.needs_optimization(),
527 recommendations: analytics.get_optimization_recommendations(),
528 }
529 }
530
531 pub fn prefetch_for_workload(
533 &self,
534 predicted_shapes: &[(Vec<usize>, DType)],
535 ) -> TorshResult<()> {
536 for (shape, dtype) in predicted_shapes {
537 for _ in 0..2 {
539 let tensor = self.create_tensor(shape, *dtype)?;
540 self.release_tensor(tensor);
541 }
542 }
543 Ok(())
544 }
545}
546
547#[derive(Debug, Clone)]
549pub struct PoolUtilizationReport {
550 pub total_pools: usize,
551 pub total_tensors_pooled: usize,
552 pub hit_rate: f64,
553 pub fragmentation_score: f64,
554 pub cache_efficiency: f64,
555 pub memory_usage_mb: usize,
556 pub peak_memory_usage_mb: usize,
557 pub pressure_events: usize,
558 pub gc_time_ms: u64,
559 pub performance_score: f64,
560 pub needs_optimization: bool,
561 pub recommendations: Vec<String>,
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn test_memory_pool_basic() {
570 let mut config = PoolConfig::default();
571 config.pre_allocate_sizes = vec![]; let pool = MemoryPool::new(config);
573
574 let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
576 assert_eq!(tensor.shape().dims(), &[32, 32]);
577 assert_eq!(tensor.dtype(), DType::F32);
578
579 pool.release_tensor(tensor);
581
582 let tensor2 = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
584 assert_eq!(tensor2.shape().dims(), &[32, 32]);
585
586 let analytics = pool.get_analytics();
587 assert_eq!(analytics.total_allocations, 2);
588 assert_eq!(analytics.pool_hits, 1);
589 assert_eq!(analytics.pool_misses, 1);
590 }
591
592 #[test]
593 fn test_memory_pool_different_sizes() {
594 let mut config = PoolConfig::default();
595 config.pre_allocate_sizes = vec![]; let pool = MemoryPool::new(config);
597
598 let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
599 let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
600
601 assert_eq!(tensor1.shape().dims(), &[64, 64]);
602 assert_eq!(tensor2.shape().dims(), &[128, 128]);
603
604 pool.release_tensor(tensor1);
605 pool.release_tensor(tensor2);
606
607 let analytics = pool.get_analytics();
608 assert_eq!(analytics.total_allocations, 2);
609 assert_eq!(analytics.total_deallocations, 2);
610 assert_eq!(analytics.pool_misses, 2);
611 assert_eq!(analytics.pool_hits, 0); }
613
614 #[test]
615 fn test_memory_pool_analytics() {
616 let mut config = PoolConfig::default();
617 config.pre_allocate_sizes = vec![]; let pool = MemoryPool::new(config);
619
620 for _ in 0..5 {
622 let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
623 pool.release_tensor(tensor);
624 }
625
626 let analytics = pool.get_analytics();
627 assert_eq!(analytics.total_allocations, 5);
628 assert_eq!(analytics.total_deallocations, 5);
629 assert_eq!(analytics.pool_hits, 4); assert_eq!(analytics.pool_misses, 1);
631 assert_eq!(analytics.hit_rate(), 80.0);
632 }
633
634 #[test]
635 fn test_memory_pool_clear() {
636 let pool = MemoryPool::new(PoolConfig::default());
637
638 let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
639 pool.release_tensor(tensor);
640
641 pool.clear();
642
643 let analytics = pool.get_analytics();
644 assert_eq!(analytics.total_allocations, 0);
645 assert_eq!(analytics.total_deallocations, 0);
646 }
647
648 #[test]
649 fn test_convenience_functions() {
650 let pool = MemoryPool::new(PoolConfig::default());
651
652 let f32_tensor = pool.allocate_f32(&[16, 16]).unwrap();
653 let i8_tensor = pool.allocate_i8(&[16, 16]).unwrap();
654 let u8_tensor = pool.allocate_u8(&[16, 16]).unwrap();
655
656 assert_eq!(f32_tensor.dtype(), DType::F32);
658 assert_eq!(i8_tensor.dtype(), DType::F32); assert_eq!(u8_tensor.dtype(), DType::F32); assert_eq!(f32_tensor.shape().dims(), &[16, 16]);
663 assert_eq!(i8_tensor.shape().dims(), &[16, 16]);
664 assert_eq!(u8_tensor.shape().dims(), &[16, 16]);
665 }
666
667 #[test]
668 fn test_global_pool() {
669 let pool = MemoryPool::global();
670 let tensor = pool.allocate_f32(&[8, 8]).unwrap();
671 assert_eq!(tensor.shape().dims(), &[8, 8]);
672 pool.release_tensor(tensor);
673 }
674
675 #[test]
676 fn test_advanced_analytics() {
677 let pool = MemoryPool::new(PoolConfig::default());
678
679 for i in 0..10 {
681 let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
682 if i % 2 == 0 {
683 pool.release_tensor(tensor);
684 }
685 }
686
687 let analytics = pool.get_analytics();
688 assert_eq!(analytics.total_allocations, 10);
689 assert!(analytics.performance_score() >= 0.0);
690 assert!(analytics.performance_score() <= 100.0);
691
692 let recommendations = analytics.get_optimization_recommendations();
693 assert!(!recommendations.is_empty() || analytics.performance_score() > 70.0);
695 }
696
697 #[test]
698 fn test_garbage_collection() {
699 let pool = MemoryPool::new(PoolConfig::default());
700
701 for i in 0..5 {
703 let tensor = pool
704 .allocate_tensor(&[i * 10 + 1, i * 10 + 1], DType::F32)
705 .unwrap();
706 if i % 2 == 0 {
707 pool.release_tensor(tensor);
708 }
709 }
710
711 pool.garbage_collect().unwrap();
713
714 let analytics = pool.get_analytics();
715 let _gc_time = analytics.gc_time_us; }
718
719 #[test]
720 fn test_memory_pressure_detection() {
721 let mut config = PoolConfig::default();
722 config.max_total_memory = 1024; let pool = MemoryPool::new(config);
724
725 let initial_pressure = pool.check_memory_pressure();
727 assert!(!initial_pressure);
728
729 let _tensors: Vec<_> = (0..10)
731 .map(|_| pool.allocate_tensor(&[32, 32], DType::F32).unwrap())
732 .collect();
733
734 let _final_pressure = pool.check_memory_pressure();
736 }
737
738 #[test]
739 fn test_utilization_report() {
740 let pool = MemoryPool::new(PoolConfig::default());
741
742 let tensor1 = pool.allocate_tensor(&[64, 64], DType::F32).unwrap();
744 let tensor2 = pool.allocate_tensor(&[128, 128], DType::F32).unwrap();
745 pool.release_tensor(tensor1);
746 pool.release_tensor(tensor2);
747
748 let report = pool.get_utilization_report();
749 let _pools = report.total_pools;
751 assert!(report.hit_rate >= 0.0);
752 assert!(report.performance_score >= 0.0);
753 assert!(report.performance_score <= 100.0);
754 }
755
756 #[test]
757 fn test_prefetch_workload() {
758 let pool = MemoryPool::new(PoolConfig::default());
759
760 let predicted_shapes = vec![
761 (vec![32, 32], DType::F32),
762 (vec![64, 64], DType::F32),
763 (vec![128, 128], DType::F32),
764 ];
765
766 pool.prefetch_for_workload(&predicted_shapes).unwrap();
767
768 let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
770 assert_eq!(tensor.shape().dims(), &[32, 32]);
771
772 let analytics = pool.get_analytics();
773 assert!(analytics.total_allocations > 0);
774 }
775
776 #[test]
777 fn test_adaptive_config() {
778 let mut config = PoolConfig::default();
779 config.enable_cache_awareness = true;
780 config.enable_adaptive_sizing = true;
781 config.auto_gc_threshold = 0.5;
782
783 let pool = MemoryPool::new(config);
784
785 let tensor = pool.allocate_tensor(&[32, 32], DType::F32).unwrap();
787 pool.release_tensor(tensor);
788
789 pool.adaptive_resize().unwrap();
791
792 let analytics = pool.get_analytics();
793 assert_eq!(analytics.total_allocations, 1);
794 }
795}