1use crate::error::{ScirsError, ScirsResult};
7
8use scirs2_core::gpu::{GpuBuffer, GpuContext};
11pub type OptimGpuArray<T> = GpuBuffer<T>;
12pub type OptimGpuBuffer<T> = GpuBuffer<T>;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
18pub struct GpuMemoryInfo {
19 pub total: usize,
20 pub free: usize,
21 pub used: usize,
22}
23
24pub struct GpuMemoryPool {
26 context: Arc<GpuContext>,
27 pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
28 allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
29 memory_limit: Option<usize>,
30 current_usage: Arc<Mutex<usize>>,
31 allocation_stats: Arc<Mutex<AllocationStats>>,
32}
33
34impl GpuMemoryPool {
35 pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
37 Ok(Self {
38 context,
39 pools: Arc::new(Mutex::new(HashMap::new())),
40 allocated_blocks: Arc::new(Mutex::new(Vec::new())),
41 memory_limit,
42 current_usage: Arc::new(Mutex::new(0)),
43 allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
44 })
45 }
46
47 pub fn new_stub() -> Self {
49 use scirs2_core::gpu::GpuBackend;
50 let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
51 Self {
52 context: Arc::new(context),
53 pools: Arc::new(Mutex::new(HashMap::new())),
54 allocated_blocks: Arc::new(Mutex::new(Vec::new())),
55 memory_limit: None,
56 current_usage: Arc::new(Mutex::new(0)),
57 allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
58 }
59 }
60
61 pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
63 let block = self.allocate_block(size)?;
64 Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
65 }
66
67 fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
69 let mut stats = self.allocation_stats.lock().unwrap();
70 stats.total_allocations += 1;
71
72 if let Some(limit) = self.memory_limit {
74 let current = *self.current_usage.lock().unwrap();
75 if current + size > limit {
76 drop(stats);
78 self.garbage_collect()?;
80 stats = self.allocation_stats.lock().unwrap();
82 let current = *self.current_usage.lock().unwrap();
83 if current + size > limit {
84 return Err(ScirsError::MemoryError(
85 scirs2_core::error::ErrorContext::new(format!(
86 "Would exceed memory limit: {} + {} > {}",
87 current, size, limit
88 ))
89 .with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
90 ));
91 }
92 }
93 }
94
95 let mut pools = self.pools.lock().unwrap();
97 if let Some(pool) = pools.get_mut(&size) {
98 if let Some(block) = pool.pop_front() {
99 stats.pool_hits += 1;
100 return Ok(block);
101 }
102 }
103
104 stats.new_allocations += 1;
106 let gpu_buffer = self.context.create_buffer::<u8>(size);
107 let ptr = std::ptr::null_mut();
108 let block = GpuMemoryBlock {
109 size,
110 ptr,
111 gpu_buffer: Some(gpu_buffer),
112 };
113
114 *self.current_usage.lock().unwrap() += size;
116
117 Ok(block)
118 }
119
120 fn return_block(&self, block: GpuMemoryBlock) {
122 let mut pools = self.pools.lock().unwrap();
123 pools
124 .entry(block.size)
125 .or_insert_with(VecDeque::new)
126 .push_back(block);
127 }
128
129 fn garbage_collect(&mut self) -> ScirsResult<()> {
131 let mut pools = self.pools.lock().unwrap();
132 let mut freed_memory = 0;
133
134 for (size, pool) in pools.iter_mut() {
136 let count = pool.len();
137 freed_memory += size * count;
138 pool.clear();
139 }
140
141 *self.current_usage.lock().unwrap() = self
143 .current_usage
144 .lock()
145 .unwrap()
146 .saturating_sub(freed_memory);
147
148 let mut stats = self.allocation_stats.lock().unwrap();
150 stats.garbage_collections += 1;
151 stats.total_freed_memory += freed_memory;
152
153 Ok(())
154 }
155
156 pub fn memory_stats(&self) -> MemoryStats {
158 let current_usage = *self.current_usage.lock().unwrap();
159 let allocation_stats = self.allocation_stats.lock().unwrap().clone();
160 let pool_sizes: HashMap<usize, usize> = self
161 .pools
162 .lock()
163 .unwrap()
164 .iter()
165 .map(|(&size, pool)| (size, pool.len()))
166 .collect();
167
168 MemoryStats {
169 current_usage,
170 memory_limit: self.memory_limit,
171 allocation_stats,
172 pool_sizes,
173 }
174 }
175}
176
177pub struct GpuMemoryBlock {
179 size: usize,
180 ptr: *mut u8,
181 gpu_buffer: Option<OptimGpuBuffer<u8>>,
182}
183
184impl std::fmt::Debug for GpuMemoryBlock {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 f.debug_struct("GpuMemoryBlock")
187 .field("size", &self.size)
188 .field("ptr", &self.ptr)
189 .field("gpu_buffer", &self.gpu_buffer.is_some())
190 .finish()
191 }
192}
193
194unsafe impl Send for GpuMemoryBlock {}
195unsafe impl Sync for GpuMemoryBlock {}
196
197impl GpuMemoryBlock {
198 pub fn size(&self) -> usize {
200 self.size
201 }
202
203 pub fn ptr(&self) -> *mut u8 {
205 self.ptr
206 }
207
208 pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
210 if let Some(ref buffer) = self.gpu_buffer {
211 Err(ScirsError::ComputationError(
214 scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
215 ))
216 } else {
217 Err(ScirsError::InvalidInput(
218 scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
219 ))
220 }
221 }
222}
223
224pub struct GpuWorkspace {
226 blocks: Vec<GpuMemoryBlock>,
227 pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
228}
229
230impl GpuWorkspace {
231 fn new(
232 initial_block: GpuMemoryBlock,
233 pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
234 ) -> Self {
235 Self {
236 blocks: vec![initial_block],
237 pool,
238 }
239 }
240
241 pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
243 for block in &self.blocks {
245 if block.size >= size {
246 return Ok(block);
247 }
248 }
249
250 Err(ScirsError::MemoryError(
254 scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
255 ))
256 }
257
258 pub fn get_buffer<T: scirs2_core::GpuDataType>(
260 &mut self,
261 size: usize,
262 ) -> ScirsResult<&OptimGpuBuffer<T>> {
263 let size_bytes = size * std::mem::size_of::<T>();
264 let block = self.get_block(size_bytes)?;
265 block.as_typed::<T>()
266 }
267
268 pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
270 where
271 T: Clone + Default + 'static + scirs2_core::GpuDataType,
272 {
273 let total_elements: usize = dimensions.iter().product();
274 let buffer = self.get_buffer::<T>(total_elements)?;
275
276 Err(ScirsError::ComputationError(
279 scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
280 ))
281 }
282
283 pub fn total_size(&self) -> usize {
285 self.blocks.iter().map(|b| b.size).sum()
286 }
287}
288
289impl Drop for GpuWorkspace {
290 fn drop(&mut self) {
291 let mut pool = self.pool.lock().unwrap();
293 for block in self.blocks.drain(..) {
294 pool.entry(block.size)
295 .or_insert_with(VecDeque::new)
296 .push_back(block);
297 }
298 }
299}
300
301#[derive(Debug)]
303struct AllocatedBlock {
304 size: usize,
305 allocated_at: std::time::Instant,
306}
307
308#[derive(Debug, Clone)]
310pub struct AllocationStats {
311 pub total_allocations: u64,
313 pub pool_hits: u64,
315 pub new_allocations: u64,
317 pub garbage_collections: u64,
319 pub total_freed_memory: usize,
321}
322
323impl AllocationStats {
324 fn new() -> Self {
325 Self {
326 total_allocations: 0,
327 pool_hits: 0,
328 new_allocations: 0,
329 garbage_collections: 0,
330 total_freed_memory: 0,
331 }
332 }
333
334 pub fn hit_rate(&self) -> f64 {
336 if self.total_allocations == 0 {
337 0.0
338 } else {
339 self.pool_hits as f64 / self.total_allocations as f64
340 }
341 }
342}
343
344#[derive(Debug, Clone)]
346pub struct MemoryStats {
347 pub current_usage: usize,
349 pub memory_limit: Option<usize>,
351 pub allocation_stats: AllocationStats,
353 pub pool_sizes: HashMap<usize, usize>,
355}
356
357impl MemoryStats {
358 pub fn utilization(&self) -> Option<f64> {
360 self.memory_limit.map(|limit| {
361 if limit == 0 {
362 0.0
363 } else {
364 self.current_usage as f64 / limit as f64
365 }
366 })
367 }
368
369 pub fn generate_report(&self) -> String {
371 let mut report = String::from("GPU Memory Usage Report\n");
372 report.push_str("=======================\n\n");
373
374 report.push_str(&format!(
375 "Current Usage: {} bytes ({:.2} MB)\n",
376 self.current_usage,
377 self.current_usage as f64 / 1024.0 / 1024.0
378 ));
379
380 if let Some(limit) = self.memory_limit {
381 report.push_str(&format!(
382 "Memory Limit: {} bytes ({:.2} MB)\n",
383 limit,
384 limit as f64 / 1024.0 / 1024.0
385 ));
386
387 if let Some(util) = self.utilization() {
388 report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
389 }
390 }
391
392 report.push('\n');
393 report.push_str("Allocation Statistics:\n");
394 report.push_str(&format!(
395 " Total Allocations: {}\n",
396 self.allocation_stats.total_allocations
397 ));
398 report.push_str(&format!(
399 " Pool Hits: {} ({:.1}%)\n",
400 self.allocation_stats.pool_hits,
401 self.allocation_stats.hit_rate() * 100.0
402 ));
403 report.push_str(&format!(
404 " New Allocations: {}\n",
405 self.allocation_stats.new_allocations
406 ));
407 report.push_str(&format!(
408 " Garbage Collections: {}\n",
409 self.allocation_stats.garbage_collections
410 ));
411 report.push_str(&format!(
412 " Total Freed: {} bytes\n",
413 self.allocation_stats.total_freed_memory
414 ));
415
416 if !self.pool_sizes.is_empty() {
417 report.push('\n');
418 report.push_str("Memory Pools:\n");
419 let mut pools: Vec<_> = self.pool_sizes.iter().collect();
420 pools.sort_by_key(|&(size_, _)| size_);
421 for (&size, &count) in pools {
422 report.push_str(&format!(" {} bytes: {} blocks\n", size, count));
423 }
424 }
425
426 report
427 }
428}
429
430pub mod optimization {
432 use super::*;
433
434 #[derive(Debug, Clone)]
436 pub struct MemoryOptimizationConfig {
437 pub target_utilization: f64,
439 pub max_pool_size: usize,
441 pub gc_threshold: f64,
443 pub use_prefetching: bool,
445 }
446
447 impl Default for MemoryOptimizationConfig {
448 fn default() -> Self {
449 Self {
450 target_utilization: 0.8,
451 max_pool_size: 100,
452 gc_threshold: 0.9,
453 use_prefetching: true,
454 }
455 }
456 }
457
458 pub struct MemoryOptimizer {
460 config: MemoryOptimizationConfig,
461 pool: Arc<GpuMemoryPool>,
462 optimization_stats: OptimizationStats,
463 }
464
465 impl MemoryOptimizer {
466 pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
468 Self {
469 config,
470 pool,
471 optimization_stats: OptimizationStats::new(),
472 }
473 }
474
475 pub fn optimize(&mut self) -> ScirsResult<()> {
477 let stats = self.pool.memory_stats();
478
479 if let Some(utilization) = stats.utilization() {
481 if utilization > self.config.gc_threshold {
482 self.perform_garbage_collection()?;
483 self.optimization_stats.gc_triggered += 1;
484 }
485 }
486
487 self.optimize_pool_sizes(&stats)?;
489
490 Ok(())
491 }
492
493 fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
495 self.optimization_stats.gc_operations += 1;
498 Ok(())
499 }
500
501 fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
503 for (&_size, &count) in &stats.pool_sizes {
505 if count > self.config.max_pool_size {
506 self.optimization_stats.pool_optimizations += 1;
508 }
509 }
510 Ok(())
511 }
512
513 pub fn stats(&self) -> &OptimizationStats {
515 &self.optimization_stats
516 }
517 }
518
519 #[derive(Debug, Clone)]
521 pub struct OptimizationStats {
522 pub gc_triggered: u64,
524 pub gc_operations: u64,
526 pub pool_optimizations: u64,
528 }
529
530 impl OptimizationStats {
531 fn new() -> Self {
532 Self {
533 gc_triggered: 0,
534 gc_operations: 0,
535 pool_optimizations: 0,
536 }
537 }
538 }
539}
540
541pub mod utils {
543 use super::*;
544
545 pub fn calculate_allocation_strategy(
547 problem_size: usize,
548 batch_size: usize,
549 available_memory: usize,
550 ) -> AllocationStrategy {
551 let estimated_usage = estimate_memory_usage(problem_size, batch_size);
552
553 if estimated_usage > available_memory {
554 AllocationStrategy::Chunked {
555 chunk_size: available_memory / 2,
556 overlap: true,
557 }
558 } else if estimated_usage > available_memory / 2 {
559 AllocationStrategy::Conservative {
560 pool_size_limit: available_memory / 4,
561 }
562 } else {
563 AllocationStrategy::Aggressive {
564 prefetch_size: estimated_usage * 2,
565 }
566 }
567 }
568
569 pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
571 let input_size = batch_size * _problem_size * 8; let output_size = batch_size * 8; let temp_size = input_size; input_size + output_size + temp_size
577 }
578
579 #[derive(Debug, Clone)]
581 pub enum AllocationStrategy {
582 Chunked { chunk_size: usize, overlap: bool },
584 Conservative { pool_size_limit: usize },
586 Aggressive { prefetch_size: usize },
588 }
589
590 pub fn check_memory_availability(
592 required_memory: usize,
593 memory_info: &GpuMemoryInfo,
594 ) -> ScirsResult<bool> {
595 let available = memory_info.free;
596 let safety_margin = 0.1; let usable = (available as f64 * (1.0 - safety_margin)) as usize;
598
599 Ok(required_memory <= usable)
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_allocation_stats() {
609 let mut stats = AllocationStats::new();
610 stats.total_allocations = 100;
611 stats.pool_hits = 70;
612
613 assert_eq!(stats.hit_rate(), 0.7);
614 }
615
616 #[test]
617 fn test_memory_stats_utilization() {
618 let stats = MemoryStats {
619 current_usage: 800,
620 memory_limit: Some(1000),
621 allocation_stats: AllocationStats::new(),
622 pool_sizes: HashMap::new(),
623 };
624
625 assert_eq!(stats.utilization(), Some(0.8));
626 }
627
628 #[test]
629 fn test_memory_usage_estimation() {
630 let usage = utils::estimate_memory_usage(10, 100);
631 assert!(usage > 0);
632
633 let larger_usage = utils::estimate_memory_usage(20, 200);
635 assert!(larger_usage > usage);
636 }
637
638 #[test]
639 fn test_allocation_strategy() {
640 let strategy = utils::calculate_allocation_strategy(
641 1000, 1000, 500_000, );
645
646 match strategy {
647 utils::AllocationStrategy::Chunked { .. } => {
648 }
650 _ => panic!("Expected chunked strategy for large problem"),
651 }
652 }
653}