wdl_grammar/tree/
dive.rs

1//! Utilities for traversing a syntax tree while collecting elements of interest
2//! (i.e., "diving" for elements).
3
4use std::iter::FusedIterator;
5
6use rowan::Language;
7use rowan::api::PreorderWithTokens;
8
9use crate::SyntaxElement;
10use crate::SyntaxNode;
11use crate::WorkflowDescriptionLanguage;
12
13/// An iterator that performs a pre-order traversal of a
14/// [`SyntaxNode`](rowan::SyntaxNode)'s descendants and yields all elements
15/// while ignoring undesirable subtrees.
16#[allow(missing_debug_implementations)]
17pub struct DiveIterator<L, I>
18where
19    L: Language,
20    I: Fn(&rowan::SyntaxNode<L>) -> bool,
21{
22    /// The iterator that performs the pre-order traversal of elements.
23    it: PreorderWithTokens<L>,
24    /// The function that evaluates when checking if the subtree beneath a node
25    /// should be ignored.
26    ignore_predicate: I,
27}
28
29impl<L, I> DiveIterator<L, I>
30where
31    L: Language,
32    I: Fn(&rowan::SyntaxNode<L>) -> bool,
33{
34    /// Creates a new [`DiveIterator`].
35    pub fn new(root: rowan::SyntaxNode<L>, ignore_predicate: I) -> Self {
36        Self {
37            it: root.preorder_with_tokens(),
38            ignore_predicate,
39        }
40    }
41}
42
43impl<L, I> Iterator for DiveIterator<L, I>
44where
45    L: Language,
46    I: Fn(&rowan::SyntaxNode<L>) -> bool,
47{
48    type Item = rowan::SyntaxElement<L>;
49
50    fn next(&mut self) -> Option<Self::Item> {
51        while let Some(event) = self.it.next() {
52            let element = match event {
53                rowan::WalkEvent::Enter(element) => element,
54                rowan::WalkEvent::Leave(_) => continue,
55            };
56
57            if let rowan::SyntaxElement::Node(node) = &element
58                && (self.ignore_predicate)(node)
59            {
60                self.it.skip_subtree();
61                continue;
62            }
63
64            return Some(element);
65        }
66
67        None
68    }
69}
70
71impl<L, I> FusedIterator for DiveIterator<L, I>
72where
73    L: Language,
74    I: Fn(&rowan::SyntaxNode<L>) -> bool,
75{
76}
77
78/// Elements of a syntax tree upon which a dive can be performed.
79pub trait Divable<L>
80where
81    L: Language,
82{
83    /// Iterates over every element in the tree at the current root and yields
84    /// the elements for which the given `match_predicate` evaluates to
85    /// `true`.
86    fn dive<M>(&self, match_predicate: M) -> impl Iterator<Item = rowan::SyntaxElement<L>>
87    where
88        M: Fn(&rowan::SyntaxElement<L>) -> bool,
89    {
90        self.dive_with_ignore(match_predicate, |_| false)
91    }
92
93    /// Iterates over every element in the tree at the current root and yields
94    /// the elements for which the given `match_predicate` evaluates to
95    /// `true`.
96    ///
97    /// If the `ignore_predicate` evaluates to `true`, the subtree at the given
98    /// node will not be traversed.
99    fn dive_with_ignore<M, I>(
100        &self,
101        match_predicate: M,
102        ignore_predicate: I,
103    ) -> impl Iterator<Item = rowan::SyntaxElement<L>>
104    where
105        M: Fn(&rowan::SyntaxElement<L>) -> bool,
106        I: Fn(&rowan::SyntaxNode<L>) -> bool;
107}
108
109impl<D, L> Divable<L> for &D
110where
111    D: Divable<L>,
112    L: Language,
113{
114    fn dive_with_ignore<M, I>(
115        &self,
116        match_predicate: M,
117        ignore_predicate: I,
118    ) -> impl Iterator<Item = rowan::SyntaxElement<L>>
119    where
120        M: Fn(&rowan::SyntaxElement<L>) -> bool,
121        I: Fn(&rowan::SyntaxNode<L>) -> bool,
122    {
123        D::dive_with_ignore(self, match_predicate, ignore_predicate)
124    }
125}
126
127impl Divable<WorkflowDescriptionLanguage> for SyntaxNode {
128    fn dive_with_ignore<M, I>(
129        &self,
130        match_predicate: M,
131        ignore_predicate: I,
132    ) -> impl Iterator<Item = SyntaxElement>
133    where
134        M: Fn(&SyntaxElement) -> bool,
135        I: Fn(&SyntaxNode) -> bool,
136    {
137        DiveIterator::new(
138            // NOTE: this is an inexpensive clone of a red node.
139            self.clone(),
140            ignore_predicate,
141        )
142        .filter(match_predicate)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use std::sync::OnceLock;
149
150    use rowan::GreenNode;
151
152    use crate::SyntaxKind;
153    use crate::SyntaxNode;
154    use crate::SyntaxTree;
155    use crate::dive::Divable;
156
157    fn get_syntax_node() -> SyntaxNode {
158        static GREEN_NODE: OnceLock<GreenNode> = OnceLock::new();
159
160        let green_node = GREEN_NODE
161            .get_or_init(|| {
162                let (tree, diagnostics) = SyntaxTree::parse(
163                    r#"version 1.2
164
165task hello {
166    String a_private_declaration = false
167}
168
169workflow world {
170    String another_private_declaration = true
171}"#,
172                );
173
174                assert!(diagnostics.is_empty());
175                tree.green().into()
176            })
177            .clone();
178
179        SyntaxNode::new_root(green_node)
180    }
181
182    #[test]
183    fn it_dives_correctly() {
184        let tree = get_syntax_node();
185
186        let mut idents = tree.dive(|element| element.kind() == SyntaxKind::Ident);
187
188        assert_eq!(idents.next().unwrap().as_token().unwrap().text(), "hello");
189
190        assert_eq!(
191            idents.next().unwrap().as_token().unwrap().text(),
192            "a_private_declaration"
193        );
194
195        assert_eq!(idents.next().unwrap().as_token().unwrap().text(), "world");
196
197        assert_eq!(
198            idents.next().unwrap().as_token().unwrap().text(),
199            "another_private_declaration"
200        );
201
202        assert!(idents.next().is_none());
203    }
204
205    #[test]
206    fn it_dives_with_ignores_correctly() {
207        let tree = get_syntax_node();
208
209        let mut ignored_idents = tree.dive_with_ignore(
210            |element| element.kind() == SyntaxKind::Ident,
211            |node| node.kind() == SyntaxKind::WorkflowDefinitionNode,
212        );
213
214        assert_eq!(
215            ignored_idents.next().unwrap().as_token().unwrap().text(),
216            "hello"
217        );
218        assert_eq!(
219            ignored_idents.next().unwrap().as_token().unwrap().text(),
220            "a_private_declaration"
221        );
222
223        // The idents contained in the workflow are not included in the results,
224        // as we explicitly ignored any workflow definition nodes.
225        assert!(ignored_idents.next().is_none());
226    }
227}