1use crate::query::rewrite::context::RewriteContext;
3use crate::query::rewrite::error::RewriteError;
4use crate::query::rewrite::registry::RewriteRegistry;
5use uni_cypher::ast::{Expr, MapProjectionItem, Query, Statement};
6
7pub struct ExpressionWalker<'a> {
9 registry: &'a RewriteRegistry,
10 context: RewriteContext,
11}
12
13impl<'a> ExpressionWalker<'a> {
14 pub fn new(registry: &'a RewriteRegistry, context: RewriteContext) -> Self {
16 Self { registry, context }
17 }
18
19 pub fn context(&self) -> &RewriteContext {
21 &self.context
22 }
23
24 pub fn context_mut(&mut self) -> &mut RewriteContext {
26 &mut self.context
27 }
28
29 pub fn into_context(self) -> RewriteContext {
31 self.context
32 }
33
34 pub fn rewrite_statement(&mut self, stmt: Statement) -> Statement {
36 Statement {
37 clauses: stmt
38 .clauses
39 .into_iter()
40 .map(|c| self.rewrite_clause(c))
41 .collect(),
42 }
43 }
44
45 pub fn rewrite_query(&mut self, query: Query) -> Query {
47 match query {
48 Query::Single(stmt) => Query::Single(self.rewrite_statement(stmt)),
49 Query::Union { left, right, all } => Query::Union {
50 left: Box::new(self.rewrite_query(*left)),
51 right: Box::new(self.rewrite_query(*right)),
52 all,
53 },
54 Query::Schema(schema_cmd) => Query::Schema(schema_cmd),
55 Query::Explain(inner) => Query::Explain(Box::new(self.rewrite_query(*inner))),
56 Query::TimeTravel { .. } => {
57 unreachable!("TimeTravel should be resolved at API layer before rewriting")
58 }
59 }
60 }
61
62 fn rewrite_clause(&mut self, clause: uni_cypher::ast::Clause) -> uni_cypher::ast::Clause {
64 use uni_cypher::ast::Clause;
65
66 match clause {
67 Clause::Match(m) => Clause::Match(self.rewrite_match_clause(m)),
68 Clause::Create(c) => Clause::Create(self.rewrite_create_clause(c)),
69 Clause::Return(r) => Clause::Return(self.rewrite_return_clause(r)),
70 Clause::With(w) => Clause::With(self.rewrite_with_clause(w)),
71 Clause::Unwind(u) => Clause::Unwind(self.rewrite_unwind_clause(u)),
72 Clause::Set(s) => Clause::Set(self.rewrite_set_clause(s)),
73 Clause::Delete(d) => Clause::Delete(self.rewrite_delete_clause(d)),
74 Clause::Remove(r) => Clause::Remove(self.rewrite_remove_clause(r)),
75 other => other,
77 }
78 }
79
80 fn rewrite_match_clause(
81 &mut self,
82 m: uni_cypher::ast::MatchClause,
83 ) -> uni_cypher::ast::MatchClause {
84 uni_cypher::ast::MatchClause {
85 optional: m.optional,
86 pattern: self.rewrite_pattern(m.pattern),
87 where_clause: m.where_clause.map(|e| self.rewrite_expr(e)),
88 }
89 }
90
91 fn rewrite_create_clause(
92 &mut self,
93 c: uni_cypher::ast::CreateClause,
94 ) -> uni_cypher::ast::CreateClause {
95 uni_cypher::ast::CreateClause {
96 pattern: self.rewrite_pattern(c.pattern),
97 }
98 }
99
100 fn rewrite_delete_clause(
101 &mut self,
102 d: uni_cypher::ast::DeleteClause,
103 ) -> uni_cypher::ast::DeleteClause {
104 uni_cypher::ast::DeleteClause {
105 detach: d.detach,
106 items: d.items.into_iter().map(|e| self.rewrite_expr(e)).collect(),
107 }
108 }
109
110 fn rewrite_set_clause(&mut self, s: uni_cypher::ast::SetClause) -> uni_cypher::ast::SetClause {
111 uni_cypher::ast::SetClause {
112 items: s
113 .items
114 .into_iter()
115 .map(|item| self.rewrite_set_item(item))
116 .collect(),
117 }
118 }
119
120 fn rewrite_set_item(&mut self, item: uni_cypher::ast::SetItem) -> uni_cypher::ast::SetItem {
121 use uni_cypher::ast::SetItem;
122
123 match item {
124 SetItem::Property { expr, value } => SetItem::Property {
125 expr: self.rewrite_expr(expr),
126 value: self.rewrite_expr(value),
127 },
128 SetItem::Variable { variable, value } => SetItem::Variable {
129 variable,
130 value: self.rewrite_expr(value),
131 },
132 SetItem::VariablePlus { variable, value } => SetItem::VariablePlus {
133 variable,
134 value: self.rewrite_expr(value),
135 },
136 SetItem::Labels { variable, labels } => SetItem::Labels { variable, labels },
137 }
138 }
139
140 fn rewrite_remove_clause(
141 &mut self,
142 r: uni_cypher::ast::RemoveClause,
143 ) -> uni_cypher::ast::RemoveClause {
144 uni_cypher::ast::RemoveClause {
145 items: r
146 .items
147 .into_iter()
148 .map(|item| self.rewrite_remove_item(item))
149 .collect(),
150 }
151 }
152
153 fn rewrite_remove_item(
154 &mut self,
155 item: uni_cypher::ast::RemoveItem,
156 ) -> uni_cypher::ast::RemoveItem {
157 use uni_cypher::ast::RemoveItem;
158
159 match item {
160 RemoveItem::Property(expr) => RemoveItem::Property(self.rewrite_expr(expr)),
161 RemoveItem::Labels { variable, labels } => RemoveItem::Labels { variable, labels },
162 }
163 }
164
165 fn rewrite_unwind_clause(
166 &mut self,
167 u: uni_cypher::ast::UnwindClause,
168 ) -> uni_cypher::ast::UnwindClause {
169 uni_cypher::ast::UnwindClause {
170 expr: self.rewrite_expr(u.expr),
171 variable: u.variable,
172 }
173 }
174
175 fn rewrite_pattern(&mut self, pattern: uni_cypher::ast::Pattern) -> uni_cypher::ast::Pattern {
176 uni_cypher::ast::Pattern {
177 paths: pattern
178 .paths
179 .into_iter()
180 .map(|path| self.rewrite_path_pattern(path))
181 .collect(),
182 }
183 }
184
185 fn rewrite_path_pattern(
186 &mut self,
187 path: uni_cypher::ast::PathPattern,
188 ) -> uni_cypher::ast::PathPattern {
189 uni_cypher::ast::PathPattern {
190 variable: path.variable,
191 elements: path
192 .elements
193 .into_iter()
194 .map(|elem| self.rewrite_pattern_element(elem))
195 .collect(),
196 shortest_path_mode: path.shortest_path_mode,
197 }
198 }
199
200 fn rewrite_pattern_element(
201 &mut self,
202 elem: uni_cypher::ast::PatternElement,
203 ) -> uni_cypher::ast::PatternElement {
204 use uni_cypher::ast::PatternElement;
205
206 match elem {
207 PatternElement::Node(node) => PatternElement::Node(uni_cypher::ast::NodePattern {
208 variable: node.variable,
209 labels: node.labels,
210 properties: node.properties.map(|expr| self.rewrite_expr(expr)),
211 where_clause: node.where_clause.map(|expr| self.rewrite_expr(expr)),
212 }),
213 PatternElement::Relationship(rel) => {
214 PatternElement::Relationship(uni_cypher::ast::RelationshipPattern {
215 variable: rel.variable,
216 types: rel.types,
217 direction: rel.direction,
218 properties: rel.properties.map(|expr| self.rewrite_expr(expr)),
219 range: rel.range,
220 where_clause: rel.where_clause.map(|expr| self.rewrite_expr(expr)),
221 })
222 }
223 PatternElement::Parenthesized { pattern, range } => PatternElement::Parenthesized {
224 pattern: Box::new(self.rewrite_path_pattern(*pattern)),
225 range,
226 },
227 }
228 }
229
230 fn rewrite_order_by(
231 &mut self,
232 order_by: Option<Vec<uni_cypher::ast::SortItem>>,
233 ) -> Option<Vec<uni_cypher::ast::SortItem>> {
234 order_by.map(|items| {
235 items
236 .into_iter()
237 .map(|item| uni_cypher::ast::SortItem {
238 expr: self.rewrite_expr(item.expr),
239 ascending: item.ascending,
240 })
241 .collect()
242 })
243 }
244
245 fn rewrite_return_clause(
246 &mut self,
247 r: uni_cypher::ast::ReturnClause,
248 ) -> uni_cypher::ast::ReturnClause {
249 uni_cypher::ast::ReturnClause {
250 distinct: r.distinct,
251 items: r
252 .items
253 .into_iter()
254 .map(|item| self.rewrite_return_item(item))
255 .collect(),
256 order_by: self.rewrite_order_by(r.order_by),
257 skip: r.skip.map(|e| self.rewrite_expr(e)),
258 limit: r.limit.map(|e| self.rewrite_expr(e)),
259 }
260 }
261
262 fn rewrite_return_item(
263 &mut self,
264 item: uni_cypher::ast::ReturnItem,
265 ) -> uni_cypher::ast::ReturnItem {
266 use uni_cypher::ast::ReturnItem;
267
268 match item {
269 ReturnItem::All => ReturnItem::All,
270 ReturnItem::Expr {
271 expr,
272 alias,
273 source_text,
274 } => ReturnItem::Expr {
275 expr: self.rewrite_expr(expr),
276 alias,
277 source_text,
278 },
279 }
280 }
281
282 fn rewrite_with_clause(
283 &mut self,
284 w: uni_cypher::ast::WithClause,
285 ) -> uni_cypher::ast::WithClause {
286 uni_cypher::ast::WithClause {
287 distinct: w.distinct,
288 items: w
289 .items
290 .into_iter()
291 .map(|item| self.rewrite_return_item(item))
292 .collect(),
293 order_by: self.rewrite_order_by(w.order_by),
294 skip: w.skip.map(|e| self.rewrite_expr(e)),
295 limit: w.limit.map(|e| self.rewrite_expr(e)),
296 where_clause: w.where_clause.map(|e| self.rewrite_expr(e)),
297 }
298 }
299
300 pub fn rewrite_expr(&mut self, expr: Expr) -> Expr {
302 match expr {
303 Expr::PatternComprehension {
304 path_variable,
305 pattern,
306 where_clause,
307 map_expr,
308 } => Expr::PatternComprehension {
309 path_variable,
310 pattern, where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
312 map_expr: Box::new(self.rewrite_expr(*map_expr)),
313 },
314 Expr::CollectSubquery(query) => {
315 Expr::CollectSubquery(Box::new(self.rewrite_query(*query)))
316 }
317 Expr::FunctionCall {
319 name,
320 args,
321 distinct,
322 window_spec,
323 } => self.try_rewrite_function(name, args, distinct, window_spec),
324
325 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
327 left: Box::new(self.rewrite_expr(*left)),
328 op,
329 right: Box::new(self.rewrite_expr(*right)),
330 },
331
332 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
333 op,
334 expr: Box::new(self.rewrite_expr(*expr)),
335 },
336
337 Expr::Property(expr, prop) => Expr::Property(Box::new(self.rewrite_expr(*expr)), prop),
338
339 Expr::List(exprs) => {
340 Expr::List(exprs.into_iter().map(|e| self.rewrite_expr(e)).collect())
341 }
342
343 Expr::Map(entries) => Expr::Map(
344 entries
345 .into_iter()
346 .map(|(k, v)| (k, self.rewrite_expr(v)))
347 .collect(),
348 ),
349
350 Expr::Case {
351 expr,
352 when_then,
353 else_expr,
354 } => Expr::Case {
355 expr: expr.map(|e| Box::new(self.rewrite_expr(*e))),
356 when_then: when_then
357 .into_iter()
358 .map(|(w, t)| (self.rewrite_expr(w), self.rewrite_expr(t)))
359 .collect(),
360 else_expr: else_expr.map(|e| Box::new(self.rewrite_expr(*e))),
361 },
362
363 Expr::Exists {
364 query,
365 from_pattern_predicate,
366 } => Expr::Exists {
367 query: Box::new(self.rewrite_query(*query)),
368 from_pattern_predicate,
369 },
370
371 Expr::CountSubquery(query) => Expr::CountSubquery(Box::new(self.rewrite_query(*query))),
372
373 Expr::IsNull(expr) => Expr::IsNull(Box::new(self.rewrite_expr(*expr))),
374
375 Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(self.rewrite_expr(*expr))),
376
377 Expr::IsUnique(expr) => Expr::IsUnique(Box::new(self.rewrite_expr(*expr))),
378
379 Expr::In { expr, list } => Expr::In {
380 expr: Box::new(self.rewrite_expr(*expr)),
381 list: Box::new(self.rewrite_expr(*list)),
382 },
383
384 Expr::ArrayIndex { array, index } => Expr::ArrayIndex {
385 array: Box::new(self.rewrite_expr(*array)),
386 index: Box::new(self.rewrite_expr(*index)),
387 },
388
389 Expr::ArraySlice { array, start, end } => Expr::ArraySlice {
390 array: Box::new(self.rewrite_expr(*array)),
391 start: start.map(|e| Box::new(self.rewrite_expr(*e))),
392 end: end.map(|e| Box::new(self.rewrite_expr(*e))),
393 },
394
395 Expr::Quantifier {
396 quantifier,
397 variable,
398 list,
399 predicate,
400 } => Expr::Quantifier {
401 quantifier,
402 variable,
403 list: Box::new(self.rewrite_expr(*list)),
404 predicate: Box::new(self.rewrite_expr(*predicate)),
405 },
406
407 Expr::Reduce {
408 accumulator,
409 init,
410 variable,
411 list,
412 expr,
413 } => Expr::Reduce {
414 accumulator,
415 init: Box::new(self.rewrite_expr(*init)),
416 variable,
417 list: Box::new(self.rewrite_expr(*list)),
418 expr: Box::new(self.rewrite_expr(*expr)),
419 },
420
421 Expr::ListComprehension {
422 variable,
423 list,
424 where_clause,
425 map_expr,
426 } => Expr::ListComprehension {
427 variable,
428 list: Box::new(self.rewrite_expr(*list)),
429 where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
430 map_expr: Box::new(self.rewrite_expr(*map_expr)),
431 },
432
433 Expr::ValidAt {
434 entity,
435 timestamp,
436 start_prop,
437 end_prop,
438 } => Expr::ValidAt {
439 entity: Box::new(self.rewrite_expr(*entity)),
440 timestamp: Box::new(self.rewrite_expr(*timestamp)),
441 start_prop,
442 end_prop,
443 },
444
445 Expr::MapProjection { base, items } => Expr::MapProjection {
446 base: Box::new(self.rewrite_expr(*base)),
447 items: items
448 .into_iter()
449 .map(|item| match item {
450 MapProjectionItem::LiteralEntry(k, v) => {
451 MapProjectionItem::LiteralEntry(k, Box::new(self.rewrite_expr(*v)))
452 }
453 other => other,
454 })
455 .collect(),
456 },
457
458 Expr::LabelCheck { expr, labels } => Expr::LabelCheck {
459 expr: Box::new(self.rewrite_expr(*expr)),
460 labels,
461 },
462
463 Expr::Literal(_) | Expr::Parameter(_) | Expr::Variable(_) | Expr::Wildcard => expr,
465 }
466 }
467
468 fn try_rewrite_function(
470 &mut self,
471 name: String,
472 args: Vec<Expr>,
473 distinct: bool,
474 window_spec: Option<uni_cypher::ast::WindowSpec>,
475 ) -> Expr {
476 let rewritten_args: Vec<Expr> =
478 args.into_iter().map(|arg| self.rewrite_expr(arg)).collect();
479
480 self.context.stats.record_visit();
482
483 let make_fallback = |name, args| Expr::FunctionCall {
485 name,
486 args,
487 distinct,
488 window_spec: window_spec.clone(),
489 };
490
491 let Some(rule) = self.registry.get_rule(&name) else {
493 return make_fallback(name, rewritten_args);
494 };
495
496 if let Err(e) = rule.validate_args(&rewritten_args) {
498 self.context.stats.record_failure(&name, e);
499 if self.context.config.verbose_logging {
500 tracing::debug!(
501 "Rewrite validation failed for {}: {:?}",
502 name,
503 self.context.stats.errors.last()
504 );
505 }
506 return make_fallback(name, rewritten_args);
507 }
508
509 if !rule.is_applicable(&self.context) {
511 let error = RewriteError::NotApplicable {
512 reason: "Context requirements not met".to_string(),
513 };
514 self.context.stats.record_failure(&name, error);
515 if self.context.config.verbose_logging {
516 tracing::debug!("Rewrite not applicable for {}", name);
517 }
518 return make_fallback(name, rewritten_args);
519 }
520
521 match rule.rewrite(rewritten_args.clone(), &self.context) {
523 Ok(rewritten_expr) => {
524 self.context.stats.record_success(&name);
525 if self.context.config.verbose_logging {
526 tracing::debug!("Rewrote function call: {} -> {:?}", name, rewritten_expr);
527 } else {
528 tracing::info!("Rewrote function: {}", name);
529 }
530 rewritten_expr
531 }
532 Err(e) => {
533 self.context.stats.record_failure(&name, e);
534 if self.context.config.verbose_logging {
535 tracing::debug!(
536 "Rewrite failed for {}: {:?}",
537 name,
538 self.context.stats.errors.last()
539 );
540 }
541 make_fallback(name, rewritten_args)
542 }
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use crate::query::rewrite::context::RewriteConfig;
551 use uni_cypher::ast::CypherLiteral;
552
553 #[test]
554 fn test_walker_visits_nested_expressions() {
555 let registry = RewriteRegistry::new();
556 let config = RewriteConfig::default();
557 let mut walker = ExpressionWalker::new(®istry, RewriteContext::with_config(config));
558
559 let expr = Expr::BinaryOp {
561 left: Box::new(Expr::FunctionCall {
562 name: "func1".into(),
563 args: vec![Expr::Literal(CypherLiteral::Integer(1))],
564 distinct: false,
565 window_spec: None,
566 }),
567 op: uni_cypher::ast::BinaryOp::And,
568 right: Box::new(Expr::FunctionCall {
569 name: "func2".into(),
570 args: vec![Expr::Literal(CypherLiteral::Integer(2))],
571 distinct: false,
572 window_spec: None,
573 }),
574 };
575
576 let _ = walker.rewrite_expr(expr);
577
578 assert_eq!(walker.context().stats.functions_visited, 2);
580 }
581
582 #[test]
583 fn test_walker_fallback_without_rules() {
584 let registry = RewriteRegistry::new();
585 let config = RewriteConfig::default();
586 let mut walker = ExpressionWalker::new(®istry, RewriteContext::with_config(config));
587
588 let original = Expr::FunctionCall {
589 name: "unknown".into(),
590 args: vec![Expr::Literal(CypherLiteral::Integer(1))],
591 distinct: false,
592 window_spec: None,
593 };
594
595 let rewritten = walker.rewrite_expr(original.clone());
596
597 assert!(matches!(rewritten, Expr::FunctionCall { name, .. } if name == "unknown"));
599 assert_eq!(walker.context().stats.functions_visited, 1);
600 }
601}