Skip to main content

sapient_backends_cpu/
pool.rs

1//! Tensor memory pool with LRU eviction.
2//!
3//! The pool reduces allocation pressure for intermediate tensors that are
4//! created and freed on every forward pass.
5
6use std::collections::HashMap;
7
8use parking_lot::Mutex;
9
10use sapient_core::buffer::BufferHandle;
11use sapient_core::DType;
12
13// ── Entry ────────────────────────────────────────────────────────────────────
14
15struct PoolEntry {
16    handle: BufferHandle,
17    last_used: std::time::Instant,
18    capacity: usize,
19}
20
21// ── PoolAllocator ─────────────────────────────────────────────────────────────
22
23/// LRU memory pool for CPU tensor buffers.
24///
25/// Buffers are keyed by `(numel, dtype)`.  When a caller returns a buffer, it
26/// can be re-acquired on the next allocation of the same size, avoiding heap
27/// `malloc`/`free` on the hot path.
28pub struct PoolAllocator {
29    inner: Mutex<PoolInner>,
30}
31
32struct PoolInner {
33    /// Available buffers, grouped by byte capacity.
34    free: HashMap<usize, Vec<PoolEntry>>,
35    /// Total bytes currently held in the pool.
36    used_bytes: usize,
37    /// Maximum bytes the pool will hold.
38    capacity: usize,
39}
40
41impl PoolAllocator {
42    pub fn new(capacity_bytes: usize) -> Self {
43        Self {
44            inner: Mutex::new(PoolInner {
45                free: HashMap::new(),
46                used_bytes: 0,
47                capacity: capacity_bytes,
48            }),
49        }
50    }
51
52    /// Try to acquire a buffer for `numel` elements of `dtype`.
53    /// Returns `None` if the pool has no suitable entry — caller should
54    /// allocate fresh.
55    pub fn acquire(&self, numel: usize, dtype: DType) -> Option<BufferHandle> {
56        let byte_size = dtype.byte_count(numel);
57        let mut inner = self.inner.lock();
58
59        // Look for an exact or larger buffer.
60        if let Some(entries) = inner.free.get_mut(&byte_size) {
61            if let Some(entry) = entries.pop() {
62                inner.used_bytes = inner.used_bytes.saturating_sub(entry.capacity);
63                return Some(entry.handle);
64            }
65        }
66        None
67    }
68
69    /// Return a buffer to the pool after use.
70    ///
71    /// If the pool is over capacity, evict the least-recently-used entries.
72    pub fn release(&self, handle: BufferHandle, numel: usize, dtype: DType) {
73        let byte_size = dtype.byte_count(numel);
74        let mut inner = self.inner.lock();
75
76        // Evict LRU entries if needed.
77        while inner.used_bytes + byte_size > inner.capacity {
78            if !Self::evict_lru(&mut inner) {
79                break;
80            }
81        }
82
83        if inner.used_bytes + byte_size <= inner.capacity {
84            inner.used_bytes += byte_size;
85            inner.free.entry(byte_size).or_default().push(PoolEntry {
86                handle,
87                last_used: std::time::Instant::now(),
88                capacity: byte_size,
89            });
90        }
91        // If still over capacity: discard (buffer drops).
92    }
93
94    fn evict_lru(inner: &mut PoolInner) -> bool {
95        // Find the oldest entry across all buckets.
96        let mut oldest_key: Option<usize> = None;
97        let mut oldest_time = std::time::Instant::now();
98
99        for (&key, entries) in &inner.free {
100            for entry in entries {
101                if entry.last_used < oldest_time {
102                    oldest_time = entry.last_used;
103                    oldest_key = Some(key);
104                }
105            }
106        }
107
108        if let Some(key) = oldest_key {
109            if let Some(entries) = inner.free.get_mut(&key) {
110                if let Some(entry) = entries.pop() {
111                    inner.used_bytes = inner.used_bytes.saturating_sub(entry.capacity);
112                    return true;
113                }
114            }
115        }
116        false
117    }
118
119    /// Total bytes currently pooled.
120    pub fn used_bytes(&self) -> usize {
121        self.inner.lock().used_bytes
122    }
123
124    /// Pool capacity in bytes.
125    pub fn capacity(&self) -> usize {
126        self.inner.lock().capacity
127    }
128}
129
130impl std::fmt::Debug for PoolAllocator {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        let inner = self.inner.lock();
133        f.debug_struct("PoolAllocator")
134            .field("used_bytes", &inner.used_bytes)
135            .field("capacity", &inner.capacity)
136            .finish()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use sapient_core::buffer::CpuBuffer;
144
145    #[test]
146    fn acquire_release_cycle() {
147        let pool = PoolAllocator::new(1024 * 1024);
148        // Nothing in pool → None.
149        assert!(pool.acquire(16, DType::F32).is_none());
150
151        // Put something in.
152        let buf = BufferHandle::new(CpuBuffer::zeros(16, DType::F32).unwrap());
153        pool.release(buf, 16, DType::F32);
154        assert_eq!(pool.used_bytes(), 64); // 16 * 4 bytes
155
156        // Now acquire should succeed.
157        let h = pool.acquire(16, DType::F32);
158        assert!(h.is_some());
159        assert_eq!(pool.used_bytes(), 0);
160    }
161}