Skip to main content

uni_algo/algo/algorithms/
max_matching.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Maximum Bipartite Matching Algorithm (Hopcroft-Karp).
5//!
6//! Finds the maximum matching in a bipartite graph.
7//! Returns the matching edges and count.
8//! Requires the graph to be bipartite.
9
10use crate::algo::GraphProjection;
11use crate::algo::algorithms::{Algorithm, BipartiteCheck, BipartiteCheckConfig};
12use uni_common::core::id::Vid;
13
14pub struct MaximumMatching;
15
16#[derive(Debug, Clone, Default)]
17pub struct MaximumMatchingConfig {}
18
19pub struct MaximumMatchingResult {
20    pub match_count: usize,
21    pub matching: Vec<(Vid, Vid)>,
22}
23
24impl Algorithm for MaximumMatching {
25    type Config = MaximumMatchingConfig;
26    type Result = Result<MaximumMatchingResult, String>;
27
28    fn name() -> &'static str {
29        "max_matching"
30    }
31
32    fn run(graph: &GraphProjection, _config: Self::Config) -> Self::Result {
33        let n = graph.vertex_count();
34        if n == 0 {
35            return Ok(MaximumMatchingResult {
36                match_count: 0,
37                matching: Vec::new(),
38            });
39        }
40
41        // 1. Check Bipartite
42        let check = BipartiteCheck::run(graph, BipartiteCheckConfig::default());
43        if !check.is_bipartite {
44            return Err("Graph is not bipartite".to_string());
45        }
46
47        // Split into U (color 0) and V (color 1)
48        // partition is Vec<(Vid, u8)> but we need slots.
49        // We can reconstruct color map by slot.
50        // BipartiteCheck actually returns `partition: Vec<(Vid, u8)>`.
51        // We should probably modify BipartiteCheck to return slot map or re-run logic.
52        // Or map Vid back to slot.
53        let mut color = vec![0u8; n];
54        for (vid, c) in check.partition {
55            if let Some(slot) = graph.to_slot(vid) {
56                color[slot as usize] = c; // 0 or 1
57            }
58        }
59
60        let mut pair_u = vec![None; n]; // Pair for u in U (stores v in V)
61        let mut pair_v = vec![None; n]; // Pair for v in V (stores u in U)
62        let mut dist = vec![u32::MAX; n];
63
64        let u_nodes: Vec<usize> = (0..n).filter(|&i| color[i] == 0).collect();
65
66        let mut matching_size = 0;
67
68        loop {
69            // BFS
70            let mut queue = std::collections::VecDeque::new();
71            for &u in &u_nodes {
72                if pair_u[u].is_none() {
73                    dist[u] = 0;
74                    queue.push_back(u);
75                } else {
76                    dist[u] = u32::MAX;
77                }
78            }
79            let mut dist_null = u32::MAX;
80
81            while let Some(u) = queue.pop_front() {
82                if dist[u] < dist_null {
83                    for &v_u32 in graph.out_neighbors(u as u32) {
84                        let v = v_u32 as usize;
85                        // Since we treat graph as undirected for bipartite matching,
86                        // we must ensure we only traverse edges between partition sets.
87                        // Bipartite check ensures edges are only between 0 and 1.
88                        // But `out_neighbors` might be directed.
89                        // If graph is directed, do we treat as undirected?
90                        // Standard matching is on undirected edges.
91                        // If U->V, OK. If V->U?
92                        // Hopcroft-Karp usually formulated on U->V.
93                        // If we have edges in both directions, we might double count or traverse wrong.
94                        // We should only consider edges from U to V?
95                        // If `out_neighbors` contains V->U, we should ignore?
96                        // Since we iterate `u` in `u_nodes` (Set U), `out_neighbors` are neighbors of U.
97                        // They must be in V.
98
99                        if let Some(next_u) = pair_v[v] {
100                            if dist[next_u] == u32::MAX {
101                                dist[next_u] = dist[u] + 1;
102                                queue.push_back(next_u);
103                            }
104                        } else {
105                            dist_null = dist[u] + 1;
106                        }
107                    }
108
109                    // Also check in_neighbors if undirected?
110                    // GraphProjection might have `include_reverse`.
111                    // If we assume undirected connectivity, we need to check both.
112                    // If `u` in U, neighbors are in V.
113                    if graph.has_reverse() {
114                        for &v_u32 in graph.in_neighbors(u as u32) {
115                            let v = v_u32 as usize;
116                            if let Some(next_u) = pair_v[v] {
117                                if dist[next_u] == u32::MAX {
118                                    dist[next_u] = dist[u] + 1;
119                                    queue.push_back(next_u);
120                                }
121                            } else {
122                                dist_null = dist[u] + 1;
123                            }
124                        }
125                    }
126                }
127            }
128
129            if dist_null == u32::MAX {
130                break;
131            }
132
133            // DFS
134            for &u in &u_nodes {
135                if pair_u[u].is_none() && dfs(u, graph, &mut pair_u, &mut pair_v, &dist) {
136                    matching_size += 1;
137                }
138            }
139        }
140
141        let mut matching = Vec::new();
142        for u in u_nodes {
143            if let Some(v) = pair_u[u] {
144                matching.push((graph.to_vid(u as u32), graph.to_vid(v as u32)));
145            }
146        }
147
148        Ok(MaximumMatchingResult {
149            match_count: matching_size,
150            matching,
151        })
152    }
153}
154
155fn dfs(
156    u: usize,
157    graph: &GraphProjection,
158    pair_u: &mut [Option<usize>],
159    pair_v: &mut [Option<usize>],
160    dist: &[u32],
161) -> bool {
162    if dist[u] == u32::MAX {
163        return false;
164    }
165
166    let mut neighbors = Vec::new();
167    neighbors.extend_from_slice(graph.out_neighbors(u as u32));
168    if graph.has_reverse() {
169        neighbors.extend_from_slice(graph.in_neighbors(u as u32));
170    }
171
172    for &v_u32 in &neighbors {
173        let v = v_u32 as usize;
174        // Check if dist logic holds
175        let _next_dist = if let Some(next_u) = pair_v[v] {
176            dist[next_u]
177        } else {
178            u32::MAX // null node
179        };
180
181        // Target condition: next_dist == dist[u] + 1
182        // If pair_v[v] is None, dist[null] is conceptually dist[u]+1 if we reached free node.
183        // Wait, standard DFS logic:
184
185        let proceed = if let Some(next_u) = pair_v[v] {
186            dist[next_u] == dist[u] + 1 && dfs(next_u, graph, pair_u, pair_v, dist)
187        } else {
188            true // Found free vertex, augmenting path found
189        };
190
191        if proceed {
192            pair_v[v] = Some(u);
193            pair_u[u] = Some(v);
194            return true;
195        }
196    }
197
198    // Mark as visited/useless for this phase
199    // In standard HK, we reset dist to infinity? No, just don't visit again.
200    // Usually dist is not modified in DFS.
201
202    false
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::algo::test_utils::build_test_graph;
209
210    #[test]
211    fn test_matching_simple() {
212        // 0 -> 1, 2 -> 3
213        // Bipartite: {0, 2} and {1, 3}
214        // Matching: (0,1), (2,3) -> size 2
215        let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2), Vid::from(3)];
216        let edges = vec![(Vid::from(0), Vid::from(1)), (Vid::from(2), Vid::from(3))];
217        let graph = build_test_graph(vids, edges);
218
219        let result = MaximumMatching::run(&graph, MaximumMatchingConfig::default()).unwrap();
220        assert_eq!(result.match_count, 2);
221    }
222}