scirs2_optimize/gpu/
memory_management.rs

1//! GPU memory management for optimization workloads
2//!
3//! This module provides efficient memory management for GPU-accelerated optimization,
4//! including memory pools, workspace allocation, and automatic memory optimization.
5
6use crate::error::{ScirsError, ScirsResult};
7
8// Note: Error conversion handled through scirs2_core::error system
9// GPU errors are automatically converted via CoreError type alias
10use scirs2_core::gpu::{GpuBuffer, GpuContext};
11pub type OptimGpuArray<T> = GpuBuffer<T>;
12pub type OptimGpuBuffer<T> = GpuBuffer<T>;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15
16/// GPU memory information structure
17#[derive(Debug, Clone)]
18pub struct GpuMemoryInfo {
19    pub total: usize,
20    pub free: usize,
21    pub used: usize,
22}
23
24/// GPU memory pool for efficient allocation and reuse
25pub struct GpuMemoryPool {
26    context: Arc<GpuContext>,
27    pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
28    allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
29    memory_limit: Option<usize>,
30    current_usage: Arc<Mutex<usize>>,
31    allocation_stats: Arc<Mutex<AllocationStats>>,
32}
33
34impl GpuMemoryPool {
35    /// Create a new GPU memory pool
36    pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
37        Ok(Self {
38            context,
39            pools: Arc::new(Mutex::new(HashMap::new())),
40            allocated_blocks: Arc::new(Mutex::new(Vec::new())),
41            memory_limit,
42            current_usage: Arc::new(Mutex::new(0)),
43            allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
44        })
45    }
46
47    /// Create a stub GPU memory pool (fallback for incomplete implementations)
48    pub fn new_stub() -> Self {
49        use scirs2_core::gpu::GpuBackend;
50        let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
51        Self {
52            context: Arc::new(context),
53            pools: Arc::new(Mutex::new(HashMap::new())),
54            allocated_blocks: Arc::new(Mutex::new(Vec::new())),
55            memory_limit: None,
56            current_usage: Arc::new(Mutex::new(0)),
57            allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
58        }
59    }
60
61    /// Allocate a workspace for optimization operations
62    pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
63        let block = self.allocate_block(size)?;
64        Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
65    }
66
67    /// Allocate a memory block of the specified size
68    fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
69        let mut stats = self.allocation_stats.lock().unwrap();
70        stats.total_allocations += 1;
71
72        // Check memory limit
73        if let Some(limit) = self.memory_limit {
74            let current = *self.current_usage.lock().unwrap();
75            if current + size > limit {
76                // Drop the stats lock before garbage collection
77                drop(stats);
78                // Try to free some memory
79                self.garbage_collect()?;
80                // Reacquire the lock
81                stats = self.allocation_stats.lock().unwrap();
82                let current = *self.current_usage.lock().unwrap();
83                if current + size > limit {
84                    return Err(ScirsError::MemoryError(
85                        scirs2_core::error::ErrorContext::new(format!(
86                            "Would exceed memory limit: {} + {} > {}",
87                            current, size, limit
88                        ))
89                        .with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
90                    ));
91                }
92            }
93        }
94
95        // Try to reuse existing block from pool
96        let mut pools = self.pools.lock().unwrap();
97        if let Some(pool) = pools.get_mut(&size) {
98            if let Some(block) = pool.pop_front() {
99                stats.pool_hits += 1;
100                return Ok(block);
101            }
102        }
103
104        // Allocate new block
105        stats.new_allocations += 1;
106        let gpu_buffer = self.context.create_buffer::<u8>(size);
107        let ptr = std::ptr::null_mut();
108        let block = GpuMemoryBlock {
109            size,
110            ptr,
111            gpu_buffer: Some(gpu_buffer),
112        };
113
114        // Update current usage
115        *self.current_usage.lock().unwrap() += size;
116
117        Ok(block)
118    }
119
120    /// Return a block to the pool for reuse
121    fn return_block(&self, block: GpuMemoryBlock) {
122        let mut pools = self.pools.lock().unwrap();
123        pools
124            .entry(block.size)
125            .or_insert_with(VecDeque::new)
126            .push_back(block);
127    }
128
129    /// Perform garbage collection to free unused memory
130    fn garbage_collect(&mut self) -> ScirsResult<()> {
131        let mut pools = self.pools.lock().unwrap();
132        let mut freed_memory = 0;
133
134        // Clear all pools
135        for (size, pool) in pools.iter_mut() {
136            let count = pool.len();
137            freed_memory += size * count;
138            pool.clear();
139        }
140
141        // Update current usage
142        *self.current_usage.lock().unwrap() = self
143            .current_usage
144            .lock()
145            .unwrap()
146            .saturating_sub(freed_memory);
147
148        // Update stats
149        let mut stats = self.allocation_stats.lock().unwrap();
150        stats.garbage_collections += 1;
151        stats.total_freed_memory += freed_memory;
152
153        Ok(())
154    }
155
156    /// Get current memory usage statistics
157    pub fn memory_stats(&self) -> MemoryStats {
158        let current_usage = *self.current_usage.lock().unwrap();
159        let allocation_stats = self.allocation_stats.lock().unwrap().clone();
160        let pool_sizes: HashMap<usize, usize> = self
161            .pools
162            .lock()
163            .unwrap()
164            .iter()
165            .map(|(&size, pool)| (size, pool.len()))
166            .collect();
167
168        MemoryStats {
169            current_usage,
170            memory_limit: self.memory_limit,
171            allocation_stats,
172            pool_sizes,
173        }
174    }
175}
176
177/// A block of GPU memory
178pub struct GpuMemoryBlock {
179    size: usize,
180    ptr: *mut u8,
181    gpu_buffer: Option<OptimGpuBuffer<u8>>,
182}
183
184impl std::fmt::Debug for GpuMemoryBlock {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.debug_struct("GpuMemoryBlock")
187            .field("size", &self.size)
188            .field("ptr", &self.ptr)
189            .field("gpu_buffer", &self.gpu_buffer.is_some())
190            .finish()
191    }
192}
193
194unsafe impl Send for GpuMemoryBlock {}
195unsafe impl Sync for GpuMemoryBlock {}
196
197impl GpuMemoryBlock {
198    /// Get the size of this memory block
199    pub fn size(&self) -> usize {
200        self.size
201    }
202
203    /// Get the raw pointer to GPU memory
204    pub fn ptr(&self) -> *mut u8 {
205        self.ptr
206    }
207
208    /// Cast to a specific type
209    pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
210        if let Some(ref buffer) = self.gpu_buffer {
211            // Safe casting through scirs2-_core's type system
212            // Since cast_type doesn't exist, return an error for now
213            Err(ScirsError::ComputationError(
214                scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
215            ))
216        } else {
217            Err(ScirsError::InvalidInput(
218                scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
219            ))
220        }
221    }
222}
223
224/// Workspace for GPU optimization operations
225pub struct GpuWorkspace {
226    blocks: Vec<GpuMemoryBlock>,
227    pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
228}
229
230impl GpuWorkspace {
231    fn new(
232        initial_block: GpuMemoryBlock,
233        pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
234    ) -> Self {
235        Self {
236            blocks: vec![initial_block],
237            pool,
238        }
239    }
240
241    /// Get a memory block of the specified size
242    pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
243        // Try to find existing block of sufficient size
244        for block in &self.blocks {
245            if block.size >= size {
246                return Ok(block);
247            }
248        }
249
250        // Need to allocate new block
251        // For simplicity, we'll just return an error here
252        // In a full implementation, this would allocate from the pool
253        Err(ScirsError::MemoryError(
254            scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
255        ))
256    }
257
258    /// Get a typed buffer view of the specified size
259    pub fn get_buffer<T: scirs2_core::GpuDataType>(
260        &mut self,
261        size: usize,
262    ) -> ScirsResult<&OptimGpuBuffer<T>> {
263        let size_bytes = size * std::mem::size_of::<T>();
264        let block = self.get_block(size_bytes)?;
265        block.as_typed::<T>()
266    }
267
268    /// Create a GPU array view from the workspace
269    pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
270    where
271        T: Clone + Default + 'static + scirs2_core::GpuDataType,
272    {
273        let total_elements: usize = dimensions.iter().product();
274        let buffer = self.get_buffer::<T>(total_elements)?;
275
276        // Convert buffer to array using scirs2-_core's reshape functionality
277        // Since from_buffer doesn't exist, return an error for now
278        Err(ScirsError::ComputationError(
279            scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
280        ))
281    }
282
283    /// Get total workspace size
284    pub fn total_size(&self) -> usize {
285        self.blocks.iter().map(|b| b.size).sum()
286    }
287}
288
289impl Drop for GpuWorkspace {
290    fn drop(&mut self) {
291        // Return all blocks to the pool
292        let mut pool = self.pool.lock().unwrap();
293        for block in self.blocks.drain(..) {
294            pool.entry(block.size)
295                .or_insert_with(VecDeque::new)
296                .push_back(block);
297        }
298    }
299}
300
301/// Tracks allocated memory blocks
302#[derive(Debug)]
303struct AllocatedBlock {
304    size: usize,
305    allocated_at: std::time::Instant,
306}
307
308/// Statistics for memory allocations
309#[derive(Debug, Clone)]
310pub struct AllocationStats {
311    /// Total number of allocation requests
312    pub total_allocations: u64,
313    /// Number of allocations served from pool
314    pub pool_hits: u64,
315    /// Number of new allocations
316    pub new_allocations: u64,
317    /// Number of garbage collections performed
318    pub garbage_collections: u64,
319    /// Total memory freed by garbage collection
320    pub total_freed_memory: usize,
321}
322
323impl AllocationStats {
324    fn new() -> Self {
325        Self {
326            total_allocations: 0,
327            pool_hits: 0,
328            new_allocations: 0,
329            garbage_collections: 0,
330            total_freed_memory: 0,
331        }
332    }
333
334    /// Calculate pool hit rate
335    pub fn hit_rate(&self) -> f64 {
336        if self.total_allocations == 0 {
337            0.0
338        } else {
339            self.pool_hits as f64 / self.total_allocations as f64
340        }
341    }
342}
343
344/// Overall memory usage statistics
345#[derive(Debug, Clone)]
346pub struct MemoryStats {
347    /// Current memory usage in bytes
348    pub current_usage: usize,
349    /// Memory limit (if set)
350    pub memory_limit: Option<usize>,
351    /// Allocation statistics
352    pub allocation_stats: AllocationStats,
353    /// Size of each memory pool
354    pub pool_sizes: HashMap<usize, usize>,
355}
356
357impl MemoryStats {
358    /// Get memory utilization as a percentage (if limit is set)
359    pub fn utilization(&self) -> Option<f64> {
360        self.memory_limit.map(|limit| {
361            if limit == 0 {
362                0.0
363            } else {
364                self.current_usage as f64 / limit as f64
365            }
366        })
367    }
368
369    /// Generate a memory usage report
370    pub fn generate_report(&self) -> String {
371        let mut report = String::from("GPU Memory Usage Report\n");
372        report.push_str("=======================\n\n");
373
374        report.push_str(&format!(
375            "Current Usage: {} bytes ({:.2} MB)\n",
376            self.current_usage,
377            self.current_usage as f64 / 1024.0 / 1024.0
378        ));
379
380        if let Some(limit) = self.memory_limit {
381            report.push_str(&format!(
382                "Memory Limit: {} bytes ({:.2} MB)\n",
383                limit,
384                limit as f64 / 1024.0 / 1024.0
385            ));
386
387            if let Some(util) = self.utilization() {
388                report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
389            }
390        }
391
392        report.push('\n');
393        report.push_str("Allocation Statistics:\n");
394        report.push_str(&format!(
395            "  Total Allocations: {}\n",
396            self.allocation_stats.total_allocations
397        ));
398        report.push_str(&format!(
399            "  Pool Hits: {} ({:.1}%)\n",
400            self.allocation_stats.pool_hits,
401            self.allocation_stats.hit_rate() * 100.0
402        ));
403        report.push_str(&format!(
404            "  New Allocations: {}\n",
405            self.allocation_stats.new_allocations
406        ));
407        report.push_str(&format!(
408            "  Garbage Collections: {}\n",
409            self.allocation_stats.garbage_collections
410        ));
411        report.push_str(&format!(
412            "  Total Freed: {} bytes\n",
413            self.allocation_stats.total_freed_memory
414        ));
415
416        if !self.pool_sizes.is_empty() {
417            report.push('\n');
418            report.push_str("Memory Pools:\n");
419            let mut pools: Vec<_> = self.pool_sizes.iter().collect();
420            pools.sort_by_key(|&(size_, _)| size_);
421            for (&size, &count) in pools {
422                report.push_str(&format!("  {} bytes: {} blocks\n", size, count));
423            }
424        }
425
426        report
427    }
428}
429
430/// Memory optimization strategies
431pub mod optimization {
432    use super::*;
433
434    /// Automatic memory optimization configuration
435    #[derive(Debug, Clone)]
436    pub struct MemoryOptimizationConfig {
437        /// Target memory utilization (0.0 to 1.0)
438        pub target_utilization: f64,
439        /// Maximum pool size per block size
440        pub max_pool_size: usize,
441        /// Garbage collection threshold (utilization percentage)
442        pub gc_threshold: f64,
443        /// Whether to use memory prefetching
444        pub use_prefetching: bool,
445    }
446
447    impl Default for MemoryOptimizationConfig {
448        fn default() -> Self {
449            Self {
450                target_utilization: 0.8,
451                max_pool_size: 100,
452                gc_threshold: 0.9,
453                use_prefetching: true,
454            }
455        }
456    }
457
458    /// Memory optimizer for automatic memory management
459    pub struct MemoryOptimizer {
460        config: MemoryOptimizationConfig,
461        pool: Arc<GpuMemoryPool>,
462        optimization_stats: OptimizationStats,
463    }
464
465    impl MemoryOptimizer {
466        /// Create a new memory optimizer
467        pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
468            Self {
469                config,
470                pool,
471                optimization_stats: OptimizationStats::new(),
472            }
473        }
474
475        /// Optimize memory usage based on current statistics
476        pub fn optimize(&mut self) -> ScirsResult<()> {
477            let stats = self.pool.memory_stats();
478
479            // Check if we need garbage collection
480            if let Some(utilization) = stats.utilization() {
481                if utilization > self.config.gc_threshold {
482                    self.perform_garbage_collection()?;
483                    self.optimization_stats.gc_triggered += 1;
484                }
485            }
486
487            // Optimize pool sizes
488            self.optimize_pool_sizes(&stats)?;
489
490            Ok(())
491        }
492
493        /// Perform targeted garbage collection
494        fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
495            // This would implement smart garbage collection
496            // For now, we'll just trigger a basic GC
497            self.optimization_stats.gc_operations += 1;
498            Ok(())
499        }
500
501        /// Optimize memory pool sizes based on usage patterns
502        fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
503            // Analyze usage patterns and adjust pool sizes
504            for (&_size, &count) in &stats.pool_sizes {
505                if count > self.config.max_pool_size {
506                    // Pool is too large, consider reducing
507                    self.optimization_stats.pool_optimizations += 1;
508                }
509            }
510            Ok(())
511        }
512
513        /// Get optimization statistics
514        pub fn stats(&self) -> &OptimizationStats {
515            &self.optimization_stats
516        }
517    }
518
519    /// Statistics for memory optimization
520    #[derive(Debug, Clone)]
521    pub struct OptimizationStats {
522        /// Number of times GC was triggered by optimizer
523        pub gc_triggered: u64,
524        /// Total GC operations performed
525        pub gc_operations: u64,
526        /// Pool size optimizations performed
527        pub pool_optimizations: u64,
528    }
529
530    impl OptimizationStats {
531        fn new() -> Self {
532            Self {
533                gc_triggered: 0,
534                gc_operations: 0,
535                pool_optimizations: 0,
536            }
537        }
538    }
539}
540
541/// Utilities for memory management
542pub mod utils {
543    use super::*;
544
545    /// Calculate optimal memory allocation strategy
546    pub fn calculate_allocation_strategy(
547        problem_size: usize,
548        batch_size: usize,
549        available_memory: usize,
550    ) -> AllocationStrategy {
551        let estimated_usage = estimate_memory_usage(problem_size, batch_size);
552
553        if estimated_usage > available_memory {
554            AllocationStrategy::Chunked {
555                chunk_size: available_memory / 2,
556                overlap: true,
557            }
558        } else if estimated_usage > available_memory / 2 {
559            AllocationStrategy::Conservative {
560                pool_size_limit: available_memory / 4,
561            }
562        } else {
563            AllocationStrategy::Aggressive {
564                prefetch_size: estimated_usage * 2,
565            }
566        }
567    }
568
569    /// Estimate memory usage for a given problem
570    pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
571        // Rough estimation: input data + output data + temporary buffers
572        let input_size = batch_size * _problem_size * 8; // f64
573        let output_size = batch_size * 8; // f64
574        let temp_size = input_size; // Temporary arrays
575
576        input_size + output_size + temp_size
577    }
578
579    /// Memory allocation strategies
580    #[derive(Debug, Clone)]
581    pub enum AllocationStrategy {
582        /// Use chunks with optional overlap for large problems
583        Chunked { chunk_size: usize, overlap: bool },
584        /// Conservative allocation with limited pool sizes
585        Conservative { pool_size_limit: usize },
586        /// Aggressive allocation with prefetching
587        Aggressive { prefetch_size: usize },
588    }
589
590    /// Check if the system has sufficient memory for an operation
591    pub fn check_memory_availability(
592        required_memory: usize,
593        memory_info: &GpuMemoryInfo,
594    ) -> ScirsResult<bool> {
595        let available = memory_info.free;
596        let safety_margin = 0.1; // Keep 10% free
597        let usable = (available as f64 * (1.0 - safety_margin)) as usize;
598
599        Ok(required_memory <= usable)
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_allocation_stats() {
609        let mut stats = AllocationStats::new();
610        stats.total_allocations = 100;
611        stats.pool_hits = 70;
612
613        assert_eq!(stats.hit_rate(), 0.7);
614    }
615
616    #[test]
617    fn test_memory_stats_utilization() {
618        let stats = MemoryStats {
619            current_usage: 800,
620            memory_limit: Some(1000),
621            allocation_stats: AllocationStats::new(),
622            pool_sizes: HashMap::new(),
623        };
624
625        assert_eq!(stats.utilization(), Some(0.8));
626    }
627
628    #[test]
629    fn test_memory_usage_estimation() {
630        let usage = utils::estimate_memory_usage(10, 100);
631        assert!(usage > 0);
632
633        // Should scale with problem size and batch size
634        let larger_usage = utils::estimate_memory_usage(20, 200);
635        assert!(larger_usage > usage);
636    }
637
638    #[test]
639    fn test_allocation_strategy() {
640        let strategy = utils::calculate_allocation_strategy(
641            1000,    // Large problem
642            1000,    // Large batch
643            500_000, // Limited memory
644        );
645
646        match strategy {
647            utils::AllocationStrategy::Chunked { .. } => {
648                // Expected for large problems with limited memory
649            }
650            _ => panic!("Expected chunked strategy for large problem"),
651        }
652    }
653}