Skip to main content

zer_cluster/
threshold.rs

1use zer_core::scoring::{MatchBand, ModelParams, ScoredPair};
2
3/// Pairs partitioned by their match band.
4pub struct BandedPairs {
5    pub auto_match: Vec<ScoredPair>,
6    pub borderline: Vec<ScoredPair>,
7    pub auto_reject: Vec<ScoredPair>,
8}
9
10/// Classify each pair by `match_probability` vs the upper/lower thresholds in
11/// `params`. A pair is `AutoMatch` if `prob >= upper_threshold`, `AutoReject`
12/// if `prob < lower_threshold`, and `Borderline` otherwise.
13///
14/// The band already stored in `ScoredPair::band` is used directly, it must
15/// have been assigned by the same `ModelParams` that are passed here. If the
16/// stored band disagrees with the thresholds (e.g., params were updated after
17/// scoring), the stored band takes precedence so that provenance is preserved.
18pub fn partition_by_band(pairs: Vec<ScoredPair>, _params: &ModelParams) -> BandedPairs {
19    let mut auto_match = Vec::new();
20    let mut borderline = Vec::new();
21    let mut auto_reject = Vec::new();
22
23    for pair in pairs {
24        match pair.band {
25            MatchBand::AutoMatch => auto_match.push(pair),
26            MatchBand::Borderline => borderline.push(pair),
27            MatchBand::AutoReject => auto_reject.push(pair),
28        }
29    }
30
31    BandedPairs {
32        auto_match,
33        borderline,
34        auto_reject,
35    }
36}
37
38// ── Unit tests ────────────────────────────────────────────────────────────────
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
44
45    fn params() -> ModelParams {
46        ModelParams {
47            m: vec![],
48            u: vec![],
49            log_prior_odds: 0.0,
50            upper_threshold: 0.8,
51            lower_threshold: 0.2,
52        }
53    }
54
55    fn pair(a: u64, b: u64, prob: f32, band: MatchBand) -> ScoredPair {
56        ScoredPair {
57            record_a: a,
58            record_b: b,
59            match_weight: 0.0,
60            match_probability: prob,
61            vector: ComparisonVector {
62                record_a: a,
63                record_b: b,
64                levels: vec![],
65            },
66            band,
67        }
68    }
69
70    #[test]
71    fn empty_input_returns_empty_partitions() {
72        let result = partition_by_band(vec![], &params());
73        assert!(result.auto_match.is_empty());
74        assert!(result.borderline.is_empty());
75        assert!(result.auto_reject.is_empty());
76    }
77
78    #[test]
79    fn all_auto_match() {
80        let pairs = vec![
81            pair(1, 2, 0.95, MatchBand::AutoMatch),
82            pair(3, 4, 0.90, MatchBand::AutoMatch),
83        ];
84        let result = partition_by_band(pairs, &params());
85        assert_eq!(result.auto_match.len(), 2);
86        assert!(result.borderline.is_empty());
87        assert!(result.auto_reject.is_empty());
88    }
89
90    #[test]
91    fn all_auto_reject() {
92        let pairs = vec![
93            pair(1, 2, 0.05, MatchBand::AutoReject),
94            pair(3, 4, 0.10, MatchBand::AutoReject),
95        ];
96        let result = partition_by_band(pairs, &params());
97        assert!(result.auto_match.is_empty());
98        assert!(result.borderline.is_empty());
99        assert_eq!(result.auto_reject.len(), 2);
100    }
101
102    #[test]
103    fn mixed_bands() {
104        let pairs = vec![
105            pair(1, 2, 0.95, MatchBand::AutoMatch),
106            pair(2, 3, 0.50, MatchBand::Borderline),
107            pair(4, 5, 0.05, MatchBand::AutoReject),
108        ];
109        let result = partition_by_band(pairs, &params());
110        assert_eq!(result.auto_match.len(), 1);
111        assert_eq!(result.borderline.len(), 1);
112        assert_eq!(result.auto_reject.len(), 1);
113    }
114}