tree_sitter_traversal2/
lib.rs

1//! Iterators to traverse tree-sitter [`Tree`]s using a [`TreeCursor`],
2//! with a [`Cursor`] trait to allow for traversing arbitrary n-ary trees.
3//!
4//! # Examples
5//!
6//! Basic usage:
7//!
8//! ```
9//! # #[cfg(feature = "tree-sitter")]
10//! # {
11//! use tree_sitter::{Node, Tree};
12//! use std::collections::HashSet;
13//! use std::iter::FromIterator;
14//!
15//! use tree_sitter_traversal::{traverse, traverse_tree, Order};
16//! # fn get_tree() -> Tree {
17//! #     use tree_sitter::Parser;
18//! #     let mut parser = Parser::new();
19//! #     let lang = tree_sitter_rust::language();
20//! #     parser.set_language(&lang).expect("Error loading Rust grammar");
21//! #     return parser.parse("fn double(x: usize) -> usize { x * 2 }", None).expect("Error parsing provided code");
22//! # }
23//!
24//! // Non-existent method, imagine it gets a valid Tree with >1 node
25//! let tree: Tree = get_tree();
26//! let preorder: Vec<Node<'_>> = traverse(tree.walk(), Order::Pre).collect::<Vec<_>>();
27//! let postorder: Vec<Node<'_>> = traverse_tree(&tree, Order::Post).collect::<Vec<_>>();
28//! // For any tree with more than just a root node,
29//! // the order of preorder and postorder will be different
30//! assert_ne!(preorder, postorder);
31//! // However, they will have the same amount of nodes
32//! assert_eq!(preorder.len(), postorder.len());
33//! // Specifically, they will have the exact same nodes, just in a different order
34//! assert_eq!(
35//!     <HashSet<_>>::from_iter(preorder.into_iter()),
36//!     <HashSet<_>>::from_iter(postorder.into_iter())
37//! );
38//! # }
39//! ```
40//!
41//! [`Tree`]: tree_sitter::Tree
42//! [`TreeCursor`]: tree_sitter::TreeCursor
43//! [`Cursor`]: crate::Cursor
44#![no_std]
45
46use core::iter::FusedIterator;
47
48/// Trait which represents a stateful cursor in a n-ary tree.
49/// The cursor can be moved between nodes in the tree by the given methods,
50/// and the node which the cursor is currently pointing at can be read as well.
51pub trait Cursor {
52    /// The type of the nodes which the cursor points at; the cursor is always pointing
53    /// at exactly one of this type.
54    type Node;
55
56    /// Move this cursor to the first child of its current node.
57    ///
58    /// This returns `true` if the cursor successfully moved, and returns `false`
59    /// if there were no children.
60    fn goto_first_child(&mut self) -> bool;
61
62    /// Move this cursor to the parent of its current node.
63    ///
64    /// This returns `true` if the cursor successfully moved, and returns `false`
65    /// if there was no parent node (the cursor was already on the root node).
66    fn goto_parent(&mut self) -> bool;
67
68    /// Move this cursor to the next sibling of its current node.
69    ///
70    /// This returns `true` if the cursor successfully moved, and returns `false`
71    /// if there was no next sibling node.
72    fn goto_next_sibling(&mut self) -> bool;
73
74    /// Get the node which the cursor is currently pointing at.
75    fn node(&self) -> Self::Node;
76}
77
78impl<'a, T> Cursor for &'a mut T
79where
80    T: Cursor,
81{
82    type Node = T::Node;
83
84    fn goto_first_child(&mut self) -> bool {
85        T::goto_first_child(self)
86    }
87
88    fn goto_parent(&mut self) -> bool {
89        T::goto_parent(self)
90    }
91
92    fn goto_next_sibling(&mut self) -> bool {
93        T::goto_next_sibling(self)
94    }
95
96    fn node(&self) -> Self::Node {
97        T::node(self)
98    }
99}
100
101/// Quintessential implementation of [`Cursor`] for tree-sitter's [`TreeCursor`]
102///
103/// [`TreeCursor`]: tree_sitter::TreeCursor
104/// [`Cursor`]: crate::Cursor
105#[cfg(feature = "tree-sitter")]
106impl<'a> Cursor for tree_sitter::TreeCursor<'a> {
107    type Node = tree_sitter::Node<'a>;
108
109    fn goto_first_child(&mut self) -> bool {
110        self.goto_first_child()
111    }
112
113    fn goto_parent(&mut self) -> bool {
114        self.goto_parent()
115    }
116
117    fn goto_next_sibling(&mut self) -> bool {
118        self.goto_next_sibling()
119    }
120
121    fn node(&self) -> Self::Node {
122        self.node()
123    }
124}
125
126/// Order to iterate through a n-ary tree; for n-ary trees only
127/// Pre-order and Post-order make sense.
128#[derive(Eq, PartialEq, Hash, Debug, Copy, Clone)]
129pub enum Order {
130    Pre,
131    Post,
132}
133
134/// Iterative traversal of the tree; serves as a reference for both
135/// PreorderTraversal and PostorderTraversal, as they both will call the exact same
136/// cursor methods in the exact same order as this function for a given tree; the order
137/// is also the same as traverse_recursive.
138#[allow(dead_code)]
139fn traverse_iterative<C: Cursor, F>(mut c: C, order: Order, mut cb: F)
140where
141    F: FnMut(C::Node),
142{
143    loop {
144        // This is the first time we've encountered the node, so we'll call if preorder
145        if order == Order::Pre {
146            cb(c.node());
147        }
148
149        // Keep travelling down the tree as far as we can
150        if c.goto_first_child() {
151            continue;
152        }
153
154        let node = c.node();
155
156        // If we can't travel any further down, try going to next sibling and repeating
157        if c.goto_next_sibling() {
158            // If we succeed in going to the previous nodes sibling,
159            // we won't be encountering that node again, so we'll call if postorder
160            if order == Order::Post {
161                cb(node);
162            }
163            continue;
164        }
165
166        // Otherwise, we must travel back up; we'll loop until we reach the root or can
167        // go to the next sibling of a node again.
168        loop {
169            // Since we're retracing back up the tree, this is the last time we'll encounter
170            // this node, so we'll call if postorder
171            if order == Order::Post {
172                cb(c.node());
173            }
174            if !c.goto_parent() {
175                // We have arrived back at the root, so we are done.
176                return;
177            }
178
179            let node = c.node();
180
181            if c.goto_next_sibling() {
182                // If we succeed in going to the previous node's sibling,
183                // we will go back to travelling down that sibling's tree, and we also
184                // won't be encountering the previous node again, so we'll call if postorder
185                if order == Order::Post {
186                    cb(node);
187                }
188                break;
189            }
190        }
191    }
192}
193
194/// Idiomatic recursive traversal of the tree; this version is easier to understand
195/// conceptually, but the recursion is actually unnecessary and can cause stack overflow.
196#[allow(dead_code)]
197fn traverse_recursive<C: Cursor, F>(mut c: C, order: Order, mut cb: F)
198where
199    F: FnMut(C::Node),
200{
201    traverse_helper(&mut c, order, &mut cb);
202}
203
204fn traverse_helper<C: Cursor, F>(c: &mut C, order: Order, cb: &mut F)
205where
206    F: FnMut(C::Node),
207{
208    // If preorder, call the callback when we first touch the node
209    if order == Order::Pre {
210        cb(c.node());
211    }
212    if c.goto_first_child() {
213        // If there is a child, recursively call on
214        // that child and all its siblings
215        loop {
216            traverse_helper(c, order, cb);
217            if !c.goto_next_sibling() {
218                break;
219            }
220        }
221        // Make sure to reset back to the original node;
222        // this must always return true, as we only get here if we go to a child
223        // of the original node.
224        assert!(c.goto_parent());
225    }
226    // If preorder, call the callback after the recursive calls on child nodes
227    if order == Order::Post {
228        cb(c.node());
229    }
230}
231
232struct PreorderTraverse<C> {
233    cursor: Option<C>,
234}
235
236impl<C> PreorderTraverse<C> {
237    pub fn new(c: C) -> Self {
238        PreorderTraverse { cursor: Some(c) }
239    }
240}
241
242impl<C> Iterator for PreorderTraverse<C>
243where
244    C: Cursor,
245{
246    type Item = C::Node;
247
248    fn next(&mut self) -> Option<Self::Item> {
249        let c = match self.cursor.as_mut() {
250            None => {
251                return None;
252            }
253            Some(c) => c,
254        };
255
256        // We will always return the node we were on at the start;
257        // the node we traverse to will either be returned on the next iteration,
258        // or will be back to the root node, at which point we'll clear out
259        // the reference to the cursor
260        let node = c.node();
261
262        // First, try to go to a child or a sibling; if either succeed, this will be the
263        // first time we touch that node, so it'll be the next starting node
264        if c.goto_first_child() || c.goto_next_sibling() {
265            return Some(node);
266        }
267
268        loop {
269            // If we can't go to the parent, then that means we've reached the root, and our
270            // iterator will be done in the next iteration
271            if !c.goto_parent() {
272                self.cursor = None;
273                break;
274            }
275
276            // If we get to a sibling, then this will be the first time we touch that node,
277            // so it'll be the next starting node
278            if c.goto_next_sibling() {
279                break;
280            }
281        }
282
283        Some(node)
284    }
285}
286
287struct PostorderTraverse<C> {
288    cursor: Option<C>,
289    retracing: bool,
290}
291
292impl<C> PostorderTraverse<C> {
293    pub fn new(c: C) -> Self {
294        PostorderTraverse {
295            cursor: Some(c),
296            retracing: false,
297        }
298    }
299}
300
301impl<C> Iterator for PostorderTraverse<C>
302where
303    C: Cursor,
304{
305    type Item = C::Node;
306
307    fn next(&mut self) -> Option<Self::Item> {
308        let c = match self.cursor.as_mut() {
309            None => {
310                return None;
311            }
312            Some(c) => c,
313        };
314
315        // For the postorder traversal, we will only return a node when we are travelling back up
316        // the tree structure. Therefore, we go all the way to the leaves of the tree immediately,
317        // and only when we are retracing do we return elements
318        if !self.retracing {
319            while c.goto_first_child() {}
320        }
321
322        // Much like in preorder traversal, we want to return the node we were previously at.
323        // We know this will be the last time we touch this node, as we will either be going
324        // to its next sibling or retracing back up the tree
325        let node = c.node();
326        if c.goto_next_sibling() {
327            // If we successfully go to a sibling of this node, we want to go back down
328            // the tree on the next iteration
329            self.retracing = false;
330        } else {
331            // If we weren't already retracing, we are now; travel upwards until we can
332            // go to the next sibling or reach the root again
333            self.retracing = true;
334            if !c.goto_parent() {
335                // We've reached the root again, and our iteration is done
336                self.cursor = None;
337            }
338        }
339
340        Some(node)
341    }
342}
343
344// Used for visibility purposes, in case this struct becomes public
345struct Traverse<C> {
346    inner: TraverseInner<C>,
347}
348
349enum TraverseInner<C> {
350    Post(PostorderTraverse<C>),
351    Pre(PreorderTraverse<C>),
352}
353
354impl<C> Traverse<C> {
355    pub fn new(c: C, order: Order) -> Self {
356        let inner = match order {
357            Order::Pre => TraverseInner::Pre(PreorderTraverse::new(c)),
358            Order::Post => TraverseInner::Post(PostorderTraverse::new(c)),
359        };
360        Self { inner }
361    }
362}
363
364#[cfg(feature = "tree-sitter")]
365impl<'a> Traverse<tree_sitter::TreeCursor<'a>> {
366    #[allow(dead_code)]
367    pub fn from_tree(tree: &'a tree_sitter::Tree, order: Order) -> Self {
368        Traverse::new(tree.walk(), order)
369    }
370}
371
372/// Convenience method to traverse a tree-sitter [`Tree`] in an order according to `order`.
373///
374/// [`Tree`]: tree_sitter::Tree
375#[cfg(feature = "tree-sitter")]
376pub fn traverse_tree(
377    tree: &tree_sitter::Tree,
378    order: Order,
379) -> impl FusedIterator<Item = tree_sitter::Node> {
380    return traverse(tree.walk(), order);
381}
382
383/// Traverse an n-ary tree using `cursor`, returning the nodes of the tree through an iterator
384/// in an order according to `order`.
385///
386/// `cursor` must be at the root of the tree
387/// (i.e. `cursor.goto_parent()` must return false)
388pub fn traverse<C: Cursor>(mut cursor: C, order: Order) -> impl FusedIterator<Item = C::Node> {
389    assert!(!cursor.goto_parent());
390    Traverse::new(cursor, order)
391}
392
393impl<C> Iterator for Traverse<C>
394where
395    C: Cursor,
396{
397    type Item = C::Node;
398
399    fn next(&mut self) -> Option<Self::Item> {
400        match self.inner {
401            TraverseInner::Post(ref mut i) => i.next(),
402            TraverseInner::Pre(ref mut i) => i.next(),
403        }
404    }
405}
406
407// We know that PreorderTraverse and PostorderTraverse are fused due to their implementation,
408// so we can add this bound for free.
409impl<C> FusedIterator for Traverse<C> where C: Cursor {}
410
411#[cfg(test)]
412#[cfg(feature = "tree-sitter")]
413mod tree_sitter_tests {
414    use super::*;
415
416    extern crate std;
417    use std::vec::Vec;
418    use tree_sitter::{Parser, Tree};
419
420    const EX1: &str = r#"
421fn double(x: usize) -> usize {
422    return 2 * x;
423}"#;
424
425    const EX2: &str = r#"
426// Intentionally invalid code below
427
428"123
429
430const DOUBLE = 2;
431
432function double(x: usize) -> usize {
433    return DOUBLE * x;
434}"#;
435
436    const EX3: &str = "";
437
438    /// For a given tree and iteration order, verify that the two callback approaches
439    /// and the Iterator approach are all equivalent
440    fn generate_traversals(tree: &Tree, order: Order) {
441        let mut recursive_callback = Vec::new();
442        traverse_recursive(tree.walk(), order, |n| recursive_callback.push(n));
443        let mut iterative_callback = Vec::new();
444        traverse_iterative(tree.walk(), order, |n| iterative_callback.push(n));
445        let iterator = traverse(tree.walk(), order).collect::<Vec<_>>();
446
447        assert_eq!(recursive_callback, iterative_callback);
448        assert_eq!(iterative_callback, iterator);
449    }
450
451    /// Helper function to generate a Tree from Rust code
452    fn get_tree(code: &str) -> Tree {
453        let mut parser = Parser::new();
454        let lang = tree_sitter_rust::language();
455        parser
456            .set_language(&lang)
457            .expect("Error loading Rust grammar");
458        return parser
459            .parse(code, None)
460            .expect("Error parsing provided code");
461    }
462
463    #[test]
464    fn test_equivalence() {
465        for code in [EX1, EX2, EX3] {
466            let tree = get_tree(code);
467            for order in [Order::Pre, Order::Post] {
468                generate_traversals(&tree, order);
469            }
470        }
471    }
472
473    #[test]
474    fn test_postconditions() {
475        let parsed = get_tree(EX1);
476        let mut walk = parsed.walk();
477        for order in [Order::Pre, Order::Post] {
478            let mut iter = traverse(&mut walk, order);
479            while iter.next().is_some() {}
480            // Make sure it's fused
481            assert!(iter.next().is_none());
482            // Really make sure it's fused
483            assert!(iter.next().is_none());
484            drop(iter);
485            // Verify that the walk is reset to the root_node and can be reused
486            assert_eq!(walk.node(), parsed.root_node());
487        }
488    }
489
490    #[test]
491    #[should_panic]
492    fn test_panic() {
493        // Tests that the precondition check works
494        let parsed = get_tree(EX1);
495        let mut walk = parsed.walk();
496        walk.goto_first_child();
497        let iter = traverse(&mut walk, Order::Pre);
498        iter.count();
499    }
500
501    #[test]
502    fn example() {
503        use std::collections::HashSet;
504        use std::iter::FromIterator;
505        use tree_sitter::{Node, Tree};
506        let tree: Tree = get_tree(EX1);
507        let preorder: Vec<Node<'_>> = traverse(tree.walk(), Order::Pre).collect::<Vec<_>>();
508        let postorder: Vec<Node<'_>> = traverse_tree(&tree, Order::Post).collect::<Vec<_>>();
509        assert_ne!(preorder, postorder);
510        assert_eq!(preorder.len(), postorder.len());
511        assert_eq!(
512            <HashSet<_>>::from_iter(preorder.into_iter()),
513            <HashSet<_>>::from_iter(postorder.into_iter())
514        );
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    struct Root;
523
524    // Root represents a tree where there's only one node, the root, and its type is the unit type
525    impl Cursor for Root {
526        type Node = ();
527
528        fn goto_first_child(&mut self) -> bool {
529            return false;
530        }
531
532        fn goto_parent(&mut self) -> bool {
533            return false;
534        }
535
536        fn goto_next_sibling(&mut self) -> bool {
537            return false;
538        }
539
540        fn node(&self) -> Self::Node {
541            ()
542        }
543    }
544
545    #[test]
546    fn test_root() {
547        assert_eq!(1, traverse(Root, Order::Pre).count());
548        assert_eq!(1, traverse(Root, Order::Post).count());
549    }
550}