Skip to main content

zer_compare/
em.rs

1use zer_core::{
2    comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector},
3    error::ZerError,
4    scoring::ModelParams,
5};
6
7const N_LEVELS: usize = 4; // None=0, Partial=1, Close=2, Exact=3
8
9
10// ── E-step ────────────────────────────────────────────────────────────────────
11
12/// Compute P(match | comparison_vector) for a single pair given current params.
13pub fn e_step(vector: &ComparisonVector, params: &ModelParams) -> f32 {
14    let log_odds: f32 = params.log_prior_odds
15        + vector.levels.iter().enumerate()
16            .map(|(i, &level)| {
17                if level == ComparisonLevel::Null { return 0.0_f32; }
18                let l = level as usize;
19                let m = params.m[i][l].max(1e-9_f32);
20                let u = params.u[i][l].max(1e-9_f32);
21                (m / u).ln()
22            })
23            .sum::<f32>();
24    1.0 / (1.0 + (-log_odds).exp())
25}
26
27#[inline]
28fn e_step_p(batch: &ComparisonBatch, p: usize, params: &ModelParams) -> f32 {
29    let n_pairs = batch.n_pairs;
30    let log_odds: f32 = params.log_prior_odds
31        + (0..batch.n_fields)
32            .map(|f| {
33                let l_u8 = batch.levels[f * n_pairs + p];
34                if l_u8 == 255 { return 0.0_f32; } // ComparisonLevel::Null, skip
35                let l = l_u8 as usize;
36                let m = params.m[f][l].max(1e-9_f32);
37                let u = params.u[f][l].max(1e-9_f32);
38                (m / u).ln()
39            })
40            .sum::<f32>();
41    1.0 / (1.0 + (-log_odds).exp())
42}
43
44// ── M-step ────────────────────────────────────────────────────────────────────
45
46fn m_step(
47    batch:     &ComparisonBatch,
48    posteriors: &[f32],
49    prev:      &ModelParams,
50) -> ModelParams {
51    let n_fields = batch.n_fields;
52    let n_pairs  = batch.n_pairs;
53
54    let mut m_num = vec![vec![0.0f32; N_LEVELS]; n_fields];
55    let mut u_num = vec![vec![0.0f32; N_LEVELS]; n_fields];
56
57    let mut total_match    = 0.0f32;
58    let mut total_nonmatch = 0.0f32;
59
60    for p in 0..n_pairs {
61        total_match    += posteriors[p];
62        total_nonmatch += 1.0 - posteriors[p];
63    }
64
65    // Field-outer, pair-inner: sequential reads of levels[f*n_pairs+p].
66    // This layout lets the compiler auto-vectorize the inner accumulation.
67    // Null (255) fields are skipped, they carry no m/u evidence.
68    for f in 0..n_fields {
69        let field_slice = &batch.levels[f * n_pairs..(f + 1) * n_pairs];
70        for p in 0..n_pairs {
71            let l_u8 = field_slice[p];
72            if l_u8 == 255 { continue; } // ComparisonLevel::Null
73            let l = l_u8 as usize;
74            m_num[f][l] += posteriors[p];
75            u_num[f][l] += 1.0 - posteriors[p];
76        }
77    }
78
79    let total_match    = total_match.max(1e-9);
80    let total_nonmatch = total_nonmatch.max(1e-9);
81
82    let mut m = vec![vec![1e-9f32; N_LEVELS]; n_fields];
83    let mut u = vec![vec![1e-9f32; N_LEVELS]; n_fields];
84
85    for f in 0..n_fields {
86        for l in 0..N_LEVELS {
87            m[f][l] = (m_num[f][l] / total_match).max(1e-9);
88            u[f][l] = (u_num[f][l] / total_nonmatch).max(1e-9);
89        }
90        let m_sum: f32 = m[f].iter().sum();
91        let u_sum: f32 = u[f].iter().sum();
92        for l in 0..N_LEVELS {
93            m[f][l] /= m_sum;
94            u[f][l] /= u_sum;
95        }
96    }
97
98    let lambda    = (total_match / n_pairs as f32).max(0.001).min(0.999);
99    let log_prior = (lambda / (1.0 - lambda)).ln();
100
101    ModelParams {
102        m,
103        u,
104        log_prior_odds:  log_prior,
105        upper_threshold: prev.upper_threshold,
106        lower_threshold: prev.lower_threshold,
107    }
108}
109
110// ── Delta ─────────────────────────────────────────────────────────────────────
111
112fn params_delta(a: &ModelParams, b: &ModelParams) -> f32 {
113    let mut max_delta = 0.0f32;
114    for (am, bm) in a.m.iter().zip(b.m.iter()) {
115        for (&av, &bv) in am.iter().zip(bm.iter()) {
116            max_delta = max_delta.max((av - bv).abs());
117        }
118    }
119    for (au, bu) in a.u.iter().zip(b.u.iter()) {
120        for (&av, &bv) in au.iter().zip(bu.iter()) {
121            max_delta = max_delta.max((av - bv).abs());
122        }
123    }
124    max_delta
125}
126
127// ── Initialization ────────────────────────────────────────────────────────────
128
129fn init_from_priors(n_fields: usize) -> ModelParams {
130    let m = vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields];
131    let u = vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields];
132    ModelParams {
133        m,
134        u,
135        log_prior_odds:  0.0,
136        upper_threshold: 0.9,
137        lower_threshold: 0.1,
138    }
139}
140
141// ── Public API ────────────────────────────────────────────────────────────────
142
143/// Estimate the prior match rate λ = P(true match in candidate set).
144pub fn estimate_lambda(batch: &ComparisonBatch) -> f32 {
145    if batch.n_pairs == 0 { return 0.01; }
146    let exact = ComparisonLevel::Exact as u8;
147    let n_pairs = batch.n_pairs;
148    let high_sim_count = (0..n_pairs)
149        .filter(|&p| {
150            (0..batch.n_fields).any(|f| batch.levels[f * n_pairs + p] == exact)
151        })
152        .count();
153    let raw = high_sim_count as f32 / n_pairs as f32;
154    raw.max(0.001).min(0.5)
155}
156
157/// Auto-calibrate upper/lower thresholds after EM converges.
158pub fn auto_calibrate_thresholds(scores: &[f32]) -> (f32, f32) {
159    if scores.is_empty() { return (0.9, 0.1); }
160
161    let high: Vec<f32> = scores.iter().copied().filter(|&s| s >= 0.7).collect();
162    let low:  Vec<f32> = scores.iter().copied().filter(|&s| s <= 0.3).collect();
163
164    let upper = if high.len() >= 10 {
165        let mut sorted = high.clone();
166        sorted.sort_by(f32::total_cmp);
167        sorted[(sorted.len() as f32 * 0.05) as usize].max(0.85)
168    } else {
169        0.9
170    };
171
172    let lower = if low.len() >= 10 {
173        let mut sorted = low.clone();
174        sorted.sort_by(f32::total_cmp);
175        sorted[(sorted.len() as f32 * 0.95) as usize].min(0.15)
176    } else {
177        0.1
178    };
179
180    (upper, lower)
181}
182
183/// Run the EM algorithm to learn m/u parameters without labels.
184pub fn run_em(
185    batch:    &ComparisonBatch,
186    init:     Option<ModelParams>,
187    max_iter: usize,
188) -> Result<ModelParams, ZerError> {
189    if batch.n_pairs == 0 {
190        return Err(ZerError::SchemaMismatch { expected: 1, got: 0 });
191    }
192
193    let n_fields = batch.n_fields;
194    if n_fields == 0 {
195        return Err(ZerError::EmptySchema);
196    }
197
198    let mut params = init.unwrap_or_else(|| {
199        let mut p = init_from_priors(n_fields);
200        let lambda = estimate_lambda(batch);
201        p.log_prior_odds = (lambda / (1.0 - lambda)).ln();
202        tracing::debug!(lambda, "auto-estimated prior match rate");
203        p
204    });
205
206    for iter in 0..max_iter {
207        let posteriors: Vec<f32> = (0..batch.n_pairs)
208            .map(|p| e_step_p(batch, p, &params))
209            .collect();
210
211        let new_params = m_step(batch, &posteriors, &params);
212        let delta      = params_delta(&params, &new_params);
213
214        params = new_params;
215        tracing::debug!(iter, delta, "EM iteration");
216
217        if delta < 1e-6 {
218            tracing::info!(iter, "EM converged");
219            break;
220        }
221    }
222
223    Ok(params)
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use zer_core::comparison::{ComparisonBatch, ComparisonLevel, ComparisonVector};
230
231    fn uniform_vector(id_a: u64, id_b: u64, n_fields: usize, level: ComparisonLevel) -> ComparisonVector {
232        ComparisonVector::new(id_a, id_b, vec![level; n_fields])
233    }
234
235    fn synthetic_batch(n_match: usize, n_nonmatch: usize, n_fields: usize) -> ComparisonBatch {
236        let mut vecs = Vec::with_capacity(n_match + n_nonmatch);
237        for i in 0..n_match {
238            vecs.push(uniform_vector(i as u64, (i + 1_000_000) as u64, n_fields, ComparisonLevel::Exact));
239        }
240        for i in 0..n_nonmatch {
241            vecs.push(uniform_vector((i + 2_000_000) as u64, (i + 3_000_000) as u64, n_fields, ComparisonLevel::None));
242        }
243        ComparisonBatch::from_vectors(&vecs)
244    }
245
246    #[test]
247    fn em_converges_on_synthetic_data() {
248        let batch  = synthetic_batch(200, 800, 4);
249        let params = run_em(&batch, None, 100).expect("EM should succeed");
250        for f in 0..4 {
251            let exact_idx = ComparisonLevel::Exact as usize;
252            assert!(
253                params.m[f][exact_idx] > params.u[f][exact_idx],
254                "m[Exact] should exceed u[Exact] for field {f}: m={}, u={}",
255                params.m[f][exact_idx], params.u[f][exact_idx]
256            );
257        }
258    }
259
260    #[test]
261    fn em_warm_start_converges_faster() {
262        let batch = synthetic_batch(200, 800, 3);
263
264        let warm = ModelParams {
265            m: vec![vec![0.02, 0.06, 0.12, 0.78]; 3],
266            u: vec![vec![0.75, 0.12, 0.08, 0.05]; 3],
267            log_prior_odds:  (0.2_f32 / 0.8_f32).ln(),
268            upper_threshold: 0.9,
269            lower_threshold: 0.1,
270        };
271
272        let params = run_em(&batch, Some(warm), 5).expect("warm start EM should succeed");
273        for f in 0..3 {
274            let exact_idx = ComparisonLevel::Exact as usize;
275            assert!(params.m[f][exact_idx] > params.u[f][exact_idx],
276                "warm-start: m[Exact] should exceed u[Exact] for field {f}");
277        }
278    }
279
280    #[test]
281    fn em_empty_batch_returns_error() {
282        let batch = ComparisonBatch::new(0, 0, vec![]);
283        let result = run_em(&batch, None, 50);
284        assert!(result.is_err(), "empty batch should return an error");
285    }
286
287    #[test]
288    fn estimate_lambda_all_exact() {
289        let batch  = synthetic_batch(100, 0, 2);
290        let lambda = estimate_lambda(&batch);
291        assert_eq!(lambda, 0.5);
292    }
293
294    #[test]
295    fn estimate_lambda_all_none() {
296        let batch  = synthetic_batch(0, 100, 2);
297        let lambda = estimate_lambda(&batch);
298        assert_eq!(lambda, 0.001);
299    }
300
301    #[test]
302    fn auto_calibrate_bimodal_distribution() {
303        let mut scores = vec![];
304        for _ in 0..50  { scores.push(0.95_f32); }
305        for _ in 0..200 { scores.push(0.05_f32); }
306        let (upper, lower) = auto_calibrate_thresholds(&scores);
307        assert!(upper >= 0.85, "upper threshold should be ≥ 0.85, got {upper}");
308        assert!(lower <= 0.15, "lower threshold should be ≤ 0.15, got {lower}");
309    }
310
311    #[test]
312    fn auto_calibrate_empty_returns_defaults() {
313        let (upper, lower) = auto_calibrate_thresholds(&[]);
314        assert_eq!(upper, 0.9);
315        assert_eq!(lower, 0.1);
316    }
317}