Skip to main content

scirs2_neural/data/
memory_pool.rs

1//! GPU memory pooling for efficient buffer reuse
2//!
3//! This module provides memory pool management for GPU buffers, reducing
4//! allocation/deallocation overhead and improving training performance.
5
6use crate::error::{NeuralError, Result};
7#[cfg(feature = "gpu")]
8use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::{Arc, Mutex};
12
13/// Buffer size class for memory pooling
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct SizeClass {
16    /// Size in elements (not bytes)
17    pub size: usize,
18}
19
20impl SizeClass {
21    /// Create a new size class
22    pub fn new(size: usize) -> Self {
23        Self { size }
24    }
25
26    /// Round up size to the nearest power of 2 for efficient pooling
27    pub fn from_size_rounded(size: usize) -> Self {
28        let rounded = if size == 0 {
29            1
30        } else {
31            size.next_power_of_two()
32        };
33        Self { size: rounded }
34    }
35
36    /// Get the actual size in elements
37    pub fn size(&self) -> usize {
38        self.size
39    }
40}
41
42/// Memory pool for GPU buffers (GPU feature required)
43#[cfg(feature = "gpu")]
44pub struct GpuMemoryPool<T: GpuDataType> {
45    /// GPU context for buffer allocation
46    gpu_context: Arc<GpuContext>,
47    /// Available buffers organized by size class
48    available_buffers: Arc<Mutex<HashMap<SizeClass, Vec<GpuBuffer<T>>>>>,
49    /// Total allocated memory in bytes
50    total_allocated: Arc<AtomicU64>,
51    /// Total memory currently in use
52    total_in_use: Arc<AtomicU64>,
53    /// Peak memory usage
54    peak_usage: Arc<AtomicU64>,
55    /// Number of allocations
56    num_allocations: Arc<AtomicU64>,
57    /// Number of deallocations
58    num_deallocations: Arc<AtomicU64>,
59    /// Number of cache hits
60    cache_hits: Arc<AtomicU64>,
61    /// Number of cache misses
62    cache_misses: Arc<AtomicU64>,
63    /// Maximum pool size in bytes (0 = unlimited)
64    max_pool_size: u64,
65    /// Whether to enable automatic cleanup
66    auto_cleanup: bool,
67}
68
69#[cfg(feature = "gpu")]
70impl<T: GpuDataType> GpuMemoryPool<T> {
71    /// Create a new GPU memory pool
72    pub fn new(gpu_context: Arc<GpuContext>, max_pool_size: u64) -> Self {
73        Self {
74            gpu_context,
75            available_buffers: Arc::new(Mutex::new(HashMap::new())),
76            total_allocated: Arc::new(AtomicU64::new(0)),
77            total_in_use: Arc::new(AtomicU64::new(0)),
78            peak_usage: Arc::new(AtomicU64::new(0)),
79            num_allocations: Arc::new(AtomicU64::new(0)),
80            num_deallocations: Arc::new(AtomicU64::new(0)),
81            cache_hits: Arc::new(AtomicU64::new(0)),
82            cache_misses: Arc::new(AtomicU64::new(0)),
83            max_pool_size,
84            auto_cleanup: true,
85        }
86    }
87
88    /// Allocate a buffer from the pool
89    pub fn allocate(&self, size: usize) -> Result<PooledBuffer<T>> {
90        self.num_allocations.fetch_add(1, Ordering::Relaxed);
91
92        let size_class = SizeClass::from_size_rounded(size);
93        let actual_size = size_class.size();
94
95        // Try to get from pool
96        let mut buffers = self.available_buffers.lock().map_err(|_| {
97            NeuralError::TrainingError("Failed to lock available buffers".to_string())
98        })?;
99
100        let buffer = if let Some(pool) = buffers.get_mut(&size_class) {
101            if let Some(buf) = pool.pop() {
102                // Cache hit
103                self.cache_hits.fetch_add(1, Ordering::Relaxed);
104                buf
105            } else {
106                // Cache miss - allocate new buffer
107                self.cache_misses.fetch_add(1, Ordering::Relaxed);
108                drop(buffers); // Release lock before GPU allocation
109                self.allocate_new_buffer(actual_size)?
110            }
111        } else {
112            // Cache miss - allocate new buffer
113            self.cache_misses.fetch_add(1, Ordering::Relaxed);
114            drop(buffers); // Release lock before GPU allocation
115            self.allocate_new_buffer(actual_size)?
116        };
117
118        // Update usage statistics
119        let buffer_size = actual_size * std::mem::size_of::<T>();
120        let current_usage = self
121            .total_in_use
122            .fetch_add(buffer_size as u64, Ordering::Relaxed)
123            + buffer_size as u64;
124
125        // Update peak usage
126        let mut peak = self.peak_usage.load(Ordering::Relaxed);
127        while current_usage > peak {
128            match self.peak_usage.compare_exchange_weak(
129                peak,
130                current_usage,
131                Ordering::Relaxed,
132                Ordering::Relaxed,
133            ) {
134                Ok(_) => break,
135                Err(x) => peak = x,
136            }
137        }
138
139        Ok(PooledBuffer {
140            buffer,
141            size_class,
142            pool: Arc::downgrade(&self.available_buffers),
143            total_in_use: Arc::clone(&self.total_in_use),
144            num_deallocations: Arc::clone(&self.num_deallocations),
145        })
146    }
147
148    /// Allocate a new GPU buffer
149    fn allocate_new_buffer(&self, size: usize) -> Result<GpuBuffer<T>> {
150        let buffer_size = size * std::mem::size_of::<T>();
151
152        // Check pool size limit
153        if self.max_pool_size > 0 {
154            let total_allocated = self.total_allocated.load(Ordering::Relaxed);
155            if total_allocated + buffer_size as u64 > self.max_pool_size {
156                // Try to free some memory
157                if self.auto_cleanup {
158                    self.cleanup_oldest_buffers(buffer_size as u64)?;
159                } else {
160                    return Err(NeuralError::TrainingError(format!(
161                        "Memory pool size limit exceeded: {} + {} > {}",
162                        total_allocated, buffer_size, self.max_pool_size
163                    )));
164                }
165            }
166        }
167
168        let buffer = self.gpu_context.create_buffer::<T>(size);
169        self.total_allocated
170            .fetch_add(buffer_size as u64, Ordering::Relaxed);
171
172        Ok(buffer)
173    }
174
175    /// Return a buffer to the pool
176    fn return_buffer(&self, buffer: GpuBuffer<T>, size_class: SizeClass) -> Result<()> {
177        let mut buffers = self.available_buffers.lock().map_err(|_| {
178            NeuralError::TrainingError("Failed to lock available buffers".to_string())
179        })?;
180
181        buffers
182            .entry(size_class)
183            .or_insert_with(Vec::new)
184            .push(buffer);
185
186        Ok(())
187    }
188
189    /// Clean up oldest buffers to free memory
190    fn cleanup_oldest_buffers(&self, required_space: u64) -> Result<()> {
191        let mut buffers = self.available_buffers.lock().map_err(|_| {
192            NeuralError::TrainingError("Failed to lock available buffers".to_string())
193        })?;
194
195        let mut freed_space = 0u64;
196
197        // Sort size classes by size (largest first) for efficient cleanup
198        let mut size_classes: Vec<_> = buffers.keys().cloned().collect();
199        size_classes.sort_by_key(|sc| std::cmp::Reverse(sc.size));
200
201        for size_class in size_classes {
202            if freed_space >= required_space {
203                break;
204            }
205
206            if let Some(pool) = buffers.get_mut(&size_class) {
207                while let Some(buffer) = pool.pop() {
208                    let buffer_size = (buffer.len() * std::mem::size_of::<T>()) as u64;
209                    freed_space += buffer_size;
210                    self.total_allocated
211                        .fetch_sub(buffer_size, Ordering::Relaxed);
212
213                    if freed_space >= required_space {
214                        break;
215                    }
216                }
217            }
218        }
219
220        Ok(())
221    }
222
223    /// Clear all cached buffers
224    pub fn clear(&self) -> Result<()> {
225        let mut buffers = self.available_buffers.lock().map_err(|_| {
226            NeuralError::TrainingError("Failed to lock available buffers".to_string())
227        })?;
228
229        for (size_class, pool) in buffers.drain() {
230            let buffer_size = (size_class.size * std::mem::size_of::<T>()) as u64;
231            let count = pool.len() as u64;
232            self.total_allocated
233                .fetch_sub(buffer_size * count, Ordering::Relaxed);
234        }
235
236        Ok(())
237    }
238
239    /// Get memory pool statistics
240    pub fn get_statistics(&self) -> PoolStatistics {
241        let buffers = self
242            .available_buffers
243            .lock()
244            .expect("Failed to lock buffers");
245
246        let cached_buffers: usize = buffers.values().map(|pool| pool.len()).sum();
247
248        PoolStatistics {
249            total_allocated: self.total_allocated.load(Ordering::Relaxed),
250            total_in_use: self.total_in_use.load(Ordering::Relaxed),
251            peak_usage: self.peak_usage.load(Ordering::Relaxed),
252            cached_buffers,
253            num_allocations: self.num_allocations.load(Ordering::Relaxed),
254            num_deallocations: self.num_deallocations.load(Ordering::Relaxed),
255            cache_hits: self.cache_hits.load(Ordering::Relaxed),
256            cache_misses: self.cache_misses.load(Ordering::Relaxed),
257            cache_hit_rate: {
258                let hits = self.cache_hits.load(Ordering::Relaxed) as f64;
259                let total = hits + self.cache_misses.load(Ordering::Relaxed) as f64;
260                if total > 0.0 {
261                    hits / total
262                } else {
263                    0.0
264                }
265            },
266        }
267    }
268
269    /// Enable or disable automatic cleanup
270    pub fn set_auto_cleanup(&mut self, enabled: bool) {
271        self.auto_cleanup = enabled;
272    }
273
274    /// Get maximum pool size
275    pub fn max_pool_size(&self) -> u64 {
276        self.max_pool_size
277    }
278
279    /// Set maximum pool size
280    pub fn set_max_pool_size(&mut self, size: u64) {
281        self.max_pool_size = size;
282    }
283}
284
285/// Statistics for memory pool
286#[derive(Debug, Clone)]
287pub struct PoolStatistics {
288    /// Total allocated memory in bytes
289    pub total_allocated: u64,
290    /// Total memory currently in use
291    pub total_in_use: u64,
292    /// Peak memory usage
293    pub peak_usage: u64,
294    /// Number of cached buffers
295    pub cached_buffers: usize,
296    /// Number of allocations
297    pub num_allocations: u64,
298    /// Number of deallocations
299    pub num_deallocations: u64,
300    /// Number of cache hits
301    pub cache_hits: u64,
302    /// Number of cache misses
303    pub cache_misses: u64,
304    /// Cache hit rate (0.0 to 1.0)
305    pub cache_hit_rate: f64,
306}
307
308/// A pooled GPU buffer that automatically returns to the pool when dropped (GPU feature required)
309#[cfg(feature = "gpu")]
310pub struct PooledBuffer<T: GpuDataType> {
311    /// The actual GPU buffer
312    buffer: GpuBuffer<T>,
313    /// Size class for returning to pool
314    size_class: SizeClass,
315    /// Weak reference to the pool
316    pool: std::sync::Weak<Mutex<HashMap<SizeClass, Vec<GpuBuffer<T>>>>>,
317    /// Reference to in-use counter
318    total_in_use: Arc<AtomicU64>,
319    /// Reference to deallocation counter
320    num_deallocations: Arc<AtomicU64>,
321}
322
323#[cfg(feature = "gpu")]
324impl<T: GpuDataType> PooledBuffer<T> {
325    /// Get a reference to the underlying buffer
326    pub fn buffer(&self) -> &GpuBuffer<T> {
327        &self.buffer
328    }
329
330    /// Get the size class
331    pub fn size_class(&self) -> SizeClass {
332        self.size_class
333    }
334}
335
336#[cfg(feature = "gpu")]
337impl<T: GpuDataType> Drop for PooledBuffer<T> {
338    fn drop(&mut self) {
339        self.num_deallocations.fetch_add(1, Ordering::Relaxed);
340
341        let buffer_size = (self.buffer.len() * std::mem::size_of::<T>()) as u64;
342        self.total_in_use.fetch_sub(buffer_size, Ordering::Relaxed);
343
344        // Return buffer to pool if pool still exists
345        if let Some(pool) = self.pool.upgrade() {
346            if let Ok(mut buffers) = pool.lock() {
347                // Clone the buffer and return it to the pool for reuse
348                let recycled = self.buffer.clone();
349                buffers
350                    .entry(self.size_class)
351                    .or_insert_with(Vec::new)
352                    .push(recycled);
353            }
354        }
355    }
356}
357
358#[cfg(feature = "gpu")]
359impl<T: GpuDataType> std::ops::Deref for PooledBuffer<T> {
360    type Target = GpuBuffer<T>;
361
362    fn deref(&self) -> &Self::Target {
363        &self.buffer
364    }
365}
366
367#[cfg(all(test, feature = "gpu"))]
368mod tests {
369    use super::*;
370    use scirs2_core::gpu::GpuBackend;
371
372    #[test]
373    fn test_size_class() {
374        let size_class = SizeClass::new(100);
375        assert_eq!(size_class.size(), 100);
376
377        let rounded = SizeClass::from_size_rounded(100);
378        assert_eq!(rounded.size(), 128); // Next power of 2
379    }
380
381    #[test]
382    fn test_memory_pool_creation() {
383        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
384        let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
385
386        let stats = pool.get_statistics();
387        assert_eq!(stats.total_allocated, 0);
388        assert_eq!(stats.cached_buffers, 0);
389    }
390
391    #[test]
392    fn test_buffer_allocation() {
393        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
394        let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
395
396        let buffer = pool.allocate(1000).expect("Failed to allocate");
397        assert_eq!(buffer.size_class().size(), 1024); // Rounded to power of 2
398
399        let stats = pool.get_statistics();
400        assert_eq!(stats.num_allocations, 1);
401        assert!(stats.total_in_use > 0);
402    }
403
404    #[test]
405    fn test_buffer_reuse() {
406        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
407        let pool = Arc::new(GpuMemoryPool::<f32>::new(
408            Arc::new(context),
409            1024 * 1024 * 1024,
410        ));
411
412        // Allocate and drop buffer
413        {
414            let _buffer = pool.allocate(1000).expect("Failed to allocate");
415        }
416
417        // Allocate again - should reuse
418        let _buffer2 = pool.allocate(1000).expect("Failed to allocate");
419
420        let stats = pool.get_statistics();
421        assert_eq!(stats.num_allocations, 2);
422        assert_eq!(stats.cache_hits, 1);
423    }
424
425    #[test]
426    fn test_pool_statistics() {
427        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
428        let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
429
430        let _buffer1 = pool.allocate(1000).expect("Failed to allocate");
431        let _buffer2 = pool.allocate(2000).expect("Failed to allocate");
432
433        let stats = pool.get_statistics();
434        assert_eq!(stats.num_allocations, 2);
435        assert!(stats.total_in_use > 0);
436        assert!(stats.peak_usage >= stats.total_in_use);
437    }
438
439    #[test]
440    fn test_pool_clear() {
441        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
442        let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
443
444        {
445            let _buffer = pool.allocate(1000).expect("Failed to allocate");
446        }
447
448        pool.clear().expect("Failed to clear pool");
449
450        let stats = pool.get_statistics();
451        assert_eq!(stats.cached_buffers, 0);
452    }
453
454    #[test]
455    fn test_pooled_buffer_deref() {
456        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
457        let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
458
459        let buffer = pool.allocate(1000).expect("Failed to allocate");
460
461        // Test deref
462        assert!(!buffer.is_empty());
463        assert!(!buffer.is_empty());
464    }
465}