rustworkx_core/
dag_algo.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use std::cmp::{Eq, Ordering};
14use std::collections::BinaryHeap;
15use std::error::Error;
16use std::fmt::{Display, Formatter};
17use std::hash::Hash;
18
19use hashbrown::hash_set::Entry;
20use hashbrown::{HashMap, HashSet};
21use std::fmt::Debug;
22use std::mem::swap;
23
24use petgraph::algo;
25use petgraph::data::DataMap;
26use petgraph::visit::{
27    EdgeRef, GraphBase, GraphProp, IntoEdgesDirected, IntoNeighborsDirected, IntoNodeIdentifiers,
28    NodeCount, NodeIndexable, Visitable,
29};
30use petgraph::Directed;
31
32use num_traits::{Num, Zero};
33
34use crate::err::LayersError;
35
36/// Return a pair of [`petgraph::Direction`] values corresponding to the "forwards" and "backwards"
37/// direction of graph traversal, based on whether the graph is being traversed forwards (following
38/// the edges) or backward (reversing along edges).  The order of returns is (forwards, backwards).
39#[inline(always)]
40pub fn traversal_directions(reverse: bool) -> (petgraph::Direction, petgraph::Direction) {
41    if reverse {
42        (petgraph::Direction::Outgoing, petgraph::Direction::Incoming)
43    } else {
44        (petgraph::Direction::Incoming, petgraph::Direction::Outgoing)
45    }
46}
47
48/// An error enumeration returned by topological sort functions.
49#[derive(Debug, PartialEq, Eq)]
50pub enum TopologicalSortError<E: Error> {
51    CycleOrBadInitialState,
52    KeyError(E),
53}
54
55impl<E: Error> Display for TopologicalSortError<E> {
56    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57        match self {
58            TopologicalSortError::CycleOrBadInitialState => {
59                write!(f, "At least one initial node is reachable from another")
60            }
61            TopologicalSortError::KeyError(ref e) => {
62                write!(f, "The key callback failed with: {e:?}")
63            }
64        }
65    }
66}
67
68impl<E: Error> Error for TopologicalSortError<E> {}
69
70/// Get the lexicographical topological sorted nodes from the provided DAG.
71///
72/// This function returns a list of nodes in a graph lexicographically
73/// topologically sorted, using the provided key function as a tie-breaker.
74/// A topological sort is a linear ordering of vertices such that for every
75/// directed edge from node :math:`u` to node :math:`v`, :math:`u` comes before
76/// :math:`v` in the ordering.  If ``reverse`` is set to ``true``, the edges
77/// are treated as if they pointed in the opposite direction.
78///
79/// # Arguments:
80///
81/// * `dag`: The DAG to get the topological sorted nodes from
82/// * `key`: A function that gets passed a single argument, the node id from
83///   `dag` and is expected to return a key which will be used for
84///   resolving ties in the sorting order.
85/// * `reverse`: If `false`, perform a regular topological ordering.  If `true`,
86///   return the lexicographical topological order that would have been found
87///   if all the edges in the graph were reversed.  This does not affect the
88///   comparisons from the `key`.
89/// * `initial`: If given, the initial node indices to start the topological
90///   ordering from.  If not given, the topological ordering will certainly contain every node in
91///   the graph.  If given, only the `initial` nodes and nodes that are dominated by the
92///   `initial` set will be in the ordering.  Notably, any node that has a natural in degree of
93///   zero will not be in the output ordering if `initial` is given and the zero-in-degree node
94///   is not in it.  It is not supported to give an `initial` set where the nodes have even
95///   a partial topological order between themselves and `None` will be returned in this case
96///
97/// # Returns
98///
99/// `Vec<G::NodeId>` representing the topological ordering of nodes or a
100/// `TopologicalSortError` if there is an error.
101///
102/// # Example
103///
104/// ```rust
105/// use std::convert::Infallible;
106///
107/// use rustworkx_core::dag_algo::lexicographical_topological_sort;
108/// use rustworkx_core::petgraph::stable_graph::{StableDiGraph, NodeIndex};
109///
110/// let mut graph: StableDiGraph<u8, ()> = StableDiGraph::new();
111/// let mut nodes: Vec<NodeIndex> = Vec::new();
112/// for weight in 0..9 {
113///     nodes.push(graph.add_node(weight));
114/// }
115/// let edges = [
116///         (nodes[0], nodes[1]),
117///         (nodes[0], nodes[2]),
118///         (nodes[1], nodes[3]),
119///         (nodes[2], nodes[4]),
120///         (nodes[3], nodes[4]),
121///         (nodes[4], nodes[5]),
122///         (nodes[5], nodes[6]),
123///         (nodes[4], nodes[7]),
124///         (nodes[6], nodes[8]),
125///         (nodes[7], nodes[8]),
126/// ];
127/// for (source, target) in edges {
128///     graph.add_edge(source, target, ());
129/// }
130/// let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].to_string()) };
131/// let initial = [nodes[6], nodes[7]];
132/// let result = lexicographical_topological_sort(&graph, sort_fn, true, Some(&initial));
133/// let expected = vec![
134///     nodes[6],
135///     nodes[5],
136///     nodes[7],
137///     nodes[4],
138///     nodes[2],
139///     nodes[3],
140///     nodes[1],
141///     nodes[0]
142/// ];
143/// assert_eq!(result, Ok(expected));
144///
145/// ```
146pub fn lexicographical_topological_sort<G, F, K, E>(
147    dag: G,
148    mut key: F,
149    reverse: bool,
150    initial: Option<&[G::NodeId]>,
151) -> Result<Vec<G::NodeId>, TopologicalSortError<E>>
152where
153    G: GraphProp<EdgeType = Directed>
154        + IntoNodeIdentifiers
155        + IntoNeighborsDirected
156        + IntoEdgesDirected
157        + NodeCount,
158    <G as GraphBase>::NodeId: Hash + Eq + Ord,
159    F: FnMut(G::NodeId) -> Result<K, E>,
160    K: Ord,
161    E: Error,
162{
163    // HashMap of node_index indegree
164    let node_count = dag.node_count();
165    let (in_dir, out_dir) = traversal_directions(reverse);
166
167    #[derive(Clone, Eq, PartialEq)]
168    struct State<K: Ord, N: Eq + PartialOrd> {
169        key: K,
170        node: N,
171    }
172
173    impl<K: Ord, N: Eq + Ord> Ord for State<K, N> {
174        fn cmp(&self, other: &State<K, N>) -> Ordering {
175            // Notice that we flip the ordering on costs.
176            // In case of a tie, we compare positions - this step is necessary
177            // to make implementations of `PartialEq` and `Ord` consistent.
178            other
179                .key
180                .cmp(&self.key)
181                .then_with(|| other.node.cmp(&self.node))
182        }
183    }
184
185    // `PartialOrd` needs to be implemented as well.
186    impl<K: Ord, N: Eq + Ord> PartialOrd for State<K, N> {
187        fn partial_cmp(&self, other: &State<K, N>) -> Option<Ordering> {
188            Some(self.cmp(other))
189        }
190    }
191
192    let mut in_degree_map: HashMap<G::NodeId, usize> = HashMap::with_capacity(node_count);
193    if let Some(initial) = initial {
194        // In this case, we don't iterate through all the nodes in the graph, and most nodes aren't
195        // in `in_degree_map`; we'll fill in the relevant edge counts lazily.
196        for node in initial {
197            in_degree_map.insert(*node, 0);
198        }
199    } else {
200        for node in dag.node_identifiers() {
201            in_degree_map.insert(node, dag.edges_directed(node, in_dir).count());
202        }
203    }
204
205    let mut zero_indegree = BinaryHeap::with_capacity(node_count);
206    for (node, degree) in in_degree_map.iter() {
207        if *degree == 0 {
208            let map_key = key(*node).map_err(|e| TopologicalSortError::KeyError(e))?;
209            zero_indegree.push(State {
210                key: map_key,
211                node: *node,
212            });
213        }
214    }
215    let mut out_list: Vec<G::NodeId> = Vec::with_capacity(node_count);
216    while let Some(State { node, .. }) = zero_indegree.pop() {
217        let neighbors = dag.neighbors_directed(node, out_dir);
218        for child in neighbors {
219            let child_degree = in_degree_map
220                .entry(child)
221                .or_insert_with(|| dag.edges_directed(child, in_dir).count());
222            if *child_degree == 0 {
223                return Err(TopologicalSortError::CycleOrBadInitialState);
224            } else if *child_degree == 1 {
225                let map_key = key(child).map_err(|e| TopologicalSortError::KeyError(e))?;
226                zero_indegree.push(State {
227                    key: map_key,
228                    node: child,
229                });
230                in_degree_map.remove(&child);
231            } else {
232                *child_degree -= 1;
233            }
234        }
235        out_list.push(node)
236    }
237    Ok(out_list)
238}
239
240// Type aliases for readability
241type NodeId<G> = <G as GraphBase>::NodeId;
242type LongestPathResult<G, T, E> = Result<Option<(Vec<NodeId<G>>, T)>, E>;
243
244/// Calculates the longest path in a directed acyclic graph (DAG).
245///
246/// This function computes the longest path by weight in a given DAG. It will return the longest path
247/// along with its total weight, or `None` if the graph contains cycles which make the longest path
248/// computation undefined.
249///
250/// # Arguments
251/// * `graph`: Reference to a directed graph.
252/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for each edge in the graph.
253///   The callable should return the weight of the edge as `Result<T, E>`. The weight must be a type that implements
254///   `Num`, `Zero`, `PartialOrd`, and `Copy`.
255///
256/// # Type Parameters
257/// * `G`: Type of the graph. Must be a directed graph.
258/// * `F`: Type of the weight function.
259/// * `T`: The type of the edge weight. Must implement `Num`, `Zero`, `PartialOrd`, and `Copy`.
260/// * `E`: The type of the error that the weight function can return.
261///
262/// # Returns
263/// * `None` if the graph contains a cycle.
264/// * `Some((Vec<NodeId<G>>, T))` representing the longest path as a sequence of nodes and its total weight.
265/// * `Err(E)` if there is an error computing the weight of any edge.
266///
267/// # Example
268/// ```
269/// use petgraph::graph::DiGraph;
270/// use petgraph::Directed;
271/// use rustworkx_core::dag_algo::longest_path;
272///
273/// let mut graph: DiGraph<(), i32> = DiGraph::new();
274/// let n0 = graph.add_node(());
275/// let n1 = graph.add_node(());
276/// let n2 = graph.add_node(());
277/// graph.add_edge(n0, n1, 1);
278/// graph.add_edge(n0, n2, 3);
279/// graph.add_edge(n1, n2, 1);
280///
281/// let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
282/// let result = longest_path(&graph, weight_fn).unwrap();
283/// assert_eq!(result, Some((vec![n0, n2], 3)));
284/// ```
285pub fn longest_path<G, F, T, E>(graph: G, mut weight_fn: F) -> LongestPathResult<G, T, E>
286where
287    G: GraphProp<EdgeType = Directed> + IntoNodeIdentifiers + IntoEdgesDirected + Visitable,
288    F: FnMut(G::EdgeRef) -> Result<T, E>,
289    T: Num + Zero + PartialOrd + Copy,
290    <G as GraphBase>::NodeId: Hash + Eq + PartialOrd,
291{
292    let mut path: Vec<NodeId<G>> = Vec::new();
293    let nodes = match algo::toposort(graph, None) {
294        Ok(nodes) => nodes,
295        Err(_) => return Ok(None), // Return None if the graph contains a cycle
296    };
297
298    if nodes.is_empty() {
299        return Ok(Some((path, T::zero())));
300    }
301
302    let mut dist: HashMap<G::NodeId, (T, G::NodeId)> = HashMap::with_capacity(nodes.len()); // Stores the distance and the previous node
303
304    // Iterate over nodes in topological order
305    for node in nodes {
306        let parents = graph.edges_directed(node, petgraph::Direction::Incoming);
307        let mut incoming_path: Vec<(T, G::NodeId)> = Vec::new(); // Stores the distance and the previous node for each parent
308        for p_edge in parents {
309            let p_node = p_edge.source();
310            let weight: T = weight_fn(p_edge)?;
311            let length = dist[&p_node].0 + weight;
312            incoming_path.push((length, p_node));
313        }
314        // Determine the maximum distance and corresponding parent node
315        let max_path: (T, G::NodeId) = incoming_path
316            .into_iter()
317            .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
318            .unwrap_or((T::zero(), node)); // If there are no incoming edges, the distance is zero
319
320        // Store the maximum distance and the corresponding parent node for the current node
321        dist.insert(node, max_path);
322    }
323    let (first, _) = dist
324        .iter()
325        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
326        .unwrap();
327    let mut v = *first;
328    let mut u: Option<G::NodeId> = None;
329    // Backtrack from this node to find the path
330    #[allow(clippy::unnecessary_map_or)]
331    while u.map_or(true, |u| u != v) {
332        path.push(v);
333        u = Some(v);
334        v = dist[&v].1;
335    }
336    path.reverse(); // Reverse the path to get the correct order
337    let path_weight = dist[first].0; // The total weight of the longest path
338
339    Ok(Some((path, path_weight)))
340}
341
342/// Return an iterator of graph layers
343///
344/// A layer is a subgraph whose nodes are disjoint, i.e.,
345/// a layer has depth 1. The layers are constructed using a greedy algorithm.
346///
347/// Arguments:
348///
349/// * `graph` - The graph to get the layers from
350/// * `first_layer` - A list of node ids for the first layer. This
351///   will be the first layer in the output
352///
353/// Will `panic!` if a provided node is not in the graph.
354/// ```
355/// use rustworkx_core::petgraph::prelude::*;
356/// use rustworkx_core::dag_algo::layers;
357/// use rustworkx_core::dictmap::*;
358///
359/// let edge_list = vec![
360///  (0, 1),
361///  (1, 2),
362///  (2, 3),
363///  (3, 4),
364/// ];
365///
366/// let graph = DiGraph::<u32, u32>::from_edges(&edge_list);
367/// let layers: Vec<Vec<NodeIndex>> = layers(&graph, vec![0.into(),]).map(|layer| layer.unwrap()).collect();
368/// let expected_layers: Vec<Vec<NodeIndex>> = vec![
369///     vec![0.into(),],
370///     vec![1.into(),],
371///     vec![2.into(),],
372///     vec![3.into(),],
373///     vec![4.into()]
374/// ];
375/// assert_eq!(layers, expected_layers)
376/// ```
377pub fn layers<G>(
378    graph: G,
379    first_layer: Vec<G::NodeId>,
380) -> impl Iterator<Item = Result<Vec<G::NodeId>, LayersError>>
381where
382    G: NodeIndexable // Used in from_index and to_index.
383        + IntoNodeIdentifiers // Used for .node_identifiers
384        + IntoNeighborsDirected // Used for .neighbors_directed
385        + IntoEdgesDirected, // Used for .edged_directed
386    <G as GraphBase>::NodeId: Debug + Copy + Eq + Hash,
387{
388    LayersIter {
389        graph,
390        cur_layer: first_layer,
391        next_layer: vec![],
392        predecessor_count: HashMap::new(),
393        first_iter: true,
394        cycle_check: HashSet::default(),
395    }
396}
397
398#[derive(Debug, Clone)]
399struct LayersIter<G, N> {
400    graph: G,
401    cur_layer: Vec<N>,
402    next_layer: Vec<N>,
403    predecessor_count: HashMap<N, usize>,
404    first_iter: bool,
405    cycle_check: HashSet<N>, // TODO: Figure out why some cycles cannot be detected
406}
407
408impl<G, N> Iterator for LayersIter<G, N>
409where
410    G: NodeIndexable // Used in from_index and to_index.
411        + IntoNodeIdentifiers // Used for .node_identifiers
412        + IntoNeighborsDirected // Used for .neighbors_directed
413        + IntoEdgesDirected // Used for .edged_directed
414        + GraphBase<NodeId = N>,
415    N: Debug + Copy + Eq + Hash,
416{
417    type Item = Result<Vec<N>, LayersError>;
418    fn next(&mut self) -> Option<Self::Item> {
419        if self.first_iter {
420            self.first_iter = false;
421            for node in &self.cur_layer {
422                if self.graph.to_index(*node) >= self.graph.node_bound() {
423                    panic!("Node {node:#?} is not present in the graph.");
424                }
425                if self.cycle_check.contains(node) {
426                    return Some(Err(LayersError(format!(
427                        "An invalid first layer was provided: {node:#?} appears more than once."
428                    ))));
429                }
430                self.cycle_check.insert(*node);
431            }
432            Some(Ok(self.cur_layer.clone()))
433        } else if self.cur_layer.is_empty() {
434            None
435        } else {
436            for node in &self.cur_layer {
437                if self.graph.to_index(*node) >= self.graph.node_bound() {
438                    panic!("Node {node:#?} is not present in the graph.");
439                }
440                let children = self
441                    .graph
442                    .neighbors_directed(*node, petgraph::Direction::Outgoing);
443                let mut used_indices: HashSet<G::NodeId> = HashSet::new();
444                for succ in children {
445                    match used_indices.entry(succ) {
446                        Entry::Occupied(_) => {
447                            // Skip duplicate successors
448                            continue;
449                        }
450                        Entry::Vacant(used_indices_entry) => {
451                            used_indices_entry.insert();
452                        }
453                    }
454                    let mut multiplicity: usize = 0;
455                    let raw_edges: G::EdgesDirected = self
456                        .graph
457                        .edges_directed(*node, petgraph::Direction::Outgoing);
458                    for edge in raw_edges {
459                        if edge.target() == succ {
460                            multiplicity += 1;
461                        }
462                    }
463                    self.predecessor_count
464                        .entry(succ)
465                        .and_modify(|e| *e -= multiplicity)
466                        .or_insert(
467                            // Get the number of incoming edges to the successor
468                            self.graph
469                                .edges_directed(succ, petgraph::Direction::Incoming)
470                                .count()
471                                - multiplicity,
472                        );
473                    if *self.predecessor_count.get(&succ).unwrap() == 0 {
474                        if self.cycle_check.contains(&succ) {
475                            return Some(Err(LayersError("The provided graph contains a cycle or an invalid first layer was provided.".to_string())));
476                        }
477                        self.next_layer.push(succ);
478                        self.cycle_check.insert(succ);
479                        self.predecessor_count.remove(&succ);
480                    }
481                }
482            }
483            swap(&mut self.cur_layer, &mut self.next_layer);
484            self.next_layer.clear();
485            if self.cur_layer.is_empty() {
486                None
487            } else {
488                Some(Ok(self.cur_layer.clone()))
489            }
490        }
491    }
492}
493
494/// Collect runs that match a filter function given edge colors.
495///
496/// A bicolor run is a list of groups of nodes connected by edges of exactly
497/// two colors. In addition, all nodes in the group must match the given
498/// condition. Each node in the graph can appear in only a single group
499/// in the bicolor run.
500///
501/// # Arguments:
502///
503/// * `graph`: The DAG to find bicolor runs in
504/// * `filter_fn`: The filter function to use for matching nodes. It takes
505///   in one argument, the node data payload/weight object, and will return a
506///   boolean whether the node matches the conditions or not.
507///   If it returns ``true``, it will continue the bicolor chain.
508///   If it returns ``false``, it will stop the bicolor chain.
509///   If it returns ``None`` it will skip that node.
510/// * `color_fn`: The function that gives the color of the edge. It takes
511///   in one argument, the edge data payload/weight object, and will
512///   return a non-negative integer, the edge color. If the color is None,
513///   the edge is ignored.
514///
515/// # Returns:
516///
517/// * `Vec<Vec<G::NodeId>>`: a list of groups with exactly two edge colors, where each group
518///   is a list of node data payload/weight for the nodes in the bicolor run
519/// * `None` if a cycle is found in the graph
520/// * Raises an error if found computing the bicolor runs
521///
522/// # Example:
523///
524/// ```rust
525/// use rustworkx_core::dag_algo::collect_bicolor_runs;
526/// use petgraph::graph::{DiGraph, NodeIndex};
527/// use std::convert::Infallible;
528/// use std::error::Error;
529///
530/// let mut graph = DiGraph::new();
531/// let n0 = graph.add_node(0);
532/// let n1 = graph.add_node(0);
533/// let n2 = graph.add_node(1);
534/// let n3 = graph.add_node(1);
535/// let n4 = graph.add_node(0);
536/// let n5 = graph.add_node(0);
537/// graph.add_edge(n0, n2, 0);
538/// graph.add_edge(n1, n2, 1);
539/// graph.add_edge(n2, n3, 0);
540/// graph.add_edge(n2, n3, 1);
541/// graph.add_edge(n3, n4, 0);
542/// graph.add_edge(n3, n5, 1);
543///
544/// let filter_fn = |node_id| -> Result<Option<bool>, Infallible> {
545///     Ok(Some(*graph.node_weight(node_id).unwrap() > 0))
546/// };
547///
548/// let color_fn = |edge_id| -> Result<Option<usize>, Infallible> {
549///     Ok(Some(*graph.edge_weight(edge_id).unwrap() as usize))
550/// };
551///
552/// let result = collect_bicolor_runs(&graph, filter_fn, color_fn).unwrap();
553/// let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]];
554/// assert_eq!(result, Some(expected))
555/// ```
556pub fn collect_bicolor_runs<G, F, C, E>(
557    graph: G,
558    filter_fn: F,
559    color_fn: C,
560) -> Result<Option<Vec<Vec<G::NodeId>>>, E>
561where
562    F: Fn(<G as GraphBase>::NodeId) -> Result<Option<bool>, E>,
563    C: Fn(<G as GraphBase>::EdgeId) -> Result<Option<usize>, E>,
564    G: IntoNodeIdentifiers // Used in toposort
565        + IntoNeighborsDirected // Used in toposort
566        + IntoEdgesDirected // Used for .edges_directed
567        + Visitable // Used in toposort
568        + DataMap, // Used for .node_weight
569    <G as GraphBase>::NodeId: Eq + Hash,
570    <G as GraphBase>::EdgeId: Eq + Hash,
571{
572    let mut pending_list: Vec<Vec<G::NodeId>> = Vec::new();
573    let mut block_id: Vec<Option<usize>> = Vec::new();
574    let mut block_list: Vec<Vec<G::NodeId>> = Vec::new();
575
576    let nodes = match algo::toposort(graph, None) {
577        Ok(nodes) => nodes,
578        Err(_) => return Ok(None), // Return None if the graph contains a cycle
579    };
580
581    // Utility for ensuring pending_list has the color index
582    macro_rules! ensure_vector_has_index {
583        ($pending_list: expr, $block_id: expr, $color: expr) => {
584            if $color >= $pending_list.len() {
585                $pending_list.resize($color + 1, Vec::new());
586                $block_id.resize($color + 1, None);
587            }
588        };
589    }
590
591    for node in nodes {
592        if let Some(is_match) = filter_fn(node)? {
593            let raw_edges = graph.edges_directed(node, petgraph::Direction::Outgoing);
594
595            // Remove all edges that yield errors from color_fn
596            let colors = raw_edges
597                .map(|edge| color_fn(edge.id()))
598                .collect::<Result<Vec<Option<usize>>, _>>()?;
599
600            // Remove null edges from color_fn
601            let colors = colors.into_iter().flatten().collect::<Vec<usize>>();
602
603            match (colors.len(), is_match) {
604                (1, true) => {
605                    let c0 = colors[0];
606                    ensure_vector_has_index!(pending_list, block_id, c0);
607                    if let Some(c0_block_id) = block_id[c0] {
608                        block_list[c0_block_id].push(node);
609                    } else {
610                        pending_list[c0].push(node);
611                    }
612                }
613                (2, true) => {
614                    let c0 = colors[0];
615                    let c1 = colors[1];
616                    ensure_vector_has_index!(pending_list, block_id, c0);
617                    ensure_vector_has_index!(pending_list, block_id, c1);
618
619                    if block_id[c0].is_some()
620                        && block_id[c1].is_some()
621                        && block_id[c0] == block_id[c1]
622                    {
623                        block_list[block_id[c0].unwrap_or_default()].push(node);
624                    } else {
625                        let mut new_block: Vec<G::NodeId> =
626                            Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1);
627
628                        // Clears pending list and add to new block
629                        new_block.append(&mut pending_list[c0]);
630                        new_block.append(&mut pending_list[c1]);
631
632                        new_block.push(node);
633
634                        // Create new block, assign its id to color pair
635                        block_id[c0] = Some(block_list.len());
636                        block_id[c1] = Some(block_list.len());
637                        block_list.push(new_block);
638                    }
639                }
640                _ => {
641                    for color in colors {
642                        ensure_vector_has_index!(pending_list, block_id, color);
643                        if let Some(color_block_id) = block_id[color] {
644                            block_list[color_block_id].append(&mut pending_list[color]);
645                        }
646                        block_id[color] = None;
647                        pending_list[color].clear();
648                    }
649                }
650            }
651        }
652    }
653
654    Ok(Some(block_list))
655}
656
657/// Collect runs that match a filter function
658///
659/// A run is a path of nodes where there is only a single successor and all
660/// nodes in the path match the given condition. Each node in the graph can
661/// appear in only a single run.
662///
663/// # Arguments:
664///
665/// * `graph`: The DAG to collect runs from
666/// * `include_node_fn`: A filter function used for matching nodes. It takes
667///   in one argument, the node data payload/weight object, and returns a
668///   boolean whether the node matches the conditions or not.
669///   If it returns ``false``, the node will be skipped, cutting the run it's part of.
670///
671/// # Returns:
672///
673/// * An Iterator object for extracting the runs one by one. Each run is of type `Result<Vec<G::NodeId>>`.
674/// * `None` if a cycle is found in the graph.
675///
676/// # Example
677///
678/// ```rust
679/// use petgraph::graph::DiGraph;
680/// use rustworkx_core::dag_algo::collect_runs;
681///
682/// let mut graph: DiGraph<i32, ()> = DiGraph::new();
683/// let n1 = graph.add_node(-1);
684/// let n2 = graph.add_node(2);
685/// let n3 = graph.add_node(3);
686/// graph.add_edge(n1, n2, ());
687/// graph.add_edge(n1, n3, ());
688///
689/// let positive_payload = |n| -> Result<bool, ()> {Ok(*graph.node_weight(n).expect("i32") > 0)};
690/// let mut runs = collect_runs(&graph, positive_payload).expect("Some");
691///
692/// assert_eq!(runs.next(), Some(Ok(vec![n3])));
693/// assert_eq!(runs.next(), Some(Ok(vec![n2])));
694/// assert_eq!(runs.next(), None);
695/// ```
696///
697pub fn collect_runs<G, F, E>(
698    graph: G,
699    include_node_fn: F,
700) -> Option<impl Iterator<Item = Result<Vec<G::NodeId>, E>>>
701where
702    G: GraphProp<EdgeType = Directed>
703        + IntoNeighborsDirected
704        + IntoNodeIdentifiers
705        + Visitable
706        + NodeCount,
707    F: Fn(G::NodeId) -> Result<bool, E>,
708    <G as GraphBase>::NodeId: Hash + Eq,
709{
710    let mut nodes = match algo::toposort(graph, None) {
711        Ok(nodes) => nodes,
712        Err(_) => return None,
713    };
714
715    nodes.reverse(); // reversing so that pop() in Runs::next obeys the topological order
716
717    let runs = Runs {
718        graph,
719        seen: HashSet::with_capacity(nodes.len()),
720        sorted_nodes: nodes,
721        include_node_fn,
722    };
723
724    Some(runs)
725}
726
727/// Auxiliary struct to make the output of [`collect_runs`] iterable
728///
729/// If the filtering function passed to [`collect_runs`] returns an error, it is propagated
730/// through `next` as `Err`. In this case the run in which the error occurred will be skipped
731/// but the iterator can be used further until consumed.
732///
733struct Runs<G, F, E>
734where
735    G: GraphProp<EdgeType = Directed>
736        + IntoNeighborsDirected
737        + IntoNodeIdentifiers
738        + Visitable
739        + NodeCount,
740    F: Fn(G::NodeId) -> Result<bool, E>,
741{
742    graph: G,
743    sorted_nodes: Vec<G::NodeId>, // topologically-sorted nodes
744    seen: HashSet<G::NodeId>,
745    include_node_fn: F, // filtering function of the nodes
746}
747
748impl<G, F, E> Iterator for Runs<G, F, E>
749where
750    G: GraphProp<EdgeType = Directed>
751        + IntoNeighborsDirected
752        + IntoNodeIdentifiers
753        + Visitable
754        + NodeCount,
755    F: Fn(G::NodeId) -> Result<bool, E>,
756    <G as GraphBase>::NodeId: Hash + Eq,
757{
758    // This is a run, wrapped in Result for catching filter function errors
759    type Item = Result<Vec<G::NodeId>, E>;
760
761    fn next(&mut self) -> Option<Self::Item> {
762        while let Some(node) = self.sorted_nodes.pop() {
763            if self.seen.contains(&node) {
764                continue;
765            }
766            self.seen.insert(node);
767
768            match (self.include_node_fn)(node) {
769                Ok(false) => continue,
770                Err(e) => return Some(Err(e)),
771                _ => (),
772            }
773
774            let mut run: Vec<G::NodeId> = vec![node];
775            loop {
776                let mut successors: Vec<G::NodeId> = self
777                    .graph
778                    .neighbors_directed(*run.last().unwrap(), petgraph::Direction::Outgoing)
779                    .collect();
780                successors.dedup();
781
782                if successors.len() != 1 || self.seen.contains(&successors[0]) {
783                    break;
784                }
785
786                self.seen.insert(successors[0]);
787
788                match (self.include_node_fn)(successors[0]) {
789                    Ok(false) => continue,
790                    Err(e) => return Some(Err(e)),
791                    _ => (),
792                }
793
794                run.push(successors[0]);
795            }
796
797            if !run.is_empty() {
798                return Some(Ok(run));
799            }
800        }
801
802        None
803    }
804
805    fn size_hint(&self) -> (usize, Option<usize>) {
806        // Lower bound is 0 in case all remaining nodes are filtered out
807        // Upper bound is the remaining unprocessed nodes (some of which may be seen already), potentially all resulting with singleton runs
808        (0, Some(self.sorted_nodes.len()))
809    }
810}
811
812// Tests for longest_path
813#[cfg(test)]
814mod test_longest_path {
815    use super::*;
816    use petgraph::graph::DiGraph;
817    use petgraph::stable_graph::StableDiGraph;
818
819    #[test]
820    fn test_empty_graph() {
821        let graph: DiGraph<(), ()> = DiGraph::new();
822        let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
823        let result = longest_path(&graph, weight_fn);
824        assert_eq!(result, Ok(Some((vec![], 0))));
825    }
826
827    #[test]
828    fn test_single_node_graph() {
829        let mut graph: DiGraph<(), ()> = DiGraph::new();
830        let n0 = graph.add_node(());
831        let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::<i32, &str>(0);
832        let result = longest_path(&graph, weight_fn);
833        assert_eq!(result, Ok(Some((vec![n0], 0))));
834    }
835
836    #[test]
837    fn test_dag_with_multiple_paths() {
838        let mut graph: DiGraph<(), i32> = DiGraph::new();
839        let n0 = graph.add_node(());
840        let n1 = graph.add_node(());
841        let n2 = graph.add_node(());
842        let n3 = graph.add_node(());
843        let n4 = graph.add_node(());
844        let n5 = graph.add_node(());
845        graph.add_edge(n0, n1, 3);
846        graph.add_edge(n0, n2, 2);
847        graph.add_edge(n1, n2, 1);
848        graph.add_edge(n1, n3, 4);
849        graph.add_edge(n2, n3, 2);
850        graph.add_edge(n3, n4, 2);
851        graph.add_edge(n2, n5, 1);
852        graph.add_edge(n4, n5, 3);
853        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
854        let result = longest_path(&graph, weight_fn);
855        assert_eq!(result, Ok(Some((vec![n0, n1, n3, n4, n5], 12))));
856    }
857
858    #[test]
859    fn test_graph_with_cycle() {
860        let mut graph: DiGraph<(), i32> = DiGraph::new();
861        let n0 = graph.add_node(());
862        let n1 = graph.add_node(());
863        graph.add_edge(n0, n1, 1);
864        graph.add_edge(n1, n0, 1); // Creates a cycle
865
866        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
867        let result = longest_path(&graph, weight_fn);
868        assert_eq!(result, Ok(None));
869    }
870
871    #[test]
872    fn test_negative_weights() {
873        let mut graph: DiGraph<(), i32> = DiGraph::new();
874        let n0 = graph.add_node(());
875        let n1 = graph.add_node(());
876        let n2 = graph.add_node(());
877        graph.add_edge(n0, n1, -1);
878        graph.add_edge(n0, n2, 2);
879        graph.add_edge(n1, n2, -2);
880        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| Ok::<i32, &str>(*edge.weight());
881        let result = longest_path(&graph, weight_fn);
882        assert_eq!(result, Ok(Some((vec![n0, n2], 2))));
883    }
884
885    #[test]
886    fn test_longest_path_in_stable_digraph() {
887        let mut graph: StableDiGraph<(), i32> = StableDiGraph::new();
888        let n0 = graph.add_node(());
889        let n1 = graph.add_node(());
890        let n2 = graph.add_node(());
891        graph.add_edge(n0, n1, 1);
892        graph.add_edge(n0, n2, 3);
893        graph.add_edge(n1, n2, 1);
894        let weight_fn =
895            |edge: petgraph::stable_graph::EdgeReference<'_, i32>| Ok::<i32, &str>(*edge.weight());
896        let result = longest_path(&graph, weight_fn);
897        assert_eq!(result, Ok(Some((vec![n0, n2], 3))));
898    }
899
900    #[test]
901    fn test_error_handling() {
902        let mut graph: DiGraph<(), i32> = DiGraph::new();
903        let n0 = graph.add_node(());
904        let n1 = graph.add_node(());
905        let n2 = graph.add_node(());
906        graph.add_edge(n0, n1, 1);
907        graph.add_edge(n0, n2, 2);
908        graph.add_edge(n1, n2, 1);
909        let weight_fn = |edge: petgraph::graph::EdgeReference<i32>| {
910            if *edge.weight() == 2 {
911                Err("Error: edge weight is 2")
912            } else {
913                Ok::<i32, &str>(*edge.weight())
914            }
915        };
916        let result = longest_path(&graph, weight_fn);
917        assert_eq!(result, Err("Error: edge weight is 2"));
918    }
919}
920
921// Tests for lexicographical_topological_sort
922#[cfg(test)]
923mod test_lexicographical_topological_sort {
924    use super::*;
925    use petgraph::graph::{DiGraph, NodeIndex};
926    use petgraph::stable_graph::StableDiGraph;
927    use std::convert::Infallible;
928
929    #[test]
930    fn test_empty_graph() {
931        let graph: DiGraph<(), ()> = DiGraph::new();
932        let sort_fn = |_: NodeIndex| -> Result<String, Infallible> { Ok("a".to_string()) };
933        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
934        assert_eq!(result, Ok(vec![]));
935    }
936
937    #[test]
938    fn test_empty_stable_graph() {
939        let graph: StableDiGraph<(), ()> = StableDiGraph::new();
940        let sort_fn = |_: NodeIndex| -> Result<String, Infallible> { Ok("a".to_string()) };
941        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
942        assert_eq!(result, Ok(vec![]));
943    }
944
945    #[test]
946    fn test_simple_layer() {
947        let mut graph: DiGraph<String, ()> = DiGraph::new();
948        let mut nodes: Vec<NodeIndex> = Vec::new();
949        nodes.push(graph.add_node("a".to_string()));
950        for i in 0..5 {
951            nodes.push(graph.add_node(i.to_string()));
952        }
953        nodes.push(graph.add_node("A parent".to_string()));
954        for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
955            graph.add_edge(nodes[source], nodes[target], ());
956        }
957        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
958        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
959        assert_eq!(
960            result,
961            Ok(vec![
962                NodeIndex::new(6),
963                NodeIndex::new(0),
964                NodeIndex::new(1),
965                NodeIndex::new(2),
966                NodeIndex::new(3),
967                NodeIndex::new(4),
968                NodeIndex::new(5)
969            ])
970        )
971    }
972
973    #[test]
974    fn test_simple_layer_stable() {
975        let mut graph: StableDiGraph<String, ()> = StableDiGraph::new();
976        let mut nodes: Vec<NodeIndex> = Vec::new();
977        nodes.push(graph.add_node("a".to_string()));
978        for i in 0..5 {
979            nodes.push(graph.add_node(i.to_string()));
980        }
981        nodes.push(graph.add_node("A parent".to_string()));
982        for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
983            graph.add_edge(nodes[source], nodes[target], ());
984        }
985        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
986        let result = lexicographical_topological_sort(&graph, sort_fn, false, None);
987        assert_eq!(
988            result,
989            Ok(vec![
990                NodeIndex::new(6),
991                NodeIndex::new(0),
992                NodeIndex::new(1),
993                NodeIndex::new(2),
994                NodeIndex::new(3),
995                NodeIndex::new(4),
996                NodeIndex::new(5)
997            ])
998        )
999    }
1000
1001    #[test]
1002    fn test_reverse_graph() {
1003        let mut graph: DiGraph<String, ()> = DiGraph::new();
1004        let mut nodes: Vec<NodeIndex> = Vec::new();
1005        for weight in ["a", "b", "c", "d", "e", "f"] {
1006            nodes.push(graph.add_node(weight.to_string()));
1007        }
1008        let edges = [
1009            (nodes[0], nodes[1]),
1010            (nodes[0], nodes[2]),
1011            (nodes[1], nodes[3]),
1012            (nodes[2], nodes[3]),
1013            (nodes[1], nodes[4]),
1014            (nodes[2], nodes[5]),
1015        ];
1016
1017        for (source, target) in edges {
1018            graph.add_edge(source, target, ());
1019        }
1020        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1021        let result = lexicographical_topological_sort(&graph, sort_fn, true, None);
1022        graph.reverse();
1023        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1024        let expected = lexicographical_topological_sort(&graph, sort_fn, false, None);
1025        assert_eq!(result, expected,)
1026    }
1027
1028    #[test]
1029    fn test_reverse_graph_stable() {
1030        let mut graph: StableDiGraph<String, ()> = StableDiGraph::new();
1031        let mut nodes: Vec<NodeIndex> = Vec::new();
1032        for weight in ["a", "b", "c", "d", "e", "f"] {
1033            nodes.push(graph.add_node(weight.to_string()));
1034        }
1035        let edges = [
1036            (nodes[0], nodes[1]),
1037            (nodes[0], nodes[2]),
1038            (nodes[1], nodes[3]),
1039            (nodes[2], nodes[3]),
1040            (nodes[1], nodes[4]),
1041            (nodes[2], nodes[5]),
1042        ];
1043
1044        for (source, target) in edges {
1045            graph.add_edge(source, target, ());
1046        }
1047        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1048        let result = lexicographical_topological_sort(&graph, sort_fn, true, None);
1049        graph.reverse();
1050        let sort_fn = |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].clone()) };
1051        let expected = lexicographical_topological_sort(&graph, sort_fn, false, None);
1052        assert_eq!(result, expected,)
1053    }
1054
1055    #[test]
1056    fn test_initial() {
1057        let mut graph: StableDiGraph<u8, ()> = StableDiGraph::new();
1058        let mut nodes: Vec<NodeIndex> = Vec::new();
1059        for weight in 0..9 {
1060            nodes.push(graph.add_node(weight));
1061        }
1062        let edges = [
1063            (nodes[0], nodes[1]),
1064            (nodes[0], nodes[2]),
1065            (nodes[1], nodes[3]),
1066            (nodes[2], nodes[4]),
1067            (nodes[3], nodes[4]),
1068            (nodes[4], nodes[5]),
1069            (nodes[5], nodes[6]),
1070            (nodes[4], nodes[7]),
1071            (nodes[6], nodes[8]),
1072            (nodes[7], nodes[8]),
1073        ];
1074        for (source, target) in edges {
1075            graph.add_edge(source, target, ());
1076        }
1077        let sort_fn =
1078            |index: NodeIndex| -> Result<String, Infallible> { Ok(graph[index].to_string()) };
1079        let initial = [nodes[6], nodes[7]];
1080        let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
1081        assert_eq!(result, Ok(vec![nodes[6], nodes[7], nodes[8]]));
1082        let initial = [nodes[0]];
1083        let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
1084        assert_eq!(
1085            result,
1086            lexicographical_topological_sort(&graph, sort_fn, false, None)
1087        );
1088        let initial = [nodes[7]];
1089        let result = lexicographical_topological_sort(&graph, sort_fn, false, Some(&initial));
1090        assert_eq!(result, Ok(vec![nodes[7]]));
1091    }
1092}
1093
1094#[cfg(test)]
1095mod test_layers {
1096    use super::*;
1097    use petgraph::{
1098        graph::{DiGraph, NodeIndex},
1099        stable_graph::StableDiGraph,
1100    };
1101
1102    #[test]
1103    fn test_empty_graph() {
1104        let graph: DiGraph<(), ()> = DiGraph::new();
1105        let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![]).flatten().collect();
1106        assert_eq!(result, vec![vec![]]);
1107    }
1108
1109    #[test]
1110    fn test_empty_stable_graph() {
1111        let graph: StableDiGraph<(), ()> = StableDiGraph::new();
1112        let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![]).flatten().collect();
1113        assert_eq!(result, vec![vec![]]);
1114    }
1115
1116    #[test]
1117    fn test_simple_layer() {
1118        let mut graph: DiGraph<String, ()> = DiGraph::new();
1119        let mut nodes: Vec<NodeIndex> = Vec::new();
1120        nodes.push(graph.add_node("a".to_string()));
1121        for i in 0..5 {
1122            nodes.push(graph.add_node(i.to_string()));
1123        }
1124        nodes.push(graph.add_node("A parent".to_string()));
1125        for (source, target) in [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 3)] {
1126            graph.add_edge(nodes[source], nodes[target], ());
1127        }
1128        let expected: Vec<Vec<NodeIndex>> = vec![
1129            vec![0.into(), 6.into()],
1130            vec![5.into(), 4.into(), 2.into(), 1.into(), 3.into()],
1131        ];
1132        let result: Vec<Vec<NodeIndex>> =
1133            layers(&graph, vec![0.into(), 6.into()]).flatten().collect();
1134        assert_eq!(result, expected);
1135    }
1136
1137    #[test]
1138    #[should_panic]
1139    fn test_missing_node() {
1140        let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
1141        let graph = DiGraph::<u32, u32>::from_edges(&edge_list);
1142        layers(&graph, vec![4.into(), 5.into()]).for_each(|layer| match layer {
1143            Err(e) => panic!("{}", e.0),
1144            Ok(layer) => drop(layer),
1145        });
1146    }
1147
1148    #[test]
1149    fn test_dag_with_multiple_paths() {
1150        let mut graph: DiGraph<(), ()> = DiGraph::new();
1151        let n0 = graph.add_node(());
1152        let n1 = graph.add_node(());
1153        let n2 = graph.add_node(());
1154        let n3 = graph.add_node(());
1155        let n4 = graph.add_node(());
1156        let n5 = graph.add_node(());
1157        graph.add_edge(n0, n1, ());
1158        graph.add_edge(n0, n2, ());
1159        graph.add_edge(n1, n2, ());
1160        graph.add_edge(n1, n3, ());
1161        graph.add_edge(n2, n3, ());
1162        graph.add_edge(n3, n4, ());
1163        graph.add_edge(n2, n5, ());
1164        graph.add_edge(n4, n5, ());
1165
1166        let result: Vec<Vec<NodeIndex>> = layers(&graph, vec![0.into()]).flatten().collect();
1167        assert_eq!(
1168            result,
1169            vec![vec![n0], vec![n1], vec![n2], vec![n3], vec![n4], vec![n5]]
1170        );
1171    }
1172
1173    #[test]
1174    #[should_panic]
1175    fn test_graph_with_cycle() {
1176        let mut graph: DiGraph<(), i32> = DiGraph::new();
1177        let n0 = graph.add_node(());
1178        let n1 = graph.add_node(());
1179        graph.add_edge(n0, n1, 1);
1180        graph.add_edge(n1, n0, 1);
1181
1182        layers(&graph, vec![0.into()]).for_each(|layer| match layer {
1183            Err(e) => panic!("{}", e.0),
1184            Ok(layer) => drop(layer),
1185        });
1186    }
1187}
1188
1189// Tests for collect_bicolor_runs
1190#[cfg(test)]
1191mod test_collect_bicolor_runs {
1192
1193    use super::*;
1194    use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
1195    use std::error::Error;
1196
1197    #[test]
1198    fn test_cycle() {
1199        let mut graph = DiGraph::new();
1200        let n0 = graph.add_node(0);
1201        let n1 = graph.add_node(0);
1202        let n2 = graph.add_node(0);
1203        graph.add_edge(n0, n1, 1);
1204        graph.add_edge(n1, n2, 1);
1205        graph.add_edge(n2, n0, 1);
1206
1207        let test_filter_fn =
1208            |_node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> { Ok(Some(true)) };
1209        let test_color_fn =
1210            |_edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> { Ok(Some(1)) };
1211        let result = match collect_bicolor_runs(&graph, test_filter_fn, test_color_fn) {
1212            Ok(Some(_value)) => "Not None",
1213            Ok(None) => "None",
1214            Err(_) => "Error",
1215        };
1216        assert_eq!(result, "None")
1217    }
1218
1219    #[test]
1220    fn test_filter_function_inner_exception() {
1221        let mut graph = DiGraph::new();
1222        graph.add_node(0);
1223
1224        let fail_function = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1225            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1226            if *node_weight > 0 {
1227                Ok(Some(true))
1228            } else {
1229                Err(Box::from("Failed!"))
1230            }
1231        };
1232        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1233            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1234            Ok(Some(*edge_weight as usize))
1235        };
1236        let result = match collect_bicolor_runs(&graph, fail_function, test_color_fn) {
1237            Ok(Some(_value)) => "Not None",
1238            Ok(None) => "None",
1239            Err(_) => "Error",
1240        };
1241        assert_eq!(result, "Error")
1242    }
1243
1244    #[test]
1245    fn test_empty() {
1246        let graph = DiGraph::new();
1247        let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1248            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1249            Ok(Some(*node_weight > 1))
1250        };
1251        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1252            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1253            Ok(Some(*edge_weight as usize))
1254        };
1255        let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
1256        let expected: Vec<Vec<NodeIndex>> = vec![];
1257        assert_eq!(result, Some(expected))
1258    }
1259
1260    #[test]
1261    fn test_two_colors() {
1262        /* Based on the following graph from the Python unit tests:
1263        Input:
1264                ┌─────────────┐                 ┌─────────────┐
1265                │             │                 │             │
1266                │    q0       │                 │    q1       │
1267                │             │                 │             │
1268                └───┬─────────┘                 └──────┬──────┘
1269                    │          ┌─────────────┐         │
1270                q0  │          │             │         │  q1
1271                    │          │             │         │
1272                    └─────────►│     cx      │◄────────┘
1273                    ┌──────────┤             ├─────────┐
1274                    │          │             │         │
1275                q0  │          └─────────────┘         │  q1
1276                    │                                  │
1277                    │          ┌─────────────┐         │
1278                    │          │             │         │
1279                    └─────────►│      cz     │◄────────┘
1280                     ┌─────────┤             ├─────────┐
1281                     │         └─────────────┘         │
1282                 q0  │                                 │ q1
1283                     │                                 │
1284                 ┌───▼─────────┐                ┌──────▼──────┐
1285                 │             │                │             │
1286                 │    q0       │                │    q1       │
1287                 │             │                │             │
1288                 └─────────────┘                └─────────────┘
1289        Expected: [[cx, cz]]
1290        */
1291        let mut graph = DiGraph::new();
1292        let n0 = graph.add_node(0); //q0
1293        let n1 = graph.add_node(0); //q1
1294        let n2 = graph.add_node(1); //cx
1295        let n3 = graph.add_node(1); //cz
1296        let n4 = graph.add_node(0); //q0_1
1297        let n5 = graph.add_node(0); //q1_1
1298        graph.add_edge(n0, n2, 0); //q0 -> cx
1299        graph.add_edge(n1, n2, 1); //q1 -> cx
1300        graph.add_edge(n2, n3, 0); //cx -> cz
1301        graph.add_edge(n2, n3, 1); //cx -> cz
1302        graph.add_edge(n3, n4, 0); //cz -> q0_1
1303        graph.add_edge(n3, n5, 1); //cz -> q1_1
1304
1305        // Filter out q0, q1, q0_1 and q1_1
1306        let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1307            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1308            Ok(Some(*node_weight > 0))
1309        };
1310        // The edge color will match its weight
1311        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1312            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1313            Ok(Some(*edge_weight as usize))
1314        };
1315        let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
1316        let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3]]; //[[cx, cz]]
1317        assert_eq!(result, Some(expected))
1318    }
1319
1320    #[test]
1321    fn test_two_colors_with_pending() {
1322        /* Based on the following graph from the Python unit tests:
1323        Input:
1324                ┌─────────────┐
1325                │             │
1326                │    q0       │
1327                │             │
1328                └───┬─────────┘
1329                 | q0
13301331                ┌───▼─────────┐
1332                │             │
1333                │    h        │
1334                │             │
1335                └───┬─────────┘
1336                    | q0
1337                    │                           ┌─────────────┐
1338                    │                           │             │
1339                    │                           │    q1       │
1340                    │                           │             │
1341                    |                           └──────┬──────┘
1342                    │          ┌─────────────┐         │
1343                q0  │          │             │         │  q1
1344                    │          │             │         │
1345                    └─────────►│     cx      │◄────────┘
1346                    ┌──────────┤             ├─────────┐
1347                    │          │             │         │
1348                q0  │          └─────────────┘         │  q1
1349                    │                                  │
1350                    │          ┌─────────────┐         │
1351                    │          │             │         │
1352                    └─────────►│      cz     │◄────────┘
1353                     ┌─────────┤             ├─────────┐
1354                     │         └─────────────┘         │
1355                 q0  │                                 │ q1
1356                     │                                 │
1357                 ┌───▼─────────┐                ┌──────▼──────┐
1358                 │             │                │             │
1359                 │    q0       │                │    y        │
1360                 │             │                │             │
1361                 └─────────────┘                └─────────────┘
1362                                                    | q1
13631364                                                ┌───▼─────────┐
1365                                                │             │
1366                                                │    q1       │
1367                                                │             │
1368                                                └─────────────┘
1369        Expected: [[h, cx, cz, y]]
1370        */
1371        let mut graph = DiGraph::new();
1372        let n0 = graph.add_node(0); //q0
1373        let n1 = graph.add_node(0); //q1
1374        let n2 = graph.add_node(1); //h
1375        let n3 = graph.add_node(1); //cx
1376        let n4 = graph.add_node(1); //cz
1377        let n5 = graph.add_node(1); //y
1378        let n6 = graph.add_node(0); //q0_1
1379        let n7 = graph.add_node(0); //q1_1
1380        graph.add_edge(n0, n2, 0); //q0 -> h
1381        graph.add_edge(n2, n3, 0); //h -> cx
1382        graph.add_edge(n1, n3, 1); //q1 -> cx
1383        graph.add_edge(n3, n4, 0); //cx -> cz
1384        graph.add_edge(n3, n4, 1); //cx -> cz
1385        graph.add_edge(n4, n6, 0); //cz -> q0_1
1386        graph.add_edge(n4, n5, 1); //cz -> y
1387        graph.add_edge(n5, n7, 1); //y -> q1_1
1388
1389        // Filter out q0, q1, q0_1 and q1_1
1390        let test_filter_fn = |node_id: NodeIndex| -> Result<Option<bool>, Box<dyn Error>> {
1391            let node_weight: &i32 = graph.node_weight(node_id).expect("Invalid NodeId");
1392            Ok(Some(*node_weight > 0))
1393        };
1394        // The edge color will match its weight
1395        let test_color_fn = |edge_id: EdgeIndex| -> Result<Option<usize>, Box<dyn Error>> {
1396            let edge_weight: &i32 = graph.edge_weight(edge_id).expect("Invalid Edge");
1397            Ok(Some(*edge_weight as usize))
1398        };
1399        let result = collect_bicolor_runs(&graph, test_filter_fn, test_color_fn).unwrap();
1400        let expected: Vec<Vec<NodeIndex>> = vec![vec![n2, n3, n4, n5]]; //[[h, cx, cz, y]]
1401        assert_eq!(result, Some(expected))
1402    }
1403}
1404
1405#[cfg(test)]
1406mod test_collect_runs {
1407    use super::collect_runs;
1408    use petgraph::{graph::DiGraph, visit::GraphBase};
1409
1410    type BareDiGraph = DiGraph<(), ()>;
1411    type RunResult = Result<Vec<<BareDiGraph as GraphBase>::NodeId>, ()>;
1412
1413    #[test]
1414    fn test_empty_graph() {
1415        let graph: BareDiGraph = DiGraph::new();
1416
1417        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1418
1419        let run = runs.next();
1420        assert!(run.is_none());
1421
1422        let runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1423
1424        let runs: Vec<RunResult> = runs.collect();
1425
1426        assert_eq!(runs, Vec::<RunResult>::new());
1427    }
1428
1429    #[test]
1430    fn test_simple_run_w_filter() {
1431        let mut graph: BareDiGraph = DiGraph::new();
1432        let n1 = graph.add_node(());
1433        let n2 = graph.add_node(());
1434        let n3 = graph.add_node(());
1435        graph.add_edge(n1, n2, ());
1436        graph.add_edge(n2, n3, ());
1437
1438        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1439
1440        let the_run = runs.next().expect("Some").expect("3 nodes");
1441        assert_eq!(the_run.len(), 3);
1442        assert_eq!(runs.next(), None);
1443
1444        assert_eq!(the_run, vec![n1, n2, n3]);
1445
1446        // Now with some filters
1447        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(false) }).expect("Some");
1448
1449        assert_eq!(runs.next(), None);
1450
1451        let mut runs = collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n2) }).expect("Some");
1452
1453        assert_eq!(runs.next(), Some(Ok(vec![n1])));
1454        assert_eq!(runs.next(), Some(Ok(vec![n3])));
1455    }
1456
1457    #[test]
1458    fn test_multiple_runs_w_filter() {
1459        let mut graph: BareDiGraph = DiGraph::new();
1460        let n1 = graph.add_node(());
1461        let n2 = graph.add_node(());
1462        let n3 = graph.add_node(());
1463        let n4 = graph.add_node(());
1464        let n5 = graph.add_node(());
1465        let n6 = graph.add_node(());
1466        let n7 = graph.add_node(());
1467
1468        graph.add_edge(n1, n2, ());
1469        graph.add_edge(n2, n3, ());
1470        graph.add_edge(n3, n7, ());
1471        graph.add_edge(n4, n3, ());
1472        graph.add_edge(n4, n7, ());
1473        graph.add_edge(n5, n4, ());
1474        graph.add_edge(n6, n5, ());
1475
1476        let runs: Vec<RunResult> = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) })
1477            .expect("Some")
1478            .collect();
1479
1480        assert_eq!(runs, vec![Ok(vec![n6, n5, n4]), Ok(vec![n1, n2, n3, n7])]);
1481
1482        // And now with some filter
1483        let runs: Vec<RunResult> =
1484            collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n4 && n != n2) })
1485                .expect("Some")
1486                .collect();
1487
1488        assert_eq!(runs, vec![Ok(vec![n6, n5]), Ok(vec![n1]), Ok(vec![n3, n7])]);
1489    }
1490
1491    #[test]
1492    fn test_singleton_runs_w_filter() {
1493        let mut graph: BareDiGraph = DiGraph::new();
1494        let n1 = graph.add_node(());
1495        let n2 = graph.add_node(());
1496        let n3 = graph.add_node(());
1497
1498        graph.add_edge(n1, n2, ());
1499        graph.add_edge(n1, n3, ());
1500
1501        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");
1502
1503        assert_eq!(runs.next().expect("n1"), Ok(vec![n1]));
1504        assert_eq!(runs.next().expect("n3"), Ok(vec![n3]));
1505        assert_eq!(runs.next().expect("n2"), Ok(vec![n2]));
1506
1507        // And now with some filter
1508        let runs: Vec<RunResult> = collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n1) })
1509            .expect("Some")
1510            .collect();
1511
1512        assert_eq!(runs, vec![Ok(vec![n3]), Ok(vec![n2])]);
1513    }
1514
1515    #[test]
1516    fn test_error_propagation() {
1517        let mut graph: BareDiGraph = DiGraph::new();
1518        graph.add_node(());
1519
1520        let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Err(()) }).expect("Some");
1521
1522        assert!(runs.next().expect("Some").is_err());
1523        assert_eq!(runs.next(), None);
1524    }
1525}