sapient_backends_cpu/
pool.rs1use std::collections::HashMap;
7
8use parking_lot::Mutex;
9
10use sapient_core::buffer::BufferHandle;
11use sapient_core::DType;
12
13struct PoolEntry {
16 handle: BufferHandle,
17 last_used: std::time::Instant,
18 capacity: usize,
19}
20
21pub struct PoolAllocator {
29 inner: Mutex<PoolInner>,
30}
31
32struct PoolInner {
33 free: HashMap<usize, Vec<PoolEntry>>,
35 used_bytes: usize,
37 capacity: usize,
39}
40
41impl PoolAllocator {
42 pub fn new(capacity_bytes: usize) -> Self {
43 Self {
44 inner: Mutex::new(PoolInner {
45 free: HashMap::new(),
46 used_bytes: 0,
47 capacity: capacity_bytes,
48 }),
49 }
50 }
51
52 pub fn acquire(&self, numel: usize, dtype: DType) -> Option<BufferHandle> {
56 let byte_size = dtype.byte_count(numel);
57 let mut inner = self.inner.lock();
58
59 if let Some(entries) = inner.free.get_mut(&byte_size) {
61 if let Some(entry) = entries.pop() {
62 inner.used_bytes = inner.used_bytes.saturating_sub(entry.capacity);
63 return Some(entry.handle);
64 }
65 }
66 None
67 }
68
69 pub fn release(&self, handle: BufferHandle, numel: usize, dtype: DType) {
73 let byte_size = dtype.byte_count(numel);
74 let mut inner = self.inner.lock();
75
76 while inner.used_bytes + byte_size > inner.capacity {
78 if !Self::evict_lru(&mut inner) {
79 break;
80 }
81 }
82
83 if inner.used_bytes + byte_size <= inner.capacity {
84 inner.used_bytes += byte_size;
85 inner.free.entry(byte_size).or_default().push(PoolEntry {
86 handle,
87 last_used: std::time::Instant::now(),
88 capacity: byte_size,
89 });
90 }
91 }
93
94 fn evict_lru(inner: &mut PoolInner) -> bool {
95 let mut oldest_key: Option<usize> = None;
97 let mut oldest_time = std::time::Instant::now();
98
99 for (&key, entries) in &inner.free {
100 for entry in entries {
101 if entry.last_used < oldest_time {
102 oldest_time = entry.last_used;
103 oldest_key = Some(key);
104 }
105 }
106 }
107
108 if let Some(key) = oldest_key {
109 if let Some(entries) = inner.free.get_mut(&key) {
110 if let Some(entry) = entries.pop() {
111 inner.used_bytes = inner.used_bytes.saturating_sub(entry.capacity);
112 return true;
113 }
114 }
115 }
116 false
117 }
118
119 pub fn used_bytes(&self) -> usize {
121 self.inner.lock().used_bytes
122 }
123
124 pub fn capacity(&self) -> usize {
126 self.inner.lock().capacity
127 }
128}
129
130impl std::fmt::Debug for PoolAllocator {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 let inner = self.inner.lock();
133 f.debug_struct("PoolAllocator")
134 .field("used_bytes", &inner.used_bytes)
135 .field("capacity", &inner.capacity)
136 .finish()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use sapient_core::buffer::CpuBuffer;
144
145 #[test]
146 fn acquire_release_cycle() {
147 let pool = PoolAllocator::new(1024 * 1024);
148 assert!(pool.acquire(16, DType::F32).is_none());
150
151 let buf = BufferHandle::new(CpuBuffer::zeros(16, DType::F32).unwrap());
153 pool.release(buf, 16, DType::F32);
154 assert_eq!(pool.used_bytes(), 64); let h = pool.acquire(16, DType::F32);
158 assert!(h.is_some());
159 assert_eq!(pool.used_bytes(), 0);
160 }
161}