rustworkx_core/steiner_tree.rs
1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use std::cmp::{Eq, Ordering};
14use std::convert::Infallible;
15use std::hash::Hash;
16
17use hashbrown::{HashMap, HashSet};
18use rayon::prelude::*;
19
20use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph};
21use petgraph::unionfind::UnionFind;
22use petgraph::visit::{
23 EdgeCount, EdgeIndexable, EdgeRef, GraphProp, IntoEdgeReferences, IntoEdges,
24 IntoNodeIdentifiers, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef, Visitable,
25};
26use petgraph::Undirected;
27
28use crate::dictmap::*;
29use crate::shortest_path::dijkstra;
30use crate::utils::pairwise;
31
32type AllPairsDijkstraReturn = HashMap<usize, (DictMap<usize, Vec<usize>>, DictMap<usize, f64>)>;
33
34fn all_pairs_dijkstra_shortest_paths<G, F, E>(
35 graph: G,
36 mut weight_fn: F,
37) -> Result<AllPairsDijkstraReturn, E>
38where
39 G: NodeIndexable
40 + IntoNodeIdentifiers
41 + EdgeCount
42 + NodeCount
43 + EdgeIndexable
44 + Visitable
45 + Sync
46 + IntoEdges,
47 G::NodeId: Eq + Hash + Send,
48 G::EdgeId: Eq + Hash + Send,
49 F: FnMut(G::EdgeRef) -> Result<f64, E>,
50{
51 if graph.node_count() == 0 {
52 return Ok(HashMap::new());
53 } else if graph.edge_count() == 0 {
54 return Ok(graph
55 .node_identifiers()
56 .map(|x| {
57 (
58 NodeIndexable::to_index(&graph, x),
59 (DictMap::new(), DictMap::new()),
60 )
61 })
62 .collect());
63 }
64
65 let mut edge_weights: Vec<Option<f64>> = vec![None; graph.edge_bound()];
66 for edge in graph.edge_references() {
67 let index = EdgeIndexable::to_index(&graph, edge.id());
68 edge_weights[index] = Some(weight_fn(edge)?);
69 }
70 let edge_cost = |e: G::EdgeRef| -> Result<f64, Infallible> {
71 Ok(edge_weights[EdgeIndexable::to_index(&graph, e.id())].unwrap())
72 };
73
74 let node_indices: Vec<usize> = graph
75 .node_identifiers()
76 .map(|n| NodeIndexable::to_index(&graph, n))
77 .collect();
78 Ok(node_indices
79 .into_par_iter()
80 .map(|x| {
81 let mut paths: DictMap<G::NodeId, Vec<G::NodeId>> =
82 DictMap::with_capacity(graph.node_count());
83 let distances: DictMap<G::NodeId, f64> = dijkstra(
84 graph,
85 NodeIndexable::from_index(&graph, x),
86 None,
87 edge_cost,
88 Some(&mut paths),
89 )
90 .unwrap();
91 (
92 x,
93 (
94 paths
95 .into_iter()
96 .map(|(k, v)| {
97 (
98 NodeIndexable::to_index(&graph, k),
99 v.into_iter()
100 .map(|n| NodeIndexable::to_index(&graph, n))
101 .collect(),
102 )
103 })
104 .collect(),
105 distances
106 .into_iter()
107 .map(|(k, v)| (NodeIndexable::to_index(&graph, k), v))
108 .collect(),
109 ),
110 )
111 })
112 .collect())
113}
114
115struct MetricClosureEdge {
116 source: usize,
117 target: usize,
118 distance: f64,
119 path: Vec<usize>,
120}
121
122/// Return the metric closure of a graph
123///
124/// The metric closure of a graph is the complete graph in which each edge is
125/// weighted by the shortest path distance between the nodes in the graph.
126///
127/// Arguments:
128/// `graph`: The input graph to compute the metric closure for
129/// `weight_fn`: A callable weight function that will be passed an edge reference
130/// for each edge in the graph and it is expected to return a `Result<f64>`
131/// which if it doesn't error represents the weight of that edge.
132/// `default_weight`: A blind callable that returns a default weight to use for
133/// edges added to the output
134///
135/// Returns a `StableGraph` with the input graph node ids for node weights and edge weights with a
136/// tuple of the numeric weight (found via `weight_fn`) and the path. The output will be `None`
137/// if `graph` is disconnected.
138///
139/// # Example
140/// ```rust
141/// use std::convert::Infallible;
142///
143/// use rustworkx_core::petgraph::Graph;
144/// use rustworkx_core::petgraph::Undirected;
145/// use rustworkx_core::petgraph::graph::EdgeReference;
146/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef};
147///
148/// use rustworkx_core::steiner_tree::metric_closure;
149///
150/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[
151/// (0, 1, 10),
152/// (1, 2, 10),
153/// (2, 3, 10),
154/// (3, 4, 10),
155/// (4, 5, 10),
156/// (1, 6, 1),
157/// (6, 4, 1),
158/// ]);
159///
160/// let weight_fn = |e: EdgeReference<u8>| -> Result<f64, Infallible> {
161/// Ok(*e.weight() as f64)
162/// };
163///
164/// let closure = metric_closure(&input_graph, weight_fn).unwrap().unwrap();
165/// let mut output_edge_list: Vec<(usize, usize, (f64, Vec<usize>))> = closure.edge_references().map(|edge| (edge.source().index(), edge.target().index(), edge.weight().clone())).collect();
166/// let mut expected_edges: Vec<(usize, usize, (f64, Vec<usize>))> = vec![
167/// (0, 1, (10.0, vec![0, 1])),
168/// (0, 2, (20.0, vec![0, 1, 2])),
169/// (0, 3, (22.0, vec![0, 1, 6, 4, 3])),
170/// (0, 4, (12.0, vec![0, 1, 6, 4])),
171/// (0, 5, (22.0, vec![0, 1, 6, 4, 5])),
172/// (0, 6, (11.0, vec![0, 1, 6])),
173/// (1, 2, (10.0, vec![1, 2])),
174/// (1, 3, (12.0, vec![1, 6, 4, 3])),
175/// (1, 4, (2.0, vec![1, 6, 4])),
176/// (1, 5, (12.0, vec![1, 6, 4, 5])),
177/// (1, 6, (1.0, vec![1, 6])),
178/// (2, 3, (10.0, vec![2, 3])),
179/// (2, 4, (12.0, vec![2, 1, 6, 4])),
180/// (2, 5, (22.0, vec![2, 1, 6, 4, 5])),
181/// (2, 6, (11.0, vec![2, 1, 6])),
182/// (3, 4, (10.0, vec![3, 4])),
183/// (3, 5, (20.0, vec![3, 4, 5])),
184/// (3, 6, (11.0, vec![3, 4, 6])),
185/// (4, 5, (10.0, vec![4, 5])),
186/// (4, 6, (1.0, vec![4, 6])),
187/// (5, 6, (11.0, vec![5, 4, 6])),
188/// ];
189/// output_edge_list.sort_by_key(|x| [x.0, x.1]);
190/// expected_edges.sort_by_key(|x| [x.0, x.1]);
191/// assert_eq!(output_edge_list, expected_edges);
192///
193/// ```
194#[allow(clippy::type_complexity)]
195pub fn metric_closure<G, F, E>(
196 graph: G,
197 weight_fn: F,
198) -> Result<Option<StableGraph<G::NodeId, (f64, Vec<usize>), Undirected>>, E>
199where
200 G: NodeIndexable
201 + EdgeIndexable
202 + Sync
203 + EdgeCount
204 + NodeCount
205 + Visitable
206 + IntoNodeReferences
207 + IntoEdges
208 + Visitable
209 + GraphProp<EdgeType = Undirected>,
210 G::NodeId: Eq + Hash + NodeRef + Send,
211 G::EdgeId: Eq + Hash + Send,
212 G::NodeWeight: Clone,
213 F: FnMut(G::EdgeRef) -> Result<f64, E>,
214{
215 let mut out_graph: StableGraph<G::NodeId, (f64, Vec<usize>), Undirected> =
216 StableGraph::with_capacity(graph.node_count(), graph.edge_count());
217 let node_map: HashMap<usize, NodeIndex> = graph
218 .node_references()
219 .map(|node| {
220 (
221 NodeIndexable::to_index(&graph, node.id()),
222 out_graph.add_node(node.id()),
223 )
224 })
225 .collect();
226 let edges = metric_closure_edges(graph, weight_fn)?;
227 if edges.is_none() {
228 return Ok(None);
229 }
230 for edge in edges.unwrap() {
231 out_graph.add_edge(
232 node_map[&edge.source],
233 node_map[&edge.target],
234 (edge.distance, edge.path),
235 );
236 }
237 Ok(Some(out_graph))
238}
239
240fn metric_closure_edges<G, F, E>(
241 graph: G,
242 weight_fn: F,
243) -> Result<Option<Vec<MetricClosureEdge>>, E>
244where
245 G: NodeIndexable
246 + Sync
247 + Visitable
248 + IntoNodeReferences
249 + IntoEdges
250 + Visitable
251 + NodeIndexable
252 + NodeCount
253 + EdgeCount
254 + EdgeIndexable,
255 G::NodeId: Eq + Hash + Send,
256 G::EdgeId: Eq + Hash + Send,
257 F: FnMut(G::EdgeRef) -> Result<f64, E>,
258{
259 let node_count = graph.node_count();
260 if node_count == 0 {
261 return Ok(Some(Vec::new()));
262 }
263 let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2);
264 let paths = all_pairs_dijkstra_shortest_paths(graph, weight_fn)?;
265 let mut nodes: HashSet<usize> = graph
266 .node_identifiers()
267 .map(|x| NodeIndexable::to_index(&graph, x))
268 .collect();
269 let first_node = graph
270 .node_identifiers()
271 .map(|x| NodeIndexable::to_index(&graph, x))
272 .next()
273 .unwrap();
274 let path_keys: HashSet<usize> = paths[&first_node].0.keys().copied().collect();
275 // first_node will always be missing from path_keys so if the difference
276 // is > 1 with nodes that means there is another node in the graph that
277 // first_node doesn't have a path to.
278 if nodes.difference(&path_keys).count() > 1 {
279 return Ok(None);
280 }
281 // Iterate over node indices for a deterministic order
282 for node in graph
283 .node_identifiers()
284 .map(|x| NodeIndexable::to_index(&graph, x))
285 {
286 let path_map = &paths[&node].0;
287 nodes.remove(&node);
288 let distance = &paths[&node].1;
289 for v in &nodes {
290 out_vec.push(MetricClosureEdge {
291 source: node,
292 target: *v,
293 distance: distance[v],
294 path: path_map[v].clone(),
295 });
296 }
297 }
298 Ok(Some(out_vec))
299}
300
301/// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes`
302/// *provided* that:
303/// - there is an edge `(u, v)` in the graph and path pass through this edge.
304/// - node `s` is the closest node to `u` among all `terminal_nodes`
305/// - node `t` is the closest node to `v` among all `terminal_nodes`
306/// and wraps the result inside a `MetricClosureEdge`
307///
308/// For example, if all vertices are terminals, it returns the original edges of the graph.
309fn fast_metric_edges<G, F, E>(
310 in_graph: G,
311 terminal_nodes: &[G::NodeId],
312 mut weight_fn: F,
313) -> Result<Vec<MetricClosureEdge>, E>
314where
315 G: IntoEdges
316 + NodeIndexable
317 + EdgeIndexable
318 + Sync
319 + EdgeCount
320 + Visitable
321 + IntoNodeReferences
322 + NodeCount,
323 G::NodeId: Eq + Hash + Send,
324 G::EdgeId: Eq + Hash + Send,
325 F: FnMut(G::EdgeRef) -> Result<f64, E>,
326{
327 let mut graph: StableGraph<(), (), Undirected> = StableGraph::with_capacity(
328 in_graph.node_count() + 1,
329 in_graph.edge_count() + terminal_nodes.len(),
330 );
331 let node_map: HashMap<G::NodeId, NodeIndex> = in_graph
332 .node_references()
333 .map(|n| (n.id(), graph.add_node(())))
334 .collect();
335 let reverse_node_map: HashMap<NodeIndex, G::NodeId> =
336 node_map.iter().map(|(k, v)| (*v, *k)).collect();
337 let edge_map: HashMap<EdgeIndex, G::EdgeRef> = in_graph
338 .edge_references()
339 .map(|e| {
340 (
341 graph.add_edge(node_map[&e.source()], node_map[&e.target()], ()),
342 e,
343 )
344 })
345 .collect();
346
347 // temporarily add a ``dummy`` node, connect it with
348 // all the terminal nodes and find all the shortest paths
349 // starting from ``dummy`` node.
350 let dummy = graph.add_node(());
351 for node in terminal_nodes {
352 graph.add_edge(dummy, node_map[node], ());
353 }
354
355 let mut paths = DictMap::with_capacity(graph.node_count());
356
357 let mut wrapped_weight_fn =
358 |e: <&StableGraph<(), ()> as IntoEdgeReferences>::EdgeRef| -> Result<f64, E> {
359 if let Some(edge_ref) = edge_map.get(&e.id()) {
360 weight_fn(*edge_ref)
361 } else {
362 Ok(0.0)
363 }
364 };
365
366 let mut distance: DictMap<NodeIndex, f64> = dijkstra(
367 &graph,
368 dummy,
369 None,
370 &mut wrapped_weight_fn,
371 Some(&mut paths),
372 )?;
373 paths.swap_remove(&dummy);
374 distance.swap_remove(&dummy);
375
376 // ``partition[u]`` holds the terminal node closest to node ``u``.
377 let mut partition: Vec<usize> = vec![usize::MAX; graph.node_bound()];
378 for (u, path) in paths.iter() {
379 let u = NodeIndexable::to_index(&in_graph, reverse_node_map[u]);
380 partition[u] = NodeIndexable::to_index(&in_graph, reverse_node_map[&path[1]]);
381 }
382
383 let mut out_edges: Vec<MetricClosureEdge> = Vec::with_capacity(graph.edge_count());
384
385 for edge in graph.edge_references() {
386 let source = edge.source();
387 let target = edge.target();
388 // assert that ``source`` is reachable from a terminal node.
389 if distance.contains_key(&source) {
390 let weight = distance[&source] + wrapped_weight_fn(edge)? + distance[&target];
391 let mut path: Vec<usize> = paths[&source]
392 .iter()
393 .skip(1)
394 .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x]))
395 .collect();
396 path.append(
397 &mut paths[&target]
398 .iter()
399 .skip(1)
400 .rev()
401 .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x]))
402 .collect(),
403 );
404
405 let source = NodeIndexable::to_index(&in_graph, reverse_node_map[&source]);
406 let target = NodeIndexable::to_index(&in_graph, reverse_node_map[&target]);
407
408 let mut source = partition[source];
409 let mut target = partition[target];
410
411 match source.cmp(&target) {
412 Ordering::Equal => continue,
413 Ordering::Greater => std::mem::swap(&mut source, &mut target),
414 _ => {}
415 }
416
417 out_edges.push(MetricClosureEdge {
418 source,
419 target,
420 distance: weight,
421 path,
422 });
423 }
424 }
425
426 // if parallel edges, keep the edge with minimum distance.
427 out_edges.par_sort_unstable_by(|a, b| {
428 let weight_a = (a.source, a.target, a.distance);
429 let weight_b = (b.source, b.target, b.distance);
430 weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
431 });
432
433 out_edges.dedup_by(|edge_a, edge_b| {
434 edge_a.source == edge_b.source && edge_a.target == edge_b.target
435 });
436
437 Ok(out_edges)
438}
439
440/// Solution to a minimum Steiner tree problem.
441///
442/// This `struct` is created by the [steiner_tree] function.
443pub struct SteinerTreeResult {
444 pub used_node_indices: HashSet<usize>,
445 pub used_edge_endpoints: HashSet<(usize, usize)>,
446}
447
448/// Return an approximation to the minimum Steiner tree of a graph.
449///
450/// The minimum tree of ``graph`` with regard to a set of ``terminal_nodes``
451/// is a tree within ``graph`` that spans those nodes and has a minimum size
452/// (measured as the sum of edge weights) among all such trees.
453///
454/// The minimum steiner tree can be approximated by computing the minimum
455/// spanning tree of the subgraph of the metric closure of ``graph`` induced
456/// by the terminal nodes, where the metric closure of ``graph`` is the
457/// complete graph in which each edge is weighted by the shortest path distance
458/// between nodes in ``graph``.
459///
460/// This algorithm by Kou, Markowsky, and Berman[^KouMarkowskyBerman1981]
461/// produces a tree whose weight is within a `(2 - (2 / t))` factor of
462/// the weight of the optimal Steiner tree where `t` is the number of
463/// terminal nodes.
464/// The algorithm implemented here is due to Mehlhorn[^Mehlhorn1987]. It avoids
465/// computing all pairs shortest paths but rather reduces the problem to a
466/// single source shortest path and a minimum spanning tree problem.
467///
468/// # Arguments
469///
470/// - `graph` - The input graph to compute the Steiner tree of
471/// - `terminal_nodes` - The terminal nodes of the Steiner tree
472/// - `weight_fn` - A callable weight function that will be passed an edge reference
473/// for each edge in the graph and it is expected to return a [`Result<f64>`]
474/// which if it doesn't error represents the weight of that edge.
475///
476/// # Returns
477///
478/// A custom struct that contains a set of nodes and edges and `None`
479/// if the graph is disconnected relative to the terminal nodes.
480///
481/// # Example
482///
483/// ```rust
484/// use std::convert::Infallible;
485///
486/// use rustworkx_core::petgraph::Graph;
487/// use rustworkx_core::petgraph::graph::NodeIndex;
488/// use rustworkx_core::petgraph::Undirected;
489/// use rustworkx_core::petgraph::graph::EdgeReference;
490/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef};
491///
492/// use rustworkx_core::steiner_tree::steiner_tree;
493///
494/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[
495/// (0, 1, 10),
496/// (1, 2, 10),
497/// (2, 3, 10),
498/// (3, 4, 10),
499/// (4, 5, 10),
500/// (1, 6, 1),
501/// (6, 4, 1),
502/// ]);
503///
504/// let weight_fn = |e: EdgeReference<u8>| -> Result<f64, Infallible> {
505/// Ok(*e.weight() as f64)
506/// };
507/// let terminal_nodes = vec![
508/// NodeIndex::new(0),
509/// NodeIndex::new(1),
510/// NodeIndex::new(2),
511/// NodeIndex::new(3),
512/// NodeIndex::new(4),
513/// NodeIndex::new(5),
514/// ];
515///
516/// let tree = steiner_tree(&input_graph, &terminal_nodes, weight_fn).unwrap().unwrap();
517/// ```
518///
519/// [^KouMarkowskyBerman1981]: Kou, Markowsky & Berman,
520/// "A fast algorithm for Steiner trees"
521/// Acta Informatica 15, 141–145 (1981)
522/// <https://link.springer.com/article/10.1007/BF00288961>
523/// [^Mehlhorn1987]: Kurt Mehlhorn,
524/// "A faster approximation algorithm for the Steiner problem in graphs"
525/// Information Processing Letters 27(3), 125-128 (1987)
526/// <https://doi.org/10.1016/0020-0190(88)90066-X>
527pub fn steiner_tree<G, F, E>(
528 graph: G,
529 terminal_nodes: &[G::NodeId],
530 weight_fn: F,
531) -> Result<Option<SteinerTreeResult>, E>
532where
533 G: IntoEdges
534 + NodeIndexable
535 + Sync
536 + EdgeCount
537 + IntoNodeReferences
538 + EdgeIndexable
539 + Visitable
540 + NodeCount,
541 G::NodeId: Eq + Hash + Send,
542 G::EdgeId: Eq + Hash + Send,
543 F: FnMut(G::EdgeRef) -> Result<f64, E>,
544{
545 let node_bound = graph.node_bound();
546 let mut edge_list = fast_metric_edges(graph, terminal_nodes, weight_fn)?;
547 let mut subgraphs = UnionFind::<usize>::new(node_bound);
548 edge_list.par_sort_unstable_by(|a, b| {
549 let weight_a = (a.distance, a.source, a.target);
550 let weight_b = (b.distance, b.source, b.target);
551 weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
552 });
553 let mut mst_edges: Vec<MetricClosureEdge> = Vec::new();
554 for float_edge_pair in edge_list {
555 let u = float_edge_pair.source;
556 let v = float_edge_pair.target;
557 if subgraphs.union(u, v) {
558 mst_edges.push(float_edge_pair);
559 }
560 }
561 // assert that the terminal nodes are connected.
562 if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 {
563 return Ok(None);
564 }
565 // Generate the output graph from the MST
566 let out_edge_list: Vec<[usize; 2]> = mst_edges
567 .into_iter()
568 .flat_map(|edge| pairwise(edge.path))
569 .filter_map(|x| x.0.map(|a| [a, x.1]))
570 .collect();
571 let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect();
572 let out_nodes: HashSet<usize> = out_edge_list
573 .iter()
574 .flat_map(|x| x.iter())
575 .copied()
576 .collect();
577 Ok(Some(SteinerTreeResult {
578 used_node_indices: out_nodes,
579 used_edge_endpoints: out_edges,
580 }))
581}