1use crate::*;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, RwLock};
9use std::time::{Duration, Instant};
10use torsh_core::{DType, Result as TorshResult, Shape};
11
12static MEMORY_POOL: std::sync::OnceLock<Arc<RwLock<SparseMemoryPool>>> = std::sync::OnceLock::new();
14
15fn get_memory_pool() -> &'static Arc<RwLock<SparseMemoryPool>> {
17 MEMORY_POOL.get_or_init(|| Arc::new(RwLock::new(SparseMemoryPool::new())))
18}
19
20#[derive(Debug, Clone)]
22pub struct MemoryPoolConfig {
23 pub max_total_memory: usize,
25 pub max_allocation_size: usize,
27 pub num_size_buckets: usize,
29 pub memory_timeout: Duration,
31 pub enable_tracking: bool,
33 pub gc_interval: Duration,
35}
36
37impl Default for MemoryPoolConfig {
38 fn default() -> Self {
39 Self {
40 max_total_memory: 1024 * 1024 * 1024, max_allocation_size: 128 * 1024 * 1024, num_size_buckets: 16,
43 memory_timeout: Duration::from_secs(300), enable_tracking: true,
45 gc_interval: Duration::from_secs(60), }
47 }
48}
49
50#[derive(Debug)]
52struct MemoryBucket {
53 size_range: (usize, usize),
55 available_blocks: VecDeque<MemoryBlock>,
57 total_allocated: usize,
59 active_allocations: usize,
61}
62
63#[derive(Debug)]
65struct MemoryBlock {
66 memory: Vec<u8>,
68 allocated_at: Instant,
70 last_accessed: Instant,
72 ref_count: usize,
74 #[allow(dead_code)]
76 id: u64,
77}
78
79#[derive(Debug)]
81pub struct SparseMemoryPool {
82 config: MemoryPoolConfig,
84 buckets: Vec<MemoryBucket>,
86 total_allocated: usize,
88 stats: MemoryStatistics,
90 last_gc: Instant,
92 next_block_id: u64,
94 allocation_cache: HashMap<String, Vec<u64>>,
96}
97
98#[derive(Debug, Clone)]
100pub struct MemoryStatistics {
101 pub total_allocated_bytes: usize,
103 pub total_deallocated_bytes: usize,
105 pub active_allocations: usize,
107 pub current_memory_usage: usize,
109 pub peak_memory_usage: usize,
111 pub allocation_requests: usize,
113 pub pool_reuses: usize,
115 pub garbage_collections: usize,
117 pub average_allocation_size: f64,
119}
120
121impl Default for MemoryStatistics {
122 fn default() -> Self {
123 Self {
124 total_allocated_bytes: 0,
125 total_deallocated_bytes: 0,
126 active_allocations: 0,
127 current_memory_usage: 0,
128 peak_memory_usage: 0,
129 allocation_requests: 0,
130 pool_reuses: 0,
131 garbage_collections: 0,
132 average_allocation_size: 0.0,
133 }
134 }
135}
136
137impl MemoryBucket {
138 fn new(size_range: (usize, usize)) -> Self {
139 Self {
140 size_range,
141 available_blocks: VecDeque::new(),
142 total_allocated: 0,
143 active_allocations: 0,
144 }
145 }
146
147 fn can_handle(&self, size: usize) -> bool {
148 size >= self.size_range.0 && size <= self.size_range.1
149 }
150
151 fn allocate(&mut self, size: usize, block_id: u64) -> Option<MemoryBlock> {
152 if !self.can_handle(size) {
153 return None;
154 }
155
156 if let Some(mut block) = self.available_blocks.pop_front() {
158 block.last_accessed = Instant::now();
159 block.ref_count = 1;
160 return Some(block);
161 }
162
163 let actual_size = self.size_range.1; let memory = vec![0u8; actual_size];
166 let now = Instant::now();
167
168 self.total_allocated += actual_size;
169 self.active_allocations += 1;
170
171 Some(MemoryBlock {
172 memory,
173 allocated_at: now,
174 last_accessed: now,
175 ref_count: 1,
176 id: block_id,
177 })
178 }
179
180 fn deallocate(&mut self, block: MemoryBlock) {
181 self.active_allocations = self.active_allocations.saturating_sub(1);
182 self.available_blocks.push_back(block);
183 }
184
185 fn cleanup_expired(&mut self, timeout: Duration) -> usize {
186 let now = Instant::now();
187 let initial_count = self.available_blocks.len();
188
189 self.available_blocks.retain(|block| {
190 let should_keep = now.duration_since(block.last_accessed) < timeout;
191 if !should_keep {
192 self.total_allocated = self.total_allocated.saturating_sub(block.memory.len());
193 }
194 should_keep
195 });
196
197 initial_count - self.available_blocks.len()
198 }
199}
200
201impl Default for SparseMemoryPool {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl SparseMemoryPool {
208 pub fn new() -> Self {
210 Self::with_config(MemoryPoolConfig::default())
211 }
212
213 pub fn with_config(config: MemoryPoolConfig) -> Self {
215 let buckets = Self::create_buckets(&config);
216
217 Self {
218 config,
219 buckets,
220 total_allocated: 0,
221 stats: MemoryStatistics::default(),
222 last_gc: Instant::now(),
223 next_block_id: 1,
224 allocation_cache: HashMap::new(),
225 }
226 }
227
228 fn create_buckets(config: &MemoryPoolConfig) -> Vec<MemoryBucket> {
230 let mut buckets = Vec::new();
231 let max_size = config.max_allocation_size;
232 let num_buckets = config.num_size_buckets;
233
234 let mut size = 1024; for _ in 0..num_buckets {
237 let next_size = (size * 2).min(max_size);
238 buckets.push(MemoryBucket::new((size, next_size)));
239 size = next_size;
240 if size >= max_size {
241 break;
242 }
243 }
244
245 buckets
246 }
247
248 pub fn allocate(
250 &mut self,
251 size: usize,
252 allocation_type: &str,
253 ) -> TorshResult<SparseMemoryHandle> {
254 if size > self.config.max_allocation_size {
255 return Err(torsh_core::TorshError::Other(format!(
256 "Allocation size {} exceeds maximum {}",
257 size, self.config.max_allocation_size
258 )));
259 }
260
261 if self.total_allocated + size > self.config.max_total_memory {
262 self.garbage_collect();
264
265 if self.total_allocated + size > self.config.max_total_memory {
266 return Err(torsh_core::TorshError::Other(
267 "Memory pool capacity exceeded".to_string(),
268 ));
269 }
270 }
271
272 let bucket_idx = self.find_bucket_for_size(size);
274 let block_id = self.next_block_id;
275 self.next_block_id += 1;
276
277 let block = if let Some(idx) = bucket_idx {
278 self.buckets[idx].allocate(size, block_id)
279 } else {
280 let memory = vec![0u8; size];
282 let now = Instant::now();
283 Some(MemoryBlock {
284 memory,
285 allocated_at: now,
286 last_accessed: now,
287 ref_count: 1,
288 id: block_id,
289 })
290 };
291
292 if let Some(block) = block {
293 self.total_allocated += block.memory.len();
294 self.stats.total_allocated_bytes += block.memory.len();
295 self.stats.allocation_requests += 1;
296 self.stats.active_allocations += 1;
297 self.stats.current_memory_usage = self.total_allocated;
298
299 if self.total_allocated > self.stats.peak_memory_usage {
300 self.stats.peak_memory_usage = self.total_allocated;
301 }
302
303 self.stats.average_allocation_size =
304 self.stats.total_allocated_bytes as f64 / self.stats.allocation_requests as f64;
305
306 self.allocation_cache
308 .entry(allocation_type.to_string())
309 .or_default()
310 .push(block_id);
311
312 Ok(SparseMemoryHandle::new(block, bucket_idx))
313 } else {
314 Err(torsh_core::TorshError::Other(
315 "Failed to allocate memory".to_string(),
316 ))
317 }
318 }
319
320 fn find_bucket_for_size(&self, size: usize) -> Option<usize> {
322 self.buckets
323 .iter()
324 .position(|bucket| bucket.can_handle(size))
325 }
326
327 pub fn deallocate(&mut self, handle: SparseMemoryHandle) {
329 let (block, bucket_idx) = handle.into_parts();
330
331 self.total_allocated = self.total_allocated.saturating_sub(block.memory.len());
332 self.stats.total_deallocated_bytes += block.memory.len();
333 self.stats.active_allocations = self.stats.active_allocations.saturating_sub(1);
334 self.stats.current_memory_usage = self.total_allocated;
335
336 if let Some(idx) = bucket_idx {
337 if idx < self.buckets.len() {
338 self.buckets[idx].deallocate(block);
339 self.stats.pool_reuses += 1;
340 }
341 }
342 }
344
345 pub fn garbage_collect(&mut self) {
347 let now = Instant::now();
348
349 if now.duration_since(self.last_gc) < self.config.gc_interval {
350 return; }
352
353 let mut _total_freed = 0;
354 for bucket in &mut self.buckets {
355 _total_freed += bucket.cleanup_expired(self.config.memory_timeout);
356 }
357
358 self.allocation_cache.retain(|_, ids| {
360 ids.retain(|_| true); !ids.is_empty()
362 });
363
364 self.stats.garbage_collections += 1;
365 self.last_gc = now;
366 }
367
368 pub fn force_garbage_collect(&mut self) {
370 self.last_gc = Instant::now() - self.config.gc_interval;
371 self.garbage_collect();
372 }
373
374 pub fn statistics(&self) -> MemoryStatistics {
376 self.stats.clone()
377 }
378
379 pub fn usage_by_type(&self) -> HashMap<String, usize> {
381 let mut usage = HashMap::new();
384 usage.insert("sparse_matrices".to_string(), self.total_allocated);
385 usage
386 }
387
388 pub fn is_healthy(&self) -> bool {
390 let usage_ratio = self.total_allocated as f64 / self.config.max_total_memory as f64;
391 usage_ratio < 0.8 }
393
394 pub fn efficiency_score(&self) -> f64 {
396 if self.stats.allocation_requests == 0 {
397 return 1.0;
398 }
399
400 let reuse_ratio = self.stats.pool_reuses as f64 / self.stats.allocation_requests as f64;
401 let fragmentation_ratio =
402 1.0 - (self.stats.current_memory_usage as f64 / self.stats.peak_memory_usage as f64);
403
404 (reuse_ratio + (1.0 - fragmentation_ratio)) / 2.0
405 }
406}
407
408pub struct SparseMemoryHandle {
410 block: Option<MemoryBlock>,
411 bucket_idx: Option<usize>,
412}
413
414impl SparseMemoryHandle {
415 fn new(block: MemoryBlock, bucket_idx: Option<usize>) -> Self {
416 Self {
417 block: Some(block),
418 bucket_idx,
419 }
420 }
421
422 pub fn as_mut_slice(&mut self) -> &mut [u8] {
424 if let Some(ref mut block) = self.block {
425 block.last_accessed = Instant::now();
426 &mut block.memory
427 } else {
428 &mut []
429 }
430 }
431
432 pub fn as_slice(&self) -> &[u8] {
434 if let Some(ref block) = self.block {
435 &block.memory
436 } else {
437 &[]
438 }
439 }
440
441 pub fn size(&self) -> usize {
443 self.block.as_ref().map_or(0, |b| b.memory.len())
444 }
445
446 pub fn age(&self) -> Duration {
448 self.block.as_ref().map_or(Duration::ZERO, |b| {
449 Instant::now().duration_since(b.allocated_at)
450 })
451 }
452
453 pub fn is_valid(&self) -> bool {
455 self.block.is_some()
456 }
457
458 fn into_parts(mut self) -> (MemoryBlock, Option<usize>) {
460 let block = self.block.take().expect("Handle should have a block");
461 (block, self.bucket_idx)
462 }
463}
464
465impl Drop for SparseMemoryHandle {
466 fn drop(&mut self) {
467 if let Some(block) = self.block.take() {
468 if let Ok(mut pool) = get_memory_pool().write() {
469 pool.deallocate(SparseMemoryHandle {
470 block: Some(block),
471 bucket_idx: self.bucket_idx,
472 });
473 }
474 }
475 }
476}
477
478pub struct SparseMemoryManager;
480
481impl SparseMemoryManager {
482 pub fn allocate(size: usize, allocation_type: &str) -> TorshResult<SparseMemoryHandle> {
484 get_memory_pool()
485 .write()
486 .expect("lock should not be poisoned")
487 .allocate(size, allocation_type)
488 }
489
490 pub fn global_statistics() -> MemoryStatistics {
492 get_memory_pool()
493 .read()
494 .expect("lock should not be poisoned")
495 .statistics()
496 }
497
498 pub fn force_garbage_collect() {
500 get_memory_pool()
501 .write()
502 .expect("lock should not be poisoned")
503 .force_garbage_collect();
504 }
505
506 pub fn is_healthy() -> bool {
508 get_memory_pool()
509 .read()
510 .expect("lock should not be poisoned")
511 .is_healthy()
512 }
513
514 pub fn efficiency_score() -> f64 {
516 get_memory_pool()
517 .read()
518 .expect("lock should not be poisoned")
519 .efficiency_score()
520 }
521
522 pub fn configure(config: MemoryPoolConfig) {
524 let mut pool = get_memory_pool()
525 .write()
526 .expect("lock should not be poisoned");
527 *pool = SparseMemoryPool::with_config(config);
528 }
529
530 pub fn generate_report() -> MemoryReport {
532 let pool = get_memory_pool()
533 .read()
534 .expect("lock should not be poisoned");
535 let stats = pool.statistics();
536 let usage_by_type = pool.usage_by_type();
537 let is_healthy = pool.is_healthy();
538 let efficiency = pool.efficiency_score();
539
540 MemoryReport {
541 statistics: stats,
542 usage_by_type,
543 is_healthy,
544 efficiency_score: efficiency,
545 recommendations: Self::generate_recommendations(&pool),
546 }
547 }
548
549 fn generate_recommendations(pool: &SparseMemoryPool) -> Vec<String> {
551 let mut recommendations = Vec::new();
552 let stats = &pool.stats;
553
554 let usage_ratio = pool.total_allocated as f64 / pool.config.max_total_memory as f64;
556 if usage_ratio > 0.9 {
557 recommendations.push("Memory usage is very high. Consider increasing pool size or optimizing allocations.".to_string());
558 } else if usage_ratio > 0.8 {
559 recommendations.push(
560 "Memory usage is high. Monitor closely and consider optimization.".to_string(),
561 );
562 }
563
564 let reuse_ratio = if stats.allocation_requests > 0 {
566 stats.pool_reuses as f64 / stats.allocation_requests as f64
567 } else {
568 0.0
569 };
570
571 if reuse_ratio < 0.3 {
572 recommendations.push("Low memory reuse detected. Consider adjusting bucket sizes or allocation patterns.".to_string());
573 }
574
575 if stats.garbage_collections == 0 && stats.allocation_requests > 100 {
577 recommendations.push(
578 "No garbage collections performed. Consider enabling automatic GC.".to_string(),
579 );
580 }
581
582 let fragmentation = if stats.peak_memory_usage > 0 {
584 1.0 - (stats.current_memory_usage as f64 / stats.peak_memory_usage as f64)
585 } else {
586 0.0
587 };
588
589 if fragmentation > 0.5 {
590 recommendations.push(
591 "High memory fragmentation detected. Consider more frequent garbage collection."
592 .to_string(),
593 );
594 }
595
596 if recommendations.is_empty() {
597 recommendations.push("Memory management appears optimal.".to_string());
598 }
599
600 recommendations
601 }
602}
603
604#[derive(Debug, Clone)]
606pub struct MemoryReport {
607 pub statistics: MemoryStatistics,
608 pub usage_by_type: HashMap<String, usize>,
609 pub is_healthy: bool,
610 pub efficiency_score: f64,
611 pub recommendations: Vec<String>,
612}
613
614pub struct MemoryAwareSparseBuilder {
616 format: SparseFormat,
617 estimated_nnz: usize,
618 memory_handles: Vec<SparseMemoryHandle>,
619 optimization_hints: Vec<String>,
620}
621
622impl MemoryAwareSparseBuilder {
623 pub fn new(format: SparseFormat, estimated_nnz: usize) -> Self {
625 Self {
626 format,
627 estimated_nnz,
628 memory_handles: Vec::new(),
629 optimization_hints: Vec::new(),
630 }
631 }
632
633 pub fn pre_allocate(&mut self) -> TorshResult<()> {
635 let memory_needed = self.estimate_memory_requirements();
636
637 let chunk_size = 1024 * 1024; let num_chunks = memory_needed.div_ceil(chunk_size);
640
641 for i in 0..num_chunks {
642 let size = if i == num_chunks - 1 {
643 memory_needed - i * chunk_size
644 } else {
645 chunk_size
646 };
647
648 let handle = SparseMemoryManager::allocate(size, "sparse_builder")?;
649 self.memory_handles.push(handle);
650 }
651
652 Ok(())
653 }
654
655 fn estimate_memory_requirements(&self) -> usize {
657 match self.format {
658 SparseFormat::Coo => {
659 self.estimated_nnz * (2 * std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
660 }
661 SparseFormat::Csr | SparseFormat::Csc => {
662 self.estimated_nnz * (std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
663 + 1000 * std::mem::size_of::<usize>() }
665 _ => self.estimated_nnz * 3 * std::mem::size_of::<f32>(), }
667 }
668
669 pub fn build(
671 self,
672 data: &[(usize, usize, f32)],
673 shape: Shape,
674 ) -> TorshResult<Box<dyn SparseTensor>> {
675 match self.format {
678 SparseFormat::Coo => {
679 let mut coo = CooTensor::empty(shape, DType::F32)?;
680 for &(row, col, val) in data {
681 coo.insert(row, col, val)?;
682 }
683 Ok(Box::new(coo))
684 }
685 SparseFormat::Csr => {
686 let coo = {
688 let mut coo = CooTensor::empty(shape, DType::F32)?;
689 for &(row, col, val) in data {
690 coo.insert(row, col, val)?;
691 }
692 coo
693 };
694 Ok(Box::new(coo.to_csr()?))
695 }
696 _ => {
697 let mut coo = CooTensor::empty(shape, DType::F32)?;
699 for &(row, col, val) in data {
700 coo.insert(row, col, val)?;
701 }
702 Ok(Box::new(coo))
703 }
704 }
705 }
706
707 pub fn optimization_hints(&self) -> &[String] {
709 &self.optimization_hints
710 }
711}
712
713pub fn create_sparse_with_memory_management(
715 data: &[(usize, usize, f32)],
716 shape: Shape,
717 format: SparseFormat,
718) -> TorshResult<Box<dyn SparseTensor>> {
719 let mut builder = MemoryAwareSparseBuilder::new(format, data.len());
720 builder.pre_allocate()?;
721 builder.build(data, shape)
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[test]
729 fn test_memory_pool_creation() {
730 let pool = SparseMemoryPool::new();
731 assert!(pool.is_healthy());
732 assert_eq!(pool.total_allocated, 0);
733 }
734
735 #[test]
736 fn test_memory_allocation() {
737 let mut pool = SparseMemoryPool::new();
738 let handle = pool.allocate(1024, "test").unwrap();
739
740 assert!(handle.is_valid());
741 assert!(handle.size() >= 1024); assert!(pool.total_allocated > 0);
743 }
744
745 #[test]
746 fn test_memory_statistics() {
747 let mut pool = SparseMemoryPool::new();
748 let _handle1 = pool.allocate(1024, "test").unwrap();
749 let _handle2 = pool.allocate(2048, "test").unwrap();
750
751 let stats = pool.statistics();
752 assert_eq!(stats.allocation_requests, 2);
753 assert_eq!(stats.active_allocations, 2);
754 assert!(stats.current_memory_usage > 0);
755 }
756
757 #[test]
758 fn test_garbage_collection() {
759 let mut pool = SparseMemoryPool::with_config(MemoryPoolConfig {
760 memory_timeout: Duration::from_millis(1),
761 ..Default::default()
762 });
763
764 {
765 let _handle = pool.allocate(1024, "test").unwrap();
766 } std::thread::sleep(Duration::from_millis(10));
769 pool.force_garbage_collect();
770
771 let stats = pool.statistics();
772 assert!(stats.garbage_collections > 0);
773 }
774
775 #[test]
776 fn test_global_memory_manager() {
777 let handle = SparseMemoryManager::allocate(1024, "test").unwrap();
778 assert!(handle.is_valid());
779
780 let stats = SparseMemoryManager::global_statistics();
781 assert!(stats.allocation_requests > 0);
782
783 assert!(SparseMemoryManager::is_healthy());
784 }
785
786 #[test]
787 fn test_memory_report() {
788 let _handle = SparseMemoryManager::allocate(1024, "test").unwrap();
789 let report = SparseMemoryManager::generate_report();
790
791 assert!(report.statistics.allocation_requests > 0);
792 assert!(!report.recommendations.is_empty());
793 }
794
795 #[test]
796 fn test_memory_aware_builder() {
797 let mut builder = MemoryAwareSparseBuilder::new(SparseFormat::Coo, 10);
798 builder.pre_allocate().unwrap();
799
800 let data = vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)];
801 let shape = Shape::new(vec![3, 3]);
802
803 let sparse = builder.build(&data, shape).unwrap();
804 assert_eq!(sparse.nnz(), 3);
805 }
806
807 #[test]
808 fn test_memory_handle_operations() {
809 let mut handle = SparseMemoryManager::allocate(1024, "test").unwrap();
810
811 {
813 let slice = handle.as_mut_slice();
814 slice[0] = 42;
815 slice[1] = 24;
816 }
817
818 let slice = handle.as_slice();
820 assert_eq!(slice[0], 42);
821 assert_eq!(slice[1], 24);
822
823 assert!(handle.age() >= Duration::ZERO);
824 assert!(handle.size() >= 1024); }
826}