Skip to main content

rustkernel_core/memory/
mod.rs

1//! Memory Management Infrastructure
2//!
3//! Provides GPU memory pooling, pressure handling, and analytics context management
4//! for high-performance kernel execution.
5//!
6//! # Features
7//!
8//! - **Memory Pools**: Size-stratified pools for efficient GPU allocation
9//! - **Pressure Handling**: Automatic memory pressure detection and mitigation
10//! - **Analytics Context**: Reusable buffers for analytics workloads
11//! - **Reduction Buffers**: Cached buffers for GPU reduction operations
12//! - **Multi-phase Reductions**: Synchronization primitives for iterative algorithms
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use rustkernel_core::memory::{KernelMemoryManager, MemoryConfig};
18//!
19//! let config = MemoryConfig::production();
20//! let manager = KernelMemoryManager::new(config);
21//!
22//! // Allocate from pool
23//! let buffer = manager.allocate(1024 * 1024)?; // 1MB
24//!
25//! // Return to pool
26//! manager.deallocate(buffer);
27//! ```
28
29pub mod reduction;
30
31pub use reduction::{
32    CooperativeBarrier, GlobalReduction, InterPhaseReduction, PhaseState, ReductionBuilder,
33    ReductionConfig, ReductionError, ReductionOp, SyncMode,
34};
35
36// Re-export ringkernel-core 0.4.2 memory primitives for deep integration.
37pub use ringkernel_core::analytics_context as ring_analytics_context;
38pub use ringkernel_core::memory as ring_memory;
39pub use ringkernel_core::reduction as ring_reduction;
40
41use serde::{Deserialize, Serialize};
42use std::collections::HashMap;
43use std::sync::Arc;
44use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
45use tokio::sync::RwLock;
46
47/// Memory configuration
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MemoryConfig {
50    /// Maximum total GPU memory to use (bytes)
51    pub max_gpu_memory: u64,
52    /// Maximum total CPU staging memory (bytes)
53    pub max_staging_memory: u64,
54    /// Enable memory pooling
55    pub pooling_enabled: bool,
56    /// Pool bucket sizes (bytes)
57    pub bucket_sizes: Vec<u64>,
58    /// Pressure threshold (0.0-1.0)
59    pub pressure_threshold: f64,
60    /// Enable automatic defragmentation
61    pub auto_defrag: bool,
62    /// Defrag threshold (fragmentation ratio)
63    pub defrag_threshold: f64,
64}
65
66impl Default for MemoryConfig {
67    fn default() -> Self {
68        Self {
69            max_gpu_memory: 4 * 1024 * 1024 * 1024, // 4GB
70            max_staging_memory: 1024 * 1024 * 1024, // 1GB
71            pooling_enabled: true,
72            bucket_sizes: vec![
73                64 * 1024,        // 64KB
74                256 * 1024,       // 256KB
75                1024 * 1024,      // 1MB
76                4 * 1024 * 1024,  // 4MB
77                16 * 1024 * 1024, // 16MB
78                64 * 1024 * 1024, // 64MB
79            ],
80            pressure_threshold: 0.85,
81            auto_defrag: true,
82            defrag_threshold: 0.3,
83        }
84    }
85}
86
87impl MemoryConfig {
88    /// Create development configuration (smaller limits)
89    pub fn development() -> Self {
90        Self {
91            max_gpu_memory: 512 * 1024 * 1024,     // 512MB
92            max_staging_memory: 256 * 1024 * 1024, // 256MB
93            pooling_enabled: false,
94            ..Default::default()
95        }
96    }
97
98    /// Create production configuration
99    pub fn production() -> Self {
100        Self::default()
101    }
102
103    /// Create high-performance configuration
104    pub fn high_performance() -> Self {
105        Self {
106            max_gpu_memory: 16 * 1024 * 1024 * 1024,    // 16GB
107            max_staging_memory: 4 * 1024 * 1024 * 1024, // 4GB
108            pooling_enabled: true,
109            auto_defrag: true,
110            defrag_threshold: 0.2,
111            ..Default::default()
112        }
113    }
114}
115
116/// Size bucket for memory pool
117#[derive(Debug)]
118pub struct SizeBucket {
119    /// Bucket size in bytes
120    pub size: u64,
121    /// Number of available buffers
122    pub available: AtomicUsize,
123    /// Number of allocated buffers
124    pub allocated: AtomicUsize,
125    /// Peak allocation
126    pub peak: AtomicUsize,
127}
128
129impl SizeBucket {
130    /// Create a new size bucket
131    pub fn new(size: u64) -> Self {
132        Self {
133            size,
134            available: AtomicUsize::new(0),
135            allocated: AtomicUsize::new(0),
136            peak: AtomicUsize::new(0),
137        }
138    }
139
140    /// Record an allocation
141    pub fn record_alloc(&self) {
142        let count = self.allocated.fetch_add(1, Ordering::Relaxed) + 1;
143        let mut peak = self.peak.load(Ordering::Relaxed);
144        while count > peak {
145            match self
146                .peak
147                .compare_exchange_weak(peak, count, Ordering::Relaxed, Ordering::Relaxed)
148            {
149                Ok(_) => break,
150                Err(p) => peak = p,
151            }
152        }
153    }
154
155    /// Record a deallocation
156    pub fn record_dealloc(&self) {
157        self.allocated.fetch_sub(1, Ordering::Relaxed);
158    }
159
160    /// Get bucket statistics
161    pub fn stats(&self) -> BucketStats {
162        BucketStats {
163            size: self.size,
164            available: self.available.load(Ordering::Relaxed),
165            allocated: self.allocated.load(Ordering::Relaxed),
166            peak: self.peak.load(Ordering::Relaxed),
167        }
168    }
169}
170
171/// Statistics for a size bucket
172#[derive(Debug, Clone)]
173pub struct BucketStats {
174    /// Bucket size in bytes
175    pub size: u64,
176    /// Number of available buffers
177    pub available: usize,
178    /// Number of allocated buffers
179    pub allocated: usize,
180    /// Peak allocation
181    pub peak: usize,
182}
183
184/// Memory buffer handle
185#[derive(Debug)]
186pub struct MemoryBuffer {
187    /// Buffer ID
188    pub id: u64,
189    /// Size in bytes
190    pub size: u64,
191    /// Bucket index (if from pool)
192    pub bucket_index: Option<usize>,
193    /// Whether buffer is GPU memory
194    pub is_gpu: bool,
195}
196
197/// Memory allocation result
198pub type AllocResult<T> = std::result::Result<T, MemoryError>;
199
200/// Memory errors
201#[derive(Debug, thiserror::Error)]
202pub enum MemoryError {
203    /// Out of memory
204    #[error("Out of memory: requested {requested} bytes, available {available} bytes")]
205    OutOfMemory {
206        /// Requested size
207        requested: u64,
208        /// Available size
209        available: u64,
210    },
211
212    /// Memory pressure exceeded
213    #[error("Memory pressure exceeded: {usage_percent:.1}% usage")]
214    PressureExceeded {
215        /// Current usage percentage
216        usage_percent: f64,
217    },
218
219    /// Invalid buffer
220    #[error("Invalid buffer: {id}")]
221    InvalidBuffer {
222        /// Buffer ID
223        id: u64,
224    },
225
226    /// Allocation failed
227    #[error("Allocation failed: {reason}")]
228    AllocationFailed {
229        /// Failure reason
230        reason: String,
231    },
232}
233
234/// Memory pressure level
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
236pub enum PressureLevel {
237    /// Normal operation
238    #[default]
239    Normal,
240    /// Elevated usage, start cleanup
241    Warning,
242    /// High usage, defer allocations
243    High,
244    /// Critical usage, emergency cleanup
245    Critical,
246}
247
248impl PressureLevel {
249    /// Get pressure level from usage ratio
250    pub fn from_ratio(ratio: f64) -> Self {
251        if ratio < 0.70 {
252            Self::Normal
253        } else if ratio < 0.85 {
254            Self::Warning
255        } else if ratio < 0.95 {
256            Self::High
257        } else {
258            Self::Critical
259        }
260    }
261}
262
263/// Memory statistics
264#[derive(Debug, Clone, Default)]
265pub struct MemoryStats {
266    /// Total GPU memory (bytes)
267    pub gpu_total: u64,
268    /// Used GPU memory (bytes)
269    pub gpu_used: u64,
270    /// Peak GPU memory (bytes)
271    pub gpu_peak: u64,
272    /// Total staging memory (bytes)
273    pub staging_total: u64,
274    /// Used staging memory (bytes)
275    pub staging_used: u64,
276    /// Number of allocations
277    pub allocations: u64,
278    /// Number of deallocations
279    pub deallocations: u64,
280    /// Pool hit rate
281    pub pool_hit_rate: f64,
282    /// Current pressure level
283    pub pressure_level: PressureLevel,
284}
285
286/// Kernel memory manager
287pub struct KernelMemoryManager {
288    config: MemoryConfig,
289    buckets: Vec<SizeBucket>,
290    stats: Arc<MemoryStatsInner>,
291    buffers: Arc<RwLock<HashMap<u64, MemoryBuffer>>>,
292    next_id: AtomicU64,
293}
294
295#[derive(Debug, Default)]
296struct MemoryStatsInner {
297    gpu_used: AtomicU64,
298    gpu_peak: AtomicU64,
299    staging_used: AtomicU64,
300    allocations: AtomicU64,
301    deallocations: AtomicU64,
302    pool_hits: AtomicU64,
303    pool_misses: AtomicU64,
304}
305
306impl KernelMemoryManager {
307    /// Create a new memory manager
308    pub fn new(config: MemoryConfig) -> Self {
309        let buckets = config
310            .bucket_sizes
311            .iter()
312            .map(|&size| SizeBucket::new(size))
313            .collect();
314
315        Self {
316            config,
317            buckets,
318            stats: Arc::new(MemoryStatsInner::default()),
319            buffers: Arc::new(RwLock::new(HashMap::new())),
320            next_id: AtomicU64::new(1),
321        }
322    }
323
324    /// Get configuration
325    pub fn config(&self) -> &MemoryConfig {
326        &self.config
327    }
328
329    /// Allocate GPU memory
330    pub async fn allocate(&self, size: u64) -> AllocResult<MemoryBuffer> {
331        // Check pressure
332        let pressure = self.pressure_level();
333        if pressure == PressureLevel::Critical {
334            return Err(MemoryError::PressureExceeded {
335                usage_percent: self.gpu_usage_percent(),
336            });
337        }
338
339        // Check limits
340        let current_used = self.stats.gpu_used.load(Ordering::Relaxed);
341        if current_used + size > self.config.max_gpu_memory {
342            return Err(MemoryError::OutOfMemory {
343                requested: size,
344                available: self.config.max_gpu_memory - current_used,
345            });
346        }
347
348        // Try pool allocation
349        let bucket_index = if self.config.pooling_enabled {
350            self.find_bucket(size)
351        } else {
352            None
353        };
354
355        if let Some(idx) = bucket_index {
356            self.stats.pool_hits.fetch_add(1, Ordering::Relaxed);
357            self.buckets[idx].record_alloc();
358        } else if self.config.pooling_enabled {
359            self.stats.pool_misses.fetch_add(1, Ordering::Relaxed);
360        }
361
362        // Update stats
363        self.stats.gpu_used.fetch_add(size, Ordering::Relaxed);
364        self.stats.allocations.fetch_add(1, Ordering::Relaxed);
365
366        // Update peak
367        let new_used = self.stats.gpu_used.load(Ordering::Relaxed);
368        let mut peak = self.stats.gpu_peak.load(Ordering::Relaxed);
369        while new_used > peak {
370            match self.stats.gpu_peak.compare_exchange_weak(
371                peak,
372                new_used,
373                Ordering::Relaxed,
374                Ordering::Relaxed,
375            ) {
376                Ok(_) => break,
377                Err(p) => peak = p,
378            }
379        }
380
381        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
382        let buffer = MemoryBuffer {
383            id,
384            size,
385            bucket_index,
386            is_gpu: true,
387        };
388
389        self.buffers.write().await.insert(
390            id,
391            MemoryBuffer {
392                id,
393                size,
394                bucket_index,
395                is_gpu: true,
396            },
397        );
398
399        Ok(buffer)
400    }
401
402    /// Deallocate GPU memory
403    pub async fn deallocate(&self, buffer: MemoryBuffer) -> AllocResult<()> {
404        let removed = self.buffers.write().await.remove(&buffer.id);
405        if removed.is_none() {
406            return Err(MemoryError::InvalidBuffer { id: buffer.id });
407        }
408
409        if let Some(idx) = buffer.bucket_index {
410            self.buckets[idx].record_dealloc();
411        }
412
413        self.stats
414            .gpu_used
415            .fetch_sub(buffer.size, Ordering::Relaxed);
416        self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
417
418        Ok(())
419    }
420
421    /// Allocate staging (CPU) memory
422    pub async fn allocate_staging(&self, size: u64) -> AllocResult<MemoryBuffer> {
423        let current_used = self.stats.staging_used.load(Ordering::Relaxed);
424        if current_used + size > self.config.max_staging_memory {
425            return Err(MemoryError::OutOfMemory {
426                requested: size,
427                available: self.config.max_staging_memory - current_used,
428            });
429        }
430
431        self.stats.staging_used.fetch_add(size, Ordering::Relaxed);
432        self.stats.allocations.fetch_add(1, Ordering::Relaxed);
433
434        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
435        let buffer = MemoryBuffer {
436            id,
437            size,
438            bucket_index: None,
439            is_gpu: false,
440        };
441
442        self.buffers.write().await.insert(
443            id,
444            MemoryBuffer {
445                id,
446                size,
447                bucket_index: None,
448                is_gpu: false,
449            },
450        );
451
452        Ok(buffer)
453    }
454
455    /// Deallocate staging memory
456    pub async fn deallocate_staging(&self, buffer: MemoryBuffer) -> AllocResult<()> {
457        let removed = self.buffers.write().await.remove(&buffer.id);
458        if removed.is_none() {
459            return Err(MemoryError::InvalidBuffer { id: buffer.id });
460        }
461
462        self.stats
463            .staging_used
464            .fetch_sub(buffer.size, Ordering::Relaxed);
465        self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
466
467        Ok(())
468    }
469
470    /// Get current memory statistics
471    pub fn stats(&self) -> MemoryStats {
472        let gpu_used = self.stats.gpu_used.load(Ordering::Relaxed);
473        let pool_hits = self.stats.pool_hits.load(Ordering::Relaxed);
474        let pool_misses = self.stats.pool_misses.load(Ordering::Relaxed);
475        let total_pool = pool_hits + pool_misses;
476
477        MemoryStats {
478            gpu_total: self.config.max_gpu_memory,
479            gpu_used,
480            gpu_peak: self.stats.gpu_peak.load(Ordering::Relaxed),
481            staging_total: self.config.max_staging_memory,
482            staging_used: self.stats.staging_used.load(Ordering::Relaxed),
483            allocations: self.stats.allocations.load(Ordering::Relaxed),
484            deallocations: self.stats.deallocations.load(Ordering::Relaxed),
485            pool_hit_rate: if total_pool > 0 {
486                pool_hits as f64 / total_pool as f64
487            } else {
488                0.0
489            },
490            pressure_level: self.pressure_level(),
491        }
492    }
493
494    /// Get bucket statistics
495    pub fn bucket_stats(&self) -> Vec<BucketStats> {
496        self.buckets.iter().map(|b| b.stats()).collect()
497    }
498
499    /// Get current pressure level
500    pub fn pressure_level(&self) -> PressureLevel {
501        PressureLevel::from_ratio(self.gpu_usage_percent() / 100.0)
502    }
503
504    /// Get GPU usage percentage
505    pub fn gpu_usage_percent(&self) -> f64 {
506        let used = self.stats.gpu_used.load(Ordering::Relaxed) as f64;
507        let total = self.config.max_gpu_memory as f64;
508        (used / total) * 100.0
509    }
510
511    /// Request garbage collection
512    pub async fn request_gc(&self) {
513        // Clear unused pool buffers
514        tracing::info!(
515            "Memory GC requested, pressure level: {:?}",
516            self.pressure_level()
517        );
518    }
519
520    /// Find appropriate bucket for size
521    fn find_bucket(&self, size: u64) -> Option<usize> {
522        self.buckets.iter().position(|b| b.size >= size)
523    }
524}
525
526impl Default for KernelMemoryManager {
527    fn default() -> Self {
528        Self::new(MemoryConfig::default())
529    }
530}
531
532/// Reduction buffer for GPU reduction operations
533#[derive(Debug)]
534pub struct ReductionBuffer<T> {
535    /// Buffer data
536    data: Vec<T>,
537    /// Capacity
538    capacity: usize,
539}
540
541impl<T: Default + Clone> ReductionBuffer<T> {
542    /// Create a new reduction buffer
543    pub fn new(capacity: usize) -> Self {
544        Self {
545            data: vec![T::default(); capacity],
546            capacity,
547        }
548    }
549
550    /// Get capacity
551    pub fn capacity(&self) -> usize {
552        self.capacity
553    }
554
555    /// Get data slice
556    pub fn as_slice(&self) -> &[T] {
557        &self.data
558    }
559
560    /// Get mutable data slice
561    pub fn as_mut_slice(&mut self) -> &mut [T] {
562        &mut self.data
563    }
564
565    /// Reset buffer to default values
566    pub fn reset(&mut self) {
567        for item in &mut self.data {
568            *item = T::default();
569        }
570    }
571}
572
573/// Reduction buffer cache
574pub struct ReductionBufferCache {
575    max_buffers: usize,
576    buffers: Arc<RwLock<Vec<Vec<u8>>>>,
577}
578
579impl ReductionBufferCache {
580    /// Create a new cache
581    pub fn new(max_buffers: usize) -> Self {
582        Self {
583            max_buffers,
584            buffers: Arc::new(RwLock::new(Vec::new())),
585        }
586    }
587
588    /// Get or create a buffer of the given size
589    pub async fn get(&self, size: usize) -> Vec<u8> {
590        let mut buffers = self.buffers.write().await;
591
592        // Try to find a buffer of adequate size
593        if let Some(pos) = buffers.iter().position(|b| b.capacity() >= size) {
594            let mut buf = buffers.remove(pos);
595            buf.resize(size, 0);
596            return buf;
597        }
598
599        // Create new buffer
600        vec![0u8; size]
601    }
602
603    /// Return a buffer to the cache
604    pub async fn return_buffer(&self, buffer: Vec<u8>) {
605        let mut buffers = self.buffers.write().await;
606        if buffers.len() < self.max_buffers {
607            buffers.push(buffer);
608        }
609        // Otherwise let it drop
610    }
611
612    /// Clear the cache
613    pub async fn clear(&self) {
614        self.buffers.write().await.clear();
615    }
616}
617
618impl Default for ReductionBufferCache {
619    fn default() -> Self {
620        Self::new(16)
621    }
622}
623
624/// Analytics context for reusable buffers
625#[derive(Debug)]
626pub struct AnalyticsContext {
627    /// Context ID
628    pub id: u64,
629    /// Maximum working set size
630    pub max_working_set: u64,
631    /// Current allocations
632    allocations: AtomicU64,
633}
634
635impl AnalyticsContext {
636    /// Create a new analytics context
637    pub fn new(id: u64, max_working_set: u64) -> Self {
638        Self {
639            id,
640            max_working_set,
641            allocations: AtomicU64::new(0),
642        }
643    }
644
645    /// Record an allocation
646    pub fn record_allocation(&self, size: u64) -> bool {
647        let current = self.allocations.load(Ordering::Relaxed);
648        if current + size > self.max_working_set {
649            return false;
650        }
651        self.allocations.fetch_add(size, Ordering::Relaxed);
652        true
653    }
654
655    /// Record a deallocation
656    pub fn record_deallocation(&self, size: u64) {
657        self.allocations.fetch_sub(size, Ordering::Relaxed);
658    }
659
660    /// Get current usage
661    pub fn current_usage(&self) -> u64 {
662        self.allocations.load(Ordering::Relaxed)
663    }
664
665    /// Get usage percentage
666    pub fn usage_percent(&self) -> f64 {
667        (self.current_usage() as f64 / self.max_working_set as f64) * 100.0
668    }
669}
670
671/// Analytics context manager
672pub struct AnalyticsContextManager {
673    contexts: Arc<RwLock<HashMap<u64, Arc<AnalyticsContext>>>>,
674    default_working_set: u64,
675    next_id: AtomicU64,
676}
677
678impl AnalyticsContextManager {
679    /// Create a new context manager
680    pub fn new(default_working_set: u64) -> Self {
681        Self {
682            contexts: Arc::new(RwLock::new(HashMap::new())),
683            default_working_set,
684            next_id: AtomicU64::new(1),
685        }
686    }
687
688    /// Create a new analytics context
689    pub async fn create_context(&self) -> Arc<AnalyticsContext> {
690        self.create_context_with_size(self.default_working_set)
691            .await
692    }
693
694    /// Create a context with specific working set size
695    pub async fn create_context_with_size(&self, max_working_set: u64) -> Arc<AnalyticsContext> {
696        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
697        let ctx = Arc::new(AnalyticsContext::new(id, max_working_set));
698        self.contexts.write().await.insert(id, ctx.clone());
699        ctx
700    }
701
702    /// Get a context by ID
703    pub async fn get_context(&self, id: u64) -> Option<Arc<AnalyticsContext>> {
704        self.contexts.read().await.get(&id).cloned()
705    }
706
707    /// Remove a context
708    pub async fn remove_context(&self, id: u64) {
709        self.contexts.write().await.remove(&id);
710    }
711
712    /// Get number of active contexts
713    pub async fn active_contexts(&self) -> usize {
714        self.contexts.read().await.len()
715    }
716}
717
718impl Default for AnalyticsContextManager {
719    fn default() -> Self {
720        Self::new(256 * 1024 * 1024) // 256MB default working set
721    }
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    #[tokio::test]
729    async fn test_memory_allocation() {
730        let manager = KernelMemoryManager::new(MemoryConfig::development());
731
732        let buffer = manager.allocate(1024).await.unwrap();
733        assert_eq!(buffer.size, 1024);
734        assert!(buffer.is_gpu);
735
736        let stats = manager.stats();
737        assert_eq!(stats.gpu_used, 1024);
738        assert_eq!(stats.allocations, 1);
739
740        manager.deallocate(buffer).await.unwrap();
741
742        let stats = manager.stats();
743        assert_eq!(stats.gpu_used, 0);
744        assert_eq!(stats.deallocations, 1);
745    }
746
747    #[tokio::test]
748    async fn test_out_of_memory() {
749        let config = MemoryConfig {
750            max_gpu_memory: 1024,
751            ..MemoryConfig::development()
752        };
753        let manager = KernelMemoryManager::new(config);
754
755        let result = manager.allocate(2048).await;
756        assert!(matches!(result, Err(MemoryError::OutOfMemory { .. })));
757    }
758
759    #[tokio::test]
760    async fn test_pressure_levels() {
761        let config = MemoryConfig {
762            max_gpu_memory: 1000,
763            ..MemoryConfig::development()
764        };
765        let manager = KernelMemoryManager::new(config);
766
767        assert_eq!(manager.pressure_level(), PressureLevel::Normal);
768
769        // Allocate 70%
770        let _buf = manager.allocate(700).await.unwrap();
771        assert_eq!(manager.pressure_level(), PressureLevel::Warning);
772    }
773
774    #[tokio::test]
775    async fn test_reduction_buffer_cache() {
776        let cache = ReductionBufferCache::new(4);
777
778        let buf1 = cache.get(1024).await;
779        assert_eq!(buf1.len(), 1024);
780
781        cache.return_buffer(buf1).await;
782
783        // Should reuse the buffer
784        let buf2 = cache.get(512).await;
785        assert_eq!(buf2.len(), 512);
786    }
787
788    #[tokio::test]
789    async fn test_analytics_context() {
790        let manager = AnalyticsContextManager::new(1024);
791
792        let ctx = manager.create_context().await;
793        assert!(ctx.record_allocation(512));
794        assert_eq!(ctx.current_usage(), 512);
795
796        ctx.record_deallocation(256);
797        assert_eq!(ctx.current_usage(), 256);
798    }
799}