Skip to main content

thread_rule_engine/
rule_core.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use crate::DeserializeEnv;
8use crate::check_var::{CheckHint, check_rule_with_hint};
9use crate::fixer::{Fixer, FixerError, SerializableFixer};
10use crate::rule::Rule;
11use crate::rule::referent_rule::RuleRegistration;
12use crate::rule::{RuleSerializeError, SerializableRule};
13use crate::transform::{Transform, TransformError, Transformation};
14
15use serde::{Deserialize, Serialize};
16use serde_yaml::Error as YamlError;
17use thread_ast_engine::language::Language;
18use thread_ast_engine::meta_var::MetaVarEnv;
19use thread_ast_engine::{Doc, Matcher, Node};
20
21use bit_set::BitSet;
22use schemars::JsonSchema;
23use thiserror::Error;
24
25use std::borrow::Cow;
26use std::ops::Deref;
27use thread_utilities::{RapidMap, RapidSet};
28
29#[derive(Error, Debug)]
30pub enum RuleCoreError {
31    #[error("Fail to parse yaml as RuleConfig")]
32    Yaml(#[from] YamlError),
33    #[error("`utils` is not configured correctly.")]
34    Utils(#[source] RuleSerializeError),
35    #[error("`rule` is not configured correctly.")]
36    Rule(#[from] RuleSerializeError),
37    #[error("`constraints` is not configured correctly.")]
38    Constraints(#[source] RuleSerializeError),
39    #[error("`transform` is not configured correctly.")]
40    Transform(#[from] TransformError),
41    #[error("`fix` pattern is invalid.")]
42    Fixer(#[from] FixerError),
43    #[error("Undefined meta var `{0}` used in `{1}`.")]
44    UndefinedMetaVar(String, &'static str),
45}
46
47type RResult<T> = std::result::Result<T, RuleCoreError>;
48
49/// Used for global rules, rewriters, and pyo3/napi
50#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
51pub struct SerializableRuleCore {
52    /// A rule object to find matching AST nodes
53    pub rule: SerializableRule,
54    /// Additional meta variables pattern to filter matching
55    pub constraints: Option<RapidMap<String, SerializableRule>>,
56    /// Utility rules that can be used in `matches`
57    pub utils: Option<RapidMap<String, SerializableRule>>,
58    /// A dictionary for metavariable manipulation. Dict key is the new variable name.
59    /// Dict value is a [transformation] that specifies how meta var is processed.
60    /// See [transformation doc](https://ast-grep.github.io/reference/yaml/transformation.html).
61    pub transform: Option<RapidMap<String, Transformation>>,
62    /// A pattern string or a FixConfig object to auto fix the issue.
63    /// It can reference metavariables appeared in rule.
64    /// See details in fix [object reference](https://ast-grep.github.io/reference/yaml/fix.html#fixconfig).
65    pub fix: Option<SerializableFixer>,
66}
67
68impl SerializableRuleCore {
69    /// This function assumes env's local is empty.
70    fn get_deserialize_env<L: Language>(
71        &self,
72        env: DeserializeEnv<L>,
73    ) -> RResult<DeserializeEnv<L>> {
74        if let Some(utils) = &self.utils {
75            let env = env.with_utils(utils).map_err(RuleCoreError::Utils)?;
76            Ok(env)
77        } else {
78            Ok(env)
79        }
80    }
81
82    fn get_constraints<L: Language>(
83        &self,
84        env: &DeserializeEnv<L>,
85    ) -> RResult<RapidMap<thread_ast_engine::meta_var::MetaVariableID, Rule>> {
86        let mut constraints = RapidMap::default();
87        let Some(serde_cons) = &self.constraints else {
88            return Ok(constraints);
89        };
90        for (key, ser) in serde_cons {
91            let constraint = env
92                .deserialize_rule(ser.clone())
93                .map_err(RuleCoreError::Constraints)?;
94            constraints.insert(std::sync::Arc::from(key.as_str()), constraint);
95        }
96        Ok(constraints)
97    }
98
99    fn get_fixer<L: Language>(&self, env: &DeserializeEnv<L>) -> RResult<Vec<Fixer>> {
100        if let Some(fix) = &self.fix {
101            let parsed = Fixer::parse(fix, env, &self.transform)?;
102            Ok(parsed)
103        } else {
104            Ok(vec![])
105        }
106    }
107
108    fn get_matcher_from_env<L: Language>(&self, env: &DeserializeEnv<L>) -> RResult<RuleCore> {
109        let rule = env.deserialize_rule(self.rule.clone())?;
110        let constraints = self.get_constraints(env)?;
111        let transform = self
112            .transform
113            .as_ref()
114            .map(|t| Transform::deserialize(t, env))
115            .transpose()?;
116        let fixer = self.get_fixer(env)?;
117        Ok(RuleCore::new(rule)
118            .with_matchers(constraints)
119            .with_registration(env.registration.clone())
120            .with_transform(transform)
121            .with_fixer(fixer))
122    }
123
124    pub fn get_matcher<L: Language>(&self, env: DeserializeEnv<L>) -> RResult<RuleCore> {
125        self.get_matcher_with_hint(env, CheckHint::Normal)
126    }
127
128    pub(crate) fn get_matcher_with_hint<L: Language>(
129        &self,
130        env: DeserializeEnv<L>,
131        hint: CheckHint,
132    ) -> RResult<RuleCore> {
133        let env = self.get_deserialize_env(env)?;
134        let ret = self.get_matcher_from_env(&env)?;
135        check_rule_with_hint(
136            &ret.rule,
137            &ret.registration,
138            &ret.constraints,
139            &ret.transform,
140            &ret.fixer,
141            hint,
142        )?;
143        Ok(ret)
144    }
145}
146
147#[derive(Clone, Debug)]
148pub struct RuleCore {
149    rule: Rule,
150    constraints: RapidMap<thread_ast_engine::meta_var::MetaVariableID, Rule>,
151    kinds: Option<BitSet>,
152    pub(crate) transform: Option<Transform>,
153    pub fixer: Vec<Fixer>,
154    // this is required to hold util rule reference
155    registration: RuleRegistration,
156}
157
158impl RuleCore {
159    #[inline]
160    pub fn new(rule: Rule) -> Self {
161        let kinds = rule.potential_kinds();
162        Self {
163            rule,
164            kinds,
165            ..Default::default()
166        }
167    }
168
169    #[inline]
170    pub fn with_matchers(
171        self,
172        constraints: RapidMap<thread_ast_engine::meta_var::MetaVariableID, Rule>,
173    ) -> Self {
174        Self {
175            constraints,
176            ..self
177        }
178    }
179
180    #[inline]
181    pub fn with_registration(self, registration: RuleRegistration) -> Self {
182        Self {
183            registration,
184            ..self
185        }
186    }
187
188    #[inline]
189    pub fn with_transform(self, transform: Option<Transform>) -> Self {
190        Self { transform, ..self }
191    }
192
193    #[inline]
194    pub fn with_fixer(self, fixer: Vec<Fixer>) -> Self {
195        Self { fixer, ..self }
196    }
197
198    pub fn get_env<L: Language>(&self, lang: L) -> DeserializeEnv<L> {
199        DeserializeEnv {
200            lang,
201            registration: self.registration.clone(),
202        }
203    }
204    /// Get the meta variables that have real ast node matches
205    /// that is, meta vars defined in the rules and constraints
206    pub(crate) fn defined_node_vars(&self) -> RapidSet<&str> {
207        let mut ret = self.rule.defined_vars();
208        for v in self.registration.get_local_util_vars() {
209            ret.insert(v);
210        }
211        for constraint in self.constraints.values() {
212            for var in constraint.defined_vars() {
213                ret.insert(var);
214            }
215        }
216        ret
217    }
218
219    pub fn defined_vars(&self) -> RapidSet<&str> {
220        let mut ret = self.defined_node_vars();
221        if let Some(trans) = &self.transform {
222            for key in trans.keys() {
223                ret.insert(key);
224            }
225        }
226        ret
227    }
228
229    pub(crate) fn do_match<'tree, D: Doc>(
230        &self,
231        node: Node<'tree, D>,
232        env: &mut Cow<MetaVarEnv<'tree, D>>,
233        enclosing_env: Option<&MetaVarEnv<'tree, D>>,
234    ) -> Option<Node<'tree, D>> {
235        if let Some(kinds) = &self.kinds
236            && !kinds.contains(node.kind_id().into())
237        {
238            return None;
239        }
240        let ret = self.rule.match_node_with_env(node, env)?;
241        if !env.to_mut().match_constraints(&self.constraints) {
242            return None;
243        }
244        if let Some(trans) = &self.transform {
245            let rewriters = self.registration.get_rewriters();
246            let env = env.to_mut();
247            if let Some(enclosing) = enclosing_env {
248                trans.apply_transform(env, rewriters, enclosing);
249            } else {
250                let enclosing = env.clone();
251                trans.apply_transform(env, rewriters, &enclosing);
252            };
253        }
254        Some(ret)
255    }
256}
257impl Deref for RuleCore {
258    type Target = Rule;
259    fn deref(&self) -> &Self::Target {
260        &self.rule
261    }
262}
263
264impl Default for RuleCore {
265    #[inline]
266    fn default() -> Self {
267        Self {
268            rule: Rule::default(),
269            constraints: RapidMap::default(),
270            kinds: None,
271            transform: None,
272            fixer: vec![],
273            registration: RuleRegistration::default(),
274        }
275    }
276}
277
278impl Matcher for RuleCore {
279    fn match_node_with_env<'tree, D: Doc>(
280        &self,
281        node: Node<'tree, D>,
282        env: &mut Cow<MetaVarEnv<'tree, D>>,
283    ) -> Option<Node<'tree, D>> {
284        self.do_match(node, env, None)
285    }
286
287    fn potential_kinds(&self) -> Option<BitSet> {
288        self.rule.potential_kinds()
289    }
290}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295    use crate::from_str;
296    use crate::rule::referent_rule::{ReferentRule, ReferentRuleError};
297    use crate::test::TypeScript;
298    use thread_ast_engine::matcher::{Pattern, RegexMatcher};
299    use thread_ast_engine::tree_sitter::LanguageExt;
300
301    fn get_matcher(src: &str) -> RResult<RuleCore> {
302        let env = DeserializeEnv::new(TypeScript::Tsx);
303        let rule: SerializableRuleCore = from_str(src).expect("should word");
304        rule.get_matcher(env)
305    }
306
307    #[test]
308    fn test_rule_error() {
309        let ret = get_matcher(r"rule: {kind: bbb}");
310        assert!(matches!(ret, Err(RuleCoreError::Rule(_))));
311    }
312
313    #[test]
314    fn test_utils_error() {
315        let ret = get_matcher(
316            r"
317rule: { kind: number }
318utils: { testa: {kind: bbb} }
319  ",
320        );
321        assert!(matches!(ret, Err(RuleCoreError::Utils(_))));
322    }
323
324    #[test]
325    fn test_undefined_utils_error() {
326        let ret = get_matcher(r"rule: { kind: number, matches: undefined-util }");
327        match ret {
328            Err(RuleCoreError::Rule(RuleSerializeError::MatchesReference(
329                ReferentRuleError::UndefinedUtil(name),
330            ))) => {
331                assert_eq!(name, "undefined-util");
332            }
333            _ => panic!("wrong error"),
334        }
335    }
336
337    #[test]
338    fn test_cyclic_transform_error() {
339        let ret = get_matcher(
340            r"
341rule: { kind: number }
342transform:
343  A: {substring: {source: $B}}
344  B: {substring: {source: $A}}",
345        );
346        assert!(matches!(
347            ret,
348            Err(RuleCoreError::Transform(TransformError::Cyclic(_)))
349        ));
350    }
351
352    #[test]
353    fn test_rule_reg_with_utils() {
354        let env = DeserializeEnv::new(TypeScript::Tsx);
355        let ser_rule: SerializableRuleCore =
356            from_str("{rule: {matches: test}, utils: {test: {kind: number}} }")
357                .expect("should deser");
358        let rule = ReferentRule::try_new("test".into(), &env.registration).expect("should work");
359        let not = ReferentRule::try_new("test2".into(), &env.registration).expect("should work");
360        let matcher = ser_rule.get_matcher(env).expect("should parse");
361        let grep = TypeScript::Tsx.ast_grep("a = 123");
362        assert!(grep.root().find(&matcher).is_some());
363        assert!(grep.root().find(&rule).is_some());
364        assert!(grep.root().find(&not).is_none());
365        let grep = TypeScript::Tsx.ast_grep("a = '123'");
366        assert!(grep.root().find(&matcher).is_none());
367        assert!(grep.root().find(&rule).is_none());
368        assert!(grep.root().find(&not).is_none());
369    }
370
371    #[test]
372    fn test_rule_with_constraints() {
373        let mut constraints = RapidMap::default();
374        constraints.insert(
375            std::sync::Arc::from("A"),
376            Rule::Regex(RegexMatcher::try_new("a").unwrap()),
377        );
378        let rule = RuleCore::new(Rule::Pattern(Pattern::new("$A", &TypeScript::Tsx)))
379            .with_matchers(constraints);
380        let grep = TypeScript::Tsx.ast_grep("a");
381        assert!(grep.root().find(&rule).is_some());
382        let grep = TypeScript::Tsx.ast_grep("bbb");
383        assert!(grep.root().find(&rule).is_none());
384    }
385
386    #[test]
387    fn test_constraints_inheriting_env() {
388        let env = DeserializeEnv::new(TypeScript::Tsx);
389        let ser_rule: SerializableRuleCore =
390            from_str("{rule: {pattern: $A = $B}, constraints: {A: {pattern: $B}} }")
391                .expect("should deser");
392        let matcher = ser_rule.get_matcher(env).expect("should parse");
393        let grep = TypeScript::Tsx.ast_grep("a = a");
394        assert!(grep.root().find(&matcher).is_some());
395        let grep = TypeScript::Tsx.ast_grep("a = b");
396        assert!(grep.root().find(&matcher).is_none());
397    }
398
399    #[test]
400    fn test_constraints_writing_to_env() {
401        let env = DeserializeEnv::new(TypeScript::Tsx);
402        let ser_rule: SerializableRuleCore =
403            from_str("{rule: {pattern: $A = $B}, constraints: {B: {pattern: $C + $D}} }")
404                .expect("should deser");
405        let matcher = ser_rule.get_matcher(env).expect("should parse");
406        let grep = TypeScript::Tsx.ast_grep("a = a");
407        assert!(grep.root().find(&matcher).is_none());
408        let grep = TypeScript::Tsx.ast_grep("a = 1 + 2");
409        let nm = grep.root().find(&matcher).expect("should match");
410        let env = nm.get_env();
411        let matched = env.get_match("C").expect("should match C").text();
412        assert_eq!(matched, "1");
413        let matched = env.get_match("D").expect("should match D").text();
414        assert_eq!(matched, "2");
415    }
416
417    fn get_rewriters() -> (&'static str, RuleCore) {
418        // NOTE: initialize a DeserializeEnv here is not 100% correct
419        // it does not inherit global rules or local rules
420        let env = DeserializeEnv::new(TypeScript::Tsx);
421        let rewriter: SerializableRuleCore =
422            from_str("{rule: {kind: number, pattern: $REWRITE}, fix: yjsnp}")
423                .expect("should parse");
424        let rewriter = rewriter.get_matcher(env).expect("should work");
425        ("re", rewriter)
426    }
427
428    #[test]
429    fn test_rewriter_writing_to_env() {
430        let (id, rewriter) = get_rewriters();
431        let env = DeserializeEnv::new(TypeScript::Tsx);
432        env.registration.insert_rewriter(id, rewriter);
433        let ser_rule: SerializableRuleCore = from_str(
434            r"
435rule: {pattern: $A = $B}
436transform:
437  C:
438    rewrite:
439      source: $B
440      rewriters: [re]",
441        )
442        .expect("should deser");
443        let matcher = ser_rule.get_matcher(env).expect("should parse");
444        let grep = TypeScript::Tsx.ast_grep("a = 1 + 2");
445        let nm = grep.root().find(&matcher).expect("should match");
446        let env = nm.get_env();
447        let matched = env.get_match("B").expect("should match").text();
448        assert_eq!(matched, "1 + 2");
449        let matched = env.get_match("A").expect("should match").text();
450        assert_eq!(matched, "a");
451        let transformed = env.get_transformed("C").expect("should transform");
452        assert_eq!(String::from_utf8_lossy(transformed), "yjsnp + yjsnp");
453        assert!(env.get_match("REWRITE").is_none());
454
455        let grep = TypeScript::Tsx.ast_grep("a = a");
456        let nm = grep.root().find(&matcher).expect("should match");
457        let env = nm.get_env();
458        let matched = env.get_match("B").expect("should match").text();
459        assert_eq!(matched, "a");
460        let transformed = env.get_transformed("C").expect("should transform");
461        assert_eq!(String::from_utf8_lossy(transformed), "a");
462    }
463}