scirs2_graph/algorithms/
matching.rs

1//! Graph matching algorithms
2//!
3//! This module contains algorithms for finding matchings in graphs,
4//! particularly bipartite matchings.
5
6use crate::algorithms::connectivity::is_bipartite;
7use crate::base::{EdgeWeight, Graph, IndexType, Node};
8use crate::error::{GraphError, Result};
9use std::collections::{HashMap, HashSet};
10use std::hash::Hash;
11
12/// Maximum bipartite matching result
13#[derive(Debug, Clone)]
14pub struct BipartiteMatching<N: Node> {
15    /// The matching as a map from left nodes to right nodes
16    pub matching: HashMap<N, N>,
17    /// The size of the matching
18    pub size: usize,
19}
20
21/// Finds a maximum bipartite matching using the Hungarian algorithm
22///
23/// Assumes the graph is bipartite with nodes already colored.
24///
25/// # Arguments
26/// * `graph` - The bipartite graph
27/// * `coloring` - The bipartite coloring (0 or 1 for each node)
28///
29/// # Returns
30/// * A maximum bipartite matching
31#[allow(dead_code)]
32pub fn maximum_bipartite_matching<N, E, Ix>(
33    graph: &Graph<N, E, Ix>,
34    coloring: &HashMap<N, u8>,
35) -> BipartiteMatching<N>
36where
37    N: Node + std::fmt::Debug,
38    E: EdgeWeight,
39    Ix: petgraph::graph::IndexType,
40{
41    // Create a mapping from nodes to indices
42    let mut node_to_idx: HashMap<N, petgraph::graph::NodeIndex<Ix>> = HashMap::new();
43    for node_idx in graph.inner().node_indices() {
44        node_to_idx.insert(graph.inner()[node_idx].clone(), node_idx);
45    }
46
47    // Separate nodes into left and right sets based on coloring
48    let mut left_nodes = Vec::new();
49    let mut right_nodes = Vec::new();
50
51    for (node, &color) in coloring {
52        if color == 0 {
53            left_nodes.push(node.clone());
54        } else {
55            right_nodes.push(node.clone());
56        }
57    }
58
59    // Build matching using augmenting paths
60    let mut matching: HashMap<N, N> = HashMap::new();
61    let mut reverse_matching: HashMap<N, N> = HashMap::new();
62
63    // For each unmatched left node, try to find an augmenting path
64    for left_node in &left_nodes {
65        if !matching.contains_key(left_node) {
66            let mut visited = HashSet::new();
67            augment_path(
68                graph,
69                left_node,
70                &mut matching,
71                &mut reverse_matching,
72                &mut visited,
73                coloring,
74            );
75        }
76    }
77
78    BipartiteMatching {
79        size: matching.len(),
80        matching,
81    }
82}
83
84/// Try to find an augmenting path from an unmatched left node
85#[allow(dead_code)]
86fn augment_path<N, E, Ix>(
87    graph: &Graph<N, E, Ix>,
88    node: &N,
89    matching: &mut HashMap<N, N>,
90    reverse_matching: &mut HashMap<N, N>,
91    visited: &mut HashSet<N>,
92    coloring: &HashMap<N, u8>,
93) -> bool
94where
95    N: Node + std::fmt::Debug,
96    E: EdgeWeight,
97    Ix: petgraph::graph::IndexType,
98{
99    // Mark as visited
100    visited.insert(node.clone());
101
102    // Try all neighbors
103    if let Ok(neighbors) = graph.neighbors(node) {
104        for neighbor in neighbors {
105            // Skip if same color (not bipartite edge)
106            if coloring.get(node) == coloring.get(&neighbor) {
107                continue;
108            }
109
110            // If neighbor is unmatched, we found an augmenting path
111            if let std::collections::hash_map::Entry::Vacant(e) =
112                reverse_matching.entry(neighbor.clone())
113            {
114                matching.insert(node.clone(), neighbor.clone());
115                e.insert(node.clone());
116                return true;
117            }
118
119            // Otherwise, try to augment through the matched node
120            let matched_node = reverse_matching[&neighbor].clone();
121            if !visited.contains(&matched_node)
122                && augment_path(
123                    graph,
124                    &matched_node,
125                    matching,
126                    reverse_matching,
127                    visited,
128                    coloring,
129                )
130            {
131                matching.insert(node.clone(), neighbor.clone());
132                reverse_matching.insert(neighbor, node.clone());
133                return true;
134            }
135        }
136    }
137
138    false
139}
140
141/// Minimum weight bipartite matching using a simplified Hungarian algorithm
142///
143/// Finds the minimum weight perfect matching in a bipartite graph.
144/// Returns the total weight and the matching as a vector of (left_node, right_node) pairs.
145#[allow(dead_code)]
146pub fn minimum_weight_bipartite_matching<N, E, Ix>(
147    graph: &Graph<N, E, Ix>,
148) -> Result<(f64, Vec<(N, N)>)>
149where
150    N: Node + Clone + Hash + Eq + std::fmt::Debug,
151    E: EdgeWeight + Into<f64> + Clone,
152    Ix: IndexType,
153{
154    // First check if the graph is bipartite
155    let bipartite_result = is_bipartite(graph);
156
157    if !bipartite_result.is_bipartite {
158        return Err(GraphError::InvalidGraph(
159            "Graph is not bipartite".to_string(),
160        ));
161    }
162
163    let coloring = bipartite_result.coloring;
164
165    // Separate nodes by color
166    let mut left_nodes = Vec::new();
167    let mut right_nodes = Vec::new();
168
169    for (node, &color) in &coloring {
170        if color == 0 {
171            left_nodes.push(node.clone());
172        } else {
173            right_nodes.push(node.clone());
174        }
175    }
176
177    let n_left = left_nodes.len();
178    let n_right = right_nodes.len();
179
180    if n_left != n_right {
181        return Err(GraphError::InvalidGraph(
182            "Bipartite graph must have equal number of nodes in each partition for perfect matching".to_string()
183        ));
184    }
185
186    if n_left == 0 {
187        return Ok((0.0, vec![]));
188    }
189
190    // Create cost matrix
191    let mut cost_matrix = vec![vec![f64::INFINITY; n_right]; n_left];
192
193    for (i, left_node) in left_nodes.iter().enumerate() {
194        for (j, right_node) in right_nodes.iter().enumerate() {
195            if let Ok(weight) = graph.edge_weight(left_node, right_node) {
196                cost_matrix[i][j] = weight.into();
197            }
198        }
199    }
200
201    // Use a simplified version of Hungarian algorithm
202    // For small graphs, we can use a brute force approach
203    if n_left <= 6 {
204        minimum_weight_matching_bruteforce(&left_nodes, &right_nodes, &cost_matrix)
205    } else {
206        // For larger graphs, use a greedy approximation
207        minimum_weight_matching_greedy(&left_nodes, &right_nodes, &cost_matrix)
208    }
209}
210
211#[allow(dead_code)]
212fn minimum_weight_matching_bruteforce<N>(
213    left_nodes: &[N],
214    right_nodes: &[N],
215    cost_matrix: &[Vec<f64>],
216) -> Result<(f64, Vec<(N, N)>)>
217where
218    N: Node + Clone + std::fmt::Debug,
219{
220    let n = left_nodes.len();
221    let mut best_cost = f64::INFINITY;
222    let mut best_matching = Vec::new();
223
224    // Generate all permutations
225    let mut perm: Vec<usize> = (0..n).collect();
226
227    loop {
228        // Calculate cost for this permutation
229        let mut cost = 0.0;
230        for i in 0..n {
231            cost += cost_matrix[i][perm[i]];
232        }
233
234        if cost < best_cost {
235            best_cost = cost;
236            best_matching = (0..n)
237                .map(|i| (left_nodes[i].clone(), right_nodes[perm[i]].clone()))
238                .collect();
239        }
240
241        // Next permutation
242        if !next_permutation(&mut perm) {
243            break;
244        }
245    }
246
247    Ok((best_cost, best_matching))
248}
249
250#[allow(dead_code)]
251fn minimum_weight_matching_greedy<N>(
252    left_nodes: &[N],
253    right_nodes: &[N],
254    cost_matrix: &[Vec<f64>],
255) -> Result<(f64, Vec<(N, N)>)>
256where
257    N: Node + Clone + std::fmt::Debug,
258{
259    let n = left_nodes.len();
260    let mut matching = Vec::new();
261    let mut used_right = vec![false; n];
262    let mut total_cost = 0.0;
263
264    // Greedily assign each left node to the cheapest available right node
265    for i in 0..n {
266        let mut best_j = None;
267        let mut best_cost = f64::INFINITY;
268
269        for (j, &used) in used_right.iter().enumerate().take(n) {
270            if !used && cost_matrix[i][j] < best_cost {
271                best_cost = cost_matrix[i][j];
272                best_j = Some(j);
273            }
274        }
275
276        if let Some(j) = best_j {
277            used_right[j] = true;
278            total_cost += best_cost;
279            matching.push((left_nodes[i].clone(), right_nodes[j].clone()));
280        }
281    }
282
283    Ok((total_cost, matching))
284}
285
286#[allow(dead_code)]
287fn next_permutation(perm: &mut [usize]) -> bool {
288    let n = perm.len();
289
290    // Find the largest index k such that perm[k] < perm[k + 1]
291    let mut k = None;
292    for i in 0..n - 1 {
293        if perm[i] < perm[i + 1] {
294            k = Some(i);
295        }
296    }
297
298    let k = match k {
299        Some(k) => k,
300        None => return false, // Last permutation
301    };
302
303    // Find the largest index l greater than k such that perm[k] < perm[l]
304    let mut l = k + 1;
305    for i in k + 1..n {
306        if perm[k] < perm[i] {
307            l = i;
308        }
309    }
310
311    // Swap perm[k] and perm[l]
312    perm.swap(k, l);
313
314    // Reverse the sequence from perm[k + 1] to the end
315    perm[k + 1..].reverse();
316
317    true
318}
319
320/// Maximum cardinality matching result
321#[derive(Debug, Clone)]
322pub struct MaximumMatching<N: Node> {
323    /// The matching as a vector of edge pairs
324    pub matching: Vec<(N, N)>,
325    /// The size of the matching
326    pub size: usize,
327}
328
329/// Finds a maximum cardinality matching in a general graph using Edmonds' blossom algorithm
330///
331/// This is a simplified implementation of the blossom algorithm for general graphs.
332/// For better performance on bipartite graphs, use `maximum_bipartite_matching`.
333///
334/// # Arguments
335/// * `graph` - The input graph
336///
337/// # Returns
338/// * A maximum cardinality matching
339#[allow(dead_code)]
340pub fn maximum_cardinality_matching<N, E, Ix>(graph: &Graph<N, E, Ix>) -> MaximumMatching<N>
341where
342    N: Node + Clone + std::fmt::Debug,
343    E: EdgeWeight,
344    Ix: IndexType,
345{
346    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
347    let n = nodes.len();
348
349    if n == 0 {
350        return MaximumMatching {
351            matching: Vec::new(),
352            size: 0,
353        };
354    }
355
356    // Use a greedy approach for simplicity
357    // A full implementation would use Edmonds' blossom algorithm
358    let mut matching = Vec::new();
359    let mut matched = vec![false; n];
360    let node_to_idx: HashMap<N, usize> = nodes
361        .iter()
362        .enumerate()
363        .map(|(i, n)| (n.clone(), i))
364        .collect();
365
366    // Greedy matching: find augmenting paths
367    for (i, node) in nodes.iter().enumerate() {
368        if matched[i] {
369            continue;
370        }
371
372        if let Ok(neighbors) = graph.neighbors(node) {
373            for neighbor in neighbors {
374                if let Some(&j) = node_to_idx.get(&neighbor) {
375                    if !matched[j] {
376                        // Found an augmenting path of length 1
377                        matching.push((node.clone(), neighbor));
378                        matched[i] = true;
379                        matched[j] = true;
380                        break;
381                    }
382                }
383            }
384        }
385    }
386
387    MaximumMatching {
388        size: matching.len(),
389        matching,
390    }
391}
392
393/// Finds a maximal matching using a greedy algorithm
394///
395/// A maximal matching is one where no more edges can be added.
396/// This is simpler than maximum matching but provides a 2-approximation.
397///
398/// # Arguments
399/// * `graph` - The input graph
400///
401/// # Returns
402/// * A maximal matching
403#[allow(dead_code)]
404pub fn maximal_matching<N, E, Ix>(graph: &Graph<N, E, Ix>) -> MaximumMatching<N>
405where
406    N: Node + Clone + std::fmt::Debug,
407    E: EdgeWeight,
408    Ix: IndexType,
409{
410    let mut matching = Vec::new();
411    let mut matched_nodes = HashSet::new();
412
413    // Get all edges
414    let edges = graph.edges();
415
416    // Greedily add edges that don't conflict with existing matching
417    for edge in edges {
418        if !matched_nodes.contains(&edge.source) && !matched_nodes.contains(&edge.target) {
419            matching.push((edge.source.clone(), edge.target.clone()));
420            matched_nodes.insert(edge.source);
421            matched_nodes.insert(edge.target);
422        }
423    }
424
425    MaximumMatching {
426        size: matching.len(),
427        matching,
428    }
429}
430
431/// Stable marriage problem solver using the Gale-Shapley algorithm
432///
433/// Finds a stable matching between two sets of equal size where each element
434/// has a preference order over the other set.
435///
436/// # Arguments
437/// * `left_prefs` - Preference lists for left set (each list is ordered from most to least preferred)
438/// * `right_prefs` - Preference lists for right set
439///
440/// # Returns
441/// * A stable matching as pairs (left_index, right_index)
442#[allow(dead_code)]
443pub fn stable_marriage(
444    left_prefs: &[Vec<usize>],
445    right_prefs: &[Vec<usize>],
446) -> Result<Vec<(usize, usize)>> {
447    let n = left_prefs.len();
448
449    if n != right_prefs.len() {
450        return Err(GraphError::InvalidGraph(
451            "Left and right sets must have equal size".to_string(),
452        ));
453    }
454
455    if n == 0 {
456        return Ok(Vec::new());
457    }
458
459    // Validate preference lists
460    for (i, prefs) in left_prefs.iter().enumerate() {
461        if prefs.len() != n {
462            return Err(GraphError::InvalidGraph(format!(
463                "Left preference list {i} has wrong length"
464            )));
465        }
466        let mut sorted_prefs = prefs.clone();
467        sorted_prefs.sort_unstable();
468        if sorted_prefs != (0..n).collect::<Vec<_>>() {
469            return Err(GraphError::InvalidGraph(format!(
470                "Left preference list {i} is not a valid permutation"
471            )));
472        }
473    }
474
475    for (i, prefs) in right_prefs.iter().enumerate() {
476        if prefs.len() != n {
477            return Err(GraphError::InvalidGraph(format!(
478                "Right preference list {i} has wrong length"
479            )));
480        }
481        let mut sorted_prefs = prefs.clone();
482        sorted_prefs.sort_unstable();
483        if sorted_prefs != (0..n).collect::<Vec<_>>() {
484            return Err(GraphError::InvalidGraph(format!(
485                "Right preference list {i} is not a valid permutation"
486            )));
487        }
488    }
489
490    // Create inverse preference mappings for right set for efficiency
491    let mut right_inv_prefs = vec![vec![0; n]; n];
492    for (i, prefs) in right_prefs.iter().enumerate() {
493        for (rank, &person) in prefs.iter().enumerate() {
494            right_inv_prefs[i][person] = rank;
495        }
496    }
497
498    // Gale-Shapley algorithm
499    let mut left_partner = vec![None; n];
500    let mut right_partner = vec![None; n];
501    let mut left_next_proposal = vec![0; n];
502    let mut free_left: std::collections::VecDeque<usize> = (0..n).collect();
503
504    while let Some(left) = free_left.pop_front() {
505        if left_next_proposal[left] >= n {
506            continue; // This left person has proposed to everyone
507        }
508
509        let right = left_prefs[left][left_next_proposal[left]];
510        left_next_proposal[left] += 1;
511
512        match right_partner[right] {
513            None => {
514                // Right person is free, form engagement
515                left_partner[left] = Some(right);
516                right_partner[right] = Some(left);
517            }
518            Some(current_left) => {
519                // Right person is engaged, check if they prefer the new proposal
520                if right_inv_prefs[right][left] < right_inv_prefs[right][current_left] {
521                    // Right person prefers the new proposal
522                    left_partner[left] = Some(right);
523                    right_partner[right] = Some(left);
524                    left_partner[current_left] = None;
525                    free_left.push_back(current_left);
526                } else {
527                    // Right person prefers their current partner
528                    free_left.push_back(left);
529                }
530            }
531        }
532    }
533
534    // Convert to result format
535    let mut result = Vec::new();
536    for (left, partner) in left_partner.iter().enumerate() {
537        if let Some(right) = partner {
538            result.push((left, *right));
539        }
540    }
541
542    Ok(result)
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::error::Result as GraphResult;
549    use crate::generators::create_graph;
550
551    #[test]
552    fn test_maximum_bipartite_matching() -> GraphResult<()> {
553        let mut graph = create_graph::<&str, ()>();
554
555        // Create a bipartite graph
556        graph.add_edge("A", "1", ())?;
557        graph.add_edge("A", "2", ())?;
558        graph.add_edge("B", "2", ())?;
559        graph.add_edge("B", "3", ())?;
560        graph.add_edge("C", "3", ())?;
561
562        // Create coloring
563        let mut coloring = HashMap::new();
564        coloring.insert("A", 0);
565        coloring.insert("B", 0);
566        coloring.insert("C", 0);
567        coloring.insert("1", 1);
568        coloring.insert("2", 1);
569        coloring.insert("3", 1);
570
571        let matching = maximum_bipartite_matching(&graph, &coloring);
572
573        // Should find a perfect matching of size 3
574        assert_eq!(matching.size, 3);
575
576        // Verify it's a valid matching
577        let mut used_right = HashSet::new();
578        for right in matching.matching.values() {
579            assert!(!used_right.contains(right));
580            used_right.insert(right);
581        }
582
583        Ok(())
584    }
585
586    #[test]
587    fn test_minimum_weight_bipartite_matching() -> GraphResult<()> {
588        let mut graph = create_graph::<&str, f64>();
589
590        // Create a complete bipartite graph K2,2
591        graph.add_edge("A", "1", 1.0)?;
592        graph.add_edge("A", "2", 3.0)?;
593        graph.add_edge("B", "1", 2.0)?;
594        graph.add_edge("B", "2", 1.0)?;
595
596        let (total_weight, matching) = minimum_weight_bipartite_matching(&graph)?;
597
598        // Optimal matching: A-1 (1.0) and B-2 (1.0)
599        assert_eq!(total_weight, 2.0);
600        assert_eq!(matching.len(), 2);
601
602        Ok(())
603    }
604
605    #[test]
606    fn test_maximum_cardinality_matching() {
607        let mut graph = create_graph::<&str, ()>();
608
609        // Create a simple graph
610        graph.add_edge("A", "B", ()).unwrap();
611        graph.add_edge("C", "D", ()).unwrap();
612        graph.add_edge("E", "F", ()).unwrap();
613
614        let matching = maximum_cardinality_matching(&graph);
615
616        // Should find a matching of size 3
617        assert_eq!(matching.size, 3);
618        assert_eq!(matching.matching.len(), 3);
619
620        // Verify no node is matched twice
621        let mut matched_nodes = HashSet::new();
622        for (u, v) in &matching.matching {
623            assert!(!matched_nodes.contains(u));
624            assert!(!matched_nodes.contains(v));
625            matched_nodes.insert(u);
626            matched_nodes.insert(v);
627        }
628    }
629
630    #[test]
631    fn test_maximal_matching() {
632        let mut graph = create_graph::<i32, ()>();
633
634        // Create a triangle
635        graph.add_edge(1, 2, ()).unwrap();
636        graph.add_edge(2, 3, ()).unwrap();
637        graph.add_edge(3, 1, ()).unwrap();
638
639        let matching = maximal_matching(&graph);
640
641        // Should find at least one edge (maximal for triangle is 1)
642        assert_eq!(matching.size, 1);
643        assert_eq!(matching.matching.len(), 1);
644
645        // Verify it's a valid matching
646        let mut matched_nodes = HashSet::new();
647        for (u, v) in &matching.matching {
648            assert!(!matched_nodes.contains(u));
649            assert!(!matched_nodes.contains(v));
650            matched_nodes.insert(u);
651            matched_nodes.insert(v);
652        }
653    }
654
655    #[test]
656    fn test_stable_marriage() -> GraphResult<()> {
657        // Example: 3 people on each side
658        let left_prefs = vec![
659            vec![0, 1, 2], // Person 0 prefers 0, then 1, then 2
660            vec![1, 0, 2], // Person 1 prefers 1, then 0, then 2
661            vec![0, 1, 2], // Person 2 prefers 0, then 1, then 2
662        ];
663
664        let right_prefs = vec![
665            vec![2, 1, 0], // Person 0 prefers 2, then 1, then 0
666            vec![0, 2, 1], // Person 1 prefers 0, then 2, then 1
667            vec![0, 1, 2], // Person 2 prefers 0, then 1, then 2
668        ];
669
670        let matching = stable_marriage(&left_prefs, &right_prefs)?;
671
672        // Should have 3 pairs
673        assert_eq!(matching.len(), 3);
674
675        // Verify it's a complete matching
676        let mut matched_left = HashSet::new();
677        let mut matched_right = HashSet::new();
678        for (left, right) in &matching {
679            assert!(!matched_left.contains(left));
680            assert!(!matched_right.contains(right));
681            matched_left.insert(*left);
682            matched_right.insert(*right);
683        }
684
685        Ok(())
686    }
687
688    #[test]
689    fn test_stable_marriage_empty() -> GraphResult<()> {
690        let left_prefs: Vec<Vec<usize>> = vec![];
691        let right_prefs: Vec<Vec<usize>> = vec![];
692
693        let matching = stable_marriage(&left_prefs, &right_prefs)?;
694        assert_eq!(matching.len(), 0);
695
696        Ok(())
697    }
698
699    #[test]
700    fn test_stable_marriage_invalid_input() {
701        // Mismatched sizes
702        let left_prefs = vec![vec![0]];
703        let right_prefs = vec![vec![0], vec![1]];
704
705        assert!(stable_marriage(&left_prefs, &right_prefs).is_err());
706
707        // Invalid preference list
708        let left_prefs = vec![vec![0, 0]]; // Duplicate
709        let right_prefs = vec![vec![0, 1]];
710
711        assert!(stable_marriage(&left_prefs, &right_prefs).is_err());
712    }
713}