Skip to main content

rust_igraph/algorithms/
matching.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_sign_loss,
4    clippy::cast_possible_wrap,
5    clippy::cast_precision_loss,
6    clippy::many_single_char_names,
7    clippy::too_many_lines,
8    clippy::similar_names,
9    clippy::module_name_repetitions
10)]
11
12//! Bipartite matching algorithms.
13//!
14//! Provides:
15//! - [`is_matching`] — validate a matching vector against a graph
16//! - [`is_maximal_matching`] — check whether a valid matching is maximal
17//! - [`maximum_bipartite_matching`] — maximum cardinality bipartite matching
18//!   (push-relabel, unweighted)
19//! - [`maximum_bipartite_matching_weighted`] — maximum weight bipartite
20//!   matching (Hungarian / Kuhn-Munkres)
21//!
22//! Reference: `igraph/src/misc/matching.c` (1013 lines).
23
24use std::collections::VecDeque;
25
26use crate::core::error::{IgraphError, IgraphResult};
27use crate::core::graph::Graph;
28
29/// Result of [`maximum_bipartite_matching`] or
30/// [`maximum_bipartite_matching_weighted`].
31///
32/// # Fields
33///
34/// * `matching_size` — number of matched vertex pairs.
35/// * `matching_weight` — total weight of matched edges (equals `matching_size`
36///   for unweighted).
37/// * `matching` — per-vertex match: `matching[v]` is the partner of `v`, or
38///   `None` if `v` is unmatched.
39///
40/// ```
41/// use rust_igraph::{create, maximum_bipartite_matching, MatchingResult};
42///
43/// // K_{2,2}: 0-2, 0-3, 1-2, 1-3
44/// let g = create(&[(0, 2), (0, 3), (1, 2), (1, 3)], 4, false).unwrap();
45/// let types = vec![false, false, true, true];
46/// let r = maximum_bipartite_matching(&g, &types).unwrap();
47/// assert_eq!(r.matching_size, 2);
48/// ```
49#[derive(Debug, Clone)]
50pub struct MatchingResult {
51    pub matching_size: usize,
52    pub matching_weight: f64,
53    pub matching: Vec<Option<u32>>,
54}
55
56/// Check whether `matching` is a valid matching for `graph`.
57///
58/// A matching vector has length `vcount`; entry `i` is `Some(j)` when vertex
59/// `i` is matched to `j`, or `None` if unmatched. The function verifies:
60/// 1. Length equals `vcount`.
61/// 2. Matched pairs are mutual (`matching[i] == Some(j)` ⟹ `matching[j] == Some(i)`).
62/// 3. Every matched pair is connected by an edge (ignoring direction).
63/// 4. If `types` is provided, matched vertices have different types.
64///
65/// ```
66/// use rust_igraph::{create, is_matching};
67///
68/// let g = create(&[(0, 1), (1, 2)], 3, false).unwrap();
69/// let m = vec![Some(1), Some(0), None];
70/// assert!(is_matching(&g, None, &m).unwrap());
71/// ```
72pub fn is_matching(
73    graph: &Graph,
74    types: Option<&[bool]>,
75    matching: &[Option<u32>],
76) -> IgraphResult<bool> {
77    let n = graph.vcount() as usize;
78    if matching.len() != n {
79        return Ok(false);
80    }
81
82    let adj = build_undirected_adj(graph);
83
84    for (i, &mi) in matching.iter().enumerate() {
85        let Some(j) = mi else { continue };
86        let j_usize = j as usize;
87        if j_usize >= n {
88            return Ok(false);
89        }
90        if matching[j_usize] != Some(i as u32) {
91            return Ok(false);
92        }
93        if !adj[i].contains(&j) {
94            return Ok(false);
95        }
96    }
97
98    if let Some(t) = types {
99        if t.len() < n {
100            return Err(IgraphError::InvalidArgument(
101                "types vector too short".into(),
102            ));
103        }
104        for (i, &mi) in matching.iter().enumerate() {
105            let Some(j) = mi else { continue };
106            if t[i] == t[j as usize] {
107                return Ok(false);
108            }
109        }
110    }
111
112    Ok(true)
113}
114
115/// Check whether `matching` is a *maximal* matching for `graph`.
116///
117/// A matching is maximal if no unmatched vertex has an unmatched neighbor
118/// (respecting bipartite types if given).
119///
120/// ```
121/// use rust_igraph::{create, is_maximal_matching};
122///
123/// let g = create(&[(0, 1), (1, 2)], 3, false).unwrap();
124/// // Only 0-1 matched; vertex 2 is unmatched but has no unmatched neighbor → maximal
125/// let m = vec![Some(1), Some(0), None];
126/// assert!(is_maximal_matching(&g, None, &m).unwrap());
127/// ```
128pub fn is_maximal_matching(
129    graph: &Graph,
130    types: Option<&[bool]>,
131    matching: &[Option<u32>],
132) -> IgraphResult<bool> {
133    if !is_matching(graph, types, matching)? {
134        return Ok(false);
135    }
136
137    let n = graph.vcount() as usize;
138    let adj = build_undirected_adj(graph);
139
140    for i in 0..n {
141        if matching[i].is_some() {
142            continue;
143        }
144        for &nb in &adj[i] {
145            if matching[nb as usize].is_none() {
146                if let Some(t) = types {
147                    if t[i] == t[nb as usize] {
148                        continue;
149                    }
150                }
151                return Ok(false);
152            }
153        }
154    }
155
156    Ok(true)
157}
158
159/// Compute a maximum cardinality matching in an unweighted bipartite graph.
160///
161/// Uses a push-relabel algorithm with greedy initialization and global
162/// relabeling every `n/2` steps. Returns a [`MatchingResult`] where
163/// `matching_weight == matching_size` (since all edges have unit weight).
164///
165/// `types` must be a bipartite partition: `types[v]` is `false` for one side
166/// and `true` for the other. The function validates that every edge connects
167/// vertices of different types.
168///
169/// ```
170/// use rust_igraph::{create, maximum_bipartite_matching};
171///
172/// let g = create(&[(0, 2), (0, 3), (1, 2), (1, 3)], 4, false).unwrap();
173/// let types = vec![false, false, true, true];
174/// let r = maximum_bipartite_matching(&g, &types).unwrap();
175/// assert_eq!(r.matching_size, 2);
176/// ```
177pub fn maximum_bipartite_matching(graph: &Graph, types: &[bool]) -> IgraphResult<MatchingResult> {
178    let n = graph.vcount() as usize;
179    if types.len() < n {
180        return Err(IgraphError::InvalidArgument(
181            "types vector too short".into(),
182        ));
183    }
184
185    let adj = build_undirected_adj(graph);
186    let (matching, num_matched) = push_relabel_unweighted(graph, &adj, types, n)?;
187
188    Ok(MatchingResult {
189        matching_size: num_matched,
190        matching_weight: num_matched as f64,
191        matching,
192    })
193}
194
195/// Compute a maximum weight matching in a weighted bipartite graph.
196///
197/// Uses the Hungarian algorithm (Kuhn-Munkres) with push-relabel
198/// initialization on tight edges. `weights[e]` gives the weight of edge `e`
199/// (by edge id). `eps` controls floating-point tolerance for "tight" edge
200/// detection; pass `0.0` for integer weights.
201///
202/// ```
203/// use rust_igraph::{create, maximum_bipartite_matching_weighted};
204///
205/// let g = create(&[(0, 2), (0, 3), (1, 2), (1, 3)], 4, false).unwrap();
206/// let types = vec![false, false, true, true];
207/// let weights = vec![1.0, 10.0, 10.0, 1.0];
208/// let r = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).unwrap();
209/// assert_eq!(r.matching_size, 2);
210/// assert!((r.matching_weight - 20.0).abs() < 1e-9);
211/// ```
212pub fn maximum_bipartite_matching_weighted(
213    graph: &Graph,
214    types: &[bool],
215    weights: &[f64],
216    eps: f64,
217) -> IgraphResult<MatchingResult> {
218    let n = graph.vcount() as usize;
219    let ne = graph.ecount();
220    if types.len() < n {
221        return Err(IgraphError::InvalidArgument(
222            "types vector too short".into(),
223        ));
224    }
225    if weights.len() < ne {
226        return Err(IgraphError::InvalidArgument(
227            "weights vector too short".into(),
228        ));
229    }
230    let eps = if eps < 0.0 { 0.0 } else { eps };
231
232    hungarian(graph, types, weights, eps, n, ne)
233}
234
235// ── helpers ──────────────────────────────────────────────────────────
236
237fn build_undirected_adj(graph: &Graph) -> Vec<Vec<u32>> {
238    let n = graph.vcount() as usize;
239    let mut adj: Vec<Vec<u32>> = vec![Vec::new(); n];
240    for eid in 0..graph.ecount() {
241        if let Ok((u, v)) = graph.edge(eid as u32) {
242            adj[u as usize].push(v);
243            if u != v {
244                adj[v as usize].push(u);
245            }
246        }
247    }
248    adj
249}
250
251// ── push-relabel unweighted ─────────────────────────────────────────
252
253fn push_relabel_unweighted(
254    _graph: &Graph,
255    adj: &[Vec<u32>],
256    types: &[bool],
257    n: usize,
258) -> IgraphResult<(Vec<Option<u32>>, usize)> {
259    let mut matching: Vec<i64> = vec![-1; n];
260    let mut labels: Vec<i64> = vec![0; n];
261
262    // Determine smaller set and greedy init
263    let count_true = types[..n].iter().filter(|&&t| t).count();
264    let smaller_set = count_true <= n / 2;
265
266    let mut num_matched: usize = 0;
267    for i in 0..n {
268        if matching[i] != -1 {
269            continue;
270        }
271        for &nb in &adj[i] {
272            let nb_usize = nb as usize;
273            if types[nb_usize] == types[i] {
274                return Err(IgraphError::InvalidArgument(
275                    "Graph is not bipartite with supplied types vector".into(),
276                ));
277            }
278            if matching[nb_usize] == -1 {
279                matching[nb_usize] = i64::from(i as u32);
280                matching[i] = i64::from(nb);
281                num_matched += 1;
282                break;
283            }
284        }
285    }
286
287    // Global relabeling
288    global_relabel(adj, &mut labels, &matching, types, smaller_set, n);
289
290    // Fill push queue with unmatched vertices from smaller set
291    let mut q: VecDeque<usize> = VecDeque::new();
292    for i in 0..n {
293        if matching[i] == -1 && types[i] == smaller_set {
294            q.push_back(i);
295        }
296    }
297
298    let relabeling_freq = (n / 2).max(1);
299    let mut label_changed: usize = 0;
300
301    while let Some(v) = q.pop_front() {
302        if label_changed >= relabeling_freq {
303            global_relabel(adj, &mut labels, &matching, types, smaller_set, n);
304            label_changed = 0;
305        }
306
307        let mut best_u: i64 = -1;
308        let mut best_label: i64 = 2 * n as i64;
309
310        for &nb in &adj[v] {
311            let nb_usize = nb as usize;
312            if labels[nb_usize] < best_label {
313                best_u = i64::from(nb);
314                best_label = labels[nb_usize];
315                label_changed += 1;
316            }
317        }
318
319        if best_label < n as i64 {
320            let u = best_u as usize;
321            labels[v] = labels[u] + 1;
322            if matching[u] != -1 {
323                let w = matching[u] as usize;
324                if w != v {
325                    matching[u] = -1;
326                    matching[w] = -1;
327                    q.push_back(w);
328                    num_matched -= 1;
329                }
330            }
331            matching[u] = v as i64;
332            matching[v] = u as i64;
333            num_matched += 1;
334            labels[u] += 2;
335            label_changed += 1;
336        }
337    }
338
339    let result: Vec<Option<u32>> = matching
340        .iter()
341        .map(|&m| if m < 0 { None } else { Some(m as u32) })
342        .collect();
343
344    Ok((result, num_matched))
345}
346
347fn global_relabel(
348    adj: &[Vec<u32>],
349    labels: &mut [i64],
350    matching: &[i64],
351    types: &[bool],
352    smaller_set: bool,
353    n: usize,
354) {
355    labels.fill(n as i64);
356
357    let mut q: VecDeque<usize> = VecDeque::new();
358    for i in 0..n {
359        if types[i] != smaller_set && matching[i] == -1 {
360            q.push_back(i);
361            labels[i] = 0;
362        }
363    }
364
365    while let Some(v) = q.pop_front() {
366        for &nb in &adj[v] {
367            let w = nb as usize;
368            if labels[w] == n as i64 {
369                labels[w] = labels[v] + 1;
370                let matched_to = matching[w];
371                if matched_to != -1 {
372                    let mt = matched_to as usize;
373                    if labels[mt] == n as i64 {
374                        q.push_back(mt);
375                        labels[mt] = labels[w] + 1;
376                    }
377                }
378            }
379        }
380    }
381}
382
383// ── Hungarian (weighted) ────────────────────────────────────────────
384
385fn hungarian(
386    graph: &Graph,
387    types: &[bool],
388    weights: &[f64],
389    eps: f64,
390    n: usize,
391    ne: usize,
392) -> IgraphResult<MatchingResult> {
393    // Build incidence list: for each vertex, list of (edge_id, other_vertex)
394    let mut incidence: Vec<Vec<(u32, u32)>> = vec![Vec::new(); n];
395    let mut edges: Vec<(u32, u32)> = Vec::with_capacity(ne);
396    for eid in 0..ne {
397        let (u, v) = graph.edge(eid as u32)?;
398        edges.push((u, v));
399        incidence[u as usize].push((eid as u32, v));
400        if u != v {
401            incidence[v as usize].push((eid as u32, u));
402        }
403    }
404
405    // Find smaller and larger sets
406    let count_false = types[..n].iter().filter(|&&t| !t).count();
407    let smaller_set_type = count_false > n / 2;
408    let smaller_set_size = if smaller_set_type {
409        n - count_false
410    } else {
411        count_false
412    };
413
414    let mut smaller_set: Vec<usize> = Vec::with_capacity(smaller_set_size);
415    let mut larger_set: Vec<usize> = Vec::with_capacity(n - smaller_set_size);
416    for (i, &tp) in types[..n].iter().enumerate() {
417        if tp == smaller_set_type {
418            smaller_set.push(i);
419        } else {
420            larger_set.push(i);
421        }
422    }
423
424    // Initial labeling: for each vertex in smaller set, label = max incident weight
425    let mut labels: Vec<f64> = vec![0.0; n];
426    for (i, &tp) in types[..n].iter().enumerate() {
427        if tp != smaller_set_type {
428            continue;
429        }
430        let mut max_w: f64 = 0.0;
431        for &(eid, other) in &incidence[i] {
432            if types[other as usize] == types[i] {
433                return Err(IgraphError::InvalidArgument(
434                    "Graph is not bipartite with supplied types vector".into(),
435                ));
436            }
437            if weights[eid as usize] > max_w {
438                max_w = weights[eid as usize];
439            }
440        }
441        labels[i] = max_w;
442    }
443
444    // Compute initial slack and tight edges
445    let mut slack: Vec<f64> = vec![0.0; ne];
446    let mut tight_edges: Vec<(u32, u32)> = Vec::new();
447    for eid in 0..ne {
448        let (u, v) = edges[eid];
449        slack[eid] = labels[u as usize] + labels[v as usize] - weights[eid];
450        if slack[eid] <= eps {
451            tight_edges.push((u, v));
452        }
453    }
454
455    // Build initial matching on tight edges using push-relabel
456    let tight_graph = crate::algorithms::constructors::create::create(
457        &tight_edges.iter().map(|&(a, b)| (a, b)).collect::<Vec<_>>(),
458        n as u32,
459        false,
460    )?;
461    let tight_adj = build_undirected_adj(&tight_graph);
462    let (init_match_opt, mut msize) = push_relabel_unweighted(&tight_graph, &tight_adj, types, n)?;
463    let mut matching: Vec<i64> = init_match_opt
464        .iter()
465        .map(|o| match o {
466            Some(v) => i64::from(*v),
467            None => -1,
468        })
469        .collect();
470
471    // Tight phantom edges adjacency (sorted for binary search)
472    let mut tight_phantom: Vec<Vec<usize>> = vec![Vec::new(); n];
473
474    // Main Hungarian loop
475    while msize < smaller_set_size {
476        let mut parent: Vec<i64> = vec![-1; n];
477        let mut reachable_smaller: Vec<usize> = Vec::new();
478        let mut reachable_larger: Vec<usize> = Vec::new();
479
480        // Fill queue with unmatched vertices from smaller set
481        let mut q: VecDeque<usize> = VecDeque::new();
482        for &s in &smaller_set {
483            if matching[s] == -1 {
484                q.push_back(s);
485                parent[s] = s as i64;
486                reachable_smaller.push(s);
487            }
488        }
489
490        // BFS along tight edges
491        let mut alternating_path_endpoint: i64 = -1;
492        'bfs: while let Some(v) = q.pop_front() {
493            // Real tight edges
494            for &(eid, other) in &incidence[v] {
495                let u = other as usize;
496                if slack[eid as usize] > eps {
497                    continue;
498                }
499                if parent[u] >= 0 {
500                    continue;
501                }
502                parent[u] = v as i64;
503                reachable_larger.push(u);
504                let w = matching[u];
505                if w == -1 {
506                    alternating_path_endpoint = u as i64;
507                    break 'bfs;
508                }
509                let w_usize = w as usize;
510                q.push_back(w_usize);
511                parent[w_usize] = u as i64;
512                reachable_smaller.push(w_usize);
513            }
514
515            // Tight phantom edges
516            for &u in &tight_phantom[v] {
517                if parent[u] >= 0 {
518                    continue;
519                }
520                if (labels[v] + labels[u]).abs() > eps {
521                    continue;
522                }
523                parent[u] = v as i64;
524                reachable_larger.push(u);
525                let w = matching[u];
526                if w == -1 {
527                    alternating_path_endpoint = u as i64;
528                    break 'bfs;
529                }
530                let w_usize = w as usize;
531                q.push_back(w_usize);
532                parent[w_usize] = u as i64;
533                reachable_smaller.push(w_usize);
534            }
535        }
536
537        if alternating_path_endpoint != -1 {
538            // Augment along alternating path
539            let mut v = alternating_path_endpoint as usize;
540            let mut u = parent[v] as usize;
541            while u != v {
542                let w = matching[v];
543                if w != -1 {
544                    matching[w as usize] = -1;
545                }
546                matching[v] = u as i64;
547                let w2 = matching[u];
548                if w2 != -1 {
549                    matching[w2 as usize] = -1;
550                }
551                matching[u] = v as i64;
552
553                v = parent[u] as usize;
554                u = parent[v] as usize;
555            }
556            msize += 1;
557            continue;
558        }
559
560        // No augmenting path found — update labels
561        // Find minimum slack between reachable smaller-set and unreachable larger-set
562
563        // Upper bound from phantom edges
564        let mut min_label_larger = f64::INFINITY;
565        for &l in &larger_set {
566            if labels[l] < min_label_larger {
567                min_label_larger = labels[l];
568            }
569        }
570        let mut min_label_reachable_smaller = f64::INFINITY;
571        for &s in &reachable_smaller {
572            if parent[s] >= 0 && labels[s] < min_label_reachable_smaller {
573                min_label_reachable_smaller = labels[s];
574            }
575        }
576        let mut min_slack = min_label_larger + min_label_reachable_smaller;
577
578        // Check real edges
579        for &u in &reachable_smaller {
580            for &(eid, other) in &incidence[u] {
581                let v_node = other as usize;
582                if parent[v_node] >= 0 {
583                    continue;
584                }
585                if slack[eid as usize] < min_slack {
586                    min_slack = slack[eid as usize];
587                }
588            }
589        }
590
591        if min_slack > 0.0 {
592            // Update labels and slack
593            for &u in &reachable_smaller {
594                labels[u] -= min_slack;
595                for &(eid, _) in &incidence[u] {
596                    slack[eid as usize] -= min_slack;
597                }
598            }
599            for &u in &reachable_larger {
600                labels[u] += min_slack;
601                for &(eid, _) in &incidence[u] {
602                    slack[eid as usize] += min_slack;
603                }
604            }
605        }
606
607        // Update tight phantom edges
608        for &u in &smaller_set {
609            for &v in &larger_set {
610                if (labels[u] + labels[v]).abs() <= eps {
611                    let phantoms = &mut tight_phantom[u];
612                    match phantoms.binary_search(&v) {
613                        Ok(_) => {} // already present
614                        Err(pos) => phantoms.insert(pos, v),
615                    }
616                }
617            }
618        }
619    }
620
621    // Remove phantom matches
622    for &u in &smaller_set {
623        let v = matching[u];
624        if v != -1 {
625            let v_usize = v as usize;
626            if tight_phantom[u].binary_search(&v_usize).is_ok() {
627                // Check if this is a real edge or phantom
628                let is_real = incidence[u]
629                    .iter()
630                    .any(|&(_, other)| other as usize == v_usize);
631                if !is_real {
632                    matching[u] = -1;
633                    matching[v_usize] = -1;
634                    msize -= 1;
635                }
636            }
637        }
638    }
639
640    // Compute matching weight
641    let mut total_weight: f64 = 0.0;
642    for eid in 0..ne {
643        if slack[eid] <= eps {
644            let (u, v) = edges[eid];
645            if matching[u as usize] == i64::from(v) {
646                total_weight += weights[eid];
647            }
648        }
649    }
650
651    let result_matching: Vec<Option<u32>> = matching
652        .iter()
653        .map(|&m| if m < 0 { None } else { Some(m as u32) })
654        .collect();
655
656    Ok(MatchingResult {
657        matching_size: msize,
658        matching_weight: total_weight,
659        matching: result_matching,
660    })
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666    use crate::algorithms::constructors::create::create;
667
668    fn make_k22() -> (Graph, Vec<bool>) {
669        let g = create(&[(0, 2), (0, 3), (1, 2), (1, 3)], 4, false).expect("K22");
670        let types = vec![false, false, true, true];
671        (g, types)
672    }
673
674    // ── is_matching ─────────────────────────────────────────────
675
676    #[test]
677    fn is_matching_valid() {
678        let (g, types) = make_k22();
679        let m = vec![Some(2), Some(3), Some(0), Some(1)];
680        assert!(is_matching(&g, Some(&types), &m).expect("ok"));
681    }
682
683    #[test]
684    fn is_matching_wrong_length() {
685        let (g, _) = make_k22();
686        let m = vec![Some(1), Some(0)];
687        assert!(!is_matching(&g, None, &m).expect("ok"));
688    }
689
690    #[test]
691    fn is_matching_non_mutual() {
692        let (g, _) = make_k22();
693        let m = vec![Some(2), None, None, None];
694        assert!(!is_matching(&g, None, &m).expect("ok"));
695    }
696
697    #[test]
698    fn is_matching_no_edge() {
699        let g = create(&[(0, 1)], 3, false).expect("ok");
700        // vertices 0 and 2 are not connected
701        let m = vec![Some(2), None, Some(0)];
702        assert!(!is_matching(&g, None, &m).expect("ok"));
703    }
704
705    #[test]
706    fn is_matching_all_unmatched_with_types() {
707        let g = create(&[(0, 1), (1, 2), (2, 3)], 4, false).expect("ok");
708        let types = vec![false, true, false, true];
709        let m = vec![None, None, None, None];
710        assert!(is_matching(&g, Some(&types), &m).expect("ok"));
711    }
712
713    #[test]
714    fn is_matching_types_same_partition() {
715        let g = create(&[(0, 1), (1, 2)], 3, false).expect("ok");
716        let types = vec![false, false, true]; // 0 and 1 same type but matched
717        let m = vec![Some(1), Some(0), None];
718        assert!(!is_matching(&g, Some(&types), &m).expect("ok"));
719    }
720
721    // ── is_maximal_matching ─────────────────────────────────────
722
723    #[test]
724    fn is_maximal_matching_true() {
725        let (g, types) = make_k22();
726        let m = vec![Some(2), Some(3), Some(0), Some(1)];
727        assert!(is_maximal_matching(&g, Some(&types), &m).expect("ok"));
728    }
729
730    #[test]
731    fn is_maximal_matching_false() {
732        let (g, types) = make_k22();
733        // Only 0-2 matched, but 1 and 3 are both unmatched and connected → not maximal
734        let m = vec![Some(2), None, Some(0), None];
735        assert!(!is_maximal_matching(&g, Some(&types), &m).expect("ok"));
736    }
737
738    #[test]
739    fn is_maximal_all_unmatched_no_edges() {
740        let g = Graph::new(3, false).expect("ok");
741        let m = vec![None, None, None];
742        assert!(is_maximal_matching(&g, None, &m).expect("ok"));
743    }
744
745    // ── maximum_bipartite_matching ──────────────────────────────
746
747    #[test]
748    fn max_matching_k22() {
749        let (g, types) = make_k22();
750        let r = maximum_bipartite_matching(&g, &types).expect("ok");
751        assert_eq!(r.matching_size, 2);
752        assert!(is_maximal_matching(&g, Some(&types), &r.matching).expect("ok"));
753    }
754
755    #[test]
756    fn max_matching_empty() {
757        let g = Graph::new(0, false).expect("ok");
758        let types: Vec<bool> = vec![];
759        let r = maximum_bipartite_matching(&g, &types).expect("ok");
760        assert_eq!(r.matching_size, 0);
761    }
762
763    #[test]
764    fn max_matching_singleton() {
765        let g = Graph::new(1, false).expect("ok");
766        let types = vec![false];
767        let r = maximum_bipartite_matching(&g, &types).expect("ok");
768        assert_eq!(r.matching_size, 0);
769    }
770
771    #[test]
772    fn max_matching_path_4() {
773        // 0-1-2-3, bipartite with types [F,T,F,T]
774        let g = create(&[(0, 1), (1, 2), (2, 3)], 4, false).expect("ok");
775        let types = vec![false, true, false, true];
776        let r = maximum_bipartite_matching(&g, &types).expect("ok");
777        assert_eq!(r.matching_size, 2);
778        assert!(is_maximal_matching(&g, Some(&types), &r.matching).expect("ok"));
779    }
780
781    #[test]
782    fn max_matching_star() {
783        // Star: 0 connected to 1,2,3,4
784        let g = create(&[(0, 1), (0, 2), (0, 3), (0, 4)], 5, false).expect("ok");
785        let types = vec![false, true, true, true, true];
786        let r = maximum_bipartite_matching(&g, &types).expect("ok");
787        assert_eq!(r.matching_size, 1);
788        assert!(is_maximal_matching(&g, Some(&types), &r.matching).expect("ok"));
789    }
790
791    #[test]
792    fn max_matching_complete_bipartite_k33() {
793        let g = create(
794            &[
795                (0, 3),
796                (0, 4),
797                (0, 5),
798                (1, 3),
799                (1, 4),
800                (1, 5),
801                (2, 3),
802                (2, 4),
803                (2, 5),
804            ],
805            6,
806            false,
807        )
808        .expect("ok");
809        let types = vec![false, false, false, true, true, true];
810        let r = maximum_bipartite_matching(&g, &types).expect("ok");
811        assert_eq!(r.matching_size, 3);
812        assert!(is_maximal_matching(&g, Some(&types), &r.matching).expect("ok"));
813    }
814
815    #[test]
816    fn max_matching_not_bipartite_error() {
817        // Triangle: not bipartite
818        let g = create(&[(0, 1), (1, 2), (2, 0)], 3, false).expect("ok");
819        let types = vec![false, true, false]; // 0 and 2 connected but same type
820        let r = maximum_bipartite_matching(&g, &types);
821        assert!(r.is_err());
822    }
823
824    #[test]
825    fn max_matching_disconnected() {
826        // Two disconnected edges
827        let g = create(&[(0, 1), (2, 3)], 4, false).expect("ok");
828        let types = vec![false, true, false, true];
829        let r = maximum_bipartite_matching(&g, &types).expect("ok");
830        assert_eq!(r.matching_size, 2);
831    }
832
833    #[test]
834    fn max_matching_types_too_short() {
835        let g = create(&[(0, 1)], 2, false).expect("ok");
836        let types = vec![false];
837        let r = maximum_bipartite_matching(&g, &types);
838        assert!(r.is_err());
839    }
840
841    // ── maximum_bipartite_matching_weighted ──────────────────────
842
843    #[test]
844    fn weighted_matching_simple() {
845        // K2,2 with weights: prefer 0-3 and 1-2
846        let g = create(&[(0, 2), (0, 3), (1, 2), (1, 3)], 4, false).expect("ok");
847        let types = vec![false, false, true, true];
848        let weights = vec![1.0, 10.0, 10.0, 1.0];
849        let r = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).expect("ok");
850        assert_eq!(r.matching_size, 2);
851        assert!((r.matching_weight - 20.0).abs() < 1e-9);
852    }
853
854    #[test]
855    fn weighted_matching_mit_notes() {
856        // Test graph from MIT lecture notes on matching
857        // 10 vertices: 0-4 in set A, 5-9 in set B
858        let g = create(
859            &[
860                (0, 6),
861                (0, 7),
862                (0, 8),
863                (0, 9),
864                (1, 5),
865                (1, 6),
866                (1, 7),
867                (1, 8),
868                (1, 9),
869                (2, 5),
870                (2, 6),
871                (2, 7),
872                (2, 8),
873                (2, 9),
874                (3, 5),
875                (3, 7),
876                (3, 9),
877                (4, 7),
878            ],
879            10,
880            false,
881        )
882        .expect("ok");
883        let types: Vec<bool> = (0..10).map(|i| i >= 5).collect();
884        let weights = vec![
885            2.0, 7.0, 2.0, 3.0, // edges from 0
886            1.0, 3.0, 9.0, 3.0, 3.0, // edges from 1
887            1.0, 3.0, 3.0, 1.0, 2.0, // edges from 2
888            4.0, 1.0, 2.0, // edges from 3
889            3.0, // edge from 4
890        ];
891        let r = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).expect("ok");
892        assert_eq!(r.matching_size, 4);
893        assert!((r.matching_weight - 19.0).abs() < 1e-9);
894        assert!(is_maximal_matching(&g, Some(&types), &r.matching).expect("ok"));
895    }
896
897    #[test]
898    fn weighted_matching_generated_case1() {
899        let g = create(&[(0, 8), (2, 7), (3, 7), (3, 8), (4, 5), (4, 9)], 10, false).expect("ok");
900        let types: Vec<bool> = (0..10).map(|i| i >= 5).collect();
901        let weights = vec![8.0, 5.0, 9.0, 18.0, 20.0, 13.0];
902        let r = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).expect("ok");
903        assert!((r.matching_weight - 43.0).abs() < 1e-9);
904    }
905
906    #[test]
907    fn weighted_matching_generated_case2() {
908        let g = create(&[(0, 5), (0, 6), (1, 7), (2, 5), (3, 5), (3, 9)], 10, false).expect("ok");
909        let types: Vec<bool> = (0..10).map(|i| i >= 5).collect();
910        let weights = vec![20.0, 4.0, 20.0, 3.0, 13.0, 1.0];
911        let r = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).expect("ok");
912        assert!((r.matching_weight - 41.0).abs() < 1e-9);
913    }
914
915    #[test]
916    fn weighted_matching_empty() {
917        let g = Graph::new(0, false).expect("ok");
918        let r = maximum_bipartite_matching_weighted(&g, &[], &[], 0.0).expect("ok");
919        assert_eq!(r.matching_size, 0);
920    }
921
922    #[test]
923    fn weighted_matching_no_edges() {
924        let g = Graph::new(4, false).expect("ok");
925        let types = vec![false, false, true, true];
926        let r = maximum_bipartite_matching_weighted(&g, &types, &[], 0.0).expect("ok");
927        assert_eq!(r.matching_size, 0);
928    }
929
930    // ── proptest ────────────────────────────────────────────────
931
932    #[cfg(all(test, feature = "proptest-harness"))]
933    mod proptests {
934        use super::*;
935        use proptest::prelude::*;
936
937        fn arb_bipartite_graph(
938            max_a: u32,
939            max_b: u32,
940        ) -> impl Strategy<Value = (Graph, Vec<bool>)> {
941            (1..=max_a, 1..=max_b).prop_flat_map(move |(a, b)| {
942                let pool = (a as usize) * (b as usize);
943                let mask_len = pool.min(20);
944                proptest::collection::vec(proptest::bool::ANY, mask_len).prop_map(move |mask| {
945                    let n = a + b;
946                    let mut edges = Vec::new();
947                    for (idx, &present) in mask.iter().enumerate() {
948                        if present {
949                            let u = (idx as u32) / b;
950                            let v = a + (idx as u32) % b;
951                            edges.push((u, v));
952                        }
953                    }
954                    let g = create(&edges, n, false).expect("bipartite graph");
955                    let types: Vec<bool> = (0..n).map(|i| i >= a).collect();
956                    (g, types)
957                })
958            })
959        }
960
961        proptest! {
962            #[test]
963            fn matching_is_valid((g, types) in arb_bipartite_graph(6, 6)) {
964                let r = maximum_bipartite_matching(&g, &types).expect("ok");
965                prop_assert!(is_matching(&g, Some(&types), &r.matching).expect("ok"));
966                prop_assert!(is_maximal_matching(&g, Some(&types), &r.matching).expect("ok"));
967            }
968
969            #[test]
970            fn matching_size_leq_min_partition(
971                (g, types) in arb_bipartite_graph(6, 6)
972            ) {
973                let r = maximum_bipartite_matching(&g, &types).expect("ok");
974                let a_size = types.iter().filter(|&&t| !t).count();
975                let b_size = types.iter().filter(|&&t| t).count();
976                prop_assert!(r.matching_size <= a_size.min(b_size));
977            }
978
979            #[test]
980            fn matching_size_leq_ecount(
981                (g, types) in arb_bipartite_graph(6, 6)
982            ) {
983                let r = maximum_bipartite_matching(&g, &types).expect("ok");
984                prop_assert!(r.matching_size <= g.ecount());
985            }
986
987            #[test]
988            fn weighted_matching_is_valid((g, types) in arb_bipartite_graph(5, 5)) {
989                let ne = g.ecount();
990                let weights: Vec<f64> = (0..ne).map(|i| (i as f64) + 1.0).collect();
991                if ne > 0 {
992                    let r = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).expect("ok");
993                    prop_assert!(is_matching(&g, Some(&types), &r.matching).expect("ok"));
994                }
995            }
996
997            #[test]
998            fn weighted_geq_unweighted_unit(
999                (g, types) in arb_bipartite_graph(5, 5)
1000            ) {
1001                let unw = maximum_bipartite_matching(&g, &types).expect("ok");
1002                let ne = g.ecount();
1003                let weights: Vec<f64> = vec![1.0; ne];
1004                if ne > 0 {
1005                    let w = maximum_bipartite_matching_weighted(&g, &types, &weights, 0.0).expect("ok");
1006                    // With unit weights, weighted should find same size or better
1007                    prop_assert!(w.matching_size >= unw.matching_size.saturating_sub(1),
1008                        "weighted: {}, unweighted: {}", w.matching_size, unw.matching_size);
1009                }
1010            }
1011        }
1012    }
1013}