rustworkx_core/
steiner_tree.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use std::cmp::{Eq, Ordering};
14use std::convert::Infallible;
15use std::hash::Hash;
16
17use hashbrown::{HashMap, HashSet};
18use rayon::prelude::*;
19
20use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph};
21use petgraph::unionfind::UnionFind;
22use petgraph::visit::{
23    EdgeCount, EdgeIndexable, EdgeRef, GraphProp, IntoEdgeReferences, IntoEdges,
24    IntoNodeIdentifiers, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef, Visitable,
25};
26use petgraph::Undirected;
27
28use crate::dictmap::*;
29use crate::shortest_path::dijkstra;
30use crate::utils::pairwise;
31
32type AllPairsDijkstraReturn = HashMap<usize, (DictMap<usize, Vec<usize>>, DictMap<usize, f64>)>;
33
34fn all_pairs_dijkstra_shortest_paths<G, F, E>(
35    graph: G,
36    mut weight_fn: F,
37) -> Result<AllPairsDijkstraReturn, E>
38where
39    G: NodeIndexable
40        + IntoNodeIdentifiers
41        + EdgeCount
42        + NodeCount
43        + EdgeIndexable
44        + Visitable
45        + Sync
46        + IntoEdges,
47    G::NodeId: Eq + Hash + Send,
48    G::EdgeId: Eq + Hash + Send,
49    F: FnMut(G::EdgeRef) -> Result<f64, E>,
50{
51    if graph.node_count() == 0 {
52        return Ok(HashMap::new());
53    } else if graph.edge_count() == 0 {
54        return Ok(graph
55            .node_identifiers()
56            .map(|x| {
57                (
58                    NodeIndexable::to_index(&graph, x),
59                    (DictMap::new(), DictMap::new()),
60                )
61            })
62            .collect());
63    }
64
65    let mut edge_weights: Vec<Option<f64>> = vec![None; graph.edge_bound()];
66    for edge in graph.edge_references() {
67        let index = EdgeIndexable::to_index(&graph, edge.id());
68        edge_weights[index] = Some(weight_fn(edge)?);
69    }
70    let edge_cost = |e: G::EdgeRef| -> Result<f64, Infallible> {
71        Ok(edge_weights[EdgeIndexable::to_index(&graph, e.id())].unwrap())
72    };
73
74    let node_indices: Vec<usize> = graph
75        .node_identifiers()
76        .map(|n| NodeIndexable::to_index(&graph, n))
77        .collect();
78    Ok(node_indices
79        .into_par_iter()
80        .map(|x| {
81            let mut paths: DictMap<G::NodeId, Vec<G::NodeId>> =
82                DictMap::with_capacity(graph.node_count());
83            let distances: DictMap<G::NodeId, f64> = dijkstra(
84                graph,
85                NodeIndexable::from_index(&graph, x),
86                None,
87                edge_cost,
88                Some(&mut paths),
89            )
90            .unwrap();
91            (
92                x,
93                (
94                    paths
95                        .into_iter()
96                        .map(|(k, v)| {
97                            (
98                                NodeIndexable::to_index(&graph, k),
99                                v.into_iter()
100                                    .map(|n| NodeIndexable::to_index(&graph, n))
101                                    .collect(),
102                            )
103                        })
104                        .collect(),
105                    distances
106                        .into_iter()
107                        .map(|(k, v)| (NodeIndexable::to_index(&graph, k), v))
108                        .collect(),
109                ),
110            )
111        })
112        .collect())
113}
114
115struct MetricClosureEdge {
116    source: usize,
117    target: usize,
118    distance: f64,
119    path: Vec<usize>,
120}
121
122/// Return the metric closure of a graph
123///
124/// The metric closure of a graph is the complete graph in which each edge is
125/// weighted by the shortest path distance between the nodes in the graph.
126///
127/// Arguments:
128///     `graph`: The input graph to compute the metric closure for
129///     `weight_fn`: A callable weight function that will be passed an edge reference
130///         for each edge in the graph and it is expected to return a `Result<f64>`
131///         which if it doesn't error represents the weight of that edge.
132///     `default_weight`: A blind callable that returns a default weight to use for
133///         edges added to the output
134///
135/// Returns a `StableGraph` with the input graph node ids for node weights and edge weights with a
136///     tuple of the numeric weight (found via `weight_fn`) and the path. The output will be `None`
137///     if `graph` is disconnected.
138///
139/// # Example
140/// ```rust
141/// use std::convert::Infallible;
142///
143/// use rustworkx_core::petgraph::Graph;
144/// use rustworkx_core::petgraph::Undirected;
145/// use rustworkx_core::petgraph::graph::EdgeReference;
146/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef};
147///
148/// use rustworkx_core::steiner_tree::metric_closure;
149///
150/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[
151///     (0, 1, 10),
152///     (1, 2, 10),
153///     (2, 3, 10),
154///     (3, 4, 10),
155///     (4, 5, 10),
156///     (1, 6, 1),
157///     (6, 4, 1),
158/// ]);
159///
160/// let weight_fn = |e: EdgeReference<u8>| -> Result<f64, Infallible> {
161///    Ok(*e.weight() as f64)
162/// };
163///
164/// let closure = metric_closure(&input_graph, weight_fn).unwrap().unwrap();
165/// let mut output_edge_list: Vec<(usize, usize, (f64, Vec<usize>))> = closure.edge_references().map(|edge| (edge.source().index(), edge.target().index(), edge.weight().clone())).collect();
166/// let mut expected_edges: Vec<(usize, usize, (f64, Vec<usize>))> = vec![
167///     (0, 1, (10.0, vec![0, 1])),
168///     (0, 2, (20.0, vec![0, 1, 2])),
169///     (0, 3, (22.0, vec![0, 1, 6, 4, 3])),
170///     (0, 4, (12.0, vec![0, 1, 6, 4])),
171///     (0, 5, (22.0, vec![0, 1, 6, 4, 5])),
172///     (0, 6, (11.0, vec![0, 1, 6])),
173///     (1, 2, (10.0, vec![1, 2])),
174///     (1, 3, (12.0, vec![1, 6, 4, 3])),
175///     (1, 4, (2.0, vec![1, 6, 4])),
176///     (1, 5, (12.0, vec![1, 6, 4, 5])),
177///     (1, 6, (1.0, vec![1, 6])),
178///     (2, 3, (10.0, vec![2, 3])),
179///     (2, 4, (12.0, vec![2, 1, 6, 4])),
180///     (2, 5, (22.0, vec![2, 1, 6, 4, 5])),
181///     (2, 6, (11.0, vec![2, 1, 6])),
182///     (3, 4, (10.0, vec![3, 4])),
183///     (3, 5, (20.0, vec![3, 4, 5])),
184///     (3, 6, (11.0, vec![3, 4, 6])),
185///     (4, 5, (10.0, vec![4, 5])),
186///     (4, 6, (1.0, vec![4, 6])),
187///     (5, 6, (11.0, vec![5, 4, 6])),
188/// ];
189/// output_edge_list.sort_by_key(|x| [x.0, x.1]);
190/// expected_edges.sort_by_key(|x| [x.0, x.1]);
191/// assert_eq!(output_edge_list, expected_edges);
192///
193/// ```
194#[allow(clippy::type_complexity)]
195pub fn metric_closure<G, F, E>(
196    graph: G,
197    weight_fn: F,
198) -> Result<Option<StableGraph<G::NodeId, (f64, Vec<usize>), Undirected>>, E>
199where
200    G: NodeIndexable
201        + EdgeIndexable
202        + Sync
203        + EdgeCount
204        + NodeCount
205        + Visitable
206        + IntoNodeReferences
207        + IntoEdges
208        + Visitable
209        + GraphProp<EdgeType = Undirected>,
210    G::NodeId: Eq + Hash + NodeRef + Send,
211    G::EdgeId: Eq + Hash + Send,
212    G::NodeWeight: Clone,
213    F: FnMut(G::EdgeRef) -> Result<f64, E>,
214{
215    let mut out_graph: StableGraph<G::NodeId, (f64, Vec<usize>), Undirected> =
216        StableGraph::with_capacity(graph.node_count(), graph.edge_count());
217    let node_map: HashMap<usize, NodeIndex> = graph
218        .node_references()
219        .map(|node| {
220            (
221                NodeIndexable::to_index(&graph, node.id()),
222                out_graph.add_node(node.id()),
223            )
224        })
225        .collect();
226    let edges = metric_closure_edges(graph, weight_fn)?;
227    if edges.is_none() {
228        return Ok(None);
229    }
230    for edge in edges.unwrap() {
231        out_graph.add_edge(
232            node_map[&edge.source],
233            node_map[&edge.target],
234            (edge.distance, edge.path),
235        );
236    }
237    Ok(Some(out_graph))
238}
239
240fn metric_closure_edges<G, F, E>(
241    graph: G,
242    weight_fn: F,
243) -> Result<Option<Vec<MetricClosureEdge>>, E>
244where
245    G: NodeIndexable
246        + Sync
247        + Visitable
248        + IntoNodeReferences
249        + IntoEdges
250        + Visitable
251        + NodeIndexable
252        + NodeCount
253        + EdgeCount
254        + EdgeIndexable,
255    G::NodeId: Eq + Hash + Send,
256    G::EdgeId: Eq + Hash + Send,
257    F: FnMut(G::EdgeRef) -> Result<f64, E>,
258{
259    let node_count = graph.node_count();
260    if node_count == 0 {
261        return Ok(Some(Vec::new()));
262    }
263    let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2);
264    let paths = all_pairs_dijkstra_shortest_paths(graph, weight_fn)?;
265    let mut nodes: HashSet<usize> = graph
266        .node_identifiers()
267        .map(|x| NodeIndexable::to_index(&graph, x))
268        .collect();
269    let first_node = graph
270        .node_identifiers()
271        .map(|x| NodeIndexable::to_index(&graph, x))
272        .next()
273        .unwrap();
274    let path_keys: HashSet<usize> = paths[&first_node].0.keys().copied().collect();
275    // first_node will always be missing from path_keys so if the difference
276    // is > 1 with nodes that means there is another node in the graph that
277    // first_node doesn't have a path to.
278    if nodes.difference(&path_keys).count() > 1 {
279        return Ok(None);
280    }
281    // Iterate over node indices for a deterministic order
282    for node in graph
283        .node_identifiers()
284        .map(|x| NodeIndexable::to_index(&graph, x))
285    {
286        let path_map = &paths[&node].0;
287        nodes.remove(&node);
288        let distance = &paths[&node].1;
289        for v in &nodes {
290            out_vec.push(MetricClosureEdge {
291                source: node,
292                target: *v,
293                distance: distance[v],
294                path: path_map[v].clone(),
295            });
296        }
297    }
298    Ok(Some(out_vec))
299}
300
301/// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes`
302/// *provided* that:
303///   - there is an edge `(u, v)` in the graph and path pass through this edge.
304///   - node `s` is the closest node to  `u` among all `terminal_nodes`
305///   - node `t` is the closest node to `v` among all `terminal_nodes`
306///     and wraps the result inside a `MetricClosureEdge`
307///
308/// For example, if all vertices are terminals, it returns the original edges of the graph.
309fn fast_metric_edges<G, F, E>(
310    in_graph: G,
311    terminal_nodes: &[G::NodeId],
312    mut weight_fn: F,
313) -> Result<Vec<MetricClosureEdge>, E>
314where
315    G: IntoEdges
316        + NodeIndexable
317        + EdgeIndexable
318        + Sync
319        + EdgeCount
320        + Visitable
321        + IntoNodeReferences
322        + NodeCount,
323    G::NodeId: Eq + Hash + Send,
324    G::EdgeId: Eq + Hash + Send,
325    F: FnMut(G::EdgeRef) -> Result<f64, E>,
326{
327    let mut graph: StableGraph<(), (), Undirected> = StableGraph::with_capacity(
328        in_graph.node_count() + 1,
329        in_graph.edge_count() + terminal_nodes.len(),
330    );
331    let node_map: HashMap<G::NodeId, NodeIndex> = in_graph
332        .node_references()
333        .map(|n| (n.id(), graph.add_node(())))
334        .collect();
335    let reverse_node_map: HashMap<NodeIndex, G::NodeId> =
336        node_map.iter().map(|(k, v)| (*v, *k)).collect();
337    let edge_map: HashMap<EdgeIndex, G::EdgeRef> = in_graph
338        .edge_references()
339        .map(|e| {
340            (
341                graph.add_edge(node_map[&e.source()], node_map[&e.target()], ()),
342                e,
343            )
344        })
345        .collect();
346
347    // temporarily add a ``dummy`` node, connect it with
348    // all the terminal nodes and find all the shortest paths
349    // starting from ``dummy`` node.
350    let dummy = graph.add_node(());
351    for node in terminal_nodes {
352        graph.add_edge(dummy, node_map[node], ());
353    }
354
355    let mut paths = DictMap::with_capacity(graph.node_count());
356
357    let mut wrapped_weight_fn =
358        |e: <&StableGraph<(), ()> as IntoEdgeReferences>::EdgeRef| -> Result<f64, E> {
359            if let Some(edge_ref) = edge_map.get(&e.id()) {
360                weight_fn(*edge_ref)
361            } else {
362                Ok(0.0)
363            }
364        };
365
366    let mut distance: DictMap<NodeIndex, f64> = dijkstra(
367        &graph,
368        dummy,
369        None,
370        &mut wrapped_weight_fn,
371        Some(&mut paths),
372    )?;
373    paths.swap_remove(&dummy);
374    distance.swap_remove(&dummy);
375
376    // ``partition[u]`` holds the terminal node closest to node ``u``.
377    let mut partition: Vec<usize> = vec![usize::MAX; graph.node_bound()];
378    for (u, path) in paths.iter() {
379        let u = NodeIndexable::to_index(&in_graph, reverse_node_map[u]);
380        partition[u] = NodeIndexable::to_index(&in_graph, reverse_node_map[&path[1]]);
381    }
382
383    let mut out_edges: Vec<MetricClosureEdge> = Vec::with_capacity(graph.edge_count());
384
385    for edge in graph.edge_references() {
386        let source = edge.source();
387        let target = edge.target();
388        // assert that ``source`` is reachable from a terminal node.
389        if distance.contains_key(&source) {
390            let weight = distance[&source] + wrapped_weight_fn(edge)? + distance[&target];
391            let mut path: Vec<usize> = paths[&source]
392                .iter()
393                .skip(1)
394                .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x]))
395                .collect();
396            path.append(
397                &mut paths[&target]
398                    .iter()
399                    .skip(1)
400                    .rev()
401                    .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x]))
402                    .collect(),
403            );
404
405            let source = NodeIndexable::to_index(&in_graph, reverse_node_map[&source]);
406            let target = NodeIndexable::to_index(&in_graph, reverse_node_map[&target]);
407
408            let mut source = partition[source];
409            let mut target = partition[target];
410
411            match source.cmp(&target) {
412                Ordering::Equal => continue,
413                Ordering::Greater => std::mem::swap(&mut source, &mut target),
414                _ => {}
415            }
416
417            out_edges.push(MetricClosureEdge {
418                source,
419                target,
420                distance: weight,
421                path,
422            });
423        }
424    }
425
426    // if parallel edges, keep the edge with minimum distance.
427    out_edges.par_sort_unstable_by(|a, b| {
428        let weight_a = (a.source, a.target, a.distance);
429        let weight_b = (b.source, b.target, b.distance);
430        weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
431    });
432
433    out_edges.dedup_by(|edge_a, edge_b| {
434        edge_a.source == edge_b.source && edge_a.target == edge_b.target
435    });
436
437    Ok(out_edges)
438}
439
440/// Solution to a minimum Steiner tree problem.
441///
442/// This `struct` is created by the [steiner_tree] function.
443pub struct SteinerTreeResult {
444    pub used_node_indices: HashSet<usize>,
445    pub used_edge_endpoints: HashSet<(usize, usize)>,
446}
447
448/// Return an approximation to the minimum Steiner tree of a graph.
449///
450/// The minimum tree of ``graph`` with regard to a set of ``terminal_nodes``
451/// is a tree within ``graph`` that spans those nodes and has a minimum size
452/// (measured as the sum of edge weights) among all such trees.
453///
454/// The minimum steiner tree can be approximated by computing the minimum
455/// spanning tree of the subgraph of the metric closure of ``graph`` induced
456/// by the terminal nodes, where the metric closure of ``graph`` is the
457/// complete graph in which each edge is weighted by the shortest path distance
458/// between nodes in ``graph``.
459///
460/// This algorithm by Kou, Markowsky, and Berman[^KouMarkowskyBerman1981]
461/// produces a tree whose weight is within a `(2 - (2 / t))` factor of
462/// the weight of the optimal Steiner tree where `t` is the number of
463/// terminal nodes.
464/// The algorithm implemented here is due to Mehlhorn[^Mehlhorn1987]. It avoids
465/// computing all pairs shortest paths but rather reduces the problem to a
466/// single source shortest path and a minimum spanning tree problem.
467///
468/// # Arguments
469///
470/// - `graph` -  The input graph to compute the Steiner tree of
471/// - `terminal_nodes` - The terminal nodes of the Steiner tree
472/// - `weight_fn` - A callable weight function that will be passed an edge reference
473///   for each edge in the graph and it is expected to return a [`Result<f64>`]
474///   which if it doesn't error represents the weight of that edge.
475///
476/// # Returns
477///
478/// A custom struct that contains a set of nodes and edges and `None`
479/// if the graph is disconnected relative to the terminal nodes.
480///
481/// # Example
482///
483/// ```rust
484/// use std::convert::Infallible;
485///
486/// use rustworkx_core::petgraph::Graph;
487/// use rustworkx_core::petgraph::graph::NodeIndex;
488/// use rustworkx_core::petgraph::Undirected;
489/// use rustworkx_core::petgraph::graph::EdgeReference;
490/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef};
491///
492/// use rustworkx_core::steiner_tree::steiner_tree;
493///
494/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[
495///     (0, 1, 10),
496///     (1, 2, 10),
497///     (2, 3, 10),
498///     (3, 4, 10),
499///     (4, 5, 10),
500///     (1, 6, 1),
501///     (6, 4, 1),
502/// ]);
503///
504/// let weight_fn = |e: EdgeReference<u8>| -> Result<f64, Infallible> {
505///    Ok(*e.weight() as f64)
506/// };
507/// let terminal_nodes = vec![
508///     NodeIndex::new(0),
509///     NodeIndex::new(1),
510///     NodeIndex::new(2),
511///     NodeIndex::new(3),
512///     NodeIndex::new(4),
513///     NodeIndex::new(5),
514/// ];
515///
516/// let tree = steiner_tree(&input_graph, &terminal_nodes, weight_fn).unwrap().unwrap();
517/// ```
518///
519/// [^KouMarkowskyBerman1981]: Kou, Markowsky & Berman,
520///    "A fast algorithm for Steiner trees"
521///    Acta Informatica 15, 141–145 (1981)
522///    <https://link.springer.com/article/10.1007/BF00288961>
523/// [^Mehlhorn1987]: Kurt Mehlhorn,
524///    "A faster approximation algorithm for the Steiner problem in graphs"
525///    Information Processing Letters 27(3), 125-128 (1987)
526///    <https://doi.org/10.1016/0020-0190(88)90066-X>
527pub fn steiner_tree<G, F, E>(
528    graph: G,
529    terminal_nodes: &[G::NodeId],
530    weight_fn: F,
531) -> Result<Option<SteinerTreeResult>, E>
532where
533    G: IntoEdges
534        + NodeIndexable
535        + Sync
536        + EdgeCount
537        + IntoNodeReferences
538        + EdgeIndexable
539        + Visitable
540        + NodeCount,
541    G::NodeId: Eq + Hash + Send,
542    G::EdgeId: Eq + Hash + Send,
543    F: FnMut(G::EdgeRef) -> Result<f64, E>,
544{
545    let node_bound = graph.node_bound();
546    let mut edge_list = fast_metric_edges(graph, terminal_nodes, weight_fn)?;
547    let mut subgraphs = UnionFind::<usize>::new(node_bound);
548    edge_list.par_sort_unstable_by(|a, b| {
549        let weight_a = (a.distance, a.source, a.target);
550        let weight_b = (b.distance, b.source, b.target);
551        weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
552    });
553    let mut mst_edges: Vec<MetricClosureEdge> = Vec::new();
554    for float_edge_pair in edge_list {
555        let u = float_edge_pair.source;
556        let v = float_edge_pair.target;
557        if subgraphs.union(u, v) {
558            mst_edges.push(float_edge_pair);
559        }
560    }
561    // assert that the terminal nodes are connected.
562    if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 {
563        return Ok(None);
564    }
565    // Generate the output graph from the MST
566    let out_edge_list: Vec<[usize; 2]> = mst_edges
567        .into_iter()
568        .flat_map(|edge| pairwise(edge.path))
569        .filter_map(|x| x.0.map(|a| [a, x.1]))
570        .collect();
571    let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect();
572    let out_nodes: HashSet<usize> = out_edge_list
573        .iter()
574        .flat_map(|x| x.iter())
575        .copied()
576        .collect();
577    Ok(Some(SteinerTreeResult {
578        used_node_indices: out_nodes,
579        used_edge_endpoints: out_edges,
580    }))
581}