Skip to main content

zer_adapters/
bench_writer.rs

1/// Benchmark result writer for zer accuracy runs.
2///
3/// Writes two output files per run:
4/// - `<run_id>_pairs.ndjson`  , one JSON object per line (streaming pairs)
5/// - `<run_id>_summary.csv`   , single-row CSV in the shared cross-library format
6///
7/// The shared CSV format makes side-by-side comparison with splink
8/// trivial via the `zer-bench compare` subcommand.
9use std::{
10    fs::{self, File},
11    io::{BufWriter, Write as IoWrite},
12    path::{Path, PathBuf},
13};
14
15use zer_core::{error::ZerError, scoring::MatchBand};
16
17/// Aggregate accuracy metrics computed against a ground-truth labels file.
18#[derive(Debug, Clone)]
19pub struct AccuracyMetrics {
20    pub true_pos: usize,
21    pub false_pos: usize,
22    pub false_neg: usize,
23    pub precision: f32,
24    pub recall: f32,
25    pub f1: f32,
26}
27
28impl AccuracyMetrics {
29    /// Compute from counts.  Returns a zero-valued struct when `tp + fp == 0`.
30    pub fn from_counts(true_pos: usize, false_pos: usize, false_neg: usize) -> Self {
31        let precision = if true_pos + false_pos > 0 {
32            true_pos as f32 / (true_pos + false_pos) as f32
33        } else {
34            0.0
35        };
36        let recall = if true_pos + false_neg > 0 {
37            true_pos as f32 / (true_pos + false_neg) as f32
38        } else {
39            0.0
40        };
41        let f1 = if precision + recall > 0.0 {
42            2.0 * precision * recall / (precision + recall)
43        } else {
44            0.0
45        };
46        Self {
47            true_pos,
48            false_pos,
49            false_neg,
50            precision,
51            recall,
52            f1,
53        }
54    }
55}
56
57/// A single scored pair as written to the NDJSON pairs file.
58#[derive(Debug, Clone, serde::Serialize)]
59pub struct PairRecord {
60    pub run_id: String,
61    pub record_key_a: String,
62    pub source_a: Option<String>,
63    pub record_key_b: String,
64    pub source_b: Option<String>,
65    pub match_probability: f32,
66    pub predicted_match: bool,
67}
68
69/// Summary row shared with all benchmark libraries.
70#[derive(Debug, Clone, serde::Serialize)]
71struct SummaryRow {
72    library: String,
73    mode: String,
74    dataset: String,
75    run_id: String,
76    timestamp: String,
77    total_records: usize,
78    candidate_pairs: usize,
79    auto_matched: usize,
80    borderline: usize,
81    auto_rejected: usize,
82    elapsed_ms: u64,
83    true_pos: Option<usize>,
84    false_pos: Option<usize>,
85    false_neg: Option<usize>,
86    precision: Option<f32>,
87    recall: Option<f32>,
88    f1: Option<f32>,
89}
90
91/// A lightweight `BatchReport`-like view used by `BenchResultWriter`.
92///
93/// This avoids a direct dependency on `zer-pipeline` from `zer-adapters`.
94pub struct BenchBatchSummary {
95    pub total_records: usize,
96    pub candidate_pairs: usize,
97    pub auto_matched: usize,
98    pub borderline: usize,
99    pub auto_rejected: usize,
100    pub elapsed_ms: u64,
101    pub link_mode: String,
102    pub dataset: String,
103}
104
105pub struct BenchResultWriter {
106    run_id: String,
107    out_dir: PathBuf,
108}
109
110impl BenchResultWriter {
111    /// Create a new writer.  `out_dir` is created if it does not yet exist.
112    pub fn new(out_dir: &Path, run_id: &str) -> Result<Self, ZerError> {
113        fs::create_dir_all(out_dir)
114            .map_err(|e| ZerError::Store(format!("cannot create output dir: {e}")))?;
115        Ok(Self {
116            run_id: run_id.to_owned(),
117            out_dir: out_dir.to_path_buf(),
118        })
119    }
120
121    /// Write a streaming NDJSON pairs file.  One JSON object per line.
122    pub fn write_pairs(&self, pairs: &[PairRecord]) -> Result<(), ZerError> {
123        let path = self.out_dir.join(format!("{}_pairs.ndjson", self.run_id));
124        let file = File::create(&path)
125            .map_err(|e| ZerError::Store(format!("cannot create pairs file: {e}")))?;
126        let mut w = BufWriter::new(file);
127        for pair in pairs {
128            let line = serde_json::to_string(pair)
129                .map_err(|e| ZerError::Store(format!("JSON serialise error: {e}")))?;
130            writeln!(w, "{line}").map_err(|e| ZerError::Store(format!("write error: {e}")))?;
131        }
132        Ok(())
133    }
134
135    /// Write a single-row summary CSV in the shared cross-library format.
136    ///
137    /// Accuracy columns (`true_pos`, `false_pos`, etc.) are left empty when
138    /// `accuracy` is `None`, suitable for runs without a ground-truth file.
139    /// Uses `"zer"` as the library name.  Call [`Self::write_summary_with_library`]
140    /// when a different name is needed (e.g. `"zer+judge"`).
141    pub fn write_summary(
142        &self,
143        summary: &BenchBatchSummary,
144        accuracy: Option<&AccuracyMetrics>,
145    ) -> Result<(), ZerError> {
146        self.write_summary_with_library(summary, accuracy, "zer")
147    }
148
149    /// Like [`Self::write_summary`] but lets the caller set the `library` column.
150    ///
151    /// Use `"zer"` for the FS-only pipeline and `"zer+judge"` when the
152    /// MiniLM neural judge is enabled, so the comparison table distinguishes
153    /// both operating points.
154    pub fn write_summary_with_library(
155        &self,
156        summary: &BenchBatchSummary,
157        accuracy: Option<&AccuracyMetrics>,
158        library: &str,
159    ) -> Result<(), ZerError> {
160        let path = self.out_dir.join(format!("{}_summary.csv", self.run_id));
161        let file = File::create(&path)
162            .map_err(|e| ZerError::Store(format!("cannot create summary file: {e}")))?;
163
164        let timestamp = crate::time::utc_timestamp_iso();
165        let row = SummaryRow {
166            library: library.to_owned(),
167            mode: summary.link_mode.to_lowercase(),
168            dataset: summary.dataset.clone(),
169            run_id: self.run_id.clone(),
170            timestamp,
171            total_records: summary.total_records,
172            candidate_pairs: summary.candidate_pairs,
173            auto_matched: summary.auto_matched,
174            borderline: summary.borderline,
175            auto_rejected: summary.auto_rejected,
176            elapsed_ms: summary.elapsed_ms,
177            true_pos: accuracy.map(|a| a.true_pos),
178            false_pos: accuracy.map(|a| a.false_pos),
179            false_neg: accuracy.map(|a| a.false_neg),
180            precision: accuracy.map(|a| a.precision),
181            recall: accuracy.map(|a| a.recall),
182            f1: accuracy.map(|a| a.f1),
183        };
184
185        let mut wtr = csv::Writer::from_writer(file);
186        wtr.serialize(&row)
187            .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
188        wtr.flush()
189            .map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
190        Ok(())
191    }
192
193    /// Write a scored-pairs CSV file sorted by score descending.
194    ///
195    /// The file is named `<run_id>_scored_pairs.csv` and contains two columns:
196    /// `score` (f32) and `is_match` (0 or 1).  Separating this from the benchmark
197    /// JSON keeps the JSON small and allows millions of rows without memory cost.
198    pub fn write_scored_pairs_csv(&self, pairs: &[(f32, bool)]) -> Result<(), ZerError> {
199        let path = self
200            .out_dir
201            .join(format!("{}_scored_pairs.csv", self.run_id));
202        let file = File::create(&path)
203            .map_err(|e| ZerError::Store(format!("cannot create scored pairs file: {e}")))?;
204        let mut w = csv::Writer::from_writer(file);
205        w.write_record(["score", "is_match"])
206            .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
207        let mut sorted: Vec<(f32, bool)> = pairs.to_vec();
208        sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
209        for (score, is_match) in &sorted {
210            w.write_record(&[score.to_string(), (*is_match as u8).to_string()])
211                .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
212        }
213        w.flush()
214            .map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
215        Ok(())
216    }
217
218    pub fn run_id(&self) -> &str {
219        &self.run_id
220    }
221
222    pub fn out_dir(&self) -> &Path {
223        &self.out_dir
224    }
225}
226
227/// Convert a `MatchBand` to a bool for the `predicted_match` column.
228pub fn band_to_match(band: MatchBand) -> bool {
229    matches!(band, MatchBand::AutoMatch)
230}
231
232// ── Unit tests ────────────────────────────────────────────────────────────────
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use tempfile::TempDir;
238
239    fn sample_summary(_dir: &TempDir) -> BenchBatchSummary {
240        BenchBatchSummary {
241            total_records: 100,
242            candidate_pairs: 500,
243            auto_matched: 400,
244            borderline: 50,
245            auto_rejected: 50,
246            elapsed_ms: 1200,
247            link_mode: "deduplicate".into(),
248            dataset: "test_dataset".into(),
249        }
250    }
251
252    #[test]
253    fn write_pairs_ndjson_line_count() {
254        let dir = TempDir::new().unwrap();
255        let writer = BenchResultWriter::new(dir.path(), "test_run").unwrap();
256
257        let pairs: Vec<PairRecord> = (0..5)
258            .map(|i| PairRecord {
259                run_id: "test_run".into(),
260                record_key_a: i.to_string(),
261                source_a: Some("brp".into()),
262                record_key_b: (i + 100).to_string(),
263                source_b: Some("kvk".into()),
264                match_probability: 0.9,
265                predicted_match: true,
266            })
267            .collect();
268
269        writer.write_pairs(&pairs).unwrap();
270
271        let path = dir.path().join("test_run_pairs.ndjson");
272        let content = std::fs::read_to_string(&path).unwrap();
273        let lines: Vec<&str> = content.lines().collect();
274        assert_eq!(lines.len(), 5, "NDJSON file must have exactly N lines");
275
276        // Each line must be valid JSON
277        for line in &lines {
278            let v: serde_json::Value = serde_json::from_str(line).unwrap();
279            assert!(v.get("run_id").is_some());
280            assert!(v.get("match_probability").is_some());
281        }
282    }
283
284    #[test]
285    fn write_summary_csv_no_accuracy() {
286        let dir = TempDir::new().unwrap();
287        let writer = BenchResultWriter::new(dir.path(), "run_no_acc").unwrap();
288        let summary = sample_summary(&dir);
289
290        writer.write_summary(&summary, None).unwrap();
291
292        let path = dir.path().join("run_no_acc_summary.csv");
293        let content = std::fs::read_to_string(&path).unwrap();
294        assert!(content.contains("zer"), "library field must be 'zer'");
295        assert!(content.contains("test_dataset"));
296        assert!(content.contains("100")); // total_records
297    }
298
299    #[test]
300    fn write_summary_csv_with_accuracy() {
301        let dir = TempDir::new().unwrap();
302        let writer = BenchResultWriter::new(dir.path(), "run_acc").unwrap();
303        let summary = sample_summary(&dir);
304        let acc = AccuracyMetrics::from_counts(96, 4, 2);
305
306        writer.write_summary(&summary, Some(&acc)).unwrap();
307
308        let path = dir.path().join("run_acc_summary.csv");
309        let content = std::fs::read_to_string(&path).unwrap();
310        assert!(content.contains("96")); // true_pos
311    }
312
313    #[test]
314    fn accuracy_metrics_from_counts() {
315        let acc = AccuracyMetrics::from_counts(90, 10, 5);
316        assert!((acc.precision - 0.9).abs() < 0.001);
317        assert!((acc.recall - (90.0 / 95.0)).abs() < 0.001);
318        assert!(acc.f1 > 0.0 && acc.f1 < 1.0);
319    }
320
321    #[test]
322    fn accuracy_metrics_zero_denominator() {
323        let acc = AccuracyMetrics::from_counts(0, 0, 0);
324        assert_eq!(acc.precision, 0.0);
325        assert_eq!(acc.recall, 0.0);
326        assert_eq!(acc.f1, 0.0);
327    }
328}