1use zer_core::{
2 comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
3 scoring::{MatchBand, ModelParams, ScoredPair},
4 traits::Scorer,
5};
6
7const NULL_LEVEL_BYTE: u8 = ComparisonLevel::Null as u8;
8
9pub struct FellegiSunterScorer;
11
12impl FellegiSunterScorer {
13 fn sigmoid(x: f32) -> f32 {
14 1.0 / (1.0 + (-x).exp())
15 }
16}
17
18#[inline]
19fn classify(prob: f32, params: &ModelParams) -> MatchBand {
20 if prob >= params.upper_threshold {
21 MatchBand::AutoMatch
22 } else if prob < params.lower_threshold {
23 MatchBand::AutoReject
24 } else {
25 MatchBand::Borderline
26 }
27}
28
29impl Scorer for FellegiSunterScorer {
30 fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
31 let match_weight: f32 = vector
32 .levels
33 .iter()
34 .enumerate()
35 .map(|(i, &level)| {
36 if level == ComparisonLevel::Null {
37 return 0.0_f32;
38 }
39 let l = level as usize;
40 let m = params.m[i][l].max(1e-9);
41 let u = params.u[i][l].max(1e-9);
42 (m / u).ln()
43 })
44 .sum();
45
46 let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
47 let band = classify(match_probability, params);
48
49 ScoredPair {
50 record_a: vector.record_a,
51 record_b: vector.record_b,
52 match_weight,
53 match_probability,
54 vector: vector.clone(),
55 band,
56 }
57 }
58
59 fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
61 let n_pairs = batch.n_pairs;
62 let n_fields = batch.n_fields;
63
64 let mut weight_flat = vec![0.0f32; n_fields * 4];
67 for f in 0..n_fields {
68 for l in 0..4 {
69 let m = params.m[f][l].max(1e-9_f32);
70 let u = params.u[f][l].max(1e-9_f32);
71 weight_flat[f * 4 + l] = (m / u).ln();
72 }
73 }
74
75 let mut match_weights = vec![0.0f32; n_pairs];
76
77 for f in 0..n_fields {
79 let field_levels = &batch.levels[f * n_pairs..(f + 1) * n_pairs];
80 let field_weights = &weight_flat[f * 4..(f + 1) * 4];
81 for p in 0..n_pairs {
82 let l = field_levels[p];
83 if l != NULL_LEVEL_BYTE {
84 match_weights[p] += field_weights[l as usize];
85 }
86 }
87 }
88
89 (0..n_pairs)
90 .map(|p| {
91 let (a, b) = batch.pair_ids[p];
92 let match_weight = match_weights[p];
93 let match_probability = Self::sigmoid(match_weight + params.log_prior_odds);
94 let band = classify(match_probability, params);
95 ScoredPair {
96 record_a: a,
97 record_b: b,
98 match_weight,
99 match_probability,
100 vector: batch.pair_as_vector(p),
101 band,
102 }
103 })
104 .collect()
105 }
106
107 fn estimate_params(
108 &self,
109 batch: &ComparisonBatch,
110 init: Option<ModelParams>,
111 max_iter: usize,
112 ) -> zer_core::traits::Result<ModelParams> {
113 let mut params = crate::em::run_em(batch, init, max_iter)?;
114
115 let scores: Vec<f32> = self
116 .score_batch(batch, ¶ms)
117 .into_iter()
118 .map(|sp| sp.match_probability)
119 .collect();
120 let (upper, lower) = crate::em::auto_calibrate_thresholds(&scores);
121 params.upper_threshold = upper;
122 params.lower_threshold = lower;
123 tracing::info!(upper, lower, "auto-calibrated thresholds");
124
125 Ok(params)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use zer_core::comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector};
132
133 use super::*;
134
135 fn default_params(n_fields: usize) -> ModelParams {
136 ModelParams {
137 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
138 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
139 log_prior_odds: (0.1_f32 / 0.9_f32).ln(),
140 upper_threshold: 0.9,
141 lower_threshold: 0.1,
142 }
143 }
144
145 fn all_exact_vector(n_fields: usize) -> ComparisonVector {
146 ComparisonVector::new(1, 2, vec![ComparisonLevel::Exact; n_fields])
147 }
148
149 fn all_none_vector(n_fields: usize) -> ComparisonVector {
150 ComparisonVector::new(3, 4, vec![ComparisonLevel::None; n_fields])
151 }
152
153 #[test]
154 fn all_exact_produces_high_probability() {
155 let scorer = FellegiSunterScorer;
156 let params = default_params(4);
157 let cv = all_exact_vector(4);
158 let scored = scorer.score(&cv, ¶ms);
159 assert!(
160 scored.match_probability > 0.9,
161 "all-Exact vector should score > 0.9, got {}",
162 scored.match_probability
163 );
164 assert_eq!(scored.band, MatchBand::AutoMatch);
165 }
166
167 #[test]
168 fn all_none_produces_low_probability() {
169 let scorer = FellegiSunterScorer;
170 let params = default_params(4);
171 let cv = all_none_vector(4);
172 let scored = scorer.score(&cv, ¶ms);
173 assert!(
174 scored.match_probability < 0.1,
175 "all-None vector should score < 0.1, got {}",
176 scored.match_probability
177 );
178 assert_eq!(scored.band, MatchBand::AutoReject);
179 }
180
181 #[test]
182 fn mixed_vector_scores_between_extremes() {
183 let scorer = FellegiSunterScorer;
184 let params = default_params(4);
185 let all_exact = scorer.score(&all_exact_vector(4), ¶ms);
186 let all_none = scorer.score(&all_none_vector(4), ¶ms);
187
188 let cv = ComparisonVector::new(
189 5,
190 6,
191 vec![
192 ComparisonLevel::Exact,
193 ComparisonLevel::Exact,
194 ComparisonLevel::None,
195 ComparisonLevel::None,
196 ],
197 );
198 let scored = scorer.score(&cv, ¶ms);
199
200 assert!(
201 scored.match_probability > all_none.match_probability
202 && scored.match_probability < all_exact.match_probability,
203 "mixed vector ({}) should be between all-None ({}) and all-Exact ({})",
204 scored.match_probability,
205 all_none.match_probability,
206 all_exact.match_probability
207 );
208 assert_ne!(
209 scored.band,
210 MatchBand::AutoMatch,
211 "mixed vector should not be AutoMatch"
212 );
213 }
214
215 #[test]
216 fn score_batch_matches_individual_scores() {
217 let scorer = FellegiSunterScorer;
218 let params = default_params(3);
219 let vectors = vec![all_exact_vector(3), all_none_vector(3)];
220 let batch = ComparisonBatch::from_vectors(&vectors);
221 let scored = scorer.score_batch(&batch, ¶ms);
222 let ind: Vec<ScoredPair> = vectors.iter().map(|v| scorer.score(v, ¶ms)).collect();
223
224 for (b, i) in scored.iter().zip(ind.iter()) {
225 assert_eq!(b.band, i.band);
226 assert!((b.match_probability - i.match_probability).abs() < 1e-6);
227 }
228 }
229
230 #[test]
231 fn estimate_params_converges_from_mixed_data() {
232 let scorer = FellegiSunterScorer;
233 let mut vectors = vec![];
234 for i in 0..100u64 {
235 vectors.push(ComparisonVector::new(
236 i,
237 i + 10000,
238 vec![ComparisonLevel::Exact; 3],
239 ));
240 }
241 for i in 0..400u64 {
242 vectors.push(ComparisonVector::new(
243 i + 20000,
244 i + 30000,
245 vec![ComparisonLevel::None; 3],
246 ));
247 }
248 let batch = ComparisonBatch::from_vectors(&vectors);
249
250 let params = scorer
251 .estimate_params(&batch, None, 100)
252 .expect("estimate_params should succeed");
253
254 for f in 0..3 {
255 assert!(
256 params.m[f][ComparisonLevel::Exact as usize]
257 > params.u[f][ComparisonLevel::Exact as usize],
258 "after EM, m[Exact] should exceed u[Exact] for field {f}"
259 );
260 }
261 }
262}