1use crate::error::{ScirsError, ScirsResult};
7
8use scirs2_core::gpu::{GpuBuffer, GpuContext};
11pub type OptimGpuArray<T> = GpuBuffer<T>;
12pub type OptimGpuBuffer<T> = GpuBuffer<T>;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
18pub struct GpuMemoryInfo {
19 pub total: usize,
20 pub free: usize,
21 pub used: usize,
22}
23
24pub struct GpuMemoryPool {
26 context: Arc<GpuContext>,
27 pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
28 allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
29 memory_limit: Option<usize>,
30 current_usage: Arc<Mutex<usize>>,
31 allocation_stats: Arc<Mutex<AllocationStats>>,
32}
33
34impl GpuMemoryPool {
35 pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
37 Ok(Self {
38 context,
39 pools: Arc::new(Mutex::new(HashMap::new())),
40 allocated_blocks: Arc::new(Mutex::new(Vec::new())),
41 memory_limit,
42 current_usage: Arc::new(Mutex::new(0)),
43 allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
44 })
45 }
46
47 pub fn new_stub() -> Self {
49 use scirs2_core::gpu::GpuBackend;
50 let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
51 Self {
52 context: Arc::new(context),
53 pools: Arc::new(Mutex::new(HashMap::new())),
54 allocated_blocks: Arc::new(Mutex::new(Vec::new())),
55 memory_limit: None,
56 current_usage: Arc::new(Mutex::new(0)),
57 allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
58 }
59 }
60
61 pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
63 let block = self.allocate_block(size)?;
64 Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
65 }
66
67 fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
69 let mut stats = self.allocation_stats.lock().unwrap();
70 stats.total_allocations += 1;
71
72 if let Some(limit) = self.memory_limit {
74 let current = *self.current_usage.lock().unwrap();
75 if current + size > limit {
76 drop(stats);
78 self.garbage_collect()?;
80 stats = self.allocation_stats.lock().unwrap();
82 let current = *self.current_usage.lock().unwrap();
83 if current + size > limit {
84 return Err(ScirsError::MemoryError(
85 scirs2_core::error::ErrorContext::new(format!(
86 "Would exceed memory limit: {} + {} > {}",
87 current, size, limit
88 ))
89 .with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
90 ));
91 }
92 }
93 }
94
95 let mut pools = self.pools.lock().unwrap();
97 if let Some(pool) = pools.get_mut(&size) {
98 if let Some(block) = pool.pop_front() {
99 stats.pool_hits += 1;
100 return Ok(block);
101 }
102 }
103
104 stats.new_allocations += 1;
106 let gpu_buffer = self.context.create_buffer::<u8>(size);
107 let ptr = std::ptr::null_mut();
108 let block = GpuMemoryBlock {
109 size,
110 ptr,
111 gpu_buffer: Some(gpu_buffer),
112 };
113
114 *self.current_usage.lock().unwrap() += size;
116
117 Ok(block)
118 }
119
120 fn return_block(&self, block: GpuMemoryBlock) {
122 let mut pools = self.pools.lock().unwrap();
123 pools.entry(block.size).or_default().push_back(block);
124 }
125
126 fn garbage_collect(&mut self) -> ScirsResult<()> {
128 let mut pools = self.pools.lock().unwrap();
129 let mut freed_memory = 0;
130
131 for (size, pool) in pools.iter_mut() {
133 let count = pool.len();
134 freed_memory += size * count;
135 pool.clear();
136 }
137
138 *self.current_usage.lock().unwrap() = self
140 .current_usage
141 .lock()
142 .unwrap()
143 .saturating_sub(freed_memory);
144
145 let mut stats = self.allocation_stats.lock().unwrap();
147 stats.garbage_collections += 1;
148 stats.total_freed_memory += freed_memory;
149
150 Ok(())
151 }
152
153 pub fn memory_stats(&self) -> MemoryStats {
155 let current_usage = *self.current_usage.lock().unwrap();
156 let allocation_stats = self.allocation_stats.lock().unwrap().clone();
157 let pool_sizes: HashMap<usize, usize> = self
158 .pools
159 .lock()
160 .unwrap()
161 .iter()
162 .map(|(&size, pool)| (size, pool.len()))
163 .collect();
164
165 MemoryStats {
166 current_usage,
167 memory_limit: self.memory_limit,
168 allocation_stats,
169 pool_sizes,
170 }
171 }
172}
173
174pub struct GpuMemoryBlock {
176 size: usize,
177 ptr: *mut u8,
178 gpu_buffer: Option<OptimGpuBuffer<u8>>,
179}
180
181impl std::fmt::Debug for GpuMemoryBlock {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("GpuMemoryBlock")
184 .field("size", &self.size)
185 .field("ptr", &self.ptr)
186 .field("gpu_buffer", &self.gpu_buffer.is_some())
187 .finish()
188 }
189}
190
191unsafe impl Send for GpuMemoryBlock {}
192unsafe impl Sync for GpuMemoryBlock {}
193
194impl GpuMemoryBlock {
195 pub fn size(&self) -> usize {
197 self.size
198 }
199
200 pub fn ptr(&self) -> *mut u8 {
202 self.ptr
203 }
204
205 pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
207 if let Some(ref buffer) = self.gpu_buffer {
208 Err(ScirsError::ComputationError(
211 scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
212 ))
213 } else {
214 Err(ScirsError::InvalidInput(
215 scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
216 ))
217 }
218 }
219}
220
221pub struct GpuWorkspace {
223 blocks: Vec<GpuMemoryBlock>,
224 pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
225}
226
227impl GpuWorkspace {
228 fn new(
229 initial_block: GpuMemoryBlock,
230 pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
231 ) -> Self {
232 Self {
233 blocks: vec![initial_block],
234 pool,
235 }
236 }
237
238 pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
240 for block in &self.blocks {
242 if block.size >= size {
243 return Ok(block);
244 }
245 }
246
247 Err(ScirsError::MemoryError(
251 scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
252 ))
253 }
254
255 pub fn get_buffer<T: scirs2_core::GpuDataType>(
257 &mut self,
258 size: usize,
259 ) -> ScirsResult<&OptimGpuBuffer<T>> {
260 let size_bytes = size * std::mem::size_of::<T>();
261 let block = self.get_block(size_bytes)?;
262 block.as_typed::<T>()
263 }
264
265 pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
267 where
268 T: Clone + Default + 'static + scirs2_core::GpuDataType,
269 {
270 let total_elements: usize = dimensions.iter().product();
271 let buffer = self.get_buffer::<T>(total_elements)?;
272
273 Err(ScirsError::ComputationError(
276 scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
277 ))
278 }
279
280 pub fn total_size(&self) -> usize {
282 self.blocks.iter().map(|b| b.size).sum()
283 }
284}
285
286impl Drop for GpuWorkspace {
287 fn drop(&mut self) {
288 let mut pool = self.pool.lock().unwrap();
290 for block in self.blocks.drain(..) {
291 pool.entry(block.size).or_default().push_back(block);
292 }
293 }
294}
295
296#[derive(Debug)]
298struct AllocatedBlock {
299 size: usize,
300 allocated_at: std::time::Instant,
301}
302
303#[derive(Debug, Clone)]
305pub struct AllocationStats {
306 pub total_allocations: u64,
308 pub pool_hits: u64,
310 pub new_allocations: u64,
312 pub garbage_collections: u64,
314 pub total_freed_memory: usize,
316}
317
318impl AllocationStats {
319 fn new() -> Self {
320 Self {
321 total_allocations: 0,
322 pool_hits: 0,
323 new_allocations: 0,
324 garbage_collections: 0,
325 total_freed_memory: 0,
326 }
327 }
328
329 pub fn hit_rate(&self) -> f64 {
331 if self.total_allocations == 0 {
332 0.0
333 } else {
334 self.pool_hits as f64 / self.total_allocations as f64
335 }
336 }
337}
338
339#[derive(Debug, Clone)]
341pub struct MemoryStats {
342 pub current_usage: usize,
344 pub memory_limit: Option<usize>,
346 pub allocation_stats: AllocationStats,
348 pub pool_sizes: HashMap<usize, usize>,
350}
351
352impl MemoryStats {
353 pub fn utilization(&self) -> Option<f64> {
355 self.memory_limit.map(|limit| {
356 if limit == 0 {
357 0.0
358 } else {
359 self.current_usage as f64 / limit as f64
360 }
361 })
362 }
363
364 pub fn generate_report(&self) -> String {
366 let mut report = String::from("GPU Memory Usage Report\n");
367 report.push_str("=======================\n\n");
368
369 report.push_str(&format!(
370 "Current Usage: {} bytes ({:.2} MB)\n",
371 self.current_usage,
372 self.current_usage as f64 / 1024.0 / 1024.0
373 ));
374
375 if let Some(limit) = self.memory_limit {
376 report.push_str(&format!(
377 "Memory Limit: {} bytes ({:.2} MB)\n",
378 limit,
379 limit as f64 / 1024.0 / 1024.0
380 ));
381
382 if let Some(util) = self.utilization() {
383 report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
384 }
385 }
386
387 report.push('\n');
388 report.push_str("Allocation Statistics:\n");
389 report.push_str(&format!(
390 " Total Allocations: {}\n",
391 self.allocation_stats.total_allocations
392 ));
393 report.push_str(&format!(
394 " Pool Hits: {} ({:.1}%)\n",
395 self.allocation_stats.pool_hits,
396 self.allocation_stats.hit_rate() * 100.0
397 ));
398 report.push_str(&format!(
399 " New Allocations: {}\n",
400 self.allocation_stats.new_allocations
401 ));
402 report.push_str(&format!(
403 " Garbage Collections: {}\n",
404 self.allocation_stats.garbage_collections
405 ));
406 report.push_str(&format!(
407 " Total Freed: {} bytes\n",
408 self.allocation_stats.total_freed_memory
409 ));
410
411 if !self.pool_sizes.is_empty() {
412 report.push('\n');
413 report.push_str("Memory Pools:\n");
414 let mut pools: Vec<_> = self.pool_sizes.iter().collect();
415 pools.sort_by_key(|&(size_, _)| size_);
416 for (&size, &count) in pools {
417 report.push_str(&format!(" {} bytes: {} blocks\n", size, count));
418 }
419 }
420
421 report
422 }
423}
424
425pub mod optimization {
427 use super::*;
428
429 #[derive(Debug, Clone)]
431 pub struct MemoryOptimizationConfig {
432 pub target_utilization: f64,
434 pub max_pool_size: usize,
436 pub gc_threshold: f64,
438 pub use_prefetching: bool,
440 }
441
442 impl Default for MemoryOptimizationConfig {
443 fn default() -> Self {
444 Self {
445 target_utilization: 0.8,
446 max_pool_size: 100,
447 gc_threshold: 0.9,
448 use_prefetching: true,
449 }
450 }
451 }
452
453 pub struct MemoryOptimizer {
455 config: MemoryOptimizationConfig,
456 pool: Arc<GpuMemoryPool>,
457 optimization_stats: OptimizationStats,
458 }
459
460 impl MemoryOptimizer {
461 pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
463 Self {
464 config,
465 pool,
466 optimization_stats: OptimizationStats::new(),
467 }
468 }
469
470 pub fn optimize(&mut self) -> ScirsResult<()> {
472 let stats = self.pool.memory_stats();
473
474 if let Some(utilization) = stats.utilization() {
476 if utilization > self.config.gc_threshold {
477 self.perform_garbage_collection()?;
478 self.optimization_stats.gc_triggered += 1;
479 }
480 }
481
482 self.optimize_pool_sizes(&stats)?;
484
485 Ok(())
486 }
487
488 fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
490 self.optimization_stats.gc_operations += 1;
493 Ok(())
494 }
495
496 fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
498 for (&_size, &count) in &stats.pool_sizes {
500 if count > self.config.max_pool_size {
501 self.optimization_stats.pool_optimizations += 1;
503 }
504 }
505 Ok(())
506 }
507
508 pub fn stats(&self) -> &OptimizationStats {
510 &self.optimization_stats
511 }
512 }
513
514 #[derive(Debug, Clone)]
516 pub struct OptimizationStats {
517 pub gc_triggered: u64,
519 pub gc_operations: u64,
521 pub pool_optimizations: u64,
523 }
524
525 impl OptimizationStats {
526 fn new() -> Self {
527 Self {
528 gc_triggered: 0,
529 gc_operations: 0,
530 pool_optimizations: 0,
531 }
532 }
533 }
534}
535
536pub mod utils {
538 use super::*;
539
540 pub fn calculate_allocation_strategy(
542 problem_size: usize,
543 batch_size: usize,
544 available_memory: usize,
545 ) -> AllocationStrategy {
546 let estimated_usage = estimate_memory_usage(problem_size, batch_size);
547
548 if estimated_usage > available_memory {
549 AllocationStrategy::Chunked {
550 chunk_size: available_memory / 2,
551 overlap: true,
552 }
553 } else if estimated_usage > available_memory / 2 {
554 AllocationStrategy::Conservative {
555 pool_size_limit: available_memory / 4,
556 }
557 } else {
558 AllocationStrategy::Aggressive {
559 prefetch_size: estimated_usage * 2,
560 }
561 }
562 }
563
564 pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
566 let input_size = batch_size * _problem_size * 8; let output_size = batch_size * 8; let temp_size = input_size; input_size + output_size + temp_size
572 }
573
574 #[derive(Debug, Clone)]
576 pub enum AllocationStrategy {
577 Chunked { chunk_size: usize, overlap: bool },
579 Conservative { pool_size_limit: usize },
581 Aggressive { prefetch_size: usize },
583 }
584
585 pub fn check_memory_availability(
587 required_memory: usize,
588 memory_info: &GpuMemoryInfo,
589 ) -> ScirsResult<bool> {
590 let available = memory_info.free;
591 let safety_margin = 0.1; let usable = (available as f64 * (1.0 - safety_margin)) as usize;
593
594 Ok(required_memory <= usable)
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn test_allocation_stats() {
604 let mut stats = AllocationStats::new();
605 stats.total_allocations = 100;
606 stats.pool_hits = 70;
607
608 assert_eq!(stats.hit_rate(), 0.7);
609 }
610
611 #[test]
612 fn test_memory_stats_utilization() {
613 let stats = MemoryStats {
614 current_usage: 800,
615 memory_limit: Some(1000),
616 allocation_stats: AllocationStats::new(),
617 pool_sizes: HashMap::new(),
618 };
619
620 assert_eq!(stats.utilization(), Some(0.8));
621 }
622
623 #[test]
624 fn test_memory_usage_estimation() {
625 let usage = utils::estimate_memory_usage(10, 100);
626 assert!(usage > 0);
627
628 let larger_usage = utils::estimate_memory_usage(20, 200);
630 assert!(larger_usage > usage);
631 }
632
633 #[test]
634 fn test_allocation_strategy() {
635 let strategy = utils::calculate_allocation_strategy(
636 1000, 1000, 500_000, );
640
641 match strategy {
642 utils::AllocationStrategy::Chunked { .. } => {
643 }
645 _ => panic!("Expected chunked strategy for large problem"),
646 }
647 }
648}