Skip to main content

tenflowers_dataset/
memory_pool.rs

1//! Memory pooling utilities for dataset operations
2//!
3//! This module provides efficient memory allocation and reuse mechanisms
4//! to reduce allocation overhead during dataset iteration and batch processing.
5
6#![allow(unsafe_code)]
7
8use std::alloc::{alloc, dealloc, Layout};
9use std::collections::VecDeque;
10use std::mem;
11use std::ptr::NonNull;
12use std::sync::{Arc, Mutex};
13
14/// Memory pool statistics for monitoring
15#[derive(Debug, Clone, Default)]
16pub struct PoolStats {
17    pub allocations: u64,
18    pub deallocations: u64,
19    pub cache_hits: u64,
20    pub cache_misses: u64,
21    pub current_size: usize,
22    pub peak_size: usize,
23}
24
25impl PoolStats {
26    /// Calculate the cache hit ratio
27    pub fn hit_ratio(&self) -> f64 {
28        if self.cache_hits + self.cache_misses == 0 {
29            0.0
30        } else {
31            self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
32        }
33    }
34
35    /// Get memory utilization efficiency
36    pub fn efficiency(&self) -> f64 {
37        if self.allocations == 0 {
38            0.0
39        } else {
40            self.cache_hits as f64 / self.allocations as f64
41        }
42    }
43}
44
45/// A memory block that can be reused
46#[derive(Debug)]
47struct MemoryBlock {
48    ptr: NonNull<u8>,
49    size: usize,
50    layout: Layout,
51}
52
53impl MemoryBlock {
54    fn new(size: usize) -> Result<Self, String> {
55        let layout = Layout::from_size_align(size, mem::align_of::<u8>())
56            .map_err(|e| format!("Layout error: {e:?}"))?;
57
58        let ptr = NonNull::new(unsafe { alloc(layout) })
59            .ok_or_else(|| "Memory allocation failed".to_string())?;
60
61        Ok(Self { ptr, size, layout })
62    }
63
64    /// Get a slice view of the memory block
65    pub fn as_slice_mut(&mut self) -> &mut [u8] {
66        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
67    }
68
69    /// Get the raw pointer
70    pub fn as_ptr(&self) -> *mut u8 {
71        self.ptr.as_ptr()
72    }
73}
74
75impl Drop for MemoryBlock {
76    fn drop(&mut self) {
77        unsafe {
78            dealloc(self.ptr.as_ptr(), self.layout);
79        }
80    }
81}
82
83// SAFETY: MemoryBlock owns its memory exclusively and the pointer is valid
84unsafe impl Send for MemoryBlock {}
85unsafe impl Sync for MemoryBlock {}
86
87/// Memory pool for efficient allocation and reuse
88pub struct MemoryPool {
89    pools: Vec<Mutex<VecDeque<MemoryBlock>>>,
90    max_blocks_per_size: usize,
91    min_block_size: usize,
92    max_block_size: usize,
93    stats: Arc<Mutex<PoolStats>>,
94}
95
96impl MemoryPool {
97    /// Create a new memory pool
98    pub fn new() -> Self {
99        Self::with_config(64, 1024, 1024 * 1024 * 16) // 1KB to 16MB
100    }
101
102    /// Create a memory pool with custom configuration
103    pub fn with_config(
104        max_blocks_per_size: usize,
105        min_block_size: usize,
106        max_block_size: usize,
107    ) -> Self {
108        // Create pools for different size classes (powers of 2)
109        let mut size = min_block_size;
110        let mut pools = Vec::new();
111
112        while size <= max_block_size {
113            pools.push(Mutex::new(VecDeque::new()));
114            size *= 2;
115        }
116
117        Self {
118            pools,
119            max_blocks_per_size,
120            min_block_size,
121            max_block_size,
122            stats: Arc::new(Mutex::new(PoolStats::default())),
123        }
124    }
125
126    /// Find the appropriate size class for a requested size
127    fn find_size_class(&self, size: usize) -> Option<usize> {
128        if size < self.min_block_size || size > self.max_block_size {
129            return None;
130        }
131
132        let mut class_size = self.min_block_size;
133        let mut class_index = 0;
134
135        while class_size < size && class_index < self.pools.len() {
136            class_size *= 2;
137            class_index += 1;
138        }
139
140        if class_index < self.pools.len() {
141            Some(class_index)
142        } else {
143            None
144        }
145    }
146
147    /// Allocate a memory block from the pool
148    pub fn allocate(self: &Arc<Self>, size: usize) -> Result<PooledMemory, String> {
149        let mut stats = self
150            .stats
151            .lock()
152            .map_err(|e| format!("Failed to acquire stats lock: {e}"))?;
153        stats.allocations += 1;
154
155        if let Some(class_index) = self.find_size_class(size) {
156            let mut pool = self.pools[class_index]
157                .lock()
158                .map_err(|e| format!("Failed to acquire pool lock: {e}"))?;
159
160            if let Some(block) = pool.pop_front() {
161                stats.cache_hits += 1;
162                stats.current_size -= block.size;
163                drop(stats);
164                drop(pool);
165
166                return Ok(PooledMemory {
167                    block: Some(block),
168                    pool: Arc::downgrade(self),
169                    class_index: Some(class_index),
170                });
171            } else {
172                stats.cache_misses += 1;
173                drop(pool);
174            }
175        } else {
176            stats.cache_misses += 1;
177        }
178
179        // Allocate new block
180        let actual_size = if let Some(class_index) = self.find_size_class(size) {
181            self.min_block_size << class_index
182        } else {
183            size
184        };
185
186        let block = MemoryBlock::new(actual_size)?;
187        stats.current_size += block.size;
188        stats.peak_size = stats.peak_size.max(stats.current_size);
189        drop(stats);
190
191        Ok(PooledMemory {
192            block: Some(block),
193            pool: Arc::downgrade(self),
194            class_index: self.find_size_class(size),
195        })
196    }
197
198    /// Return a memory block to the pool
199    fn deallocate(&self, block: MemoryBlock, class_index: Option<usize>) {
200        let mut stats = self.stats.lock().expect("lock should not be poisoned");
201        stats.deallocations += 1;
202
203        if let Some(class_index) = class_index {
204            if class_index < self.pools.len() {
205                let mut pool = self.pools[class_index]
206                    .lock()
207                    .expect("lock should not be poisoned");
208
209                if pool.len() < self.max_blocks_per_size {
210                    stats.current_size += block.size;
211                    pool.push_back(block);
212                    return;
213                }
214            }
215        }
216
217        // Block will be dropped automatically if not returned to pool
218        stats.current_size = stats.current_size.saturating_sub(block.size);
219    }
220
221    /// Get current pool statistics
222    pub fn stats(&self) -> PoolStats {
223        self.stats
224            .lock()
225            .expect("lock should not be poisoned")
226            .clone()
227    }
228
229    /// Clear all cached blocks
230    pub fn clear(&self) {
231        for pool in &self.pools {
232            pool.lock().expect("lock should not be poisoned").clear();
233        }
234
235        let mut stats = self.stats.lock().expect("lock should not be poisoned");
236        stats.current_size = 0;
237    }
238}
239
240impl Clone for MemoryPool {
241    fn clone(&self) -> Self {
242        // Note: This creates a new pool with the same configuration
243        // The actual cached blocks are not cloned
244        Self::with_config(
245            self.max_blocks_per_size,
246            self.min_block_size,
247            self.max_block_size,
248        )
249    }
250}
251
252impl Default for MemoryPool {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// A memory allocation from the pool that automatically returns to the pool on drop
259pub struct PooledMemory {
260    block: Option<MemoryBlock>,
261    pool: std::sync::Weak<MemoryPool>,
262    class_index: Option<usize>,
263}
264
265impl PooledMemory {
266    /// Get the size of the allocated memory
267    pub fn size(&self) -> usize {
268        self.block.as_ref().map(|b| b.size).unwrap_or(0)
269    }
270
271    /// Get a mutable slice view of the memory
272    pub fn as_slice_mut(&mut self) -> &mut [u8] {
273        self.block
274            .as_mut()
275            .expect("block should exist for valid PooledMemory")
276            .as_slice_mut()
277    }
278
279    /// Get the raw pointer
280    pub fn as_ptr(&self) -> *mut u8 {
281        self.block
282            .as_ref()
283            .expect("block should exist for valid PooledMemory")
284            .as_ptr()
285    }
286
287    /// Convert to a `Vec<u8>` (consumes the pooled memory)
288    pub fn into_vec(mut self) -> Vec<u8> {
289        let block = self
290            .block
291            .take()
292            .expect("block should exist for valid PooledMemory");
293        let size = block.size;
294        let ptr = block.as_ptr();
295
296        // Prevent the block from being deallocated
297        mem::forget(block);
298
299        // Create a Vec from the raw pointer
300        unsafe { Vec::from_raw_parts(ptr, size, size) }
301    }
302}
303
304impl Drop for PooledMemory {
305    fn drop(&mut self) {
306        if let Some(block) = self.block.take() {
307            if let Some(pool) = self.pool.upgrade() {
308                pool.deallocate(block, self.class_index);
309            }
310            // If pool is dropped, block will be automatically deallocated
311        }
312    }
313}
314
315/// Thread-safe global memory pool
316pub struct GlobalMemoryPool {
317    pool: Arc<MemoryPool>,
318}
319
320impl GlobalMemoryPool {
321    /// Get the global memory pool instance
322    pub fn instance() -> &'static GlobalMemoryPool {
323        static INSTANCE: std::sync::OnceLock<GlobalMemoryPool> = std::sync::OnceLock::new();
324        INSTANCE.get_or_init(|| GlobalMemoryPool {
325            pool: Arc::new(MemoryPool::new()),
326        })
327    }
328
329    /// Allocate memory from the global pool
330    pub fn allocate(size: usize) -> Result<PooledMemory, String> {
331        Self::instance().pool.allocate(size)
332    }
333
334    /// Get global pool statistics
335    pub fn stats() -> PoolStats {
336        Self::instance().pool.stats()
337    }
338
339    /// Clear the global pool
340    pub fn clear() {
341        Self::instance().pool.clear()
342    }
343}
344
345/// Extension trait for easy memory pool allocation
346pub trait MemoryPoolExt<T> {
347    /// Allocate a vector using the memory pool
348    fn with_pool_capacity(capacity: usize) -> Result<Vec<T>, String>;
349}
350
351impl<T> MemoryPoolExt<T> for Vec<T> {
352    fn with_pool_capacity(capacity: usize) -> Result<Vec<T>, String> {
353        let size = capacity * mem::size_of::<T>();
354        let pooled = GlobalMemoryPool::allocate(size)?;
355
356        // Convert to Vec
357        let vec = pooled.into_vec();
358
359        // Cast to the proper type (this is safe since we allocated the right amount)
360        let ptr = vec.as_ptr() as *mut T;
361        let len = 0;
362
363        mem::forget(vec); // Prevent deallocation
364
365        Ok(unsafe { Vec::from_raw_parts(ptr, len, capacity) })
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_memory_pool_basic() {
375        let pool = Arc::new(MemoryPool::new());
376
377        // Allocate some memory
378        let mut mem1 = pool.allocate(1024).expect("test: operation should succeed");
379        assert_eq!(mem1.size(), 1024);
380
381        // Write some data
382        let slice = mem1.as_slice_mut();
383        slice[0] = 42;
384        slice[1023] = 99;
385
386        drop(mem1);
387
388        // Allocate again
389        let mut mem2 = pool.allocate(1024).expect("test: operation should succeed");
390        let slice2 = mem2.as_slice_mut();
391
392        // Verify we can write to the new allocation
393        slice2[0] = 100;
394        assert_eq!(slice2[0], 100);
395
396        let stats = pool.stats();
397        assert_eq!(stats.allocations, 2);
398        // With the Arc optimization, we should now see cache hits
399        assert_eq!(stats.cache_misses, 1);
400        assert_eq!(stats.cache_hits, 1);
401    }
402
403    #[test]
404    fn test_memory_pool_different_sizes() {
405        let pool = Arc::new(MemoryPool::new());
406
407        let mem1 = pool.allocate(512).expect("test: operation should succeed");
408        let mem2 = pool.allocate(1024).expect("test: operation should succeed");
409        let mem3 = pool.allocate(2048).expect("test: operation should succeed");
410
411        assert!(mem1.size() >= 512);
412        assert!(mem2.size() >= 1024);
413        assert!(mem3.size() >= 2048);
414
415        drop(mem1);
416        drop(mem2);
417        drop(mem3);
418
419        let stats = pool.stats();
420        assert_eq!(stats.allocations, 3);
421    }
422
423    #[test]
424    fn test_global_memory_pool() {
425        let mem1 = GlobalMemoryPool::allocate(1024).expect("test: operation should succeed");
426        assert_eq!(mem1.size(), 1024);
427
428        let mem2 = GlobalMemoryPool::allocate(2048).expect("test: operation should succeed");
429        assert!(mem2.size() >= 2048);
430
431        // Basic functionality test - just verify allocations work
432        drop(mem1);
433        drop(mem2);
434
435        let stats = GlobalMemoryPool::stats();
436        assert!(stats.allocations >= 2);
437    }
438
439    #[test]
440    fn test_vec_with_pool_capacity() {
441        GlobalMemoryPool::clear();
442
443        let mut vec: Vec<i32> =
444            Vec::with_pool_capacity(100).expect("test: operation should succeed");
445        vec.push(42);
446        vec.push(99);
447
448        assert_eq!(vec.len(), 2);
449        assert_eq!(vec.capacity(), 100);
450        assert_eq!(vec[0], 42);
451        assert_eq!(vec[1], 99);
452    }
453
454    #[test]
455    fn test_pool_stats() {
456        let pool = Arc::new(MemoryPool::new());
457
458        let stats = pool.stats();
459        assert_eq!(stats.allocations, 0);
460        assert_eq!(stats.deallocations, 0);
461        assert_eq!(stats.cache_hits, 0);
462        assert_eq!(stats.cache_misses, 0);
463        assert_eq!(stats.hit_ratio(), 0.0);
464        assert_eq!(stats.efficiency(), 0.0);
465
466        let _mem = pool.allocate(1024).expect("test: operation should succeed");
467        let stats = pool.stats();
468        assert_eq!(stats.allocations, 1);
469        assert_eq!(stats.cache_misses, 1);
470    }
471}