Skip to main content

zeph_bench/
baseline.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Write as _;
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use crate::{BenchError, BenchRun};
11
12/// Score delta for a single scenario between memory-on and memory-off runs.
13///
14/// Produced by [`BaselineComparison::compute`] for each scenario that appears
15/// in both runs.
16///
17/// # Examples
18///
19/// ```
20/// use zeph_bench::baseline::ScenarioDelta;
21///
22/// let delta = ScenarioDelta {
23///     scenario_id: "q_001".into(),
24///     score_with_memory: 1.0,
25///     score_without_memory: 0.5,
26///     delta: 0.5,
27/// };
28/// assert!(delta.delta > 0.0, "positive delta means memory helped");
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ScenarioDelta {
32    /// Scenario identifier (matches [`crate::Scenario::id`]).
33    pub scenario_id: String,
34    /// Score from the memory-on run.
35    pub score_with_memory: f64,
36    /// Score from the memory-off run.
37    pub score_without_memory: f64,
38    /// `score_with_memory - score_without_memory`. Positive = memory helped.
39    pub delta: f64,
40}
41
42/// Comparison between two benchmark runs (memory-on vs memory-off).
43///
44/// Use [`BaselineComparison::compute`] to join two [`BenchRun`]s by scenario ID
45/// and compute per-scenario deltas and an aggregate mean delta.
46///
47/// # Examples
48///
49/// ```
50/// use zeph_bench::{BenchRun, RunStatus, ScenarioResult, Aggregate};
51/// use zeph_bench::baseline::BaselineComparison;
52///
53/// fn make_run(run_id: &str, scores: &[(&str, f64)]) -> BenchRun {
54///     BenchRun {
55///         dataset: "test".into(),
56///         model: "model".into(),
57///         run_id: run_id.into(),
58///         started_at: "2026-01-01T00:00:00Z".into(),
59///         finished_at: "2026-01-01T00:01:00Z".into(),
60///         status: RunStatus::Completed,
61///         results: scores.iter().map(|(id, score)| ScenarioResult {
62///             scenario_id: id.to_string(),
63///             score: *score,
64///             response_excerpt: String::new(),
65///             error: None,
66///             elapsed_ms: 0,
67///         }).collect(),
68///         aggregate: Aggregate::default(),
69///     }
70/// }
71///
72/// let on = make_run("r1", &[("s1", 1.0), ("s2", 0.5)]);
73/// let off = make_run("r2", &[("s1", 0.5), ("s2", 0.0)]);
74/// let cmp = BaselineComparison::compute(&on, &off);
75/// assert_eq!(cmp.deltas.len(), 2);
76/// assert!((cmp.aggregate_delta - 0.5).abs() < f64::EPSILON);
77/// ```
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct BaselineComparison {
80    /// Dataset name (from the memory-on run).
81    pub dataset: String,
82    /// Model identifier (from the memory-on run).
83    pub model: String,
84    /// Run ID of the memory-on run.
85    pub run_id_memory_on: String,
86    /// Run ID of the memory-off run.
87    pub run_id_memory_off: String,
88    /// Per-scenario deltas, sorted by `scenario_id`.
89    ///
90    /// Only scenarios present in **both** runs are included (inner join).
91    pub deltas: Vec<ScenarioDelta>,
92    /// Arithmetic mean of all `delta` values. `0.0` if no scenarios overlap.
93    pub aggregate_delta: f64,
94}
95
96impl BaselineComparison {
97    /// Compute deltas by joining `memory_on` and `memory_off` runs on `scenario_id`.
98    ///
99    /// Only scenarios present in **both** runs are included. Non-overlapping
100    /// scenarios are silently dropped. `aggregate_delta` is the arithmetic mean
101    /// of all per-scenario deltas; `0.0` when there are no overlapping scenarios.
102    #[must_use]
103    pub fn compute(memory_on: &BenchRun, memory_off: &BenchRun) -> Self {
104        let off_scores: HashMap<&str, f64> = memory_off
105            .results
106            .iter()
107            .map(|r| (r.scenario_id.as_str(), r.score))
108            .collect();
109
110        let mut deltas: Vec<ScenarioDelta> = memory_on
111            .results
112            .iter()
113            .filter_map(|r| {
114                let score_off = *off_scores.get(r.scenario_id.as_str())?;
115                Some(ScenarioDelta {
116                    scenario_id: r.scenario_id.clone(),
117                    score_with_memory: r.score,
118                    score_without_memory: score_off,
119                    delta: r.score - score_off,
120                })
121            })
122            .collect();
123
124        deltas.sort_by(|a, b| a.scenario_id.cmp(&b.scenario_id));
125
126        #[allow(clippy::cast_precision_loss)]
127        let aggregate_delta = if deltas.is_empty() {
128            0.0
129        } else {
130            deltas.iter().map(|d| d.delta).sum::<f64>() / deltas.len() as f64
131        };
132
133        Self {
134            dataset: memory_on.dataset.clone(),
135            model: memory_on.model.clone(),
136            run_id_memory_on: memory_on.run_id.clone(),
137            run_id_memory_off: memory_off.run_id.clone(),
138            deltas,
139            aggregate_delta,
140        }
141    }
142
143    /// Write this comparison as pretty-printed JSON to `{output_dir}/comparison.json`.
144    ///
145    /// The file is written atomically via a `.tmp` sibling + rename, so a concurrent
146    /// SIGINT cannot leave a half-written file.
147    ///
148    /// # Errors
149    ///
150    /// Returns [`BenchError::InvalidFormat`] on serialization failure and
151    /// [`BenchError::Io`] on write failure.
152    pub fn write_comparison_json(&self, output_dir: &Path) -> Result<(), BenchError> {
153        let json = serde_json::to_string_pretty(self)
154            .map_err(|e| BenchError::InvalidFormat(e.to_string()))?;
155        write_atomic(&output_dir.join("comparison.json"), json.as_bytes())?;
156        Ok(())
157    }
158
159    /// Append a delta table section to the Markdown file at `summary_path`.
160    ///
161    /// Creates the file if it does not exist. The section header is
162    /// `## Baseline Comparison (Memory On vs Off)` followed by a Markdown table
163    /// of per-scenario deltas and a final aggregate delta line.
164    ///
165    /// # Errors
166    ///
167    /// Returns [`BenchError::Io`] on read/write failure.
168    pub fn write_delta_table(&self, summary_path: &Path) -> Result<(), BenchError> {
169        use std::fs::OpenOptions;
170        use std::io::Write as _;
171
172        let mut section = String::new();
173        let _ = writeln!(section);
174        let _ = writeln!(section, "## Baseline Comparison (Memory On vs Off)");
175        let _ = writeln!(section);
176        let _ = writeln!(section, "| scenario_id | memory_on | memory_off | delta |");
177        let _ = writeln!(section, "|-------------|-----------|------------|-------|");
178        for d in &self.deltas {
179            let sign = if d.delta >= 0.0 { "+" } else { "" };
180            let _ = writeln!(
181                section,
182                "| {} | {:.4} | {:.4} | {sign}{:.4} |",
183                d.scenario_id, d.score_with_memory, d.score_without_memory, d.delta
184            );
185        }
186        let sign = if self.aggregate_delta >= 0.0 { "+" } else { "" };
187        let _ = writeln!(
188            section,
189            "\n**Aggregate delta**: {sign}{:.4} (mean score improvement with memory)",
190            self.aggregate_delta
191        );
192
193        let mut file = OpenOptions::new()
194            .append(true)
195            .create(true)
196            .open(summary_path)?;
197        file.write_all(section.as_bytes())?;
198        Ok(())
199    }
200}
201
202/// Write `data` to `path` atomically via a `.tmp` sibling + rename.
203fn write_atomic(path: &Path, data: &[u8]) -> Result<(), std::io::Error> {
204    let tmp = path.with_extension("tmp");
205    std::fs::write(&tmp, data)?;
206    std::fs::rename(&tmp, path)?;
207    Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::{Aggregate, RunStatus, ScenarioResult};
214
215    fn make_run(run_id: &str, scores: &[(&str, f64)]) -> BenchRun {
216        BenchRun {
217            dataset: "test-dataset".into(),
218            model: "test-model".into(),
219            run_id: run_id.into(),
220            started_at: "2026-01-01T00:00:00Z".into(),
221            finished_at: "2026-01-01T00:01:00Z".into(),
222            status: RunStatus::Completed,
223            results: scores
224                .iter()
225                .map(|(id, score)| ScenarioResult {
226                    scenario_id: id.to_string(),
227                    score: *score,
228                    response_excerpt: String::new(),
229                    error: None,
230                    elapsed_ms: 0,
231                })
232                .collect(),
233            aggregate: Aggregate::default(),
234        }
235    }
236
237    #[test]
238    fn compute_correct_aggregate_delta() {
239        let on = make_run("r1", &[("s1", 1.0), ("s2", 0.5)]);
240        let off = make_run("r2", &[("s1", 0.5), ("s2", 0.0)]);
241        let cmp = BaselineComparison::compute(&on, &off);
242        assert_eq!(cmp.deltas.len(), 2);
243        // mean delta = (0.5 + 0.5) / 2 = 0.5
244        assert!((cmp.aggregate_delta - 0.5).abs() < f64::EPSILON);
245    }
246
247    #[test]
248    fn compute_handles_missing_scenarios_gracefully() {
249        // off run has s1 but not s2 — s2 is excluded from deltas
250        let on = make_run("r1", &[("s1", 1.0), ("s2", 0.5)]);
251        let off = make_run("r2", &[("s1", 0.5)]);
252        let cmp = BaselineComparison::compute(&on, &off);
253        assert_eq!(cmp.deltas.len(), 1);
254        assert_eq!(cmp.deltas[0].scenario_id, "s1");
255    }
256
257    #[test]
258    fn compute_empty_overlap_returns_zero_aggregate() {
259        let on = make_run("r1", &[("s1", 1.0)]);
260        let off = make_run("r2", &[("s2", 0.5)]);
261        let cmp = BaselineComparison::compute(&on, &off);
262        assert!(cmp.deltas.is_empty());
263        assert!(cmp.aggregate_delta.abs() < f64::EPSILON);
264    }
265
266    #[test]
267    fn compute_sorts_deltas_by_scenario_id() {
268        let on = make_run("r1", &[("z_last", 1.0), ("a_first", 0.5)]);
269        let off = make_run("r2", &[("z_last", 0.5), ("a_first", 0.0)]);
270        let cmp = BaselineComparison::compute(&on, &off);
271        assert_eq!(cmp.deltas[0].scenario_id, "a_first");
272        assert_eq!(cmp.deltas[1].scenario_id, "z_last");
273    }
274
275    #[test]
276    fn json_round_trip() {
277        let on = make_run("r1", &[("s1", 1.0)]);
278        let off = make_run("r2", &[("s1", 0.5)]);
279        let cmp = BaselineComparison::compute(&on, &off);
280        let json = serde_json::to_string_pretty(&cmp).unwrap();
281        let decoded: BaselineComparison = serde_json::from_str(&json).unwrap();
282        assert_eq!(decoded.dataset, cmp.dataset);
283        assert_eq!(decoded.deltas.len(), 1);
284        assert!((decoded.aggregate_delta - cmp.aggregate_delta).abs() < f64::EPSILON);
285    }
286
287    #[test]
288    fn write_delta_table_appends_section() {
289        let dir = tempfile::tempdir().unwrap();
290        let summary = dir.path().join("summary.md");
291        std::fs::write(&summary, "# Header\n").unwrap();
292        let on = make_run("r1", &[("s1", 1.0)]);
293        let off = make_run("r2", &[("s1", 0.5)]);
294        let cmp = BaselineComparison::compute(&on, &off);
295        cmp.write_delta_table(&summary).unwrap();
296        let content = std::fs::read_to_string(&summary).unwrap();
297        assert!(content.contains("# Header"));
298        assert!(content.contains("## Baseline Comparison"));
299        assert!(content.contains("s1"));
300    }
301
302    #[test]
303    fn write_delta_table_creates_file_if_absent() {
304        let dir = tempfile::tempdir().unwrap();
305        let summary = dir.path().join("new_summary.md");
306        let on = make_run("r1", &[("s1", 1.0)]);
307        let off = make_run("r2", &[("s1", 0.5)]);
308        let cmp = BaselineComparison::compute(&on, &off);
309        cmp.write_delta_table(&summary).unwrap();
310        assert!(summary.exists());
311        let content = std::fs::read_to_string(&summary).unwrap();
312        assert!(content.contains("## Baseline Comparison"));
313    }
314
315    #[test]
316    fn write_comparison_json_round_trip() {
317        let dir = tempfile::tempdir().unwrap();
318        let on = make_run("r1", &[("s1", 1.0)]);
319        let off = make_run("r2", &[("s1", 0.5)]);
320        let cmp = BaselineComparison::compute(&on, &off);
321        cmp.write_comparison_json(dir.path()).unwrap();
322        let json = std::fs::read_to_string(dir.path().join("comparison.json")).unwrap();
323        let decoded: BaselineComparison = serde_json::from_str(&json).unwrap();
324        assert_eq!(decoded.run_id_memory_on, "r1");
325        assert_eq!(decoded.run_id_memory_off, "r2");
326    }
327}