1use crate::{Device, Result, TensorError};
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone)]
13pub struct MemoryPoolStats {
14 pub total_allocated: usize,
15 pub total_free: usize,
16 pub blocks_allocated: usize,
17 pub blocks_free: usize,
18 pub fragmentation_ratio: f32,
19 pub peak_allocated: usize,
20 pub allocation_count: u64,
21 pub deallocation_count: u64,
22 pub defragmentation_count: u64,
23 pub largest_free_block: usize,
24 pub average_block_size: f32,
25 pub memory_pressure: f32,
26}
27
28#[derive(Debug, Clone)]
30pub struct AllocationTracker {
31 pub timestamp: Instant,
32 pub size: usize,
33 pub block_idx: usize,
34 pub lifetime_us: Option<u64>,
35 pub deallocated_at: Option<Instant>,
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub enum MemoryPressureLevel {
41 Low, Medium, High, Critical, }
46
47#[derive(Debug, Clone)]
49pub(crate) struct MemoryBlock {
50 #[allow(dead_code)] pub offset: usize,
52 pub size: usize,
53 pub is_free: bool,
54 #[allow(dead_code)] pub ref_count: usize, }
57
58impl MemoryBlock {
59 #[allow(dead_code)] pub fn new_free(offset: usize, size: usize) -> Self {
62 Self {
63 offset,
64 size,
65 is_free: true,
66 ref_count: 0,
67 }
68 }
69
70 #[allow(dead_code)] pub fn new_allocated(offset: usize, size: usize) -> Self {
73 Self {
74 offset,
75 size,
76 is_free: false,
77 ref_count: 1, }
79 }
80
81 #[allow(dead_code)] pub fn add_ref(&mut self) {
84 assert!(!self.is_free, "Cannot add reference to free block");
85 self.ref_count += 1;
86 }
87
88 #[allow(dead_code)] pub fn release_ref(&mut self) -> bool {
91 assert!(!self.is_free, "Cannot release reference from free block");
92 assert!(self.ref_count > 0, "Reference count underflow");
93
94 self.ref_count -= 1;
95 self.ref_count == 0 }
97
98 #[allow(dead_code)] pub fn can_free(&self) -> bool {
101 !self.is_free && self.ref_count == 0
102 }
103}
104
105#[derive(Debug)]
107pub struct MemoryPool {
108 #[allow(dead_code)]
109 device: Device,
110 #[cfg(feature = "gpu")]
111 gpu_device: Arc<wgpu::Device>,
112 #[cfg(feature = "gpu")]
113 gpu_queue: Arc<wgpu::Queue>,
114
115 #[allow(dead_code)]
117 pool_size: usize,
118 #[cfg(feature = "gpu")]
119 pool_buffer: wgpu::Buffer,
120
121 #[allow(dead_code)]
123 blocks: Arc<RwLock<Vec<MemoryBlock>>>,
124 #[allow(dead_code)]
125 free_blocks: Arc<Mutex<VecDeque<usize>>>, stats: Arc<Mutex<MemoryPoolStats>>,
129 #[allow(dead_code)]
130 allocation_history: Arc<Mutex<HashMap<usize, AllocationTracker>>>,
131
132 #[allow(dead_code)]
134 auto_defrag_threshold: f32, #[allow(dead_code)]
136 defrag_last_run: Arc<Mutex<Instant>>,
137 #[allow(dead_code)]
138 defrag_min_interval: Duration, }
140
141impl MemoryPool {
142 #[cfg(feature = "gpu")]
144 pub fn new(device_id: usize, pool_size: usize) -> Result<Self> {
145 let gpu_ctx = crate::device::context::get_gpu_context(device_id)?;
146
147 let pool_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
149 label: Some("memory_pool_buffer"),
150 size: pool_size as u64,
151 usage: wgpu::BufferUsages::STORAGE
152 | wgpu::BufferUsages::COPY_SRC
153 | wgpu::BufferUsages::COPY_DST,
154 mapped_at_creation: false,
155 });
156
157 let blocks = vec![MemoryBlock::new_free(0, pool_size)];
159
160 let mut free_blocks = VecDeque::new();
161 free_blocks.push_back(0);
162
163 let stats = MemoryPoolStats {
164 total_allocated: 0,
165 total_free: pool_size,
166 blocks_allocated: 0,
167 blocks_free: 1,
168 fragmentation_ratio: 0.0,
169 peak_allocated: 0,
170 allocation_count: 0,
171 deallocation_count: 0,
172 defragmentation_count: 0,
173 largest_free_block: pool_size,
174 average_block_size: pool_size as f32,
175 memory_pressure: 0.0,
176 };
177
178 Ok(Self {
179 device: Device::Gpu(device_id),
180 gpu_device: gpu_ctx.device.clone(),
181 gpu_queue: gpu_ctx.queue.clone(),
182 pool_size,
183 pool_buffer,
184 blocks: Arc::new(RwLock::new(blocks)),
185 free_blocks: Arc::new(Mutex::new(free_blocks)),
186 stats: Arc::new(Mutex::new(stats)),
187 allocation_history: Arc::new(Mutex::new(HashMap::new())),
188 auto_defrag_threshold: 0.25, defrag_last_run: Arc::new(Mutex::new(Instant::now())),
190 defrag_min_interval: Duration::from_secs(30), })
192 }
193
194 #[cfg(feature = "gpu")]
196 pub fn allocate(&self, size: usize, alignment: usize) -> Result<PooledBuffer<'_>> {
197 let aligned_size = align_size(size, alignment);
198
199 let mut free_blocks = self
200 .free_blocks
201 .lock()
202 .expect("lock should not be poisoned");
203 let mut blocks = self
204 .blocks
205 .write()
206 .expect("write lock should not be poisoned");
207
208 let mut best_block_idx = None;
210 let mut best_size = usize::MAX;
211
212 for &block_idx in free_blocks.iter() {
213 let block = &blocks[block_idx];
214 if block.is_free && block.size >= aligned_size && block.size < best_size {
215 best_block_idx = Some(block_idx);
216 best_size = block.size;
217 }
218 }
219
220 if let Some(block_idx) = best_block_idx {
221 let (offset, block_size) = {
223 let block = &blocks[block_idx];
224 (block.offset, block.size)
225 };
226
227 if block_size > aligned_size {
229 let new_block =
231 MemoryBlock::new_free(offset + aligned_size, block_size - aligned_size);
232 blocks.push(new_block);
233 free_blocks.push_back(blocks.len() - 1);
234 }
235
236 blocks[block_idx] = MemoryBlock::new_allocated(offset, aligned_size);
238
239 free_blocks.retain(|&idx| idx != block_idx);
241
242 let mut history = self
244 .allocation_history
245 .lock()
246 .expect("lock should not be poisoned");
247 history.insert(
248 block_idx,
249 AllocationTracker {
250 timestamp: Instant::now(),
251 size: aligned_size,
252 block_idx,
253 lifetime_us: None,
254 deallocated_at: None,
255 },
256 );
257
258 self.update_enhanced_stats(&blocks);
260
261 #[cfg(feature = "gpu")]
263 self.maybe_auto_defrag();
264
265 Ok(PooledBuffer {
266 pool: self,
267 block_idx,
268 offset,
269 size: aligned_size,
270 })
271 } else {
272 Err(TensorError::allocation_error_simple(format!(
273 "Cannot allocate {} bytes from memory pool",
274 aligned_size
275 )))
276 }
277 }
278
279 #[cfg(feature = "gpu")]
281 pub(crate) fn deallocate(&self, block_idx: usize) -> Result<()> {
282 let mut blocks = self
283 .blocks
284 .write()
285 .expect("write lock should not be poisoned");
286 let mut free_blocks = self
287 .free_blocks
288 .lock()
289 .expect("lock should not be poisoned");
290
291 let block = &mut blocks[block_idx];
292 if block.is_free {
293 return Err(TensorError::invalid_argument(
294 "Attempting to deallocate already free block".to_string(),
295 ));
296 }
297
298 if block.release_ref() {
300 block.is_free = true;
302 free_blocks.push_back(block_idx);
303 }
304
305 let mut history = self
307 .allocation_history
308 .lock()
309 .expect("lock should not be poisoned");
310 if let Some(_tracker) = history.remove(&block_idx) {
311 }
313
314 self.coalesce_blocks(&mut blocks, &mut free_blocks);
316
317 self.update_enhanced_stats(&blocks);
319
320 Ok(())
321 }
322
323 #[cfg(feature = "gpu")]
326 pub fn share_buffer(&self, block_idx: usize) -> Result<bool> {
327 let mut blocks = self
328 .blocks
329 .write()
330 .expect("write lock should not be poisoned");
331
332 if block_idx >= blocks.len() {
333 return Err(TensorError::invalid_argument(format!(
334 "Invalid block index: {}",
335 block_idx
336 )));
337 }
338
339 let block = &mut blocks[block_idx];
340 if block.is_free {
341 return Err(TensorError::invalid_argument(
342 "Cannot share a free block".to_string(),
343 ));
344 }
345
346 block.add_ref();
347 Ok(true)
348 }
349
350 #[cfg(feature = "gpu")]
353 pub fn release_buffer(&self, block_idx: usize) -> Result<bool> {
354 let mut blocks = self
355 .blocks
356 .write()
357 .expect("write lock should not be poisoned");
358 let mut free_blocks = self
359 .free_blocks
360 .lock()
361 .expect("lock should not be poisoned");
362
363 if block_idx >= blocks.len() {
364 return Err(TensorError::invalid_argument(format!(
365 "Invalid block index: {}",
366 block_idx
367 )));
368 }
369
370 let block = &mut blocks[block_idx];
371 if block.is_free {
372 return Err(TensorError::invalid_argument(
373 "Cannot release reference to already free block".to_string(),
374 ));
375 }
376
377 if block.release_ref() {
378 block.is_free = true;
380 free_blocks.push_back(block_idx);
381
382 let mut history = self
384 .allocation_history
385 .lock()
386 .expect("lock should not be poisoned");
387 if let Some(_tracker) = history.remove(&block_idx) {
388 }
390
391 self.update_enhanced_stats(&blocks);
393
394 Ok(true) } else {
396 Ok(false) }
398 }
399
400 #[cfg(feature = "gpu")]
402 pub fn get_buffer_ref_count(&self, block_idx: usize) -> Result<usize> {
403 let blocks = self
404 .blocks
405 .read()
406 .expect("read lock should not be poisoned");
407
408 if block_idx >= blocks.len() {
409 return Err(TensorError::invalid_argument(format!(
410 "Invalid block index: {}",
411 block_idx
412 )));
413 }
414
415 let block = &blocks[block_idx];
416 if block.is_free {
417 Ok(0)
418 } else {
419 Ok(block.ref_count)
420 }
421 }
422
423 #[cfg(feature = "gpu")]
425 fn coalesce_blocks(&self, blocks: &mut [MemoryBlock], free_blocks: &mut VecDeque<usize>) {
426 let mut free_indices: Vec<_> = free_blocks.iter().copied().collect();
428 free_indices.sort_by_key(|&idx| blocks[idx].offset);
429
430 let mut coalesced = Vec::new();
431 let mut i = 0;
432
433 while i < free_indices.len() {
434 let mut current_idx = free_indices[i];
435 let mut current_block = blocks[current_idx].clone();
436
437 while i + 1 < free_indices.len() {
439 let next_idx = free_indices[i + 1];
440 let next_block = &blocks[next_idx];
441
442 if current_block.offset + current_block.size == next_block.offset {
444 current_block.size += next_block.size;
446 i += 1; } else {
448 break;
449 }
450 }
451
452 blocks[current_idx] = current_block;
454 coalesced.push(current_idx);
455 i += 1;
456 }
457
458 free_blocks.clear();
460 for idx in coalesced {
461 free_blocks.push_back(idx);
462 }
463 }
464
465 #[allow(dead_code)]
467 fn update_enhanced_stats(&self, blocks: &[MemoryBlock]) {
468 let mut stats = self.stats.lock().expect("lock should not be poisoned");
469 stats.blocks_allocated = 0;
470 stats.blocks_free = 0;
471 stats.total_allocated = 0;
472 stats.total_free = 0;
473 stats.largest_free_block = 0;
474
475 let mut block_sizes = Vec::new();
476
477 for block in blocks {
478 block_sizes.push(block.size);
479 if block.is_free {
480 stats.blocks_free += 1;
481 stats.total_free += block.size;
482 stats.largest_free_block = stats.largest_free_block.max(block.size);
483 } else {
484 stats.blocks_allocated += 1;
485 stats.total_allocated += block.size;
486 }
487 }
488
489 stats.peak_allocated = stats.peak_allocated.max(stats.total_allocated);
491
492 stats.allocation_count += 1;
494
495 if stats.total_free > 0 {
497 stats.fragmentation_ratio =
498 stats.blocks_free as f32 / (stats.total_free as f32 / 1024.0);
499 } else {
500 stats.fragmentation_ratio = 0.0;
501 }
502
503 if !block_sizes.is_empty() {
505 stats.average_block_size =
506 block_sizes.iter().sum::<usize>() as f32 / block_sizes.len() as f32;
507 }
508
509 let usage_ratio = stats.total_allocated as f32 / self.pool_size as f32;
511 stats.memory_pressure = usage_ratio;
512 }
513
514 #[cfg(feature = "gpu")]
516 #[allow(dead_code)]
517 fn maybe_auto_defrag(&self) {
518 let stats = self.stats.lock().expect("lock should not be poisoned");
519 if stats.fragmentation_ratio > self.auto_defrag_threshold {
520 let mut last_run = self
521 .defrag_last_run
522 .lock()
523 .expect("lock should not be poisoned");
524 if last_run.elapsed() >= self.defrag_min_interval {
525 drop(stats); self.defragment();
527 *last_run = Instant::now();
528 }
529 }
530 }
531
532 #[cfg(feature = "gpu")]
534 #[allow(dead_code)]
535 pub fn defragment(&self) {
536 let mut blocks = self
537 .blocks
538 .write()
539 .expect("write lock should not be poisoned");
540 let mut free_blocks = self
541 .free_blocks
542 .lock()
543 .expect("lock should not be poisoned");
544
545 blocks.sort_by_key(|block| block.offset);
547
548 free_blocks.clear();
550 for (idx, block) in blocks.iter().enumerate() {
551 if block.is_free {
552 free_blocks.push_back(idx);
553 }
554 }
555
556 self.coalesce_blocks(&mut blocks, &mut free_blocks);
558
559 self.update_enhanced_stats(&blocks);
561 let mut stats = self.stats.lock().expect("lock should not be poisoned");
562 stats.defragmentation_count += 1;
563 }
564
565 #[allow(dead_code)]
567 pub fn memory_pressure_level(&self) -> MemoryPressureLevel {
568 let stats = self.stats.lock().expect("lock should not be poisoned");
569 match stats.memory_pressure {
570 p if p < 0.5 => MemoryPressureLevel::Low,
571 p if p < 0.8 => MemoryPressureLevel::Medium,
572 p if p < 0.95 => MemoryPressureLevel::High,
573 _ => MemoryPressureLevel::Critical,
574 }
575 }
576
577 #[cfg(feature = "gpu")]
579 #[allow(dead_code)]
580 pub fn aggressive_cleanup(&self, min_block_size: usize) -> Result<usize> {
581 let mut blocks = self
582 .blocks
583 .write()
584 .expect("write lock should not be poisoned");
585 let mut free_blocks = self
586 .free_blocks
587 .lock()
588 .expect("lock should not be poisoned");
589
590 let mut removed_count = 0;
591
592 let mut i = 0;
594 while i < blocks.len() {
595 if blocks[i].is_free && blocks[i].size < min_block_size {
596 blocks.remove(i);
597 removed_count += 1;
598 } else {
599 i += 1;
600 }
601 }
602
603 free_blocks.clear();
605 for (idx, block) in blocks.iter().enumerate() {
606 if block.is_free {
607 free_blocks.push_back(idx);
608 }
609 }
610
611 self.coalesce_blocks(&mut blocks, &mut free_blocks);
613
614 self.update_enhanced_stats(&blocks);
616
617 Ok(removed_count)
618 }
619
620 pub fn stats(&self) -> MemoryPoolStats {
622 self.stats
623 .lock()
624 .expect("lock should not be poisoned")
625 .clone()
626 }
627
628 #[cfg(feature = "gpu")]
630 pub fn buffer(&self) -> &wgpu::Buffer {
631 &self.pool_buffer
632 }
633
634 #[cfg(feature = "gpu")]
636 pub fn device(&self) -> &wgpu::Device {
637 &self.gpu_device
638 }
639
640 #[cfg(feature = "gpu")]
642 pub fn queue(&self) -> &wgpu::Queue {
643 &self.gpu_queue
644 }
645}
646
647#[derive(Debug)]
649pub struct PooledBuffer<'a> {
650 #[allow(dead_code)]
651 pool: &'a MemoryPool,
652 #[allow(dead_code)]
653 block_idx: usize,
654 offset: usize,
655 size: usize,
656}
657
658impl<'a> PooledBuffer<'a> {
659 pub fn offset(&self) -> usize {
661 self.offset
662 }
663
664 pub fn size(&self) -> usize {
666 self.size
667 }
668
669 #[cfg(feature = "gpu")]
671 pub fn buffer(&self) -> &wgpu::Buffer {
672 self.pool.buffer()
673 }
674
675 pub fn view(&'a self, offset: usize, size: usize) -> Result<BufferView<'a>> {
677 if offset + size > self.size {
678 return Err(TensorError::invalid_argument(format!(
679 "View out of bounds: offset={}, size={}, buffer_size={}",
680 offset, size, self.size
681 )));
682 }
683
684 Ok(BufferView {
685 buffer: self,
686 view_offset: offset,
687 view_size: size,
688 })
689 }
690}
691
692#[cfg(feature = "gpu")]
693impl<'a> Drop for PooledBuffer<'a> {
694 fn drop(&mut self) {
695 let _ = self.pool.deallocate(self.block_idx);
697 }
698}
699
700pub struct BufferView<'a> {
702 buffer: &'a PooledBuffer<'a>,
703 view_offset: usize,
704 view_size: usize,
705}
706
707impl<'a> BufferView<'a> {
708 pub fn absolute_offset(&self) -> usize {
710 self.buffer.offset() + self.view_offset
711 }
712
713 pub fn size(&self) -> usize {
715 self.view_size
716 }
717
718 #[cfg(feature = "gpu")]
720 pub fn buffer(&self) -> &wgpu::Buffer {
721 self.buffer.buffer()
722 }
723}
724
725#[allow(dead_code)]
727pub fn align_size(size: usize, alignment: usize) -> usize {
728 (size + alignment - 1) & !(alignment - 1)
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[test]
736 fn test_align_size() {
737 assert_eq!(align_size(13, 8), 16);
738 assert_eq!(align_size(16, 8), 16);
739 assert_eq!(align_size(17, 8), 24);
740 }
741
742 #[test]
743 fn test_memory_block() {
744 let block = MemoryBlock::new_free(0, 1024);
745 assert!(block.is_free);
746 assert_eq!(block.size, 1024);
747 assert_eq!(block.ref_count, 0);
748
749 let mut allocated_block = MemoryBlock::new_allocated(1024, 512);
750 assert!(!allocated_block.is_free);
751 assert_eq!(allocated_block.ref_count, 1);
752
753 allocated_block.add_ref();
754 assert_eq!(allocated_block.ref_count, 2);
755
756 assert!(!allocated_block.release_ref());
757 assert_eq!(allocated_block.ref_count, 1);
758
759 assert!(allocated_block.release_ref());
760 assert_eq!(allocated_block.ref_count, 0);
761 }
762
763 #[test]
764 fn test_memory_pressure_level() {
765 let pressure = MemoryPressureLevel::Low;
766 assert_eq!(pressure, MemoryPressureLevel::Low);
767
768 let high_pressure = MemoryPressureLevel::High;
769 assert_eq!(high_pressure, MemoryPressureLevel::High);
770 }
771}