1use zer_core::{
2 entity::{Entity, EntityId, EntityMember, ResolutionMethod},
3 record::RecordId,
4 scoring::{ModelParams, ScoredPair},
5 traits::Clusterer,
6};
7
8use crate::{
9 graph::{ClusterConfig, ClusterGraph},
10 threshold::partition_by_band,
11};
12
13#[derive(Default)]
22pub struct ConnectedComponentsClusterer {
23 pub config: ClusterConfig,
24}
25
26impl Clusterer for ConnectedComponentsClusterer {
27 fn cluster(&self, pairs: &[ScoredPair], params: &ModelParams) -> Vec<Entity> {
28 let banded = partition_by_band(pairs.to_vec(), params);
29
30 let mut graph = ClusterGraph::new();
31 graph.add_pairs(&banded.auto_match);
32
33 let components = graph.compute_clusters(&self.config);
34
35 components
36 .into_iter()
37 .enumerate()
38 .map(|(idx, members)| {
39 let entity_members = members
40 .iter()
41 .map(|&rid| EntityMember {
42 record_id: rid,
43 record_key: rid.to_string(),
47 score: best_score_in_cluster(rid, &banded.auto_match),
48 method: ResolutionMethod::AutoMatch,
49 source: None,
50 })
51 .collect();
52
53 Entity {
54 id: idx as EntityId + 1,
57 members: entity_members,
58 }
59 })
60 .collect()
61 }
62}
63
64fn best_score_in_cluster(record_id: RecordId, pairs: &[ScoredPair]) -> f32 {
67 pairs
68 .iter()
69 .filter(|p| p.record_a == record_id || p.record_b == record_id)
70 .map(|p| p.match_probability)
71 .fold(0.0_f32, f32::max)
72}
73
74#[cfg(test)]
77mod tests {
78 use super::*;
79 use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
80
81 fn params() -> ModelParams {
82 ModelParams {
83 m: vec![],
84 u: vec![],
85 log_prior_odds: 0.0,
86 upper_threshold: 0.8,
87 lower_threshold: 0.2,
88 }
89 }
90
91 fn pair(a: u64, b: u64, prob: f32, band: MatchBand) -> ScoredPair {
92 ScoredPair {
93 record_a: a,
94 record_b: b,
95 match_weight: 0.0,
96 match_probability: prob,
97 vector: ComparisonVector {
98 record_a: a,
99 record_b: b,
100 levels: vec![],
101 },
102 band,
103 }
104 }
105
106 #[test]
107 fn empty_pairs_returns_empty() {
108 let clusterer = ConnectedComponentsClusterer::default();
109 let entities = clusterer.cluster(&[], ¶ms());
110 assert!(entities.is_empty());
111 }
112
113 #[test]
114 fn two_matched_pairs_form_one_entity() {
115 let clusterer = ConnectedComponentsClusterer::default();
116 let pairs = vec![
117 pair(1, 2, 0.95, MatchBand::AutoMatch),
118 pair(2, 3, 0.95, MatchBand::AutoMatch),
119 ];
120 let entities = clusterer.cluster(&pairs, ¶ms());
121 assert_eq!(entities.len(), 1);
122 assert_eq!(entities[0].members.len(), 3);
123 }
124
125 #[test]
126 fn auto_rejected_pairs_ignored() {
127 let clusterer = ConnectedComponentsClusterer::default();
128 let pairs = vec![
129 pair(1, 2, 0.95, MatchBand::AutoMatch),
130 pair(3, 4, 0.05, MatchBand::AutoReject),
131 ];
132 let entities = clusterer.cluster(&pairs, ¶ms());
133 assert_eq!(entities.len(), 1);
134 let rids: Vec<_> = entities[0].members.iter().map(|m| m.record_id).collect();
135 assert!(rids.contains(&1));
136 assert!(rids.contains(&2));
137 assert!(!rids.contains(&3));
138 assert!(!rids.contains(&4));
139 }
140
141 #[test]
142 fn members_get_correct_scores() {
143 let clusterer = ConnectedComponentsClusterer::default();
144 let pairs = vec![
145 pair(1, 2, 0.92, MatchBand::AutoMatch),
146 pair(1, 3, 0.88, MatchBand::AutoMatch),
147 ];
148 let entities = clusterer.cluster(&pairs, ¶ms());
149 assert_eq!(entities.len(), 1);
150
151 let member_1 = entities[0]
152 .members
153 .iter()
154 .find(|m| m.record_id == 1)
155 .unwrap();
156 assert!(
157 (member_1.score - 0.92).abs() < 1e-5,
158 "record 1 best score is 0.92"
159 );
160 }
161}