1use super::{num_vertices, to_adjacency_list, validate_graph};
7use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use ndarray::Array1;
11use num_traits::Float;
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14use std::fmt::Debug;
15
16#[derive(Debug, Clone)]
18struct Edge<T>
19where
20 T: Float + PartialOrd,
21{
22 weight: T,
23 u: usize,
24 v: usize,
25}
26
27impl<T> PartialEq for Edge<T>
28where
29 T: Float + PartialOrd,
30{
31 fn eq(&self, other: &Self) -> bool {
32 self.weight == other.weight
33 }
34}
35
36impl<T> Eq for Edge<T> where T: Float + PartialOrd {}
37
38impl<T> PartialOrd for Edge<T>
39where
40 T: Float + PartialOrd,
41{
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43 Some(self.cmp(other))
44 }
45}
46
47impl<T> Ord for Edge<T>
48where
49 T: Float + PartialOrd,
50{
51 fn cmp(&self, other: &Self) -> Ordering {
52 other
54 .weight
55 .partial_cmp(&self.weight)
56 .unwrap_or(Ordering::Equal)
57 }
58}
59
60#[derive(Debug)]
62struct UnionFind {
63 parent: Vec<usize>,
64 rank: Vec<usize>,
65}
66
67impl UnionFind {
68 fn new(n: usize) -> Self {
69 Self {
70 parent: (0..n).collect(),
71 rank: vec![0; n],
72 }
73 }
74
75 fn find(&mut self, x: usize) -> usize {
76 if self.parent[x] != x {
77 self.parent[x] = self.find(self.parent[x]); }
79 self.parent[x]
80 }
81
82 fn union(&mut self, x: usize, y: usize) -> bool {
83 let root_x = self.find(x);
84 let root_y = self.find(y);
85
86 if root_x == root_y {
87 return false; }
89
90 match self.rank[root_x].cmp(&self.rank[root_y]) {
92 Ordering::Less => self.parent[root_x] = root_y,
93 Ordering::Greater => self.parent[root_y] = root_x,
94 Ordering::Equal => {
95 self.parent[root_y] = root_x;
96 self.rank[root_x] += 1;
97 }
98 }
99
100 true
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq)]
106pub enum MSTAlgorithm {
107 Kruskal,
109 Prim,
111 Auto,
113}
114
115impl MSTAlgorithm {
116 #[allow(clippy::should_implement_trait)]
117 pub fn from_str(s: &str) -> SparseResult<Self> {
118 match s.to_lowercase().as_str() {
119 "kruskal" => Ok(Self::Kruskal),
120 "prim" => Ok(Self::Prim),
121 "auto" => Ok(Self::Auto),
122 _ => Err(SparseError::ValueError(format!(
123 "Unknown MST algorithm: {s}. Use 'kruskal', 'prim', or 'auto'"
124 ))),
125 }
126 }
127}
128
129#[allow(dead_code)]
159pub fn minimum_spanning_tree<T, S>(
160 graph: &S,
161 algorithm: &str,
162 return_tree: bool,
163) -> SparseResult<(T, Option<CsrArray<T>>, Array1<isize>)>
164where
165 T: Float + Debug + Copy + 'static,
166 S: SparseArray<T>,
167{
168 validate_graph(graph, false)?; let n = num_vertices(graph);
170
171 if n == 0 {
172 return Err(SparseError::ValueError(
173 "Cannot compute MST of empty graph".to_string(),
174 ));
175 }
176
177 let mst_algorithm = MSTAlgorithm::from_str(algorithm)?;
178
179 let actual_algorithm = match mst_algorithm {
180 MSTAlgorithm::Auto => {
181 let nnz = graph.nnz();
184 if nnz <= n * n / 4 {
185 MSTAlgorithm::Kruskal
186 } else {
187 MSTAlgorithm::Prim
188 }
189 }
190 alg => alg,
191 };
192
193 match actual_algorithm {
194 MSTAlgorithm::Kruskal => kruskal_mst(graph, return_tree),
195 MSTAlgorithm::Prim => {
196 prim_mst(graph, 0, return_tree) }
198 MSTAlgorithm::Auto => unreachable!(),
199 }
200}
201
202#[allow(dead_code)]
204pub fn kruskal_mst<T, S>(
205 graph: &S,
206 return_tree: bool,
207) -> SparseResult<(T, Option<CsrArray<T>>, Array1<isize>)>
208where
209 T: Float + Debug + Copy + 'static,
210 S: SparseArray<T>,
211{
212 let n = num_vertices(graph);
213 let (row_indices, col_indices, values) = graph.find();
214
215 let mut edges = Vec::new();
217 for (i, (&u, &v)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
218 if u <= v && !values[i].is_zero() {
219 edges.push(Edge {
221 weight: values[i],
222 u,
223 v,
224 });
225 }
226 }
227
228 edges.sort_by(|a, b| a.weight.partial_cmp(&b.weight).unwrap_or(Ordering::Equal));
229
230 let mut union_find = UnionFind::new(n);
231 let mut mst_edges = Vec::new();
232 let mut total_weight = T::zero();
233 let mut parent = Array1::from_elem(n, -1isize);
234
235 for edge in edges {
236 if union_find.union(edge.u, edge.v) {
237 mst_edges.push(edge.clone());
238 total_weight = total_weight + edge.weight;
239
240 if parent[edge.v] == -1 {
242 parent[edge.v] = edge.u as isize;
243 } else if parent[edge.u] == -1 {
244 parent[edge.u] = edge.v as isize;
245 }
246
247 if mst_edges.len() == n - 1 {
249 break;
250 }
251 }
252 }
253
254 if mst_edges.len() != n - 1 {
256 return Err(SparseError::ValueError(
257 "Graph is not connected - cannot compute spanning tree".to_string(),
258 ));
259 }
260
261 let mst_matrix = if return_tree {
262 Some(build_mst_matrix(&mst_edges, n)?)
263 } else {
264 None
265 };
266
267 Ok((total_weight, mst_matrix, parent))
268}
269
270#[allow(dead_code)]
272pub fn prim_mst<T, S>(
273 graph: &S,
274 start: usize,
275 return_tree: bool,
276) -> SparseResult<(T, Option<CsrArray<T>>, Array1<isize>)>
277where
278 T: Float + Debug + Copy + 'static,
279 S: SparseArray<T>,
280{
281 let n = num_vertices(graph);
282 let adj_list = to_adjacency_list(graph, false)?; if start >= n {
285 return Err(SparseError::ValueError(format!(
286 "Start vertex {start} out of bounds for graph with {n} vertices"
287 )));
288 }
289
290 let mut in_mst = vec![false; n];
291 let mut min_weight = vec![T::infinity(); n];
292 let mut parent = Array1::from_elem(n, -1isize);
293 let mut total_weight = T::zero();
294 let mut mst_edges = Vec::new();
295
296 let mut heap = BinaryHeap::new();
298
299 min_weight[start] = T::zero();
301 heap.push(Edge {
302 weight: T::zero(),
303 u: start,
304 v: start,
305 });
306
307 while let Some(Edge { weight, u: _, v }) = heap.pop() {
308 if in_mst[v] {
309 continue;
310 }
311
312 in_mst[v] = true;
313 total_weight = total_weight + weight;
314
315 if weight > T::zero() {
316 mst_edges.push(Edge {
318 weight,
319 u: parent[v] as usize,
320 v,
321 });
322 }
323
324 for &(neighbor, edge_weight) in &adj_list[v] {
326 if !in_mst[neighbor] && edge_weight < min_weight[neighbor] {
327 min_weight[neighbor] = edge_weight;
328 parent[neighbor] = v as isize;
329
330 heap.push(Edge {
331 weight: edge_weight,
332 u: v,
333 v: neighbor,
334 });
335 }
336 }
337 }
338
339 let vertices_in_mst = in_mst.iter().filter(|&&x| x).count();
341 if vertices_in_mst != n {
342 return Err(SparseError::ValueError(
343 "Graph is not connected - cannot compute spanning tree".to_string(),
344 ));
345 }
346
347 let mst_matrix = if return_tree {
348 Some(build_mst_matrix(&mst_edges, n)?)
349 } else {
350 None
351 };
352
353 Ok((total_weight, mst_matrix, parent))
354}
355
356#[allow(dead_code)]
358fn build_mst_matrix<T>(edges: &[Edge<T>], n: usize) -> SparseResult<CsrArray<T>>
359where
360 T: Float + Debug + Copy + 'static,
361{
362 let mut rows = Vec::new();
363 let mut cols = Vec::new();
364 let mut values = Vec::new();
365
366 for edge in edges {
367 rows.push(edge.u);
369 cols.push(edge.v);
370 values.push(edge.weight);
371
372 rows.push(edge.v);
373 cols.push(edge.u);
374 values.push(edge.weight);
375 }
376
377 CsrArray::from_triplets(&rows, &cols, &values, (n, n), false)
378}
379
380#[allow(dead_code)]
392pub fn is_spanning_tree<T, S1, S2>(graph: &S1, tree: &S2, tol: T) -> SparseResult<bool>
393where
394 T: Float + Debug + Copy + 'static,
395 S1: SparseArray<T>,
396 S2: SparseArray<T>,
397{
398 let n = num_vertices(graph);
399 let m = num_vertices(tree);
400
401 if n != m {
403 return Ok(false);
404 }
405
406 let tree_edges = tree.nnz() / 2; if tree_edges != n - 1 {
409 return Ok(false);
410 }
411
412 let (tree_rows, tree_cols, tree_values) = tree.find();
414
415 for (i, (&u, &v)) in tree_rows.iter().zip(tree_cols.iter()).enumerate() {
416 if u < v {
417 let graph_weight = graph.get(u, v);
419 let tree_weight = tree_values[i];
420
421 if (graph_weight - tree_weight).abs() > tol {
422 return Ok(false);
423 }
424 }
425 }
426
427 Ok(true)
431}
432
433#[allow(dead_code)]
443pub fn spanning_tree_weight<T, S>(tree: &S) -> SparseResult<T>
444where
445 T: Float + Debug + Copy + 'static,
446 S: SparseArray<T>,
447{
448 let (row_indices, col_indices, values) = tree.find();
449 let mut total_weight = T::zero();
450
451 for (i, (&u, &v)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
453 if u <= v {
454 total_weight = total_weight + values[i];
455 }
456 }
457
458 Ok(total_weight)
459}
460
461#[allow(dead_code)]
479pub fn all_minimum_spanning_trees<T, S>(
480 graph: &S,
481 algorithm: &str,
482) -> SparseResult<(CsrArray<T>, bool, T)>
483where
484 T: Float + Debug + Copy + 'static,
485 S: SparseArray<T>,
486{
487 let (total_weight, mst_, _) = minimum_spanning_tree(graph, algorithm, true)?;
488 let mst = mst_.unwrap();
489
490 let (_, _, values) = graph.find();
492 let mut weights: Vec<_> = values.iter().copied().collect();
493 weights.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
494
495 let has_duplicates = weights
496 .windows(2)
497 .any(|w| (w[0] - w[1]).abs() < T::from(1e-10).unwrap());
498
499 Ok((mst, has_duplicates, total_weight))
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::csr_array::CsrArray;
506 use approx::assert_relative_eq;
507
508 fn create_test_graph() -> CsrArray<f64> {
509 let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3];
518 let cols = vec![1, 2, 0, 2, 3, 0, 1, 3, 1, 2];
519 let data = vec![1.0, 2.0, 1.0, 1.0, 3.0, 2.0, 1.0, 4.0, 3.0, 4.0];
520
521 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
522 }
523
524 #[test]
525 fn test_union_find() {
526 let mut uf = UnionFind::new(4);
527
528 assert_ne!(uf.find(0), uf.find(1));
530 assert_ne!(uf.find(1), uf.find(2));
531
532 assert!(uf.union(0, 1));
534 assert_eq!(uf.find(0), uf.find(1));
535
536 assert!(uf.union(1, 2));
538 assert_eq!(uf.find(0), uf.find(2));
539
540 assert!(!uf.union(0, 2));
542 }
543
544 #[test]
545 fn test_kruskal_mst() {
546 let graph = create_test_graph();
547 let (total_weight, mst_, _) = kruskal_mst(&graph, true).unwrap();
548
549 assert_relative_eq!(total_weight, 5.0);
551
552 let mst = mst_.unwrap();
553
554 assert_eq!(mst.nnz(), 6); let calculated_weight = spanning_tree_weight(&mst).unwrap();
559 assert_relative_eq!(calculated_weight, total_weight);
560
561 assert!(is_spanning_tree(&graph, &mst, 1e-10).unwrap());
563 }
564
565 #[test]
566 fn test_prim_mst() {
567 let graph = create_test_graph();
568 let (total_weight, mst_, _mst_parents) = prim_mst(&graph, 0, true).unwrap();
569
570 assert_relative_eq!(total_weight, 5.0);
572
573 let mst = mst_.unwrap();
574 assert_eq!(mst.nnz(), 6); assert!(is_spanning_tree(&graph, &mst, 1e-10).unwrap());
578 }
579
580 #[test]
581 fn test_minimum_spanning_tree_api() {
582 let graph = create_test_graph();
583
584 let (weight_k_, _, _) = minimum_spanning_tree(&graph, "kruskal", false).unwrap();
586 assert_relative_eq!(weight_k_, 5.0);
587
588 let (weight_p_, _, _) = minimum_spanning_tree(&graph, "prim", false).unwrap();
590 assert_relative_eq!(weight_p_, 5.0);
591
592 let (weight_a_, _, _) = minimum_spanning_tree(&graph, "auto", false).unwrap();
594 assert_relative_eq!(weight_a_, 5.0);
595 }
596
597 #[test]
598 fn test_disconnected_graph() {
599 let rows = vec![0, 1, 2, 3];
601 let cols = vec![1, 0, 3, 2];
602 let data = vec![1.0, 1.0, 1.0, 1.0];
603 let graph = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
604
605 assert!(minimum_spanning_tree(&graph, "kruskal", false).is_err());
607 assert!(minimum_spanning_tree(&graph, "prim", false).is_err());
608 }
609
610 #[test]
611 fn test_single_vertex() {
612 let graph: CsrArray<f64> = CsrArray::from_triplets(&[], &[], &[], (1, 1), false).unwrap();
614
615 let (total_weight, mst_, _) = minimum_spanning_tree(&graph, "kruskal", true).unwrap();
616 assert_relative_eq!(total_weight, 0.0);
617
618 let mst = mst_.unwrap();
619 assert_eq!(mst.nnz(), 0); }
621
622 #[test]
623 fn test_two_vertices() {
624 let rows = vec![0, 1];
626 let cols = vec![1, 0];
627 let data = vec![5.0, 5.0];
628 let graph = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
629
630 let (total_weight, mst_, _mst_parents) =
631 minimum_spanning_tree(&graph, "prim", true).unwrap();
632 assert_relative_eq!(total_weight, 5.0);
633
634 let mst = mst_.unwrap();
635 assert_eq!(mst.nnz(), 2); }
637
638 #[test]
639 fn test_complete_graph() {
640 let rows = vec![0, 0, 0, 1, 1, 2];
642 let cols = vec![1, 2, 3, 2, 3, 3];
643 let data = vec![1.0, 4.0, 3.0, 2.0, 5.0, 6.0];
644
645 let mut all_rows = rows.clone();
647 let mut all_cols = cols.clone();
648 let mut all_data = data.clone();
649
650 for (i, (&r, &c)) in rows.iter().zip(cols.iter()).enumerate() {
651 all_rows.push(c);
652 all_cols.push(r);
653 all_data.push(data[i]);
654 }
655
656 let graph =
657 CsrArray::from_triplets(&all_rows, &all_cols, &all_data, (4, 4), false).unwrap();
658
659 let (total_weight_, _, _) = minimum_spanning_tree(&graph, "kruskal", false).unwrap();
660
661 assert_relative_eq!(total_weight_, 6.0);
663 }
664
665 #[test]
666 fn test_spanning_tree_validation() {
667 let graph = create_test_graph();
668 let (_, mst_, _) = minimum_spanning_tree(&graph, "kruskal", true).unwrap();
669 let mst = mst_.unwrap();
670
671 assert!(is_spanning_tree(&graph, &mst, 1e-10).unwrap());
673
674 let rows = vec![0, 1];
676 let cols = vec![1, 0];
677 let data = vec![1.0, 1.0];
678 let invalid_tree = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
679
680 assert!(!is_spanning_tree(&graph, &invalid_tree, 1e-10).unwrap());
681 }
682
683 #[test]
684 fn test_algorithm_selection() {
685 let _graph = create_test_graph();
686
687 assert!(matches!(
689 MSTAlgorithm::from_str("kruskal"),
690 Ok(MSTAlgorithm::Kruskal)
691 ));
692 assert!(matches!(
693 MSTAlgorithm::from_str("prim"),
694 Ok(MSTAlgorithm::Prim)
695 ));
696 assert!(matches!(
697 MSTAlgorithm::from_str("auto"),
698 Ok(MSTAlgorithm::Auto)
699 ));
700 assert!(MSTAlgorithm::from_str("invalid").is_err());
701 }
702}