1use std::alloc::{GlobalAlloc, Layout};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::time::{Duration, Instant};
10
11static ALLOCATED_BYTES: AtomicUsize = AtomicUsize::new(0);
13static TOTAL_ALLOCATIONS: AtomicUsize = AtomicUsize::new(0);
14static PEAK_MEMORY: AtomicUsize = AtomicUsize::new(0);
15
16#[derive(Debug, Clone)]
18pub struct MemoryStats {
19 pub current_bytes: usize,
21 pub total_allocations: usize,
23 pub peak_bytes: usize,
25 pub timeline: Vec<(Instant, usize)>,
27}
28
29impl MemoryStats {
30 pub fn current() -> Self {
32 Self {
33 current_bytes: ALLOCATED_BYTES.load(Ordering::Relaxed),
34 total_allocations: TOTAL_ALLOCATIONS.load(Ordering::Relaxed),
35 peak_bytes: PEAK_MEMORY.load(Ordering::Relaxed),
36 timeline: Vec::new(),
37 }
38 }
39
40 pub fn reset() {
42 ALLOCATED_BYTES.store(0, Ordering::Relaxed);
43 TOTAL_ALLOCATIONS.store(0, Ordering::Relaxed);
44 PEAK_MEMORY.store(0, Ordering::Relaxed);
45 }
46
47 pub fn format_bytes(bytes: usize) -> String {
49 const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
50 let mut size = bytes as f64;
51 let mut unit_idx = 0;
52
53 while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
54 size /= 1024.0;
55 unit_idx += 1;
56 }
57
58 if unit_idx == 0 {
59 format!("{} {}", bytes, UNITS[unit_idx])
60 } else {
61 format!("{:.2} {}", size, UNITS[unit_idx])
62 }
63 }
64
65 pub fn summary(&self) -> String {
67 format!(
68 "Memory Stats:\n\
69 Current: {}\n\
70 Peak: {}\n\
71 Total Allocations: {}",
72 Self::format_bytes(self.current_bytes),
73 Self::format_bytes(self.peak_bytes),
74 self.total_allocations
75 )
76 }
77}
78
79pub struct MemoryProfiler {
81 start_time: Instant,
82 timeline: Vec<(Instant, usize)>,
83 sampling_interval: Duration,
84 operation_name: String,
85}
86
87impl MemoryProfiler {
88 pub fn new(operation_name: impl Into<String>) -> Self {
90 Self {
91 start_time: Instant::now(),
92 timeline: Vec::new(),
93 sampling_interval: Duration::from_millis(100),
94 operation_name: operation_name.into(),
95 }
96 }
97
98 pub fn with_sampling_interval(mut self, interval: Duration) -> Self {
100 self.sampling_interval = interval;
101 self
102 }
103
104 pub fn start(&mut self) {
106 self.start_time = Instant::now();
107 self.timeline.clear();
108 self.record_sample();
109 }
110
111 pub fn record_sample(&mut self) {
113 let current_memory = ALLOCATED_BYTES.load(Ordering::Relaxed);
114 self.timeline.push((Instant::now(), current_memory));
115 }
116
117 pub fn stop(mut self) -> MemoryProfilingResult {
119 self.record_sample();
120
121 let duration = self.start_time.elapsed();
122 let peak_during_profiling = self
123 .timeline
124 .iter()
125 .map(|(_, bytes)| *bytes)
126 .max()
127 .unwrap_or(0);
128
129 let memory_over_time = self
130 .timeline
131 .into_iter()
132 .map(|(time, bytes)| (time.duration_since(self.start_time), bytes))
133 .collect();
134
135 MemoryProfilingResult {
136 operation_name: self.operation_name,
137 duration,
138 peak_memory: peak_during_profiling,
139 memory_timeline: memory_over_time,
140 final_stats: MemoryStats::current(),
141 }
142 }
143}
144
145#[derive(Debug)]
147pub struct MemoryProfilingResult {
148 pub operation_name: String,
150 pub duration: Duration,
152 pub peak_memory: usize,
154 pub memory_timeline: Vec<(Duration, usize)>,
156 pub final_stats: MemoryStats,
158}
159
160impl MemoryProfilingResult {
161 pub fn report(&self) -> String {
163 let avg_memory = if !self.memory_timeline.is_empty() {
164 self.memory_timeline
165 .iter()
166 .map(|(_, bytes)| *bytes)
167 .sum::<usize>()
168 / self.memory_timeline.len()
169 } else {
170 0
171 };
172
173 let memory_growth = if self.memory_timeline.len() >= 2 {
174 let start_memory = self.memory_timeline[0].1;
175 let end_memory = self.memory_timeline[self.memory_timeline.len() - 1].1;
176 end_memory.saturating_sub(start_memory)
177 } else {
178 0
179 };
180
181 format!(
182 "Memory Profile Report: {}\n\
183 Duration: {:.2}s\n\
184 Peak Memory: {}\n\
185 Average Memory: {}\n\
186 Memory Growth: {}\n\
187 Samples: {}\n\
188 {}",
189 self.operation_name,
190 self.duration.as_secs_f64(),
191 MemoryStats::format_bytes(self.peak_memory),
192 MemoryStats::format_bytes(avg_memory),
193 MemoryStats::format_bytes(memory_growth),
194 self.memory_timeline.len(),
195 self.final_stats.summary()
196 )
197 }
198
199 pub fn export_csv(&self) -> String {
201 let mut csv = String::from("timestamp_ms,memory_bytes\n");
202 for (duration, bytes) in &self.memory_timeline {
203 csv.push_str(&format!("{},{}\n", duration.as_millis(), bytes));
204 }
205 csv
206 }
207
208 pub fn check_memory_leaks(&self) -> Option<String> {
210 if self.memory_timeline.len() < 2 {
211 return None;
212 }
213
214 let start_memory = self.memory_timeline[0].1;
215 let end_memory = self.memory_timeline[self.memory_timeline.len() - 1].1;
216 let growth = end_memory.saturating_sub(start_memory);
217
218 if growth > 10 * 1024 * 1024 {
220 Some(format!(
221 "Potential memory leak detected: {} growth during {}",
222 MemoryStats::format_bytes(growth),
223 self.operation_name
224 ))
225 } else {
226 None
227 }
228 }
229}
230
231#[macro_export]
233macro_rules! profile_memory {
234 ($operation_name:expr, $code:block) => {{
235 let mut profiler = $crate::memory_profiler::MemoryProfiler::new($operation_name);
236 profiler.start();
237 let result = $code;
238 let profile_result = profiler.stop();
239 (result, profile_result)
240 }};
241}
242
243pub struct TrackingAllocator<A: GlobalAlloc> {
245 inner: A,
246}
247
248impl<A: GlobalAlloc> TrackingAllocator<A> {
249 pub const fn new(inner: A) -> Self {
250 Self { inner }
251 }
252}
253
254unsafe impl<A: GlobalAlloc> GlobalAlloc for TrackingAllocator<A> {
255 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
256 let ptr = self.inner.alloc(layout);
257 if !ptr.is_null() {
258 let size = layout.size();
259 let current = ALLOCATED_BYTES.fetch_add(size, Ordering::Relaxed) + size;
260 TOTAL_ALLOCATIONS.fetch_add(1, Ordering::Relaxed);
261
262 let mut peak = PEAK_MEMORY.load(Ordering::Relaxed);
264 while peak < current {
265 match PEAK_MEMORY.compare_exchange_weak(
266 peak,
267 current,
268 Ordering::Relaxed,
269 Ordering::Relaxed,
270 ) {
271 Ok(_) => break,
272 Err(x) => peak = x,
273 }
274 }
275 }
276 ptr
277 }
278
279 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
280 ALLOCATED_BYTES.fetch_sub(layout.size(), Ordering::Relaxed);
281 self.inner.dealloc(ptr, layout);
282 }
283}
284
285pub struct ImputationMemoryBenchmark {
287 results: HashMap<String, MemoryProfilingResult>,
288}
289
290impl ImputationMemoryBenchmark {
291 pub fn new() -> Self {
293 Self {
294 results: HashMap::new(),
295 }
296 }
297
298 pub fn add_result(&mut self, name: String, result: MemoryProfilingResult) {
300 self.results.insert(name, result);
301 }
302
303 pub fn comparison_report(&self) -> String {
305 let mut report = String::from("Memory Usage Comparison:\n");
306 report.push_str("Method\tPeak Memory\tAvg Memory\tDuration\tGrowth\n");
307
308 for (name, result) in &self.results {
309 let avg_memory = if !result.memory_timeline.is_empty() {
310 result
311 .memory_timeline
312 .iter()
313 .map(|(_, bytes)| *bytes)
314 .sum::<usize>()
315 / result.memory_timeline.len()
316 } else {
317 0
318 };
319
320 let growth = if result.memory_timeline.len() >= 2 {
321 let start = result.memory_timeline[0].1;
322 let end = result.memory_timeline[result.memory_timeline.len() - 1].1;
323 end.saturating_sub(start)
324 } else {
325 0
326 };
327
328 report.push_str(&format!(
329 "{}\t{}\t{}\t{:.2}s\t{}\n",
330 name,
331 MemoryStats::format_bytes(result.peak_memory),
332 MemoryStats::format_bytes(avg_memory),
333 result.duration.as_secs_f64(),
334 MemoryStats::format_bytes(growth),
335 ));
336 }
337
338 report
339 }
340
341 pub fn most_efficient(&self) -> Option<(&String, &MemoryProfilingResult)> {
343 self.results
344 .iter()
345 .min_by_key(|(_, result)| result.peak_memory)
346 }
347}
348
349impl Default for ImputationMemoryBenchmark {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355#[allow(non_snake_case)]
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use std::thread;
360
361 #[test]
362 fn test_memory_stats_formatting() {
363 assert_eq!(MemoryStats::format_bytes(500), "500 B");
364 assert_eq!(MemoryStats::format_bytes(1536), "1.50 KB");
365 assert_eq!(MemoryStats::format_bytes(1048576), "1.00 MB");
366 assert_eq!(MemoryStats::format_bytes(2147483648), "2.00 GB");
367 }
368
369 #[test]
370 fn test_memory_profiler() {
371 let mut profiler = MemoryProfiler::new("test_operation");
372 profiler.start();
373
374 thread::sleep(Duration::from_millis(10));
376 profiler.record_sample();
377
378 let result = profiler.stop();
379
380 assert_eq!(result.operation_name, "test_operation");
381 assert!(result.duration.as_millis() >= 10);
382 assert!(!result.memory_timeline.is_empty());
383 }
384
385 #[test]
386 fn test_profiling_result_report() {
387 let result = MemoryProfilingResult {
388 operation_name: "test".to_string(),
389 duration: Duration::from_secs(1),
390 peak_memory: 1024,
391 memory_timeline: vec![
392 (Duration::from_millis(0), 512),
393 (Duration::from_millis(500), 1024),
394 (Duration::from_millis(1000), 768),
395 ],
396 final_stats: MemoryStats::current(),
397 };
398
399 let report = result.report();
400 assert!(report.contains("test"));
401 assert!(report.contains("1.00s"));
402 }
403
404 #[test]
405 fn test_csv_export() {
406 let result = MemoryProfilingResult {
407 operation_name: "test".to_string(),
408 duration: Duration::from_secs(1),
409 peak_memory: 1024,
410 memory_timeline: vec![
411 (Duration::from_millis(0), 512),
412 (Duration::from_millis(1000), 1024),
413 ],
414 final_stats: MemoryStats::current(),
415 };
416
417 let csv = result.export_csv();
418 assert!(csv.contains("timestamp_ms,memory_bytes"));
419 assert!(csv.contains("0,512"));
420 assert!(csv.contains("1000,1024"));
421 }
422
423 #[test]
424 fn test_memory_leak_detection() {
425 let result = MemoryProfilingResult {
426 operation_name: "test".to_string(),
427 duration: Duration::from_secs(1),
428 peak_memory: 20 * 1024 * 1024, memory_timeline: vec![
430 (Duration::from_millis(0), 1024),
431 (Duration::from_millis(1000), 20 * 1024 * 1024), ],
433 final_stats: MemoryStats::current(),
434 };
435
436 let leak_check = result.check_memory_leaks();
437 assert!(leak_check.is_some());
438 assert!(leak_check.unwrap().contains("Potential memory leak"));
439 }
440}