Skip to main content

tierkreis_core/
graph.rs

1//! Crate versions of graph protobufs, plus [GraphBuilder]
2use super::portgraph::graph::{ConnectError, Direction};
3use crate::prelude::{TryFrom, TryInto};
4use crate::symbol::{FunctionName, Label, Location, SymbolError, TypeVar};
5use indexmap::{IndexMap, IndexSet};
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8use std::convert::Infallible;
9use std::hash::{Hash, Hasher};
10use thiserror::Error;
11
12pub use super::portgraph::graph::{EdgeIndex, NodeIndex};
13
14/// A value that can be passed around a Tierkreis graph.
15/// A strict (no optional fields), crate version of protobuf `tierkreis.v1alpha.graph.Value`.
16#[derive(Clone, Debug)]
17pub enum Value {
18    /// Boolean value (true or false)
19    Bool(bool),
20    /// Signed integer
21    Int(i64),
22    /// String
23    Str(String),
24    /// Double-precision (64-bit) float
25    Float(f64),
26    /// A Tierkreis graph
27    Graph(Graph),
28    /// A pair of two other [Value]s
29    Pair(Box<(Value, Value)>),
30    /// A map from keys to values. Keys must be hashable, i.e. must not be or contain
31    /// [Value::Graph], [Value::Map], [Value::Struct] or [Value::Float].
32    Map(HashMap<Value, Value>),
33    /// List of [Value]s.
34    Vec(Vec<Value>),
35    /// A struct or record type with string-named fields (themselves [Value]s)
36    Struct(HashMap<Label, Value>),
37    /// A [Value] tagged with a string to make a disjoint union of component types
38    Variant(Label, Box<Value>),
39}
40
41impl PartialEq for Value {
42    fn eq(&self, other: &Value) -> bool {
43        match (self, other) {
44            (Value::Bool(x), Value::Bool(y)) => x == y,
45            (Value::Int(x), Value::Int(y)) => x == y,
46            (Value::Str(x), Value::Str(y)) => x == y,
47            (Value::Float(x), Value::Float(y)) => x == y,
48            (Value::Pair(x), Value::Pair(y)) => x == y,
49            (Value::Map(x), Value::Map(y)) => x == y,
50            (Value::Struct(f1), Value::Struct(f2)) => f1 == f2,
51            (Value::Vec(ar1), Value::Vec(ar2)) => ar1 == ar2,
52            (Value::Graph(x), Value::Graph(y)) => x == y,
53            (Value::Variant(t1, v1), Value::Variant(t2, v2)) => (t1 == t2) && (v1 == v2),
54            _ => false,
55        }
56    }
57}
58
59// Required since we use Value's as the *keys* of HashMap's
60impl Eq for Value {}
61
62impl Hash for Value {
63    fn hash<H: Hasher>(&self, state: &mut H) {
64        match self {
65            Value::Graph(_) | Value::Map(_) | Value::Struct(_) | Value::Float(_) => {
66                panic!("Value is not hashable: {:?}", self)
67            }
68            Value::Bool(x) => x.hash(state),
69            Value::Int(x) => x.hash(state),
70            Value::Str(x) => x.hash(state),
71            Value::Pair(x) => x.hash(state),
72            Value::Vec(x) => x.hash(state),
73            Value::Variant(tag, value) => {
74                tag.hash(state);
75                value.hash(state);
76            }
77        }
78    }
79}
80
81/// Crate version of protobuf `tierkreis.v1alpha1.graph.Type`.
82/// A type of values (or [Row](Type::Row)s) that can be passed around a Tierkreis [Graph].
83#[derive(Clone, PartialEq, Eq, Hash)]
84pub enum Type {
85    /// Type ([Kind::Star]) of Booleans, i.e. with two values `true` and `false`
86    Bool,
87    /// Type ([Kind::Star]) of signed integers
88    Int,
89    /// Type ([Kind::Star]) of strings
90    Str,
91    /// Type ([Kind::Star]) of floating-point numbers (double-precision)
92    Float,
93    /// Type ([Kind::Star]) identifying the set of graphs with the specified
94    /// row of named inputs and row of named outputs
95    Graph(GraphType),
96    /// Type ([Kind::Star]) of pairs (where the first value has one type, and
97    /// the second another); each component must also be a [Kind::Star].
98    Pair(Box<Type>, Box<Type>),
99    /// Type ([Kind::Star]) of lists of elements all of the same type (given,
100    /// also [Kind::Star]).
101    Vec(Box<Type>),
102    /// Type variable, used (in types) inside polymorphic [TypeScheme]s only.
103    /// Can be a [Kind::Row] or a [Kind::Star].
104    Var(TypeVar),
105    /// A named row of types. Unlike the other variants (except possibly [Type::Var]),
106    /// this is a ([Kind::Row])), *not* a type of values, so cannot be used as a member
107    /// of any other [Type] (e.g. [Type::Pair], [Type::Vec]), or as the type of a field
108    /// in a [RowType].
109    /// However, can appear in [Constraint::Lacks::row] or [Constraint::Partition],
110    /// and can be returned in [TypeError::Unify]
111    ///
112    /// [TypeError::Unify]: super::type_checker::TypeError::Unify
113    Row(RowType),
114    /// Type ([Kind::Star]) of maps from a key type to value type (both [Kind::Star]).
115    // We do nothing to rule out key *types* that are not hashable, only values
116    Map(Box<Type>, Box<Type>),
117    /// Struct type (i.e. [Kind::Star]): made up of an unordered collection of named
118    /// fields each with a type ([Kind::Star]). Optionally, the type itself may have
119    /// a name.
120    Struct(RowType, Option<String>),
121    /// A disjoint (tagged) union of other types, given as a row.
122    /// May be open, for the output of a Tag operation, or closed,
123    /// for the input to match (where the handlers are known).
124    Variant(RowType),
125}
126
127impl Type {
128    /// Makes an unnamed [Type::Struct] given a row of fields
129    pub fn struct_from_row(row: RowType) -> Self {
130        Self::Struct(row, None)
131    }
132}
133
134impl std::fmt::Debug for Type {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            Type::Bool => f.debug_struct("Bool").finish(),
138            Type::Int => f.debug_struct("Int").finish(),
139            Type::Str => f.debug_struct("Str").finish(),
140            Type::Float => f.debug_struct("Float").finish(),
141            Type::Graph(graph) => std::fmt::Debug::fmt(graph, f),
142            Type::Pair(first, second) => f.debug_tuple("Pair").field(first).field(second).finish(),
143            Type::Vec(element) => f.debug_tuple("Vector").field(element).finish(),
144            Type::Var(var) => std::fmt::Debug::fmt(var, f),
145            Type::Row(row) => f.debug_tuple("Row").field(row).finish(),
146            Type::Map(key, value) => f.debug_tuple("Map").field(key).field(value).finish(),
147            Type::Struct(row, name) => match name {
148                Some(name) => f.debug_struct(name).finish(),
149                None => f.debug_tuple("Struct").field(row).finish(),
150            },
151            Type::Variant(row) => f.debug_tuple("Variant").field(row).finish(),
152        }
153    }
154}
155
156impl From<GraphType> for Type {
157    fn from(t: GraphType) -> Self {
158        Type::Graph(t)
159    }
160}
161
162impl From<TypeVar> for Type {
163    fn from(t: TypeVar) -> Self {
164        Type::Var(t)
165    }
166}
167
168impl Type {
169    /// Iterator over the type variables in the type in the order they occur.
170    /// Each variable is returned only once.
171    pub fn type_vars(&self) -> impl Iterator<Item = TypeVar> + '_ {
172        let mut vars = IndexSet::new();
173        self.type_vars_impl(&mut vars);
174        vars.into_iter()
175    }
176
177    fn type_vars_impl(&self, vars: &mut IndexSet<TypeVar>) {
178        match self {
179            Type::Bool => {}
180            Type::Int => {}
181            Type::Str => {}
182            Type::Float => {}
183            Type::Graph(graph) => {
184                graph.type_vars_impl(vars);
185            }
186            Type::Pair(left, right) => {
187                left.type_vars_impl(vars);
188                right.type_vars_impl(vars);
189            }
190            Type::Vec(element) => {
191                element.type_vars_impl(vars);
192            }
193            Type::Var(var) => {
194                vars.insert(*var);
195            }
196            Type::Row(row) => {
197                row.type_vars_impl(vars);
198            }
199            Type::Map(key, value) => {
200                key.type_vars_impl(vars);
201                value.type_vars_impl(vars);
202            }
203            Type::Struct(row, _) => {
204                row.type_vars_impl(vars);
205            }
206            Type::Variant(row) => {
207                row.type_vars_impl(vars);
208            }
209        }
210    }
211}
212
213/// Type of a Graph, i.e. a higher-order function value, with input and output rows.
214#[derive(Clone, PartialEq, Eq, Hash)]
215pub struct GraphType {
216    /// The inputs to the graph (known and/or variable)
217    pub inputs: RowType,
218    /// The outputs from the graph (known and/or variable)
219    pub outputs: RowType,
220}
221
222impl GraphType {
223    /// Creates a new instance with closed empty input and output rows
224    pub fn new() -> Self {
225        Self {
226            inputs: Default::default(),
227            outputs: Default::default(),
228        }
229    }
230
231    /// Adds a new input given a label and type
232    pub fn add_input(&mut self, port: impl Into<Label>, type_: impl Into<Type>) {
233        self.inputs.content.insert(port.into(), type_.into());
234    }
235
236    /// Adds a new output given a label and type
237    pub fn add_output(&mut self, port: impl Into<Label>, type_: impl Into<Type>) {
238        self.outputs.content.insert(port.into(), type_.into());
239    }
240
241    fn type_vars_impl(&self, vars: &mut IndexSet<TypeVar>) {
242        self.inputs.type_vars_impl(vars);
243        self.outputs.type_vars_impl(vars);
244    }
245}
246
247impl Default for GraphType {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253impl std::fmt::Debug for GraphType {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        f.debug_struct("Graph")
256            .field("inputs", &self.inputs)
257            .field("outputs", &self.outputs)
258            .finish()
259    }
260}
261
262/// A row of named types, possibly including a variable for an unknown portion
263#[derive(Clone, PartialEq, Eq, Hash, Default)]
264pub struct RowType {
265    /// Known labels, and types for each (of course these may contain variables)
266    pub content: BTreeMap<Label, Type>,
267    /// Either `Some` (of a variable whose [Kind] is [Kind::Row]), in which case
268    /// the [RowType] is an "open row" that may contain any number of other fields;
269    /// or `None` for a "closed row" i.e. exactly/only the fields in [Self::content]
270    pub rest: Option<TypeVar>,
271}
272
273impl std::fmt::Debug for RowType {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        f.debug_map()
276            .entries(self.content.iter())
277            .entries(self.rest.iter().map(|rest| ("#", rest)))
278            .finish()
279    }
280}
281
282impl From<TypeVar> for RowType {
283    fn from(var: TypeVar) -> Self {
284        Self {
285            content: BTreeMap::new(),
286            rest: Some(var),
287        }
288    }
289}
290
291impl RowType {
292    fn type_vars_impl(&self, vars: &mut IndexSet<TypeVar>) {
293        for type_ in self.content.values() {
294            type_.type_vars_impl(vars);
295        }
296
297        if let Some(var) = self.rest {
298            vars.insert(var);
299        }
300    }
301}
302
303/// The kind of a type variable - i.e. whether the "type" variable stands for
304/// a single type, or (some part of) a row.
305#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
306pub enum Kind {
307    /// A single type
308    Star,
309    /// A row of named types
310    Row,
311}
312
313/// A polymorphic type scheme. Usually (but not necessarily) for a function.
314#[derive(Debug, Clone, PartialEq, Eq)]
315pub struct TypeScheme {
316    /// Variables over which the scheme is polymorphic, each with its [Kind].
317    /// A concrete type (usable for a port or edge) is obtained by supplying
318    /// a type or row (according to the [Kind]) for each.
319    pub variables: IndexMap<TypeVar, Kind>,
320    /// Constraints restricting the legal instantiations of the [Self::variables]
321    pub constraints: Vec<Constraint>,
322    /// The body of the type scheme, i.e. potentially containing occurrences
323    /// of the [Self::variables], into which values for the variables are
324    /// substituted when the scheme is instantiated to a concrete type.
325    pub body: Type,
326}
327
328impl TypeScheme {
329    /// Adds (binds) a new variable of a given kind (destructive update)
330    pub fn with_variable(mut self, var: impl Into<TypeVar>, kind: Kind) -> Self {
331        self.variables.insert(var.into(), kind);
332        self
333    }
334
335    /// Adds a new constraint (destructive update). This should refer (only) to
336    /// variables bound by the scheme.
337    pub fn with_constraint(mut self, constraint: impl Into<Constraint>) -> Self {
338        self.constraints.push(constraint.into());
339        self
340    }
341}
342
343impl From<Type> for TypeScheme {
344    fn from(type_: Type) -> Self {
345        TypeScheme {
346            variables: IndexMap::new(),
347            constraints: Vec::new(),
348            body: type_,
349        }
350    }
351}
352
353impl From<GraphType> for TypeScheme {
354    fn from(graph_type: GraphType) -> Self {
355        Type::Graph(graph_type).into()
356    }
357}
358
359/// Specifies restrictions on the instantiations of type variables in a [TypeScheme]
360#[derive(Debug, Clone, PartialEq, Eq)]
361pub enum Constraint {
362    /// A row must *not* have a specified label
363    Lacks {
364        /// Either a [Type::Row] or a [Type::Var] of [Kind::Row].
365        row: Type,
366        /// Field name that must not be present in [Self::Lacks::row].
367        label: Label,
368    },
369    /// A row must be the union of two other rows (which for any label in common,
370    /// must have [Type]s that can be made the same).
371    Partition {
372        /// One input to the union - a [Type::Row] or a [Type::Var] of [Kind::Row].
373        left: Type,
374        /// The other input to the union - a [Type::Row] or a [Type::Var] of [Kind::Row].
375        right: Type,
376        /// The result, i.e. left and right merged together.
377        /// Could be a [Type::Row] or a [Type::Var] of [Kind::Row].
378        union: Type,
379    },
380}
381
382/// A node in a [`Graph`].
383#[derive(Clone, Debug, PartialEq)]
384pub enum Node {
385    /// The node that emits a graph's input values.
386    Input,
387
388    /// The node that receives a graph's output values.
389    Output,
390
391    /// A node that emits a constant value.
392    Const(Value),
393
394    /// A subgraph embedded as a single node. The ports of the node are the ports of the embedded
395    /// graph. Box nodes can be used to conveniently compose common subgraphs.
396    /// The box will be run on the specified location.
397    Box(Location, Graph),
398
399    /// A node that executes a function with a specified name. The type and runtime behavior of a
400    /// function node depend on the functions provided by the environment in which the graph is
401    /// interpreted.
402    Function(FunctionName, Option<u32>),
403
404    /// Perform pattern matching on a variant type.
405    Match,
406
407    /// Create a variant. Tag(tag) :: forall T. T -> Variant[tag:T|...]
408    Tag(Label),
409}
410
411impl Node {
412    /// Does this node run something on a machine/process outside this runtime?
413    pub fn is_external(&self) -> bool {
414        match self {
415            Node::Function(f, _) => !f.is_builtin(),
416            Node::Box(l, _) => l != &Location::local(),
417            _ => false,
418        }
419    }
420
421    /// Makes a [Node::Box] that runs a given graph on the same runtime
422    /// as is running the node's parent (i.e. at [Location::local])
423    pub fn local_box(graph: Graph) -> Self {
424        Node::Box(Location::local(), graph)
425    }
426}
427
428/// Convert a [`Value`] into a constant node.
429impl From<Value> for Node {
430    fn from(value: Value) -> Self {
431        Node::Const(value)
432    }
433}
434
435impl TryFrom<Value> for Node {
436    type Error = Infallible;
437
438    fn try_from(value: Value) -> Result<Self, Self::Error> {
439        Ok(value.into())
440    }
441}
442
443/// Convert a [`Graph`] into a box node.
444impl From<Graph> for Node {
445    fn from(graph: Graph) -> Self {
446        Node::local_box(graph)
447    }
448}
449
450impl TryFrom<Graph> for Node {
451    type Error = Infallible;
452
453    fn try_from(graph: Graph) -> Result<Self, Self::Error> {
454        Ok(graph.into())
455    }
456}
457
458/// Convert a [`FunctionName`] into a function node, with default timeout
459impl From<FunctionName> for Node {
460    fn from(function: FunctionName) -> Self {
461        Node::Function(function, None)
462    }
463}
464
465impl TryFrom<FunctionName> for Node {
466    type Error = Infallible;
467
468    fn try_from(function: FunctionName) -> Result<Self, Self::Error> {
469        Ok(function.into())
470    }
471}
472
473/// Convert a [`FunctionName`] and timeout into a function node
474impl From<(FunctionName, u32)> for Node {
475    fn from(fn_t: (FunctionName, u32)) -> Self {
476        Node::Function(fn_t.0, Some(fn_t.1))
477    }
478}
479
480impl TryFrom<(FunctionName, u32)> for Node {
481    type Error = Infallible;
482
483    fn try_from(fn_t: (FunctionName, u32)) -> Result<Self, Self::Error> {
484        Ok(fn_t.into())
485    }
486}
487
488impl TryFrom<&str> for Node {
489    type Error = SymbolError;
490
491    fn try_from(value: &str) -> Result<Self, Self::Error> {
492        Ok(Node::Function(value.parse()?, None))
493    }
494}
495
496impl TryFrom<(&str, u32)> for Node {
497    type Error = SymbolError;
498
499    fn try_from(value: (&str, u32)) -> Result<Self, Self::Error> {
500        Ok(Node::Function(value.0.parse()?, Some(value.1)))
501    }
502}
503
504/// An edge in a [`Graph`].
505#[derive(Clone, Debug, PartialEq, Eq, Hash)]
506pub struct Edge {
507    /// Source (out-)port
508    pub source: NodePort,
509    /// Target/destination (in-)port
510    pub target: NodePort,
511    /// Explicit annotation of the type of the edge. May (optionally) be
512    /// provided by the client; will be filled in by typechecking.
513    pub edge_type: Option<Type>,
514}
515
516// uniquely identify a port in the graph by node and port of node
517#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
518#[allow(missing_docs)]
519pub struct NodePort {
520    pub node: NodeIndex,
521    pub port: Label,
522}
523
524impl std::fmt::Display for NodePort {
525    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        write!(f, "{}:{}", self.node.index(), self.port)
527    }
528}
529
530impl<N, P> From<(N, P)> for NodePort
531where
532    N: Into<NodeIndex>,
533    P: Into<Label>,
534{
535    fn from((node, port): (N, P)) -> Self {
536        Self {
537            node: node.into(),
538            port: port.into(),
539        }
540    }
541}
542
543impl<N, P> TryFrom<(N, P)> for NodePort
544where
545    N: Into<NodeIndex>,
546    P: TryInto<Label>,
547    <P as TryInto<Label>>::Error: Into<SymbolError>,
548{
549    type Error = SymbolError;
550
551    fn try_from((node, port): (N, P)) -> Result<Self, Self::Error> {
552        Ok(Self {
553            node: node.into(),
554            port: TryInto::try_into(port).map_err(|e| e.into())?,
555        })
556    }
557}
558
559/// Computation graph.
560///
561/// A computation graph is a directed acyclic port graph in which data flows along the edges and as
562/// it is processed by the nodes.  Nodes in the graph are addressed by their unique names. The
563/// inputs of the graph are emitted by a node named `input` and received by a node named `output`.
564/// Each node has labelled input and output ports, encoded implicitly as the endpoints of the
565/// graph's edges. The port labels are unique among a node's input and output ports individually,
566/// but input and output ports with the same label are considered different. Any port in the graph
567/// has exactly one edge connected to it. An edge can optionally be annotated with a type.
568///
569/// This type is immutable. The [`GraphBuilder`] can be used to construct new instances.
570#[derive(Clone, Debug)]
571pub struct Graph {
572    internal_graph: super::portgraph::graph::Graph<Node, Edge>,
573    name: String,
574    input_order: Vec<Label>,
575    output_order: Vec<Label>,
576}
577
578impl Graph {
579    /// Gets the indices of the [Node::Input] node and [Node::Output] node.
580    pub fn boundary() -> [NodeIndex; 2] {
581        [NodeIndex::new(0), NodeIndex::new(1)]
582    }
583
584    /// Iterator over the node indices of the graph
585    pub fn node_indices(&self) -> impl Iterator<Item = NodeIndex> + '_ {
586        self.internal_graph.node_indices()
587    }
588
589    /// Iterator over the nodes of the graph
590    pub fn nodes(&self) -> impl Iterator<Item = &Node> {
591        self.internal_graph.node_weights()
592    }
593
594    /// Iterator over the edges of the graph.
595    pub fn edges(&self) -> impl Iterator<Item = &Edge> {
596        self.internal_graph.edge_weights()
597    }
598
599    /// Returns the node with the given index.
600    pub fn node(&self, node: impl Into<NodeIndex>) -> Option<&Node> {
601        self.internal_graph.node_weight(node.into())
602    }
603
604    /// Returns the edge connected to an input port of the graph.
605    /// This is an output port of the unique input node.
606    pub fn input<E>(&self, port: impl TryInto<Label, Error = E>) -> Option<&Edge>
607    where
608        E: Into<SymbolError>,
609    {
610        self.node_output((Self::boundary()[0], port))
611    }
612
613    /// Returns the edge connected to an output port of the graph.
614    /// This is an input port of the unique output node.
615    pub fn output<E>(&self, port: impl TryInto<Label, Error = E>) -> Option<&Edge>
616    where
617        E: Into<SymbolError>,
618    {
619        self.node_input((Self::boundary()[1], port))
620    }
621
622    /// Iterator over the input edges for a node.
623    /// An empty iterator is returned when there is no node with the given index.
624    pub fn node_inputs(&self, node: NodeIndex) -> impl Iterator<Item = &Edge> + '_ {
625        self.internal_graph
626            .node_edges(node, Direction::Incoming)
627            .map(move |edge| {
628                self.internal_graph
629                    .edge_weight(edge)
630                    .expect("missing edge.")
631            })
632    }
633
634    /// Iterator over the output edges for a node.
635    /// An empty iterator is returned when there is no node with the given index.
636    pub fn node_outputs(&self, node: NodeIndex) -> impl Iterator<Item = &Edge> + '_ {
637        self.internal_graph
638            .node_edges(node, Direction::Outgoing)
639            .map(move |edge| {
640                self.internal_graph
641                    .edge_weight(edge)
642                    .expect("missing edge.")
643            })
644    }
645
646    /// Returns the edge connected to a given input port.
647    /// If the node or the edge does not exist `None` is returned.
648    pub fn node_input(&self, port: impl TryInto<NodePort>) -> Option<&Edge> {
649        let port = TryInto::try_into(port).ok()?;
650        self.node_inputs(port.node)
651            .find(|edge| edge.target.port == port.port)
652    }
653
654    /// Returns the edge connected to a given output port.
655    /// If the node or the edge does not exist `None` is returned.
656    pub fn node_output(&self, port: impl TryInto<NodePort>) -> Option<&Edge> {
657        let port = TryInto::try_into(port).ok()?;
658        self.node_outputs(port.node)
659            .find(|edge| edge.source.port == port.port)
660    }
661
662    /// The name of the graph, provided by the creator.
663    pub fn name(&self) -> &String {
664        &self.name
665    }
666
667    /// Iterator over the labels of the input ports to the graph, in the
668    /// order specified for debugging/display by [GraphBuilder::set_io_order].
669    /// Equivalently, the out-ports of the [Node::Input].
670    pub fn inputs(&self) -> impl Iterator<Item = &Label> {
671        self.input_order.iter()
672    }
673
674    /// Iterator over the labels of the input ports to the graph, in the
675    /// order specified for debugging/display by [GraphBuilder::set_io_order].
676    /// Equivalently, the in-ports of the [Node::Output].
677    pub fn outputs(&self) -> impl Iterator<Item = &Label> {
678        self.output_order.iter()
679    }
680}
681
682/// Graph builder to construct [`Graph`] instances.
683pub struct GraphBuilder {
684    graph: Graph,
685}
686
687impl GraphBuilder {
688    /// Creates a new graph builder.
689    /// The input and output nodes are created automatically.
690    pub fn new() -> Self {
691        let mut builder = GraphBuilder {
692            graph: Graph {
693                internal_graph: super::portgraph::graph::Graph::new(),
694                name: "".into(),
695                input_order: vec![],
696                output_order: vec![],
697            },
698        };
699
700        builder.add_node(Node::Input).unwrap();
701        builder.add_node(Node::Output).unwrap();
702
703        builder
704    }
705
706    /// Sets the name of the Graph being built
707    pub fn set_name(&mut self, name: String) {
708        self.graph.name = name;
709    }
710
711    /// Add a new node to the graph with a set of incoming and outgoing edges.
712    /// Most efficient way to set edge orderings.
713    ///
714    /// # Errors
715    ///
716    /// * Attempting to create an additional input or output node.
717    /// * Already connected edges
718    pub fn add_node_with_edges<F>(
719        &mut self,
720        node: impl TryInto<Node, Error = F>,
721        incoming: impl IntoIterator<Item = EdgeIndex>,
722        outgoing: impl IntoIterator<Item = EdgeIndex>,
723    ) -> Result<NodeIndex, GraphBuilderError>
724    where
725        F: Into<SymbolError>,
726    {
727        let node = TryInto::try_into(node).map_err(|e| GraphBuilderError::SymbolError(e.into()))?;
728
729        Ok(self
730            .graph
731            .internal_graph
732            .add_node_with_edges(node, incoming, outgoing)?)
733    }
734
735    /// Add a new node to the graph.
736    ///
737    /// # Errors
738    ///
739    /// * Attempting to create an additional input or output node.
740    pub fn add_node<F>(
741        &mut self,
742        node: impl TryInto<Node, Error = F>,
743    ) -> Result<NodeIndex, GraphBuilderError>
744    where
745        F: Into<SymbolError>,
746    {
747        self.add_node_with_edges(node, [], [])
748    }
749
750    /// Add a new edge to the graph connecting the given input and output ports.
751    /// The edge can also be optionally annotated with a type.
752    ///
753    /// # Errors
754    ///
755    /// * The source or target node does not exist.
756    /// * There already is an edge connected to the source or target port.
757    ///
758    /// This method does not fail in the case a cycle is introduced into the graph.
759    /// The check of acyclicity is deferred to [`GraphBuilder::build`] which can
760    /// check the entire graph at once in linear time.
761    pub fn add_edge<E, F>(
762        &mut self,
763        source: impl TryInto<NodePort, Error = E>,
764        target: impl TryInto<NodePort, Error = F>,
765        edge_type: impl Into<Option<Type>>,
766    ) -> Result<(), GraphBuilderError>
767    where
768        E: Into<SymbolError>,
769        F: Into<SymbolError>,
770    {
771        let source =
772            TryInto::try_into(source).map_err(|e| GraphBuilderError::SymbolError(e.into()))?;
773        let target =
774            TryInto::try_into(target).map_err(|e| GraphBuilderError::SymbolError(e.into()))?;
775        let edge_type = edge_type.into();
776
777        if self.graph.node_input(target).is_some() {
778            return Err(GraphBuilderError::OccupiedInputPort(target));
779        } else if self.graph.node_output(source).is_some() {
780            return Err(GraphBuilderError::OccupiedOutputPort(source));
781        }
782
783        let edge = Edge {
784            source,
785            target,
786            edge_type,
787        };
788
789        let eidx = self.graph.internal_graph.add_edge(edge);
790
791        self.graph
792            .internal_graph
793            .connect(source.node, eidx, Direction::Outgoing, None)?;
794        self.graph
795            .internal_graph
796            .connect(target.node, eidx, Direction::Incoming, None)?;
797
798        Ok(())
799    }
800
801    fn is_acyclic(&self) -> bool {
802        let mut stack: Vec<_> = self.graph.node_indices().collect();
803        let mut discovered = HashSet::new();
804        let mut finished = HashSet::new();
805
806        while let Some(node) = stack.pop() {
807            if !discovered.insert(node) {
808                finished.insert(node);
809                continue;
810            }
811
812            stack.push(node);
813
814            for edge in self.graph.node_outputs(node) {
815                let target = edge.target.node;
816                if !discovered.contains(&target) {
817                    stack.push(target);
818                } else if !finished.contains(&target) {
819                    return false;
820                }
821            }
822        }
823
824        true
825    }
826
827    /// Construct the finished graph.
828    ///
829    /// # Errors
830    ///
831    /// This method fails if there is a cycle in the graph.
832    pub fn build(self) -> Result<Graph, GraphBuilderError> {
833        if !self.is_acyclic() {
834            return Err(GraphBuilderError::CyclicGraph);
835        }
836
837        Ok(self.graph)
838    }
839
840    /// Sets the order of inputs, and of outputs, for debugging/display.
841    /// (Replaces any previously-provided order.)
842    pub fn set_io_order(&mut self, io_order: [Vec<Label>; 2]) {
843        let [i, o] = io_order;
844        self.graph.input_order = i;
845        self.graph.output_order = o;
846    }
847}
848
849impl Default for GraphBuilder {
850    fn default() -> Self {
851        Self::new()
852    }
853}
854
855/// Error in building a graph with a [GraphBuilder]
856#[derive(Debug, Error)]
857pub enum GraphBuilderError {
858    /// Tried to add an edge to a in-port that already has an incoming edge.
859    #[error("there already is an edge at the input port '{0}'")]
860    OccupiedInputPort(NodePort),
861    /// Tried to add an edge from an out-port that already has an outgoing edge.
862    #[error("there already is an edge at the output port '{0}'")]
863    OccupiedOutputPort(NodePort),
864    /// The edges (dependencies) in a Graph contained a cycle
865    #[error("the graph must be acyclic")]
866    CyclicGraph,
867    /// An error in a (supposed) qualified name
868    #[error("encountered error importing symbols: {0}")]
869    SymbolError(#[from] SymbolError),
870    /// Error from [super::portgraph] library that edge could not be added,
871    /// e.g. if the source or target node-indices do not exist.
872    #[error("Error when connecting edges.")]
873    ConnectError(#[from] ConnectError),
874}
875
876impl PartialEq for Graph {
877    fn eq(&self, other: &Self) -> bool {
878        if self.name != other.name {
879            return false;
880        }
881        // Compare nodes
882        let our_nodes: Vec<&Node> = self.nodes().collect();
883        let other_nodes: Vec<&Node> = other.nodes().collect();
884
885        if our_nodes != other_nodes {
886            return false;
887        }
888
889        // Compare edges
890        let our_edges: HashSet<&Edge> = self.edges().collect();
891        let other_edges: HashSet<&Edge> = other.edges().collect();
892
893        our_edges == other_edges
894    }
895}
896
897#[cfg(test)]
898mod test {
899    use crate::graph::{Graph, GraphBuilder, GraphBuilderError};
900    use std::error::Error;
901
902    /// Test that graphs with a cycle between different nodes can not be constructed. The graph constructed
903    /// in this test is a partial trace on the value bit of a two cnot gates wired in sequence.
904    #[test]
905    fn test_cyclic_graph() -> Result<(), Box<dyn Error>> {
906        let mut builder = GraphBuilder::new();
907        let [i, o] = Graph::boundary();
908        let a = builder.add_node("cnot")?;
909        let b = builder.add_node("cnot")?;
910
911        builder.add_edge((i, "control"), (a, "control"), None)?;
912        builder.add_edge((a, "control"), (b, "control"), None)?;
913        builder.add_edge((a, "value"), (b, "value"), None)?;
914        builder.add_edge((b, "control"), (o, "control"), None)?;
915        builder.add_edge((b, "value"), (a, "value"), None)?;
916
917        assert!(matches!(
918            builder.build(),
919            Err(GraphBuilderError::CyclicGraph)
920        ));
921
922        Ok(())
923    }
924
925    /// Test that graphs with a cycle on a node are rejected. The graph constructed in this test is
926    /// a partial trace on the value bit of a cnot gate.
927    #[test]
928    fn test_cycle() -> Result<(), Box<dyn Error>> {
929        let mut builder = GraphBuilder::new();
930
931        let [i, o] = Graph::boundary();
932        let a = builder.add_node("cnot")?;
933        builder.add_edge((i, "control"), (a, "control"), None)?;
934        builder.add_edge((a, "control"), (o, "control"), None)?;
935        builder.add_edge((a, "value"), (a, "value"), None)?;
936
937        assert!(matches!(
938            builder.build(),
939            Err(GraphBuilderError::CyclicGraph)
940        ));
941
942        Ok(())
943    }
944
945    /// Test that doubled edges at an input port are rejected.
946    #[test]
947    fn test_duplicate_node_input() -> Result<(), Box<dyn Error>> {
948        let mut builder = GraphBuilder::new();
949        let [i, _] = Graph::boundary();
950
951        let d = builder.add_node("discard")?;
952        builder.add_edge((i, "a"), (d, "value"), None)?;
953        let result = builder.add_edge((i, "b"), (d, "value"), None);
954
955        assert!(matches!(
956            result,
957            Err(GraphBuilderError::OccupiedInputPort(_))
958        ));
959
960        Ok(())
961    }
962}