1use zer_core::{
2 comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
3 scoring::{MatchBand, ModelParams, ScoredPair},
4 traits::Scorer,
5};
6
7const NULL_LEVEL_BYTE: u8 = ComparisonLevel::Null as u8;
8
9pub struct FellegiSunterScorer;
11
12impl FellegiSunterScorer {
13 fn sigmoid(x: f32) -> f32 {
14 1.0 / (1.0 + (-x).exp())
15 }
16}
17
18#[inline]
19fn classify(prob: f32, params: &ModelParams) -> MatchBand {
20 if prob >= params.upper_threshold { MatchBand::AutoMatch }
21 else if prob < params.lower_threshold { MatchBand::AutoReject }
22 else { MatchBand::Borderline }
23}
24
25impl Scorer for FellegiSunterScorer {
26 fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
27 let match_weight: f32 = vector.levels.iter().enumerate()
28 .map(|(i, &level)| {
29 if level == ComparisonLevel::Null { return 0.0_f32; }
30 let l = level as usize;
31 let m = params.m[i][l].max(1e-9);
32 let u = params.u[i][l].max(1e-9);
33 (m / u).ln()
34 })
35 .sum();
36
37 let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
38 let band = classify(match_probability, params);
39
40 ScoredPair {
41 record_a: vector.record_a,
42 record_b: vector.record_b,
43 match_weight,
44 match_probability,
45 vector: vector.clone(),
46 band,
47 }
48 }
49
50 fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
52 let n_pairs = batch.n_pairs;
53 let n_fields = batch.n_fields;
54
55 let mut weight_flat = vec![0.0f32; n_fields * 4];
58 for f in 0..n_fields {
59 for l in 0..4 {
60 let m = params.m[f][l].max(1e-9_f32);
61 let u = params.u[f][l].max(1e-9_f32);
62 weight_flat[f * 4 + l] = (m / u).ln();
63 }
64 }
65
66 let mut match_weights = vec![0.0f32; n_pairs];
67
68 for f in 0..n_fields {
70 let field_levels = &batch.levels[f * n_pairs..(f + 1) * n_pairs];
71 let field_weights = &weight_flat[f * 4..(f + 1) * 4];
72 for p in 0..n_pairs {
73 let l = field_levels[p];
74 if l != NULL_LEVEL_BYTE {
75 match_weights[p] += field_weights[l as usize];
76 }
77 }
78 }
79
80 (0..n_pairs)
81 .map(|p| {
82 let (a, b) = batch.pair_ids[p];
83 let match_weight = match_weights[p];
84 let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
85 let band = classify(match_probability, params);
86 ScoredPair {
87 record_a: a,
88 record_b: b,
89 match_weight,
90 match_probability,
91 vector: batch.pair_as_vector(p),
92 band,
93 }
94 })
95 .collect()
96 }
97
98 fn estimate_params(
99 &self,
100 batch: &ComparisonBatch,
101 init: Option<ModelParams>,
102 max_iter: usize,
103 ) -> zer_core::traits::Result<ModelParams> {
104 let mut params = crate::em::run_em(batch, init, max_iter)?;
105
106 let scores: Vec<f32> = self.score_batch(batch, ¶ms)
107 .into_iter()
108 .map(|sp| sp.match_probability)
109 .collect();
110 let (upper, lower) = crate::em::auto_calibrate_thresholds(&scores);
111 params.upper_threshold = upper;
112 params.lower_threshold = lower;
113 tracing::info!(upper, lower, "auto-calibrated thresholds");
114
115 Ok(params)
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use zer_core::comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector};
122
123 use super::*;
124
125 fn default_params(n_fields: usize) -> ModelParams {
126 ModelParams {
127 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
128 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
129 log_prior_odds: (0.1_f32 / 0.9_f32).ln(),
130 upper_threshold: 0.9,
131 lower_threshold: 0.1,
132 }
133 }
134
135 fn all_exact_vector(n_fields: usize) -> ComparisonVector {
136 ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
137 }
138
139 fn all_none_vector(n_fields: usize) -> ComparisonVector {
140 ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
141 }
142
143 #[test]
144 fn all_exact_produces_high_probability() {
145 let scorer = FellegiSunterScorer;
146 let params = default_params(4);
147 let cv = all_exact_vector(4);
148 let scored = scorer.score(&cv, ¶ms);
149 assert!(
150 scored.match_probability > 0.9,
151 "all-Exact vector should score > 0.9, got {}",
152 scored.match_probability
153 );
154 assert_eq!(scored.band, MatchBand::AutoMatch);
155 }
156
157 #[test]
158 fn all_none_produces_low_probability() {
159 let scorer = FellegiSunterScorer;
160 let params = default_params(4);
161 let cv = all_none_vector(4);
162 let scored = scorer.score(&cv, ¶ms);
163 assert!(
164 scored.match_probability < 0.1,
165 "all-None vector should score < 0.1, got {}",
166 scored.match_probability
167 );
168 assert_eq!(scored.band, MatchBand::AutoReject);
169 }
170
171 #[test]
172 fn mixed_vector_scores_between_extremes() {
173 let scorer = FellegiSunterScorer;
174 let params = default_params(4);
175 let all_exact = scorer.score(&all_exact_vector(4), ¶ms);
176 let all_none = scorer.score(&all_none_vector(4), ¶ms);
177
178 let cv = ComparisonVector::new(5, 6, vec![
179 ComparisonLevel::Exact,
180 ComparisonLevel::Exact,
181 ComparisonLevel::None,
182 ComparisonLevel::None,
183 ]);
184 let scored = scorer.score(&cv, ¶ms);
185
186 assert!(
187 scored.match_probability > all_none.match_probability &&
188 scored.match_probability < all_exact.match_probability,
189 "mixed vector ({}) should be between all-None ({}) and all-Exact ({})",
190 scored.match_probability, all_none.match_probability, all_exact.match_probability
191 );
192 assert_ne!(scored.band, MatchBand::AutoMatch,
193 "mixed vector should not be AutoMatch");
194 }
195
196 #[test]
197 fn score_batch_matches_individual_scores() {
198 let scorer = FellegiSunterScorer;
199 let params = default_params(3);
200 let vectors = vec![all_exact_vector(3), all_none_vector(3)];
201 let batch = ComparisonBatch::from_vectors(&vectors);
202 let scored = scorer.score_batch(&batch, ¶ms);
203 let ind: Vec<ScoredPair> = vectors.iter().map(|v| scorer.score(v, ¶ms)).collect();
204
205 for (b, i) in scored.iter().zip(ind.iter()) {
206 assert_eq!(b.band, i.band);
207 assert!((b.match_probability - i.match_probability).abs() < 1e-6);
208 }
209 }
210
211 #[test]
212 fn estimate_params_converges_from_mixed_data() {
213 let scorer = FellegiSunterScorer;
214 let mut vectors = vec![];
215 for i in 0..100u64 {
216 vectors.push(ComparisonVector::new(i, i + 10000, vec![ComparisonLevel::Exact; 3]));
217 }
218 for i in 0..400u64 {
219 vectors.push(ComparisonVector::new(i + 20000, i + 30000, vec![ComparisonLevel::None; 3]));
220 }
221 let batch = ComparisonBatch::from_vectors(&vectors);
222
223 let params = scorer.estimate_params(&batch, None, 100)
224 .expect("estimate_params should succeed");
225
226 for f in 0..3 {
227 assert!(
228 params.m[f][ComparisonLevel::Exact as usize] > params.u[f][ComparisonLevel::Exact as usize],
229 "after EM, m[Exact] should exceed u[Exact] for field {f}"
230 );
231 }
232 }
233}