Skip to main content

thread_ast_engine/matchers/
pattern.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use super::kind::{KindMatcher, kind_utils};
8use super::matcher::Matcher;
9pub use super::types::{MatchStrictness, Pattern, PatternBuilder, PatternError, PatternNode};
10use crate::language::Language;
11use crate::match_tree::{match_end_non_recursive, match_node_non_recursive};
12use crate::meta_var::{MetaVarEnv, MetaVariable};
13use crate::source::SgNode;
14use crate::{Doc, Node, Root};
15
16use bit_set::BitSet;
17use std::borrow::Cow;
18use thread_utilities::RapidSet;
19
20impl PatternBuilder<'_> {
21    pub fn build<D, F>(&self, parse: F) -> Result<Pattern, PatternError>
22    where
23        F: FnOnce(&str) -> Result<D, String>,
24        D: Doc,
25    {
26        let doc = parse(&self.src).map_err(PatternError::Parse)?;
27        let root = Root::doc(doc);
28        if let Some(selector) = self.selector {
29            self.contextual(&root, selector)
30        } else {
31            self.single(&root)
32        }
33    }
34    fn single<D: Doc>(&self, root: &Root<D>) -> Result<Pattern, PatternError> {
35        let goal = root.root();
36        if goal.children().len() == 0 {
37            return Err(PatternError::NoContent(self.src.to_string()));
38        }
39        if !is_single_node(&goal.inner) {
40            return Err(PatternError::MultipleNode(self.src.to_string()));
41        }
42        let node = Pattern::single_matcher(root);
43        Ok(Pattern::from(node))
44    }
45
46    fn contextual<D: Doc>(&self, root: &Root<D>, selector: &str) -> Result<Pattern, PatternError> {
47        let goal = root.root();
48        let kind_matcher = KindMatcher::try_new(selector, root.lang())?;
49        let Some(node) = goal.find(&kind_matcher) else {
50            return Err(PatternError::NoSelectorInContext {
51                context: self.src.to_string(),
52                selector: selector.into(),
53            });
54        };
55        Ok(Pattern {
56            root_kind: Some(node.kind_id()),
57            node: convert_node_to_pattern(node.get_node()),
58            strictness: MatchStrictness::Smart,
59        })
60    }
61}
62
63impl PatternNode {
64    // for skipping trivial nodes in goal after ellipsis
65    #[must_use]
66    pub const fn is_trivial(&self) -> bool {
67        match self {
68            Self::Terminal { is_named, .. } => !*is_named,
69            _ => false,
70        }
71    }
72
73    #[inline]
74    #[must_use]
75    pub fn fixed_string(&self) -> Cow<'_, str> {
76        match &self {
77            Self::Terminal { text, .. } => Cow::Borrowed(text),
78            Self::MetaVar { .. } => Cow::Borrowed(""),
79            Self::Internal { children, .. } => children.iter().map(|n| n.fixed_string()).fold(
80                Cow::Borrowed(""),
81                |longest, curr| {
82                    if longest.len() >= curr.len() {
83                        longest
84                    } else {
85                        curr
86                    }
87                },
88            ),
89        }
90    }
91}
92impl<'r, D: Doc> From<Node<'r, D>> for PatternNode {
93    fn from(node: Node<'r, D>) -> Self {
94        convert_node_to_pattern(&node)
95    }
96}
97
98impl<'r, D: Doc> From<Node<'r, D>> for Pattern {
99    fn from(node: Node<'r, D>) -> Self {
100        Self {
101            node: convert_node_to_pattern(&node),
102            root_kind: None,
103            strictness: MatchStrictness::Smart,
104        }
105    }
106}
107
108fn convert_node_to_pattern<D: Doc>(node: &Node<'_, D>) -> PatternNode {
109    if let Some(meta_var) = extract_var_from_node(node) {
110        PatternNode::MetaVar { meta_var }
111    } else if node.is_leaf() {
112        PatternNode::Terminal {
113            text: node.text().to_string(),
114            is_named: node.is_named(),
115            kind_id: node.kind_id(),
116        }
117    } else {
118        // Pre-allocate vector with estimated capacity to reduce allocations
119        let child_count = node.children().count();
120        let mut children = Vec::with_capacity(child_count);
121
122        for child in node.children() {
123            if !child.is_missing() {
124                children.push(PatternNode::from(child));
125            }
126        }
127
128        PatternNode::Internal {
129            kind_id: node.kind_id(),
130            children,
131        }
132    }
133}
134
135fn extract_var_from_node<D: Doc>(goal: &Node<'_, D>) -> Option<MetaVariable> {
136    let key = goal.text();
137    goal.lang().extract_meta_var(&key)
138}
139
140#[inline]
141fn is_single_node<'r, N: SgNode<'r>>(n: &N) -> bool {
142    match n.children().len() {
143        1 => true,
144        2 => {
145            let c = n.child(1).expect("second child must exist");
146            // some language will have weird empty syntax node at the end
147            // see golang's `$A = 0` pattern test case
148            c.is_missing() || c.kind().is_empty()
149        }
150        _ => false,
151    }
152}
153impl Pattern {
154    #[must_use]
155    pub const fn has_error(&self) -> bool {
156        let kind = match &self.node {
157            PatternNode::Terminal { kind_id, .. } | PatternNode::Internal { kind_id, .. } => {
158                *kind_id
159            }
160            PatternNode::MetaVar { .. } => match self.root_kind {
161                Some(k) => k,
162                None => return false,
163            },
164        };
165        kind_utils::is_error_kind(kind)
166    }
167
168    #[must_use]
169    pub fn fixed_string(&self) -> Cow<'_, str> {
170        self.node.fixed_string()
171    }
172
173    /// Get all defined variables in the pattern.
174    /// Used for validating rules and report undefined variables.
175    #[must_use]
176    pub fn defined_vars(&self) -> RapidSet<&str> {
177        let mut vars = RapidSet::default();
178        collect_vars(&self.node, &mut vars);
179        vars
180    }
181}
182
183fn meta_var_name(meta_var: &MetaVariable) -> Option<&str> {
184    use MetaVariable as MV;
185    match meta_var {
186        MV::Capture(name, _) | MV::MultiCapture(name) => Some(name),
187        MV::Dropped(_) | MV::Multiple => None,
188    }
189}
190
191fn collect_vars<'p>(p: &'p PatternNode, vars: &mut RapidSet<&'p str>) {
192    match p {
193        PatternNode::MetaVar { meta_var, .. } => {
194            if let Some(name) = meta_var_name(meta_var) {
195                vars.insert(name);
196            }
197        }
198        PatternNode::Terminal { .. } => {
199            // collect nothing for terminal nodes!
200        }
201        PatternNode::Internal { children, .. } => {
202            for c in children {
203                collect_vars(c, vars);
204            }
205        }
206    }
207}
208
209impl Pattern {
210    pub fn try_new<L: Language>(src: &str, lang: &L) -> Result<Self, PatternError> {
211        let processed = lang.pre_process_pattern(src);
212        let builder = PatternBuilder {
213            selector: None,
214            src: processed,
215        };
216        lang.build_pattern(&builder)
217    }
218
219    pub fn new<L: Language>(src: &str, lang: &L) -> Self {
220        Self::try_new(src, lang).unwrap()
221    }
222
223    #[must_use]
224    pub const fn with_strictness(mut self, strictness: MatchStrictness) -> Self {
225        self.strictness = strictness;
226        self
227    }
228
229    pub fn contextual<L: Language>(
230        context: &str,
231        selector: &str,
232        lang: &L,
233    ) -> Result<Self, PatternError> {
234        let processed = lang.pre_process_pattern(context);
235        let builder = PatternBuilder {
236            selector: Some(selector),
237            src: processed,
238        };
239        lang.build_pattern(&builder)
240    }
241    fn single_matcher<D: Doc>(root: &Root<D>) -> Node<'_, D> {
242        // debug_assert!(matches!(self.style, PatternStyle::Single));
243        let node = root.root();
244        let mut inner = node.inner;
245        while is_single_node(&inner) {
246            inner = inner.child(0).unwrap();
247        }
248        Node { inner, root }
249    }
250}
251
252impl Matcher for Pattern {
253    fn match_node_with_env<'tree, D: Doc>(
254        &self,
255        node: Node<'tree, D>,
256        env: &mut Cow<MetaVarEnv<'tree, D>>,
257    ) -> Option<Node<'tree, D>> {
258        if let Some(k) = self.root_kind
259            && node.kind_id() != k
260        {
261            return None;
262        }
263        // do not pollute the env if pattern does not match
264        let mut may_write = Cow::Borrowed(env.as_ref());
265        let node = match_node_non_recursive(self, node, &mut may_write)?;
266        if let Cow::Owned(map) = may_write {
267            // only change env when pattern matches
268            *env = Cow::Owned(map);
269        }
270        Some(node)
271    }
272
273    fn potential_kinds(&self) -> Option<bit_set::BitSet> {
274        let kind = match self.node {
275            PatternNode::Terminal { kind_id, .. } => kind_id,
276            PatternNode::MetaVar { .. } => self.root_kind?,
277            PatternNode::Internal { kind_id, .. } => {
278                if kind_utils::is_error_kind(kind_id) {
279                    // error can match any kind
280                    return None;
281                }
282                kind_id
283            }
284        };
285
286        let mut kinds = BitSet::new();
287        kinds.insert(kind.into());
288        Some(kinds)
289    }
290
291    fn get_match_len<D: Doc>(&self, node: Node<'_, D>) -> Option<usize> {
292        let start = node.range().start;
293        let end = match_end_non_recursive(self, &node)?;
294        Some(end - start)
295    }
296}
297impl std::fmt::Debug for PatternNode {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        match self {
300            Self::MetaVar { meta_var, .. } => write!(f, "{meta_var:?}"),
301            Self::Terminal { text, .. } => write!(f, "{text}"),
302            Self::Internal { children, .. } => write!(f, "{children:?}"),
303        }
304    }
305}
306
307impl std::fmt::Debug for Pattern {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        write!(f, "{:?}", self.node)
310    }
311}
312
313#[cfg(test)]
314mod test {
315    use super::*;
316    use crate::language::Tsx;
317    use crate::matcher::MatcherExt;
318    use crate::meta_var::MetaVarEnv;
319    use crate::tree_sitter::StrDoc;
320    use thread_utilities::RapidMap;
321
322    fn pattern_node(s: &str) -> Root<StrDoc<Tsx>> {
323        Root::str(s, Tsx)
324    }
325
326    fn test_match(s1: &str, s2: &str) {
327        let pattern = Pattern::new(s1, &Tsx);
328        let cand = pattern_node(s2);
329        let cand = cand.root();
330        assert!(
331            pattern.find_node(cand.clone()).is_some(),
332            "goal: {:?}, candidate: {}",
333            pattern,
334            cand.get_inner_node().to_sexp(),
335        );
336    }
337    fn test_non_match(s1: &str, s2: &str) {
338        let pattern = Pattern::new(s1, &Tsx);
339        let cand = pattern_node(s2);
340        let cand = cand.root();
341        assert!(
342            pattern.find_node(cand.clone()).is_none(),
343            "goal: {:?}, candidate: {}",
344            pattern,
345            cand.get_inner_node().to_sexp(),
346        );
347    }
348
349    #[test]
350    fn test_meta_variable() {
351        test_match("const a = $VALUE", "const a = 123");
352        test_match("const $VARIABLE = $VALUE", "const a = 123");
353        test_match("const $VARIABLE = $VALUE", "const a = 123");
354    }
355
356    #[test]
357    fn test_whitespace() {
358        test_match("function t() { }", "function t() {}");
359        test_match("function t() {}", "function t() {  }");
360    }
361
362    fn match_env(goal_str: &str, cand: &str) -> RapidMap<String, String> {
363        let pattern = Pattern::new(goal_str, &Tsx);
364        let cand = pattern_node(cand);
365        let cand = cand.root();
366        let nm = pattern.find_node(cand).unwrap();
367        RapidMap::from(nm.get_env().clone())
368    }
369
370    #[test]
371    fn test_meta_variable_env() {
372        let env = match_env("const a = $VALUE", "const a = 123");
373        assert_eq!(env["VALUE"], "123");
374    }
375
376    #[test]
377    fn test_pattern_should_not_pollute_env() {
378        // gh issue #1164
379        let pattern = Pattern::new("const $A = 114", &Tsx);
380        let cand = pattern_node("const a = 514");
381        let cand = cand.root().child(0).unwrap();
382        let map = MetaVarEnv::new();
383        let mut env = Cow::Borrowed(&map);
384        let nm = pattern.match_node_with_env(cand, &mut env);
385        assert!(nm.is_none());
386        assert!(env.get_match("A").is_none());
387        assert!(map.get_match("A").is_none());
388    }
389
390    #[test]
391    fn test_match_non_atomic() {
392        let env = match_env("const a = $VALUE", "const a = 5 + 3");
393        assert_eq!(env["VALUE"], "5 + 3");
394    }
395
396    #[test]
397    fn test_class_assignment() {
398        test_match("class $C { $MEMBER = $VAL}", "class A {a = 123}");
399        test_non_match("class $C { $MEMBER = $VAL; b = 123; }", "class A {a = 123}");
400        // test_match("a = 123", "class A {a = 123}");
401        test_non_match("a = 123", "class B {b = 123}");
402    }
403
404    #[test]
405    fn test_return() {
406        test_match("$A($B)", "return test(123)");
407    }
408
409    #[test]
410    fn test_contextual_pattern() {
411        let pattern = Pattern::contextual("class A { $F = $I }", "public_field_definition", &Tsx)
412            .expect("test");
413        let cand = pattern_node("class B { b = 123 }");
414        assert!(pattern.find_node(cand.root()).is_some());
415        let cand = pattern_node("let b = 123");
416        assert!(pattern.find_node(cand.root()).is_none());
417    }
418
419    #[test]
420    fn test_contextual_match_with_env() {
421        let pattern = Pattern::contextual("class A { $F = $I }", "public_field_definition", &Tsx)
422            .expect("test");
423        let cand = pattern_node("class B { b = 123 }");
424        let nm = pattern.find_node(cand.root()).expect("test");
425        let env = nm.get_env();
426        let env = RapidMap::from(env.clone());
427        assert_eq!(env["F"], "b");
428        assert_eq!(env["I"], "123");
429    }
430
431    #[test]
432    fn test_contextual_unmatch_with_env() {
433        let pattern = Pattern::contextual("class A { $F = $I }", "public_field_definition", &Tsx)
434            .expect("test");
435        let cand = pattern_node("let b = 123");
436        let nm = pattern.find_node(cand.root());
437        assert!(nm.is_none());
438    }
439
440    fn get_kind(kind_str: &str) -> usize {
441        Tsx.kind_to_id(kind_str).into()
442    }
443
444    #[test]
445    fn test_pattern_potential_kinds() {
446        let pattern = Pattern::new("const a = 1", &Tsx);
447        let kind = get_kind("lexical_declaration");
448        let kinds = pattern.potential_kinds().expect("should have kinds");
449        assert_eq!(kinds.len(), 1);
450        assert!(kinds.contains(kind));
451    }
452
453    #[test]
454    fn test_pattern_with_non_root_meta_var() {
455        let pattern = Pattern::new("const $A = $B", &Tsx);
456        let kind = get_kind("lexical_declaration");
457        let kinds = pattern.potential_kinds().expect("should have kinds");
458        assert_eq!(kinds.len(), 1);
459        assert!(kinds.contains(kind));
460    }
461
462    #[test]
463    fn test_bare_wildcard() {
464        let pattern = Pattern::new("$A", &Tsx);
465        // wildcard should match anything, so kinds should be None
466        assert!(pattern.potential_kinds().is_none());
467    }
468
469    #[test]
470    fn test_contextual_potential_kinds() {
471        let pattern = Pattern::contextual("class A { $F = $I }", "public_field_definition", &Tsx)
472            .expect("test");
473        let kind = get_kind("public_field_definition");
474        let kinds = pattern.potential_kinds().expect("should have kinds");
475        assert_eq!(kinds.len(), 1);
476        assert!(kinds.contains(kind));
477    }
478
479    #[test]
480    fn test_contextual_wildcard() {
481        let pattern =
482            Pattern::contextual("class A { $F }", "property_identifier", &Tsx).expect("test");
483        let kind = get_kind("property_identifier");
484        let kinds = pattern.potential_kinds().expect("should have kinds");
485        assert_eq!(kinds.len(), 1);
486        assert!(kinds.contains(kind));
487    }
488
489    #[test]
490    #[ignore = "multi-node patterns not yet implemented"]
491    fn test_multi_node_pattern() {
492        let pattern = Pattern::new("a;b;c;", &Tsx);
493        let kinds = pattern.potential_kinds().expect("should have kinds");
494        assert_eq!(kinds.len(), 1);
495        test_match("a;b;c", "a;b;c;");
496    }
497
498    #[test]
499    #[ignore = "multi-node patterns not yet implemented"]
500    fn test_multi_node_meta_var() {
501        let env = match_env("a;$B;c", "a;b;c");
502        assert_eq!(env["B"], "b");
503        let env = match_env("a;$B;c", "a;1+2+3;c");
504        assert_eq!(env["B"], "1+2+3");
505    }
506
507    #[test]
508    #[ignore = "struct layout is compiler and platform specific"]
509    fn test_pattern_size() {
510        assert_eq!(std::mem::size_of::<Pattern>(), 40);
511    }
512
513    #[test]
514    fn test_error_kind() {
515        let ret = Pattern::contextual("a", "property_identifier", &Tsx);
516        assert!(ret.is_err());
517        let ret = Pattern::new("123+", &Tsx);
518        assert!(ret.has_error());
519    }
520
521    #[test]
522    fn test_bare_wildcard_in_context() {
523        let pattern =
524            Pattern::contextual("class A { $F }", "property_identifier", &Tsx).expect("test");
525        let cand = pattern_node("let b = 123");
526        // it should not match
527        assert!(pattern.find_node(cand.root()).is_none());
528    }
529
530    #[test]
531    fn test_pattern_fixed_string() {
532        let pattern = Pattern::new("class A { $F }", &Tsx);
533        assert_eq!(pattern.fixed_string(), "class");
534        let pattern =
535            Pattern::contextual("class A { $F }", "property_identifier", &Tsx).expect("test");
536        assert!(pattern.fixed_string().is_empty());
537    }
538
539    #[test]
540    fn test_pattern_error() {
541        let pattern = Pattern::try_new("", &Tsx);
542        assert!(matches!(pattern, Err(PatternError::NoContent(_))));
543        let pattern = Pattern::try_new("12  3344", &Tsx);
544        assert!(matches!(pattern, Err(PatternError::MultipleNode(_))));
545    }
546
547    #[test]
548    fn test_debug_pattern() {
549        let pattern = Pattern::new("var $A = 1", &Tsx);
550        assert_eq!(
551            format!("{pattern:?}"),
552            "[var, [Capture(\"A\", true), =, 1]]"
553        );
554    }
555
556    fn defined_vars(s: &str) -> Vec<String> {
557        let pattern = Pattern::new(s, &Tsx);
558        let mut vars: Vec<_> = pattern
559            .defined_vars()
560            .into_iter()
561            .map(String::from)
562            .collect();
563        vars.sort();
564        vars
565    }
566
567    #[test]
568    fn test_extract_meta_var_from_pattern() {
569        let vars = defined_vars("var $A = 1");
570        assert_eq!(vars, ["A"]);
571    }
572
573    #[test]
574    fn test_extract_complex_meta_var() {
575        let vars = defined_vars("function $FUNC($$$ARGS): $RET { $$$BODY }");
576        assert_eq!(vars, ["ARGS", "BODY", "FUNC", "RET"]);
577    }
578
579    #[test]
580    fn test_extract_duplicate_meta_var() {
581        let vars = defined_vars("var $A = $A");
582        assert_eq!(vars, ["A"]);
583    }
584
585    #[test]
586    fn test_contextual_pattern_vars() {
587        let pattern =
588            Pattern::contextual("<div ref={$A}/>", "jsx_attribute", &Tsx).expect("correct");
589        assert_eq!(pattern.defined_vars(), std::iter::once("A").collect());
590    }
591
592    #[test]
593    fn test_gh_1087() {
594        test_match("($P) => $F($P)", "(x) => bar(x)");
595    }
596}