1use crate::{
6 ArmPattern, CapturedNode, CodePattern, MatchResult, NameMatcher, NodeKind, PatternExpr, Span,
7};
8use ryo_source::pure::{PureBlock, PureExpr, PureFn, PureMatchArm, PurePattern, PureStmt};
9use std::collections::HashMap;
10
11#[derive(Debug, Default)]
13pub struct MatchContext {
14 pub captures: HashMap<String, CapturedNode>,
16}
17
18impl MatchContext {
19 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn capture(&mut self, name: impl Into<String>, text: impl Into<String>) {
26 let name = name.into();
27 self.captures
28 .insert(name, CapturedNode::new(Span::point(0, 0), text.into()));
29 }
30
31 pub fn merge(&mut self, other: MatchContext) {
33 self.captures.extend(other.captures);
34 }
35
36 pub fn into_match_result(self) -> MatchResult {
38 let mut result = MatchResult::matched();
39 result.captures = self.captures;
40 result
41 }
42}
43
44pub struct ExprMatcher<'p> {
46 pattern: &'p CodePattern,
47}
48
49impl<'p> ExprMatcher<'p> {
50 pub fn new(pattern: &'p CodePattern) -> Self {
52 Self { pattern }
53 }
54
55 pub fn matches(&self, expr: &PureExpr) -> Option<MatchContext> {
57 let mut ctx = MatchContext::new();
58
59 if self.match_expr(expr, &mut ctx) {
60 if let Some(ref capture_name) = self.pattern.capture {
62 ctx.capture(capture_name.clone(), expr_to_string(expr));
63 }
64 Some(ctx)
65 } else {
66 None
67 }
68 }
69
70 fn match_expr(&self, expr: &PureExpr, ctx: &mut MatchContext) -> bool {
72 match (&self.pattern.node, expr) {
73 (
75 NodeKind::MethodCall,
76 PureExpr::MethodCall {
77 receiver,
78 method,
79 args,
80 ..
81 },
82 ) => {
83 if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("method") {
85 if !match_name(name_matcher, method) {
86 return false;
87 }
88 }
89
90 if let Some(receiver_pattern) = self.pattern.children.get("receiver") {
92 if !self.match_pattern_expr(receiver_pattern, receiver, ctx) {
93 return false;
94 }
95 }
96
97 if let Some(PatternExpr::Pattern(args_pattern)) = self.pattern.children.get("args")
99 {
100 let _ = args_pattern;
102 let _ = args;
103 }
104
105 true
106 }
107
108 (NodeKind::FunctionCall, PureExpr::Call { func, args }) => {
110 if let Some(func_pattern) = self.pattern.children.get("func") {
112 if !self.match_pattern_expr(func_pattern, func, ctx) {
113 return false;
114 }
115 }
116 let _ = args;
117 true
118 }
119
120 (NodeKind::MacroCall, PureExpr::Macro { name, .. }) => {
122 if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("macro") {
124 if !match_name(name_matcher, name) {
125 return false;
126 }
127 }
128 true
129 }
130
131 (NodeKind::Try, PureExpr::Try(expr)) => {
133 if let Some(expr_pattern) = self.pattern.children.get("expr") {
134 if !self.match_pattern_expr(expr_pattern, expr, ctx) {
135 return false;
136 }
137 }
138 true
139 }
140
141 (NodeKind::Await, PureExpr::Await(expr)) => {
143 if let Some(expr_pattern) = self.pattern.children.get("expr") {
144 if !self.match_pattern_expr(expr_pattern, expr, ctx) {
145 return false;
146 }
147 }
148 true
149 }
150
151 (NodeKind::BinaryOp, PureExpr::Binary { op, left, right }) => {
153 if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("op") {
154 if !match_name(name_matcher, op) {
155 return false;
156 }
157 }
158 if let Some(left_pattern) = self.pattern.children.get("left") {
159 if !self.match_pattern_expr(left_pattern, left, ctx) {
160 return false;
161 }
162 }
163 if let Some(right_pattern) = self.pattern.children.get("right") {
164 if !self.match_pattern_expr(right_pattern, right, ctx) {
165 return false;
166 }
167 }
168 true
169 }
170
171 (NodeKind::Path, PureExpr::Path(path)) => {
173 if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("path") {
174 if !match_name(name_matcher, path) {
175 return false;
176 }
177 }
178 true
179 }
180
181 (NodeKind::Literal, PureExpr::Lit(lit)) => {
183 if let Some(value_pattern) = self.pattern.children.get("value") {
184 match value_pattern {
185 PatternExpr::Literal(expected) => {
186 if let Some(expected_str) = expected.as_str() {
187 if lit != expected_str {
188 return false;
189 }
190 }
191 }
192 PatternExpr::Name(NameMatcher::Exact(expected_str))
195 if lit != expected_str =>
196 {
197 return false;
198 }
199 _ => {}
200 }
201 }
202 true
203 }
204
205 (NodeKind::Expr, _) => true,
207
208 (NodeKind::Block, PureExpr::Block { .. }) => true,
210
211 (
213 NodeKind::If,
214 PureExpr::If {
215 cond,
216 then_branch,
217 else_branch,
218 },
219 ) => {
220 if let Some(cond_pattern) = self.pattern.children.get("cond") {
221 if !self.match_pattern_expr(cond_pattern, cond, ctx) {
222 return false;
223 }
224 }
225 let _ = (then_branch, else_branch);
226 true
227 }
228
229 (NodeKind::Match, PureExpr::Match { expr, arms }) => {
231 if let Some(expr_pattern) = self.pattern.children.get("expr") {
232 if !self.match_pattern_expr(expr_pattern, expr, ctx) {
233 return false;
234 }
235 }
236 if let Some(expected) = self.pattern.arm_count {
238 if arms.len() != expected {
239 return false;
240 }
241 }
242 if let Some(arm_patterns) = &self.pattern.arms {
244 for ap in arm_patterns {
245 if !arms.iter().any(|arm| match_arm_pattern(ap, arm, ctx)) {
246 return false;
247 }
248 }
249 }
250 true
251 }
252
253 (NodeKind::Return, PureExpr::Return(maybe_expr)) => {
255 if let Some(expr_pattern) = self.pattern.children.get("expr") {
256 if let Some(expr) = maybe_expr {
257 if !self.match_pattern_expr(expr_pattern, expr, ctx) {
258 return false;
259 }
260 } else {
261 return false;
262 }
263 }
264 true
265 }
266
267 (NodeKind::Loop, PureExpr::Loop { .. }) => true,
269 (NodeKind::Loop, PureExpr::While { .. }) => true,
270 (NodeKind::Loop, PureExpr::For { .. }) => true,
271
272 (NodeKind::Closure, PureExpr::Closure { .. }) => true,
274
275 (NodeKind::Index, PureExpr::Index { expr, index }) => {
277 if let Some(expr_pattern) = self.pattern.children.get("expr") {
278 if !self.match_pattern_expr(expr_pattern, expr, ctx) {
279 return false;
280 }
281 }
282 if let Some(index_pattern) = self.pattern.children.get("index") {
283 if !self.match_pattern_expr(index_pattern, index, ctx) {
284 return false;
285 }
286 }
287 true
288 }
289
290 _ => false,
291 }
292 }
293
294 fn match_pattern_expr(
296 &self,
297 pattern: &PatternExpr,
298 expr: &PureExpr,
299 ctx: &mut MatchContext,
300 ) -> bool {
301 match pattern {
302 PatternExpr::Pattern(nested) => {
303 let matcher = ExprMatcher::new(nested);
304 if let Some(nested_ctx) = matcher.matches(expr) {
305 ctx.merge(nested_ctx);
306 true
307 } else {
308 false
309 }
310 }
311 PatternExpr::Capture(var_name) => {
312 ctx.capture(var_name.clone(), expr_to_string(expr));
314 true
315 }
316 PatternExpr::Wildcard => true,
317 PatternExpr::Name(name_matcher) => {
318 if let PureExpr::Path(path) = expr {
319 match_name(name_matcher, path)
320 } else {
321 false
322 }
323 }
324 PatternExpr::Literal(expected) => {
325 if let PureExpr::Lit(lit) = expr {
326 if let Some(expected_str) = expected.as_str() {
327 lit == expected_str
328 } else {
329 false
330 }
331 } else {
332 false
333 }
334 }
335 }
336 }
337}
338
339fn match_name(matcher: &NameMatcher, name: &str) -> bool {
341 match matcher {
342 NameMatcher::Exact(expected) => name == expected,
343 NameMatcher::Pattern(pattern) => {
344 if let Some(ref prefix) = pattern.starts_with {
345 if !name.starts_with(prefix) {
346 return false;
347 }
348 }
349 if let Some(ref suffix) = pattern.ends_with {
350 if !name.ends_with(suffix) {
351 return false;
352 }
353 }
354 if let Some(ref substr) = pattern.contains {
355 if !name.contains(substr) {
356 return false;
357 }
358 }
359 if let Some(ref glob) = pattern.glob {
360 if !match_glob(glob, name) {
361 return false;
362 }
363 }
364 true
365 }
366 }
367}
368
369fn match_glob(pattern: &str, name: &str) -> bool {
371 if pattern == "*" {
372 return true;
373 }
374 if let Some(prefix) = pattern.strip_suffix('*') {
375 return name.starts_with(prefix);
376 }
377 if let Some(suffix) = pattern.strip_prefix('*') {
378 return name.ends_with(suffix);
379 }
380 pattern == name
381}
382
383fn match_arm_pattern(ap: &ArmPattern, arm: &PureMatchArm, _ctx: &mut MatchContext) -> bool {
385 if let Some(ref expected_path) = ap.pattern_path {
387 if !pattern_contains_path(&arm.pattern, expected_path) {
388 return false;
389 }
390 }
391 if let Some(ref body_pattern) = ap.body {
393 let matcher = ExprMatcher::new(body_pattern);
394 if matcher.matches(&arm.body).is_none() {
395 return false;
396 }
397 }
398 true
399}
400
401fn pattern_contains_path(pat: &PurePattern, expected: &str) -> bool {
403 match pat {
404 PurePattern::Path(p) => p == expected || p.ends_with(&format!("::{}", expected)),
405 PurePattern::Struct { path, .. } => {
406 path == expected || path.ends_with(&format!("::{}", expected))
407 }
408 PurePattern::Ident { name, .. } => name == expected,
409 PurePattern::Or(patterns) => patterns.iter().any(|p| pattern_contains_path(p, expected)),
410 PurePattern::Ref { pattern, .. } => pattern_contains_path(pattern, expected),
411 PurePattern::Tuple(patterns) => patterns.iter().any(|p| pattern_contains_path(p, expected)),
412 _ => false,
413 }
414}
415
416pub fn expr_to_string(expr: &PureExpr) -> String {
418 match expr {
419 PureExpr::Lit(s) => s.clone(),
420 PureExpr::Path(s) => s.clone(),
421 PureExpr::MethodCall {
422 receiver,
423 method,
424 args,
425 turbofish,
426 } => {
427 let receiver_str = expr_to_string(receiver);
428 let turbofish_str = turbofish
429 .as_ref()
430 .map(|t| format!("::{}", t))
431 .unwrap_or_default();
432 let args_str = args
433 .iter()
434 .map(expr_to_string)
435 .collect::<Vec<_>>()
436 .join(", ");
437 format!("{}.{}{}({})", receiver_str, method, turbofish_str, args_str)
438 }
439 PureExpr::Call { func, args } => {
440 let func_str = expr_to_string(func);
441 let args_str = args
442 .iter()
443 .map(expr_to_string)
444 .collect::<Vec<_>>()
445 .join(", ");
446 format!("{}({})", func_str, args_str)
447 }
448 PureExpr::Binary { op, left, right } => {
449 format!("{} {} {}", expr_to_string(left), op, expr_to_string(right))
450 }
451 PureExpr::Unary { op, expr } => {
452 format!("{}{}", op, expr_to_string(expr))
453 }
454 PureExpr::Try(expr) => {
455 format!("{}?", expr_to_string(expr))
456 }
457 PureExpr::Await(expr) => {
458 format!("{}.await", expr_to_string(expr))
459 }
460 PureExpr::Field { expr, field } => {
461 format!("{}.{}", expr_to_string(expr), field)
462 }
463 PureExpr::Return(Some(e)) => format!("return {}", expr_to_string(e)),
464 PureExpr::Return(None) => "return".to_string(),
465 PureExpr::Block { .. } => "{ ... }".to_string(),
466 PureExpr::If { .. } => "if ...".to_string(),
467 PureExpr::Match { .. } => "match ...".to_string(),
468 PureExpr::Closure { .. } => "|...| ...".to_string(),
469 PureExpr::Tuple(items) => {
470 let items_str = items
471 .iter()
472 .map(expr_to_string)
473 .collect::<Vec<_>>()
474 .join(", ");
475 format!("({})", items_str)
476 }
477 PureExpr::Array(items) => {
478 let items_str = items
479 .iter()
480 .map(expr_to_string)
481 .collect::<Vec<_>>()
482 .join(", ");
483 format!("[{}]", items_str)
484 }
485 PureExpr::Macro { name, .. } => format!("{}!(...)", name),
486 _ => "<expr>".to_string(),
487 }
488}
489
490pub struct BodyScanner<'p> {
492 pattern: &'p CodePattern,
493}
494
495impl<'p> BodyScanner<'p> {
496 pub fn new(pattern: &'p CodePattern) -> Self {
498 Self { pattern }
499 }
500
501 pub fn scan_fn(&self, func: &PureFn) -> Vec<MatchResult> {
503 let mut results = Vec::new();
504 self.scan_block(&func.body, &mut results);
505 results
506 }
507
508 fn scan_block(&self, block: &PureBlock, results: &mut Vec<MatchResult>) {
510 for stmt in &block.stmts {
511 self.scan_stmt(stmt, results);
512 }
513 }
514
515 fn scan_stmt(&self, stmt: &PureStmt, results: &mut Vec<MatchResult>) {
517 match stmt {
518 PureStmt::Local {
519 init: Some(expr), ..
520 } => {
521 self.scan_expr(expr, results);
522 }
523 PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
524 self.scan_expr(expr, results);
525 }
526 _ => {}
527 }
528 }
529
530 fn scan_expr(&self, expr: &PureExpr, results: &mut Vec<MatchResult>) {
532 let matcher = ExprMatcher::new(self.pattern);
534 if let Some(ctx) = matcher.matches(expr) {
535 results.push(ctx.into_match_result());
536 }
537
538 match expr {
540 PureExpr::MethodCall { receiver, args, .. } => {
541 self.scan_expr(receiver, results);
542 for arg in args {
543 self.scan_expr(arg, results);
544 }
545 }
546 PureExpr::Call { func, args } => {
547 self.scan_expr(func, results);
548 for arg in args {
549 self.scan_expr(arg, results);
550 }
551 }
552 PureExpr::Binary { left, right, .. } => {
553 self.scan_expr(left, results);
554 self.scan_expr(right, results);
555 }
556 PureExpr::Unary { expr, .. } => {
557 self.scan_expr(expr, results);
558 }
559 PureExpr::Try(expr) => {
560 self.scan_expr(expr, results);
561 }
562 PureExpr::Await(expr) => {
563 self.scan_expr(expr, results);
564 }
565 PureExpr::Field { expr, .. } => {
566 self.scan_expr(expr, results);
567 }
568 PureExpr::Index { expr, index } => {
569 self.scan_expr(expr, results);
570 self.scan_expr(index, results);
571 }
572 PureExpr::Block { block, .. } => {
573 self.scan_block(block, results);
574 }
575 PureExpr::If {
576 cond,
577 then_branch,
578 else_branch,
579 } => {
580 self.scan_expr(cond, results);
581 self.scan_block(then_branch, results);
582 if let Some(else_expr) = else_branch {
583 self.scan_expr(else_expr, results);
584 }
585 }
586 PureExpr::Match { expr, arms } => {
587 self.scan_expr(expr, results);
588 for arm in arms {
589 self.scan_expr(&arm.body, results);
590 }
591 }
592 PureExpr::Loop { body, .. } => {
593 self.scan_block(body, results);
594 }
595 PureExpr::While { cond, body, .. } => {
596 self.scan_expr(cond, results);
597 self.scan_block(body, results);
598 }
599 PureExpr::For { expr, body, .. } => {
600 self.scan_expr(expr, results);
601 self.scan_block(body, results);
602 }
603 PureExpr::Return(Some(e)) => {
604 self.scan_expr(e, results);
605 }
606 PureExpr::Break { expr: Some(e), .. } => {
607 self.scan_expr(e, results);
608 }
609 PureExpr::Closure { body, .. } => {
610 self.scan_expr(body, results);
611 }
612 PureExpr::Struct { fields, .. } => {
613 for (_, field_expr) in fields {
614 self.scan_expr(field_expr, results);
615 }
616 }
617 PureExpr::Tuple(items) | PureExpr::Array(items) => {
618 for item in items {
619 self.scan_expr(item, results);
620 }
621 }
622 _ => {}
623 }
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use ryo_source::pure::MacroDelimiter;
631
632 #[test]
633 fn test_match_method_call() {
634 let pattern = CodePattern::new(NodeKind::MethodCall).with_child(
635 "method",
636 PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
637 );
638
639 let expr = PureExpr::MethodCall {
640 receiver: Box::new(PureExpr::Path("result".into())),
641 method: "unwrap".into(),
642 turbofish: None,
643 args: vec![],
644 };
645
646 let matcher = ExprMatcher::new(&pattern);
647 assert!(matcher.matches(&expr).is_some());
648 }
649
650 #[test]
651 fn test_no_match_different_method() {
652 let pattern = CodePattern::new(NodeKind::MethodCall).with_child(
653 "method",
654 PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
655 );
656
657 let expr = PureExpr::MethodCall {
658 receiver: Box::new(PureExpr::Path("result".into())),
659 method: "expect".into(),
660 turbofish: None,
661 args: vec![],
662 };
663
664 let matcher = ExprMatcher::new(&pattern);
665 assert!(matcher.matches(&expr).is_none());
666 }
667
668 #[test]
669 fn test_capture_receiver() {
670 let pattern = CodePattern::new(NodeKind::MethodCall)
671 .with_child(
672 "method",
673 PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
674 )
675 .with_child("receiver", PatternExpr::Capture("$x".into()))
676 .with_capture("$call");
677
678 let expr = PureExpr::MethodCall {
679 receiver: Box::new(PureExpr::Path("my_result".into())),
680 method: "unwrap".into(),
681 turbofish: None,
682 args: vec![],
683 };
684
685 let matcher = ExprMatcher::new(&pattern);
686 let ctx = matcher.matches(&expr).unwrap();
687
688 assert!(ctx.captures.contains_key("$x"));
689 assert!(ctx.captures.contains_key("$call"));
690 assert_eq!(ctx.captures["$x"].text, "my_result");
691 }
692
693 #[test]
694 fn test_glob_matching() {
695 assert!(match_glob("get_*", "get_name"));
696 assert!(match_glob("*_id", "user_id"));
697 assert!(match_glob("*", "anything"));
698 assert!(!match_glob("get_*", "set_name"));
699 }
700
701 #[test]
702 fn test_literal_match_with_name_pattern() {
703 let pattern = CodePattern::new(NodeKind::Literal).with_child(
707 "value",
708 PatternExpr::Name(NameMatcher::Exact("true".into())),
709 );
710
711 let expr_true = PureExpr::Lit("true".into());
713 let matcher = ExprMatcher::new(&pattern);
714 assert!(matcher.matches(&expr_true).is_some());
715
716 let expr_false = PureExpr::Lit("false".into());
718 assert!(matcher.matches(&expr_false).is_none());
719
720 let expr_num = PureExpr::Lit("42".into());
722 assert!(matcher.matches(&expr_num).is_none());
723 }
724
725 #[test]
726 fn test_literal_match_with_literal_pattern() {
727 let pattern = CodePattern::new(NodeKind::Literal)
729 .with_child("value", PatternExpr::Literal(serde_json::json!("true")));
730
731 let expr_true = PureExpr::Lit("true".into());
732 let matcher = ExprMatcher::new(&pattern);
733 assert!(matcher.matches(&expr_true).is_some());
734
735 let expr_false = PureExpr::Lit("false".into());
736 assert!(matcher.matches(&expr_false).is_none());
737 }
738
739 #[test]
740 fn test_macro_call_match() {
741 let pattern = CodePattern::new(NodeKind::MacroCall).with_child(
743 "macro",
744 PatternExpr::Name(NameMatcher::Exact("todo".into())),
745 );
746
747 let expr_todo = PureExpr::Macro {
749 name: "todo".into(),
750 delimiter: MacroDelimiter::Paren,
751 tokens: "".into(),
752 };
753 let matcher = ExprMatcher::new(&pattern);
754 assert!(matcher.matches(&expr_todo).is_some());
755
756 let expr_println = PureExpr::Macro {
758 name: "println".into(),
759 delimiter: MacroDelimiter::Paren,
760 tokens: "".into(),
761 };
762 assert!(matcher.matches(&expr_println).is_none());
763
764 let expr_vec = PureExpr::Macro {
766 name: "vec".into(),
767 delimiter: MacroDelimiter::Bracket,
768 tokens: "".into(),
769 };
770 assert!(matcher.matches(&expr_vec).is_none());
771 }
772
773 #[test]
774 fn test_macro_call_no_filter_matches_all() {
775 let pattern = CodePattern::new(NodeKind::MacroCall);
777
778 let expr_todo = PureExpr::Macro {
779 name: "todo".into(),
780 delimiter: MacroDelimiter::Paren,
781 tokens: "".into(),
782 };
783 let matcher = ExprMatcher::new(&pattern);
784 assert!(matcher.matches(&expr_todo).is_some());
785
786 let expr_vec = PureExpr::Macro {
787 name: "vec".into(),
788 delimiter: MacroDelimiter::Bracket,
789 tokens: "".into(),
790 };
791 assert!(matcher.matches(&expr_vec).is_some());
792 }
793
794 #[test]
795 fn test_path_match_exact() {
796 let pattern = CodePattern::new(NodeKind::Path).with_child(
797 "path",
798 PatternExpr::Name(NameMatcher::Exact("Filter::Recurse".into())),
799 );
800
801 let expr_match = PureExpr::Path("Filter::Recurse".into());
802 let matcher = ExprMatcher::new(&pattern);
803 assert!(matcher.matches(&expr_match).is_some());
804
805 let expr_no_match = PureExpr::Path("Filter::Include".into());
806 assert!(matcher.matches(&expr_no_match).is_none());
807
808 let expr_unrelated = PureExpr::Path("something_else".into());
809 assert!(matcher.matches(&expr_unrelated).is_none());
810 }
811
812 #[test]
813 fn test_path_no_filter_matches_all() {
814 let pattern = CodePattern::new(NodeKind::Path);
816
817 let expr = PureExpr::Path("anything".into());
818 let matcher = ExprMatcher::new(&pattern);
819 assert!(matcher.matches(&expr).is_some());
820 }
821}