Skip to main content

ruvector_profiler/
memory.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2
3#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
4pub struct MemorySnapshot {
5    pub peak_rss_bytes: u64,
6    pub kv_cache_bytes: u64,
7    pub activation_bytes: u64,
8    pub temp_buffer_bytes: u64,
9    pub timestamp_us: u64,
10}
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MemoryReport {
14    pub label: String,
15    pub peak_rss: u64,
16    pub mean_rss: u64,
17    pub kv_cache_total: u64,
18    pub activation_total: u64,
19}
20
21/// Capture current memory via /proc/self/status (Linux) or zero fallback.
22pub fn capture_memory() -> MemorySnapshot {
23    let ts = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_micros() as u64;
24    MemorySnapshot {
25        peak_rss_bytes: read_vm_rss(),
26        kv_cache_bytes: 0,
27        activation_bytes: 0,
28        temp_buffer_bytes: 0,
29        timestamp_us: ts,
30    }
31}
32
33#[cfg(target_os = "linux")]
34fn read_vm_rss() -> u64 {
35    std::fs::read_to_string("/proc/self/status").ok().and_then(|s| {
36        s.lines()
37            .find(|l| l.starts_with("VmRSS:"))
38            .and_then(|l| l.trim_start_matches("VmRSS:").trim().trim_end_matches("kB").trim().parse::<u64>().ok())
39            .map(|kb| kb * 1024)
40    }).unwrap_or(0)
41}
42
43#[cfg(not(target_os = "linux"))]
44fn read_vm_rss() -> u64 { 0 }
45
46pub struct MemoryTracker {
47    pub snapshots: Vec<MemorySnapshot>,
48    pub label: String,
49}
50
51impl MemoryTracker {
52    pub fn new(label: &str) -> Self {
53        Self { snapshots: Vec::new(), label: label.to_string() }
54    }
55
56    pub fn snapshot(&mut self) { self.snapshots.push(capture_memory()); }
57
58    pub fn peak(&self) -> u64 {
59        self.snapshots.iter().map(|s| s.peak_rss_bytes).max().unwrap_or(0)
60    }
61
62    pub fn report(&self) -> MemoryReport {
63        let n = self.snapshots.len().max(1) as u64;
64        MemoryReport {
65            label: self.label.clone(),
66            peak_rss: self.peak(),
67            mean_rss: self.snapshots.iter().map(|s| s.peak_rss_bytes).sum::<u64>() / n,
68            kv_cache_total: self.snapshots.iter().map(|s| s.kv_cache_bytes).sum(),
69            activation_total: self.snapshots.iter().map(|s| s.activation_bytes).sum(),
70        }
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn capture_returns_nonzero_timestamp() { assert!(capture_memory().timestamp_us > 0); }
80
81    #[test]
82    fn tracker_peak_empty() { assert_eq!(MemoryTracker::new("x").peak(), 0); }
83
84    #[test]
85    fn tracker_report_aggregates() {
86        let mut t = MemoryTracker::new("test");
87        let mk = |rss, kv, act| MemorySnapshot {
88            peak_rss_bytes: rss, kv_cache_bytes: kv, activation_bytes: act,
89            temp_buffer_bytes: 0, timestamp_us: 1,
90        };
91        t.snapshots.push(mk(100, 10, 20));
92        t.snapshots.push(mk(200, 30, 40));
93        let r = t.report();
94        assert_eq!((r.peak_rss, r.mean_rss, r.kv_cache_total, r.activation_total),
95                    (200, 150, 40, 60));
96    }
97}