1use std::collections::HashMap;
5use std::fmt::Write as _;
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use crate::{BenchError, BenchRun};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ScenarioDelta {
32 pub scenario_id: String,
34 pub score_with_memory: f64,
36 pub score_without_memory: f64,
38 pub delta: f64,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct BaselineComparison {
80 pub dataset: String,
82 pub model: String,
84 pub run_id_memory_on: String,
86 pub run_id_memory_off: String,
88 pub deltas: Vec<ScenarioDelta>,
92 pub aggregate_delta: f64,
94}
95
96impl BaselineComparison {
97 #[must_use]
103 pub fn compute(memory_on: &BenchRun, memory_off: &BenchRun) -> Self {
104 let off_scores: HashMap<&str, f64> = memory_off
105 .results
106 .iter()
107 .map(|r| (r.scenario_id.as_str(), r.score))
108 .collect();
109
110 let mut deltas: Vec<ScenarioDelta> = memory_on
111 .results
112 .iter()
113 .filter_map(|r| {
114 let score_off = *off_scores.get(r.scenario_id.as_str())?;
115 Some(ScenarioDelta {
116 scenario_id: r.scenario_id.clone(),
117 score_with_memory: r.score,
118 score_without_memory: score_off,
119 delta: r.score - score_off,
120 })
121 })
122 .collect();
123
124 deltas.sort_by(|a, b| a.scenario_id.cmp(&b.scenario_id));
125
126 #[allow(clippy::cast_precision_loss)]
127 let aggregate_delta = if deltas.is_empty() {
128 0.0
129 } else {
130 deltas.iter().map(|d| d.delta).sum::<f64>() / deltas.len() as f64
131 };
132
133 Self {
134 dataset: memory_on.dataset.clone(),
135 model: memory_on.model.clone(),
136 run_id_memory_on: memory_on.run_id.clone(),
137 run_id_memory_off: memory_off.run_id.clone(),
138 deltas,
139 aggregate_delta,
140 }
141 }
142
143 pub fn write_comparison_json(&self, output_dir: &Path) -> Result<(), BenchError> {
153 let json = serde_json::to_string_pretty(self)
154 .map_err(|e| BenchError::InvalidFormat(e.to_string()))?;
155 write_atomic(&output_dir.join("comparison.json"), json.as_bytes())?;
156 Ok(())
157 }
158
159 pub fn write_delta_table(&self, summary_path: &Path) -> Result<(), BenchError> {
169 use std::fs::OpenOptions;
170 use std::io::Write as _;
171
172 let mut section = String::new();
173 let _ = writeln!(section);
174 let _ = writeln!(section, "## Baseline Comparison (Memory On vs Off)");
175 let _ = writeln!(section);
176 let _ = writeln!(section, "| scenario_id | memory_on | memory_off | delta |");
177 let _ = writeln!(section, "|-------------|-----------|------------|-------|");
178 for d in &self.deltas {
179 let sign = if d.delta >= 0.0 { "+" } else { "" };
180 let _ = writeln!(
181 section,
182 "| {} | {:.4} | {:.4} | {sign}{:.4} |",
183 d.scenario_id, d.score_with_memory, d.score_without_memory, d.delta
184 );
185 }
186 let sign = if self.aggregate_delta >= 0.0 { "+" } else { "" };
187 let _ = writeln!(
188 section,
189 "\n**Aggregate delta**: {sign}{:.4} (mean score improvement with memory)",
190 self.aggregate_delta
191 );
192
193 let mut file = OpenOptions::new()
194 .append(true)
195 .create(true)
196 .open(summary_path)?;
197 file.write_all(section.as_bytes())?;
198 Ok(())
199 }
200}
201
202fn write_atomic(path: &Path, data: &[u8]) -> Result<(), std::io::Error> {
204 let tmp = path.with_extension("tmp");
205 std::fs::write(&tmp, data)?;
206 std::fs::rename(&tmp, path)?;
207 Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::{Aggregate, RunStatus, ScenarioResult};
214
215 fn make_run(run_id: &str, scores: &[(&str, f64)]) -> BenchRun {
216 BenchRun {
217 dataset: "test-dataset".into(),
218 model: "test-model".into(),
219 run_id: run_id.into(),
220 started_at: "2026-01-01T00:00:00Z".into(),
221 finished_at: "2026-01-01T00:01:00Z".into(),
222 status: RunStatus::Completed,
223 results: scores
224 .iter()
225 .map(|(id, score)| ScenarioResult {
226 scenario_id: id.to_string(),
227 score: *score,
228 response_excerpt: String::new(),
229 error: None,
230 elapsed_ms: 0,
231 })
232 .collect(),
233 aggregate: Aggregate::default(),
234 }
235 }
236
237 #[test]
238 fn compute_correct_aggregate_delta() {
239 let on = make_run("r1", &[("s1", 1.0), ("s2", 0.5)]);
240 let off = make_run("r2", &[("s1", 0.5), ("s2", 0.0)]);
241 let cmp = BaselineComparison::compute(&on, &off);
242 assert_eq!(cmp.deltas.len(), 2);
243 assert!((cmp.aggregate_delta - 0.5).abs() < f64::EPSILON);
245 }
246
247 #[test]
248 fn compute_handles_missing_scenarios_gracefully() {
249 let on = make_run("r1", &[("s1", 1.0), ("s2", 0.5)]);
251 let off = make_run("r2", &[("s1", 0.5)]);
252 let cmp = BaselineComparison::compute(&on, &off);
253 assert_eq!(cmp.deltas.len(), 1);
254 assert_eq!(cmp.deltas[0].scenario_id, "s1");
255 }
256
257 #[test]
258 fn compute_empty_overlap_returns_zero_aggregate() {
259 let on = make_run("r1", &[("s1", 1.0)]);
260 let off = make_run("r2", &[("s2", 0.5)]);
261 let cmp = BaselineComparison::compute(&on, &off);
262 assert!(cmp.deltas.is_empty());
263 assert!(cmp.aggregate_delta.abs() < f64::EPSILON);
264 }
265
266 #[test]
267 fn compute_sorts_deltas_by_scenario_id() {
268 let on = make_run("r1", &[("z_last", 1.0), ("a_first", 0.5)]);
269 let off = make_run("r2", &[("z_last", 0.5), ("a_first", 0.0)]);
270 let cmp = BaselineComparison::compute(&on, &off);
271 assert_eq!(cmp.deltas[0].scenario_id, "a_first");
272 assert_eq!(cmp.deltas[1].scenario_id, "z_last");
273 }
274
275 #[test]
276 fn json_round_trip() {
277 let on = make_run("r1", &[("s1", 1.0)]);
278 let off = make_run("r2", &[("s1", 0.5)]);
279 let cmp = BaselineComparison::compute(&on, &off);
280 let json = serde_json::to_string_pretty(&cmp).unwrap();
281 let decoded: BaselineComparison = serde_json::from_str(&json).unwrap();
282 assert_eq!(decoded.dataset, cmp.dataset);
283 assert_eq!(decoded.deltas.len(), 1);
284 assert!((decoded.aggregate_delta - cmp.aggregate_delta).abs() < f64::EPSILON);
285 }
286
287 #[test]
288 fn write_delta_table_appends_section() {
289 let dir = tempfile::tempdir().unwrap();
290 let summary = dir.path().join("summary.md");
291 std::fs::write(&summary, "# Header\n").unwrap();
292 let on = make_run("r1", &[("s1", 1.0)]);
293 let off = make_run("r2", &[("s1", 0.5)]);
294 let cmp = BaselineComparison::compute(&on, &off);
295 cmp.write_delta_table(&summary).unwrap();
296 let content = std::fs::read_to_string(&summary).unwrap();
297 assert!(content.contains("# Header"));
298 assert!(content.contains("## Baseline Comparison"));
299 assert!(content.contains("s1"));
300 }
301
302 #[test]
303 fn write_delta_table_creates_file_if_absent() {
304 let dir = tempfile::tempdir().unwrap();
305 let summary = dir.path().join("new_summary.md");
306 let on = make_run("r1", &[("s1", 1.0)]);
307 let off = make_run("r2", &[("s1", 0.5)]);
308 let cmp = BaselineComparison::compute(&on, &off);
309 cmp.write_delta_table(&summary).unwrap();
310 assert!(summary.exists());
311 let content = std::fs::read_to_string(&summary).unwrap();
312 assert!(content.contains("## Baseline Comparison"));
313 }
314
315 #[test]
316 fn write_comparison_json_round_trip() {
317 let dir = tempfile::tempdir().unwrap();
318 let on = make_run("r1", &[("s1", 1.0)]);
319 let off = make_run("r2", &[("s1", 0.5)]);
320 let cmp = BaselineComparison::compute(&on, &off);
321 cmp.write_comparison_json(dir.path()).unwrap();
322 let json = std::fs::read_to_string(dir.path().join("comparison.json")).unwrap();
323 let decoded: BaselineComparison = serde_json::from_str(&json).unwrap();
324 assert_eq!(decoded.run_id_memory_on, "r1");
325 assert_eq!(decoded.run_id_memory_off, "r2");
326 }
327}