1#![allow(dead_code)]
17use std::alloc::{GlobalAlloc, Layout, System};
18use std::collections::{BTreeMap, HashMap, VecDeque};
19use std::mem::{align_of, size_of};
20use std::ptr::NonNull;
21use std::sync::{Arc, Mutex, RwLock};
22use std::time::{Duration, Instant};
23
24use torsh_core::{
26 dtype::TensorElement,
27 error::{Result, TorshError},
28};
29
30#[derive(Debug, Clone)]
32pub struct MemoryConfig {
33 pub enable_pooling: bool,
35 pub pool_size: usize,
37 pub max_cached_per_size: usize,
39 pub enable_compression: bool,
41 pub compression_threshold: usize,
43 pub enable_numa_awareness: bool,
45 pub cache_line_size: usize,
47 pub enable_predictive_allocation: bool,
49 pub memory_pressure_threshold: f64,
51}
52
53impl Default for MemoryConfig {
54 fn default() -> Self {
55 Self {
56 enable_pooling: true,
57 pool_size: 1024 * 1024 * 1024, max_cached_per_size: 64,
59 enable_compression: true,
60 compression_threshold: 100 * 1024 * 1024, enable_numa_awareness: false, cache_line_size: 64,
63 enable_predictive_allocation: true,
64 memory_pressure_threshold: 0.8,
65 }
66 }
67}
68
69pub struct AdvancedMemoryPool<T: TensorElement> {
71 config: MemoryConfig,
72 size_class_pools: RwLock<BTreeMap<usize, VecDeque<NonNull<T>>>>,
74 stats: RwLock<MemoryStats>,
76 allocation_history: Mutex<VecDeque<AllocationRecord>>,
78 predictor: Mutex<Option<AllocationPredictor>>,
80 compression_manager: Arc<CompressionManager>,
82 numa_allocators: Vec<Arc<Mutex<NumaAllocator>>>,
84}
85
86impl<T: TensorElement> AdvancedMemoryPool<T> {
87 pub fn new() -> Self {
89 Self::with_config(MemoryConfig::default())
90 }
91
92 pub fn with_config(config: MemoryConfig) -> Self {
94 let numa_nodes = if config.enable_numa_awareness {
95 detect_numa_nodes()
96 } else {
97 1
98 };
99
100 let numa_allocators = (0..numa_nodes)
101 .map(|node_id| Arc::new(Mutex::new(NumaAllocator::new(node_id))))
102 .collect();
103
104 Self {
105 config,
106 size_class_pools: RwLock::new(BTreeMap::new()),
107 stats: RwLock::new(MemoryStats::default()),
108 allocation_history: Mutex::new(VecDeque::with_capacity(10000)),
109 predictor: Mutex::new(None),
110 compression_manager: Arc::new(CompressionManager::new()),
111 numa_allocators,
112 }
113 }
114
115 pub fn allocate(&self, size: usize) -> Result<NonNull<T>> {
117 #[cfg(feature = "profiling")]
118 {
119 }
121 let aligned_size = self.align_size(size);
122
123 if self.config.enable_compression && size > self.config.compression_threshold {
125 return self.allocate_compressed(aligned_size);
126 }
127
128 if let Some(ptr) = self.try_reuse_from_pool(aligned_size)? {
130 self.record_allocation(aligned_size, true);
131 return Ok(ptr);
132 }
133
134 if self.config.enable_predictive_allocation {
136 self.maybe_predictive_allocate(aligned_size)?;
137 }
138
139 let ptr = self.allocate_new(aligned_size)?;
141 self.record_allocation(aligned_size, false);
142
143 Ok(ptr)
144 }
145
146 pub fn deallocate(&self, ptr: NonNull<T>, size: usize) -> Result<()> {
148 #[cfg(feature = "profiling")]
149 {
150 }
152 let aligned_size = self.align_size(size);
153
154 if self.compression_manager.is_compressed(ptr) {
156 return self.compression_manager.deallocate(ptr);
157 }
158
159 if self.should_cache_allocation(aligned_size) {
161 let mut pools = self
162 .size_class_pools
163 .write()
164 .expect("lock should not be poisoned");
165 let pool = pools.entry(aligned_size).or_insert_with(VecDeque::new);
166
167 if pool.len() < self.config.max_cached_per_size {
168 pool.push_back(ptr);
169 self.update_stats(|stats| stats.pooled_allocations += 1);
170 return Ok(());
171 }
172 }
173
174 self.free_allocation(ptr, aligned_size)?;
176 Ok(())
177 }
178
179 fn try_reuse_from_pool(&self, size: usize) -> Result<Option<NonNull<T>>> {
181 if !self.config.enable_pooling {
182 return Ok(None);
183 }
184
185 let mut pools = self
186 .size_class_pools
187 .write()
188 .expect("lock should not be poisoned");
189
190 if let Some(pool) = pools.get_mut(&size) {
192 if let Some(ptr) = pool.pop_front() {
193 self.update_stats(|stats| stats.pool_hits += 1);
194 return Ok(Some(ptr));
195 }
196 }
197
198 let max_oversized = size * 2; for (&pool_size, pool) in pools.range_mut(size..).take(5) {
202 if pool_size > max_oversized {
203 break;
204 }
205
206 if let Some(ptr) = pool.pop_front() {
207 self.update_stats(|stats| {
208 stats.pool_hits += 1;
209 stats.oversized_reuse += 1;
210 });
211 return Ok(Some(ptr));
212 }
213 }
214
215 self.update_stats(|stats| stats.pool_misses += 1);
216 Ok(None)
217 }
218
219 fn allocate_new(&self, size: usize) -> Result<NonNull<T>> {
221 let layout = Layout::from_size_align(
222 size * size_of::<T>(),
223 align_of::<T>().max(self.config.cache_line_size),
224 )
225 .map_err(|_| TorshError::InvalidArgument("Invalid memory layout".to_string()))?;
226
227 if self.config.enable_numa_awareness && !self.numa_allocators.is_empty() {
229 let numa_node = self.select_numa_node();
230 let allocator = &self.numa_allocators[numa_node];
231 let mut allocator = allocator.lock().expect("lock should not be poisoned");
232 return allocator.allocate(layout);
233 }
234
235 unsafe {
237 let ptr = System.alloc(layout);
238 if ptr.is_null() {
239 return Err(TorshError::AllocationError(
240 "Failed to allocate memory".to_string(),
241 ));
242 }
243
244 self.prefault_pages(ptr, layout.size());
246
247 Ok(NonNull::new_unchecked(ptr as *mut T))
248 }
249 }
250
251 fn allocate_compressed(&self, size: usize) -> Result<NonNull<T>> {
253 self.compression_manager.allocate_compressed(size)
254 }
255
256 fn free_allocation(&self, ptr: NonNull<T>, size: usize) -> Result<()> {
258 let layout = Layout::from_size_align(
259 size * size_of::<T>(),
260 align_of::<T>().max(self.config.cache_line_size),
261 )
262 .map_err(|_| TorshError::InvalidArgument("Invalid memory layout".to_string()))?;
263
264 unsafe {
265 System.dealloc(ptr.as_ptr() as *mut u8, layout);
266 }
267
268 self.update_stats(|stats| stats.direct_deallocations += 1);
269 Ok(())
270 }
271
272 fn maybe_predictive_allocate(&self, size: usize) -> Result<()> {
274 let mut predictor_guard = self.predictor.lock().expect("lock should not be poisoned");
275
276 if predictor_guard.is_none() {
277 *predictor_guard = Some(AllocationPredictor::new());
278 }
279
280 if let Some(predictor) = predictor_guard.as_mut() {
281 if let Some(predicted_sizes) = predictor.predict_next_allocations(size) {
282 for predicted_size in predicted_sizes {
284 if predicted_size != size && predicted_size > 0 {
285 let _ = self.allocate_new(predicted_size);
287 }
288 }
289 }
290 }
291
292 Ok(())
293 }
294
295 fn align_size(&self, size: usize) -> usize {
297 let cache_line = self.config.cache_line_size;
298 ((size + cache_line - 1) / cache_line) * cache_line
299 }
300
301 fn should_cache_allocation(&self, size: usize) -> bool {
303 self.config.enable_pooling &&
304 size <= self.config.pool_size / 100 && !self.is_memory_pressure_high()
306 }
307
308 fn is_memory_pressure_high(&self) -> bool {
310 let stats = self.stats.read().expect("lock should not be poisoned");
312 let total_allocations = stats.pool_hits + stats.pool_misses + stats.direct_allocations;
313
314 if total_allocations == 0 {
315 return false;
316 }
317
318 let cache_hit_rate = stats.pool_hits as f64 / total_allocations as f64;
319 cache_hit_rate < (1.0 - self.config.memory_pressure_threshold)
320 }
321
322 fn prefault_pages(&self, ptr: *mut u8, size: usize) {
324 const PAGE_SIZE: usize = 4096;
325 let page_count = (size + PAGE_SIZE - 1) / PAGE_SIZE;
326
327 unsafe {
328 for i in 0..page_count {
329 let page_ptr = ptr.add(i * PAGE_SIZE);
330 std::ptr::write_volatile(page_ptr, 0);
331 }
332 }
333 }
334
335 fn select_numa_node(&self) -> usize {
337 let stats = self.stats.read().expect("lock should not be poisoned");
339 (stats.total_allocations % self.numa_allocators.len()) as usize
340 }
341
342 fn record_allocation(&self, size: usize, was_reused: bool) {
344 let record = AllocationRecord {
345 size,
346 timestamp: Instant::now(),
347 was_reused,
348 };
349
350 let mut history = self
351 .allocation_history
352 .lock()
353 .expect("lock should not be poisoned");
354 history.push_back(record);
355
356 if history.len() > 10000 {
358 history.pop_front();
359 }
360
361 self.update_stats(|stats| {
362 stats.total_allocations += 1;
363 if was_reused {
364 stats.reused_allocations += 1;
365 } else {
366 stats.direct_allocations += 1;
367 }
368 });
369 }
370
371 fn update_stats<F>(&self, f: F)
373 where
374 F: FnOnce(&mut MemoryStats),
375 {
376 let mut stats = self.stats.write().expect("lock should not be poisoned");
377 f(&mut *stats);
378 }
379
380 pub fn get_stats(&self) -> MemoryStats {
382 self.stats
383 .read()
384 .expect("lock should not be poisoned")
385 .clone()
386 }
387
388 pub fn defragment(&self) -> Result<DefragmentationReport> {
390 #[cfg(feature = "profiling")]
391 {
392 }
394 let start_time = Instant::now();
395 let mut report = DefragmentationReport::default();
396
397 {
399 let mut pools = self
400 .size_class_pools
401 .write()
402 .expect("lock should not be poisoned");
403 let initial_pools = pools.len();
404 pools.retain(|_, pool| !pool.is_empty());
405 report.pools_cleaned = initial_pools - pools.len();
406 }
407
408 if self.config.enable_compression {
410 report.compression_stats = self.compression_manager.compress_fragmented()?;
411 }
412
413 report.duration = start_time.elapsed();
415 report.memory_freed = self.estimate_memory_freed();
416
417 Ok(report)
418 }
419
420 fn estimate_memory_freed(&self) -> usize {
422 let stats = self.stats.read().expect("lock should not be poisoned");
424 stats
425 .total_allocations
426 .saturating_sub(stats.reused_allocations)
427 * 1024 }
429}
430
431impl<T: TensorElement> Default for AdvancedMemoryPool<T> {
432 fn default() -> Self {
433 Self::new()
434 }
435}
436
437#[derive(Debug, Clone, Default)]
439pub struct MemoryStats {
440 pub total_allocations: usize,
441 pub direct_allocations: usize,
442 pub reused_allocations: usize,
443 pub pooled_allocations: usize,
444 pub pool_hits: usize,
445 pub pool_misses: usize,
446 pub oversized_reuse: usize,
447 pub direct_deallocations: usize,
448 pub compression_saves: usize,
449 pub numa_allocations: usize,
450}
451
452impl MemoryStats {
453 pub fn hit_rate(&self) -> f64 {
455 let total_pool_requests = self.pool_hits + self.pool_misses;
456 if total_pool_requests == 0 {
457 0.0
458 } else {
459 self.pool_hits as f64 / total_pool_requests as f64
460 }
461 }
462
463 pub fn reuse_rate(&self) -> f64 {
465 if self.total_allocations == 0 {
466 0.0
467 } else {
468 self.reused_allocations as f64 / self.total_allocations as f64
469 }
470 }
471}
472
473#[derive(Debug, Clone)]
475struct AllocationRecord {
476 size: usize,
477 timestamp: Instant,
478 was_reused: bool,
479}
480
481struct AllocationPredictor {
483 size_patterns: HashMap<usize, Vec<usize>>,
484 temporal_patterns: VecDeque<(Instant, usize)>,
485 max_history: usize,
486}
487
488impl AllocationPredictor {
489 fn new() -> Self {
490 Self {
491 size_patterns: HashMap::new(),
492 temporal_patterns: VecDeque::new(),
493 max_history: 1000,
494 }
495 }
496
497 fn predict_next_allocations(&mut self, size: usize) -> Option<Vec<usize>> {
499 self.temporal_patterns.push_back((Instant::now(), size));
501
502 if self.temporal_patterns.len() > self.max_history {
504 self.temporal_patterns.pop_front();
505 }
506
507 if let Some(following_sizes) = self.size_patterns.get(&size) {
509 let mut counts: HashMap<usize, usize> = HashMap::new();
511 for &following_size in following_sizes {
512 *counts.entry(following_size).or_insert(0) += 1;
513 }
514
515 let mut sorted: Vec<_> = counts.into_iter().collect();
516 sorted.sort_by(|a, b| b.1.cmp(&a.1));
517
518 Some(sorted.into_iter().take(3).map(|(size, _)| size).collect())
519 } else {
520 None
521 }
522 }
523}
524
525struct CompressionManager {
527 compressed_allocations: RwLock<HashMap<usize, CompressedAllocation>>,
528}
529
530impl CompressionManager {
531 fn new() -> Self {
532 Self {
533 compressed_allocations: RwLock::new(HashMap::new()),
534 }
535 }
536
537 fn allocate_compressed<T: TensorElement>(&self, size: usize) -> Result<NonNull<T>> {
538 let compressed_size = size / 2; let layout = Layout::from_size_align(compressed_size, align_of::<T>())
542 .map_err(|_| TorshError::InvalidArgument("Invalid layout".to_string()))?;
543
544 unsafe {
545 let ptr = System.alloc(layout);
546 if ptr.is_null() {
547 return Err(TorshError::AllocationError(
548 "Compression allocation failed".to_string(),
549 ));
550 }
551
552 let allocation = CompressedAllocation {
553 original_size: size,
554 compressed_size,
555 compression_ratio: 0.5,
556 };
557
558 self.compressed_allocations
559 .write()
560 .expect("rwlock should not be poisoned")
561 .insert(ptr as usize, allocation);
562 Ok(NonNull::new_unchecked(ptr as *mut T))
563 }
564 }
565
566 fn is_compressed<T: TensorElement>(&self, ptr: NonNull<T>) -> bool {
567 self.compressed_allocations
568 .read()
569 .expect("rwlock should not be poisoned")
570 .contains_key(&(ptr.as_ptr() as usize))
571 }
572
573 fn deallocate<T: TensorElement>(&self, ptr: NonNull<T>) -> Result<()> {
574 let ptr_key = ptr.as_ptr() as usize;
575 let mut allocations = self
576 .compressed_allocations
577 .write()
578 .expect("lock should not be poisoned");
579
580 if let Some(allocation) = allocations.remove(&ptr_key) {
581 let layout = Layout::from_size_align(allocation.compressed_size, align_of::<T>())
582 .map_err(|_| TorshError::InvalidArgument("Invalid layout".to_string()))?;
583
584 unsafe {
585 System.dealloc(ptr_key as *mut u8, layout);
586 }
587 Ok(())
588 } else {
589 Err(TorshError::InvalidArgument(
590 "Allocation not found".to_string(),
591 ))
592 }
593 }
594
595 fn compress_fragmented(&self) -> Result<CompressionStats> {
596 Ok(CompressionStats {
598 allocations_compressed: 0,
599 memory_saved: 0,
600 average_compression_ratio: 0.0,
601 })
602 }
603}
604
605#[derive(Debug, Clone)]
607struct CompressedAllocation {
608 original_size: usize,
609 compressed_size: usize,
610 compression_ratio: f64,
611}
612
613struct NumaAllocator {
615 node_id: usize,
616 allocations: usize,
617}
618
619impl NumaAllocator {
620 fn new(node_id: usize) -> Self {
621 Self {
622 node_id,
623 allocations: 0,
624 }
625 }
626
627 fn allocate<T: TensorElement>(&mut self, layout: Layout) -> Result<NonNull<T>> {
628 unsafe {
630 let ptr = System.alloc(layout);
631 if ptr.is_null() {
632 return Err(TorshError::AllocationError(
633 "NUMA allocation failed".to_string(),
634 ));
635 }
636 self.allocations += 1;
637 Ok(NonNull::new_unchecked(ptr as *mut T))
638 }
639 }
640}
641
642#[derive(Debug, Default)]
644pub struct DefragmentationReport {
645 pub duration: Duration,
646 pub pools_cleaned: usize,
647 pub memory_freed: usize,
648 pub compression_stats: CompressionStats,
649}
650
651#[derive(Debug, Default)]
653pub struct CompressionStats {
654 pub allocations_compressed: usize,
655 pub memory_saved: usize,
656 pub average_compression_ratio: f64,
657}
658
659fn detect_numa_nodes() -> usize {
661 1 }
664
665pub struct GlobalMemoryOptimizer {
667 f32_pool: AdvancedMemoryPool<f32>,
668 f64_pool: AdvancedMemoryPool<f64>,
669 i32_pool: AdvancedMemoryPool<i32>,
670 i64_pool: AdvancedMemoryPool<i64>,
671 config: MemoryConfig,
672}
673
674impl GlobalMemoryOptimizer {
675 pub fn new() -> Self {
677 let config = MemoryConfig::default();
678 Self::with_config(config)
679 }
680
681 pub fn with_config(config: MemoryConfig) -> Self {
683 Self {
684 f32_pool: AdvancedMemoryPool::with_config(config.clone()),
685 f64_pool: AdvancedMemoryPool::with_config(config.clone()),
686 i32_pool: AdvancedMemoryPool::with_config(config.clone()),
687 i64_pool: AdvancedMemoryPool::with_config(config.clone()),
688 config,
689 }
690 }
691
692 pub fn get_pool<T: TensorElement>(&self) -> Option<&AdvancedMemoryPool<T>> {
694 None }
697
698 pub fn global_defragmentation(&self) -> Result<Vec<DefragmentationReport>> {
700 let mut reports = Vec::new();
701
702 reports.push(self.f32_pool.defragment()?);
703 reports.push(self.f64_pool.defragment()?);
704 Ok(reports)
707 }
708
709 pub fn get_aggregate_stats(&self) -> AggregateMemoryStats {
711 AggregateMemoryStats {
712 f32_stats: self.f32_pool.get_stats(),
713 f64_stats: self.f64_pool.get_stats(),
714 i32_stats: self.i32_pool.get_stats(),
715 i64_stats: self.i64_pool.get_stats(),
716 }
717 }
718}
719
720impl Default for GlobalMemoryOptimizer {
721 fn default() -> Self {
722 Self::new()
723 }
724}
725
726#[derive(Debug)]
728pub struct AggregateMemoryStats {
729 pub f32_stats: MemoryStats,
730 pub f64_stats: MemoryStats,
731 pub i32_stats: MemoryStats,
732 pub i64_stats: MemoryStats,
733}
734
735impl AggregateMemoryStats {
736 pub fn overall_hit_rate(&self) -> f64 {
738 let total_hits = self.f32_stats.pool_hits
739 + self.f64_stats.pool_hits
740 + self.i32_stats.pool_hits
741 + self.i64_stats.pool_hits;
742 let total_misses = self.f32_stats.pool_misses
743 + self.f64_stats.pool_misses
744 + self.i32_stats.pool_misses
745 + self.i64_stats.pool_misses;
746
747 let total_requests = total_hits + total_misses;
748 if total_requests == 0 {
749 0.0
750 } else {
751 total_hits as f64 / total_requests as f64
752 }
753 }
754
755 pub fn total_allocations(&self) -> usize {
757 self.f32_stats.total_allocations
758 + self.f64_stats.total_allocations
759 + self.i32_stats.total_allocations
760 + self.i64_stats.total_allocations
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767 use std::ptr;
768
769 #[test]
770 fn test_memory_config_default() {
771 let config = MemoryConfig::default();
772 assert!(config.enable_pooling);
773 assert!(config.pool_size > 0);
774 assert!(config.cache_line_size > 0);
775 }
776
777 #[test]
778 fn test_advanced_memory_pool_creation() {
779 let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
780 let stats = pool.get_stats();
781
782 assert_eq!(stats.total_allocations, 0);
783 assert_eq!(stats.pool_hits, 0);
784 assert_eq!(stats.pool_misses, 0);
785 }
786
787 #[test]
788 fn test_memory_allocation_and_deallocation() {
789 let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
790
791 let ptr = pool.allocate(1024).expect("allocation should succeed");
793 pool.deallocate(ptr, 1024)
797 .expect("deallocation should succeed");
798
799 let stats = pool.get_stats();
800 assert_eq!(stats.total_allocations, 1);
801 }
802
803 #[test]
804 fn test_memory_pool_reuse() {
805 let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
806
807 let ptr1 = pool.allocate(1024).expect("allocation should succeed");
809 pool.deallocate(ptr1, 1024)
810 .expect("deallocation should succeed");
811
812 let ptr2 = pool.allocate(1024).expect("allocation should succeed");
814 pool.deallocate(ptr2, 1024)
815 .expect("deallocation should succeed");
816
817 let stats = pool.get_stats();
818 assert_eq!(stats.total_allocations, 2);
819 }
821
822 #[test]
823 fn test_memory_stats_calculations() {
824 let mut stats = MemoryStats::default();
825 stats.pool_hits = 80;
826 stats.pool_misses = 20;
827 stats.total_allocations = 100;
828 stats.reused_allocations = 80;
829
830 assert_eq!(stats.hit_rate(), 0.8);
831 assert_eq!(stats.reuse_rate(), 0.8);
832 }
833
834 #[test]
835 fn test_size_alignment() {
836 let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::with_config(MemoryConfig {
837 cache_line_size: 64,
838 ..Default::default()
839 });
840
841 assert_eq!(pool.align_size(1), 64);
842 assert_eq!(pool.align_size(65), 128);
843 assert_eq!(pool.align_size(128), 128);
844 }
845
846 #[test]
847 fn test_defragmentation() {
848 let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::new();
849
850 for i in 0..10 {
852 let ptr = pool
853 .allocate(1024 * (i + 1))
854 .expect("allocation should succeed");
855 pool.deallocate(ptr, 1024 * (i + 1))
856 .expect("deallocation should succeed");
857 }
858
859 let report = pool.defragment().expect("defragmentation should succeed");
860 let _ = report.duration; }
865
866 #[test]
867 fn test_global_memory_optimizer() {
868 let optimizer = GlobalMemoryOptimizer::new();
869 let stats = optimizer.get_aggregate_stats();
870
871 assert_eq!(stats.total_allocations(), 0);
872 assert_eq!(stats.overall_hit_rate(), 0.0);
873 }
874
875 #[test]
876 fn test_compression_manager() {
877 let manager = CompressionManager::new();
878 let ptr = NonNull::new(ptr::null_mut::<f32>().wrapping_add(0x1000))
879 .expect("pointer should be non-null");
880
881 assert!(!manager.is_compressed(ptr));
882 }
883
884 #[test]
885 fn test_allocation_predictor() {
886 let mut predictor = AllocationPredictor::new();
887
888 let predictions = predictor.predict_next_allocations(1024);
890 assert!(predictions.is_none());
891 }
892
893 #[test]
894 fn test_memory_pressure_detection() {
895 let pool: AdvancedMemoryPool<f32> = AdvancedMemoryPool::with_config(MemoryConfig {
896 memory_pressure_threshold: 0.5,
897 ..Default::default()
898 });
899
900 assert!(!pool.is_memory_pressure_high());
902 }
903}