1use crate::Device;
11use std::collections::{HashMap, VecDeque};
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone)]
19pub struct MemoryAllocation {
20 pub ptr: usize,
22
23 pub size: usize,
25
26 pub allocated_at: Instant,
28
29 pub source: AllocationSource,
31
32 pub memory_type: MemoryType,
34
35 pub device: Option<Device>,
37
38 pub usage_stats: AllocationUsageStats,
40
41 pub lifetime_events: Vec<LifetimeEvent>,
43
44 pub performance_hints: Vec<PerformanceHint>,
46}
47
48#[derive(Debug, Clone)]
53pub struct AllocationSource {
54 pub function: String,
56
57 pub location: Option<(String, u32)>,
59
60 pub stack_depth: usize,
62
63 pub thread_id: u64,
65
66 pub context: AllocationContext,
68}
69
70#[derive(Debug, Clone)]
75pub enum AllocationContext {
76 TensorOperation {
78 operation_name: String,
79 tensor_shape: Vec<usize>,
80 data_type: String,
81 },
82
83 KernelScratch {
85 kernel_name: String,
86 scratch_type: String,
87 },
88
89 IntermediateBuffer {
91 computation_graph_id: String,
92 buffer_purpose: String,
93 },
94
95 ModelParameters {
97 model_name: String,
98 parameter_name: String,
99 },
100
101 UserAllocation { request_id: String },
103
104 InternalAllocation { purpose: String },
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum MemoryType {
113 Device,
115
116 Host,
118
119 Unified,
121
122 Pinned,
124
125 Texture,
127
128 Constant,
130
131 Shared,
133
134 MemoryMapped,
136}
137
138#[derive(Debug, Clone, Default)]
143pub struct AllocationUsageStats {
144 pub access_count: u64,
146
147 pub bytes_read: u64,
149
150 pub bytes_written: u64,
152
153 pub last_accessed: Option<Instant>,
155
156 pub access_frequency: f64,
158
159 pub bandwidth_utilization: f64,
161
162 pub cache_stats: CacheStats,
164}
165
166#[derive(Debug, Clone, Default)]
170pub struct CacheStats {
171 pub l1_hits: u64,
173
174 pub l1_misses: u64,
176
177 pub l2_hits: u64,
179
180 pub l2_misses: u64,
182
183 pub tlb_hits: u64,
185
186 pub tlb_misses: u64,
188}
189
190#[derive(Debug, Clone)]
195pub struct LifetimeEvent {
196 pub timestamp: Instant,
198
199 pub event_type: LifetimeEventType,
201
202 pub details: String,
204}
205
206#[derive(Debug, Clone)]
210pub enum LifetimeEventType {
211 Allocated,
213
214 Accessed { read: bool, write: bool },
216
217 Copied { source: bool, destination: bool },
219
220 Resized { old_size: usize, new_size: usize },
222
223 Deallocated,
225
226 MemoryPressure { pressure_level: PressureLevel },
228
229 Defragmented,
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
237pub enum PressureLevel {
238 None,
239 Low,
240 Medium,
241 High,
242 Critical,
243}
244
245impl Default for PressureLevel {
246 fn default() -> Self {
247 PressureLevel::None
248 }
249}
250
251#[derive(Debug, Clone)]
256pub struct PerformanceHint {
257 pub hint_type: PerformanceHintType,
259
260 pub severity: HintSeverity,
262
263 pub description: String,
265
266 pub suggested_action: String,
268
269 pub impact_estimate: f64,
271}
272
273#[derive(Debug, Clone)]
277pub enum PerformanceHintType {
278 SuboptimalAccessPattern,
280
281 InefficientSize,
283
284 SuboptimalMemoryType,
286
287 ExcessiveAllocations,
289
290 Fragmentation,
292
293 UnusedMemory,
295
296 PoorCacheLocality,
298
299 BandwidthUnderutilization,
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
307pub enum HintSeverity {
308 Info,
309 Warning,
310 Critical,
311}
312
313#[derive(Debug, Clone)]
318pub struct AccessPattern {
319 pub access_times: VecDeque<Instant>,
321
322 pub access_sizes: VecDeque<usize>,
324
325 pub access_types: VecDeque<AccessType>,
327
328 pub sequential_score: f64,
330
331 pub random_score: f64,
333
334 pub temporal_locality: f64,
336
337 pub spatial_locality: f64,
339
340 pub frequency: f64,
342
343 pub last_analysis: Option<Instant>,
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq)]
351pub enum AccessType {
352 Read,
353 Write,
354 ReadWrite,
355}
356
357#[derive(Debug)]
362pub struct AllocationTracker {
363 allocations: HashMap<usize, MemoryAllocation>,
365
366 access_patterns: HashMap<usize, AccessPattern>,
368
369 total_allocations: u64,
371
372 total_bytes: u64,
374
375 performance_hints: Vec<PerformanceHint>,
377}
378
379impl MemoryAllocation {
380 pub fn new(
382 ptr: usize,
383 size: usize,
384 source: AllocationSource,
385 memory_type: MemoryType,
386 device: Option<Device>,
387 ) -> Self {
388 Self {
389 ptr,
390 size,
391 allocated_at: Instant::now(),
392 source,
393 memory_type,
394 device,
395 usage_stats: AllocationUsageStats::default(),
396 lifetime_events: vec![LifetimeEvent {
397 timestamp: Instant::now(),
398 event_type: LifetimeEventType::Allocated,
399 details: format!("Allocated {} bytes at {:p}", size, ptr as *const u8),
400 }],
401 performance_hints: Vec::new(),
402 }
403 }
404
405 pub fn record_access(&mut self, access_type: AccessType, bytes: usize) {
407 let now = Instant::now();
408
409 self.usage_stats.access_count += 1;
410 self.usage_stats.last_accessed = Some(now);
411
412 match access_type {
413 AccessType::Read => self.usage_stats.bytes_read += bytes as u64,
414 AccessType::Write => self.usage_stats.bytes_written += bytes as u64,
415 AccessType::ReadWrite => {
416 self.usage_stats.bytes_read += bytes as u64;
417 self.usage_stats.bytes_written += bytes as u64;
418 }
419 }
420
421 self.lifetime_events.push(LifetimeEvent {
423 timestamp: now,
424 event_type: LifetimeEventType::Accessed {
425 read: matches!(access_type, AccessType::Read | AccessType::ReadWrite),
426 write: matches!(access_type, AccessType::Write | AccessType::ReadWrite),
427 },
428 details: format!("Accessed {} bytes ({:?})", bytes, access_type),
429 });
430
431 self.update_access_frequency();
433 }
434
435 pub fn add_performance_hint(&mut self, hint: PerformanceHint) {
437 self.performance_hints.push(hint);
438 }
439
440 pub fn age(&self) -> Duration {
442 Instant::now().duration_since(self.allocated_at)
443 }
444
445 pub fn is_active(&self) -> bool {
447 !self
448 .lifetime_events
449 .iter()
450 .any(|event| matches!(event.event_type, LifetimeEventType::Deallocated))
451 }
452
453 pub fn total_bytes_accessed(&self) -> u64 {
455 self.usage_stats.bytes_read + self.usage_stats.bytes_written
456 }
457
458 fn update_access_frequency(&mut self) {
460 let now = Instant::now();
461 let age = now.duration_since(self.allocated_at).as_secs_f64();
462
463 if age > 0.0 {
464 self.usage_stats.access_frequency = self.usage_stats.access_count as f64 / age;
465 }
466 }
467}
468
469impl AllocationSource {
470 pub fn new(
472 function: String,
473 location: Option<(String, u32)>,
474 thread_id: u64,
475 context: AllocationContext,
476 ) -> Self {
477 Self {
478 function,
479 location,
480 stack_depth: 0, thread_id,
482 context,
483 }
484 }
485
486 pub fn description(&self) -> String {
488 let location_str = if let Some((file, line)) = &self.location {
489 format!(" at {}:{}", file, line)
490 } else {
491 String::new()
492 };
493
494 format!("{}{}", self.function, location_str)
495 }
496}
497
498impl AccessPattern {
499 pub fn new() -> Self {
501 Self {
502 access_times: VecDeque::new(),
503 access_sizes: VecDeque::new(),
504 access_types: VecDeque::new(),
505 sequential_score: 0.0,
506 random_score: 0.0,
507 temporal_locality: 0.0,
508 spatial_locality: 0.0,
509 frequency: 0.0,
510 last_analysis: None,
511 }
512 }
513
514 pub fn record_access(&mut self, access_type: AccessType, size: usize) {
516 let now = Instant::now();
517
518 self.access_times.push_back(now);
519 self.access_sizes.push_back(size);
520 self.access_types.push_back(access_type);
521
522 const MAX_TRACKED_ACCESSES: usize = 1000;
524 if self.access_times.len() > MAX_TRACKED_ACCESSES {
525 self.access_times.pop_front();
526 self.access_sizes.pop_front();
527 self.access_types.pop_front();
528 }
529
530 self.update_frequency();
532 }
533
534 pub fn analyze_patterns(&mut self) {
536 if self.access_times.len() < 2 {
537 return;
538 }
539
540 self.analyze_sequentiality();
541 self.analyze_locality();
542 self.last_analysis = Some(Instant::now());
543 }
544
545 pub fn needs_analysis(&self) -> bool {
547 const ANALYSIS_INTERVAL: Duration = Duration::from_secs(30);
548
549 match self.last_analysis {
550 Some(last) => Instant::now().duration_since(last) > ANALYSIS_INTERVAL,
551 None => self.access_times.len() >= 10,
552 }
553 }
554
555 fn update_frequency(&mut self) {
557 if self.access_times.len() < 2 {
558 return;
559 }
560
561 let first = self
562 .access_times
563 .front()
564 .expect("access_times should not be empty after guard");
565 let last = self
566 .access_times
567 .back()
568 .expect("access_times should not be empty after guard");
569 let duration = last.duration_since(*first).as_secs_f64();
570
571 let effective_duration = duration.max(1e-6);
575 self.frequency = self.access_times.len() as f64 / effective_duration;
576 }
577
578 fn analyze_sequentiality(&mut self) {
580 if self.access_sizes.len() < 3 {
581 return;
582 }
583
584 let mut sequential_count = 0;
585 let mut total_comparisons = 0;
586
587 let access_sizes_vec: Vec<_> = self.access_sizes.iter().collect();
588 for window in access_sizes_vec.windows(2) {
589 let diff = if *window[1] > *window[0] {
590 *window[1] - *window[0]
591 } else {
592 *window[0] - *window[1]
593 };
594
595 if diff <= *window[0] {
597 sequential_count += 1;
598 }
599 total_comparisons += 1;
600 }
601
602 if total_comparisons > 0 {
603 self.sequential_score = sequential_count as f64 / total_comparisons as f64;
604 self.random_score = 1.0 - self.sequential_score;
605 }
606 }
607
608 fn analyze_locality(&mut self) {
610 if self.access_times.len() < 3 {
611 return;
612 }
613
614 let recent_window = Duration::from_secs(1);
616 let now = Instant::now();
617 let recent_accesses = self
618 .access_times
619 .iter()
620 .filter(|&time| now.duration_since(*time) < recent_window)
621 .count();
622
623 self.temporal_locality = recent_accesses as f64 / self.access_times.len() as f64;
624
625 let mut locality_score = 0.0;
627 let mut comparisons = 0;
628
629 let access_sizes_vec2: Vec<_> = self.access_sizes.iter().collect();
630 for window in access_sizes_vec2.windows(3) {
631 let size_var = ((*window[0] as f64 - *window[1] as f64).powi(2)
632 + (*window[1] as f64 - *window[2] as f64).powi(2))
633 / 2.0;
634
635 locality_score += 1.0 / (1.0 + size_var);
637 comparisons += 1;
638 }
639
640 if comparisons > 0 {
641 self.spatial_locality = locality_score / comparisons as f64;
642 }
643 }
644}
645
646impl AllocationTracker {
647 pub fn new() -> Self {
649 Self {
650 allocations: HashMap::new(),
651 access_patterns: HashMap::new(),
652 total_allocations: 0,
653 total_bytes: 0,
654 performance_hints: Vec::new(),
655 }
656 }
657
658 pub fn track_allocation(&mut self, allocation: MemoryAllocation) {
660 self.total_allocations += 1;
661 self.total_bytes += allocation.size as u64;
662
663 self.access_patterns
665 .insert(allocation.ptr, AccessPattern::new());
666
667 self.allocations.insert(allocation.ptr, allocation);
668 }
669
670 pub fn record_access(&mut self, ptr: usize, access_type: AccessType, bytes: usize) {
672 if let Some(allocation) = self.allocations.get_mut(&ptr) {
674 allocation.record_access(access_type, bytes);
675 }
676
677 if let Some(pattern) = self.access_patterns.get_mut(&ptr) {
679 pattern.record_access(access_type, bytes);
680
681 if pattern.needs_analysis() {
682 pattern.analyze_patterns();
683 self.generate_performance_hints(ptr);
684 }
685 }
686 }
687
688 pub fn untrack_allocation(&mut self, ptr: usize) {
690 if let Some(mut allocation) = self.allocations.remove(&ptr) {
691 allocation.lifetime_events.push(LifetimeEvent {
693 timestamp: Instant::now(),
694 event_type: LifetimeEventType::Deallocated,
695 details: "Memory deallocated".to_string(),
696 });
697
698 self.total_bytes = self.total_bytes.saturating_sub(allocation.size as u64);
699 }
700
701 self.access_patterns.remove(&ptr);
702 }
703
704 pub fn get_allocation(&self, ptr: usize) -> Option<&MemoryAllocation> {
706 self.allocations.get(&ptr)
707 }
708
709 pub fn get_access_pattern(&self, ptr: usize) -> Option<&AccessPattern> {
711 self.access_patterns.get(&ptr)
712 }
713
714 pub fn active_allocations(&self) -> impl Iterator<Item = &MemoryAllocation> {
716 self.allocations.values().filter(|alloc| alloc.is_active())
717 }
718
719 pub fn total_memory_usage(&self) -> usize {
721 self.allocations
722 .values()
723 .filter(|alloc| alloc.is_active())
724 .map(|alloc| alloc.size)
725 .sum()
726 }
727
728 fn generate_performance_hints(&mut self, ptr: usize) {
730 let (allocation, pattern) =
731 match (self.allocations.get(&ptr), self.access_patterns.get(&ptr)) {
732 (Some(alloc), Some(pat)) => (alloc, pat),
733 _ => return,
734 };
735
736 let mut hints = Vec::new();
737
738 if pattern.spatial_locality < 0.3 {
740 hints.push(PerformanceHint {
741 hint_type: PerformanceHintType::PoorCacheLocality,
742 severity: HintSeverity::Warning,
743 description: "Poor spatial locality detected in memory accesses".to_string(),
744 suggested_action: "Consider reorganizing data layout or access patterns"
745 .to_string(),
746 impact_estimate: 0.2,
747 });
748 }
749
750 if pattern.random_score > 0.7 {
752 hints.push(PerformanceHint {
753 hint_type: PerformanceHintType::SuboptimalAccessPattern,
754 severity: HintSeverity::Info,
755 description: "Random access pattern detected".to_string(),
756 suggested_action: "Consider prefetching or data reorganization".to_string(),
757 impact_estimate: 0.15,
758 });
759 }
760
761 if allocation.usage_stats.access_count == 0 && allocation.age() > Duration::from_secs(60) {
763 hints.push(PerformanceHint {
764 hint_type: PerformanceHintType::UnusedMemory,
765 severity: HintSeverity::Warning,
766 description: "Memory allocated but never accessed".to_string(),
767 suggested_action: "Consider deallocating unused memory".to_string(),
768 impact_estimate: 0.1,
769 });
770 }
771
772 self.performance_hints.extend(hints);
774 }
775
776 pub fn performance_hints(&self) -> &[PerformanceHint] {
778 &self.performance_hints
779 }
780
781 pub fn clear_old_hints(&mut self) {
783 self.performance_hints.clear();
785 }
786}
787
788impl Default for AllocationTracker {
789 fn default() -> Self {
790 Self::new()
791 }
792}
793
794impl Default for AccessPattern {
795 fn default() -> Self {
796 Self::new()
797 }
798}
799
800impl std::fmt::Display for MemoryType {
801 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
802 match self {
803 MemoryType::Device => write!(f, "Device"),
804 MemoryType::Host => write!(f, "Host"),
805 MemoryType::Unified => write!(f, "Unified"),
806 MemoryType::Pinned => write!(f, "Pinned"),
807 MemoryType::Texture => write!(f, "Texture"),
808 MemoryType::Constant => write!(f, "Constant"),
809 MemoryType::Shared => write!(f, "Shared"),
810 MemoryType::MemoryMapped => write!(f, "MemoryMapped"),
811 }
812 }
813}
814
815impl std::fmt::Display for AccessType {
816 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
817 match self {
818 AccessType::Read => write!(f, "Read"),
819 AccessType::Write => write!(f, "Write"),
820 AccessType::ReadWrite => write!(f, "ReadWrite"),
821 }
822 }
823}
824
825impl std::fmt::Display for PressureLevel {
826 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
827 match self {
828 PressureLevel::None => write!(f, "None"),
829 PressureLevel::Low => write!(f, "Low"),
830 PressureLevel::Medium => write!(f, "Medium"),
831 PressureLevel::High => write!(f, "High"),
832 PressureLevel::Critical => write!(f, "Critical"),
833 }
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 #[test]
842 fn test_memory_allocation_creation() {
843 let source = AllocationSource::new(
844 "test_function".to_string(),
845 Some(("test.rs".to_string(), 42)),
846 123,
847 AllocationContext::UserAllocation {
848 request_id: "test".to_string(),
849 },
850 );
851
852 let allocation = MemoryAllocation::new(0x1000, 1024, source, MemoryType::Host, None);
853
854 assert_eq!(allocation.ptr, 0x1000);
855 assert_eq!(allocation.size, 1024);
856 assert_eq!(allocation.memory_type, MemoryType::Host);
857 assert!(allocation.is_active());
858 }
859
860 #[test]
861 fn test_allocation_access_tracking() {
862 let source = AllocationSource::new(
863 "test_function".to_string(),
864 None,
865 123,
866 AllocationContext::UserAllocation {
867 request_id: "test".to_string(),
868 },
869 );
870
871 let mut allocation = MemoryAllocation::new(0x1000, 1024, source, MemoryType::Host, None);
872
873 allocation.record_access(AccessType::Read, 512);
874 allocation.record_access(AccessType::Write, 256);
875
876 assert_eq!(allocation.usage_stats.access_count, 2);
877 assert_eq!(allocation.usage_stats.bytes_read, 512);
878 assert_eq!(allocation.usage_stats.bytes_written, 256);
879 assert_eq!(allocation.total_bytes_accessed(), 768);
880 }
881
882 #[test]
883 fn test_access_pattern_tracking() {
884 let mut pattern = AccessPattern::new();
885
886 pattern.record_access(AccessType::Read, 1024);
887 pattern.record_access(AccessType::Read, 1024);
888 pattern.record_access(AccessType::Write, 2048);
889
890 assert_eq!(pattern.access_types.len(), 3);
891 assert!(pattern.frequency > 0.0);
892 }
893
894 #[test]
895 fn test_allocation_tracker() {
896 let mut tracker = AllocationTracker::new();
897
898 let source = AllocationSource::new(
899 "test_function".to_string(),
900 None,
901 123,
902 AllocationContext::UserAllocation {
903 request_id: "test".to_string(),
904 },
905 );
906
907 let allocation = MemoryAllocation::new(0x1000, 1024, source, MemoryType::Host, None);
908
909 tracker.track_allocation(allocation);
910 tracker.record_access(0x1000, AccessType::Read, 512);
911
912 assert_eq!(tracker.total_memory_usage(), 1024);
913 assert!(tracker.get_allocation(0x1000).is_some());
914 assert!(tracker.get_access_pattern(0x1000).is_some());
915 }
916
917 #[test]
918 fn test_memory_type_display() {
919 assert_eq!(format!("{}", MemoryType::Device), "Device");
920 assert_eq!(format!("{}", MemoryType::Host), "Host");
921 assert_eq!(format!("{}", MemoryType::Unified), "Unified");
922 }
923
924 #[test]
925 fn test_allocation_source_description() {
926 let source = AllocationSource::new(
927 "test_function".to_string(),
928 Some(("test.rs".to_string(), 42)),
929 123,
930 AllocationContext::UserAllocation {
931 request_id: "test".to_string(),
932 },
933 );
934
935 let description = source.description();
936 assert!(description.contains("test_function"));
937 assert!(description.contains("test.rs:42"));
938 }
939}