1use std::collections::{HashMap, HashSet, VecDeque};
2
3use petgraph::{
4 graph::{NodeIndex, UnGraph},
5 visit::EdgeRef,
6};
7use zer_core::{record::RecordId, scoring::ScoredPair};
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct ClusterConfig {
12 pub max_cluster_size: usize,
14 pub within_cluster_min: f32,
17}
18
19impl Default for ClusterConfig {
20 fn default() -> Self {
21 Self {
22 max_cluster_size: 50,
23 within_cluster_min: 0.85,
24 }
25 }
26}
27
28pub struct ClusterGraph {
33 graph: UnGraph<RecordId, f32>,
34 node_map: HashMap<RecordId, NodeIndex>,
35}
36
37impl ClusterGraph {
38 pub fn new() -> Self {
39 Self {
40 graph: UnGraph::new_undirected(),
41 node_map: HashMap::new(),
42 }
43 }
44
45 pub fn add_pairs(&mut self, pairs: &[ScoredPair]) {
47 for pair in pairs {
48 let a = self.get_or_insert(pair.record_a);
49 let b = self.get_or_insert(pair.record_b);
50 if let Some(edge) = self.graph.find_edge(a, b) {
52 let w = self.graph.edge_weight_mut(edge).unwrap();
53 if pair.match_probability > *w {
54 *w = pair.match_probability;
55 }
56 } else {
57 self.graph.add_edge(a, b, pair.match_probability);
58 }
59 }
60 }
61
62 pub fn compute_clusters(&self, config: &ClusterConfig) -> Vec<Vec<RecordId>> {
73 let pruned = weak_edge_removal(&self.graph, config.within_cluster_min);
74 let mut components = extract_components(&pruned);
75
76 let mut result = Vec::new();
78 for comp in components.drain(..) {
79 if comp.len() <= config.max_cluster_size {
80 if comp.len() >= 2 {
81 result.push(comp);
82 }
83 } else {
84 let sub = star_prune(&self.graph, &comp, config.within_cluster_min);
85 result.extend(sub.into_iter().filter(|c| c.len() >= 2));
86 }
87 }
88 result
89 }
90
91 fn get_or_insert(&mut self, id: RecordId) -> NodeIndex {
92 if let Some(&idx) = self.node_map.get(&id) {
93 return idx;
94 }
95 let idx = self.graph.add_node(id);
96 self.node_map.insert(id, idx);
97 idx
98 }
99}
100
101impl Default for ClusterGraph {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107fn weak_edge_removal(graph: &UnGraph<RecordId, f32>, min_weight: f32) -> UnGraph<RecordId, f32> {
115 let mut g = graph.clone();
116 let mut weak: Vec<_> = g
117 .edge_indices()
118 .filter(|&e| *g.edge_weight(e).unwrap() < min_weight)
119 .collect();
120 weak.sort_by_key(|e| std::cmp::Reverse(e.index()));
121 for e in weak {
122 g.remove_edge(e);
123 }
124 g
125}
126
127pub(crate) fn extract_components(graph: &UnGraph<RecordId, f32>) -> Vec<Vec<RecordId>> {
132 let mut visited = HashSet::new();
133 let mut components = Vec::new();
134
135 for start in graph.node_indices() {
136 if !visited.insert(start) {
137 continue;
138 }
139 let mut comp = vec![graph[start]];
140 let mut queue = VecDeque::from([start]);
141
142 while let Some(node) = queue.pop_front() {
143 for nb in graph.neighbors(node) {
144 if visited.insert(nb) {
145 comp.push(graph[nb]);
146 queue.push_back(nb);
147 }
148 }
149 }
150 components.push(comp);
151 }
152 components
153}
154
155fn star_prune(
161 graph: &UnGraph<RecordId, f32>,
162 comp: &[RecordId],
163 min_weight: f32,
164) -> Vec<Vec<RecordId>> {
165 let comp_set: HashSet<RecordId> = comp.iter().copied().collect();
166
167 let node_indices: Vec<NodeIndex> = graph
169 .node_indices()
170 .filter(|&n| comp_set.contains(&graph[n]))
171 .collect();
172
173 let hub = node_indices.iter().max_by_key(|&&n| {
175 graph
176 .edges(n)
177 .filter(|e| {
178 let other = if e.source() == n { e.target() } else { e.source() };
179 comp_set.contains(&graph[other]) && *e.weight() >= min_weight
180 })
181 .count()
182 });
183
184 let Some(&hub_idx) = hub else {
185 return vec![];
186 };
187
188 let mut sub: UnGraph<RecordId, f32> = UnGraph::new_undirected();
190 let mut sub_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
191
192 let hub_sub = sub.add_node(graph[hub_idx]);
193 sub_map.insert(hub_idx, hub_sub);
194
195 for edge in graph.edges(hub_idx) {
196 let other = if edge.source() == hub_idx { edge.target() } else { edge.source() };
197 if !comp_set.contains(&graph[other]) || *edge.weight() < min_weight {
198 continue;
199 }
200 let other_sub = *sub_map.entry(other).or_insert_with(|| sub.add_node(graph[other]));
201 sub.add_edge(hub_sub, other_sub, *edge.weight());
202 }
203
204 extract_components(&sub)
205}
206
207#[cfg(test)]
210mod tests {
211 use super::*;
212 use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
213 use zer_core::scoring::ScoredPair;
214
215 fn auto_match_pair(a: u64, b: u64, prob: f32) -> ScoredPair {
216 ScoredPair {
217 record_a: a,
218 record_b: b,
219 match_weight: 0.0,
220 match_probability: prob,
221 vector: ComparisonVector { record_a: a, record_b: b, levels: vec![] },
222 band: MatchBand::AutoMatch,
223 }
224 }
225
226 fn config() -> ClusterConfig {
227 ClusterConfig { max_cluster_size: 50, within_cluster_min: 0.85 }
228 }
229
230 #[test]
231 fn basic_connected_components() {
232 let mut g = ClusterGraph::new();
234 g.add_pairs(&[auto_match_pair(1, 2, 0.95), auto_match_pair(2, 3, 0.95)]);
235 let clusters = g.compute_clusters(&config());
236 assert_eq!(clusters.len(), 1);
237 assert_eq!(clusters[0].len(), 3);
238 }
239
240 #[test]
241 fn single_pair_one_cluster() {
242 let mut g = ClusterGraph::new();
243 g.add_pairs(&[auto_match_pair(1, 2, 0.95)]);
244 let clusters = g.compute_clusters(&config());
245 assert_eq!(clusters.len(), 1);
246 assert_eq!(clusters[0].len(), 2);
247 }
248
249 #[test]
250 fn weak_bridge_splits_chain() {
251 let mut g = ClusterGraph::new();
255 g.add_pairs(&[
256 auto_match_pair(1, 2, 0.95), auto_match_pair(2, 3, 0.28), auto_match_pair(3, 4, 0.95), ]);
260 let mut clusters = g.compute_clusters(&config());
261 clusters.sort_by_key(|c| *c.iter().min().unwrap());
262 assert_eq!(clusters.len(), 2, "weak bridge must split chain into 2 clusters");
263 assert_eq!(clusters[0].len(), 2);
264 assert_eq!(clusters[1].len(), 2);
265
266 let mut c0 = clusters[0].clone(); c0.sort();
267 let mut c1 = clusters[1].clone(); c1.sort();
268 assert_eq!(c0, vec![1, 2]);
269 assert_eq!(c1, vec![3, 4]);
270 }
271
272 #[test]
273 fn star_pruning_splits_oversized_cluster() {
274 let cfg = ClusterConfig { max_cluster_size: 50, within_cluster_min: 0.85 };
278 let mut g = ClusterGraph::new();
279 let pairs: Vec<_> = (1u64..=60).map(|i| auto_match_pair(0, i, 0.95)).collect();
280 g.add_pairs(&pairs);
281
282 let clusters = g.compute_clusters(&cfg);
283 assert!(!clusters.is_empty());
287 let total_members: usize = clusters.iter().map(|c| c.len()).sum();
288 assert!(total_members >= 2);
289 }
290
291 #[test]
292 fn two_disconnected_pairs_two_clusters() {
293 let mut g = ClusterGraph::new();
294 g.add_pairs(&[auto_match_pair(1, 2, 0.95), auto_match_pair(3, 4, 0.95)]);
295 let clusters = g.compute_clusters(&config());
296 assert_eq!(clusters.len(), 2);
297 }
298}