Skip to main content

zer_compare/
scoring.rs

1use zer_core::{
2    comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
3    scoring::{MatchBand, ModelParams, ScoredPair},
4    traits::Scorer,
5};
6
7const NULL_LEVEL_BYTE: u8 = ComparisonLevel::Null as u8;
8
9/// Fellegi-Sunter scorer.
10pub struct FellegiSunterScorer;
11
12impl FellegiSunterScorer {
13    fn sigmoid(x: f32) -> f32 {
14        1.0 / (1.0 + (-x).exp())
15    }
16}
17
18#[inline]
19fn classify(prob: f32, params: &ModelParams) -> MatchBand {
20    if prob >= params.upper_threshold      { MatchBand::AutoMatch }
21    else if prob < params.lower_threshold  { MatchBand::AutoReject }
22    else                                   { MatchBand::Borderline }
23}
24
25impl Scorer for FellegiSunterScorer {
26    fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
27        let match_weight: f32 = vector.levels.iter().enumerate()
28            .map(|(i, &level)| {
29                if level == ComparisonLevel::Null { return 0.0_f32; }
30                let l = level as usize;
31                let m = params.m[i][l].max(1e-9);
32                let u = params.u[i][l].max(1e-9);
33                (m / u).ln()
34            })
35            .sum();
36
37        let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
38        let band = classify(match_probability, params);
39
40        ScoredPair {
41            record_a:          vector.record_a,
42            record_b:          vector.record_b,
43            match_weight,
44            match_probability,
45            vector:            vector.clone(),
46            band,
47        }
48    }
49
50    /// Batch scoring over all pairs in the `ComparisonBatch`.
51    fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
52        let n_pairs  = batch.n_pairs;
53        let n_fields = batch.n_fields;
54
55        // Flatten the weight table: weight_flat[f*4 + l] = ln(m[f][l] / u[f][l]).
56        // Storing it flat lets the inner loop use a single indexed load.
57        let mut weight_flat = vec![0.0f32; n_fields * 4];
58        for f in 0..n_fields {
59            for l in 0..4 {
60                let m = params.m[f][l].max(1e-9_f32);
61                let u = params.u[f][l].max(1e-9_f32);
62                weight_flat[f * 4 + l] = (m / u).ln();
63            }
64        }
65
66        let mut match_weights = vec![0.0f32; n_pairs];
67
68        // Field-outer / pair-inner, sequential reads, auto-vectorizable inner loop.
69        for f in 0..n_fields {
70            let field_levels  = &batch.levels[f * n_pairs..(f + 1) * n_pairs];
71            let field_weights = &weight_flat[f * 4..(f + 1) * 4];
72            for p in 0..n_pairs {
73                let l = field_levels[p];
74                if l != NULL_LEVEL_BYTE {
75                    match_weights[p] += field_weights[l as usize];
76                }
77            }
78        }
79
80        (0..n_pairs)
81            .map(|p| {
82                let (a, b)            = batch.pair_ids[p];
83                let match_weight      = match_weights[p];
84                let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
85                let band              = classify(match_probability, params);
86                ScoredPair {
87                    record_a: a,
88                    record_b: b,
89                    match_weight,
90                    match_probability,
91                    vector: batch.pair_as_vector(p),
92                    band,
93                }
94            })
95            .collect()
96    }
97
98    fn estimate_params(
99        &self,
100        batch:    &ComparisonBatch,
101        init:     Option<ModelParams>,
102        max_iter: usize,
103    ) -> zer_core::traits::Result<ModelParams> {
104        let mut params = crate::em::run_em(batch, init, max_iter)?;
105
106        let scores: Vec<f32> = self.score_batch(batch, &params)
107            .into_iter()
108            .map(|sp| sp.match_probability)
109            .collect();
110        let (upper, lower) = crate::em::auto_calibrate_thresholds(&scores);
111        params.upper_threshold = upper;
112        params.lower_threshold = lower;
113        tracing::info!(upper, lower, "auto-calibrated thresholds");
114
115        Ok(params)
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use zer_core::comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector};
122
123    use super::*;
124
125    fn default_params(n_fields: usize) -> ModelParams {
126        ModelParams {
127            m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
128            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
129            log_prior_odds:  (0.1_f32 / 0.9_f32).ln(),
130            upper_threshold: 0.9,
131            lower_threshold: 0.1,
132        }
133    }
134
135    fn all_exact_vector(n_fields: usize) -> ComparisonVector {
136        ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
137    }
138
139    fn all_none_vector(n_fields: usize) -> ComparisonVector {
140        ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
141    }
142
143    #[test]
144    fn all_exact_produces_high_probability() {
145        let scorer = FellegiSunterScorer;
146        let params = default_params(4);
147        let cv     = all_exact_vector(4);
148        let scored = scorer.score(&cv, &params);
149        assert!(
150            scored.match_probability > 0.9,
151            "all-Exact vector should score > 0.9, got {}",
152            scored.match_probability
153        );
154        assert_eq!(scored.band, MatchBand::AutoMatch);
155    }
156
157    #[test]
158    fn all_none_produces_low_probability() {
159        let scorer = FellegiSunterScorer;
160        let params = default_params(4);
161        let cv     = all_none_vector(4);
162        let scored = scorer.score(&cv, &params);
163        assert!(
164            scored.match_probability < 0.1,
165            "all-None vector should score < 0.1, got {}",
166            scored.match_probability
167        );
168        assert_eq!(scored.band, MatchBand::AutoReject);
169    }
170
171    #[test]
172    fn mixed_vector_scores_between_extremes() {
173        let scorer    = FellegiSunterScorer;
174        let params    = default_params(4);
175        let all_exact = scorer.score(&all_exact_vector(4), &params);
176        let all_none  = scorer.score(&all_none_vector(4), &params);
177
178        let cv = ComparisonVector::new(5, 6, vec![
179            ComparisonLevel::Exact,
180            ComparisonLevel::Exact,
181            ComparisonLevel::None,
182            ComparisonLevel::None,
183        ]);
184        let scored = scorer.score(&cv, &params);
185
186        assert!(
187            scored.match_probability > all_none.match_probability &&
188            scored.match_probability < all_exact.match_probability,
189            "mixed vector ({}) should be between all-None ({}) and all-Exact ({})",
190            scored.match_probability, all_none.match_probability, all_exact.match_probability
191        );
192        assert_ne!(scored.band, MatchBand::AutoMatch,
193            "mixed vector should not be AutoMatch");
194    }
195
196    #[test]
197    fn score_batch_matches_individual_scores() {
198        let scorer  = FellegiSunterScorer;
199        let params  = default_params(3);
200        let vectors = vec![all_exact_vector(3), all_none_vector(3)];
201        let batch   = ComparisonBatch::from_vectors(&vectors);
202        let scored  = scorer.score_batch(&batch, &params);
203        let ind: Vec<ScoredPair> = vectors.iter().map(|v| scorer.score(v, &params)).collect();
204
205        for (b, i) in scored.iter().zip(ind.iter()) {
206            assert_eq!(b.band, i.band);
207            assert!((b.match_probability - i.match_probability).abs() < 1e-6);
208        }
209    }
210
211    #[test]
212    fn estimate_params_converges_from_mixed_data() {
213        let scorer = FellegiSunterScorer;
214        let mut vectors = vec![];
215        for i in 0..100u64 {
216            vectors.push(ComparisonVector::new(i, i + 10000, vec![ComparisonLevel::Exact; 3]));
217        }
218        for i in 0..400u64 {
219            vectors.push(ComparisonVector::new(i + 20000, i + 30000, vec![ComparisonLevel::None; 3]));
220        }
221        let batch = ComparisonBatch::from_vectors(&vectors);
222
223        let params = scorer.estimate_params(&batch, None, 100)
224            .expect("estimate_params should succeed");
225
226        for f in 0..3 {
227            assert!(
228                params.m[f][ComparisonLevel::Exact as usize] > params.u[f][ComparisonLevel::Exact as usize],
229                "after EM, m[Exact] should exceed u[Exact] for field {f}"
230            );
231        }
232    }
233}