1use std::cell::UnsafeCell;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9
10use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
11
12#[cfg(feature = "opencl")]
13use opencl3::command_queue::{CommandQueue, CL_QUEUE_PROFILING_ENABLE};
14#[cfg(feature = "opencl")]
15use opencl3::context::Context;
16#[cfg(feature = "opencl")]
17use opencl3::device::{get_all_devices, Device, CL_DEVICE_TYPE_GPU};
18#[cfg(feature = "opencl")]
19use opencl3::kernel::{ExecuteKernel, Kernel};
20#[cfg(feature = "opencl")]
21use opencl3::memory::{Buffer, ClMem, CL_MEM_READ_WRITE};
22#[cfg(feature = "opencl")]
23use opencl3::platform::get_platforms;
24#[cfg(feature = "opencl")]
25use opencl3::program::Program;
26#[cfg(feature = "opencl")]
27use opencl3::types::CL_BLOCKING;
28
29#[cfg(not(feature = "opencl"))]
31type CLPlatformId = *mut std::ffi::c_void;
32#[cfg(not(feature = "opencl"))]
33type CLDeviceId = *mut std::ffi::c_void;
34#[cfg(not(feature = "opencl"))]
35type CLContext = *mut std::ffi::c_void;
36#[cfg(not(feature = "opencl"))]
37type CLCommandQueue = *mut std::ffi::c_void;
38#[cfg(not(feature = "opencl"))]
39type CLProgram = *mut std::ffi::c_void;
40#[cfg(not(feature = "opencl"))]
41type CLKernel = *mut std::ffi::c_void;
42#[cfg(not(feature = "opencl"))]
43type CLMem = *mut std::ffi::c_void;
44
45#[allow(dead_code)]
47const ADAM_KERNEL_OPENCL: &str = r#"
48__kernel void adam_update_f32(
49 __global float* params, __global const float* grads, __global float* m, __global float* v,
50 const float lr,
51 const float beta1,
52 const float beta2,
53 const float eps,
54 const float weight_decay,
55 const float bias_correction1,
56 const float bias_correction2,
57 const int n
58) {
59 const int idx = get_global_id(0);
60
61 if (idx < n) {
62 float grad = grads[idx];
63
64 // Apply weight decay
65 if (weight_decay > 0.0f) {
66 grad += weight_decay * params[idx];
67 }
68
69 // Update biased first moment estimate
70 m[idx] = beta1 * m[idx] + (1.0f - beta1) * grad;
71
72 // Update biased second raw moment estimate
73 v[idx] = beta2 * v[idx] + (1.0f - beta2) * grad * grad;
74
75 // Compute bias-corrected moment estimates
76 float m_hat = m[idx] / bias_correction1;
77 float v_hat = v[idx] / bias_correction2;
78
79 // Update parameters
80 params[idx] -= lr * m_hat / (sqrt(v_hat) + eps);
81 }
82}
83"#;
84
85#[allow(dead_code)]
86const GEMM_KERNEL_OPENCL: &str = r#"
87__kernel void gemm_f32(
88 __global const float* A, __global const float* B, __global float* C,
89 const int M,
90 const int N,
91 const int K,
92 const float alpha,
93 const float beta
94) {
95 const int row = get_global_id(0);
96 const int col = get_global_id(1);
97
98 if (row < M && col < N) {
99 float sum = 0.0f;
100 for (int k = 0; k < K; k++) {
101 sum += A[row * K + k] * B[k * N + col];
102 }
103 C[row * N + col] = alpha * sum + beta * C[row * N + col];
104 }
105}
106"#;
107
108pub struct OpenCLContext {
110 #[cfg(feature = "opencl")]
111 device: Arc<Device>,
112 #[cfg(feature = "opencl")]
113 context: Arc<Context>,
114 #[cfg(feature = "opencl")]
115 queue: Arc<CommandQueue>,
116 #[cfg(not(feature = "opencl"))]
117 device: CLDeviceId,
118 #[cfg(not(feature = "opencl"))]
119 context: CLContext,
120 #[cfg(not(feature = "opencl"))]
121 queue: CLCommandQueue,
122 compiled_kernels: Arc<Mutex<HashMap<String, OpenCLKernel>>>,
123 memory_pool: Arc<Mutex<OpenCLMemoryPool>>,
124}
125
126unsafe impl Send for OpenCLContext {}
128unsafe impl Sync for OpenCLContext {}
129
130impl OpenCLContext {
131 pub fn new() -> Result<Self, GpuError> {
133 #[cfg(feature = "opencl")]
134 {
135 let platforms = get_platforms()
137 .map_err(|e| GpuError::Other(format!("Failed to get OpenCL platforms: {e}")))?;
138
139 if platforms.is_empty() {
140 return Err(GpuError::Other("No OpenCL platforms found".to_string()));
141 }
142
143 let device_ids = get_all_devices(CL_DEVICE_TYPE_GPU)
144 .map_err(|e| GpuError::Other(format!("Failed to get OpenCL GPU devices: {e}")))?;
145
146 if device_ids.is_empty() {
147 return Err(GpuError::Other("No OpenCL GPU devices found".to_string()));
148 }
149
150 let device = Device::new(device_ids[0]);
151 let context = Context::from_device(&device)
152 .map_err(|e| GpuError::Other(format!("Failed to create OpenCL context: {e}")))?;
153
154 let queue =
155 CommandQueue::create_default(&context, CL_QUEUE_PROFILING_ENABLE).map_err(|e| {
156 GpuError::Other(format!("Failed to create OpenCL command queue: {e}"))
157 })?;
158
159 Ok(Self {
160 device: Arc::new(device),
161 context: Arc::new(context),
162 queue: Arc::new(queue),
163 compiled_kernels: Arc::new(Mutex::new(HashMap::new())),
164 memory_pool: Arc::new(Mutex::new(OpenCLMemoryPool::new(1024 * 1024 * 1024))), })
166 }
167 #[cfg(not(feature = "opencl"))]
168 {
169 let device = Self::initialize_opencl()?;
171 let context = Self::create_opencl_context(device)?;
172 let queue = Self::create_command_queue(context, device)?;
173
174 Ok(Self {
175 device,
176 context,
177 queue,
178 compiled_kernels: Arc::new(Mutex::new(HashMap::new())),
179 memory_pool: Arc::new(Mutex::new(OpenCLMemoryPool::new(1024 * 1024 * 1024))), })
181 }
182 }
183
184 pub fn is_available() -> bool {
186 #[cfg(feature = "opencl")]
187 {
188 match get_platforms() {
190 Ok(platforms) if !platforms.is_empty() => {
191 match get_all_devices(CL_DEVICE_TYPE_GPU) {
192 Ok(devices) => !devices.is_empty(),
193 Err(_) => false,
194 }
195 }
196 _ => false,
197 }
198 }
199 #[cfg(not(feature = "opencl"))]
200 {
201 false
203 }
204 }
205
206 fn compile_kernel_internal(&self, source: &str, name: &str) -> Result<OpenCLKernel, GpuError> {
208 #[cfg(feature = "opencl")]
209 {
210 let program = Program::create_and_build_from_source(&self.context, source, "")
212 .map_err(|e| {
213 GpuError::Other(format!(
214 "OpenCL kernel compilation failed for {}: {}",
215 name, e
216 ))
217 })?;
218
219 let kernel = Kernel::create(&program, name).map_err(|e| {
220 GpuError::Other(format!("Failed to create OpenCL kernel {name}: {e}"))
221 })?;
222
223 Ok(OpenCLKernel {
224 kernel,
225 queue: Arc::clone(&self.queue),
226 name: name.to_string(),
227 })
228 }
229 #[cfg(not(feature = "opencl"))]
230 {
231 let program = Self::compile_opencl_source(source, name)?;
233 let kernel = Self::create_kernel_from_program(program, name)?;
234
235 Ok(OpenCLKernel {
236 program,
237 kernel,
238 queue: self.queue,
239 name: name.to_string(),
240 })
241 }
242 }
243
244 #[cfg(feature = "opencl")]
246 pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer<u8>, GpuError> {
247 unsafe {
248 Buffer::<u8>::create(&self.context, CL_MEM_READ_WRITE, size, std::ptr::null_mut())
249 .map_err(|e| GpuError::Other(format!("OpenCL memory allocation failed: {e}")))
250 }
251 }
252
253 #[cfg(not(feature = "opencl"))]
255 pub fn allocate_device_memory_2(&self, size: usize) -> Result<CLMem, GpuError> {
256 Ok((0x1000 + size) as CLMem)
258 }
259
260 #[cfg(not(feature = "opencl"))]
262 fn initialize_opencl() -> Result<CLDeviceId, GpuError> {
263 Ok(0x1 as CLDeviceId)
265 }
266
267 #[cfg(not(feature = "opencl"))]
268 fn create_opencl_context(device: CLDeviceId) -> Result<CLContext, GpuError> {
269 Ok(0x2 as CLContext)
271 }
272
273 #[cfg(not(feature = "opencl"))]
274 fn create_command_queue(
275 context: CLContext,
276 device: CLDeviceId,
277 ) -> Result<CLCommandQueue, GpuError> {
278 Ok(0x3 as CLCommandQueue)
280 }
281
282 #[cfg(not(feature = "opencl"))]
283 fn compile_opencl_source(source: &str, name: &str) -> Result<CLProgram, GpuError> {
284 Ok(0x4 as CLProgram)
286 }
287
288 #[cfg(not(feature = "opencl"))]
289 fn create_kernel_from_program(program: CLProgram, name: &str) -> Result<CLKernel, GpuError> {
290 Ok(0x5 as CLKernel)
292 }
293}
294
295impl GpuContextImpl for OpenCLContext {
296 fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
297 if let Ok(mut pool) = self.memory_pool.lock() {
299 if let Some(device_buffer) = pool.allocate(size) {
300 return Arc::new(OpenCLBuffer {
301 #[cfg(feature = "opencl")]
302 device_buffer: UnsafeCell::new(device_buffer),
303 #[cfg(not(feature = "opencl"))]
304 device_buffer,
305 #[cfg(feature = "opencl")]
306 queue: Arc::clone(&self.queue),
307 #[cfg(not(feature = "opencl"))]
308 queue: self.queue,
309 size,
310 memory_pool: Arc::clone(&self.memory_pool),
311 });
312 }
313 }
314
315 let device_buffer = match self.allocate_device_memory(size) {
317 Ok(buffer) => buffer,
318 Err(e) => {
319 eprintln!(
321 "Warning: OpenCL buffer allocation failed ({}), creating CPU fallback buffer",
322 e
323 );
324
325 #[cfg(feature = "opencl")]
326 {
327 return Arc::new(OpenCLCpuFallbackBuffer {
329 data: vec![0u8; size],
330 size,
331 memory_pool: Arc::clone(&self.memory_pool),
332 });
333 }
334 #[cfg(not(feature = "opencl"))]
335 {
336 (0x2000 + size) as CLMem
337 }
338 }
339 };
340
341 Arc::new(OpenCLBuffer {
342 #[cfg(feature = "opencl")]
343 device_buffer: UnsafeCell::new(device_buffer),
344 #[cfg(not(feature = "opencl"))]
345 device_buffer,
346 #[cfg(feature = "opencl")]
347 queue: Arc::clone(&self.queue),
348 #[cfg(not(feature = "opencl"))]
349 queue: self.queue,
350 size,
351 memory_pool: Arc::clone(&self.memory_pool),
352 })
353 }
354
355 fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
356 Arc::new(OpenCLCompiler {
357 context: Arc::new(OpenCLContext {
358 memory_pool: Arc::clone(&self.memory_pool),
359 compiled_kernels: Arc::clone(&self.compiled_kernels),
360 #[cfg(feature = "opencl")]
361 context: Arc::clone(&self.context),
362 #[cfg(feature = "opencl")]
363 device: Arc::clone(&self.device),
364 #[cfg(feature = "opencl")]
365 queue: Arc::clone(&self.queue),
366 #[cfg(not(feature = "opencl"))]
367 context: self.context,
368 #[cfg(not(feature = "opencl"))]
369 device: self.device,
370 #[cfg(not(feature = "opencl"))]
371 queue: self.queue,
372 }),
373 })
374 }
375}
376
377struct OpenCLKernel {
379 #[cfg(feature = "opencl")]
380 kernel: Kernel,
381 #[cfg(feature = "opencl")]
382 queue: Arc<CommandQueue>,
383 #[cfg(not(feature = "opencl"))]
384 program: CLProgram,
385 #[cfg(not(feature = "opencl"))]
386 kernel: CLKernel,
387 #[cfg(not(feature = "opencl"))]
388 queue: CLCommandQueue,
389 #[allow(dead_code)]
390 name: String,
391}
392
393unsafe impl Send for OpenCLKernel {}
395unsafe impl Sync for OpenCLKernel {}
396
397struct OpenCLCompiler {
399 context: Arc<OpenCLContext>,
400}
401
402impl GpuCompilerImpl for OpenCLCompiler {
403 fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
404 let kernel = self.context.compile_kernel_internal(source, "kernel")?;
405 Ok(Arc::new(OpenCLKernelHandle {
406 kernel_name: kernel.name.clone(),
407 compiled_kernels: Arc::clone(&self.context.compiled_kernels),
408 params: Arc::new(Mutex::new(HashMap::new())),
409 }))
410 }
411
412 fn compile_typed(
413 &self,
414 name: &str,
415 _input_type: std::any::TypeId,
416 _output_type: std::any::TypeId,
417 ) -> Arc<dyn GpuKernelImpl> {
418 Arc::new(OpenCLKernelHandle {
419 kernel_name: name.to_string(),
420 compiled_kernels: Arc::clone(&self.context.compiled_kernels),
421 params: Arc::new(Mutex::new(HashMap::new())),
422 })
423 }
424}
425
426struct OpenCLKernelHandle {
428 kernel_name: String,
429 compiled_kernels: Arc<Mutex<HashMap<String, OpenCLKernel>>>,
430 params: Arc<Mutex<HashMap<String, KernelParam>>>,
431}
432
433enum KernelParam {
434 Buffer(Arc<dyn GpuBufferImpl>),
435 U32(u32),
436 I32(i32),
437 F32(f32),
438 F64(f64),
439}
440
441impl GpuKernelImpl for OpenCLKernelHandle {
442 fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
443 let mut params = self.params.lock().expect("Operation failed");
444 params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
445 }
446
447 fn set_u32(&self, name: &str, value: u32) {
448 let mut params = self.params.lock().expect("Operation failed");
449 params.insert(name.to_string(), KernelParam::U32(value));
450 }
451
452 fn set_i32(&self, name: &str, value: i32) {
453 let mut params = self.params.lock().expect("Operation failed");
454 params.insert(name.to_string(), KernelParam::I32(value));
455 }
456
457 fn set_f32(&self, name: &str, value: f32) {
458 let mut params = self.params.lock().expect("Operation failed");
459 params.insert(name.to_string(), KernelParam::F32(value));
460 }
461
462 fn set_f64(&self, name: &str, value: f64) {
463 let mut params = self.params.lock().expect("Operation failed");
464 params.insert(name.to_string(), KernelParam::F64(value));
465 }
466
467 fn dispatch(&self, workgroups: [u32; 3]) {
468 #[cfg(feature = "opencl")]
469 {
470 let kernels = self.compiled_kernels.lock().expect("Operation failed");
472 if let Some(kernel) = kernels.get(&self.kernel_name) {
473 let params = self.params.lock().expect("Operation failed");
474
475 let mut execute_kernel = ExecuteKernel::new(&kernel.kernel);
477 for (_i, param) in params.iter().enumerate() {
478 match param.1 {
479 KernelParam::Buffer(_buffer) => {
480 }
483 KernelParam::U32(val) => {
484 unsafe { execute_kernel.set_arg(val) };
485 }
486 KernelParam::I32(val) => {
487 unsafe { execute_kernel.set_arg(val) };
488 }
489 KernelParam::F32(val) => {
490 unsafe { execute_kernel.set_arg(val) };
491 }
492 KernelParam::F64(val) => {
493 unsafe { execute_kernel.set_arg(val) };
494 }
495 }
496 }
497
498 let event = unsafe {
500 execute_kernel
501 .set_global_work_size(workgroups[0] as usize)
502 .set_local_work_size(64)
503 .enqueue_nd_range(&kernel.queue)
504 };
505 }
506 }
507 #[cfg(not(feature = "opencl"))]
508 {
509 eprintln!("Executing OpenCL kernel {} (simulated)", self.kernel_name);
511 eprintln!("Work groups: {:?}", work_groups);
512 }
513 }
514}
515
516struct OpenCLBuffer {
518 #[cfg(feature = "opencl")]
519 device_buffer: UnsafeCell<Buffer<u8>>,
520 #[cfg(feature = "opencl")]
521 queue: Arc<CommandQueue>,
522 #[cfg(not(feature = "opencl"))]
523 device_buffer: CLMem,
524 #[cfg(not(feature = "opencl"))]
525 queue: CLCommandQueue,
526 size: usize,
527 memory_pool: Arc<Mutex<OpenCLMemoryPool>>,
528}
529
530unsafe impl Send for OpenCLBuffer {}
533unsafe impl Sync for OpenCLBuffer {}
534
535impl GpuBufferImpl for OpenCLBuffer {
536 fn size(&self) -> usize {
537 self.size
538 }
539
540 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
541 #[cfg(feature = "opencl")]
542 {
543 if size > self.size {
545 return;
546 }
547
548 let data_slice = std::slice::from_raw_parts(data, size);
550
551 if let Err(_) = self.queue.enqueue_write_buffer(
554 unsafe { &mut *self.device_buffer.get() },
555 CL_BLOCKING,
556 0,
557 data_slice,
558 &[],
559 ) {
560 }
562 }
563 #[cfg(not(feature = "opencl"))]
564 {
565 let _ = (data, size);
567 }
568 }
569
570 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
571 #[cfg(feature = "opencl")]
572 {
573 if size > self.size {
575 return;
576 }
577
578 let data_slice = std::slice::from_raw_parts_mut(data, size);
580
581 if let Err(_) = self.queue.enqueue_read_buffer(
584 unsafe { &*self.device_buffer.get() },
585 CL_BLOCKING,
586 0,
587 data_slice,
588 &[],
589 ) {
590 }
592 }
593 #[cfg(not(feature = "opencl"))]
594 {
595 let _ = (data, size);
597 }
598 }
599
600 fn as_any(&self) -> &dyn std::any::Any {
601 self
602 }
603}
604
605impl Drop for OpenCLBuffer {
606 fn drop(&mut self) {
607 if let Ok(mut pool) = self.memory_pool.lock() {
609 #[cfg(feature = "opencl")]
610 {
611 }
615 #[cfg(not(feature = "opencl"))]
616 {
617 pool.deallocate(self.device_buffer);
618 }
619 }
620 }
621}
622
623struct OpenCLCpuFallbackBuffer {
626 data: Vec<u8>,
627 size: usize,
628 #[allow(dead_code)]
629 memory_pool: Arc<Mutex<OpenCLMemoryPool>>,
630}
631
632impl GpuBufferImpl for OpenCLCpuFallbackBuffer {
633 fn size(&self) -> usize {
634 self.size
635 }
636
637 unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
638 if size > self.size {
639 eprintln!("Warning: OpenCL CPU fallback buffer copy_from_host size mismatch");
640 return;
641 }
642
643 let data_slice = std::slice::from_raw_parts(data, size);
645 eprintln!(
648 "Warning: CPU fallback buffer copy_from_host called (size: {})",
649 size
650 );
651 }
652
653 unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
654 if size > self.size {
655 eprintln!("Warning: OpenCL CPU fallback buffer copy_to_host size mismatch");
656 return;
657 }
658
659 let data_slice = std::slice::from_raw_parts_mut(data, size);
661 let copy_size = size.min(self.data.len());
662 data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
663
664 eprintln!(
665 "Warning: CPU fallback buffer copy_to_host called (size: {})",
666 size
667 );
668 }
669
670 fn device_ptr(&self) -> u64 {
671 self.data.as_ptr() as u64
672 }
673
674 fn as_any(&self) -> &dyn std::any::Any {
675 self
676 }
677}
678
679unsafe impl Send for OpenCLCpuFallbackBuffer {}
681unsafe impl Sync for OpenCLCpuFallbackBuffer {}
682
683struct OpenCLMemoryPool {
692 #[cfg(feature = "opencl")]
693 available_buffers: HashMap<usize, Vec<Buffer<u8>>>,
694 #[cfg(not(feature = "opencl"))]
695 available_buffers: HashMap<usize, Vec<CLMem>>,
696
697 size_classes: Vec<usize>,
699 allocation_stats: HashMap<usize, AllocationStats>,
700 memory_pressure_threshold: f64,
701 total_size: usize,
702 used_size: usize,
703 peak_used_size: usize,
704 fragmentation_ratio: f64,
705
706 recent_allocations: std::collections::VecDeque<(usize, std::time::Instant)>,
708 hot_sizes: std::collections::BTreeSet<usize>,
709}
710
711#[derive(Debug, Clone)]
713pub struct AllocationStats {
714 total_allocations: u64,
715 total_deallocations: u64,
716 total_bytes_allocated: u64,
717 #[allow(dead_code)]
718 average_lifetime: Duration,
719 peak_concurrent_allocations: u64,
720 current_allocations: u64,
721}
722
723#[derive(Debug, Clone)]
725pub struct PoolStatistics {
726 pub total_size: usize,
727 pub used_size: usize,
728 pub peak_used_size: usize,
729 pub fragmentation_ratio: f64,
730 pub available_buffer_count: usize,
731 pub hot_size_classes: usize,
732 pub allocation_stats: HashMap<usize, AllocationStats>,
733}
734
735impl OpenCLMemoryPool {
736 fn new(totalsize: usize) -> Self {
737 let size_classes = (0..32)
739 .map(|i| 1usize << i)
740 .filter(|&size| size <= totalsize)
741 .collect();
742
743 Self {
744 available_buffers: HashMap::new(),
745 size_classes,
746 allocation_stats: HashMap::new(),
747 memory_pressure_threshold: 0.85, total_size: totalsize,
749 used_size: 0,
750 peak_used_size: 0,
751 fragmentation_ratio: 0.0,
752 recent_allocations: std::collections::VecDeque::with_capacity(1000),
753 hot_sizes: std::collections::BTreeSet::new(),
754 }
755 }
756
757 fn get_size_class(&self, requestedsize: usize) -> usize {
759 self.size_classes
760 .iter()
761 .find(|&&class_size| class_size >= requestedsize)
762 .copied()
763 .unwrap_or_else(|| {
764 ((requestedsize + 4095) / 4096) * 4096
766 })
767 }
768
769 fn update_memory_pressure(&mut self) {
771 let pressure = self.used_size as f64 / self.total_size as f64;
772
773 if pressure > self.memory_pressure_threshold {
774 self.cleanup_cold_buffers();
775 self.defragment_if_needed();
776 }
777
778 let total_available = self
780 .available_buffers
781 .values()
782 .map(|buffers| buffers.len())
783 .sum::<usize>();
784
785 if total_available > 0 {
786 self.fragmentation_ratio =
787 1.0 - (self.used_size as f64 / (self.used_size + total_available * 1024) as f64);
788 }
789 }
790
791 fn cleanup_cold_buffers(&mut self) {
793 let now = std::time::Instant::now();
794 let cold_threshold = Duration::from_secs(30); while let Some(&(_, timestamp)) = self.recent_allocations.front() {
798 if now.duration_since(timestamp) > cold_threshold {
799 self.recent_allocations.pop_front();
800 } else {
801 break;
802 }
803 }
804
805 self.hot_sizes.clear();
807 for &(size, _) in &self.recent_allocations {
808 self.hot_sizes.insert(self.get_size_class(size));
809 }
810
811 for (size_class, buffers) in &mut self.available_buffers {
813 if !self.hot_sizes.contains(size_class) && buffers.len() > 2 {
814 let excess = buffers.len() - 2;
816 for _ in 0..excess {
817 buffers.pop();
818 }
819 }
820 }
821 }
822
823 fn defragment_if_needed(&mut self) {
825 if self.fragmentation_ratio > 0.3 {
826 for buffers in self.available_buffers.values_mut() {
828 if buffers.len() > 4 {
831 buffers.truncate(buffers.len() / 2);
832 }
833 }
834 }
835 }
836
837 fn update_allocation_stats(&mut self, sizeclass: usize, allocated: bool) {
839 let stats = self
840 .allocation_stats
841 .entry(sizeclass)
842 .or_insert_with(|| AllocationStats {
843 total_allocations: 0,
844 total_deallocations: 0,
845 total_bytes_allocated: 0,
846 average_lifetime: Duration::new(0, 0),
847 peak_concurrent_allocations: 0,
848 current_allocations: 0,
849 });
850
851 if allocated {
852 stats.total_allocations += 1;
853 stats.total_bytes_allocated += sizeclass as u64;
854 stats.current_allocations += 1;
855 stats.peak_concurrent_allocations = stats
856 .peak_concurrent_allocations
857 .max(stats.current_allocations);
858 } else {
859 stats.total_deallocations += 1;
860 stats.current_allocations = stats.current_allocations.saturating_sub(1);
861 }
862 }
863
864 #[allow(dead_code)]
866 fn get_pool_statistics(&self) -> PoolStatistics {
867 PoolStatistics {
868 total_size: self.total_size,
869 used_size: self.used_size,
870 peak_used_size: self.peak_used_size,
871 fragmentation_ratio: self.fragmentation_ratio,
872 available_buffer_count: self.available_buffers.values().map(|v| v.len()).sum(),
873 hot_size_classes: self.hot_sizes.len(),
874 allocation_stats: self.allocation_stats.clone(),
875 }
876 }
877
878 #[cfg(feature = "opencl")]
879 fn allocate(&mut self, size: usize) -> Option<Buffer<u8>> {
880 let size_class = self.get_size_class(size);
881
882 if let Some(buffers) = self.available_buffers.get_mut(&size_class) {
884 if let Some(buffer) = buffers.pop() {
885 self.used_size += size_class;
886 self.peak_used_size = self.peak_used_size.max(self.used_size);
887
888 self.recent_allocations
890 .push_back((size, std::time::Instant::now()));
891 if self.recent_allocations.len() > 1000 {
892 self.recent_allocations.pop_front();
893 }
894 self.hot_sizes.insert(size_class);
895
896 self.update_allocation_stats(size_class, true);
898 self.update_memory_pressure();
899
900 return Some(buffer);
901 }
902 }
903 None
904 }
905
906 #[cfg(not(feature = "opencl"))]
907 fn allocate(&mut self, size: usize) -> Option<CLMem> {
908 if let Some(buffers) = self.available_buffers.get_mut(&size) {
910 if let Some(buffer) = buffers.pop() {
911 self.used_size += size;
912 return Some(buffer);
913 }
914 }
915 None
916 }
917
918 #[cfg(feature = "opencl")]
919 #[allow(dead_code)]
920 fn deallocate(&mut self, buffer: Buffer<u8>) {
921 let size = buffer.size().unwrap_or(0);
923 self.available_buffers
924 .entry(size)
925 .or_insert_with(Vec::new)
926 .push(buffer);
927 self.used_size = self.used_size.saturating_sub(size);
928 }
929
930 #[cfg(not(feature = "opencl"))]
931 #[allow(dead_code)]
932 fn deallocate(&mut self, buffer: CLMem) {
933 let size = 1024; self.available_buffers
936 .entry(size)
937 .or_insert_with(Vec::new)
938 .push(buffer);
939 self.used_size = self.used_size.saturating_sub(size);
940 }
941
942 #[allow(dead_code)]
943 fn get_memory_usage(&self) -> (usize, usize) {
944 (self.used_size, self.total_size)
945 }
946}