Skip to main content

uni_query/query/rewrite/
walker.rs

1/// Expression tree walker for applying rewrite rules
2use crate::query::rewrite::context::RewriteContext;
3use crate::query::rewrite::error::RewriteError;
4use crate::query::rewrite::registry::RewriteRegistry;
5use uni_cypher::ast::{Expr, MapProjectionItem, Query, Statement};
6
7/// Walks expression trees and applies rewrite rules
8pub struct ExpressionWalker<'a> {
9    registry: &'a RewriteRegistry,
10    context: RewriteContext,
11}
12
13impl<'a> ExpressionWalker<'a> {
14    /// Create a new expression walker
15    pub fn new(registry: &'a RewriteRegistry, context: RewriteContext) -> Self {
16        Self { registry, context }
17    }
18
19    /// Get the rewrite context (for accessing statistics)
20    pub fn context(&self) -> &RewriteContext {
21        &self.context
22    }
23
24    /// Get a mutable reference to the rewrite context
25    pub fn context_mut(&mut self) -> &mut RewriteContext {
26        &mut self.context
27    }
28
29    /// Take ownership of the context (for retrieving statistics)
30    pub fn into_context(self) -> RewriteContext {
31        self.context
32    }
33
34    /// Rewrite a complete statement
35    pub fn rewrite_statement(&mut self, stmt: Statement) -> Statement {
36        Statement {
37            clauses: stmt
38                .clauses
39                .into_iter()
40                .map(|c| self.rewrite_clause(c))
41                .collect(),
42        }
43    }
44
45    /// Rewrite a query
46    pub fn rewrite_query(&mut self, query: Query) -> Query {
47        match query {
48            Query::Single(stmt) => Query::Single(self.rewrite_statement(stmt)),
49            Query::Union { left, right, all } => Query::Union {
50                left: Box::new(self.rewrite_query(*left)),
51                right: Box::new(self.rewrite_query(*right)),
52                all,
53            },
54            Query::Schema(schema_cmd) => Query::Schema(schema_cmd),
55            Query::Transaction(txn_cmd) => Query::Transaction(txn_cmd),
56            Query::Explain(inner) => Query::Explain(Box::new(self.rewrite_query(*inner))),
57            Query::TimeTravel { .. } => {
58                unreachable!("TimeTravel should be resolved at API layer before rewriting")
59            }
60        }
61    }
62
63    /// Rewrite a clause
64    fn rewrite_clause(&mut self, clause: uni_cypher::ast::Clause) -> uni_cypher::ast::Clause {
65        use uni_cypher::ast::Clause;
66
67        match clause {
68            Clause::Match(m) => Clause::Match(self.rewrite_match_clause(m)),
69            Clause::Create(c) => Clause::Create(self.rewrite_create_clause(c)),
70            Clause::Return(r) => Clause::Return(self.rewrite_return_clause(r)),
71            Clause::With(w) => Clause::With(self.rewrite_with_clause(w)),
72            Clause::Unwind(u) => Clause::Unwind(self.rewrite_unwind_clause(u)),
73            Clause::Set(s) => Clause::Set(self.rewrite_set_clause(s)),
74            Clause::Delete(d) => Clause::Delete(self.rewrite_delete_clause(d)),
75            Clause::Remove(r) => Clause::Remove(self.rewrite_remove_clause(r)),
76            // Other clauses that don't contain expressions or are not yet handled
77            other => other,
78        }
79    }
80
81    fn rewrite_match_clause(
82        &mut self,
83        m: uni_cypher::ast::MatchClause,
84    ) -> uni_cypher::ast::MatchClause {
85        uni_cypher::ast::MatchClause {
86            optional: m.optional,
87            pattern: self.rewrite_pattern(m.pattern),
88            where_clause: m.where_clause.map(|e| self.rewrite_expr(e)),
89        }
90    }
91
92    fn rewrite_create_clause(
93        &mut self,
94        c: uni_cypher::ast::CreateClause,
95    ) -> uni_cypher::ast::CreateClause {
96        uni_cypher::ast::CreateClause {
97            pattern: self.rewrite_pattern(c.pattern),
98        }
99    }
100
101    fn rewrite_delete_clause(
102        &mut self,
103        d: uni_cypher::ast::DeleteClause,
104    ) -> uni_cypher::ast::DeleteClause {
105        uni_cypher::ast::DeleteClause {
106            detach: d.detach,
107            items: d.items.into_iter().map(|e| self.rewrite_expr(e)).collect(),
108        }
109    }
110
111    fn rewrite_set_clause(&mut self, s: uni_cypher::ast::SetClause) -> uni_cypher::ast::SetClause {
112        uni_cypher::ast::SetClause {
113            items: s
114                .items
115                .into_iter()
116                .map(|item| self.rewrite_set_item(item))
117                .collect(),
118        }
119    }
120
121    fn rewrite_set_item(&mut self, item: uni_cypher::ast::SetItem) -> uni_cypher::ast::SetItem {
122        use uni_cypher::ast::SetItem;
123
124        match item {
125            SetItem::Property { expr, value } => SetItem::Property {
126                expr: self.rewrite_expr(expr),
127                value: self.rewrite_expr(value),
128            },
129            SetItem::Variable { variable, value } => SetItem::Variable {
130                variable,
131                value: self.rewrite_expr(value),
132            },
133            SetItem::VariablePlus { variable, value } => SetItem::VariablePlus {
134                variable,
135                value: self.rewrite_expr(value),
136            },
137            SetItem::Labels { variable, labels } => SetItem::Labels { variable, labels },
138        }
139    }
140
141    fn rewrite_remove_clause(
142        &mut self,
143        r: uni_cypher::ast::RemoveClause,
144    ) -> uni_cypher::ast::RemoveClause {
145        uni_cypher::ast::RemoveClause {
146            items: r
147                .items
148                .into_iter()
149                .map(|item| self.rewrite_remove_item(item))
150                .collect(),
151        }
152    }
153
154    fn rewrite_remove_item(
155        &mut self,
156        item: uni_cypher::ast::RemoveItem,
157    ) -> uni_cypher::ast::RemoveItem {
158        use uni_cypher::ast::RemoveItem;
159
160        match item {
161            RemoveItem::Property(expr) => RemoveItem::Property(self.rewrite_expr(expr)),
162            RemoveItem::Labels { variable, labels } => RemoveItem::Labels { variable, labels },
163        }
164    }
165
166    fn rewrite_unwind_clause(
167        &mut self,
168        u: uni_cypher::ast::UnwindClause,
169    ) -> uni_cypher::ast::UnwindClause {
170        uni_cypher::ast::UnwindClause {
171            expr: self.rewrite_expr(u.expr),
172            variable: u.variable,
173        }
174    }
175
176    fn rewrite_pattern(&mut self, pattern: uni_cypher::ast::Pattern) -> uni_cypher::ast::Pattern {
177        uni_cypher::ast::Pattern {
178            paths: pattern
179                .paths
180                .into_iter()
181                .map(|path| self.rewrite_path_pattern(path))
182                .collect(),
183        }
184    }
185
186    fn rewrite_path_pattern(
187        &mut self,
188        path: uni_cypher::ast::PathPattern,
189    ) -> uni_cypher::ast::PathPattern {
190        uni_cypher::ast::PathPattern {
191            variable: path.variable,
192            elements: path
193                .elements
194                .into_iter()
195                .map(|elem| self.rewrite_pattern_element(elem))
196                .collect(),
197            shortest_path_mode: path.shortest_path_mode,
198        }
199    }
200
201    fn rewrite_pattern_element(
202        &mut self,
203        elem: uni_cypher::ast::PatternElement,
204    ) -> uni_cypher::ast::PatternElement {
205        use uni_cypher::ast::PatternElement;
206
207        match elem {
208            PatternElement::Node(node) => PatternElement::Node(uni_cypher::ast::NodePattern {
209                variable: node.variable,
210                labels: node.labels,
211                properties: node.properties.map(|expr| self.rewrite_expr(expr)),
212                where_clause: node.where_clause.map(|expr| self.rewrite_expr(expr)),
213            }),
214            PatternElement::Relationship(rel) => {
215                PatternElement::Relationship(uni_cypher::ast::RelationshipPattern {
216                    variable: rel.variable,
217                    types: rel.types,
218                    direction: rel.direction,
219                    properties: rel.properties.map(|expr| self.rewrite_expr(expr)),
220                    range: rel.range,
221                    where_clause: rel.where_clause.map(|expr| self.rewrite_expr(expr)),
222                })
223            }
224            PatternElement::Parenthesized { pattern, range } => PatternElement::Parenthesized {
225                pattern: Box::new(self.rewrite_path_pattern(*pattern)),
226                range,
227            },
228        }
229    }
230
231    fn rewrite_order_by(
232        &mut self,
233        order_by: Option<Vec<uni_cypher::ast::SortItem>>,
234    ) -> Option<Vec<uni_cypher::ast::SortItem>> {
235        order_by.map(|items| {
236            items
237                .into_iter()
238                .map(|item| uni_cypher::ast::SortItem {
239                    expr: self.rewrite_expr(item.expr),
240                    ascending: item.ascending,
241                })
242                .collect()
243        })
244    }
245
246    fn rewrite_return_clause(
247        &mut self,
248        r: uni_cypher::ast::ReturnClause,
249    ) -> uni_cypher::ast::ReturnClause {
250        uni_cypher::ast::ReturnClause {
251            distinct: r.distinct,
252            items: r
253                .items
254                .into_iter()
255                .map(|item| self.rewrite_return_item(item))
256                .collect(),
257            order_by: self.rewrite_order_by(r.order_by),
258            skip: r.skip.map(|e| self.rewrite_expr(e)),
259            limit: r.limit.map(|e| self.rewrite_expr(e)),
260        }
261    }
262
263    fn rewrite_return_item(
264        &mut self,
265        item: uni_cypher::ast::ReturnItem,
266    ) -> uni_cypher::ast::ReturnItem {
267        use uni_cypher::ast::ReturnItem;
268
269        match item {
270            ReturnItem::All => ReturnItem::All,
271            ReturnItem::Expr {
272                expr,
273                alias,
274                source_text,
275            } => ReturnItem::Expr {
276                expr: self.rewrite_expr(expr),
277                alias,
278                source_text,
279            },
280        }
281    }
282
283    fn rewrite_with_clause(
284        &mut self,
285        w: uni_cypher::ast::WithClause,
286    ) -> uni_cypher::ast::WithClause {
287        uni_cypher::ast::WithClause {
288            distinct: w.distinct,
289            items: w
290                .items
291                .into_iter()
292                .map(|item| self.rewrite_return_item(item))
293                .collect(),
294            order_by: self.rewrite_order_by(w.order_by),
295            skip: w.skip.map(|e| self.rewrite_expr(e)),
296            limit: w.limit.map(|e| self.rewrite_expr(e)),
297            where_clause: w.where_clause.map(|e| self.rewrite_expr(e)),
298        }
299    }
300
301    /// Walk and rewrite an expression tree
302    pub fn rewrite_expr(&mut self, expr: Expr) -> Expr {
303        match expr {
304            Expr::PatternComprehension {
305                path_variable,
306                pattern,
307                where_clause,
308                map_expr,
309            } => Expr::PatternComprehension {
310                path_variable,
311                pattern, // Pattern structure doesn't need rewriting
312                where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
313                map_expr: Box::new(self.rewrite_expr(*map_expr)),
314            },
315            // TODO: Recurse into CollectSubquery inner query for consistency
316            // with Exists and CountSubquery handling below
317            Expr::CollectSubquery(_) => expr,
318            // Try to rewrite function calls
319            Expr::FunctionCall {
320                name,
321                args,
322                distinct,
323                window_spec,
324            } => self.try_rewrite_function(name, args, distinct, window_spec),
325
326            // Recursively handle all other expression variants
327            Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
328                left: Box::new(self.rewrite_expr(*left)),
329                op,
330                right: Box::new(self.rewrite_expr(*right)),
331            },
332
333            Expr::UnaryOp { op, expr } => Expr::UnaryOp {
334                op,
335                expr: Box::new(self.rewrite_expr(*expr)),
336            },
337
338            Expr::Property(expr, prop) => Expr::Property(Box::new(self.rewrite_expr(*expr)), prop),
339
340            Expr::List(exprs) => {
341                Expr::List(exprs.into_iter().map(|e| self.rewrite_expr(e)).collect())
342            }
343
344            Expr::Map(entries) => Expr::Map(
345                entries
346                    .into_iter()
347                    .map(|(k, v)| (k, self.rewrite_expr(v)))
348                    .collect(),
349            ),
350
351            Expr::Case {
352                expr,
353                when_then,
354                else_expr,
355            } => Expr::Case {
356                expr: expr.map(|e| Box::new(self.rewrite_expr(*e))),
357                when_then: when_then
358                    .into_iter()
359                    .map(|(w, t)| (self.rewrite_expr(w), self.rewrite_expr(t)))
360                    .collect(),
361                else_expr: else_expr.map(|e| Box::new(self.rewrite_expr(*e))),
362            },
363
364            Expr::Exists {
365                query,
366                from_pattern_predicate,
367            } => Expr::Exists {
368                query: Box::new(self.rewrite_query(*query)),
369                from_pattern_predicate,
370            },
371
372            Expr::CountSubquery(query) => Expr::CountSubquery(Box::new(self.rewrite_query(*query))),
373
374            Expr::IsNull(expr) => Expr::IsNull(Box::new(self.rewrite_expr(*expr))),
375
376            Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(self.rewrite_expr(*expr))),
377
378            Expr::IsUnique(expr) => Expr::IsUnique(Box::new(self.rewrite_expr(*expr))),
379
380            Expr::In { expr, list } => Expr::In {
381                expr: Box::new(self.rewrite_expr(*expr)),
382                list: Box::new(self.rewrite_expr(*list)),
383            },
384
385            Expr::ArrayIndex { array, index } => Expr::ArrayIndex {
386                array: Box::new(self.rewrite_expr(*array)),
387                index: Box::new(self.rewrite_expr(*index)),
388            },
389
390            Expr::ArraySlice { array, start, end } => Expr::ArraySlice {
391                array: Box::new(self.rewrite_expr(*array)),
392                start: start.map(|e| Box::new(self.rewrite_expr(*e))),
393                end: end.map(|e| Box::new(self.rewrite_expr(*e))),
394            },
395
396            Expr::Quantifier {
397                quantifier,
398                variable,
399                list,
400                predicate,
401            } => Expr::Quantifier {
402                quantifier,
403                variable,
404                list: Box::new(self.rewrite_expr(*list)),
405                predicate: Box::new(self.rewrite_expr(*predicate)),
406            },
407
408            Expr::Reduce {
409                accumulator,
410                init,
411                variable,
412                list,
413                expr,
414            } => Expr::Reduce {
415                accumulator,
416                init: Box::new(self.rewrite_expr(*init)),
417                variable,
418                list: Box::new(self.rewrite_expr(*list)),
419                expr: Box::new(self.rewrite_expr(*expr)),
420            },
421
422            Expr::ListComprehension {
423                variable,
424                list,
425                where_clause,
426                map_expr,
427            } => Expr::ListComprehension {
428                variable,
429                list: Box::new(self.rewrite_expr(*list)),
430                where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
431                map_expr: Box::new(self.rewrite_expr(*map_expr)),
432            },
433
434            Expr::ValidAt {
435                entity,
436                timestamp,
437                start_prop,
438                end_prop,
439            } => Expr::ValidAt {
440                entity: Box::new(self.rewrite_expr(*entity)),
441                timestamp: Box::new(self.rewrite_expr(*timestamp)),
442                start_prop,
443                end_prop,
444            },
445
446            Expr::MapProjection { base, items } => Expr::MapProjection {
447                base: Box::new(self.rewrite_expr(*base)),
448                items: items
449                    .into_iter()
450                    .map(|item| match item {
451                        MapProjectionItem::LiteralEntry(k, v) => {
452                            MapProjectionItem::LiteralEntry(k, Box::new(self.rewrite_expr(*v)))
453                        }
454                        other => other,
455                    })
456                    .collect(),
457            },
458
459            Expr::LabelCheck { expr, labels } => Expr::LabelCheck {
460                expr: Box::new(self.rewrite_expr(*expr)),
461                labels,
462            },
463
464            // Leaf nodes - no rewriting needed
465            Expr::Literal(_) | Expr::Parameter(_) | Expr::Variable(_) | Expr::Wildcard => expr,
466        }
467    }
468
469    /// Try to rewrite a function call
470    fn try_rewrite_function(
471        &mut self,
472        name: String,
473        args: Vec<Expr>,
474        distinct: bool,
475        window_spec: Option<uni_cypher::ast::WindowSpec>,
476    ) -> Expr {
477        // First, recursively rewrite arguments
478        let rewritten_args: Vec<Expr> =
479            args.into_iter().map(|arg| self.rewrite_expr(arg)).collect();
480
481        // Record that we visited this function
482        self.context.stats.record_visit();
483
484        // Helper to construct fallback function call
485        let make_fallback = |name, args| Expr::FunctionCall {
486            name,
487            args,
488            distinct,
489            window_spec: window_spec.clone(),
490        };
491
492        // Check if we have a rewrite rule for this function
493        let Some(rule) = self.registry.get_rule(&name) else {
494            return make_fallback(name, rewritten_args);
495        };
496
497        // Validate arguments
498        if let Err(e) = rule.validate_args(&rewritten_args) {
499            self.context.stats.record_failure(&name, e);
500            if self.context.config.verbose_logging {
501                tracing::debug!(
502                    "Rewrite validation failed for {}: {:?}",
503                    name,
504                    self.context.stats.errors.last()
505                );
506            }
507            return make_fallback(name, rewritten_args);
508        }
509
510        // Check if rule is applicable in current context
511        if !rule.is_applicable(&self.context) {
512            let error = RewriteError::NotApplicable {
513                reason: "Context requirements not met".to_string(),
514            };
515            self.context.stats.record_failure(&name, error);
516            if self.context.config.verbose_logging {
517                tracing::debug!("Rewrite not applicable for {}", name);
518            }
519            return make_fallback(name, rewritten_args);
520        }
521
522        // Apply rewrite
523        match rule.rewrite(rewritten_args.clone(), &self.context) {
524            Ok(rewritten_expr) => {
525                self.context.stats.record_success(&name);
526                if self.context.config.verbose_logging {
527                    tracing::debug!("Rewrote function call: {} -> {:?}", name, rewritten_expr);
528                } else {
529                    tracing::info!("Rewrote function: {}", name);
530                }
531                rewritten_expr
532            }
533            Err(e) => {
534                self.context.stats.record_failure(&name, e);
535                if self.context.config.verbose_logging {
536                    tracing::debug!(
537                        "Rewrite failed for {}: {:?}",
538                        name,
539                        self.context.stats.errors.last()
540                    );
541                }
542                make_fallback(name, rewritten_args)
543            }
544        }
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use crate::query::rewrite::context::RewriteConfig;
552    use uni_cypher::ast::CypherLiteral;
553
554    #[test]
555    fn test_walker_visits_nested_expressions() {
556        let registry = RewriteRegistry::new();
557        let config = RewriteConfig::default();
558        let mut walker = ExpressionWalker::new(&registry, RewriteContext::with_config(config));
559
560        // Nested expression with function calls
561        let expr = Expr::BinaryOp {
562            left: Box::new(Expr::FunctionCall {
563                name: "func1".into(),
564                args: vec![Expr::Literal(CypherLiteral::Integer(1))],
565                distinct: false,
566                window_spec: None,
567            }),
568            op: uni_cypher::ast::BinaryOp::And,
569            right: Box::new(Expr::FunctionCall {
570                name: "func2".into(),
571                args: vec![Expr::Literal(CypherLiteral::Integer(2))],
572                distinct: false,
573                window_spec: None,
574            }),
575        };
576
577        let _ = walker.rewrite_expr(expr);
578
579        // Both function calls should have been visited
580        assert_eq!(walker.context().stats.functions_visited, 2);
581    }
582
583    #[test]
584    fn test_walker_fallback_without_rules() {
585        let registry = RewriteRegistry::new();
586        let config = RewriteConfig::default();
587        let mut walker = ExpressionWalker::new(&registry, RewriteContext::with_config(config));
588
589        let original = Expr::FunctionCall {
590            name: "unknown".into(),
591            args: vec![Expr::Literal(CypherLiteral::Integer(1))],
592            distinct: false,
593            window_spec: None,
594        };
595
596        let rewritten = walker.rewrite_expr(original.clone());
597
598        // Should return unchanged (but with potentially rewritten arguments)
599        assert!(matches!(rewritten, Expr::FunctionCall { name, .. } if name == "unknown"));
600        assert_eq!(walker.context().stats.functions_visited, 1);
601    }
602}