1use crate::error::{NeuralError, Result};
7#[cfg(feature = "gpu")]
8use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::{Arc, Mutex};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct SizeClass {
16 pub size: usize,
18}
19
20impl SizeClass {
21 pub fn new(size: usize) -> Self {
23 Self { size }
24 }
25
26 pub fn from_size_rounded(size: usize) -> Self {
28 let rounded = if size == 0 {
29 1
30 } else {
31 size.next_power_of_two()
32 };
33 Self { size: rounded }
34 }
35
36 pub fn size(&self) -> usize {
38 self.size
39 }
40}
41
42#[cfg(feature = "gpu")]
44pub struct GpuMemoryPool<T: GpuDataType> {
45 gpu_context: Arc<GpuContext>,
47 available_buffers: Arc<Mutex<HashMap<SizeClass, Vec<GpuBuffer<T>>>>>,
49 total_allocated: Arc<AtomicU64>,
51 total_in_use: Arc<AtomicU64>,
53 peak_usage: Arc<AtomicU64>,
55 num_allocations: Arc<AtomicU64>,
57 num_deallocations: Arc<AtomicU64>,
59 cache_hits: Arc<AtomicU64>,
61 cache_misses: Arc<AtomicU64>,
63 max_pool_size: u64,
65 auto_cleanup: bool,
67}
68
69#[cfg(feature = "gpu")]
70impl<T: GpuDataType> GpuMemoryPool<T> {
71 pub fn new(gpu_context: Arc<GpuContext>, max_pool_size: u64) -> Self {
73 Self {
74 gpu_context,
75 available_buffers: Arc::new(Mutex::new(HashMap::new())),
76 total_allocated: Arc::new(AtomicU64::new(0)),
77 total_in_use: Arc::new(AtomicU64::new(0)),
78 peak_usage: Arc::new(AtomicU64::new(0)),
79 num_allocations: Arc::new(AtomicU64::new(0)),
80 num_deallocations: Arc::new(AtomicU64::new(0)),
81 cache_hits: Arc::new(AtomicU64::new(0)),
82 cache_misses: Arc::new(AtomicU64::new(0)),
83 max_pool_size,
84 auto_cleanup: true,
85 }
86 }
87
88 pub fn allocate(&self, size: usize) -> Result<PooledBuffer<T>> {
90 self.num_allocations.fetch_add(1, Ordering::Relaxed);
91
92 let size_class = SizeClass::from_size_rounded(size);
93 let actual_size = size_class.size();
94
95 let mut buffers = self.available_buffers.lock().map_err(|_| {
97 NeuralError::TrainingError("Failed to lock available buffers".to_string())
98 })?;
99
100 let buffer = if let Some(pool) = buffers.get_mut(&size_class) {
101 if let Some(buf) = pool.pop() {
102 self.cache_hits.fetch_add(1, Ordering::Relaxed);
104 buf
105 } else {
106 self.cache_misses.fetch_add(1, Ordering::Relaxed);
108 drop(buffers); self.allocate_new_buffer(actual_size)?
110 }
111 } else {
112 self.cache_misses.fetch_add(1, Ordering::Relaxed);
114 drop(buffers); self.allocate_new_buffer(actual_size)?
116 };
117
118 let buffer_size = actual_size * std::mem::size_of::<T>();
120 let current_usage = self
121 .total_in_use
122 .fetch_add(buffer_size as u64, Ordering::Relaxed)
123 + buffer_size as u64;
124
125 let mut peak = self.peak_usage.load(Ordering::Relaxed);
127 while current_usage > peak {
128 match self.peak_usage.compare_exchange_weak(
129 peak,
130 current_usage,
131 Ordering::Relaxed,
132 Ordering::Relaxed,
133 ) {
134 Ok(_) => break,
135 Err(x) => peak = x,
136 }
137 }
138
139 Ok(PooledBuffer {
140 buffer,
141 size_class,
142 pool: Arc::downgrade(&self.available_buffers),
143 total_in_use: Arc::clone(&self.total_in_use),
144 num_deallocations: Arc::clone(&self.num_deallocations),
145 })
146 }
147
148 fn allocate_new_buffer(&self, size: usize) -> Result<GpuBuffer<T>> {
150 let buffer_size = size * std::mem::size_of::<T>();
151
152 if self.max_pool_size > 0 {
154 let total_allocated = self.total_allocated.load(Ordering::Relaxed);
155 if total_allocated + buffer_size as u64 > self.max_pool_size {
156 if self.auto_cleanup {
158 self.cleanup_oldest_buffers(buffer_size as u64)?;
159 } else {
160 return Err(NeuralError::TrainingError(format!(
161 "Memory pool size limit exceeded: {} + {} > {}",
162 total_allocated, buffer_size, self.max_pool_size
163 )));
164 }
165 }
166 }
167
168 let buffer = self.gpu_context.create_buffer::<T>(size);
169 self.total_allocated
170 .fetch_add(buffer_size as u64, Ordering::Relaxed);
171
172 Ok(buffer)
173 }
174
175 fn return_buffer(&self, buffer: GpuBuffer<T>, size_class: SizeClass) -> Result<()> {
177 let mut buffers = self.available_buffers.lock().map_err(|_| {
178 NeuralError::TrainingError("Failed to lock available buffers".to_string())
179 })?;
180
181 buffers
182 .entry(size_class)
183 .or_insert_with(Vec::new)
184 .push(buffer);
185
186 Ok(())
187 }
188
189 fn cleanup_oldest_buffers(&self, required_space: u64) -> Result<()> {
191 let mut buffers = self.available_buffers.lock().map_err(|_| {
192 NeuralError::TrainingError("Failed to lock available buffers".to_string())
193 })?;
194
195 let mut freed_space = 0u64;
196
197 let mut size_classes: Vec<_> = buffers.keys().cloned().collect();
199 size_classes.sort_by_key(|sc| std::cmp::Reverse(sc.size));
200
201 for size_class in size_classes {
202 if freed_space >= required_space {
203 break;
204 }
205
206 if let Some(pool) = buffers.get_mut(&size_class) {
207 while let Some(buffer) = pool.pop() {
208 let buffer_size = (buffer.len() * std::mem::size_of::<T>()) as u64;
209 freed_space += buffer_size;
210 self.total_allocated
211 .fetch_sub(buffer_size, Ordering::Relaxed);
212
213 if freed_space >= required_space {
214 break;
215 }
216 }
217 }
218 }
219
220 Ok(())
221 }
222
223 pub fn clear(&self) -> Result<()> {
225 let mut buffers = self.available_buffers.lock().map_err(|_| {
226 NeuralError::TrainingError("Failed to lock available buffers".to_string())
227 })?;
228
229 for (size_class, pool) in buffers.drain() {
230 let buffer_size = (size_class.size * std::mem::size_of::<T>()) as u64;
231 let count = pool.len() as u64;
232 self.total_allocated
233 .fetch_sub(buffer_size * count, Ordering::Relaxed);
234 }
235
236 Ok(())
237 }
238
239 pub fn get_statistics(&self) -> PoolStatistics {
241 let buffers = self
242 .available_buffers
243 .lock()
244 .expect("Failed to lock buffers");
245
246 let cached_buffers: usize = buffers.values().map(|pool| pool.len()).sum();
247
248 PoolStatistics {
249 total_allocated: self.total_allocated.load(Ordering::Relaxed),
250 total_in_use: self.total_in_use.load(Ordering::Relaxed),
251 peak_usage: self.peak_usage.load(Ordering::Relaxed),
252 cached_buffers,
253 num_allocations: self.num_allocations.load(Ordering::Relaxed),
254 num_deallocations: self.num_deallocations.load(Ordering::Relaxed),
255 cache_hits: self.cache_hits.load(Ordering::Relaxed),
256 cache_misses: self.cache_misses.load(Ordering::Relaxed),
257 cache_hit_rate: {
258 let hits = self.cache_hits.load(Ordering::Relaxed) as f64;
259 let total = hits + self.cache_misses.load(Ordering::Relaxed) as f64;
260 if total > 0.0 {
261 hits / total
262 } else {
263 0.0
264 }
265 },
266 }
267 }
268
269 pub fn set_auto_cleanup(&mut self, enabled: bool) {
271 self.auto_cleanup = enabled;
272 }
273
274 pub fn max_pool_size(&self) -> u64 {
276 self.max_pool_size
277 }
278
279 pub fn set_max_pool_size(&mut self, size: u64) {
281 self.max_pool_size = size;
282 }
283}
284
285#[derive(Debug, Clone)]
287pub struct PoolStatistics {
288 pub total_allocated: u64,
290 pub total_in_use: u64,
292 pub peak_usage: u64,
294 pub cached_buffers: usize,
296 pub num_allocations: u64,
298 pub num_deallocations: u64,
300 pub cache_hits: u64,
302 pub cache_misses: u64,
304 pub cache_hit_rate: f64,
306}
307
308#[cfg(feature = "gpu")]
310pub struct PooledBuffer<T: GpuDataType> {
311 buffer: GpuBuffer<T>,
313 size_class: SizeClass,
315 pool: std::sync::Weak<Mutex<HashMap<SizeClass, Vec<GpuBuffer<T>>>>>,
317 total_in_use: Arc<AtomicU64>,
319 num_deallocations: Arc<AtomicU64>,
321}
322
323#[cfg(feature = "gpu")]
324impl<T: GpuDataType> PooledBuffer<T> {
325 pub fn buffer(&self) -> &GpuBuffer<T> {
327 &self.buffer
328 }
329
330 pub fn size_class(&self) -> SizeClass {
332 self.size_class
333 }
334}
335
336#[cfg(feature = "gpu")]
337impl<T: GpuDataType> Drop for PooledBuffer<T> {
338 fn drop(&mut self) {
339 self.num_deallocations.fetch_add(1, Ordering::Relaxed);
340
341 let buffer_size = (self.buffer.len() * std::mem::size_of::<T>()) as u64;
342 self.total_in_use.fetch_sub(buffer_size, Ordering::Relaxed);
343
344 if let Some(pool) = self.pool.upgrade() {
346 if let Ok(mut buffers) = pool.lock() {
347 let recycled = self.buffer.clone();
349 buffers
350 .entry(self.size_class)
351 .or_insert_with(Vec::new)
352 .push(recycled);
353 }
354 }
355 }
356}
357
358#[cfg(feature = "gpu")]
359impl<T: GpuDataType> std::ops::Deref for PooledBuffer<T> {
360 type Target = GpuBuffer<T>;
361
362 fn deref(&self) -> &Self::Target {
363 &self.buffer
364 }
365}
366
367#[cfg(all(test, feature = "gpu"))]
368mod tests {
369 use super::*;
370 use scirs2_core::gpu::GpuBackend;
371
372 #[test]
373 fn test_size_class() {
374 let size_class = SizeClass::new(100);
375 assert_eq!(size_class.size(), 100);
376
377 let rounded = SizeClass::from_size_rounded(100);
378 assert_eq!(rounded.size(), 128); }
380
381 #[test]
382 fn test_memory_pool_creation() {
383 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
384 let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
385
386 let stats = pool.get_statistics();
387 assert_eq!(stats.total_allocated, 0);
388 assert_eq!(stats.cached_buffers, 0);
389 }
390
391 #[test]
392 fn test_buffer_allocation() {
393 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
394 let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
395
396 let buffer = pool.allocate(1000).expect("Failed to allocate");
397 assert_eq!(buffer.size_class().size(), 1024); let stats = pool.get_statistics();
400 assert_eq!(stats.num_allocations, 1);
401 assert!(stats.total_in_use > 0);
402 }
403
404 #[test]
405 fn test_buffer_reuse() {
406 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
407 let pool = Arc::new(GpuMemoryPool::<f32>::new(
408 Arc::new(context),
409 1024 * 1024 * 1024,
410 ));
411
412 {
414 let _buffer = pool.allocate(1000).expect("Failed to allocate");
415 }
416
417 let _buffer2 = pool.allocate(1000).expect("Failed to allocate");
419
420 let stats = pool.get_statistics();
421 assert_eq!(stats.num_allocations, 2);
422 assert_eq!(stats.cache_hits, 1);
423 }
424
425 #[test]
426 fn test_pool_statistics() {
427 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
428 let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
429
430 let _buffer1 = pool.allocate(1000).expect("Failed to allocate");
431 let _buffer2 = pool.allocate(2000).expect("Failed to allocate");
432
433 let stats = pool.get_statistics();
434 assert_eq!(stats.num_allocations, 2);
435 assert!(stats.total_in_use > 0);
436 assert!(stats.peak_usage >= stats.total_in_use);
437 }
438
439 #[test]
440 fn test_pool_clear() {
441 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
442 let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
443
444 {
445 let _buffer = pool.allocate(1000).expect("Failed to allocate");
446 }
447
448 pool.clear().expect("Failed to clear pool");
449
450 let stats = pool.get_statistics();
451 assert_eq!(stats.cached_buffers, 0);
452 }
453
454 #[test]
455 fn test_pooled_buffer_deref() {
456 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
457 let pool = GpuMemoryPool::<f32>::new(Arc::new(context), 1024 * 1024 * 1024);
458
459 let buffer = pool.allocate(1000).expect("Failed to allocate");
460
461 assert!(!buffer.is_empty());
463 assert!(!buffer.is_empty());
464 }
465}