tree_sitter_graph/
graph.rs

1// -*- coding: utf-8 -*-
2// ------------------------------------------------------------------------------------------------
3// Copyright © 2021, tree-sitter authors.
4// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
5// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
6// ------------------------------------------------------------------------------------------------
7
8//! Defines data types for the graphs produced by the graph DSL
9
10use std::borrow::Borrow;
11use std::collections::hash_map::Entry;
12use std::collections::BTreeSet;
13use std::collections::HashMap;
14use std::fmt;
15use std::fs::File;
16use std::hash::Hash;
17use std::io::prelude::*;
18use std::io::stdout;
19use std::ops::Index;
20use std::ops::IndexMut;
21use std::path::Path;
22
23use serde::ser::SerializeMap;
24use serde::ser::SerializeSeq;
25use serde::Serialize;
26use serde::Serializer;
27use serde_json;
28use smallvec::SmallVec;
29use tree_sitter::Node;
30
31use crate::execution::error::ExecutionError;
32use crate::Identifier;
33use crate::Location;
34
35/// A graph produced by executing a graph DSL file.  Graphs include a lifetime parameter to ensure
36/// that they don't outlive the tree-sitter syntax tree that they are generated from.
37#[derive(Default)]
38pub struct Graph<'tree> {
39    pub(crate) syntax_nodes: HashMap<SyntaxNodeID, Node<'tree>>,
40    graph_nodes: Vec<GraphNode>,
41}
42
43pub(crate) type SyntaxNodeID = u32;
44type GraphNodeID = u32;
45
46impl<'tree> Graph<'tree> {
47    /// Creates a new, empty graph.
48    pub fn new() -> Graph<'tree> {
49        Graph::default()
50    }
51
52    /// Adds a syntax node to the graph, returning a graph DSL reference to it.
53    ///
54    /// The graph won't contain _every_ syntax node in the parsed syntax tree; it will only contain
55    /// those nodes that are referenced at some point during the execution of the graph DSL file.
56    pub fn add_syntax_node(&mut self, node: Node<'tree>) -> SyntaxNodeRef {
57        let index = node.id() as SyntaxNodeID;
58        let node_ref = SyntaxNodeRef {
59            index,
60            kind: node.kind(),
61            position: node.start_position(),
62        };
63        self.syntax_nodes.entry(index).or_insert(node);
64        node_ref
65    }
66
67    /// Adds a new graph node to the graph, returning a graph DSL reference to it.
68    pub fn add_graph_node(&mut self) -> GraphNodeRef {
69        let graph_node = GraphNode::new();
70        let index = self.graph_nodes.len() as GraphNodeID;
71        self.graph_nodes.push(graph_node);
72        GraphNodeRef(index)
73    }
74
75    /// Pretty-prints the contents of this graph.
76    pub fn pretty_print<'a>(&'a self) -> impl fmt::Display + 'a {
77        struct DisplayGraph<'a, 'tree>(&'a Graph<'tree>);
78
79        impl<'a, 'tree> fmt::Display for DisplayGraph<'a, 'tree> {
80            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
81                let graph = self.0;
82                for (node_index, node) in graph.graph_nodes.iter().enumerate() {
83                    write!(f, "node {}\n{}", node_index, node.attributes)?;
84                    for (sink, edge) in &node.outgoing_edges {
85                        write!(f, "edge {} -> {}\n{}", node_index, *sink, edge.attributes)?;
86                    }
87                }
88                Ok(())
89            }
90        }
91
92        DisplayGraph(self)
93    }
94
95    pub fn display_json(&self, path: Option<&Path>) -> std::io::Result<()> {
96        let s = serde_json::to_string_pretty(self).unwrap();
97        path.map_or(stdout().write_all(s.as_bytes()), |path| {
98            File::create(path)?.write_all(s.as_bytes())
99        })
100    }
101
102    // Returns an iterator of references to all of the nodes in the graph.
103    pub fn iter_nodes(&self) -> impl Iterator<Item = GraphNodeRef> {
104        (0..self.graph_nodes.len() as u32).map(GraphNodeRef)
105    }
106
107    // Returns the number of nodes in the graph.
108    pub fn node_count(&self) -> usize {
109        self.graph_nodes.len()
110    }
111}
112
113impl<'tree> Index<SyntaxNodeRef> for Graph<'tree> {
114    type Output = Node<'tree>;
115    fn index(&self, node_ref: SyntaxNodeRef) -> &Node<'tree> {
116        &self.syntax_nodes[&node_ref.index]
117    }
118}
119
120impl Index<GraphNodeRef> for Graph<'_> {
121    type Output = GraphNode;
122    fn index(&self, index: GraphNodeRef) -> &GraphNode {
123        &self.graph_nodes[index.0 as usize]
124    }
125}
126
127impl<'tree> IndexMut<GraphNodeRef> for Graph<'_> {
128    fn index_mut(&mut self, index: GraphNodeRef) -> &mut GraphNode {
129        &mut self.graph_nodes[index.0 as usize]
130    }
131}
132
133impl<'tree> Serialize for Graph<'tree> {
134    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
135        let mut seq = serializer.serialize_seq(Some(self.graph_nodes.len()))?;
136        for (node_index, node) in self.graph_nodes.iter().enumerate() {
137            seq.serialize_element(&SerializeGraphNode(node_index, node))?;
138        }
139        seq.end()
140    }
141}
142
143/// A node in a graph
144pub struct GraphNode {
145    outgoing_edges: SmallVec<[(GraphNodeID, Edge); 8]>,
146    /// The set of attributes associated with this graph node
147    pub attributes: Attributes,
148}
149
150impl GraphNode {
151    fn new() -> GraphNode {
152        GraphNode {
153            outgoing_edges: SmallVec::new(),
154            attributes: Attributes::new(),
155        }
156    }
157
158    /// Adds an edge to this node.  There can be at most one edge connecting any two graph nodes;
159    /// the result indicates whether the edge is new (`Ok`) or already existed (`Err`).  In either
160    /// case, you also get a mutable reference to the [`Edge`][] instance for the edge.
161    pub fn add_edge(&mut self, sink: GraphNodeRef) -> Result<&mut Edge, &mut Edge> {
162        let sink = sink.0;
163        match self
164            .outgoing_edges
165            .binary_search_by_key(&sink, |(sink, _)| *sink)
166        {
167            Ok(index) => Err(&mut self.outgoing_edges[index].1),
168            Err(index) => {
169                self.outgoing_edges.insert(index, (sink, Edge::new()));
170                Ok(&mut self.outgoing_edges[index].1)
171            }
172        }
173    }
174
175    /// Returns a reference to an outgoing edge from this node, if it exists.
176    pub fn get_edge(&self, sink: GraphNodeRef) -> Option<&Edge> {
177        let sink = sink.0;
178        self.outgoing_edges
179            .binary_search_by_key(&sink, |(sink, _)| *sink)
180            .ok()
181            .map(|index| &self.outgoing_edges[index].1)
182    }
183
184    /// Returns a mutable reference to an outgoing edge from this node, if it exists.
185    pub fn get_edge_mut(&mut self, sink: GraphNodeRef) -> Option<&mut Edge> {
186        let sink = sink.0;
187        self.outgoing_edges
188            .binary_search_by_key(&sink, |(sink, _)| *sink)
189            .ok()
190            .map(move |index| &mut self.outgoing_edges[index].1)
191    }
192
193    // Returns an iterator of all of the outgoing edges from this node.
194    pub fn iter_edges(&self) -> impl Iterator<Item = (GraphNodeRef, &Edge)> + '_ {
195        self.outgoing_edges
196            .iter()
197            .map(|(id, edge)| (GraphNodeRef(*id), edge))
198    }
199
200    // Returns the number of outgoing edges from this node.
201    pub fn edge_count(&self) -> usize {
202        self.outgoing_edges.len()
203    }
204}
205
206struct SerializeGraphNode<'a>(usize, &'a GraphNode);
207
208impl<'a> Serialize for SerializeGraphNode<'a> {
209    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
210        let node_index = self.0;
211        let node = self.1;
212        // serializing as a map instead of a struct so we don't have to encode a struct name
213        let mut map = serializer.serialize_map(None)?;
214        map.serialize_entry("id", &node_index)?;
215        map.serialize_entry("edges", &SerializeGraphNodeEdges(&node.outgoing_edges))?;
216        map.serialize_entry("attrs", &node.attributes)?;
217        map.end()
218    }
219}
220
221struct SerializeGraphNodeEdges<'a>(&'a SmallVec<[(GraphNodeID, Edge); 8]>);
222
223impl<'a> Serialize for SerializeGraphNodeEdges<'a> {
224    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
225        let edges = self.0;
226        let mut seq = serializer.serialize_seq(Some(edges.len()))?;
227        for element in edges {
228            seq.serialize_element(&SerializeGraphNodeEdge(&element))?;
229        }
230        seq.end()
231    }
232}
233
234struct SerializeGraphNodeEdge<'a>(&'a (GraphNodeID, Edge));
235
236impl<'a> Serialize for SerializeGraphNodeEdge<'a> {
237    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
238        let wrapped = &self.0;
239        let sink = &wrapped.0;
240        let edge = &wrapped.1;
241        let mut map = serializer.serialize_map(None)?;
242        map.serialize_entry("sink", sink)?;
243        map.serialize_entry("attrs", &edge.attributes)?;
244        map.end()
245    }
246}
247
248/// An edge between two nodes in a graph
249pub struct Edge {
250    /// The set of attributes associated with this edge
251    pub attributes: Attributes,
252}
253
254impl Edge {
255    fn new() -> Edge {
256        Edge {
257            attributes: Attributes::new(),
258        }
259    }
260}
261
262/// A set of attributes associated with a graph node or edge
263#[derive(Clone, Debug)]
264pub struct Attributes {
265    values: HashMap<Identifier, Value>,
266}
267
268impl Attributes {
269    /// Creates a new, empty set of attributes.
270    pub fn new() -> Attributes {
271        Attributes {
272            values: HashMap::new(),
273        }
274    }
275
276    /// Adds an attribute to this attribute set.  If there was already an attribute with the same
277    /// name, replaces its value and returns `Err`.
278    pub fn add<V: Into<Value>>(&mut self, name: Identifier, value: V) -> Result<(), Value> {
279        match self.values.entry(name) {
280            Entry::Occupied(mut o) => {
281                let value = value.into();
282                if o.get() != &value {
283                    Err(o.insert(value.into()))
284                } else {
285                    Ok(())
286                }
287            }
288            Entry::Vacant(v) => {
289                v.insert(value.into());
290                Ok(())
291            }
292        }
293    }
294
295    /// Returns the value of a particular attribute, if it exists.
296    pub fn get<Q>(&self, name: &Q) -> Option<&Value>
297    where
298        Q: ?Sized + Eq + Hash,
299        Identifier: Borrow<Q>,
300    {
301        self.values.get(name.borrow())
302    }
303
304    pub fn iter(&self) -> impl Iterator<Item = (&Identifier, &Value)> {
305        self.values.iter()
306    }
307}
308
309impl std::fmt::Display for Attributes {
310    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
311        let mut keys = self.values.keys().collect::<Vec<_>>();
312        keys.sort_by(|a, b| a.cmp(b));
313        for key in &keys {
314            let value = &self.values[*key];
315            write!(f, "  {}: {:?}\n", key, value)?;
316        }
317        Ok(())
318    }
319}
320
321impl Serialize for Attributes {
322    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
323        let mut map = serializer.serialize_map(None)?;
324        for (key, value) in &self.values {
325            map.serialize_entry(key, value)?;
326        }
327        map.end()
328    }
329}
330
331/// The value of an attribute
332#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
333pub enum Value {
334    // Scalar
335    Null,
336    Boolean(bool),
337    Integer(u32),
338    String(String),
339    // Compound
340    List(Vec<Value>),
341    Set(BTreeSet<Value>),
342    // References
343    SyntaxNode(SyntaxNodeRef),
344    GraphNode(GraphNodeRef),
345}
346
347impl Value {
348    /// Check if this value is null
349    pub fn is_null(&self) -> bool {
350        match self {
351            Value::Null => true,
352            _ => false,
353        }
354    }
355
356    /// Coerces this value into a boolean, returning an error if it's some other type of value.
357    pub fn into_boolean(self) -> Result<bool, ExecutionError> {
358        match self {
359            Value::Boolean(value) => Ok(value),
360            _ => Err(ExecutionError::ExpectedBoolean(format!("got {}", self))),
361        }
362    }
363
364    pub fn as_boolean(&self) -> Result<bool, ExecutionError> {
365        match self {
366            Value::Boolean(value) => Ok(*value),
367            _ => Err(ExecutionError::ExpectedBoolean(format!("got {}", self))),
368        }
369    }
370
371    /// Coerces this value into an integer, returning an error if it's some other type of value.
372    pub fn into_integer(self) -> Result<u32, ExecutionError> {
373        match self {
374            Value::Integer(value) => Ok(value),
375            _ => Err(ExecutionError::ExpectedInteger(format!("got {}", self))),
376        }
377    }
378
379    pub fn as_integer(&self) -> Result<u32, ExecutionError> {
380        match self {
381            Value::Integer(value) => Ok(*value),
382            _ => Err(ExecutionError::ExpectedInteger(format!("got {}", self))),
383        }
384    }
385
386    /// Coerces this value into a string, returning an error if it's some other type of value.
387    pub fn into_string(self) -> Result<String, ExecutionError> {
388        match self {
389            Value::String(value) => Ok(value),
390            _ => Err(ExecutionError::ExpectedString(format!("got {}", self))),
391        }
392    }
393
394    pub fn as_str(&self) -> Result<&str, ExecutionError> {
395        match self {
396            Value::String(value) => Ok(value),
397            _ => Err(ExecutionError::ExpectedString(format!("got {}", self))),
398        }
399    }
400
401    /// Coerces this value into a list, returning an error if it's some other type of value.
402    pub fn into_list(self) -> Result<Vec<Value>, ExecutionError> {
403        match self {
404            Value::List(values) => Ok(values),
405            _ => Err(ExecutionError::ExpectedList(format!("got {}", self))),
406        }
407    }
408
409    pub fn as_list(&self) -> Result<&Vec<Value>, ExecutionError> {
410        match self {
411            Value::List(values) => Ok(values),
412            _ => Err(ExecutionError::ExpectedList(format!("got {}", self))),
413        }
414    }
415
416    /// Coerces this value into a graph node reference, returning an error if it's some other type
417    /// of value.
418    pub fn into_graph_node_ref<'a, 'tree>(self) -> Result<GraphNodeRef, ExecutionError> {
419        match self {
420            Value::GraphNode(node) => Ok(node),
421            _ => Err(ExecutionError::ExpectedGraphNode(format!("got {}", self))),
422        }
423    }
424
425    pub fn as_graph_node_ref<'a, 'tree>(&self) -> Result<GraphNodeRef, ExecutionError> {
426        match self {
427            Value::GraphNode(node) => Ok(*node),
428            _ => Err(ExecutionError::ExpectedGraphNode(format!("got {}", self))),
429        }
430    }
431
432    /// Coerces this value into a syntax node reference, returning an error if it's some other type
433    /// of value.
434    pub fn into_syntax_node_ref<'a, 'tree>(self) -> Result<SyntaxNodeRef, ExecutionError> {
435        match self {
436            Value::SyntaxNode(node) => Ok(node),
437            _ => Err(ExecutionError::ExpectedSyntaxNode(format!("got {}", self))),
438        }
439    }
440
441    /// Coerces this value into a syntax node, returning an error if it's some other type
442    /// of value.
443    #[deprecated(note = "Use the pattern graph[value.into_syntax_node_ref(graph)] instead")]
444    pub fn into_syntax_node<'a, 'tree>(
445        self,
446        graph: &'a Graph<'tree>,
447    ) -> Result<&'a Node<'tree>, ExecutionError> {
448        Ok(&graph[self.into_syntax_node_ref()?])
449    }
450
451    pub fn as_syntax_node_ref<'a, 'tree>(&self) -> Result<SyntaxNodeRef, ExecutionError> {
452        match self {
453            Value::SyntaxNode(node) => Ok(*node),
454            _ => Err(ExecutionError::ExpectedSyntaxNode(format!("got {}", self))),
455        }
456    }
457}
458
459impl From<bool> for Value {
460    fn from(value: bool) -> Value {
461        Value::Boolean(value)
462    }
463}
464
465impl From<u32> for Value {
466    fn from(value: u32) -> Value {
467        Value::Integer(value)
468    }
469}
470
471impl From<&str> for Value {
472    fn from(value: &str) -> Value {
473        Value::String(value.to_string())
474    }
475}
476
477impl From<String> for Value {
478    fn from(value: String) -> Value {
479        Value::String(value)
480    }
481}
482
483impl From<Vec<Value>> for Value {
484    fn from(value: Vec<Value>) -> Value {
485        Value::List(value)
486    }
487}
488
489impl From<BTreeSet<Value>> for Value {
490    fn from(value: BTreeSet<Value>) -> Value {
491        Value::Set(value)
492    }
493}
494
495impl std::fmt::Display for Value {
496    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
497        match self {
498            Value::Null => write!(f, "#null"),
499            Value::Boolean(value) => {
500                if *value {
501                    write!(f, "#true")
502                } else {
503                    write!(f, "#false")
504                }
505            }
506            Value::Integer(value) => write!(f, "{}", value),
507            Value::String(value) => write!(f, "{}", value),
508            Value::List(value) => {
509                write!(f, "[")?;
510                let mut first = true;
511                for element in value {
512                    if first {
513                        write!(f, "{}", element)?;
514                        first = false;
515                    } else {
516                        write!(f, ", {}", element)?;
517                    }
518                }
519                write!(f, "]")
520            }
521            Value::Set(value) => {
522                write!(f, "{{")?;
523                let mut first = true;
524                for element in value {
525                    if first {
526                        write!(f, "{}", element)?;
527                        first = false;
528                    } else {
529                        write!(f, ", {}", element)?;
530                    }
531                }
532                write!(f, "}}")
533            }
534            Value::SyntaxNode(node) => node.fmt(f),
535            Value::GraphNode(node) => node.fmt(f),
536        }
537    }
538}
539
540impl std::fmt::Debug for Value {
541    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
542        match self {
543            Value::Null => write!(f, "#null"),
544            Value::Boolean(value) => {
545                if *value {
546                    write!(f, "#true")
547                } else {
548                    write!(f, "#false")
549                }
550            }
551            Value::Integer(value) => write!(f, "{:?}", value),
552            Value::String(value) => write!(f, "{:?}", value),
553            Value::List(value) => {
554                write!(f, "[")?;
555                let mut first = true;
556                for element in value {
557                    if first {
558                        write!(f, "{:?}", element)?;
559                        first = false;
560                    } else {
561                        write!(f, ", {:?}", element)?;
562                    }
563                }
564                write!(f, "]")
565            }
566            Value::Set(value) => {
567                write!(f, "{{")?;
568                let mut first = true;
569                for element in value {
570                    if first {
571                        write!(f, "{:?}", element)?;
572                        first = false;
573                    } else {
574                        write!(f, ", {:?}", element)?;
575                    }
576                }
577                write!(f, "}}")
578            }
579            Value::SyntaxNode(node) => node.fmt(f),
580            Value::GraphNode(node) => node.fmt(f),
581        }
582    }
583}
584
585impl Serialize for Value {
586    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
587        match self {
588            Value::Null => {
589                let mut map = serializer.serialize_map(None)?;
590                map.serialize_entry("type", "null")?;
591                map.end()
592            }
593            Value::Boolean(bool) => {
594                let mut map = serializer.serialize_map(None)?;
595                map.serialize_entry("type", "bool")?;
596                map.serialize_entry("bool", bool)?;
597                map.end()
598            }
599            Value::Integer(int) => {
600                let mut map = serializer.serialize_map(None)?;
601                map.serialize_entry("type", "int")?;
602                map.serialize_entry("int", int)?;
603                map.end()
604            }
605            Value::String(str) => {
606                let mut map = serializer.serialize_map(None)?;
607                map.serialize_entry("type", "string")?;
608                map.serialize_entry("string", str)?;
609                map.end()
610            }
611            Value::List(list) => {
612                let mut map = serializer.serialize_map(None)?;
613                map.serialize_entry("type", "list")?;
614                map.serialize_entry("values", list)?;
615                map.end()
616            }
617            Value::Set(set) => {
618                let mut map = serializer.serialize_map(None)?;
619                map.serialize_entry("type", "set")?;
620                map.serialize_entry("values", set)?;
621                map.end()
622            }
623            Value::SyntaxNode(node) => {
624                let mut map = serializer.serialize_map(None)?;
625                map.serialize_entry("type", "syntaxNode")?;
626                map.serialize_entry("id", &node.index)?;
627                map.end()
628            }
629            Value::GraphNode(node) => {
630                let mut map = serializer.serialize_map(None)?;
631                map.serialize_entry("type", "graphNode")?;
632                map.serialize_entry("id", &node.0)?;
633                map.end()
634            }
635        }
636    }
637}
638
639/// A reference to a syntax node in a graph
640#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)]
641pub struct SyntaxNodeRef {
642    pub(crate) index: SyntaxNodeID,
643    kind: &'static str,
644    position: tree_sitter::Point,
645}
646
647impl From<tree_sitter::Point> for Location {
648    fn from(point: tree_sitter::Point) -> Location {
649        Location {
650            row: point.row,
651            column: point.column,
652        }
653    }
654}
655
656impl SyntaxNodeRef {
657    pub fn location(&self) -> Location {
658        Location::from(self.position)
659    }
660}
661
662impl From<SyntaxNodeRef> for Value {
663    fn from(value: SyntaxNodeRef) -> Value {
664        Value::SyntaxNode(value)
665    }
666}
667
668impl std::fmt::Display for SyntaxNodeRef {
669    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
670        write!(
671            f,
672            "[syntax node {} ({}, {})]",
673            self.kind,
674            self.position.row + 1,
675            self.position.column + 1,
676        )
677    }
678}
679
680impl std::fmt::Debug for SyntaxNodeRef {
681    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
682        write!(
683            f,
684            "[syntax node {} ({}, {})]",
685            self.kind,
686            self.position.row + 1,
687            self.position.column + 1,
688        )
689    }
690}
691
692/// A reference to a graph node
693#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)]
694pub struct GraphNodeRef(GraphNodeID);
695
696impl GraphNodeRef {
697    /// Returns the index of the graph node that this reference refers to.
698    pub fn index(self) -> usize {
699        self.0 as usize
700    }
701}
702
703impl From<GraphNodeRef> for Value {
704    fn from(value: GraphNodeRef) -> Value {
705        Value::GraphNode(value)
706    }
707}
708
709impl std::fmt::Display for GraphNodeRef {
710    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
711        write!(f, "[graph node {}]", self.0)
712    }
713}
714
715impl std::fmt::Debug for GraphNodeRef {
716    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
717        write!(f, "[graph node {}]", self.0)
718    }
719}