1use super::portgraph::graph::{ConnectError, Direction};
3use crate::prelude::{TryFrom, TryInto};
4use crate::symbol::{FunctionName, Label, Location, SymbolError, TypeVar};
5use indexmap::{IndexMap, IndexSet};
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8use std::convert::Infallible;
9use std::hash::{Hash, Hasher};
10use thiserror::Error;
11
12pub use super::portgraph::graph::{EdgeIndex, NodeIndex};
13
14#[derive(Clone, Debug)]
17pub enum Value {
18 Bool(bool),
20 Int(i64),
22 Str(String),
24 Float(f64),
26 Graph(Graph),
28 Pair(Box<(Value, Value)>),
30 Map(HashMap<Value, Value>),
33 Vec(Vec<Value>),
35 Struct(HashMap<Label, Value>),
37 Variant(Label, Box<Value>),
39}
40
41impl PartialEq for Value {
42 fn eq(&self, other: &Value) -> bool {
43 match (self, other) {
44 (Value::Bool(x), Value::Bool(y)) => x == y,
45 (Value::Int(x), Value::Int(y)) => x == y,
46 (Value::Str(x), Value::Str(y)) => x == y,
47 (Value::Float(x), Value::Float(y)) => x == y,
48 (Value::Pair(x), Value::Pair(y)) => x == y,
49 (Value::Map(x), Value::Map(y)) => x == y,
50 (Value::Struct(f1), Value::Struct(f2)) => f1 == f2,
51 (Value::Vec(ar1), Value::Vec(ar2)) => ar1 == ar2,
52 (Value::Graph(x), Value::Graph(y)) => x == y,
53 (Value::Variant(t1, v1), Value::Variant(t2, v2)) => (t1 == t2) && (v1 == v2),
54 _ => false,
55 }
56 }
57}
58
59impl Eq for Value {}
61
62impl Hash for Value {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 match self {
65 Value::Graph(_) | Value::Map(_) | Value::Struct(_) | Value::Float(_) => {
66 panic!("Value is not hashable: {:?}", self)
67 }
68 Value::Bool(x) => x.hash(state),
69 Value::Int(x) => x.hash(state),
70 Value::Str(x) => x.hash(state),
71 Value::Pair(x) => x.hash(state),
72 Value::Vec(x) => x.hash(state),
73 Value::Variant(tag, value) => {
74 tag.hash(state);
75 value.hash(state);
76 }
77 }
78 }
79}
80
81#[derive(Clone, PartialEq, Eq, Hash)]
84pub enum Type {
85 Bool,
87 Int,
89 Str,
91 Float,
93 Graph(GraphType),
96 Pair(Box<Type>, Box<Type>),
99 Vec(Box<Type>),
102 Var(TypeVar),
105 Row(RowType),
114 Map(Box<Type>, Box<Type>),
117 Struct(RowType, Option<String>),
121 Variant(RowType),
125}
126
127impl Type {
128 pub fn struct_from_row(row: RowType) -> Self {
130 Self::Struct(row, None)
131 }
132}
133
134impl std::fmt::Debug for Type {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 Type::Bool => f.debug_struct("Bool").finish(),
138 Type::Int => f.debug_struct("Int").finish(),
139 Type::Str => f.debug_struct("Str").finish(),
140 Type::Float => f.debug_struct("Float").finish(),
141 Type::Graph(graph) => std::fmt::Debug::fmt(graph, f),
142 Type::Pair(first, second) => f.debug_tuple("Pair").field(first).field(second).finish(),
143 Type::Vec(element) => f.debug_tuple("Vector").field(element).finish(),
144 Type::Var(var) => std::fmt::Debug::fmt(var, f),
145 Type::Row(row) => f.debug_tuple("Row").field(row).finish(),
146 Type::Map(key, value) => f.debug_tuple("Map").field(key).field(value).finish(),
147 Type::Struct(row, name) => match name {
148 Some(name) => f.debug_struct(name).finish(),
149 None => f.debug_tuple("Struct").field(row).finish(),
150 },
151 Type::Variant(row) => f.debug_tuple("Variant").field(row).finish(),
152 }
153 }
154}
155
156impl From<GraphType> for Type {
157 fn from(t: GraphType) -> Self {
158 Type::Graph(t)
159 }
160}
161
162impl From<TypeVar> for Type {
163 fn from(t: TypeVar) -> Self {
164 Type::Var(t)
165 }
166}
167
168impl Type {
169 pub fn type_vars(&self) -> impl Iterator<Item = TypeVar> + '_ {
172 let mut vars = IndexSet::new();
173 self.type_vars_impl(&mut vars);
174 vars.into_iter()
175 }
176
177 fn type_vars_impl(&self, vars: &mut IndexSet<TypeVar>) {
178 match self {
179 Type::Bool => {}
180 Type::Int => {}
181 Type::Str => {}
182 Type::Float => {}
183 Type::Graph(graph) => {
184 graph.type_vars_impl(vars);
185 }
186 Type::Pair(left, right) => {
187 left.type_vars_impl(vars);
188 right.type_vars_impl(vars);
189 }
190 Type::Vec(element) => {
191 element.type_vars_impl(vars);
192 }
193 Type::Var(var) => {
194 vars.insert(*var);
195 }
196 Type::Row(row) => {
197 row.type_vars_impl(vars);
198 }
199 Type::Map(key, value) => {
200 key.type_vars_impl(vars);
201 value.type_vars_impl(vars);
202 }
203 Type::Struct(row, _) => {
204 row.type_vars_impl(vars);
205 }
206 Type::Variant(row) => {
207 row.type_vars_impl(vars);
208 }
209 }
210 }
211}
212
213#[derive(Clone, PartialEq, Eq, Hash)]
215pub struct GraphType {
216 pub inputs: RowType,
218 pub outputs: RowType,
220}
221
222impl GraphType {
223 pub fn new() -> Self {
225 Self {
226 inputs: Default::default(),
227 outputs: Default::default(),
228 }
229 }
230
231 pub fn add_input(&mut self, port: impl Into<Label>, type_: impl Into<Type>) {
233 self.inputs.content.insert(port.into(), type_.into());
234 }
235
236 pub fn add_output(&mut self, port: impl Into<Label>, type_: impl Into<Type>) {
238 self.outputs.content.insert(port.into(), type_.into());
239 }
240
241 fn type_vars_impl(&self, vars: &mut IndexSet<TypeVar>) {
242 self.inputs.type_vars_impl(vars);
243 self.outputs.type_vars_impl(vars);
244 }
245}
246
247impl Default for GraphType {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253impl std::fmt::Debug for GraphType {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("Graph")
256 .field("inputs", &self.inputs)
257 .field("outputs", &self.outputs)
258 .finish()
259 }
260}
261
262#[derive(Clone, PartialEq, Eq, Hash, Default)]
264pub struct RowType {
265 pub content: BTreeMap<Label, Type>,
267 pub rest: Option<TypeVar>,
271}
272
273impl std::fmt::Debug for RowType {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_map()
276 .entries(self.content.iter())
277 .entries(self.rest.iter().map(|rest| ("#", rest)))
278 .finish()
279 }
280}
281
282impl From<TypeVar> for RowType {
283 fn from(var: TypeVar) -> Self {
284 Self {
285 content: BTreeMap::new(),
286 rest: Some(var),
287 }
288 }
289}
290
291impl RowType {
292 fn type_vars_impl(&self, vars: &mut IndexSet<TypeVar>) {
293 for type_ in self.content.values() {
294 type_.type_vars_impl(vars);
295 }
296
297 if let Some(var) = self.rest {
298 vars.insert(var);
299 }
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
306pub enum Kind {
307 Star,
309 Row,
311}
312
313#[derive(Debug, Clone, PartialEq, Eq)]
315pub struct TypeScheme {
316 pub variables: IndexMap<TypeVar, Kind>,
320 pub constraints: Vec<Constraint>,
322 pub body: Type,
326}
327
328impl TypeScheme {
329 pub fn with_variable(mut self, var: impl Into<TypeVar>, kind: Kind) -> Self {
331 self.variables.insert(var.into(), kind);
332 self
333 }
334
335 pub fn with_constraint(mut self, constraint: impl Into<Constraint>) -> Self {
338 self.constraints.push(constraint.into());
339 self
340 }
341}
342
343impl From<Type> for TypeScheme {
344 fn from(type_: Type) -> Self {
345 TypeScheme {
346 variables: IndexMap::new(),
347 constraints: Vec::new(),
348 body: type_,
349 }
350 }
351}
352
353impl From<GraphType> for TypeScheme {
354 fn from(graph_type: GraphType) -> Self {
355 Type::Graph(graph_type).into()
356 }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
361pub enum Constraint {
362 Lacks {
364 row: Type,
366 label: Label,
368 },
369 Partition {
372 left: Type,
374 right: Type,
376 union: Type,
379 },
380}
381
382#[derive(Clone, Debug, PartialEq)]
384pub enum Node {
385 Input,
387
388 Output,
390
391 Const(Value),
393
394 Box(Location, Graph),
398
399 Function(FunctionName, Option<u32>),
403
404 Match,
406
407 Tag(Label),
409}
410
411impl Node {
412 pub fn is_external(&self) -> bool {
414 match self {
415 Node::Function(f, _) => !f.is_builtin(),
416 Node::Box(l, _) => l != &Location::local(),
417 _ => false,
418 }
419 }
420
421 pub fn local_box(graph: Graph) -> Self {
424 Node::Box(Location::local(), graph)
425 }
426}
427
428impl From<Value> for Node {
430 fn from(value: Value) -> Self {
431 Node::Const(value)
432 }
433}
434
435impl TryFrom<Value> for Node {
436 type Error = Infallible;
437
438 fn try_from(value: Value) -> Result<Self, Self::Error> {
439 Ok(value.into())
440 }
441}
442
443impl From<Graph> for Node {
445 fn from(graph: Graph) -> Self {
446 Node::local_box(graph)
447 }
448}
449
450impl TryFrom<Graph> for Node {
451 type Error = Infallible;
452
453 fn try_from(graph: Graph) -> Result<Self, Self::Error> {
454 Ok(graph.into())
455 }
456}
457
458impl From<FunctionName> for Node {
460 fn from(function: FunctionName) -> Self {
461 Node::Function(function, None)
462 }
463}
464
465impl TryFrom<FunctionName> for Node {
466 type Error = Infallible;
467
468 fn try_from(function: FunctionName) -> Result<Self, Self::Error> {
469 Ok(function.into())
470 }
471}
472
473impl From<(FunctionName, u32)> for Node {
475 fn from(fn_t: (FunctionName, u32)) -> Self {
476 Node::Function(fn_t.0, Some(fn_t.1))
477 }
478}
479
480impl TryFrom<(FunctionName, u32)> for Node {
481 type Error = Infallible;
482
483 fn try_from(fn_t: (FunctionName, u32)) -> Result<Self, Self::Error> {
484 Ok(fn_t.into())
485 }
486}
487
488impl TryFrom<&str> for Node {
489 type Error = SymbolError;
490
491 fn try_from(value: &str) -> Result<Self, Self::Error> {
492 Ok(Node::Function(value.parse()?, None))
493 }
494}
495
496impl TryFrom<(&str, u32)> for Node {
497 type Error = SymbolError;
498
499 fn try_from(value: (&str, u32)) -> Result<Self, Self::Error> {
500 Ok(Node::Function(value.0.parse()?, Some(value.1)))
501 }
502}
503
504#[derive(Clone, Debug, PartialEq, Eq, Hash)]
506pub struct Edge {
507 pub source: NodePort,
509 pub target: NodePort,
511 pub edge_type: Option<Type>,
514}
515
516#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
518#[allow(missing_docs)]
519pub struct NodePort {
520 pub node: NodeIndex,
521 pub port: Label,
522}
523
524impl std::fmt::Display for NodePort {
525 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 write!(f, "{}:{}", self.node.index(), self.port)
527 }
528}
529
530impl<N, P> From<(N, P)> for NodePort
531where
532 N: Into<NodeIndex>,
533 P: Into<Label>,
534{
535 fn from((node, port): (N, P)) -> Self {
536 Self {
537 node: node.into(),
538 port: port.into(),
539 }
540 }
541}
542
543impl<N, P> TryFrom<(N, P)> for NodePort
544where
545 N: Into<NodeIndex>,
546 P: TryInto<Label>,
547 <P as TryInto<Label>>::Error: Into<SymbolError>,
548{
549 type Error = SymbolError;
550
551 fn try_from((node, port): (N, P)) -> Result<Self, Self::Error> {
552 Ok(Self {
553 node: node.into(),
554 port: TryInto::try_into(port).map_err(|e| e.into())?,
555 })
556 }
557}
558
559#[derive(Clone, Debug)]
571pub struct Graph {
572 internal_graph: super::portgraph::graph::Graph<Node, Edge>,
573 name: String,
574 input_order: Vec<Label>,
575 output_order: Vec<Label>,
576}
577
578impl Graph {
579 pub fn boundary() -> [NodeIndex; 2] {
581 [NodeIndex::new(0), NodeIndex::new(1)]
582 }
583
584 pub fn node_indices(&self) -> impl Iterator<Item = NodeIndex> + '_ {
586 self.internal_graph.node_indices()
587 }
588
589 pub fn nodes(&self) -> impl Iterator<Item = &Node> {
591 self.internal_graph.node_weights()
592 }
593
594 pub fn edges(&self) -> impl Iterator<Item = &Edge> {
596 self.internal_graph.edge_weights()
597 }
598
599 pub fn node(&self, node: impl Into<NodeIndex>) -> Option<&Node> {
601 self.internal_graph.node_weight(node.into())
602 }
603
604 pub fn input<E>(&self, port: impl TryInto<Label, Error = E>) -> Option<&Edge>
607 where
608 E: Into<SymbolError>,
609 {
610 self.node_output((Self::boundary()[0], port))
611 }
612
613 pub fn output<E>(&self, port: impl TryInto<Label, Error = E>) -> Option<&Edge>
616 where
617 E: Into<SymbolError>,
618 {
619 self.node_input((Self::boundary()[1], port))
620 }
621
622 pub fn node_inputs(&self, node: NodeIndex) -> impl Iterator<Item = &Edge> + '_ {
625 self.internal_graph
626 .node_edges(node, Direction::Incoming)
627 .map(move |edge| {
628 self.internal_graph
629 .edge_weight(edge)
630 .expect("missing edge.")
631 })
632 }
633
634 pub fn node_outputs(&self, node: NodeIndex) -> impl Iterator<Item = &Edge> + '_ {
637 self.internal_graph
638 .node_edges(node, Direction::Outgoing)
639 .map(move |edge| {
640 self.internal_graph
641 .edge_weight(edge)
642 .expect("missing edge.")
643 })
644 }
645
646 pub fn node_input(&self, port: impl TryInto<NodePort>) -> Option<&Edge> {
649 let port = TryInto::try_into(port).ok()?;
650 self.node_inputs(port.node)
651 .find(|edge| edge.target.port == port.port)
652 }
653
654 pub fn node_output(&self, port: impl TryInto<NodePort>) -> Option<&Edge> {
657 let port = TryInto::try_into(port).ok()?;
658 self.node_outputs(port.node)
659 .find(|edge| edge.source.port == port.port)
660 }
661
662 pub fn name(&self) -> &String {
664 &self.name
665 }
666
667 pub fn inputs(&self) -> impl Iterator<Item = &Label> {
671 self.input_order.iter()
672 }
673
674 pub fn outputs(&self) -> impl Iterator<Item = &Label> {
678 self.output_order.iter()
679 }
680}
681
682pub struct GraphBuilder {
684 graph: Graph,
685}
686
687impl GraphBuilder {
688 pub fn new() -> Self {
691 let mut builder = GraphBuilder {
692 graph: Graph {
693 internal_graph: super::portgraph::graph::Graph::new(),
694 name: "".into(),
695 input_order: vec![],
696 output_order: vec![],
697 },
698 };
699
700 builder.add_node(Node::Input).unwrap();
701 builder.add_node(Node::Output).unwrap();
702
703 builder
704 }
705
706 pub fn set_name(&mut self, name: String) {
708 self.graph.name = name;
709 }
710
711 pub fn add_node_with_edges<F>(
719 &mut self,
720 node: impl TryInto<Node, Error = F>,
721 incoming: impl IntoIterator<Item = EdgeIndex>,
722 outgoing: impl IntoIterator<Item = EdgeIndex>,
723 ) -> Result<NodeIndex, GraphBuilderError>
724 where
725 F: Into<SymbolError>,
726 {
727 let node = TryInto::try_into(node).map_err(|e| GraphBuilderError::SymbolError(e.into()))?;
728
729 Ok(self
730 .graph
731 .internal_graph
732 .add_node_with_edges(node, incoming, outgoing)?)
733 }
734
735 pub fn add_node<F>(
741 &mut self,
742 node: impl TryInto<Node, Error = F>,
743 ) -> Result<NodeIndex, GraphBuilderError>
744 where
745 F: Into<SymbolError>,
746 {
747 self.add_node_with_edges(node, [], [])
748 }
749
750 pub fn add_edge<E, F>(
762 &mut self,
763 source: impl TryInto<NodePort, Error = E>,
764 target: impl TryInto<NodePort, Error = F>,
765 edge_type: impl Into<Option<Type>>,
766 ) -> Result<(), GraphBuilderError>
767 where
768 E: Into<SymbolError>,
769 F: Into<SymbolError>,
770 {
771 let source =
772 TryInto::try_into(source).map_err(|e| GraphBuilderError::SymbolError(e.into()))?;
773 let target =
774 TryInto::try_into(target).map_err(|e| GraphBuilderError::SymbolError(e.into()))?;
775 let edge_type = edge_type.into();
776
777 if self.graph.node_input(target).is_some() {
778 return Err(GraphBuilderError::OccupiedInputPort(target));
779 } else if self.graph.node_output(source).is_some() {
780 return Err(GraphBuilderError::OccupiedOutputPort(source));
781 }
782
783 let edge = Edge {
784 source,
785 target,
786 edge_type,
787 };
788
789 let eidx = self.graph.internal_graph.add_edge(edge);
790
791 self.graph
792 .internal_graph
793 .connect(source.node, eidx, Direction::Outgoing, None)?;
794 self.graph
795 .internal_graph
796 .connect(target.node, eidx, Direction::Incoming, None)?;
797
798 Ok(())
799 }
800
801 fn is_acyclic(&self) -> bool {
802 let mut stack: Vec<_> = self.graph.node_indices().collect();
803 let mut discovered = HashSet::new();
804 let mut finished = HashSet::new();
805
806 while let Some(node) = stack.pop() {
807 if !discovered.insert(node) {
808 finished.insert(node);
809 continue;
810 }
811
812 stack.push(node);
813
814 for edge in self.graph.node_outputs(node) {
815 let target = edge.target.node;
816 if !discovered.contains(&target) {
817 stack.push(target);
818 } else if !finished.contains(&target) {
819 return false;
820 }
821 }
822 }
823
824 true
825 }
826
827 pub fn build(self) -> Result<Graph, GraphBuilderError> {
833 if !self.is_acyclic() {
834 return Err(GraphBuilderError::CyclicGraph);
835 }
836
837 Ok(self.graph)
838 }
839
840 pub fn set_io_order(&mut self, io_order: [Vec<Label>; 2]) {
843 let [i, o] = io_order;
844 self.graph.input_order = i;
845 self.graph.output_order = o;
846 }
847}
848
849impl Default for GraphBuilder {
850 fn default() -> Self {
851 Self::new()
852 }
853}
854
855#[derive(Debug, Error)]
857pub enum GraphBuilderError {
858 #[error("there already is an edge at the input port '{0}'")]
860 OccupiedInputPort(NodePort),
861 #[error("there already is an edge at the output port '{0}'")]
863 OccupiedOutputPort(NodePort),
864 #[error("the graph must be acyclic")]
866 CyclicGraph,
867 #[error("encountered error importing symbols: {0}")]
869 SymbolError(#[from] SymbolError),
870 #[error("Error when connecting edges.")]
873 ConnectError(#[from] ConnectError),
874}
875
876impl PartialEq for Graph {
877 fn eq(&self, other: &Self) -> bool {
878 if self.name != other.name {
879 return false;
880 }
881 let our_nodes: Vec<&Node> = self.nodes().collect();
883 let other_nodes: Vec<&Node> = other.nodes().collect();
884
885 if our_nodes != other_nodes {
886 return false;
887 }
888
889 let our_edges: HashSet<&Edge> = self.edges().collect();
891 let other_edges: HashSet<&Edge> = other.edges().collect();
892
893 our_edges == other_edges
894 }
895}
896
897#[cfg(test)]
898mod test {
899 use crate::graph::{Graph, GraphBuilder, GraphBuilderError};
900 use std::error::Error;
901
902 #[test]
905 fn test_cyclic_graph() -> Result<(), Box<dyn Error>> {
906 let mut builder = GraphBuilder::new();
907 let [i, o] = Graph::boundary();
908 let a = builder.add_node("cnot")?;
909 let b = builder.add_node("cnot")?;
910
911 builder.add_edge((i, "control"), (a, "control"), None)?;
912 builder.add_edge((a, "control"), (b, "control"), None)?;
913 builder.add_edge((a, "value"), (b, "value"), None)?;
914 builder.add_edge((b, "control"), (o, "control"), None)?;
915 builder.add_edge((b, "value"), (a, "value"), None)?;
916
917 assert!(matches!(
918 builder.build(),
919 Err(GraphBuilderError::CyclicGraph)
920 ));
921
922 Ok(())
923 }
924
925 #[test]
928 fn test_cycle() -> Result<(), Box<dyn Error>> {
929 let mut builder = GraphBuilder::new();
930
931 let [i, o] = Graph::boundary();
932 let a = builder.add_node("cnot")?;
933 builder.add_edge((i, "control"), (a, "control"), None)?;
934 builder.add_edge((a, "control"), (o, "control"), None)?;
935 builder.add_edge((a, "value"), (a, "value"), None)?;
936
937 assert!(matches!(
938 builder.build(),
939 Err(GraphBuilderError::CyclicGraph)
940 ));
941
942 Ok(())
943 }
944
945 #[test]
947 fn test_duplicate_node_input() -> Result<(), Box<dyn Error>> {
948 let mut builder = GraphBuilder::new();
949 let [i, _] = Graph::boundary();
950
951 let d = builder.add_node("discard")?;
952 builder.add_edge((i, "a"), (d, "value"), None)?;
953 let result = builder.add_edge((i, "b"), (d, "value"), None);
954
955 assert!(matches!(
956 result,
957 Err(GraphBuilderError::OccupiedInputPort(_))
958 ));
959
960 Ok(())
961 }
962}