Skip to main content

uni_algo/algo/algorithms/
mst.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Minimum Spanning Tree (MST) Algorithm.
5//!
6//! Uses Kruskal's algorithm to find the Minimum Spanning Tree of a weighted graph.
7//! Treating graph as undirected (if include_reverse is true, we dedup edges; if not, we treat directed as undirected structure).
8//! Returns the edges in the MST and total weight.
9
10use crate::algo::GraphProjection;
11use crate::algo::algorithms::Algorithm;
12use uni_common::core::id::Vid;
13
14pub struct MinimumSpanningTree;
15
16#[derive(Debug, Clone, Default)]
17pub struct MstConfig {
18    // If true, treats graph as undirected by considering u->v and v->u as same edge.
19    // If false, treats directed edges, but MST is usually defined for undirected.
20    // Kruskal's works on edges.
21}
22
23pub struct MstResult {
24    pub edges: Vec<(Vid, Vid, f64)>, // (u, v, weight)
25    pub total_weight: f64,
26}
27
28impl Algorithm for MinimumSpanningTree {
29    type Config = MstConfig;
30    type Result = MstResult;
31
32    fn name() -> &'static str {
33        "mst"
34    }
35
36    fn run(graph: &GraphProjection, _config: Self::Config) -> Self::Result {
37        let n = graph.vertex_count();
38        if n == 0 {
39            return MstResult {
40                edges: Vec::new(),
41                total_weight: 0.0,
42            };
43        }
44
45        // Collect all edges
46        // If we want to treat graph as undirected, we should dedup (u, v) and (v, u).
47        // Standard approach: normalize (min, max).
48        let mut edges = Vec::new();
49        for u in 0..n as u32 {
50            for (i, &v) in graph.out_neighbors(u).iter().enumerate() {
51                // If treating as undirected, only add if u < v to avoid duplicates
52                // assuming symmetry. If not symmetric, Kruskal's on directed graph
53                // produces Minimum Spanning Forest (Arborescence is different).
54                // Let's assume undirected MST on the underlying graph structure.
55                if u < v {
56                    let weight = if graph.has_weights() {
57                        graph.out_weight(u, i)
58                    } else {
59                        1.0
60                    };
61                    edges.push((u, v, weight));
62                }
63            }
64        }
65
66        // Sort by weight
67        edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
68
69        // Union-Find
70        let mut parent: Vec<u32> = (0..n as u32).collect();
71        let mut rank: Vec<u8> = vec![0; n];
72
73        fn find(parent: &mut [u32], mut x: u32) -> u32 {
74            while parent[x as usize] != x {
75                parent[x as usize] = parent[parent[x as usize] as usize];
76                x = parent[x as usize];
77            }
78            x
79        }
80
81        fn union(parent: &mut [u32], rank: &mut [u8], x: u32, y: u32) -> bool {
82            let px = find(parent, x);
83            let py = find(parent, y);
84            if px == py {
85                return false;
86            }
87            match rank[px as usize].cmp(&rank[py as usize]) {
88                std::cmp::Ordering::Less => parent[px as usize] = py,
89                std::cmp::Ordering::Greater => parent[py as usize] = px,
90                std::cmp::Ordering::Equal => {
91                    parent[py as usize] = px;
92                    rank[px as usize] += 1;
93                }
94            }
95            true
96        }
97
98        let mut mst_edges = Vec::new();
99        let mut total_weight = 0.0;
100
101        for (u, v, w) in edges {
102            if union(&mut parent, &mut rank, u, v) {
103                mst_edges.push((graph.to_vid(u), graph.to_vid(v), w));
104                total_weight += w;
105            }
106        }
107
108        MstResult {
109            edges: mst_edges,
110            total_weight,
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::algo::test_utils::build_test_graph;
119
120    #[test]
121    fn test_mst_simple() {
122        // 0-1 (1.0), 1-2 (2.0), 0-2 (10.0)
123        // MST should be (0,1) and (1,2) => weight 3.0
124        let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2)];
125        let edges = vec![
126            (Vid::from(0), Vid::from(1)),
127            (Vid::from(1), Vid::from(2)),
128            (Vid::from(0), Vid::from(2)),
129        ];
130        // build_test_graph does not support weights yet.
131        // I need to update build_test_graph or create a new helper.
132        // Or I can modify GraphProjection field directly since it's pub(crate) and I'm in same crate.
133
134        let mut graph = build_test_graph(vids, edges);
135        // Inject weights manually
136        // Edges in build_test_graph are added in order of iteration over `edges`.
137        // Order: (0,1), (1,2), (0,2).
138        // Node 0: out_neighbors [1, 2]
139        // Node 1: out_neighbors [2]
140        // Node 2: []
141        // We need to match this structure.
142
143        // 0->1 (idx 0 for node 0) -> weight 1.0
144        // 0->2 (idx 1 for node 0) -> weight 10.0
145        // 1->2 (idx 0 for node 1) -> weight 2.0
146
147        // Flattened weights vector for GraphProjection:
148        // Node 0 edges, then Node 1 edges, etc.
149        // Node 0: [1.0, 10.0]
150        // Node 1: [2.0]
151        // Node 2: []
152
153        graph.out_weights = Some(vec![1.0, 10.0, 2.0]);
154
155        let result = MinimumSpanningTree::run(&graph, MstConfig::default());
156        assert_eq!(result.total_weight, 3.0);
157        assert_eq!(result.edges.len(), 2);
158    }
159}