Skip to main content

uni_query/query/rewrite/
walker.rs

1/// Expression tree walker for applying rewrite rules
2use crate::query::rewrite::context::RewriteContext;
3use crate::query::rewrite::error::RewriteError;
4use crate::query::rewrite::registry::RewriteRegistry;
5use uni_cypher::ast::{Expr, MapProjectionItem, Query, Statement};
6
7/// Walks expression trees and applies rewrite rules
8pub struct ExpressionWalker<'a> {
9    registry: &'a RewriteRegistry,
10    context: RewriteContext,
11}
12
13impl<'a> ExpressionWalker<'a> {
14    /// Create a new expression walker
15    pub fn new(registry: &'a RewriteRegistry, context: RewriteContext) -> Self {
16        Self { registry, context }
17    }
18
19    /// Get the rewrite context (for accessing statistics)
20    pub fn context(&self) -> &RewriteContext {
21        &self.context
22    }
23
24    /// Get a mutable reference to the rewrite context
25    pub fn context_mut(&mut self) -> &mut RewriteContext {
26        &mut self.context
27    }
28
29    /// Take ownership of the context (for retrieving statistics)
30    pub fn into_context(self) -> RewriteContext {
31        self.context
32    }
33
34    /// Rewrite a complete statement
35    pub fn rewrite_statement(&mut self, stmt: Statement) -> Statement {
36        Statement {
37            clauses: stmt
38                .clauses
39                .into_iter()
40                .map(|c| self.rewrite_clause(c))
41                .collect(),
42        }
43    }
44
45    /// Rewrite a query
46    pub fn rewrite_query(&mut self, query: Query) -> Query {
47        match query {
48            Query::Single(stmt) => Query::Single(self.rewrite_statement(stmt)),
49            Query::Union { left, right, all } => Query::Union {
50                left: Box::new(self.rewrite_query(*left)),
51                right: Box::new(self.rewrite_query(*right)),
52                all,
53            },
54            Query::Schema(schema_cmd) => Query::Schema(schema_cmd),
55            Query::Explain(inner) => Query::Explain(Box::new(self.rewrite_query(*inner))),
56            Query::TimeTravel { .. } => {
57                unreachable!("TimeTravel should be resolved at API layer before rewriting")
58            }
59        }
60    }
61
62    /// Rewrite a clause
63    fn rewrite_clause(&mut self, clause: uni_cypher::ast::Clause) -> uni_cypher::ast::Clause {
64        use uni_cypher::ast::Clause;
65
66        match clause {
67            Clause::Match(m) => Clause::Match(self.rewrite_match_clause(m)),
68            Clause::Create(c) => Clause::Create(self.rewrite_create_clause(c)),
69            Clause::Return(r) => Clause::Return(self.rewrite_return_clause(r)),
70            Clause::With(w) => Clause::With(self.rewrite_with_clause(w)),
71            Clause::Unwind(u) => Clause::Unwind(self.rewrite_unwind_clause(u)),
72            Clause::Set(s) => Clause::Set(self.rewrite_set_clause(s)),
73            Clause::Delete(d) => Clause::Delete(self.rewrite_delete_clause(d)),
74            Clause::Remove(r) => Clause::Remove(self.rewrite_remove_clause(r)),
75            // Other clauses that don't contain expressions or are not yet handled
76            other => other,
77        }
78    }
79
80    fn rewrite_match_clause(
81        &mut self,
82        m: uni_cypher::ast::MatchClause,
83    ) -> uni_cypher::ast::MatchClause {
84        uni_cypher::ast::MatchClause {
85            optional: m.optional,
86            pattern: self.rewrite_pattern(m.pattern),
87            where_clause: m.where_clause.map(|e| self.rewrite_expr(e)),
88        }
89    }
90
91    fn rewrite_create_clause(
92        &mut self,
93        c: uni_cypher::ast::CreateClause,
94    ) -> uni_cypher::ast::CreateClause {
95        uni_cypher::ast::CreateClause {
96            pattern: self.rewrite_pattern(c.pattern),
97        }
98    }
99
100    fn rewrite_delete_clause(
101        &mut self,
102        d: uni_cypher::ast::DeleteClause,
103    ) -> uni_cypher::ast::DeleteClause {
104        uni_cypher::ast::DeleteClause {
105            detach: d.detach,
106            items: d.items.into_iter().map(|e| self.rewrite_expr(e)).collect(),
107        }
108    }
109
110    fn rewrite_set_clause(&mut self, s: uni_cypher::ast::SetClause) -> uni_cypher::ast::SetClause {
111        uni_cypher::ast::SetClause {
112            items: s
113                .items
114                .into_iter()
115                .map(|item| self.rewrite_set_item(item))
116                .collect(),
117        }
118    }
119
120    fn rewrite_set_item(&mut self, item: uni_cypher::ast::SetItem) -> uni_cypher::ast::SetItem {
121        use uni_cypher::ast::SetItem;
122
123        match item {
124            SetItem::Property { expr, value } => SetItem::Property {
125                expr: self.rewrite_expr(expr),
126                value: self.rewrite_expr(value),
127            },
128            SetItem::Variable { variable, value } => SetItem::Variable {
129                variable,
130                value: self.rewrite_expr(value),
131            },
132            SetItem::VariablePlus { variable, value } => SetItem::VariablePlus {
133                variable,
134                value: self.rewrite_expr(value),
135            },
136            SetItem::Labels { variable, labels } => SetItem::Labels { variable, labels },
137        }
138    }
139
140    fn rewrite_remove_clause(
141        &mut self,
142        r: uni_cypher::ast::RemoveClause,
143    ) -> uni_cypher::ast::RemoveClause {
144        uni_cypher::ast::RemoveClause {
145            items: r
146                .items
147                .into_iter()
148                .map(|item| self.rewrite_remove_item(item))
149                .collect(),
150        }
151    }
152
153    fn rewrite_remove_item(
154        &mut self,
155        item: uni_cypher::ast::RemoveItem,
156    ) -> uni_cypher::ast::RemoveItem {
157        use uni_cypher::ast::RemoveItem;
158
159        match item {
160            RemoveItem::Property(expr) => RemoveItem::Property(self.rewrite_expr(expr)),
161            RemoveItem::Labels { variable, labels } => RemoveItem::Labels { variable, labels },
162        }
163    }
164
165    fn rewrite_unwind_clause(
166        &mut self,
167        u: uni_cypher::ast::UnwindClause,
168    ) -> uni_cypher::ast::UnwindClause {
169        uni_cypher::ast::UnwindClause {
170            expr: self.rewrite_expr(u.expr),
171            variable: u.variable,
172        }
173    }
174
175    fn rewrite_pattern(&mut self, pattern: uni_cypher::ast::Pattern) -> uni_cypher::ast::Pattern {
176        uni_cypher::ast::Pattern {
177            paths: pattern
178                .paths
179                .into_iter()
180                .map(|path| self.rewrite_path_pattern(path))
181                .collect(),
182        }
183    }
184
185    fn rewrite_path_pattern(
186        &mut self,
187        path: uni_cypher::ast::PathPattern,
188    ) -> uni_cypher::ast::PathPattern {
189        uni_cypher::ast::PathPattern {
190            variable: path.variable,
191            elements: path
192                .elements
193                .into_iter()
194                .map(|elem| self.rewrite_pattern_element(elem))
195                .collect(),
196            shortest_path_mode: path.shortest_path_mode,
197        }
198    }
199
200    fn rewrite_pattern_element(
201        &mut self,
202        elem: uni_cypher::ast::PatternElement,
203    ) -> uni_cypher::ast::PatternElement {
204        use uni_cypher::ast::PatternElement;
205
206        match elem {
207            PatternElement::Node(node) => PatternElement::Node(uni_cypher::ast::NodePattern {
208                variable: node.variable,
209                labels: node.labels,
210                properties: node.properties.map(|expr| self.rewrite_expr(expr)),
211                where_clause: node.where_clause.map(|expr| self.rewrite_expr(expr)),
212            }),
213            PatternElement::Relationship(rel) => {
214                PatternElement::Relationship(uni_cypher::ast::RelationshipPattern {
215                    variable: rel.variable,
216                    types: rel.types,
217                    direction: rel.direction,
218                    properties: rel.properties.map(|expr| self.rewrite_expr(expr)),
219                    range: rel.range,
220                    where_clause: rel.where_clause.map(|expr| self.rewrite_expr(expr)),
221                })
222            }
223            PatternElement::Parenthesized { pattern, range } => PatternElement::Parenthesized {
224                pattern: Box::new(self.rewrite_path_pattern(*pattern)),
225                range,
226            },
227        }
228    }
229
230    fn rewrite_order_by(
231        &mut self,
232        order_by: Option<Vec<uni_cypher::ast::SortItem>>,
233    ) -> Option<Vec<uni_cypher::ast::SortItem>> {
234        order_by.map(|items| {
235            items
236                .into_iter()
237                .map(|item| uni_cypher::ast::SortItem {
238                    expr: self.rewrite_expr(item.expr),
239                    ascending: item.ascending,
240                })
241                .collect()
242        })
243    }
244
245    fn rewrite_return_clause(
246        &mut self,
247        r: uni_cypher::ast::ReturnClause,
248    ) -> uni_cypher::ast::ReturnClause {
249        uni_cypher::ast::ReturnClause {
250            distinct: r.distinct,
251            items: r
252                .items
253                .into_iter()
254                .map(|item| self.rewrite_return_item(item))
255                .collect(),
256            order_by: self.rewrite_order_by(r.order_by),
257            skip: r.skip.map(|e| self.rewrite_expr(e)),
258            limit: r.limit.map(|e| self.rewrite_expr(e)),
259        }
260    }
261
262    fn rewrite_return_item(
263        &mut self,
264        item: uni_cypher::ast::ReturnItem,
265    ) -> uni_cypher::ast::ReturnItem {
266        use uni_cypher::ast::ReturnItem;
267
268        match item {
269            ReturnItem::All => ReturnItem::All,
270            ReturnItem::Expr {
271                expr,
272                alias,
273                source_text,
274            } => ReturnItem::Expr {
275                expr: self.rewrite_expr(expr),
276                alias,
277                source_text,
278            },
279        }
280    }
281
282    fn rewrite_with_clause(
283        &mut self,
284        w: uni_cypher::ast::WithClause,
285    ) -> uni_cypher::ast::WithClause {
286        uni_cypher::ast::WithClause {
287            distinct: w.distinct,
288            items: w
289                .items
290                .into_iter()
291                .map(|item| self.rewrite_return_item(item))
292                .collect(),
293            order_by: self.rewrite_order_by(w.order_by),
294            skip: w.skip.map(|e| self.rewrite_expr(e)),
295            limit: w.limit.map(|e| self.rewrite_expr(e)),
296            where_clause: w.where_clause.map(|e| self.rewrite_expr(e)),
297        }
298    }
299
300    /// Walk and rewrite an expression tree
301    pub fn rewrite_expr(&mut self, expr: Expr) -> Expr {
302        match expr {
303            Expr::PatternComprehension {
304                path_variable,
305                pattern,
306                where_clause,
307                map_expr,
308            } => Expr::PatternComprehension {
309                path_variable,
310                pattern, // Pattern structure doesn't need rewriting
311                where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
312                map_expr: Box::new(self.rewrite_expr(*map_expr)),
313            },
314            Expr::CollectSubquery(query) => {
315                Expr::CollectSubquery(Box::new(self.rewrite_query(*query)))
316            }
317            // Try to rewrite function calls
318            Expr::FunctionCall {
319                name,
320                args,
321                distinct,
322                window_spec,
323            } => self.try_rewrite_function(name, args, distinct, window_spec),
324
325            // Recursively handle all other expression variants
326            Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
327                left: Box::new(self.rewrite_expr(*left)),
328                op,
329                right: Box::new(self.rewrite_expr(*right)),
330            },
331
332            Expr::UnaryOp { op, expr } => Expr::UnaryOp {
333                op,
334                expr: Box::new(self.rewrite_expr(*expr)),
335            },
336
337            Expr::Property(expr, prop) => Expr::Property(Box::new(self.rewrite_expr(*expr)), prop),
338
339            Expr::List(exprs) => {
340                Expr::List(exprs.into_iter().map(|e| self.rewrite_expr(e)).collect())
341            }
342
343            Expr::Map(entries) => Expr::Map(
344                entries
345                    .into_iter()
346                    .map(|(k, v)| (k, self.rewrite_expr(v)))
347                    .collect(),
348            ),
349
350            Expr::Case {
351                expr,
352                when_then,
353                else_expr,
354            } => Expr::Case {
355                expr: expr.map(|e| Box::new(self.rewrite_expr(*e))),
356                when_then: when_then
357                    .into_iter()
358                    .map(|(w, t)| (self.rewrite_expr(w), self.rewrite_expr(t)))
359                    .collect(),
360                else_expr: else_expr.map(|e| Box::new(self.rewrite_expr(*e))),
361            },
362
363            Expr::Exists {
364                query,
365                from_pattern_predicate,
366            } => Expr::Exists {
367                query: Box::new(self.rewrite_query(*query)),
368                from_pattern_predicate,
369            },
370
371            Expr::CountSubquery(query) => Expr::CountSubquery(Box::new(self.rewrite_query(*query))),
372
373            Expr::IsNull(expr) => Expr::IsNull(Box::new(self.rewrite_expr(*expr))),
374
375            Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(self.rewrite_expr(*expr))),
376
377            Expr::IsUnique(expr) => Expr::IsUnique(Box::new(self.rewrite_expr(*expr))),
378
379            Expr::In { expr, list } => Expr::In {
380                expr: Box::new(self.rewrite_expr(*expr)),
381                list: Box::new(self.rewrite_expr(*list)),
382            },
383
384            Expr::ArrayIndex { array, index } => Expr::ArrayIndex {
385                array: Box::new(self.rewrite_expr(*array)),
386                index: Box::new(self.rewrite_expr(*index)),
387            },
388
389            Expr::ArraySlice { array, start, end } => Expr::ArraySlice {
390                array: Box::new(self.rewrite_expr(*array)),
391                start: start.map(|e| Box::new(self.rewrite_expr(*e))),
392                end: end.map(|e| Box::new(self.rewrite_expr(*e))),
393            },
394
395            Expr::Quantifier {
396                quantifier,
397                variable,
398                list,
399                predicate,
400            } => Expr::Quantifier {
401                quantifier,
402                variable,
403                list: Box::new(self.rewrite_expr(*list)),
404                predicate: Box::new(self.rewrite_expr(*predicate)),
405            },
406
407            Expr::Reduce {
408                accumulator,
409                init,
410                variable,
411                list,
412                expr,
413            } => Expr::Reduce {
414                accumulator,
415                init: Box::new(self.rewrite_expr(*init)),
416                variable,
417                list: Box::new(self.rewrite_expr(*list)),
418                expr: Box::new(self.rewrite_expr(*expr)),
419            },
420
421            Expr::ListComprehension {
422                variable,
423                list,
424                where_clause,
425                map_expr,
426            } => Expr::ListComprehension {
427                variable,
428                list: Box::new(self.rewrite_expr(*list)),
429                where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
430                map_expr: Box::new(self.rewrite_expr(*map_expr)),
431            },
432
433            Expr::ValidAt {
434                entity,
435                timestamp,
436                start_prop,
437                end_prop,
438            } => Expr::ValidAt {
439                entity: Box::new(self.rewrite_expr(*entity)),
440                timestamp: Box::new(self.rewrite_expr(*timestamp)),
441                start_prop,
442                end_prop,
443            },
444
445            Expr::MapProjection { base, items } => Expr::MapProjection {
446                base: Box::new(self.rewrite_expr(*base)),
447                items: items
448                    .into_iter()
449                    .map(|item| match item {
450                        MapProjectionItem::LiteralEntry(k, v) => {
451                            MapProjectionItem::LiteralEntry(k, Box::new(self.rewrite_expr(*v)))
452                        }
453                        other => other,
454                    })
455                    .collect(),
456            },
457
458            Expr::LabelCheck { expr, labels } => Expr::LabelCheck {
459                expr: Box::new(self.rewrite_expr(*expr)),
460                labels,
461            },
462
463            // Leaf nodes - no rewriting needed
464            Expr::Literal(_) | Expr::Parameter(_) | Expr::Variable(_) | Expr::Wildcard => expr,
465        }
466    }
467
468    /// Try to rewrite a function call
469    fn try_rewrite_function(
470        &mut self,
471        name: String,
472        args: Vec<Expr>,
473        distinct: bool,
474        window_spec: Option<uni_cypher::ast::WindowSpec>,
475    ) -> Expr {
476        // First, recursively rewrite arguments
477        let rewritten_args: Vec<Expr> =
478            args.into_iter().map(|arg| self.rewrite_expr(arg)).collect();
479
480        // Record that we visited this function
481        self.context.stats.record_visit();
482
483        // Helper to construct fallback function call
484        let make_fallback = |name, args| Expr::FunctionCall {
485            name,
486            args,
487            distinct,
488            window_spec: window_spec.clone(),
489        };
490
491        // Check if we have a rewrite rule for this function
492        let Some(rule) = self.registry.get_rule(&name) else {
493            return make_fallback(name, rewritten_args);
494        };
495
496        // Validate arguments
497        if let Err(e) = rule.validate_args(&rewritten_args) {
498            self.context.stats.record_failure(&name, e);
499            if self.context.config.verbose_logging {
500                tracing::debug!(
501                    "Rewrite validation failed for {}: {:?}",
502                    name,
503                    self.context.stats.errors.last()
504                );
505            }
506            return make_fallback(name, rewritten_args);
507        }
508
509        // Check if rule is applicable in current context
510        if !rule.is_applicable(&self.context) {
511            let error = RewriteError::NotApplicable {
512                reason: "Context requirements not met".to_string(),
513            };
514            self.context.stats.record_failure(&name, error);
515            if self.context.config.verbose_logging {
516                tracing::debug!("Rewrite not applicable for {}", name);
517            }
518            return make_fallback(name, rewritten_args);
519        }
520
521        // Apply rewrite
522        match rule.rewrite(rewritten_args.clone(), &self.context) {
523            Ok(rewritten_expr) => {
524                self.context.stats.record_success(&name);
525                if self.context.config.verbose_logging {
526                    tracing::debug!("Rewrote function call: {} -> {:?}", name, rewritten_expr);
527                } else {
528                    tracing::info!("Rewrote function: {}", name);
529                }
530                rewritten_expr
531            }
532            Err(e) => {
533                self.context.stats.record_failure(&name, e);
534                if self.context.config.verbose_logging {
535                    tracing::debug!(
536                        "Rewrite failed for {}: {:?}",
537                        name,
538                        self.context.stats.errors.last()
539                    );
540                }
541                make_fallback(name, rewritten_args)
542            }
543        }
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use crate::query::rewrite::context::RewriteConfig;
551    use uni_cypher::ast::CypherLiteral;
552
553    #[test]
554    fn test_walker_visits_nested_expressions() {
555        let registry = RewriteRegistry::new();
556        let config = RewriteConfig::default();
557        let mut walker = ExpressionWalker::new(&registry, RewriteContext::with_config(config));
558
559        // Nested expression with function calls
560        let expr = Expr::BinaryOp {
561            left: Box::new(Expr::FunctionCall {
562                name: "func1".into(),
563                args: vec![Expr::Literal(CypherLiteral::Integer(1))],
564                distinct: false,
565                window_spec: None,
566            }),
567            op: uni_cypher::ast::BinaryOp::And,
568            right: Box::new(Expr::FunctionCall {
569                name: "func2".into(),
570                args: vec![Expr::Literal(CypherLiteral::Integer(2))],
571                distinct: false,
572                window_spec: None,
573            }),
574        };
575
576        let _ = walker.rewrite_expr(expr);
577
578        // Both function calls should have been visited
579        assert_eq!(walker.context().stats.functions_visited, 2);
580    }
581
582    #[test]
583    fn test_walker_fallback_without_rules() {
584        let registry = RewriteRegistry::new();
585        let config = RewriteConfig::default();
586        let mut walker = ExpressionWalker::new(&registry, RewriteContext::with_config(config));
587
588        let original = Expr::FunctionCall {
589            name: "unknown".into(),
590            args: vec![Expr::Literal(CypherLiteral::Integer(1))],
591            distinct: false,
592            window_spec: None,
593        };
594
595        let rewritten = walker.rewrite_expr(original.clone());
596
597        // Should return unchanged (but with potentially rewritten arguments)
598        assert!(matches!(rewritten, Expr::FunctionCall { name, .. } if name == "unknown"));
599        assert_eq!(walker.context().stats.functions_visited, 1);
600    }
601}