Skip to main content

zer_cluster/
clusterer.rs

1use zer_core::{
2    entity::{Entity, EntityId, EntityMember, ResolutionMethod},
3    record::RecordId,
4    scoring::{ModelParams, ScoredPair},
5    traits::Clusterer,
6};
7
8use crate::{
9    graph::{ClusterConfig, ClusterGraph},
10    threshold::partition_by_band,
11};
12
13/// Connected-components clusterer with weak-edge removal and star pruning.
14///
15/// Algorithm:
16/// 1. Partition pairs into bands using the supplied `ModelParams`.
17/// 2. Build an undirected graph from `AutoMatch` pairs only.
18/// 3. Remove edges below `config.within_cluster_min` (chain-breaking).
19/// 4. Split oversized components via star pruning.
20/// 5. Emit one `Entity` per non-trivial component (≥ 2 members).
21pub struct ConnectedComponentsClusterer {
22    pub config: ClusterConfig,
23}
24
25impl Default for ConnectedComponentsClusterer {
26    fn default() -> Self {
27        Self { config: ClusterConfig::default() }
28    }
29}
30
31impl Clusterer for ConnectedComponentsClusterer {
32    fn cluster(&self, pairs: &[ScoredPair], params: &ModelParams) -> Vec<Entity> {
33        let banded = partition_by_band(pairs.to_vec(), params);
34
35        let mut graph = ClusterGraph::new();
36        graph.add_pairs(&banded.auto_match);
37
38        let components = graph.compute_clusters(&self.config);
39
40        components
41            .into_iter()
42            .enumerate()
43            .map(|(idx, members)| {
44                let entity_members = members
45                    .iter()
46                    .map(|&rid| EntityMember {
47                        record_id: rid,
48                        score:     best_score_in_cluster(rid, &banded.auto_match),
49                        method:    ResolutionMethod::AutoMatch,
50                        source:    None,
51                    })
52                    .collect();
53
54                Entity {
55                    // Temporary sequential ids, caller should persist through
56                    // EntityStore.upsert_entity() to get stable database ids.
57                    id:      idx as EntityId + 1,
58                    members: entity_members,
59                }
60            })
61            .collect()
62    }
63}
64
65/// Returns the highest `match_probability` of any `AutoMatch` pair that
66/// involves `record_id`.
67fn best_score_in_cluster(record_id: RecordId, pairs: &[ScoredPair]) -> f32 {
68    pairs
69        .iter()
70        .filter(|p| p.record_a == record_id || p.record_b == record_id)
71        .map(|p| p.match_probability)
72        .fold(0.0_f32, f32::max)
73}
74
75// ── Unit tests ────────────────────────────────────────────────────────────────
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
81
82    fn params() -> ModelParams {
83        ModelParams {
84            m: vec![],
85            u: vec![],
86            log_prior_odds: 0.0,
87            upper_threshold: 0.8,
88            lower_threshold: 0.2,
89        }
90    }
91
92    fn pair(a: u64, b: u64, prob: f32, band: MatchBand) -> ScoredPair {
93        ScoredPair {
94            record_a:          a,
95            record_b:          b,
96            match_weight:      0.0,
97            match_probability: prob,
98            vector:            ComparisonVector { record_a: a, record_b: b, levels: vec![] },
99            band,
100        }
101    }
102
103    #[test]
104    fn empty_pairs_returns_empty() {
105        let clusterer = ConnectedComponentsClusterer::default();
106        let entities = clusterer.cluster(&[], &params());
107        assert!(entities.is_empty());
108    }
109
110    #[test]
111    fn two_matched_pairs_form_one_entity() {
112        let clusterer = ConnectedComponentsClusterer::default();
113        let pairs = vec![
114            pair(1, 2, 0.95, MatchBand::AutoMatch),
115            pair(2, 3, 0.95, MatchBand::AutoMatch),
116        ];
117        let entities = clusterer.cluster(&pairs, &params());
118        assert_eq!(entities.len(), 1);
119        assert_eq!(entities[0].members.len(), 3);
120    }
121
122    #[test]
123    fn auto_rejected_pairs_ignored() {
124        let clusterer = ConnectedComponentsClusterer::default();
125        let pairs = vec![
126            pair(1, 2, 0.95, MatchBand::AutoMatch),
127            pair(3, 4, 0.05, MatchBand::AutoReject),
128        ];
129        let entities = clusterer.cluster(&pairs, &params());
130        assert_eq!(entities.len(), 1);
131        let rids: Vec<_> = entities[0].members.iter().map(|m| m.record_id).collect();
132        assert!(rids.contains(&1));
133        assert!(rids.contains(&2));
134        assert!(!rids.contains(&3));
135        assert!(!rids.contains(&4));
136    }
137
138    #[test]
139    fn members_get_correct_scores() {
140        let clusterer = ConnectedComponentsClusterer::default();
141        let pairs = vec![
142            pair(1, 2, 0.92, MatchBand::AutoMatch),
143            pair(1, 3, 0.88, MatchBand::AutoMatch),
144        ];
145        let entities = clusterer.cluster(&pairs, &params());
146        assert_eq!(entities.len(), 1);
147
148        let member_1 = entities[0].members.iter().find(|m| m.record_id == 1).unwrap();
149        assert!((member_1.score - 0.92).abs() < 1e-5, "record 1 best score is 0.92");
150    }
151}