rustworkx_core/
dag_algo.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use std::cmp::{Eq, Ordering};
14use std::collections::BinaryHeap;
15use std::error::Error;
16use std::fmt::{Display, Formatter};
17use std::hash::Hash;
18
19use hashbrown::{HashMap, HashSet};
20use std::fmt::Debug;
21use std::mem::swap;
22
23use petgraph::algo;
24use petgraph::data::DataMap;
25use petgraph::visit::{
26    EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
27    NodeCount, NodeIndexable, Visitable,
28};
29use petgraph::Directed;
30
31use num_traits::{Num, Zero};
32
33use crate::err::LayersError;
34
35/// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards"
36/// direction of graph traversal, based on whether the graph is being traversed forwards (following
37/// the edges) or backward (reversing along edges).  The order of returns is (forwards, backwards).
38#[inline(always)]
39pub fn traversal_directions(reverse: bool) -> (petgraph::Direction, petgraph::Direction) {
40    if reverse {
41        (petgraph::Direction::Outgoing, petgraph::Direction::Incoming)
42    } else {
43        (petgraph::Direction::Incoming, petgraph::Direction::Outgoing)
44    }
45}
46
47/// An error enumeration returned by topological sort functions.
48#[derive(Debug, PartialEq, Eq)]
49pub enum TopologicalSortError<E: Error> {
50    CycleOrBadInitialState,
51    KeyError(E),
52}
53
54impl<E: Error> Display for TopologicalSortError<E> {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        match self {
57            TopologicalSortError::CycleOrBadInitialState => {
58                write!(f, "At least one initial node is reachable from another")
59            }
60            TopologicalSortError::KeyError(ref e) => {
61                write!(f, "The key callback failed with: {:?}", e)
62            }
63        }
64    }
65}
66
67impl<E: Error> Error for TopologicalSortError<E> {}
68
69/// Get the lexicographical topological sorted nodes from the provided DAG.
70///
71/// This function returns a list of nodes in a graph lexicographically
72/// topologically sorted, using the provided key function as a tie-breaker.
73/// A topological sort is a linear ordering of vertices such that for every
74/// directed edge from node :math:`u` to node :math:`v`, :math:`u` comes before
75/// :math:`v` in the ordering.  If ``reverse`` is set to ``true``, the edges
76/// are treated as if they pointed in the opposite direction.
77///
78/// # Arguments:
79///
80/// * `dag`: The DAG to get the topological sorted nodes from
81/// * `key`: A function that gets passed a single argument, the node id from
82///     `dag` and is expected to return a key which will be used for
83///     resolving ties in the sorting order.
84/// * `reverse`: If `false`, perform a regular topological ordering.  If `true`,
85///     return the lexicographical topological order that would have been found
86///     if all the edges in the graph were reversed.  This does not affect the
87///     comparisons from the `key`.
88/// * `initial`: If given, the initial node indices to start the topological
89///     ordering from.  If not given, the topological ordering will certainly contain every node in
90///     the graph.  If given, only the `initial` nodes and nodes that are dominated by the
91///     `initial` set will be in the ordering.  Notably, any node that has a natural in degree of
92///     zero will not be in the output ordering if `initial` is given and the zero-in-degree node
93///     is not in it.  It is not supported to give an `initial` set where the nodes have even
94///     a partial topological order between themselves and `None` will be returned in this case
95///
96/// # Returns
97///
98/// `Vec<G::NodeId>` representing the topological ordering of nodes or a
99/// `TopologicalSortError` if there is an error.
100///
101/// # Example
102///
103/// ```rust
104/// use std::convert::Infallible;
105///
106/// use rustworkx_core::dag_algo::lexicographical_topological_sort;
107/// use rustworkx_core::petgraph::stable_graph::{StableDiGraph, NodeIndex};
108///
109/// let mut graph: StableDiGraph<u8, ()> = StableDiGraph::new();
110/// let mut nodes: Vec<NodeIndex> = Vec::new();
111/// for weight in 0..9 {
112///     nodes.push(graph.add_node(weight));
113/// }
114/// let edges = [
115///         (nodes[0], nodes[1]),
116///         (nodes[0], nodes[2]),
117///         (nodes[1], nodes[3]),
118///         (nodes[2], nodes[4]),
119///         (nodes[3], nodes[4]),
120///         (nodes[4], nodes[5]),
121///         (nodes[5], nodes[6]),
122///         (nodes[4], nodes[7]),
123///         (nodes[6], nodes[8]),
124///         (nodes[7], nodes[8]),
125/// ];
126/// for (source, target) in edges {
127///     graph.add_edge(source, target, ());
128/// }
129/// let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].to_string()) };
130/// let initial = [nodes[6], nodes[7]];
131/// let result = lexicographical_topological_sort(&graph, sort_fn, true, Some(&initial));
132/// let expected = vec![
133///     nodes[6],
134///     nodes[5],
135///     nodes[7],
136///     nodes[4],
137///     nodes[2],
138///     nodes[3],
139///     nodes[1],
140///     nodes[0]
141/// ];
142/// assert_eq!(result, Ok(expected));
143///
144/// ```
145pub fn lexicographical_topological_sort<G, F, K, E>(
146    dag: G,
147    mut key: F,
148    reverse: bool,
149    initial: Option<&[G::NodeId]>,
150) -> Result<Vec<G::NodeId>, TopologicalSortError<E>>
151where
152    G: GraphProp<EdgeType = Directed>
153        + IntoNodeIdentifiers
154        + IntoNeighborsDirected
155        + IntoEdgesDirected
156        + NodeCount,
157    <G as GraphBase>::NodeId: Hash + Eq + Ord,
158    F: FnMut(G::NodeId) -> Result<K, E>,
159    K: Ord,
160    E: Error,
161{
162    // HashMap of node_index indegree
163    let node_count = dag.node_count();
164    let (in_dir, out_dir) = traversal_directions(reverse);
165
166    #[derive(Clone, Eq, PartialEq)]
167    struct State<K: Ord, N: Eq + PartialOrd> {
168        key: K,
169        node: N,
170    }
171
172    impl<K: Ord, N: Eq + Ord> Ord for State<K, N> {
173        fn cmp(&self, other: &State<K, N>) -> Ordering {
174            // Notice that we flip the ordering on costs.
175            // In case of a tie, we compare positions - this step is necessary
176            // to make implementations of `PartialEq` and `Ord` consistent.
177            other
178                .key
179                .cmp(&self.key)
180                .then_with(|| other.node.cmp(&self.node))
181        }
182    }
183
184    // `PartialOrd` needs to be implemented as well.
185    impl<K: Ord, N: Eq + Ord> PartialOrd for State<K, N> {
186        fn partial_cmp(&self, other: &State<K, N>) -> Option<Ordering> {
187            Some(self.cmp(other))
188        }
189    }
190
191    let mut in_degree_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(node_count);
192    if let Some(initial) = initial {
193        // In this case, we don't iterate through all the nodes in the graph, and most nodes aren't
194        // in `in_degree_map`; we'll fill in the relevant edge counts lazily.
195        for node in initial {
196            in_degree_map.insert(*node, 0);
197        }
198    } else {
199        for node in dag.node_identifiers() {
200            in_degree_map.insert(node, dag.edges_directed(node, in_dir).count());
201        }
202    }
203
204    let mut zero_indegree = BinaryHeap::with_capacity(node_count);
205    for (node, degree) in in_degree_map.iter() {
206        if *degree == 0 {
207            let map_key = key(*node).map_err(|e| TopologicalSortError::KeyError(e))?;
208            zero_indegree.push(State {
209                key: map_key,
210                node: *node,
211            });
212        }
213    }
214    let mut out_list: Vec<G::NodeId> = Vec::with_capacity(node_count);
215    while let Some(State { node, .. }) = zero_indegree.pop() {
216        let neighbors = dag.neighbors_directed(node, out_dir);
217        for child in neighbors {
218            let child_degree = in_degree_map
219                .entry(child)
220                .or_insert_with(|| dag.edges_directed(child, in_dir).count());
221            if *child_degree == 0 {
222                return Err(TopologicalSortError::CycleOrBadInitialState);
223            } else if *child_degree == 1 {
224                let map_key = key(child).map_err(|e| TopologicalSortError::KeyError(e))?;
225                zero_indegree.push(State {
226                    key: map_key,
227                    node: child,
228                });
229                in_degree_map.remove(&child);
230            } else {
231                *child_degree -= 1;
232            }
233        }
234        out_list.push(node)
235    }
236    Ok(out_list)
237}
238
239// Type aliases for readability
240type NodeId<G> = <G as GraphBase>::NodeId;
241type LongestPathResult<G, T, E> = Result<Option<(Vec<NodeId<G>>, T)>, E>;
242
243/// Calculates the longest path in a directed acyclic graph (DAG).
244///
245/// This function computes the longest path by weight in a given DAG. It will return the longest path
246/// along with its total weight, or `None` if the graph contains cycles which make the longest path
247/// computation undefined.
248///
249/// # Arguments
250/// * `graph`: Reference to a directed graph.
251/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for each edge in the graph.
252///   The callable should return the weight of the edge as `Result<T, E>`. The weight must be a type that implements
253///   `Num`, `Zero`, `PartialOrd`, and `Copy`.
254///
255/// # Type Parameters
256/// * `G`: Type of the graph. Must be a directed graph.
257/// * `F`: Type of the weight function.
258/// * `T`: The type of the edge weight. Must implement `Num`, `Zero`, `PartialOrd`, and `Copy`.
259/// * `E`: The type of the error that the weight function can return.
260///
261/// # Returns
262/// * `None` if the graph contains a cycle.
263/// * `Some((Vec<NodeId<G>>, T))` representing the longest path as a sequence of nodes and its total weight.
264/// * `Err(E)` if there is an error computing the weight of any edge.
265///
266/// # Example
267/// ```
268/// use petgraph::graph::DiGraph;
269/// use petgraph::Directed;
270/// use rustworkx_core::dag_algo::longest_path;
271///
272/// let mut graph: DiGraph<(), i32> = DiGraph::new();
273/// let n0 = graph.add_node(());
274/// let n1 = graph.add_node(());
275/// let n2 = graph.add_node(());
276/// graph.add_edge(n0, n1, 1);
277/// graph.add_edge(n0, n2, 3);
278/// graph.add_edge(n1, n2, 1);
279///
280/// let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
281/// let result = longest_path(&graph, weight_fn).unwrap();
282/// assert_eq!(result, Some((vec![n0, n2], 3)));
283/// ```
284pub fn longest_path<G, F, T, E>(graph: G, mut weight_fn: F) -> LongestPathResult<G, T, E>
285where
286    G: GraphProp<EdgeType = Directed> + IntoNodeIdentifiers + IntoEdgesDirected + Visitable,
287    F: FnMut(G::EdgeRef) -> Result<T, E>,
288    T: Num + Zero + PartialOrd + Copy,
289    <G as GraphBase>::NodeId: Hash + Eq + PartialOrd,
290{
291    let mut path: Vec<NodeId<G>> = Vec::new();
292    let nodes = match algo::toposort(graph, None) {
293        Ok(nodes) => nodes,
294        Err(_) => return Ok(None), // Return None if the graph contains a cycle
295    };
296
297    if nodes.is_empty() {
298        return Ok(Some((path, T::zero())));
299    }
300
301    let mut dist: HashMap<G::NodeId, (T, G::NodeId)> = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node
302
303    // Iterate over nodes in topological order
304    for node in nodes {
305        let parents = graph.edges_directed(node, petgraph::Direction::Incoming);
306        let mut incoming_path: Vec<(T, G::NodeId)> = Vec::new(); // Stores the distance and the previous node for each parent
307        for p_edge in parents {
308            let p_node = p_edge.source();
309            let weight: T = weight_fn(p_edge)?;
310            let length = dist[&p_node].0 + weight;
311            incoming_path.push((length, p_node));
312        }
313        // Determine the maximum distance and corresponding parent node
314        let max_path: (T, G::NodeId) = incoming_path
315            .into_iter()
316            .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
317            .unwrap_or((T::zero(), node)); // If there are no incoming edges, the distance is zero
318
319        // Store the maximum distance and the corresponding parent node for the current node
320        dist.insert(node, max_path);
321    }
322    let (first, _) = dist
323        .iter()
324        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
325        .unwrap();
326    let mut v = *first;
327    let mut u: Option<G::NodeId> = None;
328    // Backtrack from this node to find the path
329    #[allow(clippy::unnecessary_map_or)]
330    while u.map_or(true, |u| u != v) {
331        path.push(v);
332        u = Some(v);
333        v = dist[&v].1;
334    }
335    path.reverse(); // Reverse the path to get the correct order
336    let path_weight = dist[first].0; // The total weight of the longest path
337
338    Ok(Some((path, path_weight)))
339}
340
341/// Return an iterator of graph layers
342///
343/// A layer is a subgraph whose nodes are disjoint, i.e.,
344/// a layer has depth 1. The layers are constructed using a greedy algorithm.
345///
346/// Arguments:
347///
348/// * `graph` - The graph to get the layers from
349/// * `first_layer` - A list of node ids for the first layer. This
350///     will be the first layer in the output
351///
352/// Will `panic!` if a provided node is not in the graph.
353/// ```
354/// use rustworkx_core::petgraph::prelude::*;
355/// use rustworkx_core::dag_algo::layers;
356/// use rustworkx_core::dictmap::*;
357///
358/// let edge_list = vec![
359///  (0, 1),
360///  (1, 2),
361///  (2, 3),
362///  (3, 4),
363/// ];
364///
365/// let graph = DiGraph::<u32, u32>::from_edges(&edge_list);
366/// let layers: Vec<Vec<NodeIndex>> = layers(&graph, vec![0.into(),]).map(|layer| layer.unwrap()).collect();
367/// let expected_layers: Vec<Vec<NodeIndex>> = vec![
368///     vec![0.into(),],
369///     vec![1.into(),],
370///     vec![2.into(),],
371///     vec![3.into(),],
372///     vec![4.into()]
373/// ];
374/// assert_eq!(layers, expected_layers)
375/// ```
376pub fn layers<G>(
377    graph: G,
378    first_layer: Vec<G::NodeId>,
379) -> impl Iterator<Item = Result<Vec<G::NodeId>, LayersError>>
380where
381    G: NodeIndexable // Used in from_index and to_index.
382        + IntoNodeIdentifiers // Used for .node_identifiers
383        + IntoNeighborsDirected // Used for .neighbors_directed
384        + IntoEdgesDirected, // Used for .edged_directed
385    <G as GraphBase>::NodeId: Debug + Copy + Eq + Hash,
386{
387    LayersIter {
388        graph,
389        cur_layer: first_layer,
390        next_layer: vec![],
391        predecessor_count: HashMap::new(),
392        first_iter: true,
393        cycle_check: HashSet::default(),
394    }
395}
396
397#[derive(Debug, Clone)]
398struct LayersIter<G, N> {
399    graph: G,
400    cur_layer: Vec<N>,
401    next_layer: Vec<N>,
402    predecessor_count: HashMap<N, usize>,
403    first_iter: bool,
404    cycle_check: HashSet<N>, // TODO: Figure out why some cycles cannot be detected
405}
406
407impl<G, N> Iterator for LayersIter<G, N>
408where
409    G: NodeIndexable // Used in from_index and to_index.
410        + IntoNodeIdentifiers // Used for .node_identifiers
411        + IntoNeighborsDirected // Used for .neighbors_directed
412        + IntoEdgesDirected // Used for .edged_directed
413        + GraphBase<NodeId = N>,
414    N: Debug + Copy + Eq + Hash,
415{
416    type Item = Result<Vec<N>, LayersError>;
417    fn next(&mut self) -> Option<Self::Item> {
418        if self.first_iter {
419            self.first_iter = false;
420            for node in &self.cur_layer {
421                if self.graph.to_index(*node) >= self.graph.node_bound() {
422                    panic!("Node {:#?} is not present in the graph.", node);
423                }
424                if self.cycle_check.contains(node) {
425                    return Some(Err(LayersError(format!(
426                        "An invalid first layer was provided: {:#?} appears more than once.",
427                        node
428                    ))));
429                }
430                self.cycle_check.insert(*node);
431            }
432            Some(Ok(self.cur_layer.clone()))
433        } else if self.cur_layer.is_empty() {
434            None
435        } else {
436            for node in &self.cur_layer {
437                if self.graph.to_index(*node) >= self.graph.node_bound() {
438                    panic!("Node {:#?} is not present in the graph.", node);
439                }
440                let children = self
441                    .graph
442                    .neighbors_directed(*node, petgraph::Direction::Outgoing);
443                let mut used_indices: HashSet<G::NodeId> = HashSet::new();
444                for succ in children {
445                    // Skip duplicate successors
446                    if used_indices.contains(&succ) {
447                        continue;
448                    }
449                    used_indices.insert(succ);
450                    let mut multiplicity: usize = 0;
451                    let raw_edges: G::EdgesDirected = self
452                        .graph
453                        .edges_directed(*node, petgraph::Direction::Outgoing);
454                    for edge in raw_edges {
455                        if edge.target() == succ {
456                            multiplicity += 1;
457                        }
458                    }
459                    self.predecessor_count
460                        .entry(succ)
461                        .and_modify(|e| *e -= multiplicity)
462                        .or_insert(
463                            // Get the number of incoming edges to the successor
464                            self.graph
465                                .edges_directed(succ, petgraph::Direction::Incoming)
466                                .count()
467                                - multiplicity,
468                        );
469                    if *self.predecessor_count.get(&succ).unwrap() == 0 {
470                        if self.cycle_check.contains(&succ) {
471                            return Some(Err(LayersError("The provided graph contains a cycle or an invalid first layer was provided.".to_string())));
472                        }
473                        self.next_layer.push(succ);
474                        self.cycle_check.insert(succ);
475                        self.predecessor_count.remove(&succ);
476                    }
477                }
478            }
479            swap(&mut self.cur_layer, &mut self.next_layer);
480            self.next_layer.clear();
481            if self.cur_layer.is_empty() {
482                None
483            } else {
484                Some(Ok(self.cur_layer.clone()))
485            }
486        }
487    }
488}
489
490/// Collect runs that match a filter function given edge colors.
491///
492/// A bicolor run is a list of groups of nodes connected by edges of exactly
493/// two colors. In addition, all nodes in the group must match the given
494/// condition. Each node in the graph can appear in only a single group
495/// in the bicolor run.
496///
497/// # Arguments:
498///
499/// * `graph`: The DAG to find bicolor runs in
500/// * `filter_fn`: The filter function to use for matching nodes. It takes
501///     in one argument, the node data payload/weight object, and will return a
502///     boolean whether the node matches the conditions or not.
503///     If it returns ``true``, it will continue the bicolor chain.
504///     If it returns ``false``, it will stop the bicolor chain.
505///     If it returns ``None`` it will skip that node.
506/// * `color_fn`: The function that gives the color of the edge. It takes
507///     in one argument, the edge data payload/weight object, and will
508///     return a non-negative integer, the edge color. If the color is None,
509///     the edge is ignored.
510///
511/// # Returns:
512///
513/// * `Vec<Vec<G::NodeId>>`: a list of groups with exactly two edge colors, where each group
514///     is a list of node data payload/weight for the nodes in the bicolor run
515/// * `None` if a cycle is found in the graph
516/// * Raises an error if found computing the bicolor runs
517///
518/// # Example:
519///
520/// ```rust
521/// use rustworkx_core::dag_algo::collect_bicolor_runs;
522/// use petgraph::graph::{DiGraph, NodeIndex};
523/// use std::convert::Infallible;
524/// use std::error::Error;
525///
526/// let mut graph = DiGraph::new();
527/// let n0 = graph.add_node(0);
528/// let n1 = graph.add_node(0);
529/// let n2 = graph.add_node(1);
530/// let n3 = graph.add_node(1);
531/// let n4 = graph.add_node(0);
532/// let n5 = graph.add_node(0);
533/// graph.add_edge(n0, n2, 0);
534/// graph.add_edge(n1, n2, 1);
535/// graph.add_edge(n2, n3, 0);
536/// graph.add_edge(n2, n3, 1);
537/// graph.add_edge(n3, n4, 0);
538/// graph.add_edge(n3, n5, 1);
539///
540/// let filter_fn = |node_id| -> Result<Option<bool>, Infallible> {
541///     Ok(Some(*graph.node_weight(node_id).unwrap() > 0))
542/// };
543///
544/// let color_fn = |edge_id| -> Result<Option<usize>, Infallible> {
545///     Ok(Some(*graph.edge_weight(edge_id).unwrap() as usize))
546/// };
547///
548/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap();
549/// let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]];
550/// assert_eq!(result, Some(expected))
551/// ```
552pub fn collect_bicolor_runs<G, F, C, E>(
553    graph: G,
554    filter_fn: F,
555    color_fn: C,
556) -> Result<Option<Vec<Vec<G::NodeId>>>, E>
557where
558    F: Fn(<G as GraphBase>::NodeId) -> Result<Option<bool>, E>,
559    C: Fn(<G as GraphBase>::EdgeId) -> Result<Option<usize>, E>,
560    G: IntoNodeIdentifiers // Used in toposort
561        + IntoNeighborsDirected // Used in toposort
562        + IntoEdgesDirected // Used for .edges_directed
563        + Visitable // Used in toposort
564        + DataMap, // Used for .node_weight
565    <G as GraphBase>::NodeId: Eq + Hash,
566    <G as GraphBase>::EdgeId: Eq + Hash,
567{
568    let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new();
569    let mut block_id: Vec<Option<usize>> = Vec::new();
570    let mut block_list: Vec<Vec<G::NodeId>> = Vec::new();
571
572    let nodes = match algo::toposort(graph, None) {
573        Ok(nodes) => nodes,
574        Err(_) => return Ok(None), // Return None if the graph contains a cycle
575    };
576
577    // Utility for ensuring pending_list has the color index
578    macro_rules! ensure_vector_has_index {
579        ($pending_list: expr, $block_id: expr, $color: expr) => {
580            if $color >= $pending_list.len() {
581                $pending_list.resize($color + 1, Vec::new());
582                $block_id.resize($color + 1, None);
583            }
584        };
585    }
586
587    for node in nodes {
588        if let Some(is_match) = filter_fn(node)? {
589            let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing);
590
591            // Remove all edges that yield errors from color_fn
592            let colors = raw_edges
593                .map(|edge| color_fn(edge.id()))
594                .collect::<Result<Vec<Option<usize>>, _>>()?;
595
596            // Remove null edges from color_fn
597            let colors = colors.into_iter().flatten().collect::<Vec<usize>>();
598
599            match (colors.len(), is_match) {
600                (1, true) => {
601                    let c0 = colors[0];
602                    ensure_vector_has_index!(pending_list, block_id, c0);
603                    if let Some(c0_block_id) = block_id[c0] {
604                        block_list[c0_block_id].push(node);
605                    } else {
606                        pending_list[c0].push(node);
607                    }
608                }
609                (2, true) => {
610                    let c0 = colors[0];
611                    let c1 = colors[1];
612                    ensure_vector_has_index!(pending_list, block_id, c0);
613                    ensure_vector_has_index!(pending_list, block_id, c1);
614
615                    if block_id[c0].is_some()
616                        && block_id[c1].is_some()
617                        && block_id[c0] == block_id[c1]
618                    {
619                        block_list[block_id[c0].unwrap_or_default()].push(node);
620                    } else {
621                        let mut new_block: Vec<G::NodeId> =
622                            Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1);
623
624                        // Clears pending list and add to new block
625                        new_block.append(&mut pending_list[c0]);
626                        new_block.append(&mut pending_list[c1]);
627
628                        new_block.push(node);
629
630                        // Create new block, assign its id to color pair
631                        block_id[c0] = Some(block_list.len());
632                        block_id[c1] = Some(block_list.len());
633                        block_list.push(new_block);
634                    }
635                }
636                _ => {
637                    for color in colors {
638                        ensure_vector_has_index!(pending_list, block_id, color);
639                        if let Some(color_block_id) = block_id[color] {
640                            block_list[color_block_id].append(&mut pending_list[color]);
641                        }
642                        block_id[color] = None;
643                        pending_list[color].clear();
644                    }
645                }
646            }
647        }
648    }
649
650    Ok(Some(block_list))
651}
652
653/// Collect runs that match a filter function
654///
655/// A run is a path of nodes where there is only a single successor and all
656/// nodes in the path match the given condition. Each node in the graph can
657/// appear in only a single run.
658///
659/// # Arguments:
660///
661/// * `graph`: The DAG to collect runs from
662/// * `include_node_fn`: A filter function used for matching nodes. It takes
663///     in one argument, the node data payload/weight object, and returns a
664///     boolean whether the node matches the conditions or not.
665///     If it returns ``false``, the node will be skipped, cutting the run it's part of.
666///
667/// # Returns:
668///
669/// * An Iterator object for extracting the runs one by one. Each run is of type `Result<Vec<G::NodeId>>`.
670/// * `None` if a cycle is found in the graph.
671///
672/// # Example
673///
674/// ```rust
675/// use petgraph::graph::DiGraph;
676/// use rustworkx_core::dag_algo::collect_runs;
677///
678/// let mut graph: DiGraph<i32, ()> = DiGraph::new();
679/// let n1 = graph.add_node(-1);
680/// let n2 = graph.add_node(2);
681/// let n3 = graph.add_node(3);
682/// graph.add_edge(n1, n2, ());
683/// graph.add_edge(n1, n3, ());
684///
685/// let positive_payload = |n| -> Result<bool, ()> {Ok(*graph.node_weight(n).expect("i32") > 0)};
686/// let mut runs = collect_runs(&graph, positive_payload).expect("Some");
687///
688/// assert_eq!(runs.next(), Some(Ok(vec![n3])));
689/// assert_eq!(runs.next(), Some(Ok(vec![n2])));
690/// assert_eq!(runs.next(), None);
691/// ```
692///
693pub fn collect_runs<G, F, E>(
694    graph: G,
695    include_node_fn: F,
696) -> Option<impl Iterator<Item = Result<Vec<G::NodeId>, E>>>
697where
698    G: GraphProp<EdgeType = Directed>
699        + IntoNeighborsDirected
700        + IntoNodeIdentifiers
701        + Visitable
702        + NodeCount,
703    F: Fn(G::NodeId) -> Result<bool, E>,
704    <G as GraphBase>::NodeId: Hash + Eq,
705{
706    let mut nodes = match algo::toposort(graph, None) {
707        Ok(nodes) => nodes,
708        Err(_) => return None,
709    };
710
711    nodes.reverse(); // reversing so that pop() in Runs::next obeys the topological order
712
713    let runs = Runs {
714        graph,
715        seen: HashSet::with_capacity(nodes.len()),
716        sorted_nodes: nodes,
717        include_node_fn,
718    };
719
720    Some(runs)
721}
722
723/// Auxiliary struct to make the output of [`collect_runs`] iterable
724///
725/// If the filtering function passed to [`collect_runs`] returns an error, it is propagated
726/// through `next` as `Err`. In this case the run in which the error occurred will be skipped
727/// but the iterator can be used further until consumed.
728///
729struct Runs<G, F, E>
730where
731    G: GraphProp<EdgeType = Directed>
732        + IntoNeighborsDirected
733        + IntoNodeIdentifiers
734        + Visitable
735        + NodeCount,
736    F: Fn(G::NodeId) -> Result<bool, E>,
737{
738    graph: G,
739    sorted_nodes: Vec<G::NodeId>, // topologically-sorted nodes
740    seen: HashSet<G::NodeId>,
741    include_node_fn: F, // filtering function of the nodes
742}
743
744impl<G, F, E> Iterator for Runs<G, F, E>
745where
746    G: GraphProp<EdgeType = Directed>
747        + IntoNeighborsDirected
748        + IntoNodeIdentifiers
749        + Visitable
750        + NodeCount,
751    F: Fn(G::NodeId) -> Result<bool, E>,
752    <G as GraphBase>::NodeId: Hash + Eq,
753{
754    // This is a run, wrapped in Result for catching filter function errors
755    type Item = Result<Vec<G::NodeId>, E>;
756
757    fn next(&mut self) -> Option<Self::Item> {
758        while let Some(node) = self.sorted_nodes.pop() {
759            if self.seen.contains(&node) {
760                continue;
761            }
762            self.seen.insert(node);
763
764            match (self.include_node_fn)(node) {
765                Ok(false) => continue,
766                Err(e) => return Some(Err(e)),
767                _ => (),
768            }
769
770            let mut run: Vec<G::NodeId> = vec![node];
771            loop {
772                let mut successors: Vec<G::NodeId> = self
773                    .graph
774                    .neighbors_directed(*run.last().unwrap(), petgraph::Direction::Outgoing)
775                    .collect();
776                successors.dedup();
777
778                if successors.len() != 1 || self.seen.contains(&successors[0]) {
779                    break;
780                }
781
782                self.seen.insert(successors[0]);
783
784                match (self.include_node_fn)(successors[0]) {
785                    Ok(false) => continue,
786                    Err(e) => return Some(Err(e)),
787                    _ => (),
788                }
789
790                run.push(successors[0]);
791            }
792
793            if !run.is_empty() {
794                return Some(Ok(run));
795            }
796        }
797
798        None
799    }
800
801    fn size_hint(&self) -> (usize, Option<usize>) {
802        // Lower bound is 0 in case all remaining nodes are filtered out
803        // Upper bound is the remaining unprocessed nodes (some of which may be seen already), potentially all resulting with singleton runs
804        (0, Some(self.sorted_nodes.len()))
805    }
806}
807
808// Tests for longest_path
809#[cfg(test)]
810mod test_longest_path {
811    use super::*;
812    use petgraph::graph::DiGraph;
813    use petgraph::stable_graph::StableDiGraph;
814
815    #[test]
816    fn test_empty_graph() {
817        let graph: DiGraph<(), ()> = DiGraph::new();
818        let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
819        let result = longest_path(&graph, weight_fn);
820        assert_eq!(result, Ok(Some((vec![], 0))));
821    }
822
823    #[test]
824    fn test_single_node_graph() {
825        let mut graph: DiGraph<(), ()> = DiGraph::new();
826        let n0 = graph.add_node(());
827        let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
828        let result = longest_path(&graph, weight_fn);
829        assert_eq!(result, Ok(Some((vec![n0], 0))));
830    }
831
832    #[test]
833    fn test_dag_with_multiple_paths() {
834        let mut graph: DiGraph<(), i32> = DiGraph::new();
835        let n0 = graph.add_node(());
836        let n1 = graph.add_node(());
837        let n2 = graph.add_node(());
838        let n3 = graph.add_node(());
839        let n4 = graph.add_node(());
840        let n5 = graph.add_node(());
841        graph.add_edge(n0, n1, 3);
842        graph.add_edge(n0, n2, 2);
843        graph.add_edge(n1, n2, 1);
844        graph.add_edge(n1, n3, 4);
845        graph.add_edge(n2, n3, 2);
846        graph.add_edge(n3, n4, 2);
847        graph.add_edge(n2, n5, 1);
848        graph.add_edge(n4, n5, 3);
849        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
850        let result = longest_path(&graph, weight_fn);
851        assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12))));
852    }
853
854    #[test]
855    fn test_graph_with_cycle() {
856        let mut graph: DiGraph<(), i32> = DiGraph::new();
857        let n0 = graph.add_node(());
858        let n1 = graph.add_node(());
859        graph.add_edge(n0, n1, 1);
860        graph.add_edge(n1, n0, 1); // Creates a cycle
861
862        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
863        let result = longest_path(&graph, weight_fn);
864        assert_eq!(result, Ok(None));
865    }
866
867    #[test]
868    fn test_negative_weights() {
869        let mut graph: DiGraph<(), i32> = DiGraph::new();
870        let n0 = graph.add_node(());
871        let n1 = graph.add_node(());
872        let n2 = graph.add_node(());
873        graph.add_edge(n0, n1, -1);
874        graph.add_edge(n0, n2, 2);
875        graph.add_edge(n1, n2, -2);
876        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
877        let result = longest_path(&graph, weight_fn);
878        assert_eq!(result, Ok(Some((vec![n0, n2], 2))));
879    }
880
881    #[test]
882    fn test_longest_path_in_stable_digraph() {
883        let mut graph: StableDiGraph<(), i32> = StableDiGraph::new();
884        let n0 = graph.add_node(());
885        let n1 = graph.add_node(());
886        let n2 = graph.add_node(());
887        graph.add_edge(n0, n1, 1);
888        graph.add_edge(n0, n2, 3);
889        graph.add_edge(n1, n2, 1);
890        let weight_fn =
891            |edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::<i32, &str>(*edge.weight());
892        let result = longest_path(&graph, weight_fn);
893        assert_eq!(result, Ok(Some((vec![n0, n2], 3))));
894    }
895
896    #[test]
897    fn test_error_handling() {
898        let mut graph: DiGraph<(), i32> = DiGraph::new();
899        let n0 = graph.add_node(());
900        let n1 = graph.add_node(());
901        let n2 = graph.add_node(());
902        graph.add_edge(n0, n1, 1);
903        graph.add_edge(n0, n2, 2);
904        graph.add_edge(n1, n2, 1);
905        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| {
906            if *edge.weight() == 2 {
907                Err("Error: edge weight is 2")
908            } else {
909                Ok::<i32, &str>(*edge.weight())
910            }
911        };
912        let result = longest_path(&graph, weight_fn);
913        assert_eq!(result, Err("Error: edge weight is 2"));
914    }
915}
916
917// Tests for lexicographical_topological_sort
918#[cfg(test)]
919mod test_lexicographical_topological_sort {
920    use super::*;
921    use petgraph::graph::{DiGraph, NodeIndex};
922    use petgraph::stable_graph::StableDiGraph;
923    use std::convert::Infallible;
924
925    #[test]
926    fn test_empty_graph() {
927        let graph: DiGraph<(), ()> = DiGraph::new();
928        let sort_fn = |_: NodeIndex| -> Result<String, Infallible> { Ok("a".to_string()) };
929        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
930        assert_eq!(result, Ok(vec![]));
931    }
932
933    #[test]
934    fn test_empty_stable_graph() {
935        let graph: StableDiGraph<(), ()> = StableDiGraph::new();
936        let sort_fn = |_: NodeIndex| -> Result<String, Infallible> { Ok("a".to_string()) };
937        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
938        assert_eq!(result, Ok(vec![]));
939    }
940
941    #[test]
942    fn test_simple_layer() {
943        let mut graph: DiGraph<String, ()> = DiGraph::new();
944        let mut nodes: Vec<NodeIndex> = Vec::new();
945        nodes.push(graph.add_node("a".to_string()));
946        for i in 0..5 {
947            nodes.push(graph.add_node(i.to_string()));
948        }
949        nodes.push(graph.add_node("A parent".to_string()));
950        for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
951            graph.add_edge(nodes[source], nodes[target], ());
952        }
953        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
954        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
955        assert_eq!(
956            result,
957            Ok(vec![
958                NodeIndex::new(6),
959                NodeIndex::new(0),
960                NodeIndex::new(1),
961                NodeIndex::new(2),
962                NodeIndex::new(3),
963                NodeIndex::new(4),
964                NodeIndex::new(5)
965            ])
966        )
967    }
968
969    #[test]
970    fn test_simple_layer_stable() {
971        let mut graph: StableDiGraph<String, ()> = StableDiGraph::new();
972        let mut nodes: Vec<NodeIndex> = Vec::new();
973        nodes.push(graph.add_node("a".to_string()));
974        for i in 0..5 {
975            nodes.push(graph.add_node(i.to_string()));
976        }
977        nodes.push(graph.add_node("A parent".to_string()));
978        for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
979            graph.add_edge(nodes[source], nodes[target], ());
980        }
981        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
982        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
983        assert_eq!(
984            result,
985            Ok(vec![
986                NodeIndex::new(6),
987                NodeIndex::new(0),
988                NodeIndex::new(1),
989                NodeIndex::new(2),
990                NodeIndex::new(3),
991                NodeIndex::new(4),
992                NodeIndex::new(5)
993            ])
994        )
995    }
996
997    #[test]
998    fn test_reverse_graph() {
999        let mut graph: DiGraph<String, ()> = DiGraph::new();
1000        let mut nodes: Vec<NodeIndex> = Vec::new();
1001        for weight in ["a", "b", "c", "d", "e", "f"] {
1002            nodes.push(graph.add_node(weight.to_string()));
1003        }
1004        let edges = [
1005            (nodes[0], nodes[1]),
1006            (nodes[0], nodes[2]),
1007            (nodes[1], nodes[3]),
1008            (nodes[2], nodes[3]),
1009            (nodes[1], nodes[4]),
1010            (nodes[2], nodes[5]),
1011        ];
1012
1013        for (source, target) in edges {
1014            graph.add_edge(source, target, ());
1015        }
1016        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1017        let result = lexicographical_topological_sort(&graph, sort_fn, true, None);
1018        graph.reverse();
1019        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1020        let expected = lexicographical_topological_sort(&graph, sort_fn, false, None);
1021        assert_eq!(result, expected,)
1022    }
1023
1024    #[test]
1025    fn test_reverse_graph_stable() {
1026        let mut graph: StableDiGraph<String, ()> = StableDiGraph::new();
1027        let mut nodes: Vec<NodeIndex> = Vec::new();
1028        for weight in ["a", "b", "c", "d", "e", "f"] {
1029            nodes.push(graph.add_node(weight.to_string()));
1030        }
1031        let edges = [
1032            (nodes[0], nodes[1]),
1033            (nodes[0], nodes[2]),
1034            (nodes[1], nodes[3]),
1035            (nodes[2], nodes[3]),
1036            (nodes[1], nodes[4]),
1037            (nodes[2], nodes[5]),
1038        ];
1039
1040        for (source, target) in edges {
1041            graph.add_edge(source, target, ());
1042        }
1043        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1044        let result = lexicographical_topological_sort(&graph, sort_fn, true, None);
1045        graph.reverse();
1046        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1047        let expected = lexicographical_topological_sort(&graph, sort_fn, false, None);
1048        assert_eq!(result, expected,)
1049    }
1050
1051    #[test]
1052    fn test_initial() {
1053        let mut graph: StableDiGraph<u8, ()> = StableDiGraph::new();
1054        let mut nodes: Vec<NodeIndex> = Vec::new();
1055        for weight in 0..9 {
1056            nodes.push(graph.add_node(weight));
1057        }
1058        let edges = [
1059            (nodes[0], nodes[1]),
1060            (nodes[0], nodes[2]),
1061            (nodes[1], nodes[3]),
1062            (nodes[2], nodes[4]),
1063            (nodes[3], nodes[4]),
1064            (nodes[4], nodes[5]),
1065            (nodes[5], nodes[6]),
1066            (nodes[4], nodes[7]),
1067            (nodes[6], nodes[8]),
1068            (nodes[7], nodes[8]),
1069        ];
1070        for (source, target) in edges {
1071            graph.add_edge(source, target, ());
1072        }
1073        let sort_fn =
1074            |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].to_string()) };
1075        let initial = [nodes[6], nodes[7]];
1076        let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
1077        assert_eq!(result, Ok(vec![nodes[6], nodes[7], nodes[8]]));
1078        let initial = [nodes[0]];
1079        let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
1080        assert_eq!(
1081            result,
1082            lexicographical_topological_sort(&graph, sort_fn, false, None)
1083        );
1084        let initial = [nodes[7]];
1085        let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
1086        assert_eq!(result, Ok(vec![nodes[7]]));
1087    }
1088}
1089
1090#[cfg(test)]
1091mod test_layers {
1092    use super::*;
1093    use petgraph::{
1094        graph::{DiGraph, NodeIndex},
1095        stable_graph::StableDiGraph,
1096    };
1097
1098    #[test]
1099    fn test_empty_graph() {
1100        let graph: DiGraph<(), ()> = DiGraph::new();
1101        let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![]).flatten().collect();
1102        assert_eq!(result, vec![vec![]]);
1103    }
1104
1105    #[test]
1106    fn test_empty_stable_graph() {
1107        let graph: StableDiGraph<(), ()> = StableDiGraph::new();
1108        let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![]).flatten().collect();
1109        assert_eq!(result, vec![vec![]]);
1110    }
1111
1112    #[test]
1113    fn test_simple_layer() {
1114        let mut graph: DiGraph<String, ()> = DiGraph::new();
1115        let mut nodes: Vec<NodeIndex> = Vec::new();
1116        nodes.push(graph.add_node("a".to_string()));
1117        for i in 0..5 {
1118            nodes.push(graph.add_node(i.to_string()));
1119        }
1120        nodes.push(graph.add_node("A parent".to_string()));
1121        for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
1122            graph.add_edge(nodes[source], nodes[target], ());
1123        }
1124        let expected: Vec<Vec<NodeIndex>> = vec![
1125            vec![0.into(), 6.into()],
1126            vec![5.into(), 4.into(), 2.into(), 1.into(), 3.into()],
1127        ];
1128        let result: Vec<Vec<NodeIndex>> =
1129            layers(&graph, vec![0.into(), 6.into()]).flatten().collect();
1130        assert_eq!(result, expected);
1131    }
1132
1133    #[test]
1134    #[should_panic]
1135    fn test_missing_node() {
1136        let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
1137        let graph = DiGraph::<u32, u32>::from_edges(&edge_list);
1138        layers(&graph, vec![4.into(), 5.into()]).for_each(|layer| match layer {
1139            Err(e) => panic!("{}", e.0),
1140            Ok(layer) => drop(layer),
1141        });
1142    }
1143
1144    #[test]
1145    fn test_dag_with_multiple_paths() {
1146        let mut graph: DiGraph<(), ()> = DiGraph::new();
1147        let n0 = graph.add_node(());
1148        let n1 = graph.add_node(());
1149        let n2 = graph.add_node(());
1150        let n3 = graph.add_node(());
1151        let n4 = graph.add_node(());
1152        let n5 = graph.add_node(());
1153        graph.add_edge(n0, n1, ());
1154        graph.add_edge(n0, n2, ());
1155        graph.add_edge(n1, n2, ());
1156        graph.add_edge(n1, n3, ());
1157        graph.add_edge(n2, n3, ());
1158        graph.add_edge(n3, n4, ());
1159        graph.add_edge(n2, n5, ());
1160        graph.add_edge(n4, n5, ());
1161
1162        let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![0.into()]).flatten().collect();
1163        assert_eq!(
1164            result,
1165            vec![vec![n0], vec![n1], vec![n2], vec![n3], vec![n4], vec![n5]]
1166        );
1167    }
1168
1169    #[test]
1170    #[should_panic]
1171    fn test_graph_with_cycle() {
1172        let mut graph: DiGraph<(), i32> = DiGraph::new();
1173        let n0 = graph.add_node(());
1174        let n1 = graph.add_node(());
1175        graph.add_edge(n0, n1, 1);
1176        graph.add_edge(n1, n0, 1);
1177
1178        layers(&graph, vec![0.into()]).for_each(|layer| match layer {
1179            Err(e) => panic!("{}", e.0),
1180            Ok(layer) => drop(layer),
1181        });
1182    }
1183}
1184
1185// Tests for collect_bicolor_runs
1186#[cfg(test)]
1187mod test_collect_bicolor_runs {
1188
1189    use super::*;
1190    use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
1191    use std::error::Error;
1192
1193    #[test]
1194    fn test_cycle() {
1195        let mut graph = DiGraph::new();
1196        let n0 = graph.add_node(0);
1197        let n1 = graph.add_node(0);
1198        let n2 = graph.add_node(0);
1199        graph.add_edge(n0, n1, 1);
1200        graph.add_edge(n1, n2, 1);
1201        graph.add_edge(n2, n0, 1);
1202
1203        let test_filter_fn =
1204            |_node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> { Ok(Some(true)) };
1205        let test_color_fn =
1206            |_edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> { Ok(Some(1)) };
1207        let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) {
1208            Ok(Some(_value)) => "Not None",
1209            Ok(None) => "None",
1210            Err(_) => "Error",
1211        };
1212        assert_eq!(result, "None")
1213    }
1214
1215    #[test]
1216    fn test_filter_function_inner_exception() {
1217        let mut graph = DiGraph::new();
1218        graph.add_node(0);
1219
1220        let fail_function = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1221            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1222            if *node_weight > 0 {
1223                Ok(Some(true))
1224            } else {
1225                Err(Box::from("Failed!"))
1226            }
1227        };
1228        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1229            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1230            Ok(Some(*edge_weight as usize))
1231        };
1232        let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) {
1233            Ok(Some(_value)) => "Not None",
1234            Ok(None) => "None",
1235            Err(_) => "Error",
1236        };
1237        assert_eq!(result, "Error")
1238    }
1239
1240    #[test]
1241    fn test_empty() {
1242        let graph = DiGraph::new();
1243        let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1244            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1245            Ok(Some(*node_weight > 1))
1246        };
1247        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1248            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1249            Ok(Some(*edge_weight as usize))
1250        };
1251        let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
1252        let expected: Vec<Vec<NodeIndex>> = vec![];
1253        assert_eq!(result, Some(expected))
1254    }
1255
1256    #[test]
1257    fn test_two_colors() {
1258        /* Based on the following graph from the Python unit tests:
1259        Input:
1260                ┌─────────────┐                 ┌─────────────┐
1261                │             │                 │             │
1262                │    q0       │                 │    q1       │
1263                │             │                 │             │
1264                └───┬─────────┘                 └──────┬──────┘
1265                    │          ┌─────────────┐         │
1266                q0  │          │             │         │  q1
1267                    │          │             │         │
1268                    └─────────►│     cx      │◄────────┘
1269                    ┌──────────┤             ├─────────┐
1270                    │          │             │         │
1271                q0  │          └─────────────┘         │  q1
1272                    │                                  │
1273                    │          ┌─────────────┐         │
1274                    │          │             │         │
1275                    └─────────►│      cz     │◄────────┘
1276                     ┌─────────┤             ├─────────┐
1277                     │         └─────────────┘         │
1278                 q0  │                                 │ q1
1279                     │                                 │
1280                 ┌───▼─────────┐                ┌──────▼──────┐
1281                 │             │                │             │
1282                 │    q0       │                │    q1       │
1283                 │             │                │             │
1284                 └─────────────┘                └─────────────┘
1285        Expected: [[cx, cz]]
1286        */
1287        let mut graph = DiGraph::new();
1288        let n0 = graph.add_node(0); //q0
1289        let n1 = graph.add_node(0); //q1
1290        let n2 = graph.add_node(1); //cx
1291        let n3 = graph.add_node(1); //cz
1292        let n4 = graph.add_node(0); //q0_1
1293        let n5 = graph.add_node(0); //q1_1
1294        graph.add_edge(n0, n2, 0); //q0 -> cx
1295        graph.add_edge(n1, n2, 1); //q1 -> cx
1296        graph.add_edge(n2, n3, 0); //cx -> cz
1297        graph.add_edge(n2, n3, 1); //cx -> cz
1298        graph.add_edge(n3, n4, 0); //cz -> q0_1
1299        graph.add_edge(n3, n5, 1); //cz -> q1_1
1300
1301        // Filter out q0, q1, q0_1 and q1_1
1302        let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1303            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1304            Ok(Some(*node_weight > 0))
1305        };
1306        // The edge color will match its weight
1307        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1308            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1309            Ok(Some(*edge_weight as usize))
1310        };
1311        let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
1312        let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]]; //[[cx, cz]]
1313        assert_eq!(result, Some(expected))
1314    }
1315
1316    #[test]
1317    fn test_two_colors_with_pending() {
1318        /* Based on the following graph from the Python unit tests:
1319        Input:
1320                ┌─────────────┐
1321                │             │
1322                │    q0       │
1323                │             │
1324                └───┬─────────┘
1325                 | q0
13261327                ┌───▼─────────┐
1328                │             │
1329                │    h        │
1330                │             │
1331                └───┬─────────┘
1332                    | q0
1333                    │                           ┌─────────────┐
1334                    │                           │             │
1335                    │                           │    q1       │
1336                    │                           │             │
1337                    |                           └──────┬──────┘
1338                    │          ┌─────────────┐         │
1339                q0  │          │             │         │  q1
1340                    │          │             │         │
1341                    └─────────►│     cx      │◄────────┘
1342                    ┌──────────┤             ├─────────┐
1343                    │          │             │         │
1344                q0  │          └─────────────┘         │  q1
1345                    │                                  │
1346                    │          ┌─────────────┐         │
1347                    │          │             │         │
1348                    └─────────►│      cz     │◄────────┘
1349                     ┌─────────┤             ├─────────┐
1350                     │         └─────────────┘         │
1351                 q0  │                                 │ q1
1352                     │                                 │
1353                 ┌───▼─────────┐                ┌──────▼──────┐
1354                 │             │                │             │
1355                 │    q0       │                │    y        │
1356                 │             │                │             │
1357                 └─────────────┘                └─────────────┘
1358                                                    | q1
13591360                                                ┌───▼─────────┐
1361                                                │             │
1362                                                │    q1       │
1363                                                │             │
1364                                                └─────────────┘
1365        Expected: [[h, cx, cz, y]]
1366        */
1367        let mut graph = DiGraph::new();
1368        let n0 = graph.add_node(0); //q0
1369        let n1 = graph.add_node(0); //q1
1370        let n2 = graph.add_node(1); //h
1371        let n3 = graph.add_node(1); //cx
1372        let n4 = graph.add_node(1); //cz
1373        let n5 = graph.add_node(1); //y
1374        let n6 = graph.add_node(0); //q0_1
1375        let n7 = graph.add_node(0); //q1_1
1376        graph.add_edge(n0, n2, 0); //q0 -> h
1377        graph.add_edge(n2, n3, 0); //h -> cx
1378        graph.add_edge(n1, n3, 1); //q1 -> cx
1379        graph.add_edge(n3, n4, 0); //cx -> cz
1380        graph.add_edge(n3, n4, 1); //cx -> cz
1381        graph.add_edge(n4, n6, 0); //cz -> q0_1
1382        graph.add_edge(n4, n5, 1); //cz -> y
1383        graph.add_edge(n5, n7, 1); //y -> q1_1
1384
1385        // Filter out q0, q1, q0_1 and q1_1
1386        let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1387            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1388            Ok(Some(*node_weight > 0))
1389        };
1390        // The edge color will match its weight
1391        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1392            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1393            Ok(Some(*edge_weight as usize))
1394        };
1395        let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
1396        let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]]
1397        assert_eq!(result, Some(expected))
1398    }
1399}
1400
1401#[cfg(test)]
1402mod test_collect_runs {
1403    use super::collect_runs;
1404    use petgraph::{graph::DiGraph, visit::GraphBase};
1405
1406    type BareDiGraph = DiGraph<(), ()>;
1407    type RunResult = Result<Vec<<BareDiGraph as GraphBase>::NodeId>, ()>;
1408
1409    #[test]
1410    fn test_empty_graph() {
1411        let graph: BareDiGraph = DiGraph::new();
1412
1413        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1414
1415        let run = runs.next();
1416        assert!(run == None);
1417
1418        let runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1419
1420        let runs: Vec<RunResult> = runs.collect();
1421
1422        assert_eq!(runs, Vec::<RunResult>::new());
1423    }
1424
1425    #[test]
1426    fn test_simple_run_w_filter() {
1427        let mut graph: BareDiGraph = DiGraph::new();
1428        let n1 = graph.add_node(());
1429        let n2 = graph.add_node(());
1430        let n3 = graph.add_node(());
1431        graph.add_edge(n1, n2, ());
1432        graph.add_edge(n2, n3, ());
1433
1434        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1435
1436        let the_run = runs.next().expect("Some").expect("3 nodes");
1437        assert_eq!(the_run.len(), 3);
1438        assert_eq!(runs.next(), None);
1439
1440        assert_eq!(the_run, vec![n1, n2, n3]);
1441
1442        // Now with some filters
1443        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(false) }).expect("Some");
1444
1445        assert_eq!(runs.next(), None);
1446
1447        let mut runs = collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n2) }).expect("Some");
1448
1449        assert_eq!(runs.next(), Some(Ok(vec![n1])));
1450        assert_eq!(runs.next(), Some(Ok(vec![n3])));
1451    }
1452
1453    #[test]
1454    fn test_multiple_runs_w_filter() {
1455        let mut graph: BareDiGraph = DiGraph::new();
1456        let n1 = graph.add_node(());
1457        let n2 = graph.add_node(());
1458        let n3 = graph.add_node(());
1459        let n4 = graph.add_node(());
1460        let n5 = graph.add_node(());
1461        let n6 = graph.add_node(());
1462        let n7 = graph.add_node(());
1463
1464        graph.add_edge(n1, n2, ());
1465        graph.add_edge(n2, n3, ());
1466        graph.add_edge(n3, n7, ());
1467        graph.add_edge(n4, n3, ());
1468        graph.add_edge(n4, n7, ());
1469        graph.add_edge(n5, n4, ());
1470        graph.add_edge(n6, n5, ());
1471
1472        let runs: Vec<RunResult> = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) })
1473            .expect("Some")
1474            .collect();
1475
1476        assert_eq!(runs, vec![Ok(vec![n6, n5, n4]), Ok(vec![n1, n2, n3, n7])]);
1477
1478        // And now with some filter
1479        let runs: Vec<RunResult> =
1480            collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n4 && n != n2) })
1481                .expect("Some")
1482                .collect();
1483
1484        assert_eq!(runs, vec![Ok(vec![n6, n5]), Ok(vec![n1]), Ok(vec![n3, n7])]);
1485    }
1486
1487    #[test]
1488    fn test_singleton_runs_w_filter() {
1489        let mut graph: BareDiGraph = DiGraph::new();
1490        let n1 = graph.add_node(());
1491        let n2 = graph.add_node(());
1492        let n3 = graph.add_node(());
1493
1494        graph.add_edge(n1, n2, ());
1495        graph.add_edge(n1, n3, ());
1496
1497        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1498
1499        assert_eq!(runs.next().expect("n1"), Ok(vec![n1]));
1500        assert_eq!(runs.next().expect("n3"), Ok(vec![n3]));
1501        assert_eq!(runs.next().expect("n2"), Ok(vec![n2]));
1502
1503        // And now with some filter
1504        let runs: Vec<RunResult> = collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n1) })
1505            .expect("Some")
1506            .collect();
1507
1508        assert_eq!(runs, vec![Ok(vec![n3]), Ok(vec![n2])]);
1509    }
1510
1511    #[test]
1512    fn test_error_propagation() {
1513        let mut graph: BareDiGraph = DiGraph::new();
1514        graph.add_node(());
1515
1516        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Err(()) }).expect("Some");
1517
1518        assert!(runs.next().expect("Some").is_err());
1519        assert_eq!(runs.next(), None);
1520    }
1521}