Skip to main content

thread_rule_engine/rule/
deserialize_env.rs

1// SPDX-FileCopyrightText: 2022 Herrington Darkholme <2883231+HerringtonDarkholme@users.noreply.github.com>
2// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
3// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
4//
5// SPDX-License-Identifier: AGPL-3.0-or-later AND MIT
6
7use super::referent_rule::{GlobalRules, ReferentRuleError, RuleRegistration};
8use crate::check_var::CheckHint;
9use crate::maybe::Maybe;
10use crate::rule::{self, Rule, RuleSerializeError, SerializableRule};
11use crate::rule_core::{RuleCoreError, SerializableRuleCore};
12use crate::transform::Trans;
13use thread_ast_engine::meta_var::MetaVariable;
14
15use thread_ast_engine::language::Language;
16
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20use thread_utilities::RapidMap;
21
22#[derive(Serialize, Deserialize, Clone, JsonSchema)]
23pub struct SerializableGlobalRule<L: Language> {
24    #[serde(flatten)]
25    pub core: SerializableRuleCore,
26    /// Unique, descriptive identifier, e.g., no-unused-variable
27    pub id: String,
28    /// Specify the language to parse and the file extension to include in matching.
29    pub language: L,
30}
31
32fn into_map<L: Language>(
33    rules: Vec<SerializableGlobalRule<L>>,
34) -> RapidMap<String, (L, SerializableRuleCore)> {
35    rules
36        .into_iter()
37        .map(|r| (r.id, (r.language, r.core)))
38        .collect()
39}
40
41type OrderResult<T> = Result<T, String>;
42
43/// A struct to store information to deserialize rules.
44#[derive(Clone, Debug)]
45pub struct DeserializeEnv<L: Language> {
46    /// registration for global utility rules and local utility rules.
47    pub(crate) registration: RuleRegistration,
48    /// current rules' language
49    pub(crate) lang: L,
50}
51
52trait DependentRule: Sized {
53    fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()>;
54}
55
56impl DependentRule for SerializableRule {
57    fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
58        visit_dependent_rule_ids(self, sorter)
59    }
60}
61
62impl<L: Language> DependentRule for (L, SerializableRuleCore) {
63    fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
64        visit_dependent_rule_ids(&self.1.rule, sorter)
65    }
66}
67
68impl DependentRule for Trans<MetaVariable> {
69    fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
70        let used_var = self.used_vars();
71        sorter.visit(used_var)
72    }
73}
74
75/// A struct to topological sort rules
76/// it is used to report cyclic dependency errors in rules/transformation
77struct TopologicalSort<'a, T: DependentRule> {
78    maps: &'a RapidMap<String, T>,
79    order: Vec<&'a str>,
80    // bool stands for if the rule has completed visit
81    seen: RapidMap<&'a str, bool>,
82}
83
84impl<'a, T: DependentRule> TopologicalSort<'a, T> {
85    fn get_order(maps: &RapidMap<String, T>) -> OrderResult<Vec<&str>> {
86        let mut top_sort = TopologicalSort::new(maps);
87        for key in maps.keys() {
88            top_sort.visit(key)?;
89        }
90        Ok(top_sort.order)
91    }
92
93    fn new(maps: &'a RapidMap<String, T>) -> Self {
94        Self {
95            maps,
96            order: vec![],
97            seen: RapidMap::default(),
98        }
99    }
100
101    fn visit(&mut self, key: &'a str) -> OrderResult<()> {
102        if let Some(&completed) = self.seen.get(key) {
103            // if the rule has been seen but not completed
104            // it means we have a cyclic dependency and report an error here
105            return if completed {
106                Ok(())
107            } else {
108                Err(key.to_string())
109            };
110        }
111        let Some(item) = self.maps.get(key) else {
112            // key can be found elsewhere
113            // e.g. if key is rule_id
114            // if rule_id not found in global, it can be a local rule
115            // if rule_id not found in local, it can be a global rule
116            // TODO: add check here and return Err if rule not found
117            return Ok(());
118        };
119        // mark the id as seen but not completed
120        self.seen.insert(key, false);
121        item.visit_dependency(self)?;
122        // mark the id as seen and completed
123        self.seen.insert(key, true);
124        self.order.push(key);
125        Ok(())
126    }
127}
128
129fn visit_dependent_rule_ids<'a, T: DependentRule>(
130    rule: &'a SerializableRule,
131    sort: &mut TopologicalSort<'a, T>,
132) -> OrderResult<()> {
133    // handle all composite rule here
134    if let Maybe::Present(matches) = &rule.matches {
135        sort.visit(matches)?;
136    }
137    if let Maybe::Present(all) = &rule.all {
138        for sub in all {
139            visit_dependent_rule_ids(sub, sort)?;
140        }
141    }
142    if let Maybe::Present(any) = &rule.any {
143        for sub in any {
144            visit_dependent_rule_ids(sub, sort)?;
145        }
146    }
147    if let Maybe::Present(not) = &rule.not {
148        visit_dependent_rule_ids(not, sort)?;
149    }
150    Ok(())
151}
152
153impl<L: Language> DeserializeEnv<L> {
154    pub fn new(lang: L) -> Self {
155        Self {
156            registration: Default::default(),
157            lang,
158        }
159    }
160
161    /// register utils rule in the DeserializeEnv for later usage.
162    /// N.B. This function will manage the util registration order
163    /// by their dependency. `potential_kinds` need ordered insertion.
164    pub fn with_utils(
165        self,
166        utils: &RapidMap<String, SerializableRule>,
167    ) -> Result<Self, RuleSerializeError> {
168        let order = TopologicalSort::get_order(utils)
169            .map_err(ReferentRuleError::CyclicRule)
170            .map_err(RuleSerializeError::MatchesReference)?;
171        for id in order {
172            let rule = utils.get(id).expect("must exist");
173            let rule = self.deserialize_rule(rule.clone())?;
174            self.registration.insert_local(id, rule)?;
175        }
176        Ok(self)
177    }
178
179    /// register global utils rule discovered in the config.
180    pub fn parse_global_utils(
181        utils: Vec<SerializableGlobalRule<L>>,
182    ) -> Result<GlobalRules, RuleCoreError> {
183        let registration = GlobalRules::default();
184        let utils = into_map(utils);
185        let order = TopologicalSort::get_order(&utils)
186            .map_err(ReferentRuleError::CyclicRule)
187            .map_err(RuleSerializeError::from)?;
188        for id in order {
189            let (lang, core) = utils.get(id).expect("must exist");
190            let env = DeserializeEnv::new(lang.clone()).with_globals(&registration);
191            let matcher = core.get_matcher_with_hint(env, CheckHint::Global)?;
192            registration
193                .insert(id, matcher)
194                .map_err(RuleSerializeError::MatchesReference)?;
195        }
196        Ok(registration)
197    }
198
199    pub fn deserialize_rule(
200        &self,
201        serialized: SerializableRule,
202    ) -> Result<Rule, RuleSerializeError> {
203        rule::deserialize_rule(serialized, self)
204    }
205
206    pub(crate) fn get_transform_order<'a>(
207        &self,
208        trans: &'a RapidMap<String, Trans<MetaVariable>>,
209    ) -> Result<Vec<&'a str>, String> {
210        TopologicalSort::get_order(trans)
211    }
212
213    pub fn with_globals(self, globals: &GlobalRules) -> Self {
214        Self {
215            registration: RuleRegistration::from_globals(globals),
216            lang: self.lang,
217        }
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use super::*;
224    use crate::test::TypeScript;
225    use crate::{Rule, from_str};
226    use thread_ast_engine::Matcher;
227    use thread_ast_engine::tree_sitter::LanguageExt;
228
229    type Result<T> = std::result::Result<T, RuleSerializeError>;
230
231    fn get_dependent_utils() -> Result<(Rule, DeserializeEnv<TypeScript>)> {
232        let utils = from_str(
233            "
234accessor-name:
235  matches: member-name
236  regex: whatever
237member-name:
238  kind: identifier
239",
240        )
241        .expect("failed to parse utils");
242        let env = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils)?;
243        assert_eq!(utils.keys().count(), 2);
244        let rule = from_str("matches: accessor-name").unwrap();
245        Ok((
246            env.deserialize_rule(rule).unwrap(),
247            env, // env is required for weak ref
248        ))
249    }
250
251    #[test]
252    fn test_local_util_matches() -> Result<()> {
253        let (rule, _env) = get_dependent_utils()?;
254        let grep = TypeScript::Tsx.ast_grep("whatever");
255        assert!(grep.root().find(rule).is_some());
256        Ok(())
257    }
258
259    #[test]
260    #[ignore = "TODO, need to figure out potential_kinds"]
261    fn test_local_util_kinds() -> Result<()> {
262        // run multiple times to avoid accidental working order due to FastMap randomness
263        for _ in 0..10 {
264            let (rule, _env) = get_dependent_utils()?;
265            assert!(rule.potential_kinds().is_some());
266        }
267        Ok(())
268    }
269
270    #[test]
271    fn test_using_global_rule_in_local() -> Result<()> {
272        let utils = from_str(
273            "
274local-rule:
275  matches: global-rule
276",
277        )
278        .expect("failed to parse utils");
279        // should not panic
280        DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils)?;
281        Ok(())
282    }
283
284    #[test]
285    fn test_using_cyclic_local() -> Result<()> {
286        let utils = from_str(
287            "
288local-rule:
289  matches: local-rule
290",
291        )
292        .expect("failed to parse utils");
293        let ret = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils);
294        assert!(ret.is_err());
295        Ok(())
296    }
297
298    #[test]
299    fn test_using_transitive_cycle() -> Result<()> {
300        let utils = from_str(
301            "
302local-rule-a:
303  matches: local-rule-b
304local-rule-b:
305  all:
306    - matches: local-rule-c
307local-rule-c:
308  any:
309    - matches: local-rule-a
310",
311        )
312        .expect("failed to parse utils");
313        let ret = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils);
314        assert!(ret.is_err());
315        Ok(())
316    }
317
318    #[test]
319    fn test_cyclic_not() -> Result<()> {
320        let utils = from_str(
321            "
322local-rule-a:
323  not: {matches: local-rule-b}
324local-rule-b:
325  matches: local-rule-a",
326        )
327        .expect("failed to parse utils");
328        let ret = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils);
329        assert!(matches!(
330            ret,
331            Err(RuleSerializeError::MatchesReference(
332                ReferentRuleError::CyclicRule(_)
333            ))
334        ));
335        Ok(())
336    }
337}