Skip to main content

trustformers_debug/profiler/
memory.rs

1//! Memory tracking and allocation analysis
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::SystemTime;
6use uuid::Uuid;
7
8/// Memory allocation tracking
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct MemoryAllocation {
11    pub allocation_id: Uuid,
12    pub size_bytes: usize,
13    pub allocation_type: MemoryAllocationType,
14    pub device_id: Option<i32>,
15    pub timestamp: SystemTime,
16    pub stack_trace: Vec<String>,
17    pub freed: bool,
18    pub free_timestamp: Option<SystemTime>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub enum MemoryAllocationType {
23    Host,
24    Device,
25    Unified,
26    Pinned,
27    Mapped,
28}
29
30/// Memory allocation tracker
31#[derive(Debug)]
32pub struct MemoryTracker {
33    pub(crate) allocations: HashMap<Uuid, MemoryAllocation>,
34    pub(crate) total_allocated: usize,
35    pub(crate) peak_allocated: usize,
36    pub(crate) allocation_count: usize,
37    pub(crate) deallocation_count: usize,
38}
39
40impl Default for MemoryTracker {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl MemoryTracker {
47    pub fn new() -> Self {
48        Self {
49            allocations: HashMap::new(),
50            total_allocated: 0,
51            peak_allocated: 0,
52            allocation_count: 0,
53            deallocation_count: 0,
54        }
55    }
56
57    pub fn track_allocation(&mut self, allocation: MemoryAllocation) {
58        self.total_allocated += allocation.size_bytes;
59        self.allocation_count += 1;
60
61        if self.total_allocated > self.peak_allocated {
62            self.peak_allocated = self.total_allocated;
63        }
64
65        self.allocations.insert(allocation.allocation_id, allocation);
66    }
67
68    pub fn track_deallocation(&mut self, allocation_id: Uuid) {
69        if let Some(mut allocation) = self.allocations.remove(&allocation_id) {
70            allocation.freed = true;
71            allocation.free_timestamp = Some(SystemTime::now());
72            self.total_allocated = self.total_allocated.saturating_sub(allocation.size_bytes);
73            self.deallocation_count += 1;
74        }
75    }
76
77    pub fn get_memory_stats(&self) -> MemoryStats {
78        MemoryStats {
79            total_allocated: self.total_allocated,
80            peak_allocated: self.peak_allocated,
81            active_allocations: self.allocations.len(),
82            allocation_count: self.allocation_count,
83            deallocation_count: self.deallocation_count,
84            memory_efficiency: if self.allocation_count > 0 {
85                self.deallocation_count as f64 / self.allocation_count as f64
86            } else {
87                1.0
88            },
89        }
90    }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct MemoryStats {
95    pub total_allocated: usize,
96    pub peak_allocated: usize,
97    pub active_allocations: usize,
98    pub allocation_count: usize,
99    pub deallocation_count: usize,
100    pub memory_efficiency: f64,
101}
102
103/// Memory efficiency analysis results
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct MemoryEfficiencyAnalysis {
106    pub peak_memory_mb: f64,
107    pub min_memory_mb: f64,
108    pub avg_memory_mb: f64,
109    pub memory_variance: f64,
110    pub efficiency_score: f64,
111}
112
113impl Default for MemoryEfficiencyAnalysis {
114    fn default() -> Self {
115        Self {
116            peak_memory_mb: 0.0,
117            min_memory_mb: 0.0,
118            avg_memory_mb: 0.0,
119            memory_variance: 0.0,
120            efficiency_score: 100.0,
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_memory_tracker_new() {
131        let tracker = MemoryTracker::new();
132        let stats = tracker.get_memory_stats();
133        assert_eq!(stats.total_allocated, 0);
134        assert_eq!(stats.peak_allocated, 0);
135        assert_eq!(stats.active_allocations, 0);
136        assert_eq!(stats.allocation_count, 0);
137        assert_eq!(stats.deallocation_count, 0);
138        assert!((stats.memory_efficiency - 1.0).abs() < 1e-9);
139    }
140
141    #[test]
142    fn test_memory_tracker_track_allocation() {
143        let mut tracker = MemoryTracker::new();
144        let alloc_id = Uuid::new_v4();
145        let allocation = MemoryAllocation {
146            allocation_id: alloc_id,
147            size_bytes: 1024,
148            allocation_type: MemoryAllocationType::Host,
149            device_id: None,
150            timestamp: SystemTime::now(),
151            stack_trace: vec!["frame1".to_string()],
152            freed: false,
153            free_timestamp: None,
154        };
155        tracker.track_allocation(allocation);
156        let stats = tracker.get_memory_stats();
157        assert_eq!(stats.total_allocated, 1024);
158        assert_eq!(stats.peak_allocated, 1024);
159        assert_eq!(stats.active_allocations, 1);
160        assert_eq!(stats.allocation_count, 1);
161    }
162
163    #[test]
164    fn test_memory_tracker_multiple_allocations_peak() {
165        let mut tracker = MemoryTracker::new();
166        for size in [512, 1024, 256] {
167            let allocation = MemoryAllocation {
168                allocation_id: Uuid::new_v4(),
169                size_bytes: size,
170                allocation_type: MemoryAllocationType::Host,
171                device_id: None,
172                timestamp: SystemTime::now(),
173                stack_trace: Vec::new(),
174                freed: false,
175                free_timestamp: None,
176            };
177            tracker.track_allocation(allocation);
178        }
179        let stats = tracker.get_memory_stats();
180        assert_eq!(stats.total_allocated, 512 + 1024 + 256);
181        assert_eq!(stats.peak_allocated, 512 + 1024 + 256);
182        assert_eq!(stats.allocation_count, 3);
183    }
184
185    #[test]
186    fn test_memory_tracker_deallocation() {
187        let mut tracker = MemoryTracker::new();
188        let alloc_id = Uuid::new_v4();
189        let allocation = MemoryAllocation {
190            allocation_id: alloc_id,
191            size_bytes: 2048,
192            allocation_type: MemoryAllocationType::Device,
193            device_id: Some(0),
194            timestamp: SystemTime::now(),
195            stack_trace: Vec::new(),
196            freed: false,
197            free_timestamp: None,
198        };
199        tracker.track_allocation(allocation);
200        tracker.track_deallocation(alloc_id);
201        let stats = tracker.get_memory_stats();
202        assert_eq!(stats.total_allocated, 0);
203        assert_eq!(stats.peak_allocated, 2048);
204        assert_eq!(stats.deallocation_count, 1);
205        assert_eq!(stats.active_allocations, 0);
206    }
207
208    #[test]
209    fn test_memory_tracker_dealloc_nonexistent() {
210        let mut tracker = MemoryTracker::new();
211        tracker.track_deallocation(Uuid::new_v4());
212        let stats = tracker.get_memory_stats();
213        assert_eq!(stats.deallocation_count, 0);
214        assert_eq!(stats.total_allocated, 0);
215    }
216
217    #[test]
218    fn test_memory_tracker_efficiency_with_allocations() {
219        let mut tracker = MemoryTracker::new();
220        let ids: Vec<Uuid> = (0..4)
221            .map(|_| {
222                let id = Uuid::new_v4();
223                tracker.track_allocation(MemoryAllocation {
224                    allocation_id: id,
225                    size_bytes: 100,
226                    allocation_type: MemoryAllocationType::Host,
227                    device_id: None,
228                    timestamp: SystemTime::now(),
229                    stack_trace: Vec::new(),
230                    freed: false,
231                    free_timestamp: None,
232                });
233                id
234            })
235            .collect();
236        tracker.track_deallocation(ids[0]);
237        tracker.track_deallocation(ids[1]);
238        let stats = tracker.get_memory_stats();
239        assert!((stats.memory_efficiency - 0.5).abs() < 1e-9);
240    }
241}