Skip to main content

ryo_pattern/
code_pattern.rs

1//! CodePattern - AST structural matching
2//!
3//! Describes structural patterns over AST nodes for pattern matching.
4
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// AST node type to match
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
11#[serde(rename_all = "PascalCase")]
12pub enum NodeKind {
13    // Expressions
14    /// Any expression.
15    Expr,
16    /// Literal (number / string / bool / etc.).
17    Literal,
18    /// Path expression (`foo::bar`).
19    Path,
20    /// Method call (`x.foo()`).
21    MethodCall,
22    /// Free function call (`foo()`).
23    FunctionCall,
24    /// Macro invocation (`println!()`).
25    MacroCall,
26    /// Binary operator expression.
27    BinaryOp,
28    /// Unary operator expression.
29    UnaryOp,
30    /// `if` / `if let` expression.
31    If,
32    /// `match` expression.
33    Match,
34    /// `loop` / `while` / `for` expression.
35    Loop,
36    /// Block expression `{ ... }`.
37    Block,
38    /// Closure expression (`|x| ...`).
39    Closure,
40    /// `await` expression.
41    Await,
42    /// `?` (try) expression.
43    Try,
44    /// `return` expression.
45    Return,
46    /// Index expression (`a[i]`).
47    Index,
48
49    // Items
50    /// `fn` item.
51    Function,
52    /// `struct` item.
53    Struct,
54    /// `enum` item.
55    Enum,
56    /// `trait` item.
57    Trait,
58    /// `impl` block.
59    Impl,
60    /// `mod` item.
61    Mod,
62    /// `use` declaration.
63    Use,
64    /// `const` item.
65    Const,
66    /// `static` item.
67    Static,
68    /// `type` alias.
69    TypeAlias,
70
71    // Parts
72    /// Struct / enum field.
73    Field,
74    /// Enum variant.
75    Variant,
76    /// Function parameter.
77    Param,
78    /// Call site argument.
79    Arg,
80    /// Generic type / lifetime argument.
81    GenericArg,
82    /// Lifetime token.
83    Lifetime,
84    /// Attribute (`#[...]`).
85    Attribute,
86
87    // Special
88    /// `let` expression (`if let` / `while let`).
89    LetExpr,
90    /// Wildcard (matches any node).
91    Wildcard,
92    /// Unit `()`.
93    Unit,
94}
95
96/// Name matcher for symbol/method names
97#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
98#[serde(untagged)]
99pub enum NameMatcher {
100    /// Exact match
101    Exact(String),
102    /// Pattern-based match
103    Pattern(NamePattern),
104}
105
106/// Pattern-based name matching
107#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
108pub struct NamePattern {
109    /// Glob pattern (e.g., "get_*")
110    #[serde(default, skip_serializing_if = "Option::is_none")]
111    pub glob: Option<String>,
112    /// Regex pattern (e.g., "^is_")
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub regex: Option<String>,
115    /// Starts with prefix
116    #[serde(default, skip_serializing_if = "Option::is_none")]
117    pub starts_with: Option<String>,
118    /// Ends with suffix
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    pub ends_with: Option<String>,
121    /// Contains substring
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub contains: Option<String>,
124}
125
126/// Pattern expression (recursive)
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
128#[serde(untagged)]
129pub enum PatternExpr {
130    /// Nested code pattern
131    Pattern(Box<CodePattern>),
132    /// Name matcher
133    Name(NameMatcher),
134    /// Literal value
135    Literal(serde_json::Value),
136    /// Wildcard (_) - matches anything
137    Wildcard,
138    /// Capture variable reference (e.g., "$VAR")
139    Capture(String),
140}
141
142/// Match arm pattern for Match expressions
143#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
144pub struct ArmPattern {
145    /// Match on the pattern path (e.g., "Some", "None", "Ok", "Err")
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub pattern_path: Option<String>,
148
149    /// Match on the arm body expression
150    #[serde(default, skip_serializing_if = "Option::is_none")]
151    pub body: Option<Box<CodePattern>>,
152}
153
154/// AST Pattern for structural matching
155#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
156pub struct CodePattern {
157    /// AST node type to match
158    pub node: NodeKind,
159
160    /// Required arm count (for Match nodes)
161    #[serde(default, skip_serializing_if = "Option::is_none")]
162    pub arm_count: Option<usize>,
163
164    /// Arm patterns (for Match nodes, order-independent matching)
165    #[serde(default, skip_serializing_if = "Option::is_none")]
166    pub arms: Option<Vec<ArmPattern>>,
167
168    /// Child patterns (field name -> pattern)
169    #[serde(flatten)]
170    pub children: HashMap<String, PatternExpr>,
171
172    /// Capture variable (e.g., "$RECEIVER")
173    #[serde(default, skip_serializing_if = "Option::is_none")]
174    pub capture: Option<String>,
175
176    /// Match zero or more items (ellipsis)
177    #[serde(default)]
178    pub ellipsis: bool,
179}
180
181impl CodePattern {
182    /// Create a new CodePattern for the given node kind
183    pub fn new(node: NodeKind) -> Self {
184        Self {
185            node,
186            arm_count: None,
187            arms: None,
188            children: HashMap::new(),
189            capture: None,
190            ellipsis: false,
191        }
192    }
193
194    /// Add a child pattern
195    pub fn with_child(mut self, name: impl Into<String>, pattern: PatternExpr) -> Self {
196        self.children.insert(name.into(), pattern);
197        self
198    }
199
200    /// Set capture variable
201    pub fn with_capture(mut self, var: impl Into<String>) -> Self {
202        self.capture = Some(var.into());
203        self
204    }
205
206    /// Set ellipsis mode
207    pub fn with_ellipsis(mut self) -> Self {
208        self.ellipsis = true;
209        self
210    }
211}
212
213/// Body match conditions for symbol bodies
214#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
215pub struct BodyMatch {
216    /// At least one node matches each pattern (existential)
217    #[serde(default, skip_serializing_if = "Option::is_none")]
218    pub contains: Option<Vec<CodePattern>>,
219
220    /// No node matches these patterns (negation)
221    #[serde(default, skip_serializing_if = "Option::is_none")]
222    pub not_contains: Option<Vec<CodePattern>>,
223
224    /// All listed patterns must have at least one match
225    #[serde(default, skip_serializing_if = "Option::is_none")]
226    pub all_of: Option<Vec<CodePattern>>,
227}
228
229impl BodyMatch {
230    /// Construct an empty `BodyMatch`.
231    pub fn new() -> Self {
232        Self::default()
233    }
234
235    /// Add a contains pattern
236    pub fn contains(mut self, pattern: CodePattern) -> Self {
237        self.contains.get_or_insert_with(Vec::new).push(pattern);
238        self
239    }
240
241    /// Add a not_contains pattern
242    pub fn not_contains(mut self, pattern: CodePattern) -> Self {
243        self.not_contains.get_or_insert_with(Vec::new).push(pattern);
244        self
245    }
246
247    /// Add an all_of pattern
248    pub fn all_of(mut self, pattern: CodePattern) -> Self {
249        self.all_of.get_or_insert_with(Vec::new).push(pattern);
250        self
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_code_pattern_builder() {
260        let pattern = CodePattern::new(NodeKind::MethodCall)
261            .with_child(
262                "method",
263                PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
264            )
265            .with_capture("$UNWRAP");
266
267        assert_eq!(pattern.node, NodeKind::MethodCall);
268        assert_eq!(pattern.capture, Some("$UNWRAP".to_string()));
269        assert!(pattern.children.contains_key("method"));
270    }
271
272    #[test]
273    fn test_body_match_builder() {
274        let body = BodyMatch::new()
275            .contains(CodePattern::new(NodeKind::MethodCall))
276            .not_contains(CodePattern::new(NodeKind::MacroCall));
277
278        assert!(body.contains.is_some());
279        assert!(body.not_contains.is_some());
280        assert!(body.all_of.is_none());
281    }
282
283    #[test]
284    fn test_serialize_deserialize() {
285        let pattern = CodePattern::new(NodeKind::MethodCall).with_child(
286            "method",
287            PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
288        );
289
290        let json = serde_json::to_string(&pattern).unwrap();
291        let deserialized: CodePattern = serde_json::from_str(&json).unwrap();
292
293        assert_eq!(pattern.node, deserialized.node);
294    }
295}