Skip to main content

zer_adapters/
bench_writer.rs

1/// Benchmark result writer for zer accuracy runs.
2///
3/// Writes two output files per run:
4/// - `<run_id>_pairs.ndjson`  , one JSON object per line (streaming pairs)
5/// - `<run_id>_summary.csv`   , single-row CSV in the shared cross-library format
6///
7/// The shared CSV format makes side-by-side comparison with splink
8/// trivial via the `zer-bench compare` subcommand.
9use std::{
10    fs::{self, File},
11    io::{BufWriter, Write as IoWrite},
12    path::{Path, PathBuf},
13};
14
15use zer_core::{error::ZerError, record::RecordId, scoring::MatchBand};
16
17/// Aggregate accuracy metrics computed against a ground-truth labels file.
18#[derive(Debug, Clone)]
19pub struct AccuracyMetrics {
20    pub true_pos:  usize,
21    pub false_pos: usize,
22    pub false_neg: usize,
23    pub precision: f32,
24    pub recall:    f32,
25    pub f1:        f32,
26}
27
28impl AccuracyMetrics {
29    /// Compute from counts.  Returns a zero-valued struct when `tp + fp == 0`.
30    pub fn from_counts(true_pos: usize, false_pos: usize, false_neg: usize) -> Self {
31        let precision = if true_pos + false_pos > 0 {
32            true_pos as f32 / (true_pos + false_pos) as f32
33        } else {
34            0.0
35        };
36        let recall = if true_pos + false_neg > 0 {
37            true_pos as f32 / (true_pos + false_neg) as f32
38        } else {
39            0.0
40        };
41        let f1 = if precision + recall > 0.0 {
42            2.0 * precision * recall / (precision + recall)
43        } else {
44            0.0
45        };
46        Self { true_pos, false_pos, false_neg, precision, recall, f1 }
47    }
48}
49
50/// A single scored pair as written to the NDJSON pairs file.
51#[derive(Debug, Clone, serde::Serialize)]
52pub struct PairRecord {
53    pub run_id:            String,
54    pub record_id_a:       RecordId,
55    pub source_a:          Option<String>,
56    pub record_id_b:       RecordId,
57    pub source_b:          Option<String>,
58    pub match_probability: f32,
59    pub predicted_match:   bool,
60}
61
62/// Summary row shared with all benchmark libraries.
63#[derive(Debug, Clone, serde::Serialize)]
64struct SummaryRow {
65    library:         String,
66    mode:            String,
67    dataset:         String,
68    run_id:          String,
69    timestamp:       String,
70    total_records:   usize,
71    candidate_pairs: usize,
72    auto_matched:    usize,
73    borderline:      usize,
74    auto_rejected:   usize,
75    elapsed_ms:      u64,
76    true_pos:        Option<usize>,
77    false_pos:       Option<usize>,
78    false_neg:       Option<usize>,
79    precision:       Option<f32>,
80    recall:          Option<f32>,
81    f1:              Option<f32>,
82}
83
84/// A lightweight `BatchReport`-like view used by `BenchResultWriter`.
85///
86/// This avoids a direct dependency on `zer-pipeline` from `zer-adapters`.
87pub struct BenchBatchSummary {
88    pub total_records:   usize,
89    pub candidate_pairs: usize,
90    pub auto_matched:    usize,
91    pub borderline:      usize,
92    pub auto_rejected:   usize,
93    pub elapsed_ms:      u64,
94    pub link_mode:       String,
95    pub dataset:         String,
96}
97
98pub struct BenchResultWriter {
99    run_id:  String,
100    out_dir: PathBuf,
101}
102
103impl BenchResultWriter {
104    /// Create a new writer.  `out_dir` is created if it does not yet exist.
105    pub fn new(out_dir: &Path, run_id: &str) -> Result<Self, ZerError> {
106        fs::create_dir_all(out_dir)
107            .map_err(|e| ZerError::Store(format!("cannot create output dir: {e}")))?;
108        Ok(Self {
109            run_id:  run_id.to_owned(),
110            out_dir: out_dir.to_path_buf(),
111        })
112    }
113
114    /// Write a streaming NDJSON pairs file.  One JSON object per line.
115    pub fn write_pairs(&self, pairs: &[PairRecord]) -> Result<(), ZerError> {
116        let path = self.out_dir.join(format!("{}_pairs.ndjson", self.run_id));
117        let file = File::create(&path)
118            .map_err(|e| ZerError::Store(format!("cannot create pairs file: {e}")))?;
119        let mut w = BufWriter::new(file);
120        for pair in pairs {
121            let line = serde_json::to_string(pair)
122                .map_err(|e| ZerError::Store(format!("JSON serialise error: {e}")))?;
123            writeln!(w, "{line}")
124                .map_err(|e| ZerError::Store(format!("write error: {e}")))?;
125        }
126        Ok(())
127    }
128
129    /// Write a single-row summary CSV in the shared cross-library format.
130    ///
131    /// Accuracy columns (`true_pos`, `false_pos`, etc.) are left empty when
132    /// `accuracy` is `None`, suitable for runs without a ground-truth file.
133    /// Uses `"zer"` as the library name.  Call [`Self::write_summary_with_library`]
134    /// when a different name is needed (e.g. `"zer+judge"`).
135    pub fn write_summary(
136        &self,
137        summary: &BenchBatchSummary,
138        accuracy: Option<&AccuracyMetrics>,
139    ) -> Result<(), ZerError> {
140        self.write_summary_with_library(summary, accuracy, "zer")
141    }
142
143    /// Like [`Self::write_summary`] but lets the caller set the `library` column.
144    ///
145    /// Use `"zer"` for the FS-only pipeline and `"zer+judge"` when the
146    /// MiniLM neural judge is enabled, so the comparison table distinguishes
147    /// both operating points.
148    pub fn write_summary_with_library(
149        &self,
150        summary: &BenchBatchSummary,
151        accuracy: Option<&AccuracyMetrics>,
152        library: &str,
153    ) -> Result<(), ZerError> {
154        let path = self.out_dir.join(format!("{}_summary.csv", self.run_id));
155        let file = File::create(&path)
156            .map_err(|e| ZerError::Store(format!("cannot create summary file: {e}")))?;
157
158        let timestamp = crate::time::utc_timestamp_iso();
159        let row = SummaryRow {
160            library:         library.to_owned(),
161            mode:            summary.link_mode.to_lowercase(),
162            dataset:         summary.dataset.clone(),
163            run_id:          self.run_id.clone(),
164            timestamp,
165            total_records:   summary.total_records,
166            candidate_pairs: summary.candidate_pairs,
167            auto_matched:    summary.auto_matched,
168            borderline:      summary.borderline,
169            auto_rejected:   summary.auto_rejected,
170            elapsed_ms:      summary.elapsed_ms,
171            true_pos:  accuracy.map(|a| a.true_pos),
172            false_pos: accuracy.map(|a| a.false_pos),
173            false_neg: accuracy.map(|a| a.false_neg),
174            precision: accuracy.map(|a| a.precision),
175            recall:    accuracy.map(|a| a.recall),
176            f1:        accuracy.map(|a| a.f1),
177        };
178
179        let mut wtr = csv::Writer::from_writer(file);
180        wtr.serialize(&row)
181            .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
182        wtr.flush()
183            .map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
184        Ok(())
185    }
186
187    /// Write a scored-pairs CSV file sorted by score descending.
188    ///
189    /// The file is named `<run_id>_scored_pairs.csv` and contains two columns:
190    /// `score` (f32) and `is_match` (0 or 1).  Separating this from the benchmark
191    /// JSON keeps the JSON small and allows millions of rows without memory cost.
192    pub fn write_scored_pairs_csv(&self, pairs: &[(f32, bool)]) -> Result<(), ZerError> {
193        let path = self.out_dir.join(format!("{}_scored_pairs.csv", self.run_id));
194        let file = File::create(&path)
195            .map_err(|e| ZerError::Store(format!("cannot create scored pairs file: {e}")))?;
196        let mut w = csv::Writer::from_writer(file);
197        w.write_record(["score", "is_match"])
198            .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
199        let mut sorted: Vec<(f32, bool)> = pairs.to_vec();
200        sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
201        for (score, is_match) in &sorted {
202            w.write_record(&[score.to_string(), (*is_match as u8).to_string()])
203                .map_err(|e| ZerError::Store(format!("CSV write error: {e}")))?;
204        }
205        w.flush().map_err(|e| ZerError::Store(format!("CSV flush error: {e}")))?;
206        Ok(())
207    }
208
209    pub fn run_id(&self) -> &str {
210        &self.run_id
211    }
212
213    pub fn out_dir(&self) -> &Path {
214        &self.out_dir
215    }
216}
217
218/// Convert a `MatchBand` to a bool for the `predicted_match` column.
219pub fn band_to_match(band: MatchBand) -> bool {
220    matches!(band, MatchBand::AutoMatch)
221}
222
223
224// ── Unit tests ────────────────────────────────────────────────────────────────
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use tempfile::TempDir;
230
231    fn sample_summary(_dir: &TempDir) -> BenchBatchSummary {
232        BenchBatchSummary {
233            total_records:   100,
234            candidate_pairs: 500,
235            auto_matched:    400,
236            borderline:      50,
237            auto_rejected:   50,
238            elapsed_ms:      1200,
239            link_mode:       "deduplicate".into(),
240            dataset:         "test_dataset".into(),
241        }
242    }
243
244    #[test]
245    fn write_pairs_ndjson_line_count() {
246        let dir    = TempDir::new().unwrap();
247        let writer = BenchResultWriter::new(dir.path(), "test_run").unwrap();
248
249        let pairs: Vec<PairRecord> = (0..5).map(|i| PairRecord {
250            run_id:            "test_run".into(),
251            record_id_a:       i,
252            source_a:          Some("brp".into()),
253            record_id_b:       i + 100,
254            source_b:          Some("kvk".into()),
255            match_probability: 0.9,
256            predicted_match:   true,
257        }).collect();
258
259        writer.write_pairs(&pairs).unwrap();
260
261        let path    = dir.path().join("test_run_pairs.ndjson");
262        let content = std::fs::read_to_string(&path).unwrap();
263        let lines: Vec<&str> = content.lines().collect();
264        assert_eq!(lines.len(), 5, "NDJSON file must have exactly N lines");
265
266        // Each line must be valid JSON
267        for line in &lines {
268            let v: serde_json::Value = serde_json::from_str(line).unwrap();
269            assert!(v.get("run_id").is_some());
270            assert!(v.get("match_probability").is_some());
271        }
272    }
273
274    #[test]
275    fn write_summary_csv_no_accuracy() {
276        let dir     = TempDir::new().unwrap();
277        let writer  = BenchResultWriter::new(dir.path(), "run_no_acc").unwrap();
278        let summary = sample_summary(&dir);
279
280        writer.write_summary(&summary, None).unwrap();
281
282        let path    = dir.path().join("run_no_acc_summary.csv");
283        let content = std::fs::read_to_string(&path).unwrap();
284        assert!(content.contains("zer"), "library field must be 'zer'");
285        assert!(content.contains("test_dataset"));
286        assert!(content.contains("100")); // total_records
287    }
288
289    #[test]
290    fn write_summary_csv_with_accuracy() {
291        let dir     = TempDir::new().unwrap();
292        let writer  = BenchResultWriter::new(dir.path(), "run_acc").unwrap();
293        let summary = sample_summary(&dir);
294        let acc     = AccuracyMetrics::from_counts(96, 4, 2);
295
296        writer.write_summary(&summary, Some(&acc)).unwrap();
297
298        let path    = dir.path().join("run_acc_summary.csv");
299        let content = std::fs::read_to_string(&path).unwrap();
300        assert!(content.contains("96")); // true_pos
301    }
302
303    #[test]
304    fn accuracy_metrics_from_counts() {
305        let acc = AccuracyMetrics::from_counts(90, 10, 5);
306        assert!((acc.precision - 0.9).abs() < 0.001);
307        assert!((acc.recall - (90.0 / 95.0)).abs() < 0.001);
308        assert!(acc.f1 > 0.0 && acc.f1 < 1.0);
309    }
310
311    #[test]
312    fn accuracy_metrics_zero_denominator() {
313        let acc = AccuracyMetrics::from_counts(0, 0, 0);
314        assert_eq!(acc.precision, 0.0);
315        assert_eq!(acc.recall, 0.0);
316        assert_eq!(acc.f1, 0.0);
317    }
318}