Skip to main content

zer_cluster/
graph.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use petgraph::{
4    graph::{NodeIndex, UnGraph},
5    visit::EdgeRef,
6};
7use zer_core::{record::RecordId, scoring::ScoredPair};
8
9/// Parameters controlling cluster shape after graph construction.
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct ClusterConfig {
12    /// Clusters larger than this are subjected to star pruning.
13    pub max_cluster_size: usize,
14    /// Edges with weight below this threshold are removed before extracting
15    /// components (weak-edge removal / chain-breaking).
16    pub within_cluster_min: f32,
17}
18
19impl Default for ClusterConfig {
20    fn default() -> Self {
21        Self {
22            max_cluster_size: 50,
23            within_cluster_min: 0.85,
24        }
25    }
26}
27
28/// Undirected similarity graph over records.
29///
30/// Each node is a `RecordId`; each edge weight is the `match_probability` of
31/// the `AutoMatch` pair that connected those two records.
32pub struct ClusterGraph {
33    graph: UnGraph<RecordId, f32>,
34    node_map: HashMap<RecordId, NodeIndex>,
35}
36
37impl ClusterGraph {
38    pub fn new() -> Self {
39        Self {
40            graph: UnGraph::new_undirected(),
41            node_map: HashMap::new(),
42        }
43    }
44
45    /// Add `AutoMatch` pairs to the graph. Non-AutoMatch pairs are ignored.
46    pub fn add_pairs(&mut self, pairs: &[ScoredPair]) {
47        for pair in pairs {
48            let a = self.get_or_insert(pair.record_a);
49            let b = self.get_or_insert(pair.record_b);
50            // Avoid duplicate edges, keep the higher-weight one.
51            if let Some(edge) = self.graph.find_edge(a, b) {
52                let w = self.graph.edge_weight_mut(edge).unwrap();
53                if pair.match_probability > *w {
54                    *w = pair.match_probability;
55                }
56            } else {
57                self.graph.add_edge(a, b, pair.match_probability);
58            }
59        }
60    }
61
62    /// Compute clusters using the two-phase chain-breaking algorithm:
63    ///
64    /// 1. **Weak-edge removal**: remove all edges with weight <
65    ///    `config.within_cluster_min` then extract connected components.
66    /// 2. **Star pruning**: for any component whose size exceeds
67    ///    `config.max_cluster_size`, find the hub (highest-degree node in the
68    ///    original graph), remove all non-hub edges below the min threshold,
69    ///    and re-extract components from that sub-graph.
70    ///
71    /// Returns only non-trivial components (size ≥ 2).
72    pub fn compute_clusters(&self, config: &ClusterConfig) -> Vec<Vec<RecordId>> {
73        let pruned = weak_edge_removal(&self.graph, config.within_cluster_min);
74        let mut components = extract_components(&pruned);
75
76        // Star pruning for oversized components.
77        let mut result = Vec::new();
78        for comp in components.drain(..) {
79            if comp.len() <= config.max_cluster_size {
80                if comp.len() >= 2 {
81                    result.push(comp);
82                }
83            } else {
84                let sub = star_prune(&self.graph, &comp, config.within_cluster_min);
85                result.extend(sub.into_iter().filter(|c| c.len() >= 2));
86            }
87        }
88        result
89    }
90
91    fn get_or_insert(&mut self, id: RecordId) -> NodeIndex {
92        if let Some(&idx) = self.node_map.get(&id) {
93            return idx;
94        }
95        let idx = self.graph.add_node(id);
96        self.node_map.insert(id, idx);
97        idx
98    }
99}
100
101impl Default for ClusterGraph {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107// ── Graph algorithms ──────────────────────────────────────────────────────────
108
109/// Clone the graph, remove all edges below `min_weight`, and return the result.
110///
111/// Edge indices are removed in descending order to avoid the petgraph `Graph`
112/// index-swap issue: removing edge `i` moves the last edge into slot `i`, so
113/// removing from highest to lowest keeps all lower indices stable.
114fn weak_edge_removal(graph: &UnGraph<RecordId, f32>, min_weight: f32) -> UnGraph<RecordId, f32> {
115    let mut g = graph.clone();
116    let mut weak: Vec<_> = g
117        .edge_indices()
118        .filter(|&e| *g.edge_weight(e).unwrap() < min_weight)
119        .collect();
120    weak.sort_by_key(|e| std::cmp::Reverse(e.index()));
121    for e in weak {
122        g.remove_edge(e);
123    }
124    g
125}
126
127/// BFS-based connected-component extraction.
128///
129/// `petgraph::algo::connected_components()` returns only a count, this
130/// function also yields the actual groups.
131pub(crate) fn extract_components(graph: &UnGraph<RecordId, f32>) -> Vec<Vec<RecordId>> {
132    let mut visited = HashSet::new();
133    let mut components = Vec::new();
134
135    for start in graph.node_indices() {
136        if !visited.insert(start) {
137            continue;
138        }
139        let mut comp = vec![graph[start]];
140        let mut queue = VecDeque::from([start]);
141
142        while let Some(node) = queue.pop_front() {
143            for nb in graph.neighbors(node) {
144                if visited.insert(nb) {
145                    comp.push(graph[nb]);
146                    queue.push_back(nb);
147                }
148            }
149        }
150        components.push(comp);
151    }
152    components
153}
154
155/// Star pruning for a single oversized component.
156///
157/// Finds the hub (highest-degree node in the original graph restricted to
158/// `comp`), builds a sub-graph containing only hub-edges with weight ≥
159/// `min_weight`, and returns the resulting sub-components.
160fn star_prune(
161    graph: &UnGraph<RecordId, f32>,
162    comp: &[RecordId],
163    min_weight: f32,
164) -> Vec<Vec<RecordId>> {
165    let comp_set: HashSet<RecordId> = comp.iter().copied().collect();
166
167    // Identify node indices in the original graph for this component.
168    let node_indices: Vec<NodeIndex> = graph
169        .node_indices()
170        .filter(|&n| comp_set.contains(&graph[n]))
171        .collect();
172
173    // Find hub: node with most edges to other comp members with weight >= min.
174    let hub = node_indices.iter().max_by_key(|&&n| {
175        graph
176            .edges(n)
177            .filter(|e| {
178                let other = if e.source() == n {
179                    e.target()
180                } else {
181                    e.source()
182                };
183                comp_set.contains(&graph[other]) && *e.weight() >= min_weight
184            })
185            .count()
186    });
187
188    let Some(&hub_idx) = hub else {
189        return vec![];
190    };
191
192    // Build sub-graph: hub + its qualifying neighbors.
193    let mut sub: UnGraph<RecordId, f32> = UnGraph::new_undirected();
194    let mut sub_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
195
196    let hub_sub = sub.add_node(graph[hub_idx]);
197    sub_map.insert(hub_idx, hub_sub);
198
199    for edge in graph.edges(hub_idx) {
200        let other = if edge.source() == hub_idx {
201            edge.target()
202        } else {
203            edge.source()
204        };
205        if !comp_set.contains(&graph[other]) || *edge.weight() < min_weight {
206            continue;
207        }
208        let other_sub = *sub_map
209            .entry(other)
210            .or_insert_with(|| sub.add_node(graph[other]));
211        sub.add_edge(hub_sub, other_sub, *edge.weight());
212    }
213
214    extract_components(&sub)
215}
216
217// ── Unit tests ────────────────────────────────────────────────────────────────
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use zer_core::scoring::ScoredPair;
223    use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
224
225    fn auto_match_pair(a: u64, b: u64, prob: f32) -> ScoredPair {
226        ScoredPair {
227            record_a: a,
228            record_b: b,
229            match_weight: 0.0,
230            match_probability: prob,
231            vector: ComparisonVector {
232                record_a: a,
233                record_b: b,
234                levels: vec![],
235            },
236            band: MatchBand::AutoMatch,
237        }
238    }
239
240    fn config() -> ClusterConfig {
241        ClusterConfig {
242            max_cluster_size: 50,
243            within_cluster_min: 0.85,
244        }
245    }
246
247    #[test]
248    fn basic_connected_components() {
249        // A-B, B-C → one component of 3
250        let mut g = ClusterGraph::new();
251        g.add_pairs(&[auto_match_pair(1, 2, 0.95), auto_match_pair(2, 3, 0.95)]);
252        let clusters = g.compute_clusters(&config());
253        assert_eq!(clusters.len(), 1);
254        assert_eq!(clusters[0].len(), 3);
255    }
256
257    #[test]
258    fn single_pair_one_cluster() {
259        let mut g = ClusterGraph::new();
260        g.add_pairs(&[auto_match_pair(1, 2, 0.95)]);
261        let clusters = g.compute_clusters(&config());
262        assert_eq!(clusters.len(), 1);
263        assert_eq!(clusters[0].len(), 2);
264    }
265
266    #[test]
267    fn weak_bridge_splits_chain() {
268        // A -[0.95]- B -[0.28]- C -[0.95]- D
269        // with within_cluster_min = 0.85, the B-C edge is removed
270        // → {A,B} and {C,D}
271        let mut g = ClusterGraph::new();
272        g.add_pairs(&[
273            auto_match_pair(1, 2, 0.95), // A-B strong
274            auto_match_pair(2, 3, 0.28), // B-C weak bridge
275            auto_match_pair(3, 4, 0.95), // C-D strong
276        ]);
277        let mut clusters = g.compute_clusters(&config());
278        clusters.sort_by_key(|c| *c.iter().min().unwrap());
279        assert_eq!(
280            clusters.len(),
281            2,
282            "weak bridge must split chain into 2 clusters"
283        );
284        assert_eq!(clusters[0].len(), 2);
285        assert_eq!(clusters[1].len(), 2);
286
287        let mut c0 = clusters[0].clone();
288        c0.sort();
289        let mut c1 = clusters[1].clone();
290        c1.sort();
291        assert_eq!(c0, vec![1, 2]);
292        assert_eq!(c1, vec![3, 4]);
293    }
294
295    #[test]
296    fn star_pruning_splits_oversized_cluster() {
297        // Hub (id=0) connected to 60 satellites with prob 0.95.
298        // max_cluster_size = 50 → star pruning kicks in, yielding the hub+satellites
299        // as a valid cluster (star pruning keeps all hub-edges ≥ min_weight).
300        let cfg = ClusterConfig {
301            max_cluster_size: 50,
302            within_cluster_min: 0.85,
303        };
304        let mut g = ClusterGraph::new();
305        let pairs: Vec<_> = (1u64..=60).map(|i| auto_match_pair(0, i, 0.95)).collect();
306        g.add_pairs(&pairs);
307
308        let clusters = g.compute_clusters(&cfg);
309        // After star pruning, the hub stays connected to all 60 neighbors
310        // (all edges >= 0.85), so we get one cluster of 61.
311        // The important thing is that oversized handling runs without panic.
312        assert!(!clusters.is_empty());
313        let total_members: usize = clusters.iter().map(|c| c.len()).sum();
314        assert!(total_members >= 2);
315    }
316
317    #[test]
318    fn two_disconnected_pairs_two_clusters() {
319        let mut g = ClusterGraph::new();
320        g.add_pairs(&[auto_match_pair(1, 2, 0.95), auto_match_pair(3, 4, 0.95)]);
321        let clusters = g.compute_clusters(&config());
322        assert_eq!(clusters.len(), 2);
323    }
324}