1pub mod reduction;
30
31pub use reduction::{
32 CooperativeBarrier, GlobalReduction, InterPhaseReduction, PhaseState, ReductionBuilder,
33 ReductionConfig, ReductionError, ReductionOp, SyncMode,
34};
35
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::sync::Arc;
39use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
40use tokio::sync::RwLock;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct MemoryConfig {
45 pub max_gpu_memory: u64,
47 pub max_staging_memory: u64,
49 pub pooling_enabled: bool,
51 pub bucket_sizes: Vec<u64>,
53 pub pressure_threshold: f64,
55 pub auto_defrag: bool,
57 pub defrag_threshold: f64,
59}
60
61impl Default for MemoryConfig {
62 fn default() -> Self {
63 Self {
64 max_gpu_memory: 4 * 1024 * 1024 * 1024, max_staging_memory: 1024 * 1024 * 1024, pooling_enabled: true,
67 bucket_sizes: vec![
68 64 * 1024, 256 * 1024, 1024 * 1024, 4 * 1024 * 1024, 16 * 1024 * 1024, 64 * 1024 * 1024, ],
75 pressure_threshold: 0.85,
76 auto_defrag: true,
77 defrag_threshold: 0.3,
78 }
79 }
80}
81
82impl MemoryConfig {
83 pub fn development() -> Self {
85 Self {
86 max_gpu_memory: 512 * 1024 * 1024, max_staging_memory: 256 * 1024 * 1024, pooling_enabled: false,
89 ..Default::default()
90 }
91 }
92
93 pub fn production() -> Self {
95 Self::default()
96 }
97
98 pub fn high_performance() -> Self {
100 Self {
101 max_gpu_memory: 16 * 1024 * 1024 * 1024, max_staging_memory: 4 * 1024 * 1024 * 1024, pooling_enabled: true,
104 auto_defrag: true,
105 defrag_threshold: 0.2,
106 ..Default::default()
107 }
108 }
109}
110
111#[derive(Debug)]
113pub struct SizeBucket {
114 pub size: u64,
116 pub available: AtomicUsize,
118 pub allocated: AtomicUsize,
120 pub peak: AtomicUsize,
122}
123
124impl SizeBucket {
125 pub fn new(size: u64) -> Self {
127 Self {
128 size,
129 available: AtomicUsize::new(0),
130 allocated: AtomicUsize::new(0),
131 peak: AtomicUsize::new(0),
132 }
133 }
134
135 pub fn record_alloc(&self) {
137 let count = self.allocated.fetch_add(1, Ordering::Relaxed) + 1;
138 let mut peak = self.peak.load(Ordering::Relaxed);
139 while count > peak {
140 match self
141 .peak
142 .compare_exchange_weak(peak, count, Ordering::Relaxed, Ordering::Relaxed)
143 {
144 Ok(_) => break,
145 Err(p) => peak = p,
146 }
147 }
148 }
149
150 pub fn record_dealloc(&self) {
152 self.allocated.fetch_sub(1, Ordering::Relaxed);
153 }
154
155 pub fn stats(&self) -> BucketStats {
157 BucketStats {
158 size: self.size,
159 available: self.available.load(Ordering::Relaxed),
160 allocated: self.allocated.load(Ordering::Relaxed),
161 peak: self.peak.load(Ordering::Relaxed),
162 }
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct BucketStats {
169 pub size: u64,
171 pub available: usize,
173 pub allocated: usize,
175 pub peak: usize,
177}
178
179#[derive(Debug)]
181pub struct MemoryBuffer {
182 pub id: u64,
184 pub size: u64,
186 pub bucket_index: Option<usize>,
188 pub is_gpu: bool,
190}
191
192pub type AllocResult<T> = std::result::Result<T, MemoryError>;
194
195#[derive(Debug, thiserror::Error)]
197pub enum MemoryError {
198 #[error("Out of memory: requested {requested} bytes, available {available} bytes")]
200 OutOfMemory {
201 requested: u64,
203 available: u64,
205 },
206
207 #[error("Memory pressure exceeded: {usage_percent:.1}% usage")]
209 PressureExceeded {
210 usage_percent: f64,
212 },
213
214 #[error("Invalid buffer: {id}")]
216 InvalidBuffer {
217 id: u64,
219 },
220
221 #[error("Allocation failed: {reason}")]
223 AllocationFailed {
224 reason: String,
226 },
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
231pub enum PressureLevel {
232 #[default]
234 Normal,
235 Warning,
237 High,
239 Critical,
241}
242
243impl PressureLevel {
244 pub fn from_ratio(ratio: f64) -> Self {
246 if ratio < 0.70 {
247 Self::Normal
248 } else if ratio < 0.85 {
249 Self::Warning
250 } else if ratio < 0.95 {
251 Self::High
252 } else {
253 Self::Critical
254 }
255 }
256}
257
258#[derive(Debug, Clone, Default)]
260pub struct MemoryStats {
261 pub gpu_total: u64,
263 pub gpu_used: u64,
265 pub gpu_peak: u64,
267 pub staging_total: u64,
269 pub staging_used: u64,
271 pub allocations: u64,
273 pub deallocations: u64,
275 pub pool_hit_rate: f64,
277 pub pressure_level: PressureLevel,
279}
280
281pub struct KernelMemoryManager {
283 config: MemoryConfig,
284 buckets: Vec<SizeBucket>,
285 stats: Arc<MemoryStatsInner>,
286 buffers: Arc<RwLock<HashMap<u64, MemoryBuffer>>>,
287 next_id: AtomicU64,
288}
289
290#[derive(Debug, Default)]
291struct MemoryStatsInner {
292 gpu_used: AtomicU64,
293 gpu_peak: AtomicU64,
294 staging_used: AtomicU64,
295 allocations: AtomicU64,
296 deallocations: AtomicU64,
297 pool_hits: AtomicU64,
298 pool_misses: AtomicU64,
299}
300
301impl KernelMemoryManager {
302 pub fn new(config: MemoryConfig) -> Self {
304 let buckets = config
305 .bucket_sizes
306 .iter()
307 .map(|&size| SizeBucket::new(size))
308 .collect();
309
310 Self {
311 config,
312 buckets,
313 stats: Arc::new(MemoryStatsInner::default()),
314 buffers: Arc::new(RwLock::new(HashMap::new())),
315 next_id: AtomicU64::new(1),
316 }
317 }
318
319 pub fn config(&self) -> &MemoryConfig {
321 &self.config
322 }
323
324 pub async fn allocate(&self, size: u64) -> AllocResult<MemoryBuffer> {
326 let pressure = self.pressure_level();
328 if pressure == PressureLevel::Critical {
329 return Err(MemoryError::PressureExceeded {
330 usage_percent: self.gpu_usage_percent(),
331 });
332 }
333
334 let current_used = self.stats.gpu_used.load(Ordering::Relaxed);
336 if current_used + size > self.config.max_gpu_memory {
337 return Err(MemoryError::OutOfMemory {
338 requested: size,
339 available: self.config.max_gpu_memory - current_used,
340 });
341 }
342
343 let bucket_index = if self.config.pooling_enabled {
345 self.find_bucket(size)
346 } else {
347 None
348 };
349
350 if let Some(idx) = bucket_index {
351 self.stats.pool_hits.fetch_add(1, Ordering::Relaxed);
352 self.buckets[idx].record_alloc();
353 } else if self.config.pooling_enabled {
354 self.stats.pool_misses.fetch_add(1, Ordering::Relaxed);
355 }
356
357 self.stats.gpu_used.fetch_add(size, Ordering::Relaxed);
359 self.stats.allocations.fetch_add(1, Ordering::Relaxed);
360
361 let new_used = self.stats.gpu_used.load(Ordering::Relaxed);
363 let mut peak = self.stats.gpu_peak.load(Ordering::Relaxed);
364 while new_used > peak {
365 match self.stats.gpu_peak.compare_exchange_weak(
366 peak,
367 new_used,
368 Ordering::Relaxed,
369 Ordering::Relaxed,
370 ) {
371 Ok(_) => break,
372 Err(p) => peak = p,
373 }
374 }
375
376 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
377 let buffer = MemoryBuffer {
378 id,
379 size,
380 bucket_index,
381 is_gpu: true,
382 };
383
384 self.buffers.write().await.insert(
385 id,
386 MemoryBuffer {
387 id,
388 size,
389 bucket_index,
390 is_gpu: true,
391 },
392 );
393
394 Ok(buffer)
395 }
396
397 pub async fn deallocate(&self, buffer: MemoryBuffer) -> AllocResult<()> {
399 let removed = self.buffers.write().await.remove(&buffer.id);
400 if removed.is_none() {
401 return Err(MemoryError::InvalidBuffer { id: buffer.id });
402 }
403
404 if let Some(idx) = buffer.bucket_index {
405 self.buckets[idx].record_dealloc();
406 }
407
408 self.stats
409 .gpu_used
410 .fetch_sub(buffer.size, Ordering::Relaxed);
411 self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
412
413 Ok(())
414 }
415
416 pub async fn allocate_staging(&self, size: u64) -> AllocResult<MemoryBuffer> {
418 let current_used = self.stats.staging_used.load(Ordering::Relaxed);
419 if current_used + size > self.config.max_staging_memory {
420 return Err(MemoryError::OutOfMemory {
421 requested: size,
422 available: self.config.max_staging_memory - current_used,
423 });
424 }
425
426 self.stats.staging_used.fetch_add(size, Ordering::Relaxed);
427 self.stats.allocations.fetch_add(1, Ordering::Relaxed);
428
429 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
430 let buffer = MemoryBuffer {
431 id,
432 size,
433 bucket_index: None,
434 is_gpu: false,
435 };
436
437 self.buffers.write().await.insert(
438 id,
439 MemoryBuffer {
440 id,
441 size,
442 bucket_index: None,
443 is_gpu: false,
444 },
445 );
446
447 Ok(buffer)
448 }
449
450 pub async fn deallocate_staging(&self, buffer: MemoryBuffer) -> AllocResult<()> {
452 let removed = self.buffers.write().await.remove(&buffer.id);
453 if removed.is_none() {
454 return Err(MemoryError::InvalidBuffer { id: buffer.id });
455 }
456
457 self.stats
458 .staging_used
459 .fetch_sub(buffer.size, Ordering::Relaxed);
460 self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
461
462 Ok(())
463 }
464
465 pub fn stats(&self) -> MemoryStats {
467 let gpu_used = self.stats.gpu_used.load(Ordering::Relaxed);
468 let pool_hits = self.stats.pool_hits.load(Ordering::Relaxed);
469 let pool_misses = self.stats.pool_misses.load(Ordering::Relaxed);
470 let total_pool = pool_hits + pool_misses;
471
472 MemoryStats {
473 gpu_total: self.config.max_gpu_memory,
474 gpu_used,
475 gpu_peak: self.stats.gpu_peak.load(Ordering::Relaxed),
476 staging_total: self.config.max_staging_memory,
477 staging_used: self.stats.staging_used.load(Ordering::Relaxed),
478 allocations: self.stats.allocations.load(Ordering::Relaxed),
479 deallocations: self.stats.deallocations.load(Ordering::Relaxed),
480 pool_hit_rate: if total_pool > 0 {
481 pool_hits as f64 / total_pool as f64
482 } else {
483 0.0
484 },
485 pressure_level: self.pressure_level(),
486 }
487 }
488
489 pub fn bucket_stats(&self) -> Vec<BucketStats> {
491 self.buckets.iter().map(|b| b.stats()).collect()
492 }
493
494 pub fn pressure_level(&self) -> PressureLevel {
496 PressureLevel::from_ratio(self.gpu_usage_percent() / 100.0)
497 }
498
499 pub fn gpu_usage_percent(&self) -> f64 {
501 let used = self.stats.gpu_used.load(Ordering::Relaxed) as f64;
502 let total = self.config.max_gpu_memory as f64;
503 (used / total) * 100.0
504 }
505
506 pub async fn request_gc(&self) {
508 tracing::info!(
510 "Memory GC requested, pressure level: {:?}",
511 self.pressure_level()
512 );
513 }
514
515 fn find_bucket(&self, size: u64) -> Option<usize> {
517 self.buckets.iter().position(|b| b.size >= size)
518 }
519}
520
521impl Default for KernelMemoryManager {
522 fn default() -> Self {
523 Self::new(MemoryConfig::default())
524 }
525}
526
527#[derive(Debug)]
529pub struct ReductionBuffer<T> {
530 data: Vec<T>,
532 capacity: usize,
534}
535
536impl<T: Default + Clone> ReductionBuffer<T> {
537 pub fn new(capacity: usize) -> Self {
539 Self {
540 data: vec![T::default(); capacity],
541 capacity,
542 }
543 }
544
545 pub fn capacity(&self) -> usize {
547 self.capacity
548 }
549
550 pub fn as_slice(&self) -> &[T] {
552 &self.data
553 }
554
555 pub fn as_mut_slice(&mut self) -> &mut [T] {
557 &mut self.data
558 }
559
560 pub fn reset(&mut self) {
562 for item in &mut self.data {
563 *item = T::default();
564 }
565 }
566}
567
568pub struct ReductionBufferCache {
570 max_buffers: usize,
571 buffers: Arc<RwLock<Vec<Vec<u8>>>>,
572}
573
574impl ReductionBufferCache {
575 pub fn new(max_buffers: usize) -> Self {
577 Self {
578 max_buffers,
579 buffers: Arc::new(RwLock::new(Vec::new())),
580 }
581 }
582
583 pub async fn get(&self, size: usize) -> Vec<u8> {
585 let mut buffers = self.buffers.write().await;
586
587 if let Some(pos) = buffers.iter().position(|b| b.capacity() >= size) {
589 let mut buf = buffers.remove(pos);
590 buf.resize(size, 0);
591 return buf;
592 }
593
594 vec![0u8; size]
596 }
597
598 pub async fn return_buffer(&self, buffer: Vec<u8>) {
600 let mut buffers = self.buffers.write().await;
601 if buffers.len() < self.max_buffers {
602 buffers.push(buffer);
603 }
604 }
606
607 pub async fn clear(&self) {
609 self.buffers.write().await.clear();
610 }
611}
612
613impl Default for ReductionBufferCache {
614 fn default() -> Self {
615 Self::new(16)
616 }
617}
618
619#[derive(Debug)]
621pub struct AnalyticsContext {
622 pub id: u64,
624 pub max_working_set: u64,
626 allocations: AtomicU64,
628}
629
630impl AnalyticsContext {
631 pub fn new(id: u64, max_working_set: u64) -> Self {
633 Self {
634 id,
635 max_working_set,
636 allocations: AtomicU64::new(0),
637 }
638 }
639
640 pub fn record_allocation(&self, size: u64) -> bool {
642 let current = self.allocations.load(Ordering::Relaxed);
643 if current + size > self.max_working_set {
644 return false;
645 }
646 self.allocations.fetch_add(size, Ordering::Relaxed);
647 true
648 }
649
650 pub fn record_deallocation(&self, size: u64) {
652 self.allocations.fetch_sub(size, Ordering::Relaxed);
653 }
654
655 pub fn current_usage(&self) -> u64 {
657 self.allocations.load(Ordering::Relaxed)
658 }
659
660 pub fn usage_percent(&self) -> f64 {
662 (self.current_usage() as f64 / self.max_working_set as f64) * 100.0
663 }
664}
665
666pub struct AnalyticsContextManager {
668 contexts: Arc<RwLock<HashMap<u64, Arc<AnalyticsContext>>>>,
669 default_working_set: u64,
670 next_id: AtomicU64,
671}
672
673impl AnalyticsContextManager {
674 pub fn new(default_working_set: u64) -> Self {
676 Self {
677 contexts: Arc::new(RwLock::new(HashMap::new())),
678 default_working_set,
679 next_id: AtomicU64::new(1),
680 }
681 }
682
683 pub async fn create_context(&self) -> Arc<AnalyticsContext> {
685 self.create_context_with_size(self.default_working_set)
686 .await
687 }
688
689 pub async fn create_context_with_size(&self, max_working_set: u64) -> Arc<AnalyticsContext> {
691 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
692 let ctx = Arc::new(AnalyticsContext::new(id, max_working_set));
693 self.contexts.write().await.insert(id, ctx.clone());
694 ctx
695 }
696
697 pub async fn get_context(&self, id: u64) -> Option<Arc<AnalyticsContext>> {
699 self.contexts.read().await.get(&id).cloned()
700 }
701
702 pub async fn remove_context(&self, id: u64) {
704 self.contexts.write().await.remove(&id);
705 }
706
707 pub async fn active_contexts(&self) -> usize {
709 self.contexts.read().await.len()
710 }
711}
712
713impl Default for AnalyticsContextManager {
714 fn default() -> Self {
715 Self::new(256 * 1024 * 1024) }
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[tokio::test]
724 async fn test_memory_allocation() {
725 let manager = KernelMemoryManager::new(MemoryConfig::development());
726
727 let buffer = manager.allocate(1024).await.unwrap();
728 assert_eq!(buffer.size, 1024);
729 assert!(buffer.is_gpu);
730
731 let stats = manager.stats();
732 assert_eq!(stats.gpu_used, 1024);
733 assert_eq!(stats.allocations, 1);
734
735 manager.deallocate(buffer).await.unwrap();
736
737 let stats = manager.stats();
738 assert_eq!(stats.gpu_used, 0);
739 assert_eq!(stats.deallocations, 1);
740 }
741
742 #[tokio::test]
743 async fn test_out_of_memory() {
744 let config = MemoryConfig {
745 max_gpu_memory: 1024,
746 ..MemoryConfig::development()
747 };
748 let manager = KernelMemoryManager::new(config);
749
750 let result = manager.allocate(2048).await;
751 assert!(matches!(result, Err(MemoryError::OutOfMemory { .. })));
752 }
753
754 #[tokio::test]
755 async fn test_pressure_levels() {
756 let config = MemoryConfig {
757 max_gpu_memory: 1000,
758 ..MemoryConfig::development()
759 };
760 let manager = KernelMemoryManager::new(config);
761
762 assert_eq!(manager.pressure_level(), PressureLevel::Normal);
763
764 let _buf = manager.allocate(700).await.unwrap();
766 assert_eq!(manager.pressure_level(), PressureLevel::Warning);
767 }
768
769 #[tokio::test]
770 async fn test_reduction_buffer_cache() {
771 let cache = ReductionBufferCache::new(4);
772
773 let buf1 = cache.get(1024).await;
774 assert_eq!(buf1.len(), 1024);
775
776 cache.return_buffer(buf1).await;
777
778 let buf2 = cache.get(512).await;
780 assert_eq!(buf2.len(), 512);
781 }
782
783 #[tokio::test]
784 async fn test_analytics_context() {
785 let manager = AnalyticsContextManager::new(1024);
786
787 let ctx = manager.create_context().await;
788 assert!(ctx.record_allocation(512));
789 assert_eq!(ctx.current_usage(), 512);
790
791 ctx.record_deallocation(256);
792 assert_eq!(ctx.current_usage(), 256);
793 }
794}