sunscreen_compiler_common/
graph.rs

1use std::collections::HashSet;
2
3use petgraph::{
4    dot::Dot,
5    stable_graph::{EdgeReference, Edges, Neighbors, NodeIndex, StableGraph},
6    visit::{EdgeRef, IntoNodeIdentifiers},
7    Directed, Direction,
8};
9use static_assertions::const_assert;
10use thiserror::Error;
11
12use crate::{EdgeInfo, NodeInfo, Operation, Render};
13
14/**
15 * A wrapper for ascertaining the structure of the underlying graph.
16 * This type is used in [`forward_traverse`] and
17 * [`reverse_traverse`] callbacks.
18 */
19pub struct GraphQuery<'a, N, E>(&'a StableGraph<N, E>);
20
21impl<'a, N, E> From<&'a StableGraph<N, E>> for GraphQuery<'a, N, E> {
22    fn from(x: &'a StableGraph<N, E>) -> Self {
23        Self(x)
24    }
25}
26
27impl<'a, N, E> GraphQuery<'a, N, E> {
28    /**
29     * Creates a new [`GraphQuery`] from a reference to a
30     * [`StableGraph`].
31     */
32    pub fn new(ir: &'a StableGraph<N, E>) -> Self {
33        Self(ir)
34    }
35
36    /**
37     * Gets a node from its index.
38     */
39    pub fn get_node(&self, x: NodeIndex) -> Option<&N> {
40        self.0.node_weight(x)
41    }
42
43    /**
44     * Gets information about the immediate parent or child nodes of
45     * the node at the given index.
46     *
47     * # Remarks
48     * [`Direction::Outgoing`] gives children, while
49     * [`Direction::Incoming`] gives parents.
50     */
51    pub fn neighbors_directed(&self, x: NodeIndex, direction: Direction) -> Neighbors<E> {
52        self.0.neighbors_directed(x, direction)
53    }
54
55    /**
56     * Gets edges pointing at the parent or child nodes of the node at
57     * the given index.
58     *
59     * # Remarks
60     * [`Direction::Outgoing`] gives children, while
61     * [`Direction::Incoming`] gives parents.
62     */
63    pub fn edges_directed(&self, x: NodeIndex, direction: Direction) -> Edges<E, Directed> {
64        self.0.edges_directed(x, direction)
65    }
66}
67
68/**
69 * A list of transformations that should be applied to the graph.
70 */
71pub trait TransformList<N, E>
72where
73    N: Clone,
74    E: Clone,
75{
76    /**
77     * Apply the transformations and return any added nodes.
78     *
79     * # Remarks
80     * This consumes the transform list.
81     */
82    fn apply(self, graph: &mut StableGraph<N, E>) -> Vec<NodeIndex>;
83}
84
85// Make a surrogate implementation of the trait for traversal functions
86// that don't mutate the graph.
87impl<N, E> TransformList<N, E> for ()
88where
89    N: Clone,
90    E: Clone,
91{
92    fn apply(self, _graph: &mut StableGraph<N, E>) -> Vec<NodeIndex> {
93        vec![]
94    }
95}
96
97/**
98 * Call the supplied callback for each node in the given graph in
99 * topological order.
100 */
101pub fn forward_traverse<N, E, F, Err>(graph: &StableGraph<N, E>, callback: F) -> Result<(), Err>
102where
103    N: Clone,
104    E: Clone,
105    F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<(), Err>,
106{
107    let graph: *const StableGraph<N, E> = graph;
108
109    // Traverse won't mutate the graph since F returns ().
110    unsafe { traverse(graph as *mut StableGraph<N, E>, true, callback) }
111}
112
113/**
114 * Call the supplied callback for each node in the given graph in
115 * reverse topological order.
116 */
117pub fn reverse_traverse<N, E, F, Err>(graph: &StableGraph<N, E>, callback: F) -> Result<(), Err>
118where
119    N: Clone,
120    E: Clone,
121    F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<(), Err>,
122{
123    let graph: *const StableGraph<N, E> = graph;
124
125    // Traverse won't mutate the graph since F returns ().
126    unsafe { traverse(graph as *mut StableGraph<N, E>, false, callback) }
127}
128
129/**
130 * A specialized topological DAG traversal that allows the following graph
131 * mutations during traversal:
132 * * Delete the current node
133 * * Insert nodes after current node
134 * * Add new nodes with no dependencies
135 *
136 * Any other graph mutation will likely result in unvisited nodes.
137 *
138 * * `callback`: A closure that receives the current node index and an
139 *   object allowing you to make graph queries. This closure returns a    
140 *   transform list or an error.
141 *   On success, [`reverse_traverse`] will apply these transformations
142 *   before continuing the traversal. Errors will be propagated to the
143 *   caller.
144 */
145pub fn forward_traverse_mut<N, E, F, T, Err>(
146    graph: &mut StableGraph<N, E>,
147    callback: F,
148) -> Result<(), Err>
149where
150    N: Clone,
151    E: Clone,
152    T: TransformList<N, E>,
153    F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<T, Err>,
154{
155    unsafe { traverse(graph, true, callback) }
156}
157
158/**
159 * A specialized reverse topological DAG traversal that allows the following graph
160 * mutations during traversal:
161 * * Delete the current node
162 * * Insert nodes after current node
163 * * Add new nodes with no dependencies
164 *
165 * Any other graph mutation will likely result in unvisited nodes.
166 *
167 * * `callback`: A closure that receives the current node index and an
168 *   object allowing you to make graph queries. This closure returns a    
169 *   transform list or an error.
170 *   On success, [`reverse_traverse`] will apply these transformations
171 *   before continuing the traversal. Errors will be propagated to the
172 *   caller.
173 */
174pub fn reverse_traverse_mut<N, E, F, T, Err>(
175    graph: &mut StableGraph<N, E>,
176    callback: F,
177) -> Result<(), Err>
178where
179    N: Clone,
180    E: Clone,
181    T: TransformList<N, E>,
182    F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<T, Err>,
183{
184    unsafe { traverse(graph, false, callback) }
185}
186
187/**
188 * Internal traversal implementation that allows for mutable traversal.
189 * If the callback always returns an empty transform list or (), then
190 * graph won't be mutated.
191 */
192unsafe fn traverse<N, E, T, F, Err>(
193    graph: *mut StableGraph<N, E>,
194    forward: bool,
195    mut callback: F,
196) -> Result<(), Err>
197where
198    N: Clone,
199    E: Clone,
200    F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<T, Err>,
201    T: TransformList<N, E>,
202{
203    // The one unsafe line in the function...
204    let graph = &mut *graph;
205    let mut ready: HashSet<NodeIndex> = HashSet::new();
206    let mut visited: HashSet<NodeIndex> = HashSet::new();
207    let prev_direction = if forward {
208        Direction::Incoming
209    } else {
210        Direction::Outgoing
211    };
212    let next_direction = if forward {
213        Direction::Outgoing
214    } else {
215        Direction::Incoming
216    };
217
218    let mut ready_nodes: Vec<NodeIndex> = graph
219        .node_identifiers()
220        .filter(|&x| graph.neighbors_directed(x, prev_direction).next().is_none())
221        .collect();
222
223    ready.extend(ready_nodes.iter());
224
225    while let Some(n) = ready_nodes.pop() {
226        visited.insert(n);
227
228        // Remember the next nodes from the current node in case it gets deleted.
229        let next_nodes: Vec<NodeIndex> = graph.neighbors_directed(n, next_direction).collect();
230
231        let transforms = callback(GraphQuery(graph), n)?;
232
233        // Apply the transforms the callback produced
234        let added_nodes = transforms.apply(graph);
235
236        let node_ready = |n: NodeIndex| {
237            graph
238                .neighbors_directed(n, prev_direction)
239                .all(|m| visited.contains(&m))
240        };
241
242        // If the node still exists, push all its ready dependents
243        if graph.contains_node(n) {
244            for i in graph.neighbors_directed(n, next_direction) {
245                if !ready.contains(&i) && node_ready(i) {
246                    ready.insert(i);
247                    ready_nodes.push(i);
248                }
249            }
250        }
251
252        // Iterate through the next nodes that existed before visiting this node.
253        for i in next_nodes {
254            if !ready.contains(&i) && node_ready(i) {
255                ready.insert(i);
256                ready_nodes.push(i);
257            }
258        }
259
260        // Check for and sources/sinks the callback may have added.
261        for i in added_nodes {
262            if graph.neighbors_directed(i, prev_direction).next().is_none() {
263                ready.insert(i);
264                ready_nodes.push(i);
265            }
266        }
267    }
268
269    Ok(())
270}
271
272impl<N, E> Render for StableGraph<N, E>
273where
274    N: Render + std::fmt::Debug,
275    E: Render + std::fmt::Debug,
276{
277    fn render(&self) -> String {
278        let data = Dot::with_attr_getters(
279            self,
280            &[
281                petgraph::dot::Config::NodeNoLabel,
282                petgraph::dot::Config::EdgeNoLabel,
283            ],
284            &|_, e| format!("label=\"{}\"", e.weight().render()),
285            &|_, n| {
286                let (index, info) = n;
287
288                format!("label=\"{}: {}\"", index.index(), info.render())
289            },
290        );
291
292        format!("{data:?}")
293    }
294}
295
296#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
297/**
298 * An error that can occur when querying various aspects about an
299 * operation graph.
300 */
301pub enum GraphQueryError {
302    #[error("The given graph node wasn't a binary operation")]
303    /**
304     * The given operation is not a binary operation.
305     */
306    NotBinaryOperation,
307
308    #[error("The given graph node wasn't a unary operation")]
309    /**
310     * The given graph node wasn't a unary operation.
311     */
312    NotUnaryOperation,
313
314    #[error("The given graph node wasn't an unordered operation")]
315    /**
316     * The given graph node wasn't an unordered operation.
317     */
318    NotUnorderedOperation,
319
320    #[error("The given graph node wasn't an ordered operation")]
321    /**
322     * The given graph node wasn't an ordered operation.
323     */
324    NotOrderedOperation,
325
326    #[error("No node exists at the given index")]
327    /**
328     * No node exists at the given index.
329     */
330    NoSuchNode,
331
332    #[error("The given node doesn't have 1 left and 1 right edge")]
333    /**
334     * The given node doesn't have 1 left and 1 right edge.
335     */
336    IncorrectBinaryOperandEdges,
337
338    #[error("The given node doesn't have exactly 1 unary edge")]
339    /**
340     * The given node doesn't have exactly 1 unary edge.
341     */
342    IncorrectUnaryOperandEdge,
343
344    #[error("The given node has a non-unordered edge")]
345    /**
346     * The given node has a non-unordered edge.
347     */
348    IncorrectUnorderedOperandEdge,
349
350    #[error("The given node has a non-ordered edge")]
351    /**
352     * The given node has a non-ordered edge.
353     */
354    IncorrectOrderedOperandEdge,
355}
356
357const_assert!(std::mem::size_of::<GraphQueryError>() <= 8);
358
359impl<'a, O> GraphQuery<'a, NodeInfo<O>, EdgeInfo>
360where
361    O: Operation,
362{
363    /**
364     * Returns the left and right node indices to a binary operation.
365     *
366     * # Errors
367     * - No node exists at the given index.
368     * - The node at the given index isn't a binary operation.
369     * - The node at the given index doesn't have a 1 left and 1 right parent
370     */
371    pub fn get_binary_operands(
372        &self,
373        index: NodeIndex,
374    ) -> Result<(NodeIndex, NodeIndex), GraphQueryError> {
375        let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
376
377        if !node.operation.is_binary() {
378            return Err(GraphQueryError::NotBinaryOperation);
379        }
380
381        let parent_edges = self
382            .edges_directed(index, Direction::Incoming)
383            .collect::<Vec<EdgeReference<EdgeInfo>>>();
384
385        if parent_edges.len() != 2 {
386            return Err(GraphQueryError::IncorrectBinaryOperandEdges);
387        }
388
389        let left = parent_edges.iter().find_map(|e| {
390            if matches!(e.weight(), EdgeInfo::Left) {
391                Some(e.source())
392            } else {
393                None
394            }
395        });
396
397        let right = parent_edges.iter().find_map(|e| {
398            if matches!(e.weight(), EdgeInfo::Right) {
399                Some(e.source())
400            } else {
401                None
402            }
403        });
404
405        left.zip(right)
406            .ok_or(GraphQueryError::IncorrectBinaryOperandEdges)
407    }
408
409    /**
410     * Returns the unary operand node index to a unary operation.
411     *
412     * # Errors
413     * - No node exists at the given index.
414     * - The node at the given index isn't a unary operation.
415     * - The node at the given index doesn't have a single unary operand.
416     */
417    pub fn get_unary_operand(&self, index: NodeIndex) -> Result<NodeIndex, GraphQueryError> {
418        let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
419
420        if !node.operation.is_unary() {
421            return Err(GraphQueryError::NotUnaryOperation);
422        }
423
424        let parent_edges = self
425            .edges_directed(index, Direction::Incoming)
426            .collect::<Vec<EdgeReference<EdgeInfo>>>();
427
428        if parent_edges.len() != 1 || !matches!(&parent_edges[0].weight(), EdgeInfo::Unary) {
429            return Err(GraphQueryError::IncorrectBinaryOperandEdges);
430        }
431
432        let left = parent_edges.first();
433
434        Ok(left
435            .ok_or(GraphQueryError::IncorrectUnaryOperandEdge)?
436            .source())
437    }
438
439    /**
440     * Returns the unordered operands to the given operation.
441     *
442     * # Remarks
443     * As these operands are unordered, their order is undefined. Use
444     * [`EdgeInfo::Ordered`] and call
445     * [`GraphQuery::get_ordered_operands`] if you need a defined order.
446     *
447     * * # Errors
448     * - No node exists at the given index.
449     * - The node at the given index isn't a unary operation.
450     * - The node at the given index doesn't have a single unary operand.
451     */
452    pub fn get_unordered_operands(
453        &self,
454        index: NodeIndex,
455    ) -> Result<Vec<NodeIndex>, GraphQueryError> {
456        let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
457
458        if !node.operation.is_unordered() {
459            return Err(GraphQueryError::NotUnorderedOperation);
460        }
461
462        let parent_edges = self
463            .edges_directed(index, Direction::Incoming)
464            .collect::<Vec<EdgeReference<EdgeInfo>>>();
465
466        if !parent_edges
467            .iter()
468            .all(|e| matches!(e.weight(), EdgeInfo::Unordered))
469        {
470            return Err(GraphQueryError::IncorrectUnorderedOperandEdge);
471        }
472
473        Ok(parent_edges.iter().map(|x| x.source()).collect())
474    }
475
476    /**
477     * Returns the unordered operands to the given operation.
478     *
479     * # Remarks
480     * The operands node indices are returned in order.
481     *
482     * * # Errors
483     * - No node exists at the given index.
484     * - The node at the given index isn't a unary operation.
485     * - The node at the given index doesn't have a single unary operand.
486     */
487    pub fn get_ordered_operands(
488        &self,
489        index: NodeIndex,
490    ) -> Result<Vec<NodeIndex>, GraphQueryError> {
491        let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
492
493        if !node.operation.is_ordered() {
494            return Err(GraphQueryError::NotOrderedOperation);
495        }
496
497        let mut parent_edges = self
498            .edges_directed(index, Direction::Incoming)
499            .map(|x| match x.weight() {
500                EdgeInfo::Ordered(arg_id) => Ok(SortableEdge(x.source(), *arg_id)),
501                _ => Err(GraphQueryError::IncorrectOrderedOperandEdge),
502            })
503            .collect::<Result<Vec<SortableEdge>, _>>()?;
504
505        #[derive(Eq)]
506        struct SortableEdge(NodeIndex, usize);
507
508        impl PartialEq for SortableEdge {
509            fn eq(&self, other: &Self) -> bool {
510                self.1 == other.1
511            }
512        }
513
514        impl PartialOrd for SortableEdge {
515            fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
516                self.1.partial_cmp(&other.1)
517            }
518        }
519
520        impl Ord for SortableEdge {
521            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
522                // PartialCmp will always return Some(_) for usize,
523                // which is the thing we're comparing.
524                self.partial_cmp(other).unwrap()
525            }
526        }
527
528        // Sort the edges by the argument index.
529        parent_edges.sort();
530
531        // Check that the argument indices form a range 0..N
532        for (i, e) in parent_edges.iter().enumerate() {
533            if e.1 != i {
534                return Err(GraphQueryError::IncorrectOrderedOperandEdge);
535            }
536        }
537
538        // Finally, return the parent node indices sorted by their
539        // argument index
540        Ok(parent_edges.iter().map(|x| x.0).collect())
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use std::convert::Infallible;
547
548    use super::*;
549    use crate::{
550        transforms::{GraphTransforms, Transform},
551        Context, Operation as OperationTrait,
552    };
553
554    #[derive(Clone, Debug, Hash, PartialEq, Eq)]
555    enum Operation {
556        Add,
557        Mul,
558        In,
559    }
560
561    impl OperationTrait for Operation {
562        fn is_binary(&self) -> bool {
563            matches!(self, Self::Add | Self::Mul)
564        }
565
566        fn is_commutative(&self) -> bool {
567            matches!(self, Self::Add | Self::Mul)
568        }
569
570        fn is_unary(&self) -> bool {
571            false
572        }
573
574        fn is_unordered(&self) -> bool {
575            false
576        }
577
578        fn is_ordered(&self) -> bool {
579            false
580        }
581    }
582
583    type TestGraph = Context<Operation, ()>;
584
585    fn create_simple_dag() -> TestGraph {
586        let mut graph = TestGraph::new(());
587
588        let in_1 = graph.add_node(Operation::In);
589        let in_2 = graph.add_node(Operation::In);
590        let add = graph.add_binary_operation(Operation::Add, in_1, in_2);
591        let in_3 = graph.add_node(Operation::In);
592        graph.add_binary_operation(Operation::Mul, add, in_3);
593
594        graph
595    }
596
597    #[test]
598    fn can_forward_traverse() {
599        let ir = create_simple_dag();
600
601        let mut visited = vec![];
602
603        forward_traverse(&ir.graph, |_, n| {
604            visited.push(n);
605
606            Ok::<_, Infallible>(())
607        })
608        .unwrap();
609
610        assert_eq!(
611            visited,
612            vec![
613                NodeIndex::from(3),
614                NodeIndex::from(1),
615                NodeIndex::from(0),
616                NodeIndex::from(2),
617                NodeIndex::from(4)
618            ]
619        );
620    }
621
622    #[test]
623    fn can_build_simple_dag() {
624        let ir = create_simple_dag();
625
626        assert_eq!(ir.graph.node_count(), 5);
627
628        let nodes = ir
629            .graph
630            .node_identifiers()
631            .map(|i| (i, &ir.graph[i]))
632            .collect::<Vec<(NodeIndex, &NodeInfo<Operation>)>>();
633
634        assert_eq!(nodes[0].1.operation, Operation::In);
635        assert_eq!(nodes[1].1.operation, Operation::In);
636        assert_eq!(nodes[2].1.operation, Operation::Add);
637        assert_eq!(nodes[3].1.operation, Operation::In);
638        assert_eq!(nodes[4].1.operation, Operation::Mul);
639
640        assert_eq!(
641            ir.graph
642                .neighbors_directed(nodes[0].0, Direction::Outgoing)
643                .next()
644                .unwrap(),
645            nodes[2].0
646        );
647        assert_eq!(
648            ir.graph
649                .neighbors_directed(nodes[1].0, Direction::Outgoing)
650                .next()
651                .unwrap(),
652            nodes[2].0
653        );
654        assert_eq!(
655            ir.graph
656                .neighbors_directed(nodes[2].0, Direction::Outgoing)
657                .next()
658                .unwrap(),
659            nodes[4].0
660        );
661        assert_eq!(
662            ir.graph
663                .neighbors_directed(nodes[3].0, Direction::Outgoing)
664                .next()
665                .unwrap(),
666            nodes[4].0
667        );
668        assert_eq!(
669            ir.graph
670                .neighbors_directed(nodes[4].0, Direction::Outgoing)
671                .next(),
672            None
673        );
674    }
675
676    #[test]
677    fn can_reverse_traverse() {
678        let ir = create_simple_dag();
679
680        let mut visited = vec![];
681
682        reverse_traverse(&ir.graph, |_, n| {
683            visited.push(n);
684            Ok::<_, Infallible>(())
685        })
686        .unwrap();
687
688        assert_eq!(
689            visited,
690            vec![
691                NodeIndex::from(4),
692                NodeIndex::from(2),
693                NodeIndex::from(0),
694                NodeIndex::from(1),
695                NodeIndex::from(3)
696            ]
697        );
698    }
699
700    #[test]
701    fn can_delete_during_traversal() {
702        let mut ir = create_simple_dag();
703
704        let mut visited = vec![];
705
706        reverse_traverse_mut(&mut ir.graph, |_, n| {
707            visited.push(n);
708            // Delete the addition
709            if n.index() == 2 {
710                let mut transforms = GraphTransforms::new();
711                transforms.push(Transform::RemoveNode(n.into()));
712
713                Ok::<_, Infallible>(transforms)
714            } else {
715                Ok::<_, Infallible>(GraphTransforms::default())
716            }
717        })
718        .unwrap();
719
720        assert_eq!(
721            visited,
722            vec![
723                NodeIndex::from(4),
724                NodeIndex::from(2),
725                NodeIndex::from(0),
726                NodeIndex::from(1),
727                NodeIndex::from(3)
728            ]
729        );
730    }
731
732    #[test]
733    fn can_append_during_traversal() {
734        let mut ir = create_simple_dag();
735
736        let mut visited = vec![];
737
738        forward_traverse_mut(&mut ir.graph, |_, n| {
739            visited.push(n);
740
741            // Delete the addition
742            if n.index() == 2 {
743                let mut transforms: GraphTransforms<NodeInfo<Operation>, EdgeInfo> =
744                    GraphTransforms::new();
745                let mul = transforms.push(Transform::AddNode(NodeInfo {
746                    operation: Operation::Mul,
747                }));
748                transforms.push(Transform::AddEdge(n.into(), mul.into(), EdgeInfo::Left));
749                transforms.push(Transform::AddEdge(
750                    NodeIndex::from(1).into(),
751                    mul.into(),
752                    EdgeInfo::Right,
753                ));
754
755                let ret = transforms.clone();
756
757                transforms.apply(&mut create_simple_dag().graph.0);
758
759                Ok::<_, Infallible>(ret)
760            } else {
761                Ok::<_, Infallible>(GraphTransforms::default())
762            }
763        })
764        .unwrap();
765
766        assert_eq!(
767            visited,
768            vec![
769                NodeIndex::from(3),
770                NodeIndex::from(1),
771                NodeIndex::from(0),
772                NodeIndex::from(2),
773                NodeIndex::from(4),
774                NodeIndex::from(5),
775            ]
776        );
777    }
778}