1use core::panic;
4use std::{
5 cmp::Reverse,
6 ops::{Deref, DerefMut},
7};
8
9use crate::{
10 computation_mode::*,
11 datastructures::{AdjacencyMatrix, Edge, Graph, NAMatrix},
12};
13
14use delegate::delegate;
15use nalgebra::{Dyn, U1};
16use ordered_float::OrderedFloat;
17use priority_queue::PriorityQueue;
18use rayon::prelude::*;
19
20pub fn prim<const MODE: usize>(graph: &NAMatrix) -> Graph {
26 match MODE {
27 SEQ_COMPUTATION => prim_with_excluded_node_single_threaded(graph, &[]),
28 PAR_COMPUTATION => prim_with_excluded_node_multi_threaded(graph, &[]),
29 #[cfg(feature = "mpi")]
30 MPI_COMPUTATION => {
31 eprintln!("Warning: defaulting to sequential implementation of prims algorithm");
32 prim::<SEQ_COMPUTATION>(graph)
33 }
34 _ => panic_on_invaid_mode::<MODE>(),
35 }
36}
37
38pub fn prim_with_excluded_node_multi_threaded(
43 graph: &NAMatrix,
44 excluded_vertices: &[usize],
45) -> Graph {
46 prim_with_excluded_node::<MultiThreadedVecWrapper>(graph, excluded_vertices)
47}
48
49pub fn prim_with_excluded_node_single_threaded(
62 graph: &NAMatrix,
63 excluded_vertices: &[usize],
64) -> Graph {
65 prim_with_excluded_node::<Vec<(Edge, bool)>>(graph, excluded_vertices)
66}
67
68pub fn prim_with_excluded_node_priority_queue(
71 graph: &NAMatrix,
72 excluded_vertices: &[usize],
73) -> Graph {
74 prim_with_excluded_node::<VerticesInPriorityQueue>(graph, excluded_vertices)
75}
76
77fn prim_with_excluded_node<D: FindMinCostEdge>(
84 graph: &NAMatrix,
85 excluded_vertices: &[usize],
86) -> Graph {
87 let num_vertices = graph.dim();
88 let unconnected_node = num_vertices;
89
90 let mut mst_adj_list: Vec<Vec<Edge>> = vec![Vec::new(); num_vertices + 1];
92
93 let mut dist_from_mst = D::from_default_value(
96 Edge {
98 cost: f64::INFINITY,
99 to: unconnected_node,
100 },
101 num_vertices + 1,
102 );
103
104 let start_index = {
110 let mut idx = 0;
111 while excluded_vertices.contains(&idx) {
112 idx += 1;
113 }
114 if idx >= num_vertices {
115 return vec![].into();
117 }
118 idx
119 };
120
121 dist_from_mst.set_cost(
122 start_index,
123 Edge {
124 to: start_index,
125 cost: 0.,
126 },
127 );
128 for &vertex in excluded_vertices {
129 dist_from_mst.set_excluded_vertex(vertex)
130 }
131
132 for _ in 0..=num_vertices {
134 let (next_vertex, next_edge) = dist_from_mst.find_edge_with_minimal_cost();
135
136 if next_edge.cost == f64::INFINITY {
139 break;
140 }
141
142 dist_from_mst.mark_vertex_as_used(next_vertex);
144 if next_vertex != start_index {
145 let reverse_edge = Edge {
147 to: next_vertex,
148 cost: next_edge.cost,
149 };
150 let connection_from = next_edge.to;
151 let connection_to = next_vertex;
152 mst_adj_list[connection_to].push(next_edge);
153 mst_adj_list[connection_from].push(reverse_edge);
154 }
155
156 dist_from_mst.update_minimal_cost(next_vertex, graph.row(next_vertex))
162 }
163
164 mst_adj_list.pop();
166 Graph::from(mst_adj_list)
167}
168
169type NAMatrixRowView<'a> =
170 nalgebra::Matrix<f64, U1, Dyn, nalgebra::ViewStorage<'a, f64, U1, Dyn, U1, Dyn>>;
171
172trait FindMinCostEdge {
176 fn from_default_value(default_val: Edge, size: usize) -> Self;
177
178 fn find_edge_with_minimal_cost(&self) -> (usize, Edge);
182 fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView);
189
190 fn set_cost(&mut self, from: usize, edge_to: Edge);
192
193 fn set_excluded_vertex(&mut self, excluded_vertex: usize);
195
196 fn mark_vertex_as_used(&mut self, used_vertex: usize);
197}
198
199#[derive(Clone, Debug, PartialEq)]
200struct VerticesInPriorityQueue {
201 cost_queue: PriorityQueue<usize, Reverse<OrderedFloat<f64>>>,
204 connection_to_mst: Vec<usize>,
208 used: Vec<bool>,
210}
211impl FindMinCostEdge for VerticesInPriorityQueue {
212 fn from_default_value(default_val: Edge, size: usize) -> Self {
213 VerticesInPriorityQueue {
214 cost_queue: PriorityQueue::from(
215 (0..size)
216 .map(|i| (i, Reverse(OrderedFloat(default_val.cost))))
217 .collect::<Vec<(usize, Reverse<OrderedFloat<f64>>)>>(),
218 ),
219 connection_to_mst: vec![default_val.to; size],
220 used: vec![false; size],
221 }
222 }
223
224 fn find_edge_with_minimal_cost(&self) -> (usize, Edge) {
225 let base_case = Edge {
226 to: self.connection_to_mst.len(),
227 cost: f64::INFINITY,
228 };
229 let (&next_vertex, &Reverse(OrderedFloat(cost))) = self
230 .cost_queue
231 .peek()
232 .unwrap_or((&base_case.to, &Reverse(OrderedFloat(base_case.cost))));
233 let to = self.connection_to_mst[next_vertex];
234
235 (next_vertex, Edge { to, cost })
236 }
237
238 fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView) {
239 for (to, &cost) in new_neighbours.iter().enumerate() {
240 if self.used[to] {
241 continue;
242 }
243 let Reverse(OrderedFloat(old_cost)) = self.cost_queue
244 .push_increase(to, Reverse(OrderedFloat(cost)))
245 .unwrap_or_else(|| panic!("Every unused unused vertex shall be contained in the queue from the beginning. Missing vertex: {}", to));
246 if cost <= old_cost {
247 self.connection_to_mst[to] = from;
248 }
249 }
250 }
251
252 fn set_cost(&mut self, from: usize, edge_to: Edge) {
253 self.cost_queue
254 .change_priority(&from, Reverse(OrderedFloat(edge_to.cost)));
255
256 self.connection_to_mst[from] = edge_to.to;
257 }
258
259 fn set_excluded_vertex(&mut self, excluded_vertex: usize) {
260 self.mark_vertex_as_used(excluded_vertex);
261 }
262
263 fn mark_vertex_as_used(&mut self, used_vertex: usize) {
264 self.cost_queue.remove(&used_vertex);
265 self.used[used_vertex] = true;
266 }
267}
268
269impl FindMinCostEdge for Vec<(Edge, bool)> {
274 fn from_default_value(default_val: Edge, size: usize) -> Self {
275 vec![(default_val, false); size]
276 }
277
278 fn find_edge_with_minimal_cost(&self) -> (usize, Edge) {
279 let base_case = Edge {
280 to: self.len(),
281 cost: f64::INFINITY,
282 };
283 let (next_vertex, reverse_edge) = self
284 .iter()
285 .enumerate()
286 .filter_map(
288 |(i, &(edge, used_in_mst))| if used_in_mst { None } else { Some((i, edge)) },
289 )
290 .min_by(|&(_, edg_i), &(_, edg_j)| {
292 OrderedFloat(edg_i.cost).cmp(&OrderedFloat(edg_j.cost))
293 })
294 .unwrap_or((base_case.to, base_case));
296 (next_vertex, reverse_edge)
297 }
298
299 fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView) {
300 for (to, &cost) in new_neighbours.iter().enumerate() {
302 if cost < self[to].0.cost {
303 self[to].0 = Edge { to: from, cost };
304 }
305 }
306 }
307
308 fn set_cost(&mut self, from: usize, edge_to: Edge) {
309 self[from].0 = edge_to;
310 }
311
312 fn mark_vertex_as_used(&mut self, used_vertex: usize) {
313 self[used_vertex].1 = true;
314 }
315
316 fn set_excluded_vertex(&mut self, excluded_vertex: usize) {
317 self.mark_vertex_as_used(excluded_vertex);
318 }
319}
320
321#[derive(Debug, PartialEq)]
322struct MultiThreadedVecWrapper(Vec<(Edge, bool)>);
323
324impl Deref for MultiThreadedVecWrapper {
325 type Target = Vec<(Edge, bool)>;
326 fn deref(&self) -> &Self::Target {
327 &self.0
328 }
329}
330impl DerefMut for MultiThreadedVecWrapper {
331 fn deref_mut(&mut self) -> &mut Self::Target {
332 &mut self.0
333 }
334}
335
336impl FindMinCostEdge for MultiThreadedVecWrapper {
337 fn from_default_value(default_val: Edge, size: usize) -> Self {
338 MultiThreadedVecWrapper(Vec::from_default_value(default_val, size))
339 }
340 delegate! {
341 to self.0 {
342 fn set_cost(&mut self, from: usize, edge_to: Edge);
343 fn set_excluded_vertex(&mut self, excluded_vertex: usize);
344 fn mark_vertex_as_used(&mut self, used_vertex: usize);
345 }
346 }
347
348 fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView) {
349 let dim = new_neighbours.shape().1;
351 (0..dim).into_par_iter().for_each(|to| {
353 let neighbour_prt = new_neighbours.as_ptr() as *mut f64;
354 let cost = unsafe { *neighbour_prt.add(dim * to) };
359 let to_dist_ptr = self.as_ptr() as *mut (Edge, bool);
360 if cost < self[to].0.cost {
361 unsafe {
366 (*to_dist_ptr.add(to)).0 = Edge { to: from, cost };
367 }
368 }
369 });
370 }
371
372 fn find_edge_with_minimal_cost(&self) -> (usize, Edge) {
373 let base_case = Edge {
374 to: self.0.len(),
375 cost: f64::INFINITY,
376 };
377 let (next_vertex, reverse_edge) = self
378 .0
379 .par_iter()
380 .enumerate()
381 .filter_map(
383 |(i, &(edge, used_in_mst))| if used_in_mst { None } else { Some((i, edge)) },
384 )
385 .min_by(|&(_, edg_i), &(_, edg_j)| {
387 OrderedFloat(edg_i.cost).cmp(&OrderedFloat(edg_j.cost))
388 })
389 .unwrap_or((base_case.to, base_case));
391 (next_vertex, reverse_edge)
392 }
393}
394
395#[cfg(test)]
396mod test {
397 use std::assert_eq;
398
399 use nalgebra::DMatrix;
400
401 use super::*;
402
403 #[test]
404 fn easy_prim() {
405 let graph = Graph::from(vec![
406 vec![Edge { to: 1, cost: 1.0 }],
407 vec![Edge { to: 0, cost: 1.0 }],
408 ]);
409
410 let mst = prim::<SEQ_COMPUTATION>(&(&graph).into());
411 assert_eq!(graph, mst);
412 }
413
414 #[test]
436 fn four_vertices_mst_prim() {
437 let graph = Graph::from(vec![
438 vec![
440 Edge { to: 1, cost: 1.0 },
441 Edge { to: 2, cost: 0.1 },
442 Edge { to: 3, cost: 2.0 },
443 ],
444 vec![
446 Edge { to: 0, cost: 1.0 },
447 Edge { to: 2, cost: 5.0 },
448 Edge { to: 3, cost: 0.1 },
449 ],
450 vec![
452 Edge { to: 0, cost: 0.1 },
453 Edge { to: 1, cost: 1.1 },
454 Edge { to: 3, cost: 0.1 },
455 ],
456 vec![
458 Edge { to: 0, cost: 2.0 },
459 Edge { to: 1, cost: 0.1 },
460 Edge { to: 2, cost: 0.1 },
461 ],
462 ]);
463
464 let expected = Graph::from(vec![
465 vec![Edge { to: 2, cost: 0.1 }],
467 vec![Edge { to: 3, cost: 0.1 }],
469 vec![Edge { to: 0, cost: 0.1 }, Edge { to: 3, cost: 0.1 }],
471 vec![Edge { to: 2, cost: 0.1 }, Edge { to: 1, cost: 0.1 }],
473 ]);
474
475 assert_eq!(expected, prim::<SEQ_COMPUTATION>(&(&graph).into()));
476 }
477
478 #[test]
502 fn exclude_one_vertex_from_mst() {
503 let graph = Graph::from(vec![
504 vec![
506 Edge { to: 1, cost: 1.0 },
507 Edge { to: 2, cost: 0.1 },
508 Edge { to: 3, cost: 2.0 },
509 ],
510 vec![
512 Edge { to: 0, cost: 1.0 },
513 Edge { to: 2, cost: 5.0 },
514 Edge { to: 3, cost: 0.1 },
515 ],
516 vec![
518 Edge { to: 0, cost: 0.1 },
519 Edge { to: 1, cost: 1.1 },
520 Edge { to: 3, cost: 0.1 },
521 ],
522 vec![
524 Edge { to: 0, cost: 2.0 },
525 Edge { to: 1, cost: 0.1 },
526 Edge { to: 2, cost: 0.1 },
527 ],
528 ]);
529
530 let expected = Graph::from(vec![
531 vec![],
533 vec![Edge { to: 3, cost: 0.1 }],
535 vec![Edge { to: 3, cost: 0.1 }],
537 vec![Edge { to: 1, cost: 0.1 }, Edge { to: 2, cost: 0.1 }],
539 ]);
540
541 assert_eq!(
542 expected,
543 prim_with_excluded_node_multi_threaded(&(&graph).into(), &[0])
544 );
545 }
546
547 #[test]
548 fn prim_all_versions_agree() {
549 let graph = Graph::from(vec![
550 vec![
552 Edge { to: 1, cost: 1.0 },
553 Edge { to: 2, cost: 0.1 },
554 Edge { to: 3, cost: 2.0 },
555 ],
556 vec![
558 Edge { to: 0, cost: 1.0 },
559 Edge { to: 2, cost: 5.0 },
560 Edge { to: 3, cost: 0.1 },
561 ],
562 vec![
564 Edge { to: 0, cost: 0.1 },
565 Edge { to: 1, cost: 1.1 },
566 Edge { to: 3, cost: 0.1 },
567 ],
568 vec![
570 Edge { to: 0, cost: 2.0 },
571 Edge { to: 1, cost: 0.1 },
572 Edge { to: 2, cost: 0.1 },
573 ],
574 ]);
575 let excluded_vertex = &[0];
576 let res_st = prim_with_excluded_node_single_threaded(&(&graph).into(), excluded_vertex);
577 let res_mt = prim_with_excluded_node_multi_threaded(&(&graph).into(), excluded_vertex);
578 let res_prio = prim_with_excluded_node_priority_queue(&(&graph).into(), excluded_vertex);
579 assert_eq!(
580 res_st, res_mt,
581 "single_threaded should agree with multi_threaded"
582 );
583 assert_eq!(
584 res_st, res_prio,
585 "single_threaded should agree with priority queue version"
586 );
587 }
588
589 #[test]
590 fn test_vertices_in_priority_queue_from_default_value() {
591 let default_val = Edge {
592 to: 3,
593 cost: f64::INFINITY,
594 };
595
596 let size = 5;
597
598 let vert = VerticesInPriorityQueue::from_default_value(default_val, size);
599
600 let mut queue = PriorityQueue::new();
601 for i in 0..size {
602 queue.push(i, Reverse(OrderedFloat(f64::INFINITY)));
603 }
604
605 assert_eq!(vert.cost_queue, queue);
606 assert_eq!(vert.cost_queue.into_vec(), vec![0, 1, 2, 3, 4]);
607 assert_eq!(vert.connection_to_mst, vec![3; 5])
608 }
609
610 #[test]
611 fn test_vertices_in_priority_queue_increase_priority() {
612 let default_val = Edge {
613 to: 4,
614 cost: f64::INFINITY,
615 };
616
617 let size = 5;
618
619 let mut vert = VerticesInPriorityQueue::from_default_value(default_val, size);
620
621 let res = vert.cost_queue.push_increase(0, Reverse(OrderedFloat(1.0)));
622 assert_eq!(res, Some(Reverse(OrderedFloat(f64::INFINITY))));
623 }
624
625 #[test]
626 fn test_vertices_in_priority_queue_update_priority_does_not_panic() {
627 let default_val = Edge {
628 to: 4,
629 cost: f64::INFINITY,
630 };
631
632 let size = 5;
633
634 let mut vert = VerticesInPriorityQueue::from_default_value(default_val, size);
635 let mat = DMatrix::from_row_slice(1, size, &[1.0; 5]);
636
637 vert.update_minimal_cost(0, mat.row(0));
638 }
639
640 #[test]
641 fn test_vertices_in_priority_queue_update_priority_works() {
642 let default_val = Edge {
643 to: 4,
644 cost: f64::INFINITY,
645 };
646
647 let size = 5;
648
649 let mut vert = VerticesInPriorityQueue::from_default_value(default_val, size);
650 let mat = DMatrix::from_row_slice(1, size, &[0.0, 1.0, 0.0, 0.0, 0.0]);
651
652 vert.update_minimal_cost(0, mat.row(0));
653 assert_eq!(vert.connection_to_mst[1], 0);
654 assert_eq!(
655 vert.cost_queue.get_priority(&1),
656 Some(&Reverse(OrderedFloat(1.0f64)))
657 );
658 }
659}