Skip to main content

zer_compare/
scoring.rs

1use zer_core::{
2    comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
3    scoring::{MatchBand, ModelParams, ScoredPair},
4    traits::Scorer,
5};
6
7const NULL_LEVEL_BYTE: u8 = ComparisonLevel::Null as u8;
8
9/// Fellegi-Sunter scorer.
10pub struct FellegiSunterScorer;
11
12impl FellegiSunterScorer {
13    fn sigmoid(x: f32) -> f32 {
14        1.0 / (1.0 + (-x).exp())
15    }
16}
17
18#[inline]
19fn classify(prob: f32, params: &ModelParams) -> MatchBand {
20    if prob >= params.upper_threshold {
21        MatchBand::AutoMatch
22    } else if prob < params.lower_threshold {
23        MatchBand::AutoReject
24    } else {
25        MatchBand::Borderline
26    }
27}
28
29impl Scorer for FellegiSunterScorer {
30    fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
31        let match_weight: f32 = vector
32            .levels
33            .iter()
34            .enumerate()
35            .map(|(i, &level)| {
36                if level == ComparisonLevel::Null {
37                    return 0.0_f32;
38                }
39                let l = level as usize;
40                let m = params.m[i][l].max(1e-9);
41                let u = params.u[i][l].max(1e-9);
42                (m / u).ln()
43            })
44            .sum();
45
46        let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
47        let band = classify(match_probability, params);
48
49        ScoredPair {
50            record_a: vector.record_a,
51            record_b: vector.record_b,
52            match_weight,
53            match_probability,
54            vector: vector.clone(),
55            band,
56        }
57    }
58
59    /// Batch scoring over all pairs in the `ComparisonBatch`.
60    fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
61        let n_pairs = batch.n_pairs;
62        let n_fields = batch.n_fields;
63
64        // Flatten the weight table: weight_flat[f*4 + l] = ln(m[f][l] / u[f][l]).
65        // Storing it flat lets the inner loop use a single indexed load.
66        let mut weight_flat = vec![0.0f32; n_fields * 4];
67        for f in 0..n_fields {
68            for l in 0..4 {
69                let m = params.m[f][l].max(1e-9_f32);
70                let u = params.u[f][l].max(1e-9_f32);
71                weight_flat[f * 4 + l] = (m / u).ln();
72            }
73        }
74
75        let mut match_weights = vec![0.0f32; n_pairs];
76
77        // Field-outer / pair-inner, sequential reads, auto-vectorizable inner loop.
78        for f in 0..n_fields {
79            let field_levels = &batch.levels[f * n_pairs..(f + 1) * n_pairs];
80            let field_weights = &weight_flat[f * 4..(f + 1) * 4];
81            for p in 0..n_pairs {
82                let l = field_levels[p];
83                if l != NULL_LEVEL_BYTE {
84                    match_weights[p] += field_weights[l as usize];
85                }
86            }
87        }
88
89        (0..n_pairs)
90            .map(|p| {
91                let (a, b) = batch.pair_ids[p];
92                let match_weight = match_weights[p];
93                let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
94                let band = classify(match_probability, params);
95                ScoredPair {
96                    record_a: a,
97                    record_b: b,
98                    match_weight,
99                    match_probability,
100                    vector: batch.pair_as_vector(p),
101                    band,
102                }
103            })
104            .collect()
105    }
106
107    fn estimate_params(
108        &self,
109        batch: &ComparisonBatch,
110        init: Option<ModelParams>,
111        max_iter: usize,
112    ) -> zer_core::traits::Result<ModelParams> {
113        let mut params = crate::em::run_em(batch, init, max_iter)?;
114
115        let scores: Vec<f32> = self
116            .score_batch(batch, &params)
117            .into_iter()
118            .map(|sp| sp.match_probability)
119            .collect();
120        let (upper, lower) = crate::em::auto_calibrate_thresholds(&scores);
121        params.upper_threshold = upper;
122        params.lower_threshold = lower;
123        tracing::info!(upper, lower, "auto-calibrated thresholds");
124
125        Ok(params)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use zer_core::comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector};
132
133    use super::*;
134
135    fn default_params(n_fields: usize) -> ModelParams {
136        ModelParams {
137            m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
138            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
139            log_prior_odds: (0.1_f32 / 0.9_f32).ln(),
140            upper_threshold: 0.9,
141            lower_threshold: 0.1,
142        }
143    }
144
145    fn all_exact_vector(n_fields: usize) -> ComparisonVector {
146        ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
147    }
148
149    fn all_none_vector(n_fields: usize) -> ComparisonVector {
150        ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
151    }
152
153    #[test]
154    fn all_exact_produces_high_probability() {
155        let scorer = FellegiSunterScorer;
156        let params = default_params(4);
157        let cv = all_exact_vector(4);
158        let scored = scorer.score(&cv, &params);
159        assert!(
160            scored.match_probability > 0.9,
161            "all-Exact vector should score > 0.9, got {}",
162            scored.match_probability
163        );
164        assert_eq!(scored.band, MatchBand::AutoMatch);
165    }
166
167    #[test]
168    fn all_none_produces_low_probability() {
169        let scorer = FellegiSunterScorer;
170        let params = default_params(4);
171        let cv = all_none_vector(4);
172        let scored = scorer.score(&cv, &params);
173        assert!(
174            scored.match_probability < 0.1,
175            "all-None vector should score < 0.1, got {}",
176            scored.match_probability
177        );
178        assert_eq!(scored.band, MatchBand::AutoReject);
179    }
180
181    #[test]
182    fn mixed_vector_scores_between_extremes() {
183        let scorer = FellegiSunterScorer;
184        let params = default_params(4);
185        let all_exact = scorer.score(&all_exact_vector(4), &params);
186        let all_none = scorer.score(&all_none_vector(4), &params);
187
188        let cv = ComparisonVector::new(
189            5,
190            6,
191            vec![
192                ComparisonLevel::Exact,
193                ComparisonLevel::Exact,
194                ComparisonLevel::None,
195                ComparisonLevel::None,
196            ],
197        );
198        let scored = scorer.score(&cv, &params);
199
200        assert!(
201            scored.match_probability > all_none.match_probability
202                && scored.match_probability < all_exact.match_probability,
203            "mixed vector ({}) should be between all-None ({}) and all-Exact ({})",
204            scored.match_probability,
205            all_none.match_probability,
206            all_exact.match_probability
207        );
208        assert_ne!(
209            scored.band,
210            MatchBand::AutoMatch,
211            "mixed vector should not be AutoMatch"
212        );
213    }
214
215    #[test]
216    fn score_batch_matches_individual_scores() {
217        let scorer = FellegiSunterScorer;
218        let params = default_params(3);
219        let vectors = vec![all_exact_vector(3), all_none_vector(3)];
220        let batch = ComparisonBatch::from_vectors(&vectors);
221        let scored = scorer.score_batch(&batch, &params);
222        let ind: Vec<ScoredPair> = vectors.iter().map(|v| scorer.score(v, &params)).collect();
223
224        for (b, i) in scored.iter().zip(ind.iter()) {
225            assert_eq!(b.band, i.band);
226            assert!((b.match_probability - i.match_probability).abs() < 1e-6);
227        }
228    }
229
230    #[test]
231    fn estimate_params_converges_from_mixed_data() {
232        let scorer = FellegiSunterScorer;
233        let mut vectors = vec![];
234        for i in 0..100u64 {
235            vectors.push(ComparisonVector::new(
236                i,
237                i + 10000,
238                vec![ComparisonLevel::Exact; 3],
239            ));
240        }
241        for i in 0..400u64 {
242            vectors.push(ComparisonVector::new(
243                i + 20000,
244                i + 30000,
245                vec![ComparisonLevel::None; 3],
246            ));
247        }
248        let batch = ComparisonBatch::from_vectors(&vectors);
249
250        let params = scorer
251            .estimate_params(&batch, None, 100)
252            .expect("estimate_params should succeed");
253
254        for f in 0..3 {
255            assert!(
256                params.m[f][ComparisonLevel::Exact as usize]
257                    > params.u[f][ComparisonLevel::Exact as usize],
258                "after EM, m[Exact] should exceed u[Exact] for field {f}"
259            );
260        }
261    }
262}