Skip to main content

ruvector_profiler/
memory.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2
3#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
4pub struct MemorySnapshot {
5    pub peak_rss_bytes: u64,
6    pub kv_cache_bytes: u64,
7    pub activation_bytes: u64,
8    pub temp_buffer_bytes: u64,
9    pub timestamp_us: u64,
10}
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MemoryReport {
14    pub label: String,
15    pub peak_rss: u64,
16    pub mean_rss: u64,
17    pub kv_cache_total: u64,
18    pub activation_total: u64,
19}
20
21/// Capture current memory via /proc/self/status (Linux) or zero fallback.
22pub fn capture_memory() -> MemorySnapshot {
23    let ts = SystemTime::now()
24        .duration_since(UNIX_EPOCH)
25        .unwrap_or_default()
26        .as_micros() as u64;
27    MemorySnapshot {
28        peak_rss_bytes: read_vm_rss(),
29        kv_cache_bytes: 0,
30        activation_bytes: 0,
31        temp_buffer_bytes: 0,
32        timestamp_us: ts,
33    }
34}
35
36#[cfg(target_os = "linux")]
37fn read_vm_rss() -> u64 {
38    std::fs::read_to_string("/proc/self/status")
39        .ok()
40        .and_then(|s| {
41            s.lines()
42                .find(|l| l.starts_with("VmRSS:"))
43                .and_then(|l| {
44                    l.trim_start_matches("VmRSS:")
45                        .trim()
46                        .trim_end_matches("kB")
47                        .trim()
48                        .parse::<u64>()
49                        .ok()
50                })
51                .map(|kb| kb * 1024)
52        })
53        .unwrap_or(0)
54}
55
56#[cfg(not(target_os = "linux"))]
57fn read_vm_rss() -> u64 {
58    0
59}
60
61pub struct MemoryTracker {
62    pub snapshots: Vec<MemorySnapshot>,
63    pub label: String,
64}
65
66impl MemoryTracker {
67    pub fn new(label: &str) -> Self {
68        Self {
69            snapshots: Vec::new(),
70            label: label.to_string(),
71        }
72    }
73
74    pub fn snapshot(&mut self) {
75        self.snapshots.push(capture_memory());
76    }
77
78    pub fn peak(&self) -> u64 {
79        self.snapshots
80            .iter()
81            .map(|s| s.peak_rss_bytes)
82            .max()
83            .unwrap_or(0)
84    }
85
86    pub fn report(&self) -> MemoryReport {
87        let n = self.snapshots.len().max(1) as u64;
88        MemoryReport {
89            label: self.label.clone(),
90            peak_rss: self.peak(),
91            mean_rss: self.snapshots.iter().map(|s| s.peak_rss_bytes).sum::<u64>() / n,
92            kv_cache_total: self.snapshots.iter().map(|s| s.kv_cache_bytes).sum(),
93            activation_total: self.snapshots.iter().map(|s| s.activation_bytes).sum(),
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn capture_returns_nonzero_timestamp() {
104        assert!(capture_memory().timestamp_us > 0);
105    }
106
107    #[test]
108    fn tracker_peak_empty() {
109        assert_eq!(MemoryTracker::new("x").peak(), 0);
110    }
111
112    #[test]
113    fn tracker_report_aggregates() {
114        let mut t = MemoryTracker::new("test");
115        let mk = |rss, kv, act| MemorySnapshot {
116            peak_rss_bytes: rss,
117            kv_cache_bytes: kv,
118            activation_bytes: act,
119            temp_buffer_bytes: 0,
120            timestamp_us: 1,
121        };
122        t.snapshots.push(mk(100, 10, 20));
123        t.snapshots.push(mk(200, 30, 40));
124        let r = t.report();
125        assert_eq!(
126            (r.peak_rss, r.mean_rss, r.kv_cache_total, r.activation_total),
127            (200, 150, 40, 60)
128        );
129    }
130}