scirs2_optimize/gpu/
memory_management.rs

1//! GPU memory management for optimization workloads
2//!
3//! This module provides efficient memory management for GPU-accelerated optimization,
4//! including memory pools, workspace allocation, and automatic memory optimization.
5
6use crate::error::{ScirsError, ScirsResult};
7
8// Note: Error conversion handled through scirs2_core::error system
9// GPU errors are automatically converted via CoreError type alias
10use scirs2_core::gpu::{GpuBuffer, GpuContext};
11pub type OptimGpuArray<T> = GpuBuffer<T>;
12pub type OptimGpuBuffer<T> = GpuBuffer<T>;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15
16/// GPU memory information structure
17#[derive(Debug, Clone)]
18pub struct GpuMemoryInfo {
19    pub total: usize,
20    pub free: usize,
21    pub used: usize,
22}
23
24/// GPU memory pool for efficient allocation and reuse
25pub struct GpuMemoryPool {
26    context: Arc<GpuContext>,
27    pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
28    allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
29    memory_limit: Option<usize>,
30    current_usage: Arc<Mutex<usize>>,
31    allocation_stats: Arc<Mutex<AllocationStats>>,
32}
33
34impl GpuMemoryPool {
35    /// Create a new GPU memory pool
36    pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
37        Ok(Self {
38            context,
39            pools: Arc::new(Mutex::new(HashMap::new())),
40            allocated_blocks: Arc::new(Mutex::new(Vec::new())),
41            memory_limit,
42            current_usage: Arc::new(Mutex::new(0)),
43            allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
44        })
45    }
46
47    /// Create a stub GPU memory pool (fallback for incomplete implementations)
48    pub fn new_stub() -> Self {
49        use scirs2_core::gpu::GpuBackend;
50        let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
51        Self {
52            context: Arc::new(context),
53            pools: Arc::new(Mutex::new(HashMap::new())),
54            allocated_blocks: Arc::new(Mutex::new(Vec::new())),
55            memory_limit: None,
56            current_usage: Arc::new(Mutex::new(0)),
57            allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
58        }
59    }
60
61    /// Allocate a workspace for optimization operations
62    pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
63        let block = self.allocate_block(size)?;
64        Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
65    }
66
67    /// Allocate a memory block of the specified size
68    fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
69        let mut stats = self.allocation_stats.lock().unwrap();
70        stats.total_allocations += 1;
71
72        // Check memory limit
73        if let Some(limit) = self.memory_limit {
74            let current = *self.current_usage.lock().unwrap();
75            if current + size > limit {
76                // Drop the stats lock before garbage collection
77                drop(stats);
78                // Try to free some memory
79                self.garbage_collect()?;
80                // Reacquire the lock
81                stats = self.allocation_stats.lock().unwrap();
82                let current = *self.current_usage.lock().unwrap();
83                if current + size > limit {
84                    return Err(ScirsError::MemoryError(
85                        scirs2_core::error::ErrorContext::new(format!(
86                            "Would exceed memory limit: {} + {} > {}",
87                            current, size, limit
88                        ))
89                        .with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
90                    ));
91                }
92            }
93        }
94
95        // Try to reuse existing block from pool
96        let mut pools = self.pools.lock().unwrap();
97        if let Some(pool) = pools.get_mut(&size) {
98            if let Some(block) = pool.pop_front() {
99                stats.pool_hits += 1;
100                return Ok(block);
101            }
102        }
103
104        // Allocate new block
105        stats.new_allocations += 1;
106        let gpu_buffer = self.context.create_buffer::<u8>(size);
107        let ptr = std::ptr::null_mut();
108        let block = GpuMemoryBlock {
109            size,
110            ptr,
111            gpu_buffer: Some(gpu_buffer),
112        };
113
114        // Update current usage
115        *self.current_usage.lock().unwrap() += size;
116
117        Ok(block)
118    }
119
120    /// Return a block to the pool for reuse
121    fn return_block(&self, block: GpuMemoryBlock) {
122        let mut pools = self.pools.lock().unwrap();
123        pools.entry(block.size).or_default().push_back(block);
124    }
125
126    /// Perform garbage collection to free unused memory
127    fn garbage_collect(&mut self) -> ScirsResult<()> {
128        let mut pools = self.pools.lock().unwrap();
129        let mut freed_memory = 0;
130
131        // Clear all pools
132        for (size, pool) in pools.iter_mut() {
133            let count = pool.len();
134            freed_memory += size * count;
135            pool.clear();
136        }
137
138        // Update current usage
139        *self.current_usage.lock().unwrap() = self
140            .current_usage
141            .lock()
142            .unwrap()
143            .saturating_sub(freed_memory);
144
145        // Update stats
146        let mut stats = self.allocation_stats.lock().unwrap();
147        stats.garbage_collections += 1;
148        stats.total_freed_memory += freed_memory;
149
150        Ok(())
151    }
152
153    /// Get current memory usage statistics
154    pub fn memory_stats(&self) -> MemoryStats {
155        let current_usage = *self.current_usage.lock().unwrap();
156        let allocation_stats = self.allocation_stats.lock().unwrap().clone();
157        let pool_sizes: HashMap<usize, usize> = self
158            .pools
159            .lock()
160            .unwrap()
161            .iter()
162            .map(|(&size, pool)| (size, pool.len()))
163            .collect();
164
165        MemoryStats {
166            current_usage,
167            memory_limit: self.memory_limit,
168            allocation_stats,
169            pool_sizes,
170        }
171    }
172}
173
174/// A block of GPU memory
175pub struct GpuMemoryBlock {
176    size: usize,
177    ptr: *mut u8,
178    gpu_buffer: Option<OptimGpuBuffer<u8>>,
179}
180
181impl std::fmt::Debug for GpuMemoryBlock {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("GpuMemoryBlock")
184            .field("size", &self.size)
185            .field("ptr", &self.ptr)
186            .field("gpu_buffer", &self.gpu_buffer.is_some())
187            .finish()
188    }
189}
190
191unsafe impl Send for GpuMemoryBlock {}
192unsafe impl Sync for GpuMemoryBlock {}
193
194impl GpuMemoryBlock {
195    /// Get the size of this memory block
196    pub fn size(&self) -> usize {
197        self.size
198    }
199
200    /// Get the raw pointer to GPU memory
201    pub fn ptr(&self) -> *mut u8 {
202        self.ptr
203    }
204
205    /// Cast to a specific type
206    pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
207        if let Some(ref buffer) = self.gpu_buffer {
208            // Safe casting through scirs2-_core's type system
209            // Since cast_type doesn't exist, return an error for now
210            Err(ScirsError::ComputationError(
211                scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
212            ))
213        } else {
214            Err(ScirsError::InvalidInput(
215                scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
216            ))
217        }
218    }
219}
220
221/// Workspace for GPU optimization operations
222pub struct GpuWorkspace {
223    blocks: Vec<GpuMemoryBlock>,
224    pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
225}
226
227impl GpuWorkspace {
228    fn new(
229        initial_block: GpuMemoryBlock,
230        pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
231    ) -> Self {
232        Self {
233            blocks: vec![initial_block],
234            pool,
235        }
236    }
237
238    /// Get a memory block of the specified size
239    pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
240        // Try to find existing block of sufficient size
241        for block in &self.blocks {
242            if block.size >= size {
243                return Ok(block);
244            }
245        }
246
247        // Need to allocate new block
248        // For simplicity, we'll just return an error here
249        // In a full implementation, this would allocate from the pool
250        Err(ScirsError::MemoryError(
251            scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
252        ))
253    }
254
255    /// Get a typed buffer view of the specified size
256    pub fn get_buffer<T: scirs2_core::GpuDataType>(
257        &mut self,
258        size: usize,
259    ) -> ScirsResult<&OptimGpuBuffer<T>> {
260        let size_bytes = size * std::mem::size_of::<T>();
261        let block = self.get_block(size_bytes)?;
262        block.as_typed::<T>()
263    }
264
265    /// Create a GPU array view from the workspace
266    pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
267    where
268        T: Clone + Default + 'static + scirs2_core::GpuDataType,
269    {
270        let total_elements: usize = dimensions.iter().product();
271        let buffer = self.get_buffer::<T>(total_elements)?;
272
273        // Convert buffer to array using scirs2-_core's reshape functionality
274        // Since from_buffer doesn't exist, return an error for now
275        Err(ScirsError::ComputationError(
276            scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
277        ))
278    }
279
280    /// Get total workspace size
281    pub fn total_size(&self) -> usize {
282        self.blocks.iter().map(|b| b.size).sum()
283    }
284}
285
286impl Drop for GpuWorkspace {
287    fn drop(&mut self) {
288        // Return all blocks to the pool
289        let mut pool = self.pool.lock().unwrap();
290        for block in self.blocks.drain(..) {
291            pool.entry(block.size).or_default().push_back(block);
292        }
293    }
294}
295
296/// Tracks allocated memory blocks
297#[derive(Debug)]
298struct AllocatedBlock {
299    size: usize,
300    allocated_at: std::time::Instant,
301}
302
303/// Statistics for memory allocations
304#[derive(Debug, Clone)]
305pub struct AllocationStats {
306    /// Total number of allocation requests
307    pub total_allocations: u64,
308    /// Number of allocations served from pool
309    pub pool_hits: u64,
310    /// Number of new allocations
311    pub new_allocations: u64,
312    /// Number of garbage collections performed
313    pub garbage_collections: u64,
314    /// Total memory freed by garbage collection
315    pub total_freed_memory: usize,
316}
317
318impl AllocationStats {
319    fn new() -> Self {
320        Self {
321            total_allocations: 0,
322            pool_hits: 0,
323            new_allocations: 0,
324            garbage_collections: 0,
325            total_freed_memory: 0,
326        }
327    }
328
329    /// Calculate pool hit rate
330    pub fn hit_rate(&self) -> f64 {
331        if self.total_allocations == 0 {
332            0.0
333        } else {
334            self.pool_hits as f64 / self.total_allocations as f64
335        }
336    }
337}
338
339/// Overall memory usage statistics
340#[derive(Debug, Clone)]
341pub struct MemoryStats {
342    /// Current memory usage in bytes
343    pub current_usage: usize,
344    /// Memory limit (if set)
345    pub memory_limit: Option<usize>,
346    /// Allocation statistics
347    pub allocation_stats: AllocationStats,
348    /// Size of each memory pool
349    pub pool_sizes: HashMap<usize, usize>,
350}
351
352impl MemoryStats {
353    /// Get memory utilization as a percentage (if limit is set)
354    pub fn utilization(&self) -> Option<f64> {
355        self.memory_limit.map(|limit| {
356            if limit == 0 {
357                0.0
358            } else {
359                self.current_usage as f64 / limit as f64
360            }
361        })
362    }
363
364    /// Generate a memory usage report
365    pub fn generate_report(&self) -> String {
366        let mut report = String::from("GPU Memory Usage Report\n");
367        report.push_str("=======================\n\n");
368
369        report.push_str(&format!(
370            "Current Usage: {} bytes ({:.2} MB)\n",
371            self.current_usage,
372            self.current_usage as f64 / 1024.0 / 1024.0
373        ));
374
375        if let Some(limit) = self.memory_limit {
376            report.push_str(&format!(
377                "Memory Limit: {} bytes ({:.2} MB)\n",
378                limit,
379                limit as f64 / 1024.0 / 1024.0
380            ));
381
382            if let Some(util) = self.utilization() {
383                report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
384            }
385        }
386
387        report.push('\n');
388        report.push_str("Allocation Statistics:\n");
389        report.push_str(&format!(
390            "  Total Allocations: {}\n",
391            self.allocation_stats.total_allocations
392        ));
393        report.push_str(&format!(
394            "  Pool Hits: {} ({:.1}%)\n",
395            self.allocation_stats.pool_hits,
396            self.allocation_stats.hit_rate() * 100.0
397        ));
398        report.push_str(&format!(
399            "  New Allocations: {}\n",
400            self.allocation_stats.new_allocations
401        ));
402        report.push_str(&format!(
403            "  Garbage Collections: {}\n",
404            self.allocation_stats.garbage_collections
405        ));
406        report.push_str(&format!(
407            "  Total Freed: {} bytes\n",
408            self.allocation_stats.total_freed_memory
409        ));
410
411        if !self.pool_sizes.is_empty() {
412            report.push('\n');
413            report.push_str("Memory Pools:\n");
414            let mut pools: Vec<_> = self.pool_sizes.iter().collect();
415            pools.sort_by_key(|&(size_, _)| size_);
416            for (&size, &count) in pools {
417                report.push_str(&format!("  {} bytes: {} blocks\n", size, count));
418            }
419        }
420
421        report
422    }
423}
424
425/// Memory optimization strategies
426pub mod optimization {
427    use super::*;
428
429    /// Automatic memory optimization configuration
430    #[derive(Debug, Clone)]
431    pub struct MemoryOptimizationConfig {
432        /// Target memory utilization (0.0 to 1.0)
433        pub target_utilization: f64,
434        /// Maximum pool size per block size
435        pub max_pool_size: usize,
436        /// Garbage collection threshold (utilization percentage)
437        pub gc_threshold: f64,
438        /// Whether to use memory prefetching
439        pub use_prefetching: bool,
440    }
441
442    impl Default for MemoryOptimizationConfig {
443        fn default() -> Self {
444            Self {
445                target_utilization: 0.8,
446                max_pool_size: 100,
447                gc_threshold: 0.9,
448                use_prefetching: true,
449            }
450        }
451    }
452
453    /// Memory optimizer for automatic memory management
454    pub struct MemoryOptimizer {
455        config: MemoryOptimizationConfig,
456        pool: Arc<GpuMemoryPool>,
457        optimization_stats: OptimizationStats,
458    }
459
460    impl MemoryOptimizer {
461        /// Create a new memory optimizer
462        pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
463            Self {
464                config,
465                pool,
466                optimization_stats: OptimizationStats::new(),
467            }
468        }
469
470        /// Optimize memory usage based on current statistics
471        pub fn optimize(&mut self) -> ScirsResult<()> {
472            let stats = self.pool.memory_stats();
473
474            // Check if we need garbage collection
475            if let Some(utilization) = stats.utilization() {
476                if utilization > self.config.gc_threshold {
477                    self.perform_garbage_collection()?;
478                    self.optimization_stats.gc_triggered += 1;
479                }
480            }
481
482            // Optimize pool sizes
483            self.optimize_pool_sizes(&stats)?;
484
485            Ok(())
486        }
487
488        /// Perform targeted garbage collection
489        fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
490            // This would implement smart garbage collection
491            // For now, we'll just trigger a basic GC
492            self.optimization_stats.gc_operations += 1;
493            Ok(())
494        }
495
496        /// Optimize memory pool sizes based on usage patterns
497        fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
498            // Analyze usage patterns and adjust pool sizes
499            for (&_size, &count) in &stats.pool_sizes {
500                if count > self.config.max_pool_size {
501                    // Pool is too large, consider reducing
502                    self.optimization_stats.pool_optimizations += 1;
503                }
504            }
505            Ok(())
506        }
507
508        /// Get optimization statistics
509        pub fn stats(&self) -> &OptimizationStats {
510            &self.optimization_stats
511        }
512    }
513
514    /// Statistics for memory optimization
515    #[derive(Debug, Clone)]
516    pub struct OptimizationStats {
517        /// Number of times GC was triggered by optimizer
518        pub gc_triggered: u64,
519        /// Total GC operations performed
520        pub gc_operations: u64,
521        /// Pool size optimizations performed
522        pub pool_optimizations: u64,
523    }
524
525    impl OptimizationStats {
526        fn new() -> Self {
527            Self {
528                gc_triggered: 0,
529                gc_operations: 0,
530                pool_optimizations: 0,
531            }
532        }
533    }
534}
535
536/// Utilities for memory management
537pub mod utils {
538    use super::*;
539
540    /// Calculate optimal memory allocation strategy
541    pub fn calculate_allocation_strategy(
542        problem_size: usize,
543        batch_size: usize,
544        available_memory: usize,
545    ) -> AllocationStrategy {
546        let estimated_usage = estimate_memory_usage(problem_size, batch_size);
547
548        if estimated_usage > available_memory {
549            AllocationStrategy::Chunked {
550                chunk_size: available_memory / 2,
551                overlap: true,
552            }
553        } else if estimated_usage > available_memory / 2 {
554            AllocationStrategy::Conservative {
555                pool_size_limit: available_memory / 4,
556            }
557        } else {
558            AllocationStrategy::Aggressive {
559                prefetch_size: estimated_usage * 2,
560            }
561        }
562    }
563
564    /// Estimate memory usage for a given problem
565    pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
566        // Rough estimation: input data + output data + temporary buffers
567        let input_size = batch_size * _problem_size * 8; // f64
568        let output_size = batch_size * 8; // f64
569        let temp_size = input_size; // Temporary arrays
570
571        input_size + output_size + temp_size
572    }
573
574    /// Memory allocation strategies
575    #[derive(Debug, Clone)]
576    pub enum AllocationStrategy {
577        /// Use chunks with optional overlap for large problems
578        Chunked { chunk_size: usize, overlap: bool },
579        /// Conservative allocation with limited pool sizes
580        Conservative { pool_size_limit: usize },
581        /// Aggressive allocation with prefetching
582        Aggressive { prefetch_size: usize },
583    }
584
585    /// Check if the system has sufficient memory for an operation
586    pub fn check_memory_availability(
587        required_memory: usize,
588        memory_info: &GpuMemoryInfo,
589    ) -> ScirsResult<bool> {
590        let available = memory_info.free;
591        let safety_margin = 0.1; // Keep 10% free
592        let usable = (available as f64 * (1.0 - safety_margin)) as usize;
593
594        Ok(required_memory <= usable)
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[test]
603    fn test_allocation_stats() {
604        let mut stats = AllocationStats::new();
605        stats.total_allocations = 100;
606        stats.pool_hits = 70;
607
608        assert_eq!(stats.hit_rate(), 0.7);
609    }
610
611    #[test]
612    fn test_memory_stats_utilization() {
613        let stats = MemoryStats {
614            current_usage: 800,
615            memory_limit: Some(1000),
616            allocation_stats: AllocationStats::new(),
617            pool_sizes: HashMap::new(),
618        };
619
620        assert_eq!(stats.utilization(), Some(0.8));
621    }
622
623    #[test]
624    fn test_memory_usage_estimation() {
625        let usage = utils::estimate_memory_usage(10, 100);
626        assert!(usage > 0);
627
628        // Should scale with problem size and batch size
629        let larger_usage = utils::estimate_memory_usage(20, 200);
630        assert!(larger_usage > usage);
631    }
632
633    #[test]
634    fn test_allocation_strategy() {
635        let strategy = utils::calculate_allocation_strategy(
636            1000,    // Large problem
637            1000,    // Large batch
638            500_000, // Limited memory
639        );
640
641        match strategy {
642            utils::AllocationStrategy::Chunked { .. } => {
643                // Expected for large problems with limited memory
644            }
645            _ => panic!("Expected chunked strategy for large problem"),
646        }
647    }
648}