1use crate::error::{ScirsError, ScirsResult};
7
8use scirs2_core::gpu::{GpuBuffer, GpuContext};
11pub type OptimGpuArray<T> = GpuBuffer<T>;
12pub type OptimGpuBuffer<T> = GpuBuffer<T>;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15
16#[derive(Debug, Clone)]
18pub struct GpuMemoryInfo {
19 pub total: usize,
20 pub free: usize,
21 pub used: usize,
22}
23
24pub struct GpuMemoryPool {
26 context: Arc<GpuContext>,
27 pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
28 allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
29 memory_limit: Option<usize>,
30 current_usage: Arc<Mutex<usize>>,
31 allocation_stats: Arc<Mutex<AllocationStats>>,
32}
33
34impl GpuMemoryPool {
35 pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
37 Ok(Self {
38 context,
39 pools: Arc::new(Mutex::new(HashMap::new())),
40 allocated_blocks: Arc::new(Mutex::new(Vec::new())),
41 memory_limit,
42 current_usage: Arc::new(Mutex::new(0)),
43 allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
44 })
45 }
46
47 pub fn new_stub() -> Self {
49 use scirs2_core::gpu::GpuBackend;
50 let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
51 Self {
52 context: Arc::new(context),
53 pools: Arc::new(Mutex::new(HashMap::new())),
54 allocated_blocks: Arc::new(Mutex::new(Vec::new())),
55 memory_limit: None,
56 current_usage: Arc::new(Mutex::new(0)),
57 allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
58 }
59 }
60
61 pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
63 let block = self.allocate_block(size)?;
64 Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
65 }
66
67 fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
69 let mut stats = self.allocation_stats.lock().expect("Operation failed");
70 stats.total_allocations += 1;
71
72 if let Some(limit) = self.memory_limit {
74 let current = *self.current_usage.lock().expect("Operation failed");
75 if current + size > limit {
76 drop(stats);
78 self.garbage_collect()?;
80 stats = self.allocation_stats.lock().expect("Operation failed");
82 let current = *self.current_usage.lock().expect("Operation failed");
83 if current + size > limit {
84 return Err(ScirsError::MemoryError(
85 scirs2_core::error::ErrorContext::new(format!(
86 "Would exceed memory limit: {} + {} > {}",
87 current, size, limit
88 ))
89 .with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
90 ));
91 }
92 }
93 }
94
95 let mut pools = self.pools.lock().expect("Operation failed");
97 if let Some(pool) = pools.get_mut(&size) {
98 if let Some(block) = pool.pop_front() {
99 stats.pool_hits += 1;
100 return Ok(block);
101 }
102 }
103
104 stats.new_allocations += 1;
106 let gpu_buffer = self.context.create_buffer::<u8>(size);
107 let ptr = std::ptr::null_mut();
108 let block = GpuMemoryBlock {
109 size,
110 ptr,
111 gpu_buffer: Some(gpu_buffer),
112 };
113
114 *self.current_usage.lock().expect("Operation failed") += size;
116
117 Ok(block)
118 }
119
120 fn return_block(&self, block: GpuMemoryBlock) {
122 let mut pools = self.pools.lock().expect("Operation failed");
123 pools.entry(block.size).or_default().push_back(block);
124 }
125
126 fn garbage_collect(&mut self) -> ScirsResult<()> {
128 let mut pools = self.pools.lock().expect("Operation failed");
129 let mut freed_memory = 0;
130
131 for (size, pool) in pools.iter_mut() {
133 let count = pool.len();
134 freed_memory += size * count;
135 pool.clear();
136 }
137
138 *self.current_usage.lock().expect("Operation failed") = self
140 .current_usage
141 .lock()
142 .expect("Operation failed")
143 .saturating_sub(freed_memory);
144
145 let mut stats = self.allocation_stats.lock().expect("Operation failed");
147 stats.garbage_collections += 1;
148 stats.total_freed_memory += freed_memory;
149
150 Ok(())
151 }
152
153 pub fn memory_stats(&self) -> MemoryStats {
155 let current_usage = *self.current_usage.lock().expect("Operation failed");
156 let allocation_stats = self
157 .allocation_stats
158 .lock()
159 .expect("Operation failed")
160 .clone();
161 let pool_sizes: HashMap<usize, usize> = self
162 .pools
163 .lock()
164 .expect("Operation failed")
165 .iter()
166 .map(|(&size, pool)| (size, pool.len()))
167 .collect();
168
169 MemoryStats {
170 current_usage,
171 memory_limit: self.memory_limit,
172 allocation_stats,
173 pool_sizes,
174 }
175 }
176}
177
178pub struct GpuMemoryBlock {
180 size: usize,
181 ptr: *mut u8,
182 gpu_buffer: Option<OptimGpuBuffer<u8>>,
183}
184
185impl std::fmt::Debug for GpuMemoryBlock {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("GpuMemoryBlock")
188 .field("size", &self.size)
189 .field("ptr", &self.ptr)
190 .field("gpu_buffer", &self.gpu_buffer.is_some())
191 .finish()
192 }
193}
194
195unsafe impl Send for GpuMemoryBlock {}
196unsafe impl Sync for GpuMemoryBlock {}
197
198impl GpuMemoryBlock {
199 pub fn size(&self) -> usize {
201 self.size
202 }
203
204 pub fn ptr(&self) -> *mut u8 {
206 self.ptr
207 }
208
209 pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
211 if let Some(ref buffer) = self.gpu_buffer {
212 Err(ScirsError::ComputationError(
215 scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
216 ))
217 } else {
218 Err(ScirsError::InvalidInput(
219 scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
220 ))
221 }
222 }
223}
224
225pub struct GpuWorkspace {
227 blocks: Vec<GpuMemoryBlock>,
228 pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
229}
230
231impl GpuWorkspace {
232 fn new(
233 initial_block: GpuMemoryBlock,
234 pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
235 ) -> Self {
236 Self {
237 blocks: vec![initial_block],
238 pool,
239 }
240 }
241
242 pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
244 for block in &self.blocks {
246 if block.size >= size {
247 return Ok(block);
248 }
249 }
250
251 Err(ScirsError::MemoryError(
255 scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
256 ))
257 }
258
259 pub fn get_buffer<T: scirs2_core::GpuDataType>(
261 &mut self,
262 size: usize,
263 ) -> ScirsResult<&OptimGpuBuffer<T>> {
264 let size_bytes = size * std::mem::size_of::<T>();
265 let block = self.get_block(size_bytes)?;
266 block.as_typed::<T>()
267 }
268
269 pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
271 where
272 T: Clone + Default + 'static + scirs2_core::GpuDataType,
273 {
274 let total_elements: usize = dimensions.iter().product();
275 let buffer = self.get_buffer::<T>(total_elements)?;
276
277 Err(ScirsError::ComputationError(
280 scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
281 ))
282 }
283
284 pub fn total_size(&self) -> usize {
286 self.blocks.iter().map(|b| b.size).sum()
287 }
288}
289
290impl Drop for GpuWorkspace {
291 fn drop(&mut self) {
292 let mut pool = self.pool.lock().expect("Operation failed");
294 for block in self.blocks.drain(..) {
295 pool.entry(block.size).or_default().push_back(block);
296 }
297 }
298}
299
300#[derive(Debug)]
302struct AllocatedBlock {
303 size: usize,
304 allocated_at: std::time::Instant,
305}
306
307#[derive(Debug, Clone)]
309pub struct AllocationStats {
310 pub total_allocations: u64,
312 pub pool_hits: u64,
314 pub new_allocations: u64,
316 pub garbage_collections: u64,
318 pub total_freed_memory: usize,
320}
321
322impl AllocationStats {
323 fn new() -> Self {
324 Self {
325 total_allocations: 0,
326 pool_hits: 0,
327 new_allocations: 0,
328 garbage_collections: 0,
329 total_freed_memory: 0,
330 }
331 }
332
333 pub fn hit_rate(&self) -> f64 {
335 if self.total_allocations == 0 {
336 0.0
337 } else {
338 self.pool_hits as f64 / self.total_allocations as f64
339 }
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct MemoryStats {
346 pub current_usage: usize,
348 pub memory_limit: Option<usize>,
350 pub allocation_stats: AllocationStats,
352 pub pool_sizes: HashMap<usize, usize>,
354}
355
356impl MemoryStats {
357 pub fn utilization(&self) -> Option<f64> {
359 self.memory_limit.map(|limit| {
360 if limit == 0 {
361 0.0
362 } else {
363 self.current_usage as f64 / limit as f64
364 }
365 })
366 }
367
368 pub fn generate_report(&self) -> String {
370 let mut report = String::from("GPU Memory Usage Report\n");
371 report.push_str("=======================\n\n");
372
373 report.push_str(&format!(
374 "Current Usage: {} bytes ({:.2} MB)\n",
375 self.current_usage,
376 self.current_usage as f64 / 1024.0 / 1024.0
377 ));
378
379 if let Some(limit) = self.memory_limit {
380 report.push_str(&format!(
381 "Memory Limit: {} bytes ({:.2} MB)\n",
382 limit,
383 limit as f64 / 1024.0 / 1024.0
384 ));
385
386 if let Some(util) = self.utilization() {
387 report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
388 }
389 }
390
391 report.push('\n');
392 report.push_str("Allocation Statistics:\n");
393 report.push_str(&format!(
394 " Total Allocations: {}\n",
395 self.allocation_stats.total_allocations
396 ));
397 report.push_str(&format!(
398 " Pool Hits: {} ({:.1}%)\n",
399 self.allocation_stats.pool_hits,
400 self.allocation_stats.hit_rate() * 100.0
401 ));
402 report.push_str(&format!(
403 " New Allocations: {}\n",
404 self.allocation_stats.new_allocations
405 ));
406 report.push_str(&format!(
407 " Garbage Collections: {}\n",
408 self.allocation_stats.garbage_collections
409 ));
410 report.push_str(&format!(
411 " Total Freed: {} bytes\n",
412 self.allocation_stats.total_freed_memory
413 ));
414
415 if !self.pool_sizes.is_empty() {
416 report.push('\n');
417 report.push_str("Memory Pools:\n");
418 let mut pools: Vec<_> = self.pool_sizes.iter().collect();
419 pools.sort_by_key(|&(size_, _)| size_);
420 for (&size, &count) in pools {
421 report.push_str(&format!(" {} bytes: {} blocks\n", size, count));
422 }
423 }
424
425 report
426 }
427}
428
429pub mod optimization {
431 use super::*;
432
433 #[derive(Debug, Clone)]
435 pub struct MemoryOptimizationConfig {
436 pub target_utilization: f64,
438 pub max_pool_size: usize,
440 pub gc_threshold: f64,
442 pub use_prefetching: bool,
444 }
445
446 impl Default for MemoryOptimizationConfig {
447 fn default() -> Self {
448 Self {
449 target_utilization: 0.8,
450 max_pool_size: 100,
451 gc_threshold: 0.9,
452 use_prefetching: true,
453 }
454 }
455 }
456
457 pub struct MemoryOptimizer {
459 config: MemoryOptimizationConfig,
460 pool: Arc<GpuMemoryPool>,
461 optimization_stats: OptimizationStats,
462 }
463
464 impl MemoryOptimizer {
465 pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
467 Self {
468 config,
469 pool,
470 optimization_stats: OptimizationStats::new(),
471 }
472 }
473
474 pub fn optimize(&mut self) -> ScirsResult<()> {
476 let stats = self.pool.memory_stats();
477
478 if let Some(utilization) = stats.utilization() {
480 if utilization > self.config.gc_threshold {
481 self.perform_garbage_collection()?;
482 self.optimization_stats.gc_triggered += 1;
483 }
484 }
485
486 self.optimize_pool_sizes(&stats)?;
488
489 Ok(())
490 }
491
492 fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
494 self.optimization_stats.gc_operations += 1;
497 Ok(())
498 }
499
500 fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
502 for (&_size, &count) in &stats.pool_sizes {
504 if count > self.config.max_pool_size {
505 self.optimization_stats.pool_optimizations += 1;
507 }
508 }
509 Ok(())
510 }
511
512 pub fn stats(&self) -> &OptimizationStats {
514 &self.optimization_stats
515 }
516 }
517
518 #[derive(Debug, Clone)]
520 pub struct OptimizationStats {
521 pub gc_triggered: u64,
523 pub gc_operations: u64,
525 pub pool_optimizations: u64,
527 }
528
529 impl OptimizationStats {
530 fn new() -> Self {
531 Self {
532 gc_triggered: 0,
533 gc_operations: 0,
534 pool_optimizations: 0,
535 }
536 }
537 }
538}
539
540pub mod utils {
542 use super::*;
543
544 pub fn calculate_allocation_strategy(
546 problem_size: usize,
547 batch_size: usize,
548 available_memory: usize,
549 ) -> AllocationStrategy {
550 let estimated_usage = estimate_memory_usage(problem_size, batch_size);
551
552 if estimated_usage > available_memory {
553 AllocationStrategy::Chunked {
554 chunk_size: available_memory / 2,
555 overlap: true,
556 }
557 } else if estimated_usage > available_memory / 2 {
558 AllocationStrategy::Conservative {
559 pool_size_limit: available_memory / 4,
560 }
561 } else {
562 AllocationStrategy::Aggressive {
563 prefetch_size: estimated_usage * 2,
564 }
565 }
566 }
567
568 pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
570 let input_size = batch_size * _problem_size * 8; let output_size = batch_size * 8; let temp_size = input_size; input_size + output_size + temp_size
576 }
577
578 #[derive(Debug, Clone)]
580 pub enum AllocationStrategy {
581 Chunked { chunk_size: usize, overlap: bool },
583 Conservative { pool_size_limit: usize },
585 Aggressive { prefetch_size: usize },
587 }
588
589 pub fn check_memory_availability(
591 required_memory: usize,
592 memory_info: &GpuMemoryInfo,
593 ) -> ScirsResult<bool> {
594 let available = memory_info.free;
595 let safety_margin = 0.1; let usable = (available as f64 * (1.0 - safety_margin)) as usize;
597
598 Ok(required_memory <= usable)
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_allocation_stats() {
608 let mut stats = AllocationStats::new();
609 stats.total_allocations = 100;
610 stats.pool_hits = 70;
611
612 assert_eq!(stats.hit_rate(), 0.7);
613 }
614
615 #[test]
616 fn test_memory_stats_utilization() {
617 let stats = MemoryStats {
618 current_usage: 800,
619 memory_limit: Some(1000),
620 allocation_stats: AllocationStats::new(),
621 pool_sizes: HashMap::new(),
622 };
623
624 assert_eq!(stats.utilization(), Some(0.8));
625 }
626
627 #[test]
628 fn test_memory_usage_estimation() {
629 let usage = utils::estimate_memory_usage(10, 100);
630 assert!(usage > 0);
631
632 let larger_usage = utils::estimate_memory_usage(20, 200);
634 assert!(larger_usage > usage);
635 }
636
637 #[test]
638 fn test_allocation_strategy() {
639 let strategy = utils::calculate_allocation_strategy(
640 1000, 1000, 500_000, );
644
645 match strategy {
646 utils::AllocationStrategy::Chunked { .. } => {
647 }
649 _ => panic!("Expected chunked strategy for large problem"),
650 }
651 }
652}