Skip to main content

zer_cluster/
clusterer.rs

1use zer_core::{
2    entity::{Entity, EntityId, EntityMember, ResolutionMethod},
3    record::RecordId,
4    scoring::{ModelParams, ScoredPair},
5    traits::Clusterer,
6};
7
8use crate::{
9    graph::{ClusterConfig, ClusterGraph},
10    threshold::partition_by_band,
11};
12
13/// Connected-components clusterer with weak-edge removal and star pruning.
14///
15/// Algorithm:
16/// 1. Partition pairs into bands using the supplied `ModelParams`.
17/// 2. Build an undirected graph from `AutoMatch` pairs only.
18/// 3. Remove edges below `config.within_cluster_min` (chain-breaking).
19/// 4. Split oversized components via star pruning.
20/// 5. Emit one `Entity` per non-trivial component (≥ 2 members).
21#[derive(Default)]
22pub struct ConnectedComponentsClusterer {
23    pub config: ClusterConfig,
24}
25
26impl Clusterer for ConnectedComponentsClusterer {
27    fn cluster(&self, pairs: &[ScoredPair], params: &ModelParams) -> Vec<Entity> {
28        let banded = partition_by_band(pairs.to_vec(), params);
29
30        let mut graph = ClusterGraph::new();
31        graph.add_pairs(&banded.auto_match);
32
33        let components = graph.compute_clusters(&self.config);
34
35        components
36            .into_iter()
37            .enumerate()
38            .map(|(idx, members)| {
39                let entity_members = members
40                    .iter()
41                    .map(|&rid| EntityMember {
42                        record_id: rid,
43                        // Defaults to the numeric ID string; the pipeline's
44                        // batch processor overwrites this with the real natural
45                        // key from the input records after clustering.
46                        record_key: rid.to_string(),
47                        score: best_score_in_cluster(rid, &banded.auto_match),
48                        method: ResolutionMethod::AutoMatch,
49                        source: None,
50                    })
51                    .collect();
52
53                Entity {
54                    // Temporary sequential ids, caller should persist through
55                    // EntityStore.upsert_entity() to get stable database ids.
56                    id: idx as EntityId + 1,
57                    members: entity_members,
58                }
59            })
60            .collect()
61    }
62}
63
64/// Returns the highest `match_probability` of any `AutoMatch` pair that
65/// involves `record_id`.
66fn best_score_in_cluster(record_id: RecordId, pairs: &[ScoredPair]) -> f32 {
67    pairs
68        .iter()
69        .filter(|p| p.record_a == record_id || p.record_b == record_id)
70        .map(|p| p.match_probability)
71        .fold(0.0_f32, f32::max)
72}
73
74// ── Unit tests ────────────────────────────────────────────────────────────────
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
80
81    fn params() -> ModelParams {
82        ModelParams {
83            m: vec![],
84            u: vec![],
85            log_prior_odds: 0.0,
86            upper_threshold: 0.8,
87            lower_threshold: 0.2,
88        }
89    }
90
91    fn pair(a: u64, b: u64, prob: f32, band: MatchBand) -> ScoredPair {
92        ScoredPair {
93            record_a: a,
94            record_b: b,
95            match_weight: 0.0,
96            match_probability: prob,
97            vector: ComparisonVector {
98                record_a: a,
99                record_b: b,
100                levels: vec![],
101            },
102            band,
103        }
104    }
105
106    #[test]
107    fn empty_pairs_returns_empty() {
108        let clusterer = ConnectedComponentsClusterer::default();
109        let entities = clusterer.cluster(&[], &params());
110        assert!(entities.is_empty());
111    }
112
113    #[test]
114    fn two_matched_pairs_form_one_entity() {
115        let clusterer = ConnectedComponentsClusterer::default();
116        let pairs = vec![
117            pair(1, 2, 0.95, MatchBand::AutoMatch),
118            pair(2, 3, 0.95, MatchBand::AutoMatch),
119        ];
120        let entities = clusterer.cluster(&pairs, &params());
121        assert_eq!(entities.len(), 1);
122        assert_eq!(entities[0].members.len(), 3);
123    }
124
125    #[test]
126    fn auto_rejected_pairs_ignored() {
127        let clusterer = ConnectedComponentsClusterer::default();
128        let pairs = vec![
129            pair(1, 2, 0.95, MatchBand::AutoMatch),
130            pair(3, 4, 0.05, MatchBand::AutoReject),
131        ];
132        let entities = clusterer.cluster(&pairs, &params());
133        assert_eq!(entities.len(), 1);
134        let rids: Vec<_> = entities[0].members.iter().map(|m| m.record_id).collect();
135        assert!(rids.contains(&1));
136        assert!(rids.contains(&2));
137        assert!(!rids.contains(&3));
138        assert!(!rids.contains(&4));
139    }
140
141    #[test]
142    fn members_get_correct_scores() {
143        let clusterer = ConnectedComponentsClusterer::default();
144        let pairs = vec![
145            pair(1, 2, 0.92, MatchBand::AutoMatch),
146            pair(1, 3, 0.88, MatchBand::AutoMatch),
147        ];
148        let entities = clusterer.cluster(&pairs, &params());
149        assert_eq!(entities.len(), 1);
150
151        let member_1 = entities[0]
152            .members
153            .iter()
154            .find(|m| m.record_id == 1)
155            .unwrap();
156        assert!(
157            (member_1.score - 0.92).abs() < 1e-5,
158            "record 1 best score is 0.92"
159        );
160    }
161}