1use crate::algorithms::connectivity::is_bipartite;
7use crate::base::{EdgeWeight, Graph, IndexType, Node};
8use crate::error::{GraphError, Result};
9use std::collections::{HashMap, HashSet};
10use std::hash::Hash;
11
12#[derive(Debug, Clone)]
14pub struct BipartiteMatching<N: Node> {
15 pub matching: HashMap<N, N>,
17 pub size: usize,
19}
20
21#[allow(dead_code)]
32pub fn maximum_bipartite_matching<N, E, Ix>(
33 graph: &Graph<N, E, Ix>,
34 coloring: &HashMap<N, u8>,
35) -> BipartiteMatching<N>
36where
37 N: Node + std::fmt::Debug,
38 E: EdgeWeight,
39 Ix: petgraph::graph::IndexType,
40{
41 let mut node_to_idx: HashMap<N, petgraph::graph::NodeIndex<Ix>> = HashMap::new();
43 for node_idx in graph.inner().node_indices() {
44 node_to_idx.insert(graph.inner()[node_idx].clone(), node_idx);
45 }
46
47 let mut left_nodes = Vec::new();
49 let mut right_nodes = Vec::new();
50
51 for (node, &color) in coloring {
52 if color == 0 {
53 left_nodes.push(node.clone());
54 } else {
55 right_nodes.push(node.clone());
56 }
57 }
58
59 let mut matching: HashMap<N, N> = HashMap::new();
61 let mut reverse_matching: HashMap<N, N> = HashMap::new();
62
63 for left_node in &left_nodes {
65 if !matching.contains_key(left_node) {
66 let mut visited = HashSet::new();
67 augment_path(
68 graph,
69 left_node,
70 &mut matching,
71 &mut reverse_matching,
72 &mut visited,
73 coloring,
74 );
75 }
76 }
77
78 BipartiteMatching {
79 size: matching.len(),
80 matching,
81 }
82}
83
84#[allow(dead_code)]
86fn augment_path<N, E, Ix>(
87 graph: &Graph<N, E, Ix>,
88 node: &N,
89 matching: &mut HashMap<N, N>,
90 reverse_matching: &mut HashMap<N, N>,
91 visited: &mut HashSet<N>,
92 coloring: &HashMap<N, u8>,
93) -> bool
94where
95 N: Node + std::fmt::Debug,
96 E: EdgeWeight,
97 Ix: petgraph::graph::IndexType,
98{
99 visited.insert(node.clone());
101
102 if let Ok(neighbors) = graph.neighbors(node) {
104 for neighbor in neighbors {
105 if coloring.get(node) == coloring.get(&neighbor) {
107 continue;
108 }
109
110 if let std::collections::hash_map::Entry::Vacant(e) =
112 reverse_matching.entry(neighbor.clone())
113 {
114 matching.insert(node.clone(), neighbor.clone());
115 e.insert(node.clone());
116 return true;
117 }
118
119 let matched_node = reverse_matching[&neighbor].clone();
121 if !visited.contains(&matched_node)
122 && augment_path(
123 graph,
124 &matched_node,
125 matching,
126 reverse_matching,
127 visited,
128 coloring,
129 )
130 {
131 matching.insert(node.clone(), neighbor.clone());
132 reverse_matching.insert(neighbor, node.clone());
133 return true;
134 }
135 }
136 }
137
138 false
139}
140
141#[allow(dead_code)]
146pub fn minimum_weight_bipartite_matching<N, E, Ix>(
147 graph: &Graph<N, E, Ix>,
148) -> Result<(f64, Vec<(N, N)>)>
149where
150 N: Node + Clone + Hash + Eq + std::fmt::Debug,
151 E: EdgeWeight + Into<f64> + Clone,
152 Ix: IndexType,
153{
154 let bipartite_result = is_bipartite(graph);
156
157 if !bipartite_result.is_bipartite {
158 return Err(GraphError::InvalidGraph(
159 "Graph is not bipartite".to_string(),
160 ));
161 }
162
163 let coloring = bipartite_result.coloring;
164
165 let mut left_nodes = Vec::new();
167 let mut right_nodes = Vec::new();
168
169 for (node, &color) in &coloring {
170 if color == 0 {
171 left_nodes.push(node.clone());
172 } else {
173 right_nodes.push(node.clone());
174 }
175 }
176
177 let n_left = left_nodes.len();
178 let n_right = right_nodes.len();
179
180 if n_left != n_right {
181 return Err(GraphError::InvalidGraph(
182 "Bipartite graph must have equal number of nodes in each partition for perfect matching".to_string()
183 ));
184 }
185
186 if n_left == 0 {
187 return Ok((0.0, vec![]));
188 }
189
190 let mut cost_matrix = vec![vec![f64::INFINITY; n_right]; n_left];
192
193 for (i, left_node) in left_nodes.iter().enumerate() {
194 for (j, right_node) in right_nodes.iter().enumerate() {
195 if let Ok(weight) = graph.edge_weight(left_node, right_node) {
196 cost_matrix[i][j] = weight.into();
197 }
198 }
199 }
200
201 if n_left <= 6 {
204 minimum_weight_matching_bruteforce(&left_nodes, &right_nodes, &cost_matrix)
205 } else {
206 minimum_weight_matching_greedy(&left_nodes, &right_nodes, &cost_matrix)
208 }
209}
210
211#[allow(dead_code)]
212fn minimum_weight_matching_bruteforce<N>(
213 left_nodes: &[N],
214 right_nodes: &[N],
215 cost_matrix: &[Vec<f64>],
216) -> Result<(f64, Vec<(N, N)>)>
217where
218 N: Node + Clone + std::fmt::Debug,
219{
220 let n = left_nodes.len();
221 let mut best_cost = f64::INFINITY;
222 let mut best_matching = Vec::new();
223
224 let mut perm: Vec<usize> = (0..n).collect();
226
227 loop {
228 let mut cost = 0.0;
230 for i in 0..n {
231 cost += cost_matrix[i][perm[i]];
232 }
233
234 if cost < best_cost {
235 best_cost = cost;
236 best_matching = (0..n)
237 .map(|i| (left_nodes[i].clone(), right_nodes[perm[i]].clone()))
238 .collect();
239 }
240
241 if !next_permutation(&mut perm) {
243 break;
244 }
245 }
246
247 Ok((best_cost, best_matching))
248}
249
250#[allow(dead_code)]
251fn minimum_weight_matching_greedy<N>(
252 left_nodes: &[N],
253 right_nodes: &[N],
254 cost_matrix: &[Vec<f64>],
255) -> Result<(f64, Vec<(N, N)>)>
256where
257 N: Node + Clone + std::fmt::Debug,
258{
259 let n = left_nodes.len();
260 let mut matching = Vec::new();
261 let mut used_right = vec![false; n];
262 let mut total_cost = 0.0;
263
264 for i in 0..n {
266 let mut best_j = None;
267 let mut best_cost = f64::INFINITY;
268
269 for (j, &used) in used_right.iter().enumerate().take(n) {
270 if !used && cost_matrix[i][j] < best_cost {
271 best_cost = cost_matrix[i][j];
272 best_j = Some(j);
273 }
274 }
275
276 if let Some(j) = best_j {
277 used_right[j] = true;
278 total_cost += best_cost;
279 matching.push((left_nodes[i].clone(), right_nodes[j].clone()));
280 }
281 }
282
283 Ok((total_cost, matching))
284}
285
286#[allow(dead_code)]
287fn next_permutation(perm: &mut [usize]) -> bool {
288 let n = perm.len();
289
290 let mut k = None;
292 for i in 0..n - 1 {
293 if perm[i] < perm[i + 1] {
294 k = Some(i);
295 }
296 }
297
298 let k = match k {
299 Some(k) => k,
300 None => return false, };
302
303 let mut l = k + 1;
305 for i in k + 1..n {
306 if perm[k] < perm[i] {
307 l = i;
308 }
309 }
310
311 perm.swap(k, l);
313
314 perm[k + 1..].reverse();
316
317 true
318}
319
320#[derive(Debug, Clone)]
322pub struct MaximumMatching<N: Node> {
323 pub matching: Vec<(N, N)>,
325 pub size: usize,
327}
328
329#[allow(dead_code)]
340pub fn maximum_cardinality_matching<N, E, Ix>(graph: &Graph<N, E, Ix>) -> MaximumMatching<N>
341where
342 N: Node + Clone + std::fmt::Debug,
343 E: EdgeWeight,
344 Ix: IndexType,
345{
346 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
347 let n = nodes.len();
348
349 if n == 0 {
350 return MaximumMatching {
351 matching: Vec::new(),
352 size: 0,
353 };
354 }
355
356 let mut matching = Vec::new();
359 let mut matched = vec![false; n];
360 let node_to_idx: HashMap<N, usize> = nodes
361 .iter()
362 .enumerate()
363 .map(|(i, n)| (n.clone(), i))
364 .collect();
365
366 for (i, node) in nodes.iter().enumerate() {
368 if matched[i] {
369 continue;
370 }
371
372 if let Ok(neighbors) = graph.neighbors(node) {
373 for neighbor in neighbors {
374 if let Some(&j) = node_to_idx.get(&neighbor) {
375 if !matched[j] {
376 matching.push((node.clone(), neighbor));
378 matched[i] = true;
379 matched[j] = true;
380 break;
381 }
382 }
383 }
384 }
385 }
386
387 MaximumMatching {
388 size: matching.len(),
389 matching,
390 }
391}
392
393#[allow(dead_code)]
404pub fn maximal_matching<N, E, Ix>(graph: &Graph<N, E, Ix>) -> MaximumMatching<N>
405where
406 N: Node + Clone + std::fmt::Debug,
407 E: EdgeWeight,
408 Ix: IndexType,
409{
410 let mut matching = Vec::new();
411 let mut matched_nodes = HashSet::new();
412
413 let edges = graph.edges();
415
416 for edge in edges {
418 if !matched_nodes.contains(&edge.source) && !matched_nodes.contains(&edge.target) {
419 matching.push((edge.source.clone(), edge.target.clone()));
420 matched_nodes.insert(edge.source);
421 matched_nodes.insert(edge.target);
422 }
423 }
424
425 MaximumMatching {
426 size: matching.len(),
427 matching,
428 }
429}
430
431#[allow(dead_code)]
443pub fn stable_marriage(
444 left_prefs: &[Vec<usize>],
445 right_prefs: &[Vec<usize>],
446) -> Result<Vec<(usize, usize)>> {
447 let n = left_prefs.len();
448
449 if n != right_prefs.len() {
450 return Err(GraphError::InvalidGraph(
451 "Left and right sets must have equal size".to_string(),
452 ));
453 }
454
455 if n == 0 {
456 return Ok(Vec::new());
457 }
458
459 for (i, prefs) in left_prefs.iter().enumerate() {
461 if prefs.len() != n {
462 return Err(GraphError::InvalidGraph(format!(
463 "Left preference list {i} has wrong length"
464 )));
465 }
466 let mut sorted_prefs = prefs.clone();
467 sorted_prefs.sort_unstable();
468 if sorted_prefs != (0..n).collect::<Vec<_>>() {
469 return Err(GraphError::InvalidGraph(format!(
470 "Left preference list {i} is not a valid permutation"
471 )));
472 }
473 }
474
475 for (i, prefs) in right_prefs.iter().enumerate() {
476 if prefs.len() != n {
477 return Err(GraphError::InvalidGraph(format!(
478 "Right preference list {i} has wrong length"
479 )));
480 }
481 let mut sorted_prefs = prefs.clone();
482 sorted_prefs.sort_unstable();
483 if sorted_prefs != (0..n).collect::<Vec<_>>() {
484 return Err(GraphError::InvalidGraph(format!(
485 "Right preference list {i} is not a valid permutation"
486 )));
487 }
488 }
489
490 let mut right_inv_prefs = vec![vec![0; n]; n];
492 for (i, prefs) in right_prefs.iter().enumerate() {
493 for (rank, &person) in prefs.iter().enumerate() {
494 right_inv_prefs[i][person] = rank;
495 }
496 }
497
498 let mut left_partner = vec![None; n];
500 let mut right_partner = vec![None; n];
501 let mut left_next_proposal = vec![0; n];
502 let mut free_left: std::collections::VecDeque<usize> = (0..n).collect();
503
504 while let Some(left) = free_left.pop_front() {
505 if left_next_proposal[left] >= n {
506 continue; }
508
509 let right = left_prefs[left][left_next_proposal[left]];
510 left_next_proposal[left] += 1;
511
512 match right_partner[right] {
513 None => {
514 left_partner[left] = Some(right);
516 right_partner[right] = Some(left);
517 }
518 Some(current_left) => {
519 if right_inv_prefs[right][left] < right_inv_prefs[right][current_left] {
521 left_partner[left] = Some(right);
523 right_partner[right] = Some(left);
524 left_partner[current_left] = None;
525 free_left.push_back(current_left);
526 } else {
527 free_left.push_back(left);
529 }
530 }
531 }
532 }
533
534 let mut result = Vec::new();
536 for (left, partner) in left_partner.iter().enumerate() {
537 if let Some(right) = partner {
538 result.push((left, *right));
539 }
540 }
541
542 Ok(result)
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use crate::error::Result as GraphResult;
549 use crate::generators::create_graph;
550
551 #[test]
552 fn test_maximum_bipartite_matching() -> GraphResult<()> {
553 let mut graph = create_graph::<&str, ()>();
554
555 graph.add_edge("A", "1", ())?;
557 graph.add_edge("A", "2", ())?;
558 graph.add_edge("B", "2", ())?;
559 graph.add_edge("B", "3", ())?;
560 graph.add_edge("C", "3", ())?;
561
562 let mut coloring = HashMap::new();
564 coloring.insert("A", 0);
565 coloring.insert("B", 0);
566 coloring.insert("C", 0);
567 coloring.insert("1", 1);
568 coloring.insert("2", 1);
569 coloring.insert("3", 1);
570
571 let matching = maximum_bipartite_matching(&graph, &coloring);
572
573 assert_eq!(matching.size, 3);
575
576 let mut used_right = HashSet::new();
578 for right in matching.matching.values() {
579 assert!(!used_right.contains(right));
580 used_right.insert(right);
581 }
582
583 Ok(())
584 }
585
586 #[test]
587 fn test_minimum_weight_bipartite_matching() -> GraphResult<()> {
588 let mut graph = create_graph::<&str, f64>();
589
590 graph.add_edge("A", "1", 1.0)?;
592 graph.add_edge("A", "2", 3.0)?;
593 graph.add_edge("B", "1", 2.0)?;
594 graph.add_edge("B", "2", 1.0)?;
595
596 let (total_weight, matching) = minimum_weight_bipartite_matching(&graph)?;
597
598 assert_eq!(total_weight, 2.0);
600 assert_eq!(matching.len(), 2);
601
602 Ok(())
603 }
604
605 #[test]
606 fn test_maximum_cardinality_matching() {
607 let mut graph = create_graph::<&str, ()>();
608
609 graph.add_edge("A", "B", ()).unwrap();
611 graph.add_edge("C", "D", ()).unwrap();
612 graph.add_edge("E", "F", ()).unwrap();
613
614 let matching = maximum_cardinality_matching(&graph);
615
616 assert_eq!(matching.size, 3);
618 assert_eq!(matching.matching.len(), 3);
619
620 let mut matched_nodes = HashSet::new();
622 for (u, v) in &matching.matching {
623 assert!(!matched_nodes.contains(u));
624 assert!(!matched_nodes.contains(v));
625 matched_nodes.insert(u);
626 matched_nodes.insert(v);
627 }
628 }
629
630 #[test]
631 fn test_maximal_matching() {
632 let mut graph = create_graph::<i32, ()>();
633
634 graph.add_edge(1, 2, ()).unwrap();
636 graph.add_edge(2, 3, ()).unwrap();
637 graph.add_edge(3, 1, ()).unwrap();
638
639 let matching = maximal_matching(&graph);
640
641 assert_eq!(matching.size, 1);
643 assert_eq!(matching.matching.len(), 1);
644
645 let mut matched_nodes = HashSet::new();
647 for (u, v) in &matching.matching {
648 assert!(!matched_nodes.contains(u));
649 assert!(!matched_nodes.contains(v));
650 matched_nodes.insert(u);
651 matched_nodes.insert(v);
652 }
653 }
654
655 #[test]
656 fn test_stable_marriage() -> GraphResult<()> {
657 let left_prefs = vec![
659 vec![0, 1, 2], vec![1, 0, 2], vec![0, 1, 2], ];
663
664 let right_prefs = vec![
665 vec![2, 1, 0], vec![0, 2, 1], vec![0, 1, 2], ];
669
670 let matching = stable_marriage(&left_prefs, &right_prefs)?;
671
672 assert_eq!(matching.len(), 3);
674
675 let mut matched_left = HashSet::new();
677 let mut matched_right = HashSet::new();
678 for (left, right) in &matching {
679 assert!(!matched_left.contains(left));
680 assert!(!matched_right.contains(right));
681 matched_left.insert(*left);
682 matched_right.insert(*right);
683 }
684
685 Ok(())
686 }
687
688 #[test]
689 fn test_stable_marriage_empty() -> GraphResult<()> {
690 let left_prefs: Vec<Vec<usize>> = vec![];
691 let right_prefs: Vec<Vec<usize>> = vec![];
692
693 let matching = stable_marriage(&left_prefs, &right_prefs)?;
694 assert_eq!(matching.len(), 0);
695
696 Ok(())
697 }
698
699 #[test]
700 fn test_stable_marriage_invalid_input() {
701 let left_prefs = vec![vec![0]];
703 let right_prefs = vec![vec![0], vec![1]];
704
705 assert!(stable_marriage(&left_prefs, &right_prefs).is_err());
706
707 let left_prefs = vec![vec![0, 0]]; let right_prefs = vec![vec![0, 1]];
710
711 assert!(stable_marriage(&left_prefs, &right_prefs).is_err());
712 }
713}