1use super::referent_rule::{GlobalRules, ReferentRuleError, RuleRegistration};
8use crate::check_var::CheckHint;
9use crate::maybe::Maybe;
10use crate::rule::{self, Rule, RuleSerializeError, SerializableRule};
11use crate::rule_core::{RuleCoreError, SerializableRuleCore};
12use crate::transform::Trans;
13use thread_ast_engine::meta_var::MetaVariable;
14
15use thread_ast_engine::language::Language;
16
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20use thread_utilities::RapidMap;
21
22#[derive(Serialize, Deserialize, Clone, JsonSchema)]
23pub struct SerializableGlobalRule<L: Language> {
24 #[serde(flatten)]
25 pub core: SerializableRuleCore,
26 pub id: String,
28 pub language: L,
30}
31
32fn into_map<L: Language>(
33 rules: Vec<SerializableGlobalRule<L>>,
34) -> RapidMap<String, (L, SerializableRuleCore)> {
35 rules
36 .into_iter()
37 .map(|r| (r.id, (r.language, r.core)))
38 .collect()
39}
40
41type OrderResult<T> = Result<T, String>;
42
43#[derive(Clone, Debug)]
45pub struct DeserializeEnv<L: Language> {
46 pub(crate) registration: RuleRegistration,
48 pub(crate) lang: L,
50}
51
52trait DependentRule: Sized {
53 fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()>;
54}
55
56impl DependentRule for SerializableRule {
57 fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
58 visit_dependent_rule_ids(self, sorter)
59 }
60}
61
62impl<L: Language> DependentRule for (L, SerializableRuleCore) {
63 fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
64 visit_dependent_rule_ids(&self.1.rule, sorter)
65 }
66}
67
68impl DependentRule for Trans<MetaVariable> {
69 fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
70 let used_var = self.used_vars();
71 sorter.visit(used_var)
72 }
73}
74
75struct TopologicalSort<'a, T: DependentRule> {
78 maps: &'a RapidMap<String, T>,
79 order: Vec<&'a str>,
80 seen: RapidMap<&'a str, bool>,
82}
83
84impl<'a, T: DependentRule> TopologicalSort<'a, T> {
85 fn get_order(maps: &RapidMap<String, T>) -> OrderResult<Vec<&str>> {
86 let mut top_sort = TopologicalSort::new(maps);
87 for key in maps.keys() {
88 top_sort.visit(key)?;
89 }
90 Ok(top_sort.order)
91 }
92
93 fn new(maps: &'a RapidMap<String, T>) -> Self {
94 Self {
95 maps,
96 order: vec![],
97 seen: RapidMap::default(),
98 }
99 }
100
101 fn visit(&mut self, key: &'a str) -> OrderResult<()> {
102 if let Some(&completed) = self.seen.get(key) {
103 return if completed {
106 Ok(())
107 } else {
108 Err(key.to_string())
109 };
110 }
111 let Some(item) = self.maps.get(key) else {
112 return Ok(());
118 };
119 self.seen.insert(key, false);
121 item.visit_dependency(self)?;
122 self.seen.insert(key, true);
124 self.order.push(key);
125 Ok(())
126 }
127}
128
129fn visit_dependent_rule_ids<'a, T: DependentRule>(
130 rule: &'a SerializableRule,
131 sort: &mut TopologicalSort<'a, T>,
132) -> OrderResult<()> {
133 if let Maybe::Present(matches) = &rule.matches {
135 sort.visit(matches)?;
136 }
137 if let Maybe::Present(all) = &rule.all {
138 for sub in all {
139 visit_dependent_rule_ids(sub, sort)?;
140 }
141 }
142 if let Maybe::Present(any) = &rule.any {
143 for sub in any {
144 visit_dependent_rule_ids(sub, sort)?;
145 }
146 }
147 if let Maybe::Present(not) = &rule.not {
148 visit_dependent_rule_ids(not, sort)?;
149 }
150 Ok(())
151}
152
153impl<L: Language> DeserializeEnv<L> {
154 pub fn new(lang: L) -> Self {
155 Self {
156 registration: Default::default(),
157 lang,
158 }
159 }
160
161 pub fn with_utils(
165 self,
166 utils: &RapidMap<String, SerializableRule>,
167 ) -> Result<Self, RuleSerializeError> {
168 let order = TopologicalSort::get_order(utils)
169 .map_err(ReferentRuleError::CyclicRule)
170 .map_err(RuleSerializeError::MatchesReference)?;
171 for id in order {
172 let rule = utils.get(id).expect("must exist");
173 let rule = self.deserialize_rule(rule.clone())?;
174 self.registration.insert_local(id, rule)?;
175 }
176 Ok(self)
177 }
178
179 pub fn parse_global_utils(
181 utils: Vec<SerializableGlobalRule<L>>,
182 ) -> Result<GlobalRules, RuleCoreError> {
183 let registration = GlobalRules::default();
184 let utils = into_map(utils);
185 let order = TopologicalSort::get_order(&utils)
186 .map_err(ReferentRuleError::CyclicRule)
187 .map_err(RuleSerializeError::from)?;
188 for id in order {
189 let (lang, core) = utils.get(id).expect("must exist");
190 let env = DeserializeEnv::new(lang.clone()).with_globals(®istration);
191 let matcher = core.get_matcher_with_hint(env, CheckHint::Global)?;
192 registration
193 .insert(id, matcher)
194 .map_err(RuleSerializeError::MatchesReference)?;
195 }
196 Ok(registration)
197 }
198
199 pub fn deserialize_rule(
200 &self,
201 serialized: SerializableRule,
202 ) -> Result<Rule, RuleSerializeError> {
203 rule::deserialize_rule(serialized, self)
204 }
205
206 pub(crate) fn get_transform_order<'a>(
207 &self,
208 trans: &'a RapidMap<String, Trans<MetaVariable>>,
209 ) -> Result<Vec<&'a str>, String> {
210 TopologicalSort::get_order(trans)
211 }
212
213 pub fn with_globals(self, globals: &GlobalRules) -> Self {
214 Self {
215 registration: RuleRegistration::from_globals(globals),
216 lang: self.lang,
217 }
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use super::*;
224 use crate::test::TypeScript;
225 use crate::{Rule, from_str};
226 use thread_ast_engine::Matcher;
227 use thread_ast_engine::tree_sitter::LanguageExt;
228
229 type Result<T> = std::result::Result<T, RuleSerializeError>;
230
231 fn get_dependent_utils() -> Result<(Rule, DeserializeEnv<TypeScript>)> {
232 let utils = from_str(
233 "
234accessor-name:
235 matches: member-name
236 regex: whatever
237member-name:
238 kind: identifier
239",
240 )
241 .expect("failed to parse utils");
242 let env = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils)?;
243 assert_eq!(utils.keys().count(), 2);
244 let rule = from_str("matches: accessor-name").unwrap();
245 Ok((
246 env.deserialize_rule(rule).unwrap(),
247 env, ))
249 }
250
251 #[test]
252 fn test_local_util_matches() -> Result<()> {
253 let (rule, _env) = get_dependent_utils()?;
254 let grep = TypeScript::Tsx.ast_grep("whatever");
255 assert!(grep.root().find(rule).is_some());
256 Ok(())
257 }
258
259 #[test]
260 #[ignore = "TODO, need to figure out potential_kinds"]
261 fn test_local_util_kinds() -> Result<()> {
262 for _ in 0..10 {
264 let (rule, _env) = get_dependent_utils()?;
265 assert!(rule.potential_kinds().is_some());
266 }
267 Ok(())
268 }
269
270 #[test]
271 fn test_using_global_rule_in_local() -> Result<()> {
272 let utils = from_str(
273 "
274local-rule:
275 matches: global-rule
276",
277 )
278 .expect("failed to parse utils");
279 DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils)?;
281 Ok(())
282 }
283
284 #[test]
285 fn test_using_cyclic_local() -> Result<()> {
286 let utils = from_str(
287 "
288local-rule:
289 matches: local-rule
290",
291 )
292 .expect("failed to parse utils");
293 let ret = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils);
294 assert!(ret.is_err());
295 Ok(())
296 }
297
298 #[test]
299 fn test_using_transitive_cycle() -> Result<()> {
300 let utils = from_str(
301 "
302local-rule-a:
303 matches: local-rule-b
304local-rule-b:
305 all:
306 - matches: local-rule-c
307local-rule-c:
308 any:
309 - matches: local-rule-a
310",
311 )
312 .expect("failed to parse utils");
313 let ret = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils);
314 assert!(ret.is_err());
315 Ok(())
316 }
317
318 #[test]
319 fn test_cyclic_not() -> Result<()> {
320 let utils = from_str(
321 "
322local-rule-a:
323 not: {matches: local-rule-b}
324local-rule-b:
325 matches: local-rule-a",
326 )
327 .expect("failed to parse utils");
328 let ret = DeserializeEnv::new(TypeScript::Tsx).with_utils(&utils);
329 assert!(matches!(
330 ret,
331 Err(RuleSerializeError::MatchesReference(
332 ReferentRuleError::CyclicRule(_)
333 ))
334 ));
335 Ok(())
336 }
337}