uni_algo/algo/algorithms/mst.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Minimum Spanning Tree (MST) Algorithm.
5//!
6//! Uses Kruskal's algorithm to find the Minimum Spanning Tree of a weighted graph.
7//! Treating graph as undirected (if include_reverse is true, we dedup edges; if not, we treat directed as undirected structure).
8//! Returns the edges in the MST and total weight.
9
10use crate::algo::GraphProjection;
11use crate::algo::algorithms::Algorithm;
12use uni_common::core::id::Vid;
13
14pub struct MinimumSpanningTree;
15
16#[derive(Debug, Clone, Default)]
17pub struct MstConfig {
18 // If true, treats graph as undirected by considering u->v and v->u as same edge.
19 // If false, treats directed edges, but MST is usually defined for undirected.
20 // Kruskal's works on edges.
21}
22
23pub struct MstResult {
24 pub edges: Vec<(Vid, Vid, f64)>, // (u, v, weight)
25 pub total_weight: f64,
26}
27
28impl Algorithm for MinimumSpanningTree {
29 type Config = MstConfig;
30 type Result = MstResult;
31
32 fn name() -> &'static str {
33 "mst"
34 }
35
36 fn run(graph: &GraphProjection, _config: Self::Config) -> Self::Result {
37 let n = graph.vertex_count();
38 if n == 0 {
39 return MstResult {
40 edges: Vec::new(),
41 total_weight: 0.0,
42 };
43 }
44
45 // Collect all edges
46 // If we want to treat graph as undirected, we should dedup (u, v) and (v, u).
47 // Standard approach: normalize (min, max).
48 let mut edges = Vec::new();
49 for u in 0..n as u32 {
50 for (i, &v) in graph.out_neighbors(u).iter().enumerate() {
51 // If treating as undirected, only add if u < v to avoid duplicates
52 // assuming symmetry. If not symmetric, Kruskal's on directed graph
53 // produces Minimum Spanning Forest (Arborescence is different).
54 // Let's assume undirected MST on the underlying graph structure.
55 if u < v {
56 let weight = if graph.has_weights() {
57 graph.out_weight(u, i)
58 } else {
59 1.0
60 };
61 edges.push((u, v, weight));
62 }
63 }
64 }
65
66 // Sort by weight
67 edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
68
69 // Union-Find
70 let mut parent: Vec<u32> = (0..n as u32).collect();
71 let mut rank: Vec<u8> = vec![0; n];
72
73 fn find(parent: &mut [u32], mut x: u32) -> u32 {
74 while parent[x as usize] != x {
75 parent[x as usize] = parent[parent[x as usize] as usize];
76 x = parent[x as usize];
77 }
78 x
79 }
80
81 fn union(parent: &mut [u32], rank: &mut [u8], x: u32, y: u32) -> bool {
82 let px = find(parent, x);
83 let py = find(parent, y);
84 if px == py {
85 return false;
86 }
87 match rank[px as usize].cmp(&rank[py as usize]) {
88 std::cmp::Ordering::Less => parent[px as usize] = py,
89 std::cmp::Ordering::Greater => parent[py as usize] = px,
90 std::cmp::Ordering::Equal => {
91 parent[py as usize] = px;
92 rank[px as usize] += 1;
93 }
94 }
95 true
96 }
97
98 let mut mst_edges = Vec::new();
99 let mut total_weight = 0.0;
100
101 for (u, v, w) in edges {
102 if union(&mut parent, &mut rank, u, v) {
103 mst_edges.push((graph.to_vid(u), graph.to_vid(v), w));
104 total_weight += w;
105 }
106 }
107
108 MstResult {
109 edges: mst_edges,
110 total_weight,
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::algo::test_utils::build_test_graph;
119
120 #[test]
121 fn test_mst_simple() {
122 // 0-1 (1.0), 1-2 (2.0), 0-2 (10.0)
123 // MST should be (0,1) and (1,2) => weight 3.0
124 let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2)];
125 let edges = vec![
126 (Vid::from(0), Vid::from(1)),
127 (Vid::from(1), Vid::from(2)),
128 (Vid::from(0), Vid::from(2)),
129 ];
130 // build_test_graph does not support weights yet.
131 // I need to update build_test_graph or create a new helper.
132 // Or I can modify GraphProjection field directly since it's pub(crate) and I'm in same crate.
133
134 let mut graph = build_test_graph(vids, edges);
135 // Inject weights manually
136 // Edges in build_test_graph are added in order of iteration over `edges`.
137 // Order: (0,1), (1,2), (0,2).
138 // Node 0: out_neighbors [1, 2]
139 // Node 1: out_neighbors [2]
140 // Node 2: []
141 // We need to match this structure.
142
143 // 0->1 (idx 0 for node 0) -> weight 1.0
144 // 0->2 (idx 1 for node 0) -> weight 10.0
145 // 1->2 (idx 0 for node 1) -> weight 2.0
146
147 // Flattened weights vector for GraphProjection:
148 // Node 0 edges, then Node 1 edges, etc.
149 // Node 0: [1.0, 10.0]
150 // Node 1: [2.0]
151 // Node 2: []
152
153 graph.out_weights = Some(vec![1.0, 10.0, 2.0]);
154
155 let result = MinimumSpanningTree::run(&graph, MstConfig::default());
156 assert_eq!(result.total_weight, 3.0);
157 assert_eq!(result.edges.len(), 2);
158 }
159}