1use std::collections::HashSet;
2
3use petgraph::{
4 dot::Dot,
5 stable_graph::{EdgeReference, Edges, Neighbors, NodeIndex, StableGraph},
6 visit::{EdgeRef, IntoNodeIdentifiers},
7 Directed, Direction,
8};
9use static_assertions::const_assert;
10use thiserror::Error;
11
12use crate::{EdgeInfo, NodeInfo, Operation, Render};
13
14pub struct GraphQuery<'a, N, E>(&'a StableGraph<N, E>);
20
21impl<'a, N, E> From<&'a StableGraph<N, E>> for GraphQuery<'a, N, E> {
22 fn from(x: &'a StableGraph<N, E>) -> Self {
23 Self(x)
24 }
25}
26
27impl<'a, N, E> GraphQuery<'a, N, E> {
28 pub fn new(ir: &'a StableGraph<N, E>) -> Self {
33 Self(ir)
34 }
35
36 pub fn get_node(&self, x: NodeIndex) -> Option<&N> {
40 self.0.node_weight(x)
41 }
42
43 pub fn neighbors_directed(&self, x: NodeIndex, direction: Direction) -> Neighbors<E> {
52 self.0.neighbors_directed(x, direction)
53 }
54
55 pub fn edges_directed(&self, x: NodeIndex, direction: Direction) -> Edges<E, Directed> {
64 self.0.edges_directed(x, direction)
65 }
66}
67
68pub trait TransformList<N, E>
72where
73 N: Clone,
74 E: Clone,
75{
76 fn apply(self, graph: &mut StableGraph<N, E>) -> Vec<NodeIndex>;
83}
84
85impl<N, E> TransformList<N, E> for ()
88where
89 N: Clone,
90 E: Clone,
91{
92 fn apply(self, _graph: &mut StableGraph<N, E>) -> Vec<NodeIndex> {
93 vec![]
94 }
95}
96
97pub fn forward_traverse<N, E, F, Err>(graph: &StableGraph<N, E>, callback: F) -> Result<(), Err>
102where
103 N: Clone,
104 E: Clone,
105 F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<(), Err>,
106{
107 let graph: *const StableGraph<N, E> = graph;
108
109 unsafe { traverse(graph as *mut StableGraph<N, E>, true, callback) }
111}
112
113pub fn reverse_traverse<N, E, F, Err>(graph: &StableGraph<N, E>, callback: F) -> Result<(), Err>
118where
119 N: Clone,
120 E: Clone,
121 F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<(), Err>,
122{
123 let graph: *const StableGraph<N, E> = graph;
124
125 unsafe { traverse(graph as *mut StableGraph<N, E>, false, callback) }
127}
128
129pub fn forward_traverse_mut<N, E, F, T, Err>(
146 graph: &mut StableGraph<N, E>,
147 callback: F,
148) -> Result<(), Err>
149where
150 N: Clone,
151 E: Clone,
152 T: TransformList<N, E>,
153 F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<T, Err>,
154{
155 unsafe { traverse(graph, true, callback) }
156}
157
158pub fn reverse_traverse_mut<N, E, F, T, Err>(
175 graph: &mut StableGraph<N, E>,
176 callback: F,
177) -> Result<(), Err>
178where
179 N: Clone,
180 E: Clone,
181 T: TransformList<N, E>,
182 F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<T, Err>,
183{
184 unsafe { traverse(graph, false, callback) }
185}
186
187unsafe fn traverse<N, E, T, F, Err>(
193 graph: *mut StableGraph<N, E>,
194 forward: bool,
195 mut callback: F,
196) -> Result<(), Err>
197where
198 N: Clone,
199 E: Clone,
200 F: FnMut(GraphQuery<N, E>, NodeIndex) -> Result<T, Err>,
201 T: TransformList<N, E>,
202{
203 let graph = &mut *graph;
205 let mut ready: HashSet<NodeIndex> = HashSet::new();
206 let mut visited: HashSet<NodeIndex> = HashSet::new();
207 let prev_direction = if forward {
208 Direction::Incoming
209 } else {
210 Direction::Outgoing
211 };
212 let next_direction = if forward {
213 Direction::Outgoing
214 } else {
215 Direction::Incoming
216 };
217
218 let mut ready_nodes: Vec<NodeIndex> = graph
219 .node_identifiers()
220 .filter(|&x| graph.neighbors_directed(x, prev_direction).next().is_none())
221 .collect();
222
223 ready.extend(ready_nodes.iter());
224
225 while let Some(n) = ready_nodes.pop() {
226 visited.insert(n);
227
228 let next_nodes: Vec<NodeIndex> = graph.neighbors_directed(n, next_direction).collect();
230
231 let transforms = callback(GraphQuery(graph), n)?;
232
233 let added_nodes = transforms.apply(graph);
235
236 let node_ready = |n: NodeIndex| {
237 graph
238 .neighbors_directed(n, prev_direction)
239 .all(|m| visited.contains(&m))
240 };
241
242 if graph.contains_node(n) {
244 for i in graph.neighbors_directed(n, next_direction) {
245 if !ready.contains(&i) && node_ready(i) {
246 ready.insert(i);
247 ready_nodes.push(i);
248 }
249 }
250 }
251
252 for i in next_nodes {
254 if !ready.contains(&i) && node_ready(i) {
255 ready.insert(i);
256 ready_nodes.push(i);
257 }
258 }
259
260 for i in added_nodes {
262 if graph.neighbors_directed(i, prev_direction).next().is_none() {
263 ready.insert(i);
264 ready_nodes.push(i);
265 }
266 }
267 }
268
269 Ok(())
270}
271
272impl<N, E> Render for StableGraph<N, E>
273where
274 N: Render + std::fmt::Debug,
275 E: Render + std::fmt::Debug,
276{
277 fn render(&self) -> String {
278 let data = Dot::with_attr_getters(
279 self,
280 &[
281 petgraph::dot::Config::NodeNoLabel,
282 petgraph::dot::Config::EdgeNoLabel,
283 ],
284 &|_, e| format!("label=\"{}\"", e.weight().render()),
285 &|_, n| {
286 let (index, info) = n;
287
288 format!("label=\"{}: {}\"", index.index(), info.render())
289 },
290 );
291
292 format!("{data:?}")
293 }
294}
295
296#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
297pub enum GraphQueryError {
302 #[error("The given graph node wasn't a binary operation")]
303 NotBinaryOperation,
307
308 #[error("The given graph node wasn't a unary operation")]
309 NotUnaryOperation,
313
314 #[error("The given graph node wasn't an unordered operation")]
315 NotUnorderedOperation,
319
320 #[error("The given graph node wasn't an ordered operation")]
321 NotOrderedOperation,
325
326 #[error("No node exists at the given index")]
327 NoSuchNode,
331
332 #[error("The given node doesn't have 1 left and 1 right edge")]
333 IncorrectBinaryOperandEdges,
337
338 #[error("The given node doesn't have exactly 1 unary edge")]
339 IncorrectUnaryOperandEdge,
343
344 #[error("The given node has a non-unordered edge")]
345 IncorrectUnorderedOperandEdge,
349
350 #[error("The given node has a non-ordered edge")]
351 IncorrectOrderedOperandEdge,
355}
356
357const_assert!(std::mem::size_of::<GraphQueryError>() <= 8);
358
359impl<'a, O> GraphQuery<'a, NodeInfo<O>, EdgeInfo>
360where
361 O: Operation,
362{
363 pub fn get_binary_operands(
372 &self,
373 index: NodeIndex,
374 ) -> Result<(NodeIndex, NodeIndex), GraphQueryError> {
375 let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
376
377 if !node.operation.is_binary() {
378 return Err(GraphQueryError::NotBinaryOperation);
379 }
380
381 let parent_edges = self
382 .edges_directed(index, Direction::Incoming)
383 .collect::<Vec<EdgeReference<EdgeInfo>>>();
384
385 if parent_edges.len() != 2 {
386 return Err(GraphQueryError::IncorrectBinaryOperandEdges);
387 }
388
389 let left = parent_edges.iter().find_map(|e| {
390 if matches!(e.weight(), EdgeInfo::Left) {
391 Some(e.source())
392 } else {
393 None
394 }
395 });
396
397 let right = parent_edges.iter().find_map(|e| {
398 if matches!(e.weight(), EdgeInfo::Right) {
399 Some(e.source())
400 } else {
401 None
402 }
403 });
404
405 left.zip(right)
406 .ok_or(GraphQueryError::IncorrectBinaryOperandEdges)
407 }
408
409 pub fn get_unary_operand(&self, index: NodeIndex) -> Result<NodeIndex, GraphQueryError> {
418 let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
419
420 if !node.operation.is_unary() {
421 return Err(GraphQueryError::NotUnaryOperation);
422 }
423
424 let parent_edges = self
425 .edges_directed(index, Direction::Incoming)
426 .collect::<Vec<EdgeReference<EdgeInfo>>>();
427
428 if parent_edges.len() != 1 || !matches!(&parent_edges[0].weight(), EdgeInfo::Unary) {
429 return Err(GraphQueryError::IncorrectBinaryOperandEdges);
430 }
431
432 let left = parent_edges.first();
433
434 Ok(left
435 .ok_or(GraphQueryError::IncorrectUnaryOperandEdge)?
436 .source())
437 }
438
439 pub fn get_unordered_operands(
453 &self,
454 index: NodeIndex,
455 ) -> Result<Vec<NodeIndex>, GraphQueryError> {
456 let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
457
458 if !node.operation.is_unordered() {
459 return Err(GraphQueryError::NotUnorderedOperation);
460 }
461
462 let parent_edges = self
463 .edges_directed(index, Direction::Incoming)
464 .collect::<Vec<EdgeReference<EdgeInfo>>>();
465
466 if !parent_edges
467 .iter()
468 .all(|e| matches!(e.weight(), EdgeInfo::Unordered))
469 {
470 return Err(GraphQueryError::IncorrectUnorderedOperandEdge);
471 }
472
473 Ok(parent_edges.iter().map(|x| x.source()).collect())
474 }
475
476 pub fn get_ordered_operands(
488 &self,
489 index: NodeIndex,
490 ) -> Result<Vec<NodeIndex>, GraphQueryError> {
491 let node = self.get_node(index).ok_or(GraphQueryError::NoSuchNode)?;
492
493 if !node.operation.is_ordered() {
494 return Err(GraphQueryError::NotOrderedOperation);
495 }
496
497 let mut parent_edges = self
498 .edges_directed(index, Direction::Incoming)
499 .map(|x| match x.weight() {
500 EdgeInfo::Ordered(arg_id) => Ok(SortableEdge(x.source(), *arg_id)),
501 _ => Err(GraphQueryError::IncorrectOrderedOperandEdge),
502 })
503 .collect::<Result<Vec<SortableEdge>, _>>()?;
504
505 #[derive(Eq)]
506 struct SortableEdge(NodeIndex, usize);
507
508 impl PartialEq for SortableEdge {
509 fn eq(&self, other: &Self) -> bool {
510 self.1 == other.1
511 }
512 }
513
514 impl PartialOrd for SortableEdge {
515 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
516 self.1.partial_cmp(&other.1)
517 }
518 }
519
520 impl Ord for SortableEdge {
521 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
522 self.partial_cmp(other).unwrap()
525 }
526 }
527
528 parent_edges.sort();
530
531 for (i, e) in parent_edges.iter().enumerate() {
533 if e.1 != i {
534 return Err(GraphQueryError::IncorrectOrderedOperandEdge);
535 }
536 }
537
538 Ok(parent_edges.iter().map(|x| x.0).collect())
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use std::convert::Infallible;
547
548 use super::*;
549 use crate::{
550 transforms::{GraphTransforms, Transform},
551 Context, Operation as OperationTrait,
552 };
553
554 #[derive(Clone, Debug, Hash, PartialEq, Eq)]
555 enum Operation {
556 Add,
557 Mul,
558 In,
559 }
560
561 impl OperationTrait for Operation {
562 fn is_binary(&self) -> bool {
563 matches!(self, Self::Add | Self::Mul)
564 }
565
566 fn is_commutative(&self) -> bool {
567 matches!(self, Self::Add | Self::Mul)
568 }
569
570 fn is_unary(&self) -> bool {
571 false
572 }
573
574 fn is_unordered(&self) -> bool {
575 false
576 }
577
578 fn is_ordered(&self) -> bool {
579 false
580 }
581 }
582
583 type TestGraph = Context<Operation, ()>;
584
585 fn create_simple_dag() -> TestGraph {
586 let mut graph = TestGraph::new(());
587
588 let in_1 = graph.add_node(Operation::In);
589 let in_2 = graph.add_node(Operation::In);
590 let add = graph.add_binary_operation(Operation::Add, in_1, in_2);
591 let in_3 = graph.add_node(Operation::In);
592 graph.add_binary_operation(Operation::Mul, add, in_3);
593
594 graph
595 }
596
597 #[test]
598 fn can_forward_traverse() {
599 let ir = create_simple_dag();
600
601 let mut visited = vec![];
602
603 forward_traverse(&ir.graph, |_, n| {
604 visited.push(n);
605
606 Ok::<_, Infallible>(())
607 })
608 .unwrap();
609
610 assert_eq!(
611 visited,
612 vec![
613 NodeIndex::from(3),
614 NodeIndex::from(1),
615 NodeIndex::from(0),
616 NodeIndex::from(2),
617 NodeIndex::from(4)
618 ]
619 );
620 }
621
622 #[test]
623 fn can_build_simple_dag() {
624 let ir = create_simple_dag();
625
626 assert_eq!(ir.graph.node_count(), 5);
627
628 let nodes = ir
629 .graph
630 .node_identifiers()
631 .map(|i| (i, &ir.graph[i]))
632 .collect::<Vec<(NodeIndex, &NodeInfo<Operation>)>>();
633
634 assert_eq!(nodes[0].1.operation, Operation::In);
635 assert_eq!(nodes[1].1.operation, Operation::In);
636 assert_eq!(nodes[2].1.operation, Operation::Add);
637 assert_eq!(nodes[3].1.operation, Operation::In);
638 assert_eq!(nodes[4].1.operation, Operation::Mul);
639
640 assert_eq!(
641 ir.graph
642 .neighbors_directed(nodes[0].0, Direction::Outgoing)
643 .next()
644 .unwrap(),
645 nodes[2].0
646 );
647 assert_eq!(
648 ir.graph
649 .neighbors_directed(nodes[1].0, Direction::Outgoing)
650 .next()
651 .unwrap(),
652 nodes[2].0
653 );
654 assert_eq!(
655 ir.graph
656 .neighbors_directed(nodes[2].0, Direction::Outgoing)
657 .next()
658 .unwrap(),
659 nodes[4].0
660 );
661 assert_eq!(
662 ir.graph
663 .neighbors_directed(nodes[3].0, Direction::Outgoing)
664 .next()
665 .unwrap(),
666 nodes[4].0
667 );
668 assert_eq!(
669 ir.graph
670 .neighbors_directed(nodes[4].0, Direction::Outgoing)
671 .next(),
672 None
673 );
674 }
675
676 #[test]
677 fn can_reverse_traverse() {
678 let ir = create_simple_dag();
679
680 let mut visited = vec![];
681
682 reverse_traverse(&ir.graph, |_, n| {
683 visited.push(n);
684 Ok::<_, Infallible>(())
685 })
686 .unwrap();
687
688 assert_eq!(
689 visited,
690 vec![
691 NodeIndex::from(4),
692 NodeIndex::from(2),
693 NodeIndex::from(0),
694 NodeIndex::from(1),
695 NodeIndex::from(3)
696 ]
697 );
698 }
699
700 #[test]
701 fn can_delete_during_traversal() {
702 let mut ir = create_simple_dag();
703
704 let mut visited = vec![];
705
706 reverse_traverse_mut(&mut ir.graph, |_, n| {
707 visited.push(n);
708 if n.index() == 2 {
710 let mut transforms = GraphTransforms::new();
711 transforms.push(Transform::RemoveNode(n.into()));
712
713 Ok::<_, Infallible>(transforms)
714 } else {
715 Ok::<_, Infallible>(GraphTransforms::default())
716 }
717 })
718 .unwrap();
719
720 assert_eq!(
721 visited,
722 vec![
723 NodeIndex::from(4),
724 NodeIndex::from(2),
725 NodeIndex::from(0),
726 NodeIndex::from(1),
727 NodeIndex::from(3)
728 ]
729 );
730 }
731
732 #[test]
733 fn can_append_during_traversal() {
734 let mut ir = create_simple_dag();
735
736 let mut visited = vec![];
737
738 forward_traverse_mut(&mut ir.graph, |_, n| {
739 visited.push(n);
740
741 if n.index() == 2 {
743 let mut transforms: GraphTransforms<NodeInfo<Operation>, EdgeInfo> =
744 GraphTransforms::new();
745 let mul = transforms.push(Transform::AddNode(NodeInfo {
746 operation: Operation::Mul,
747 }));
748 transforms.push(Transform::AddEdge(n.into(), mul.into(), EdgeInfo::Left));
749 transforms.push(Transform::AddEdge(
750 NodeIndex::from(1).into(),
751 mul.into(),
752 EdgeInfo::Right,
753 ));
754
755 let ret = transforms.clone();
756
757 transforms.apply(&mut create_simple_dag().graph.0);
758
759 Ok::<_, Infallible>(ret)
760 } else {
761 Ok::<_, Infallible>(GraphTransforms::default())
762 }
763 })
764 .unwrap();
765
766 assert_eq!(
767 visited,
768 vec![
769 NodeIndex::from(3),
770 NodeIndex::from(1),
771 NodeIndex::from(0),
772 NodeIndex::from(2),
773 NodeIndex::from(4),
774 NodeIndex::from(5),
775 ]
776 );
777 }
778}