Skip to main content

tenflowers_core/memory/
pools.rs

1//! Memory pool management for efficient GPU memory allocation
2//!
3//! This module provides memory pool allocators with reference counting,
4//! block management, and automatic defragmentation capabilities.
5
6use crate::{Device, Result, TensorError};
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11/// Memory pool statistics for monitoring
12#[derive(Debug, Clone)]
13pub struct MemoryPoolStats {
14    pub total_allocated: usize,
15    pub total_free: usize,
16    pub blocks_allocated: usize,
17    pub blocks_free: usize,
18    pub fragmentation_ratio: f32,
19    pub peak_allocated: usize,
20    pub allocation_count: u64,
21    pub deallocation_count: u64,
22    pub defragmentation_count: u64,
23    pub largest_free_block: usize,
24    pub average_block_size: f32,
25    pub memory_pressure: f32,
26}
27
28/// Allocation tracking for analytics
29#[derive(Debug, Clone)]
30pub struct AllocationTracker {
31    pub timestamp: Instant,
32    pub size: usize,
33    pub block_idx: usize,
34    pub lifetime_us: Option<u64>,
35    pub deallocated_at: Option<Instant>,
36}
37
38/// Memory pressure levels
39#[derive(Debug, Clone, PartialEq)]
40pub enum MemoryPressureLevel {
41    Low,      // < 50% usage
42    Medium,   // 50-80% usage
43    High,     // 80-95% usage
44    Critical, // > 95% usage
45}
46
47/// Memory block metadata
48#[derive(Debug, Clone)]
49pub(crate) struct MemoryBlock {
50    #[allow(dead_code)] // Used in GPU-feature-gated methods
51    pub offset: usize,
52    pub size: usize,
53    pub is_free: bool,
54    #[allow(dead_code)] // Used in GPU-feature-gated methods
55    pub ref_count: usize, // Reference count for shared buffer management
56}
57
58impl MemoryBlock {
59    /// Create a new free memory block
60    #[allow(dead_code)] // Used in GPU-feature-gated methods
61    pub fn new_free(offset: usize, size: usize) -> Self {
62        Self {
63            offset,
64            size,
65            is_free: true,
66            ref_count: 0,
67        }
68    }
69
70    /// Create a new allocated memory block with initial reference count
71    #[allow(dead_code)] // Used in GPU-feature-gated methods
72    pub fn new_allocated(offset: usize, size: usize) -> Self {
73        Self {
74            offset,
75            size,
76            is_free: false,
77            ref_count: 1, // Start with one reference
78        }
79    }
80
81    /// Increment reference count (for shared buffer access)
82    #[allow(dead_code)] // Used in GPU-feature-gated methods
83    pub fn add_ref(&mut self) {
84        assert!(!self.is_free, "Cannot add reference to free block");
85        self.ref_count += 1;
86    }
87
88    /// Decrement reference count and return true if should be freed
89    #[allow(dead_code)] // Used in GPU-feature-gated methods
90    pub fn release_ref(&mut self) -> bool {
91        assert!(!self.is_free, "Cannot release reference from free block");
92        assert!(self.ref_count > 0, "Reference count underflow");
93
94        self.ref_count -= 1;
95        self.ref_count == 0 // Return true if no more references
96    }
97
98    /// Check if block can be freed (no references remaining)
99    #[allow(dead_code)] // Used in GPU-feature-gated methods
100    pub fn can_free(&self) -> bool {
101        !self.is_free && self.ref_count == 0
102    }
103}
104
105/// Memory pool allocator for efficient GPU memory management
106#[derive(Debug)]
107pub struct MemoryPool {
108    #[allow(dead_code)]
109    device: Device,
110    #[cfg(feature = "gpu")]
111    gpu_device: Arc<wgpu::Device>,
112    #[cfg(feature = "gpu")]
113    gpu_queue: Arc<wgpu::Queue>,
114
115    // Memory pool data
116    #[allow(dead_code)]
117    pool_size: usize,
118    #[cfg(feature = "gpu")]
119    pool_buffer: wgpu::Buffer,
120
121    // Block management
122    #[allow(dead_code)]
123    blocks: Arc<RwLock<Vec<MemoryBlock>>>,
124    #[allow(dead_code)]
125    free_blocks: Arc<Mutex<VecDeque<usize>>>, // Indices of free blocks
126
127    // Statistics and analytics
128    stats: Arc<Mutex<MemoryPoolStats>>,
129    #[allow(dead_code)]
130    allocation_history: Arc<Mutex<HashMap<usize, AllocationTracker>>>,
131
132    // Defragmentation settings
133    #[allow(dead_code)]
134    auto_defrag_threshold: f32, // Trigger defragmentation when fragmentation > threshold
135    #[allow(dead_code)]
136    defrag_last_run: Arc<Mutex<Instant>>,
137    #[allow(dead_code)]
138    defrag_min_interval: Duration, // Minimum time between defragmentation runs
139}
140
141impl MemoryPool {
142    /// Create a new memory pool with specified size in bytes
143    #[cfg(feature = "gpu")]
144    pub fn new(device_id: usize, pool_size: usize) -> Result<Self> {
145        let gpu_ctx = crate::device::context::get_gpu_context(device_id)?;
146
147        // Create large buffer for memory pool
148        let pool_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
149            label: Some("memory_pool_buffer"),
150            size: pool_size as u64,
151            usage: wgpu::BufferUsages::STORAGE
152                | wgpu::BufferUsages::COPY_SRC
153                | wgpu::BufferUsages::COPY_DST,
154            mapped_at_creation: false,
155        });
156
157        // Initialize with single large free block
158        let blocks = vec![MemoryBlock::new_free(0, pool_size)];
159
160        let mut free_blocks = VecDeque::new();
161        free_blocks.push_back(0);
162
163        let stats = MemoryPoolStats {
164            total_allocated: 0,
165            total_free: pool_size,
166            blocks_allocated: 0,
167            blocks_free: 1,
168            fragmentation_ratio: 0.0,
169            peak_allocated: 0,
170            allocation_count: 0,
171            deallocation_count: 0,
172            defragmentation_count: 0,
173            largest_free_block: pool_size,
174            average_block_size: pool_size as f32,
175            memory_pressure: 0.0,
176        };
177
178        Ok(Self {
179            device: Device::Gpu(device_id),
180            gpu_device: gpu_ctx.device.clone(),
181            gpu_queue: gpu_ctx.queue.clone(),
182            pool_size,
183            pool_buffer,
184            blocks: Arc::new(RwLock::new(blocks)),
185            free_blocks: Arc::new(Mutex::new(free_blocks)),
186            stats: Arc::new(Mutex::new(stats)),
187            allocation_history: Arc::new(Mutex::new(HashMap::new())),
188            auto_defrag_threshold: 0.25, // Auto-defrag when 25% fragmented
189            defrag_last_run: Arc::new(Mutex::new(Instant::now())),
190            defrag_min_interval: Duration::from_secs(30), // Min 30 seconds between defrags
191        })
192    }
193
194    /// Allocate memory from the pool
195    #[cfg(feature = "gpu")]
196    pub fn allocate(&self, size: usize, alignment: usize) -> Result<PooledBuffer<'_>> {
197        let aligned_size = align_size(size, alignment);
198
199        let mut free_blocks = self
200            .free_blocks
201            .lock()
202            .expect("lock should not be poisoned");
203        let mut blocks = self
204            .blocks
205            .write()
206            .expect("write lock should not be poisoned");
207
208        // Find suitable free block using best-fit strategy
209        let mut best_block_idx = None;
210        let mut best_size = usize::MAX;
211
212        for &block_idx in free_blocks.iter() {
213            let block = &blocks[block_idx];
214            if block.is_free && block.size >= aligned_size && block.size < best_size {
215                best_block_idx = Some(block_idx);
216                best_size = block.size;
217            }
218        }
219
220        if let Some(block_idx) = best_block_idx {
221            // Get information from the block before splitting
222            let (offset, block_size) = {
223                let block = &blocks[block_idx];
224                (block.offset, block.size)
225            };
226
227            // Split block if necessary
228            if block_size > aligned_size {
229                // Create new free block for remainder
230                let new_block =
231                    MemoryBlock::new_free(offset + aligned_size, block_size - aligned_size);
232                blocks.push(new_block);
233                free_blocks.push_back(blocks.len() - 1);
234            }
235
236            // Mark block as allocated using new constructor logic
237            blocks[block_idx] = MemoryBlock::new_allocated(offset, aligned_size);
238
239            // Remove from free blocks
240            free_blocks.retain(|&idx| idx != block_idx);
241
242            // Track allocation for analytics
243            let mut history = self
244                .allocation_history
245                .lock()
246                .expect("lock should not be poisoned");
247            history.insert(
248                block_idx,
249                AllocationTracker {
250                    timestamp: Instant::now(),
251                    size: aligned_size,
252                    block_idx,
253                    lifetime_us: None,
254                    deallocated_at: None,
255                },
256            );
257
258            // Update statistics
259            self.update_enhanced_stats(&blocks);
260
261            // Check if auto-defragmentation should be triggered
262            #[cfg(feature = "gpu")]
263            self.maybe_auto_defrag();
264
265            Ok(PooledBuffer {
266                pool: self,
267                block_idx,
268                offset,
269                size: aligned_size,
270            })
271        } else {
272            Err(TensorError::allocation_error_simple(format!(
273                "Cannot allocate {} bytes from memory pool",
274                aligned_size
275            )))
276        }
277    }
278
279    /// Deallocate memory back to the pool (reference counting aware)
280    #[cfg(feature = "gpu")]
281    pub(crate) fn deallocate(&self, block_idx: usize) -> Result<()> {
282        let mut blocks = self
283            .blocks
284            .write()
285            .expect("write lock should not be poisoned");
286        let mut free_blocks = self
287            .free_blocks
288            .lock()
289            .expect("lock should not be poisoned");
290
291        let block = &mut blocks[block_idx];
292        if block.is_free {
293            return Err(TensorError::invalid_argument(
294                "Attempting to deallocate already free block".to_string(),
295            ));
296        }
297
298        // Decrement reference count and only free if no references remain
299        if block.release_ref() {
300            // No more references, free the block
301            block.is_free = true;
302            free_blocks.push_back(block_idx);
303        }
304
305        // Update allocation tracking with lifetime
306        let mut history = self
307            .allocation_history
308            .lock()
309            .expect("lock should not be poisoned");
310        if let Some(_tracker) = history.remove(&block_idx) {
311            // Tracker removed from history - could store in a completed allocations log for further analysis
312        }
313
314        // Coalesce adjacent free blocks to reduce fragmentation
315        self.coalesce_blocks(&mut blocks, &mut free_blocks);
316
317        // Update statistics
318        self.update_enhanced_stats(&blocks);
319
320        Ok(())
321    }
322
323    /// Share a buffer by incrementing its reference count
324    /// Returns true if the buffer was successfully shared
325    #[cfg(feature = "gpu")]
326    pub fn share_buffer(&self, block_idx: usize) -> Result<bool> {
327        let mut blocks = self
328            .blocks
329            .write()
330            .expect("write lock should not be poisoned");
331
332        if block_idx >= blocks.len() {
333            return Err(TensorError::invalid_argument(format!(
334                "Invalid block index: {}",
335                block_idx
336            )));
337        }
338
339        let block = &mut blocks[block_idx];
340        if block.is_free {
341            return Err(TensorError::invalid_argument(
342                "Cannot share a free block".to_string(),
343            ));
344        }
345
346        block.add_ref();
347        Ok(true)
348    }
349
350    /// Release a reference to a shared buffer
351    /// Returns true if the buffer was actually freed (reference count reached 0)
352    #[cfg(feature = "gpu")]
353    pub fn release_buffer(&self, block_idx: usize) -> Result<bool> {
354        let mut blocks = self
355            .blocks
356            .write()
357            .expect("write lock should not be poisoned");
358        let mut free_blocks = self
359            .free_blocks
360            .lock()
361            .expect("lock should not be poisoned");
362
363        if block_idx >= blocks.len() {
364            return Err(TensorError::invalid_argument(format!(
365                "Invalid block index: {}",
366                block_idx
367            )));
368        }
369
370        let block = &mut blocks[block_idx];
371        if block.is_free {
372            return Err(TensorError::invalid_argument(
373                "Cannot release reference to already free block".to_string(),
374            ));
375        }
376
377        if block.release_ref() {
378            // No more references, free the block
379            block.is_free = true;
380            free_blocks.push_back(block_idx);
381
382            // Update allocation tracking
383            let mut history = self
384                .allocation_history
385                .lock()
386                .expect("lock should not be poisoned");
387            if let Some(_tracker) = history.remove(&block_idx) {
388                // Tracker removed from history - could store in a separate history if needed for analysis
389            }
390
391            // Update statistics
392            self.update_enhanced_stats(&blocks);
393
394            Ok(true) // Buffer was freed
395        } else {
396            Ok(false) // Buffer still has references
397        }
398    }
399
400    /// Get the current reference count for a buffer
401    #[cfg(feature = "gpu")]
402    pub fn get_buffer_ref_count(&self, block_idx: usize) -> Result<usize> {
403        let blocks = self
404            .blocks
405            .read()
406            .expect("read lock should not be poisoned");
407
408        if block_idx >= blocks.len() {
409            return Err(TensorError::invalid_argument(format!(
410                "Invalid block index: {}",
411                block_idx
412            )));
413        }
414
415        let block = &blocks[block_idx];
416        if block.is_free {
417            Ok(0)
418        } else {
419            Ok(block.ref_count)
420        }
421    }
422
423    /// Coalesce adjacent free blocks to reduce fragmentation
424    #[cfg(feature = "gpu")]
425    fn coalesce_blocks(&self, blocks: &mut [MemoryBlock], free_blocks: &mut VecDeque<usize>) {
426        // Sort free blocks by offset
427        let mut free_indices: Vec<_> = free_blocks.iter().copied().collect();
428        free_indices.sort_by_key(|&idx| blocks[idx].offset);
429
430        let mut coalesced = Vec::new();
431        let mut i = 0;
432
433        while i < free_indices.len() {
434            let mut current_idx = free_indices[i];
435            let mut current_block = blocks[current_idx].clone();
436
437            // Look for adjacent blocks to coalesce
438            while i + 1 < free_indices.len() {
439                let next_idx = free_indices[i + 1];
440                let next_block = &blocks[next_idx];
441
442                // Check if blocks are adjacent
443                if current_block.offset + current_block.size == next_block.offset {
444                    // Coalesce blocks
445                    current_block.size += next_block.size;
446                    i += 1; // Skip next block as it's now coalesced
447                } else {
448                    break;
449                }
450            }
451
452            // Update the block
453            blocks[current_idx] = current_block;
454            coalesced.push(current_idx);
455            i += 1;
456        }
457
458        // Update free blocks queue
459        free_blocks.clear();
460        for idx in coalesced {
461            free_blocks.push_back(idx);
462        }
463    }
464
465    /// Enhanced statistics update with advanced analytics
466    #[allow(dead_code)]
467    fn update_enhanced_stats(&self, blocks: &[MemoryBlock]) {
468        let mut stats = self.stats.lock().expect("lock should not be poisoned");
469        stats.blocks_allocated = 0;
470        stats.blocks_free = 0;
471        stats.total_allocated = 0;
472        stats.total_free = 0;
473        stats.largest_free_block = 0;
474
475        let mut block_sizes = Vec::new();
476
477        for block in blocks {
478            block_sizes.push(block.size);
479            if block.is_free {
480                stats.blocks_free += 1;
481                stats.total_free += block.size;
482                stats.largest_free_block = stats.largest_free_block.max(block.size);
483            } else {
484                stats.blocks_allocated += 1;
485                stats.total_allocated += block.size;
486            }
487        }
488
489        // Update peak allocated
490        stats.peak_allocated = stats.peak_allocated.max(stats.total_allocated);
491
492        // Update counters
493        stats.allocation_count += 1;
494
495        // Calculate fragmentation ratio
496        if stats.total_free > 0 {
497            stats.fragmentation_ratio =
498                stats.blocks_free as f32 / (stats.total_free as f32 / 1024.0);
499        } else {
500            stats.fragmentation_ratio = 0.0;
501        }
502
503        // Calculate average block size
504        if !block_sizes.is_empty() {
505            stats.average_block_size =
506                block_sizes.iter().sum::<usize>() as f32 / block_sizes.len() as f32;
507        }
508
509        // Calculate memory pressure
510        let usage_ratio = stats.total_allocated as f32 / self.pool_size as f32;
511        stats.memory_pressure = usage_ratio;
512    }
513
514    /// Check if auto-defragmentation should be triggered
515    #[cfg(feature = "gpu")]
516    #[allow(dead_code)]
517    fn maybe_auto_defrag(&self) {
518        let stats = self.stats.lock().expect("lock should not be poisoned");
519        if stats.fragmentation_ratio > self.auto_defrag_threshold {
520            let mut last_run = self
521                .defrag_last_run
522                .lock()
523                .expect("lock should not be poisoned");
524            if last_run.elapsed() >= self.defrag_min_interval {
525                drop(stats); // Release lock before defragmentation
526                self.defragment();
527                *last_run = Instant::now();
528            }
529        }
530    }
531
532    /// Perform active defragmentation of memory pool
533    #[cfg(feature = "gpu")]
534    #[allow(dead_code)]
535    pub fn defragment(&self) {
536        let mut blocks = self
537            .blocks
538            .write()
539            .expect("write lock should not be poisoned");
540        let mut free_blocks = self
541            .free_blocks
542            .lock()
543            .expect("lock should not be poisoned");
544
545        // Sort blocks by offset to enable merging
546        blocks.sort_by_key(|block| block.offset);
547
548        // Rebuild free blocks list based on sorted blocks
549        free_blocks.clear();
550        for (idx, block) in blocks.iter().enumerate() {
551            if block.is_free {
552                free_blocks.push_back(idx);
553            }
554        }
555
556        // Coalesce adjacent free blocks
557        self.coalesce_blocks(&mut blocks, &mut free_blocks);
558
559        // Update statistics
560        self.update_enhanced_stats(&blocks);
561        let mut stats = self.stats.lock().expect("lock should not be poisoned");
562        stats.defragmentation_count += 1;
563    }
564
565    /// Get current memory pressure level
566    #[allow(dead_code)]
567    pub fn memory_pressure_level(&self) -> MemoryPressureLevel {
568        let stats = self.stats.lock().expect("lock should not be poisoned");
569        match stats.memory_pressure {
570            p if p < 0.5 => MemoryPressureLevel::Low,
571            p if p < 0.8 => MemoryPressureLevel::Medium,
572            p if p < 0.95 => MemoryPressureLevel::High,
573            _ => MemoryPressureLevel::Critical,
574        }
575    }
576
577    /// Force cleanup of small free blocks (aggressive defragmentation)
578    #[cfg(feature = "gpu")]
579    #[allow(dead_code)]
580    pub fn aggressive_cleanup(&self, min_block_size: usize) -> Result<usize> {
581        let mut blocks = self
582            .blocks
583            .write()
584            .expect("write lock should not be poisoned");
585        let mut free_blocks = self
586            .free_blocks
587            .lock()
588            .expect("lock should not be poisoned");
589
590        let mut removed_count = 0;
591
592        // Remove small free blocks and merge their space
593        let mut i = 0;
594        while i < blocks.len() {
595            if blocks[i].is_free && blocks[i].size < min_block_size {
596                blocks.remove(i);
597                removed_count += 1;
598            } else {
599                i += 1;
600            }
601        }
602
603        // Rebuild free blocks list
604        free_blocks.clear();
605        for (idx, block) in blocks.iter().enumerate() {
606            if block.is_free {
607                free_blocks.push_back(idx);
608            }
609        }
610
611        // Coalesce remaining blocks
612        self.coalesce_blocks(&mut blocks, &mut free_blocks);
613
614        // Update statistics
615        self.update_enhanced_stats(&blocks);
616
617        Ok(removed_count)
618    }
619
620    /// Get memory pool statistics
621    pub fn stats(&self) -> MemoryPoolStats {
622        self.stats
623            .lock()
624            .expect("lock should not be poisoned")
625            .clone()
626    }
627
628    /// Get pool buffer reference
629    #[cfg(feature = "gpu")]
630    pub fn buffer(&self) -> &wgpu::Buffer {
631        &self.pool_buffer
632    }
633
634    /// Get GPU device reference
635    #[cfg(feature = "gpu")]
636    pub fn device(&self) -> &wgpu::Device {
637        &self.gpu_device
638    }
639
640    /// Get GPU queue reference
641    #[cfg(feature = "gpu")]
642    pub fn queue(&self) -> &wgpu::Queue {
643        &self.gpu_queue
644    }
645}
646
647/// A buffer allocated from the memory pool
648#[derive(Debug)]
649pub struct PooledBuffer<'a> {
650    #[allow(dead_code)]
651    pool: &'a MemoryPool,
652    #[allow(dead_code)]
653    block_idx: usize,
654    offset: usize,
655    size: usize,
656}
657
658impl<'a> PooledBuffer<'a> {
659    /// Get the offset within the pool buffer
660    pub fn offset(&self) -> usize {
661        self.offset
662    }
663
664    /// Get the size of the allocated buffer
665    pub fn size(&self) -> usize {
666        self.size
667    }
668
669    /// Get reference to the underlying pool buffer
670    #[cfg(feature = "gpu")]
671    pub fn buffer(&self) -> &wgpu::Buffer {
672        self.pool.buffer()
673    }
674
675    /// Create a view of this buffer with offset and size
676    pub fn view(&'a self, offset: usize, size: usize) -> Result<BufferView<'a>> {
677        if offset + size > self.size {
678            return Err(TensorError::invalid_argument(format!(
679                "View out of bounds: offset={}, size={}, buffer_size={}",
680                offset, size, self.size
681            )));
682        }
683
684        Ok(BufferView {
685            buffer: self,
686            view_offset: offset,
687            view_size: size,
688        })
689    }
690}
691
692#[cfg(feature = "gpu")]
693impl<'a> Drop for PooledBuffer<'a> {
694    fn drop(&mut self) {
695        // Deallocate when buffer is dropped
696        let _ = self.pool.deallocate(self.block_idx);
697    }
698}
699
700/// A view into a pooled buffer for zero-copy operations
701pub struct BufferView<'a> {
702    buffer: &'a PooledBuffer<'a>,
703    view_offset: usize,
704    view_size: usize,
705}
706
707impl<'a> BufferView<'a> {
708    /// Get the absolute offset within the pool buffer
709    pub fn absolute_offset(&self) -> usize {
710        self.buffer.offset() + self.view_offset
711    }
712
713    /// Get the size of the view
714    pub fn size(&self) -> usize {
715        self.view_size
716    }
717
718    /// Get reference to the underlying pool buffer
719    #[cfg(feature = "gpu")]
720    pub fn buffer(&self) -> &wgpu::Buffer {
721        self.buffer.buffer()
722    }
723}
724
725/// Utility function to align size to boundary
726#[allow(dead_code)]
727pub fn align_size(size: usize, alignment: usize) -> usize {
728    (size + alignment - 1) & !(alignment - 1)
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734
735    #[test]
736    fn test_align_size() {
737        assert_eq!(align_size(13, 8), 16);
738        assert_eq!(align_size(16, 8), 16);
739        assert_eq!(align_size(17, 8), 24);
740    }
741
742    #[test]
743    fn test_memory_block() {
744        let block = MemoryBlock::new_free(0, 1024);
745        assert!(block.is_free);
746        assert_eq!(block.size, 1024);
747        assert_eq!(block.ref_count, 0);
748
749        let mut allocated_block = MemoryBlock::new_allocated(1024, 512);
750        assert!(!allocated_block.is_free);
751        assert_eq!(allocated_block.ref_count, 1);
752
753        allocated_block.add_ref();
754        assert_eq!(allocated_block.ref_count, 2);
755
756        assert!(!allocated_block.release_ref());
757        assert_eq!(allocated_block.ref_count, 1);
758
759        assert!(allocated_block.release_ref());
760        assert_eq!(allocated_block.ref_count, 0);
761    }
762
763    #[test]
764    fn test_memory_pressure_level() {
765        let pressure = MemoryPressureLevel::Low;
766        assert_eq!(pressure, MemoryPressureLevel::Low);
767
768        let high_pressure = MemoryPressureLevel::High;
769        assert_eq!(high_pressure, MemoryPressureLevel::High);
770    }
771}