1mod deserialize_env;
8mod nth_child;
9mod range;
10pub mod referent_rule;
11mod relational_rule;
12mod stop_by;
13
14pub use deserialize_env::DeserializeEnv;
15pub use relational_rule::Relation;
16pub use stop_by::StopBy;
17
18use crate::maybe::Maybe;
19use nth_child::{NthChild, NthChildError, SerializableNthChild};
20use range::{RangeMatcher, RangeMatcherError, SerializableRange};
21use referent_rule::{ReferentRule, ReferentRuleError};
22use relational_rule::{Follows, Has, Inside, Precedes};
23
24use thread_ast_engine::language::Language;
25use thread_ast_engine::matcher::{KindMatcher, KindMatcherError, RegexMatcher, RegexMatcherError};
26use thread_ast_engine::meta_var::MetaVarEnv;
27use thread_ast_engine::{Doc, Node, ops as o};
28use thread_ast_engine::{MatchStrictness, Matcher, Pattern, PatternError};
29
30use bit_set::BitSet;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize};
33use std::borrow::Cow;
34use thiserror::Error;
35use thread_utilities::RapidSet;
36
37#[derive(Serialize, Deserialize, Clone, Debug, Default, JsonSchema)]
47#[serde(deny_unknown_fields)]
48pub struct SerializableRule {
49 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
54 pub pattern: Maybe<PatternStyle>,
55 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
57 pub kind: Maybe<String>,
58 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
60 pub regex: Maybe<String>,
61 #[serde(default, skip_serializing_if = "Maybe::is_absent", rename = "nthChild")]
64 pub nth_child: Maybe<SerializableNthChild>,
65 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
68 pub range: Maybe<SerializableRange>,
69
70 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
74 pub inside: Maybe<Box<Relation>>,
75 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
78 pub has: Maybe<Box<Relation>>,
79 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
82 pub precedes: Maybe<Box<Relation>>,
83 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
86 pub follows: Maybe<Box<Relation>>,
87 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
91 pub all: Maybe<Vec<SerializableRule>>,
92 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
95 pub any: Maybe<Vec<SerializableRule>>,
96 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
97 pub not: Maybe<Box<SerializableRule>>,
99 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
101 pub matches: Maybe<String>,
102}
103
104struct Categorized {
105 pub atomic: AtomicRule,
106 pub relational: RelationalRule,
107 pub composite: CompositeRule,
108}
109
110impl SerializableRule {
111 fn categorized(self) -> Categorized {
112 Categorized {
113 atomic: AtomicRule {
114 pattern: self.pattern.into(),
115 kind: self.kind.into(),
116 regex: self.regex.into(),
117 nth_child: self.nth_child.into(),
118 range: self.range.into(),
119 },
120 relational: RelationalRule {
121 inside: self.inside.into(),
122 has: self.has.into(),
123 precedes: self.precedes.into(),
124 follows: self.follows.into(),
125 },
126 composite: CompositeRule {
127 all: self.all.into(),
128 any: self.any.into(),
129 not: self.not.into(),
130 matches: self.matches.into(),
131 },
132 }
133 }
134}
135
136pub struct AtomicRule {
137 pub pattern: Option<PatternStyle>,
138 pub kind: Option<String>,
139 pub regex: Option<String>,
140 pub nth_child: Option<SerializableNthChild>,
141 pub range: Option<SerializableRange>,
142}
143#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
144#[serde(rename_all = "camelCase")]
145pub enum Strictness {
146 Cst,
148 Smart,
150 Ast,
152 Relaxed,
154 Signature,
156}
157
158impl From<MatchStrictness> for Strictness {
159 fn from(value: MatchStrictness) -> Self {
160 use MatchStrictness as M;
161 use Strictness as S;
162 match value {
163 M::Cst => S::Cst,
164 M::Smart => S::Smart,
165 M::Ast => S::Ast,
166 M::Relaxed => S::Relaxed,
167 M::Signature => S::Signature,
168 }
169 }
170}
171
172impl From<Strictness> for MatchStrictness {
173 fn from(value: Strictness) -> Self {
174 use MatchStrictness as M;
175 use Strictness as S;
176 match value {
177 S::Cst => M::Cst,
178 S::Smart => M::Smart,
179 S::Ast => M::Ast,
180 S::Relaxed => M::Relaxed,
181 S::Signature => M::Signature,
182 }
183 }
184}
185
186#[derive(Serialize, Deserialize, Clone, JsonSchema, Debug)]
189#[serde(untagged)]
190pub enum PatternStyle {
191 Str(String),
192 Contextual {
193 context: String,
195 selector: Option<String>,
197 strictness: Option<Strictness>,
199 },
200}
201
202pub struct RelationalRule {
203 pub inside: Option<Box<Relation>>,
204 pub has: Option<Box<Relation>>,
205 pub precedes: Option<Box<Relation>>,
206 pub follows: Option<Box<Relation>>,
207}
208
209pub struct CompositeRule {
210 pub all: Option<Vec<SerializableRule>>,
211 pub any: Option<Vec<SerializableRule>>,
212 pub not: Option<Box<SerializableRule>>,
213 pub matches: Option<String>,
214}
215
216#[derive(Clone, Debug)]
217pub enum Rule {
218 Pattern(Pattern),
220 Kind(KindMatcher),
221 Regex(RegexMatcher),
222 NthChild(NthChild),
223 Range(RangeMatcher),
224 Inside(Box<Inside>),
226 Has(Box<Has>),
227 Precedes(Box<Precedes>),
228 Follows(Box<Follows>),
229 All(o::All<Rule>),
231 Any(o::Any<Rule>),
232 Not(Box<o::Not<Rule>>),
233 Matches(ReferentRule),
234}
235impl Rule {
236 pub(crate) fn check_cyclic(&self, id: &str) -> bool {
238 match self {
239 Rule::All(all) => all.inner().iter().any(|r| r.check_cyclic(id)),
240 Rule::Any(any) => any.inner().iter().any(|r| r.check_cyclic(id)),
241 Rule::Not(not) => not.inner().check_cyclic(id),
242 Rule::Matches(m) => m.rule_id == id,
243 _ => false,
244 }
245 }
246
247 pub fn defined_vars(&self) -> RapidSet<&str> {
248 match self {
249 Rule::Pattern(p) => p.defined_vars(),
250 Rule::Kind(_) => RapidSet::default(),
251 Rule::Regex(_) => RapidSet::default(),
252 Rule::NthChild(n) => n.defined_vars(),
253 Rule::Range(_) => RapidSet::default(),
254 Rule::Has(c) => c.defined_vars(),
255 Rule::Inside(p) => p.defined_vars(),
256 Rule::Precedes(f) => f.defined_vars(),
257 Rule::Follows(f) => f.defined_vars(),
258 Rule::All(sub) => sub.inner().iter().flat_map(|r| r.defined_vars()).collect(),
259 Rule::Any(sub) => sub.inner().iter().flat_map(|r| r.defined_vars()).collect(),
260 Rule::Not(sub) => sub.inner().defined_vars(),
261 Rule::Matches(_r) => RapidSet::default(),
263 }
264 }
265
266 pub fn verify_util(&self) -> Result<(), RuleSerializeError> {
268 match self {
269 Rule::Pattern(_) => Ok(()),
270 Rule::Kind(_) => Ok(()),
271 Rule::Regex(_) => Ok(()),
272 Rule::NthChild(n) => n.verify_util(),
273 Rule::Range(_) => Ok(()),
274 Rule::Has(c) => c.verify_util(),
275 Rule::Inside(p) => p.verify_util(),
276 Rule::Precedes(f) => f.verify_util(),
277 Rule::Follows(f) => f.verify_util(),
278 Rule::All(sub) => sub.inner().iter().try_for_each(|r| r.verify_util()),
279 Rule::Any(sub) => sub.inner().iter().try_for_each(|r| r.verify_util()),
280 Rule::Not(sub) => sub.inner().verify_util(),
281 Rule::Matches(r) => Ok(r.verify_util()?),
282 }
283 }
284}
285
286impl Matcher for Rule {
287 fn match_node_with_env<'tree, D: Doc>(
288 &self,
289 node: Node<'tree, D>,
290 env: &mut Cow<MetaVarEnv<'tree, D>>,
291 ) -> Option<Node<'tree, D>> {
292 use Rule::*;
293 match self {
294 Pattern(pattern) => pattern.match_node_with_env(node, env),
296 Kind(kind) => kind.match_node_with_env(node, env),
297 Regex(regex) => regex.match_node_with_env(node, env),
298 NthChild(nth_child) => nth_child.match_node_with_env(node, env),
299 Range(range) => range.match_node_with_env(node, env),
300 Inside(parent) => match_and_add_label(&**parent, node, env),
302 Has(child) => match_and_add_label(&**child, node, env),
303 Precedes(latter) => match_and_add_label(&**latter, node, env),
304 Follows(former) => match_and_add_label(&**former, node, env),
305 All(all) => all.match_node_with_env(node, env),
307 Any(any) => any.match_node_with_env(node, env),
308 Not(not) => not.match_node_with_env(node, env),
309 Matches(rule) => rule.match_node_with_env(node, env),
310 }
311 }
312
313 fn potential_kinds(&self) -> Option<BitSet> {
314 use Rule::*;
315 match self {
316 Pattern(pattern) => pattern.potential_kinds(),
318 Kind(kind) => kind.potential_kinds(),
319 Regex(regex) => regex.potential_kinds(),
320 NthChild(nth_child) => nth_child.potential_kinds(),
321 Range(range) => range.potential_kinds(),
322 Inside(parent) => parent.potential_kinds(),
324 Has(child) => child.potential_kinds(),
325 Precedes(latter) => latter.potential_kinds(),
326 Follows(former) => former.potential_kinds(),
327 All(all) => all.potential_kinds(),
329 Any(any) => any.potential_kinds(),
330 Not(not) => not.potential_kinds(),
331 Matches(rule) => rule.potential_kinds(),
332 }
333 }
334}
335
336impl Default for Rule {
339 fn default() -> Self {
340 Self::Any(o::Any::new(std::iter::empty()))
341 }
342}
343
344fn match_and_add_label<'tree, D: Doc, M: Matcher>(
345 inner: &M,
346 node: Node<'tree, D>,
347 env: &mut Cow<MetaVarEnv<'tree, D>>,
348) -> Option<Node<'tree, D>> {
349 let matched = inner.match_node_with_env(node, env)?;
350 env.to_mut().add_label("secondary", matched.clone());
351 Some(matched)
352}
353
354#[derive(Error, Debug)]
355pub enum RuleSerializeError {
356 #[error("Rule must have one positive matcher.")]
357 MissPositiveMatcher,
358 #[error("Rule contains invalid kind matcher.")]
359 InvalidKind(#[from] KindMatcherError),
360 #[error("Rule contains invalid pattern matcher.")]
361 InvalidPattern(#[from] PatternError),
362 #[error("Rule contains invalid nthChild.")]
363 NthChild(#[from] NthChildError),
364 #[error("Rule contains invalid regex matcher.")]
365 WrongRegex(#[from] RegexMatcherError),
366 #[error("Rule contains invalid matches reference.")]
367 MatchesReference(#[from] ReferentRuleError),
368 #[error("Rule contains invalid range matcher.")]
369 InvalidRange(#[from] RangeMatcherError),
370 #[error("field is only supported in has/inside.")]
371 FieldNotSupported,
372 #[error("Relational rule contains invalid field {0}.")]
373 InvalidField(String),
374}
375
376pub fn deserialize_rule<L: Language>(
378 serialized: SerializableRule,
379 env: &DeserializeEnv<L>,
380) -> Result<Rule, RuleSerializeError> {
381 let mut rules = Vec::with_capacity(1);
382 use Rule as R;
383 let categorized = serialized.categorized();
384 deserialize_atomic_rule(categorized.atomic, &mut rules, env)?;
387 deserialize_composite_rule(categorized.composite, &mut rules, env)?;
388 deserialize_relational_rule(categorized.relational, &mut rules, env)?;
389
390 if rules.is_empty() {
391 Err(RuleSerializeError::MissPositiveMatcher)
392 } else if rules.len() == 1 {
393 Ok(rules.pop().expect("should not be empty"))
394 } else {
395 Ok(R::All(o::All::new(rules)))
396 }
397}
398
399fn deserialize_composite_rule<L: Language>(
400 composite: CompositeRule,
401 rules: &mut Vec<Rule>,
402 env: &DeserializeEnv<L>,
403) -> Result<(), RuleSerializeError> {
404 use Rule as R;
405 let convert_rules = |rules: Vec<SerializableRule>| -> Result<_, RuleSerializeError> {
406 let mut inner = Vec::with_capacity(rules.len());
407 for rule in rules {
408 inner.push(deserialize_rule(rule, env)?);
409 }
410 Ok(inner)
411 };
412 if let Some(all) = composite.all {
413 rules.push(R::All(o::All::new(convert_rules(all)?)));
414 }
415 if let Some(any) = composite.any {
416 rules.push(R::Any(o::Any::new(convert_rules(any)?)));
417 }
418 if let Some(not) = composite.not {
419 let not = o::Not::new(deserialize_rule(*not, env)?);
420 rules.push(R::Not(Box::new(not)));
421 }
422 if let Some(id) = composite.matches {
423 let matches = ReferentRule::try_new(id, &env.registration)?;
424 rules.push(R::Matches(matches));
425 }
426 Ok(())
427}
428
429fn deserialize_relational_rule<L: Language>(
430 relational: RelationalRule,
431 rules: &mut Vec<Rule>,
432 env: &DeserializeEnv<L>,
433) -> Result<(), RuleSerializeError> {
434 use Rule as R;
435 if let Some(inside) = relational.inside {
437 rules.push(R::Inside(Box::new(Inside::try_new(*inside, env)?)));
438 }
439 if let Some(has) = relational.has {
440 rules.push(R::Has(Box::new(Has::try_new(*has, env)?)));
441 }
442 if let Some(precedes) = relational.precedes {
443 rules.push(R::Precedes(Box::new(Precedes::try_new(*precedes, env)?)));
444 }
445 if let Some(follows) = relational.follows {
446 rules.push(R::Follows(Box::new(Follows::try_new(*follows, env)?)));
447 }
448 Ok(())
449}
450
451fn deserialize_atomic_rule<L: Language>(
452 atomic: AtomicRule,
453 rules: &mut Vec<Rule>,
454 env: &DeserializeEnv<L>,
455) -> Result<(), RuleSerializeError> {
456 use Rule as R;
457 if let Some(pattern) = atomic.pattern {
458 rules.push(match pattern {
459 PatternStyle::Str(pat) => R::Pattern(Pattern::try_new(&pat, &env.lang)?),
460 PatternStyle::Contextual {
461 context,
462 selector,
463 strictness,
464 } => {
465 let pattern = if let Some(selector) = selector {
466 Pattern::contextual(&context, &selector, &env.lang)?
467 } else {
468 Pattern::try_new(&context, &env.lang)?
469 };
470 let pattern = if let Some(strictness) = strictness {
471 pattern.with_strictness(strictness.into())
472 } else {
473 pattern
474 };
475 R::Pattern(pattern)
476 }
477 });
478 }
479 if let Some(kind) = atomic.kind {
480 rules.push(R::Kind(KindMatcher::try_new(&kind, &env.lang)?));
481 }
482 if let Some(regex) = atomic.regex {
483 rules.push(R::Regex(RegexMatcher::try_new(®ex)?));
484 }
485 if let Some(nth_child) = atomic.nth_child {
486 rules.push(R::NthChild(NthChild::try_new(nth_child, env)?));
487 }
488 if let Some(range) = atomic.range {
489 rules.push(R::Range(RangeMatcher::try_new(range.start, range.end)?));
490 }
491 Ok(())
492}
493
494#[cfg(test)]
495mod test {
496 use super::*;
497 use crate::from_str;
498 use crate::test::TypeScript;
499 use PatternStyle::*;
500 use thread_ast_engine::tree_sitter::LanguageExt;
501
502 #[test]
503 fn test_pattern() {
504 let src = r"
505pattern: Test
506";
507 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
508 assert!(rule.pattern.is_present());
509 let src = r"
510pattern:
511 context: class $C { set $B() {} }
512 selector: method_definition
513";
514 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
515 assert!(matches!(rule.pattern, Maybe::Present(Contextual { .. }),));
516 }
517
518 #[test]
519 fn test_augmentation() {
520 let src = r"
521pattern: class A {}
522inside:
523 pattern: function() {}
524";
525 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
526 assert!(rule.inside.is_present());
527 assert!(rule.pattern.is_present());
528 }
529
530 #[test]
531 fn test_multi_augmentation() {
532 let src = r"
533pattern: class A {}
534inside:
535 pattern: function() {}
536has:
537 pattern: Some()
538";
539 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
540 assert!(rule.inside.is_present());
541 assert!(rule.has.is_present());
542 assert!(rule.follows.is_absent());
543 assert!(rule.precedes.is_absent());
544 assert!(rule.pattern.is_present());
545 }
546
547 #[test]
548 fn test_maybe_not() {
549 let src = "not: 123";
550 let ret: Result<SerializableRule, _> = from_str(src);
551 assert!(ret.is_err());
552 let src = "not:";
553 let ret: Result<SerializableRule, _> = from_str(src);
554 assert!(ret.is_err());
555 }
556
557 #[test]
558 fn test_nested_augmentation() {
559 let src = r"
560pattern: class A {}
561inside:
562 pattern: function() {}
563 inside:
564 pattern:
565 context: Some()
566 selector: ss
567";
568 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
569 assert!(rule.inside.is_present());
570 let inside = rule.inside.unwrap();
571 assert!(inside.rule.pattern.is_present());
572 assert!(inside.rule.inside.unwrap().rule.pattern.is_present());
573 }
574
575 #[test]
576 fn test_precedes_follows() {
577 let src = r"
578pattern: class A {}
579precedes:
580 pattern: function() {}
581follows:
582 pattern:
583 context: Some()
584 selector: ss
585";
586 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
587 assert!(rule.precedes.is_present());
588 assert!(rule.follows.is_present());
589 let follows = rule.follows.unwrap();
590 assert!(follows.rule.pattern.is_present());
591 assert!(follows.rule.pattern.is_present());
592 }
593
594 #[test]
595 fn test_deserialize_rule() {
596 let src = r"
597pattern: class A {}
598kind: class_declaration
599";
600 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
601 let env = DeserializeEnv::new(TypeScript::Tsx);
602 let rule = deserialize_rule(rule, &env).expect("should deserialize");
603 let root = TypeScript::Tsx.ast_grep("class A {}");
604 assert!(root.root().find(rule).is_some());
605 }
606
607 #[test]
608 fn test_deserialize_order() {
609 let src = r"
610pattern: class A {}
611inside:
612 kind: class
613";
614 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
615 let env = DeserializeEnv::new(TypeScript::Tsx);
616 let rule = deserialize_rule(rule, &env).expect("should deserialize");
617 assert!(matches!(rule, Rule::All(_)));
618 }
619
620 #[test]
621 fn test_defined_vars() {
622 let src = r"
623pattern: var $A = 123
624inside:
625 pattern: var $B = 456
626";
627 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
628 let env = DeserializeEnv::new(TypeScript::Tsx);
629 let rule = deserialize_rule(rule, &env).expect("should deserialize");
630 assert_eq!(rule.defined_vars(), ["A", "B"].into_iter().collect());
631 }
632
633 #[test]
634 fn test_issue_1164() {
635 let src = r"
636 kind: statement_block
637 has:
638 pattern: this.$A = promise()
639 stopBy: end";
640 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
641 let env = DeserializeEnv::new(TypeScript::Tsx);
642 let rule = deserialize_rule(rule, &env).expect("should deserialize");
643 let root = TypeScript::Tsx.ast_grep(
644 "if (a) {
645 this.a = b;
646 this.d = promise()
647 }",
648 );
649 assert!(root.root().find(rule).is_some());
650 }
651
652 #[test]
653 fn test_issue_1225() {
654 let src = r"
655 kind: statement_block
656 has:
657 pattern: $A
658 regex: const";
659 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
660 let env = DeserializeEnv::new(TypeScript::Tsx);
661 let rule = deserialize_rule(rule, &env).expect("should deserialize");
662 let root = TypeScript::Tsx.ast_grep(
663 "{
664 let x = 1;
665 const z = 9;
666 }",
667 );
668 assert!(root.root().find(rule).is_some());
669 }
670}