scirs2_optimize/gpu/
memory_management.rs

1//! GPU memory management for optimization workloads
2//!
3//! This module provides efficient memory management for GPU-accelerated optimization,
4//! including memory pools, workspace allocation, and automatic memory optimization.
5
6use crate::error::{ScirsError, ScirsResult};
7
8// Note: Error conversion handled through scirs2_core::error system
9// GPU errors are automatically converted via CoreError type alias
10use scirs2_core::gpu::{GpuBuffer, GpuContext};
11pub type OptimGpuArray<T> = GpuBuffer<T>;
12pub type OptimGpuBuffer<T> = GpuBuffer<T>;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{Arc, Mutex};
15
16/// GPU memory information structure
17#[derive(Debug, Clone)]
18pub struct GpuMemoryInfo {
19    pub total: usize,
20    pub free: usize,
21    pub used: usize,
22}
23
24/// GPU memory pool for efficient allocation and reuse
25pub struct GpuMemoryPool {
26    context: Arc<GpuContext>,
27    pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
28    allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
29    memory_limit: Option<usize>,
30    current_usage: Arc<Mutex<usize>>,
31    allocation_stats: Arc<Mutex<AllocationStats>>,
32}
33
34impl GpuMemoryPool {
35    /// Create a new GPU memory pool
36    pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
37        Ok(Self {
38            context,
39            pools: Arc::new(Mutex::new(HashMap::new())),
40            allocated_blocks: Arc::new(Mutex::new(Vec::new())),
41            memory_limit,
42            current_usage: Arc::new(Mutex::new(0)),
43            allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
44        })
45    }
46
47    /// Create a stub GPU memory pool (fallback for incomplete implementations)
48    pub fn new_stub() -> Self {
49        use scirs2_core::gpu::GpuBackend;
50        let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
51        Self {
52            context: Arc::new(context),
53            pools: Arc::new(Mutex::new(HashMap::new())),
54            allocated_blocks: Arc::new(Mutex::new(Vec::new())),
55            memory_limit: None,
56            current_usage: Arc::new(Mutex::new(0)),
57            allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
58        }
59    }
60
61    /// Allocate a workspace for optimization operations
62    pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
63        let block = self.allocate_block(size)?;
64        Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
65    }
66
67    /// Allocate a memory block of the specified size
68    fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
69        let mut stats = self.allocation_stats.lock().expect("Operation failed");
70        stats.total_allocations += 1;
71
72        // Check memory limit
73        if let Some(limit) = self.memory_limit {
74            let current = *self.current_usage.lock().expect("Operation failed");
75            if current + size > limit {
76                // Drop the stats lock before garbage collection
77                drop(stats);
78                // Try to free some memory
79                self.garbage_collect()?;
80                // Reacquire the lock
81                stats = self.allocation_stats.lock().expect("Operation failed");
82                let current = *self.current_usage.lock().expect("Operation failed");
83                if current + size > limit {
84                    return Err(ScirsError::MemoryError(
85                        scirs2_core::error::ErrorContext::new(format!(
86                            "Would exceed memory limit: {} + {} > {}",
87                            current, size, limit
88                        ))
89                        .with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
90                    ));
91                }
92            }
93        }
94
95        // Try to reuse existing block from pool
96        let mut pools = self.pools.lock().expect("Operation failed");
97        if let Some(pool) = pools.get_mut(&size) {
98            if let Some(block) = pool.pop_front() {
99                stats.pool_hits += 1;
100                return Ok(block);
101            }
102        }
103
104        // Allocate new block
105        stats.new_allocations += 1;
106        let gpu_buffer = self.context.create_buffer::<u8>(size);
107        let ptr = std::ptr::null_mut();
108        let block = GpuMemoryBlock {
109            size,
110            ptr,
111            gpu_buffer: Some(gpu_buffer),
112        };
113
114        // Update current usage
115        *self.current_usage.lock().expect("Operation failed") += size;
116
117        Ok(block)
118    }
119
120    /// Return a block to the pool for reuse
121    fn return_block(&self, block: GpuMemoryBlock) {
122        let mut pools = self.pools.lock().expect("Operation failed");
123        pools.entry(block.size).or_default().push_back(block);
124    }
125
126    /// Perform garbage collection to free unused memory
127    fn garbage_collect(&mut self) -> ScirsResult<()> {
128        let mut pools = self.pools.lock().expect("Operation failed");
129        let mut freed_memory = 0;
130
131        // Clear all pools
132        for (size, pool) in pools.iter_mut() {
133            let count = pool.len();
134            freed_memory += size * count;
135            pool.clear();
136        }
137
138        // Update current usage
139        *self.current_usage.lock().expect("Operation failed") = self
140            .current_usage
141            .lock()
142            .expect("Operation failed")
143            .saturating_sub(freed_memory);
144
145        // Update stats
146        let mut stats = self.allocation_stats.lock().expect("Operation failed");
147        stats.garbage_collections += 1;
148        stats.total_freed_memory += freed_memory;
149
150        Ok(())
151    }
152
153    /// Get current memory usage statistics
154    pub fn memory_stats(&self) -> MemoryStats {
155        let current_usage = *self.current_usage.lock().expect("Operation failed");
156        let allocation_stats = self
157            .allocation_stats
158            .lock()
159            .expect("Operation failed")
160            .clone();
161        let pool_sizes: HashMap<usize, usize> = self
162            .pools
163            .lock()
164            .expect("Operation failed")
165            .iter()
166            .map(|(&size, pool)| (size, pool.len()))
167            .collect();
168
169        MemoryStats {
170            current_usage,
171            memory_limit: self.memory_limit,
172            allocation_stats,
173            pool_sizes,
174        }
175    }
176}
177
178/// A block of GPU memory
179pub struct GpuMemoryBlock {
180    size: usize,
181    ptr: *mut u8,
182    gpu_buffer: Option<OptimGpuBuffer<u8>>,
183}
184
185impl std::fmt::Debug for GpuMemoryBlock {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("GpuMemoryBlock")
188            .field("size", &self.size)
189            .field("ptr", &self.ptr)
190            .field("gpu_buffer", &self.gpu_buffer.is_some())
191            .finish()
192    }
193}
194
195unsafe impl Send for GpuMemoryBlock {}
196unsafe impl Sync for GpuMemoryBlock {}
197
198impl GpuMemoryBlock {
199    /// Get the size of this memory block
200    pub fn size(&self) -> usize {
201        self.size
202    }
203
204    /// Get the raw pointer to GPU memory
205    pub fn ptr(&self) -> *mut u8 {
206        self.ptr
207    }
208
209    /// Cast to a specific type
210    pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
211        if let Some(ref buffer) = self.gpu_buffer {
212            // Safe casting through scirs2-_core's type system
213            // Since cast_type doesn't exist, return an error for now
214            Err(ScirsError::ComputationError(
215                scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
216            ))
217        } else {
218            Err(ScirsError::InvalidInput(
219                scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
220            ))
221        }
222    }
223}
224
225/// Workspace for GPU optimization operations
226pub struct GpuWorkspace {
227    blocks: Vec<GpuMemoryBlock>,
228    pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
229}
230
231impl GpuWorkspace {
232    fn new(
233        initial_block: GpuMemoryBlock,
234        pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
235    ) -> Self {
236        Self {
237            blocks: vec![initial_block],
238            pool,
239        }
240    }
241
242    /// Get a memory block of the specified size
243    pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
244        // Try to find existing block of sufficient size
245        for block in &self.blocks {
246            if block.size >= size {
247                return Ok(block);
248            }
249        }
250
251        // Need to allocate new block
252        // For simplicity, we'll just return an error here
253        // In a full implementation, this would allocate from the pool
254        Err(ScirsError::MemoryError(
255            scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
256        ))
257    }
258
259    /// Get a typed buffer view of the specified size
260    pub fn get_buffer<T: scirs2_core::GpuDataType>(
261        &mut self,
262        size: usize,
263    ) -> ScirsResult<&OptimGpuBuffer<T>> {
264        let size_bytes = size * std::mem::size_of::<T>();
265        let block = self.get_block(size_bytes)?;
266        block.as_typed::<T>()
267    }
268
269    /// Create a GPU array view from the workspace
270    pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
271    where
272        T: Clone + Default + 'static + scirs2_core::GpuDataType,
273    {
274        let total_elements: usize = dimensions.iter().product();
275        let buffer = self.get_buffer::<T>(total_elements)?;
276
277        // Convert buffer to array using scirs2-_core's reshape functionality
278        // Since from_buffer doesn't exist, return an error for now
279        Err(ScirsError::ComputationError(
280            scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
281        ))
282    }
283
284    /// Get total workspace size
285    pub fn total_size(&self) -> usize {
286        self.blocks.iter().map(|b| b.size).sum()
287    }
288}
289
290impl Drop for GpuWorkspace {
291    fn drop(&mut self) {
292        // Return all blocks to the pool
293        let mut pool = self.pool.lock().expect("Operation failed");
294        for block in self.blocks.drain(..) {
295            pool.entry(block.size).or_default().push_back(block);
296        }
297    }
298}
299
300/// Tracks allocated memory blocks
301#[derive(Debug)]
302struct AllocatedBlock {
303    size: usize,
304    allocated_at: std::time::Instant,
305}
306
307/// Statistics for memory allocations
308#[derive(Debug, Clone)]
309pub struct AllocationStats {
310    /// Total number of allocation requests
311    pub total_allocations: u64,
312    /// Number of allocations served from pool
313    pub pool_hits: u64,
314    /// Number of new allocations
315    pub new_allocations: u64,
316    /// Number of garbage collections performed
317    pub garbage_collections: u64,
318    /// Total memory freed by garbage collection
319    pub total_freed_memory: usize,
320}
321
322impl AllocationStats {
323    fn new() -> Self {
324        Self {
325            total_allocations: 0,
326            pool_hits: 0,
327            new_allocations: 0,
328            garbage_collections: 0,
329            total_freed_memory: 0,
330        }
331    }
332
333    /// Calculate pool hit rate
334    pub fn hit_rate(&self) -> f64 {
335        if self.total_allocations == 0 {
336            0.0
337        } else {
338            self.pool_hits as f64 / self.total_allocations as f64
339        }
340    }
341}
342
343/// Overall memory usage statistics
344#[derive(Debug, Clone)]
345pub struct MemoryStats {
346    /// Current memory usage in bytes
347    pub current_usage: usize,
348    /// Memory limit (if set)
349    pub memory_limit: Option<usize>,
350    /// Allocation statistics
351    pub allocation_stats: AllocationStats,
352    /// Size of each memory pool
353    pub pool_sizes: HashMap<usize, usize>,
354}
355
356impl MemoryStats {
357    /// Get memory utilization as a percentage (if limit is set)
358    pub fn utilization(&self) -> Option<f64> {
359        self.memory_limit.map(|limit| {
360            if limit == 0 {
361                0.0
362            } else {
363                self.current_usage as f64 / limit as f64
364            }
365        })
366    }
367
368    /// Generate a memory usage report
369    pub fn generate_report(&self) -> String {
370        let mut report = String::from("GPU Memory Usage Report\n");
371        report.push_str("=======================\n\n");
372
373        report.push_str(&format!(
374            "Current Usage: {} bytes ({:.2} MB)\n",
375            self.current_usage,
376            self.current_usage as f64 / 1024.0 / 1024.0
377        ));
378
379        if let Some(limit) = self.memory_limit {
380            report.push_str(&format!(
381                "Memory Limit: {} bytes ({:.2} MB)\n",
382                limit,
383                limit as f64 / 1024.0 / 1024.0
384            ));
385
386            if let Some(util) = self.utilization() {
387                report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
388            }
389        }
390
391        report.push('\n');
392        report.push_str("Allocation Statistics:\n");
393        report.push_str(&format!(
394            "  Total Allocations: {}\n",
395            self.allocation_stats.total_allocations
396        ));
397        report.push_str(&format!(
398            "  Pool Hits: {} ({:.1}%)\n",
399            self.allocation_stats.pool_hits,
400            self.allocation_stats.hit_rate() * 100.0
401        ));
402        report.push_str(&format!(
403            "  New Allocations: {}\n",
404            self.allocation_stats.new_allocations
405        ));
406        report.push_str(&format!(
407            "  Garbage Collections: {}\n",
408            self.allocation_stats.garbage_collections
409        ));
410        report.push_str(&format!(
411            "  Total Freed: {} bytes\n",
412            self.allocation_stats.total_freed_memory
413        ));
414
415        if !self.pool_sizes.is_empty() {
416            report.push('\n');
417            report.push_str("Memory Pools:\n");
418            let mut pools: Vec<_> = self.pool_sizes.iter().collect();
419            pools.sort_by_key(|&(size_, _)| size_);
420            for (&size, &count) in pools {
421                report.push_str(&format!("  {} bytes: {} blocks\n", size, count));
422            }
423        }
424
425        report
426    }
427}
428
429/// Memory optimization strategies
430pub mod optimization {
431    use super::*;
432
433    /// Automatic memory optimization configuration
434    #[derive(Debug, Clone)]
435    pub struct MemoryOptimizationConfig {
436        /// Target memory utilization (0.0 to 1.0)
437        pub target_utilization: f64,
438        /// Maximum pool size per block size
439        pub max_pool_size: usize,
440        /// Garbage collection threshold (utilization percentage)
441        pub gc_threshold: f64,
442        /// Whether to use memory prefetching
443        pub use_prefetching: bool,
444    }
445
446    impl Default for MemoryOptimizationConfig {
447        fn default() -> Self {
448            Self {
449                target_utilization: 0.8,
450                max_pool_size: 100,
451                gc_threshold: 0.9,
452                use_prefetching: true,
453            }
454        }
455    }
456
457    /// Memory optimizer for automatic memory management
458    pub struct MemoryOptimizer {
459        config: MemoryOptimizationConfig,
460        pool: Arc<GpuMemoryPool>,
461        optimization_stats: OptimizationStats,
462    }
463
464    impl MemoryOptimizer {
465        /// Create a new memory optimizer
466        pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
467            Self {
468                config,
469                pool,
470                optimization_stats: OptimizationStats::new(),
471            }
472        }
473
474        /// Optimize memory usage based on current statistics
475        pub fn optimize(&mut self) -> ScirsResult<()> {
476            let stats = self.pool.memory_stats();
477
478            // Check if we need garbage collection
479            if let Some(utilization) = stats.utilization() {
480                if utilization > self.config.gc_threshold {
481                    self.perform_garbage_collection()?;
482                    self.optimization_stats.gc_triggered += 1;
483                }
484            }
485
486            // Optimize pool sizes
487            self.optimize_pool_sizes(&stats)?;
488
489            Ok(())
490        }
491
492        /// Perform targeted garbage collection
493        fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
494            // This would implement smart garbage collection
495            // For now, we'll just trigger a basic GC
496            self.optimization_stats.gc_operations += 1;
497            Ok(())
498        }
499
500        /// Optimize memory pool sizes based on usage patterns
501        fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
502            // Analyze usage patterns and adjust pool sizes
503            for (&_size, &count) in &stats.pool_sizes {
504                if count > self.config.max_pool_size {
505                    // Pool is too large, consider reducing
506                    self.optimization_stats.pool_optimizations += 1;
507                }
508            }
509            Ok(())
510        }
511
512        /// Get optimization statistics
513        pub fn stats(&self) -> &OptimizationStats {
514            &self.optimization_stats
515        }
516    }
517
518    /// Statistics for memory optimization
519    #[derive(Debug, Clone)]
520    pub struct OptimizationStats {
521        /// Number of times GC was triggered by optimizer
522        pub gc_triggered: u64,
523        /// Total GC operations performed
524        pub gc_operations: u64,
525        /// Pool size optimizations performed
526        pub pool_optimizations: u64,
527    }
528
529    impl OptimizationStats {
530        fn new() -> Self {
531            Self {
532                gc_triggered: 0,
533                gc_operations: 0,
534                pool_optimizations: 0,
535            }
536        }
537    }
538}
539
540/// Utilities for memory management
541pub mod utils {
542    use super::*;
543
544    /// Calculate optimal memory allocation strategy
545    pub fn calculate_allocation_strategy(
546        problem_size: usize,
547        batch_size: usize,
548        available_memory: usize,
549    ) -> AllocationStrategy {
550        let estimated_usage = estimate_memory_usage(problem_size, batch_size);
551
552        if estimated_usage > available_memory {
553            AllocationStrategy::Chunked {
554                chunk_size: available_memory / 2,
555                overlap: true,
556            }
557        } else if estimated_usage > available_memory / 2 {
558            AllocationStrategy::Conservative {
559                pool_size_limit: available_memory / 4,
560            }
561        } else {
562            AllocationStrategy::Aggressive {
563                prefetch_size: estimated_usage * 2,
564            }
565        }
566    }
567
568    /// Estimate memory usage for a given problem
569    pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
570        // Rough estimation: input data + output data + temporary buffers
571        let input_size = batch_size * _problem_size * 8; // f64
572        let output_size = batch_size * 8; // f64
573        let temp_size = input_size; // Temporary arrays
574
575        input_size + output_size + temp_size
576    }
577
578    /// Memory allocation strategies
579    #[derive(Debug, Clone)]
580    pub enum AllocationStrategy {
581        /// Use chunks with optional overlap for large problems
582        Chunked { chunk_size: usize, overlap: bool },
583        /// Conservative allocation with limited pool sizes
584        Conservative { pool_size_limit: usize },
585        /// Aggressive allocation with prefetching
586        Aggressive { prefetch_size: usize },
587    }
588
589    /// Check if the system has sufficient memory for an operation
590    pub fn check_memory_availability(
591        required_memory: usize,
592        memory_info: &GpuMemoryInfo,
593    ) -> ScirsResult<bool> {
594        let available = memory_info.free;
595        let safety_margin = 0.1; // Keep 10% free
596        let usable = (available as f64 * (1.0 - safety_margin)) as usize;
597
598        Ok(required_memory <= usable)
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_allocation_stats() {
608        let mut stats = AllocationStats::new();
609        stats.total_allocations = 100;
610        stats.pool_hits = 70;
611
612        assert_eq!(stats.hit_rate(), 0.7);
613    }
614
615    #[test]
616    fn test_memory_stats_utilization() {
617        let stats = MemoryStats {
618            current_usage: 800,
619            memory_limit: Some(1000),
620            allocation_stats: AllocationStats::new(),
621            pool_sizes: HashMap::new(),
622        };
623
624        assert_eq!(stats.utilization(), Some(0.8));
625    }
626
627    #[test]
628    fn test_memory_usage_estimation() {
629        let usage = utils::estimate_memory_usage(10, 100);
630        assert!(usage > 0);
631
632        // Should scale with problem size and batch size
633        let larger_usage = utils::estimate_memory_usage(20, 200);
634        assert!(larger_usage > usage);
635    }
636
637    #[test]
638    fn test_allocation_strategy() {
639        let strategy = utils::calculate_allocation_strategy(
640            1000,    // Large problem
641            1000,    // Large batch
642            500_000, // Limited memory
643        );
644
645        match strategy {
646            utils::AllocationStrategy::Chunked { .. } => {
647                // Expected for large problems with limited memory
648            }
649            _ => panic!("Expected chunked strategy for large problem"),
650        }
651    }
652}