scirs2_sparse/csgraph/
minimum_spanning_tree.rs

1//! Minimum spanning tree algorithms for sparse graphs
2//!
3//! This module provides efficient implementations of minimum spanning tree (MST)
4//! algorithms for sparse matrices representing weighted graphs.
5
6use super::{num_vertices, to_adjacency_list, validate_graph};
7use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use ndarray::Array1;
11use num_traits::Float;
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14use std::fmt::Debug;
15
16/// Edge representation for MST algorithms
17#[derive(Debug, Clone)]
18struct Edge<T>
19where
20    T: Float + PartialOrd,
21{
22    weight: T,
23    u: usize,
24    v: usize,
25}
26
27impl<T> PartialEq for Edge<T>
28where
29    T: Float + PartialOrd,
30{
31    fn eq(&self, other: &Self) -> bool {
32        self.weight == other.weight
33    }
34}
35
36impl<T> Eq for Edge<T> where T: Float + PartialOrd {}
37
38impl<T> PartialOrd for Edge<T>
39where
40    T: Float + PartialOrd,
41{
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43        Some(self.cmp(other))
44    }
45}
46
47impl<T> Ord for Edge<T>
48where
49    T: Float + PartialOrd,
50{
51    fn cmp(&self, other: &Self) -> Ordering {
52        // Reverse ordering for min-heap behavior
53        other
54            .weight
55            .partial_cmp(&self.weight)
56            .unwrap_or(Ordering::Equal)
57    }
58}
59
60/// Union-Find (Disjoint Set Union) data structure
61#[derive(Debug)]
62struct UnionFind {
63    parent: Vec<usize>,
64    rank: Vec<usize>,
65}
66
67impl UnionFind {
68    fn new(n: usize) -> Self {
69        Self {
70            parent: (0..n).collect(),
71            rank: vec![0; n],
72        }
73    }
74
75    fn find(&mut self, x: usize) -> usize {
76        if self.parent[x] != x {
77            self.parent[x] = self.find(self.parent[x]); // Path compression
78        }
79        self.parent[x]
80    }
81
82    fn union(&mut self, x: usize, y: usize) -> bool {
83        let root_x = self.find(x);
84        let root_y = self.find(y);
85
86        if root_x == root_y {
87            return false; // Already in the same set
88        }
89
90        // Union by rank
91        match self.rank[root_x].cmp(&self.rank[root_y]) {
92            Ordering::Less => self.parent[root_x] = root_y,
93            Ordering::Greater => self.parent[root_y] = root_x,
94            Ordering::Equal => {
95                self.parent[root_y] = root_x;
96                self.rank[root_x] += 1;
97            }
98        }
99
100        true
101    }
102}
103
104/// MST algorithm types
105#[derive(Debug, Clone, Copy, PartialEq)]
106pub enum MSTAlgorithm {
107    /// Kruskal's algorithm
108    Kruskal,
109    /// Prim's algorithm
110    Prim,
111    /// Automatic selection based on graph properties
112    Auto,
113}
114
115impl MSTAlgorithm {
116    #[allow(clippy::should_implement_trait)]
117    pub fn from_str(s: &str) -> SparseResult<Self> {
118        match s.to_lowercase().as_str() {
119            "kruskal" => Ok(Self::Kruskal),
120            "prim" => Ok(Self::Prim),
121            "auto" => Ok(Self::Auto),
122            _ => Err(SparseError::ValueError(format!(
123                "Unknown MST algorithm: {s}. Use 'kruskal', 'prim', or 'auto'"
124            ))),
125        }
126    }
127}
128
129/// Compute the minimum spanning tree of a graph
130///
131/// # Arguments
132///
133/// * `graph` - The graph as a sparse matrix (must be undirected and connected)
134/// * `algorithm` - MST algorithm to use
135/// * `return_tree` - Whether to return the MST as a sparse matrix
136///
137/// # Returns
138///
139/// A tuple containing:
140/// - Total weight of the MST
141/// - Optional MST as a sparse matrix (if requested)
142/// - Array of parent indices in the MST
143///
144/// # Examples
145///
146/// ```
147/// use scirs2_sparse::csgraph::minimum_spanning_tree;
148/// use scirs2_sparse::csr_array::CsrArray;
149///
150/// // Create a weighted symmetric graph
151/// let rows = vec![0, 0, 1, 1, 2, 2];
152/// let cols = vec![1, 2, 0, 2, 0, 1];
153/// let data = vec![2.0, 3.0, 2.0, 1.0, 3.0, 1.0];
154/// let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
155///
156/// let (total_weight, mst, parents) = minimum_spanning_tree(&graph, "kruskal", true).unwrap();
157/// ```
158#[allow(dead_code)]
159pub fn minimum_spanning_tree<T, S>(
160    graph: &S,
161    algorithm: &str,
162    return_tree: bool,
163) -> SparseResult<(T, Option<CsrArray<T>>, Array1<isize>)>
164where
165    T: Float + Debug + Copy + 'static,
166    S: SparseArray<T>,
167{
168    validate_graph(graph, false)?; // Must be undirected
169    let n = num_vertices(graph);
170
171    if n == 0 {
172        return Err(SparseError::ValueError(
173            "Cannot compute MST of empty graph".to_string(),
174        ));
175    }
176
177    let mst_algorithm = MSTAlgorithm::from_str(algorithm)?;
178
179    let actual_algorithm = match mst_algorithm {
180        MSTAlgorithm::Auto => {
181            // For sparse graphs, Kruskal is often more efficient
182            // For dense graphs, Prim might be better
183            let nnz = graph.nnz();
184            if nnz <= n * n / 4 {
185                MSTAlgorithm::Kruskal
186            } else {
187                MSTAlgorithm::Prim
188            }
189        }
190        alg => alg,
191    };
192
193    match actual_algorithm {
194        MSTAlgorithm::Kruskal => kruskal_mst(graph, return_tree),
195        MSTAlgorithm::Prim => {
196            prim_mst(graph, 0, return_tree) // Start from vertex 0
197        }
198        MSTAlgorithm::Auto => unreachable!(),
199    }
200}
201
202/// Kruskal's algorithm for MST
203#[allow(dead_code)]
204pub fn kruskal_mst<T, S>(
205    graph: &S,
206    return_tree: bool,
207) -> SparseResult<(T, Option<CsrArray<T>>, Array1<isize>)>
208where
209    T: Float + Debug + Copy + 'static,
210    S: SparseArray<T>,
211{
212    let n = num_vertices(graph);
213    let (row_indices, col_indices, values) = graph.find();
214
215    // Create edges and sort them by weight
216    let mut edges = Vec::new();
217    for (i, (&u, &v)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
218        if u <= v && !values[i].is_zero() {
219            // Avoid duplicate edges for undirected graph
220            edges.push(Edge {
221                weight: values[i],
222                u,
223                v,
224            });
225        }
226    }
227
228    edges.sort_by(|a, b| a.weight.partial_cmp(&b.weight).unwrap_or(Ordering::Equal));
229
230    let mut union_find = UnionFind::new(n);
231    let mut mst_edges = Vec::new();
232    let mut total_weight = T::zero();
233    let mut parent = Array1::from_elem(n, -1isize);
234
235    for edge in edges {
236        if union_find.union(edge.u, edge.v) {
237            mst_edges.push(edge.clone());
238            total_weight = total_weight + edge.weight;
239
240            // Set parent relationship (arbitrary choice for undirected tree)
241            if parent[edge.v] == -1 {
242                parent[edge.v] = edge.u as isize;
243            } else if parent[edge.u] == -1 {
244                parent[edge.u] = edge.v as isize;
245            }
246
247            // MST has n-1 edges
248            if mst_edges.len() == n - 1 {
249                break;
250            }
251        }
252    }
253
254    // Check if graph is connected
255    if mst_edges.len() != n - 1 {
256        return Err(SparseError::ValueError(
257            "Graph is not connected - cannot compute spanning tree".to_string(),
258        ));
259    }
260
261    let mst_matrix = if return_tree {
262        Some(build_mst_matrix(&mst_edges, n)?)
263    } else {
264        None
265    };
266
267    Ok((total_weight, mst_matrix, parent))
268}
269
270/// Prim's algorithm for MST
271#[allow(dead_code)]
272pub fn prim_mst<T, S>(
273    graph: &S,
274    start: usize,
275    return_tree: bool,
276) -> SparseResult<(T, Option<CsrArray<T>>, Array1<isize>)>
277where
278    T: Float + Debug + Copy + 'static,
279    S: SparseArray<T>,
280{
281    let n = num_vertices(graph);
282    let adj_list = to_adjacency_list(graph, false)?; // Undirected
283
284    if start >= n {
285        return Err(SparseError::ValueError(format!(
286            "Start vertex {start} out of bounds for graph with {n} vertices"
287        )));
288    }
289
290    let mut in_mst = vec![false; n];
291    let mut min_weight = vec![T::infinity(); n];
292    let mut parent = Array1::from_elem(n, -1isize);
293    let mut total_weight = T::zero();
294    let mut mst_edges = Vec::new();
295
296    // Priority queue for edges (weight, vertex)
297    let mut heap = BinaryHeap::new();
298
299    // Start with the given vertex
300    min_weight[start] = T::zero();
301    heap.push(Edge {
302        weight: T::zero(),
303        u: start,
304        v: start,
305    });
306
307    while let Some(Edge { weight, u: _, v }) = heap.pop() {
308        if in_mst[v] {
309            continue;
310        }
311
312        in_mst[v] = true;
313        total_weight = total_weight + weight;
314
315        if weight > T::zero() {
316            // Add edge to MST (except for the first vertex)
317            mst_edges.push(Edge {
318                weight,
319                u: parent[v] as usize,
320                v,
321            });
322        }
323
324        // Update neighbors
325        for &(neighbor, edge_weight) in &adj_list[v] {
326            if !in_mst[neighbor] && edge_weight < min_weight[neighbor] {
327                min_weight[neighbor] = edge_weight;
328                parent[neighbor] = v as isize;
329
330                heap.push(Edge {
331                    weight: edge_weight,
332                    u: v,
333                    v: neighbor,
334                });
335            }
336        }
337    }
338
339    // Check if all vertices are reachable
340    let vertices_in_mst = in_mst.iter().filter(|&&x| x).count();
341    if vertices_in_mst != n {
342        return Err(SparseError::ValueError(
343            "Graph is not connected - cannot compute spanning tree".to_string(),
344        ));
345    }
346
347    let mst_matrix = if return_tree {
348        Some(build_mst_matrix(&mst_edges, n)?)
349    } else {
350        None
351    };
352
353    Ok((total_weight, mst_matrix, parent))
354}
355
356/// Build a sparse matrix representation of the MST from edges
357#[allow(dead_code)]
358fn build_mst_matrix<T>(edges: &[Edge<T>], n: usize) -> SparseResult<CsrArray<T>>
359where
360    T: Float + Debug + Copy + 'static,
361{
362    let mut rows = Vec::new();
363    let mut cols = Vec::new();
364    let mut values = Vec::new();
365
366    for edge in edges {
367        // Add both directions for undirected tree
368        rows.push(edge.u);
369        cols.push(edge.v);
370        values.push(edge.weight);
371
372        rows.push(edge.v);
373        cols.push(edge.u);
374        values.push(edge.weight);
375    }
376
377    CsrArray::from_triplets(&rows, &cols, &values, (n, n), false)
378}
379
380/// Check if a tree is a valid spanning tree of a graph
381///
382/// # Arguments
383///
384/// * `graph` - The original graph
385/// * `tree` - The potential spanning tree
386/// * `tol` - Tolerance for weight comparisons
387///
388/// # Returns
389///
390/// True if the tree is a valid spanning tree, false otherwise
391#[allow(dead_code)]
392pub fn is_spanning_tree<T, S1, S2>(graph: &S1, tree: &S2, tol: T) -> SparseResult<bool>
393where
394    T: Float + Debug + Copy + 'static,
395    S1: SparseArray<T>,
396    S2: SparseArray<T>,
397{
398    let n = num_vertices(graph);
399    let m = num_vertices(tree);
400
401    // Must have same number of vertices
402    if n != m {
403        return Ok(false);
404    }
405
406    // Tree must have exactly n-1 edges (counting each undirected edge once)
407    let tree_edges = tree.nnz() / 2; // Assuming undirected representation
408    if tree_edges != n - 1 {
409        return Ok(false);
410    }
411
412    // All edges in tree must exist in original graph with same weight
413    let (tree_rows, tree_cols, tree_values) = tree.find();
414
415    for (i, (&u, &v)) in tree_rows.iter().zip(tree_cols.iter()).enumerate() {
416        if u < v {
417            // Check each edge only once
418            let graph_weight = graph.get(u, v);
419            let tree_weight = tree_values[i];
420
421            if (graph_weight - tree_weight).abs() > tol {
422                return Ok(false);
423            }
424        }
425    }
426
427    // Check connectivity (tree should connect all vertices)
428    // This is implicitly checked by the n-1 edges condition for a tree
429
430    Ok(true)
431}
432
433/// Compute the weight of a spanning tree
434///
435/// # Arguments
436///
437/// * `tree` - The spanning tree as a sparse matrix
438///
439/// # Returns
440///
441/// Total weight of the spanning tree
442#[allow(dead_code)]
443pub fn spanning_tree_weight<T, S>(tree: &S) -> SparseResult<T>
444where
445    T: Float + Debug + Copy + 'static,
446    S: SparseArray<T>,
447{
448    let (row_indices, col_indices, values) = tree.find();
449    let mut total_weight = T::zero();
450
451    // Sum weights, counting each undirected edge only once
452    for (i, (&u, &v)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
453        if u <= v {
454            total_weight = total_weight + values[i];
455        }
456    }
457
458    Ok(total_weight)
459}
460
461/// Find all minimum spanning trees of a graph
462///
463/// # Note
464/// This is a computationally expensive operation for large graphs.
465/// It returns one MST and indicates if multiple MSTs exist.
466///
467/// # Arguments
468///
469/// * `graph` - The graph as a sparse matrix
470/// * `algorithm` - MST algorithm to use
471///
472/// # Returns
473///
474/// A tuple containing:
475/// - One minimum spanning tree
476/// - Boolean indicating if multiple MSTs exist
477/// - Total weight of any MST
478#[allow(dead_code)]
479pub fn all_minimum_spanning_trees<T, S>(
480    graph: &S,
481    algorithm: &str,
482) -> SparseResult<(CsrArray<T>, bool, T)>
483where
484    T: Float + Debug + Copy + 'static,
485    S: SparseArray<T>,
486{
487    let (total_weight, mst_, _) = minimum_spanning_tree(graph, algorithm, true)?;
488    let mst = mst_.unwrap();
489
490    // Simple heuristic: if there are edges with equal weights, multiple MSTs might exist
491    let (_, _, values) = graph.find();
492    let mut weights: Vec<_> = values.iter().copied().collect();
493    weights.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
494
495    let has_duplicates = weights
496        .windows(2)
497        .any(|w| (w[0] - w[1]).abs() < T::from(1e-10).unwrap());
498
499    Ok((mst, has_duplicates, total_weight))
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::csr_array::CsrArray;
506    use approx::assert_relative_eq;
507
508    fn create_test_graph() -> CsrArray<f64> {
509        // Create a simple weighted graph:
510        //     1
511        //  0 --- 1
512        //  |   / |
513        //  |2 /1  |3
514        //  | /    |
515        //  2 ---- 3
516        //     4
517        let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3];
518        let cols = vec![1, 2, 0, 2, 3, 0, 1, 3, 1, 2];
519        let data = vec![1.0, 2.0, 1.0, 1.0, 3.0, 2.0, 1.0, 4.0, 3.0, 4.0];
520
521        CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
522    }
523
524    #[test]
525    fn test_union_find() {
526        let mut uf = UnionFind::new(4);
527
528        // Initially, all elements are in separate sets
529        assert_ne!(uf.find(0), uf.find(1));
530        assert_ne!(uf.find(1), uf.find(2));
531
532        // Union 0 and 1
533        assert!(uf.union(0, 1));
534        assert_eq!(uf.find(0), uf.find(1));
535
536        // Union 1 and 2 (effectively 0, 1, 2 in same set)
537        assert!(uf.union(1, 2));
538        assert_eq!(uf.find(0), uf.find(2));
539
540        // Try to union elements already in same set
541        assert!(!uf.union(0, 2));
542    }
543
544    #[test]
545    fn test_kruskal_mst() {
546        let graph = create_test_graph();
547        let (total_weight, mst_, _) = kruskal_mst(&graph, true).unwrap();
548
549        // MST should have weight 5 (edges: 0-1 weight 1, 1-2 weight 1, 1-3 weight 3)
550        assert_relative_eq!(total_weight, 5.0);
551
552        let mst = mst_.unwrap();
553
554        // MST should have 3 edges (4 vertices - 1)
555        assert_eq!(mst.nnz(), 6); // 3 edges * 2 (undirected)
556
557        // Check that MST weight calculation is correct
558        let calculated_weight = spanning_tree_weight(&mst).unwrap();
559        assert_relative_eq!(calculated_weight, total_weight);
560
561        // Check that it's a valid spanning tree
562        assert!(is_spanning_tree(&graph, &mst, 1e-10).unwrap());
563    }
564
565    #[test]
566    fn test_prim_mst() {
567        let graph = create_test_graph();
568        let (total_weight, mst_, _mst_parents) = prim_mst(&graph, 0, true).unwrap();
569
570        // Should produce the same weight as Kruskal
571        assert_relative_eq!(total_weight, 5.0);
572
573        let mst = mst_.unwrap();
574        assert_eq!(mst.nnz(), 6); // 3 edges * 2 (undirected)
575
576        // Check that it's a valid spanning tree
577        assert!(is_spanning_tree(&graph, &mst, 1e-10).unwrap());
578    }
579
580    #[test]
581    fn test_minimum_spanning_tree_api() {
582        let graph = create_test_graph();
583
584        // Test Kruskal
585        let (weight_k_, _, _) = minimum_spanning_tree(&graph, "kruskal", false).unwrap();
586        assert_relative_eq!(weight_k_, 5.0);
587
588        // Test Prim
589        let (weight_p_, _, _) = minimum_spanning_tree(&graph, "prim", false).unwrap();
590        assert_relative_eq!(weight_p_, 5.0);
591
592        // Test auto selection
593        let (weight_a_, _, _) = minimum_spanning_tree(&graph, "auto", false).unwrap();
594        assert_relative_eq!(weight_a_, 5.0);
595    }
596
597    #[test]
598    fn test_disconnected_graph() {
599        // Create a disconnected graph
600        let rows = vec![0, 1, 2, 3];
601        let cols = vec![1, 0, 3, 2];
602        let data = vec![1.0, 1.0, 1.0, 1.0];
603        let graph = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
604
605        // MST should fail for disconnected graph
606        assert!(minimum_spanning_tree(&graph, "kruskal", false).is_err());
607        assert!(minimum_spanning_tree(&graph, "prim", false).is_err());
608    }
609
610    #[test]
611    fn test_single_vertex() {
612        // Single vertex graph
613        let graph: CsrArray<f64> = CsrArray::from_triplets(&[], &[], &[], (1, 1), false).unwrap();
614
615        let (total_weight, mst_, _) = minimum_spanning_tree(&graph, "kruskal", true).unwrap();
616        assert_relative_eq!(total_weight, 0.0);
617
618        let mst = mst_.unwrap();
619        assert_eq!(mst.nnz(), 0); // No edges in single vertex tree
620    }
621
622    #[test]
623    fn test_two_vertices() {
624        // Two vertex graph
625        let rows = vec![0, 1];
626        let cols = vec![1, 0];
627        let data = vec![5.0, 5.0];
628        let graph = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
629
630        let (total_weight, mst_, _mst_parents) =
631            minimum_spanning_tree(&graph, "prim", true).unwrap();
632        assert_relative_eq!(total_weight, 5.0);
633
634        let mst = mst_.unwrap();
635        assert_eq!(mst.nnz(), 2); // One edge * 2 (undirected)
636    }
637
638    #[test]
639    fn test_complete_graph() {
640        // Create a complete graph on 4 vertices with different weights
641        let rows = vec![0, 0, 0, 1, 1, 2];
642        let cols = vec![1, 2, 3, 2, 3, 3];
643        let data = vec![1.0, 4.0, 3.0, 2.0, 5.0, 6.0];
644
645        // Make it symmetric
646        let mut all_rows = rows.clone();
647        let mut all_cols = cols.clone();
648        let mut all_data = data.clone();
649
650        for (i, (&r, &c)) in rows.iter().zip(cols.iter()).enumerate() {
651            all_rows.push(c);
652            all_cols.push(r);
653            all_data.push(data[i]);
654        }
655
656        let graph =
657            CsrArray::from_triplets(&all_rows, &all_cols, &all_data, (4, 4), false).unwrap();
658
659        let (total_weight_, _, _) = minimum_spanning_tree(&graph, "kruskal", false).unwrap();
660
661        // MST should use edges: 0-1 (1), 1-2 (2), 0-3 (3) for total weight 6
662        assert_relative_eq!(total_weight_, 6.0);
663    }
664
665    #[test]
666    fn test_spanning_tree_validation() {
667        let graph = create_test_graph();
668        let (_, mst_, _) = minimum_spanning_tree(&graph, "kruskal", true).unwrap();
669        let mst = mst_.unwrap();
670
671        // Valid spanning tree
672        assert!(is_spanning_tree(&graph, &mst, 1e-10).unwrap());
673
674        // Create an invalid tree (wrong number of edges)
675        let rows = vec![0, 1];
676        let cols = vec![1, 0];
677        let data = vec![1.0, 1.0];
678        let invalid_tree = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
679
680        assert!(!is_spanning_tree(&graph, &invalid_tree, 1e-10).unwrap());
681    }
682
683    #[test]
684    fn test_algorithm_selection() {
685        let _graph = create_test_graph();
686
687        // Test algorithm string parsing
688        assert!(matches!(
689            MSTAlgorithm::from_str("kruskal"),
690            Ok(MSTAlgorithm::Kruskal)
691        ));
692        assert!(matches!(
693            MSTAlgorithm::from_str("prim"),
694            Ok(MSTAlgorithm::Prim)
695        ));
696        assert!(matches!(
697            MSTAlgorithm::from_str("auto"),
698            Ok(MSTAlgorithm::Auto)
699        ));
700        assert!(MSTAlgorithm::from_str("invalid").is_err());
701    }
702}