1use crate::query::rewrite::context::RewriteContext;
3use crate::query::rewrite::error::RewriteError;
4use crate::query::rewrite::registry::RewriteRegistry;
5use uni_cypher::ast::{Expr, MapProjectionItem, Query, Statement};
6
7pub struct ExpressionWalker<'a> {
9 registry: &'a RewriteRegistry,
10 context: RewriteContext,
11}
12
13impl<'a> ExpressionWalker<'a> {
14 pub fn new(registry: &'a RewriteRegistry, context: RewriteContext) -> Self {
16 Self { registry, context }
17 }
18
19 pub fn context(&self) -> &RewriteContext {
21 &self.context
22 }
23
24 pub fn context_mut(&mut self) -> &mut RewriteContext {
26 &mut self.context
27 }
28
29 pub fn into_context(self) -> RewriteContext {
31 self.context
32 }
33
34 pub fn rewrite_statement(&mut self, stmt: Statement) -> Statement {
36 Statement {
37 clauses: stmt
38 .clauses
39 .into_iter()
40 .map(|c| self.rewrite_clause(c))
41 .collect(),
42 }
43 }
44
45 pub fn rewrite_query(&mut self, query: Query) -> Query {
47 match query {
48 Query::Single(stmt) => Query::Single(self.rewrite_statement(stmt)),
49 Query::Union { left, right, all } => Query::Union {
50 left: Box::new(self.rewrite_query(*left)),
51 right: Box::new(self.rewrite_query(*right)),
52 all,
53 },
54 Query::Schema(schema_cmd) => Query::Schema(schema_cmd),
55 Query::Transaction(txn_cmd) => Query::Transaction(txn_cmd),
56 Query::Explain(inner) => Query::Explain(Box::new(self.rewrite_query(*inner))),
57 Query::TimeTravel { .. } => {
58 unreachable!("TimeTravel should be resolved at API layer before rewriting")
59 }
60 }
61 }
62
63 fn rewrite_clause(&mut self, clause: uni_cypher::ast::Clause) -> uni_cypher::ast::Clause {
65 use uni_cypher::ast::Clause;
66
67 match clause {
68 Clause::Match(m) => Clause::Match(self.rewrite_match_clause(m)),
69 Clause::Create(c) => Clause::Create(self.rewrite_create_clause(c)),
70 Clause::Return(r) => Clause::Return(self.rewrite_return_clause(r)),
71 Clause::With(w) => Clause::With(self.rewrite_with_clause(w)),
72 Clause::Unwind(u) => Clause::Unwind(self.rewrite_unwind_clause(u)),
73 Clause::Set(s) => Clause::Set(self.rewrite_set_clause(s)),
74 Clause::Delete(d) => Clause::Delete(self.rewrite_delete_clause(d)),
75 Clause::Remove(r) => Clause::Remove(self.rewrite_remove_clause(r)),
76 other => other,
78 }
79 }
80
81 fn rewrite_match_clause(
82 &mut self,
83 m: uni_cypher::ast::MatchClause,
84 ) -> uni_cypher::ast::MatchClause {
85 uni_cypher::ast::MatchClause {
86 optional: m.optional,
87 pattern: self.rewrite_pattern(m.pattern),
88 where_clause: m.where_clause.map(|e| self.rewrite_expr(e)),
89 }
90 }
91
92 fn rewrite_create_clause(
93 &mut self,
94 c: uni_cypher::ast::CreateClause,
95 ) -> uni_cypher::ast::CreateClause {
96 uni_cypher::ast::CreateClause {
97 pattern: self.rewrite_pattern(c.pattern),
98 }
99 }
100
101 fn rewrite_delete_clause(
102 &mut self,
103 d: uni_cypher::ast::DeleteClause,
104 ) -> uni_cypher::ast::DeleteClause {
105 uni_cypher::ast::DeleteClause {
106 detach: d.detach,
107 items: d.items.into_iter().map(|e| self.rewrite_expr(e)).collect(),
108 }
109 }
110
111 fn rewrite_set_clause(&mut self, s: uni_cypher::ast::SetClause) -> uni_cypher::ast::SetClause {
112 uni_cypher::ast::SetClause {
113 items: s
114 .items
115 .into_iter()
116 .map(|item| self.rewrite_set_item(item))
117 .collect(),
118 }
119 }
120
121 fn rewrite_set_item(&mut self, item: uni_cypher::ast::SetItem) -> uni_cypher::ast::SetItem {
122 use uni_cypher::ast::SetItem;
123
124 match item {
125 SetItem::Property { expr, value } => SetItem::Property {
126 expr: self.rewrite_expr(expr),
127 value: self.rewrite_expr(value),
128 },
129 SetItem::Variable { variable, value } => SetItem::Variable {
130 variable,
131 value: self.rewrite_expr(value),
132 },
133 SetItem::VariablePlus { variable, value } => SetItem::VariablePlus {
134 variable,
135 value: self.rewrite_expr(value),
136 },
137 SetItem::Labels { variable, labels } => SetItem::Labels { variable, labels },
138 }
139 }
140
141 fn rewrite_remove_clause(
142 &mut self,
143 r: uni_cypher::ast::RemoveClause,
144 ) -> uni_cypher::ast::RemoveClause {
145 uni_cypher::ast::RemoveClause {
146 items: r
147 .items
148 .into_iter()
149 .map(|item| self.rewrite_remove_item(item))
150 .collect(),
151 }
152 }
153
154 fn rewrite_remove_item(
155 &mut self,
156 item: uni_cypher::ast::RemoveItem,
157 ) -> uni_cypher::ast::RemoveItem {
158 use uni_cypher::ast::RemoveItem;
159
160 match item {
161 RemoveItem::Property(expr) => RemoveItem::Property(self.rewrite_expr(expr)),
162 RemoveItem::Labels { variable, labels } => RemoveItem::Labels { variable, labels },
163 }
164 }
165
166 fn rewrite_unwind_clause(
167 &mut self,
168 u: uni_cypher::ast::UnwindClause,
169 ) -> uni_cypher::ast::UnwindClause {
170 uni_cypher::ast::UnwindClause {
171 expr: self.rewrite_expr(u.expr),
172 variable: u.variable,
173 }
174 }
175
176 fn rewrite_pattern(&mut self, pattern: uni_cypher::ast::Pattern) -> uni_cypher::ast::Pattern {
177 uni_cypher::ast::Pattern {
178 paths: pattern
179 .paths
180 .into_iter()
181 .map(|path| self.rewrite_path_pattern(path))
182 .collect(),
183 }
184 }
185
186 fn rewrite_path_pattern(
187 &mut self,
188 path: uni_cypher::ast::PathPattern,
189 ) -> uni_cypher::ast::PathPattern {
190 uni_cypher::ast::PathPattern {
191 variable: path.variable,
192 elements: path
193 .elements
194 .into_iter()
195 .map(|elem| self.rewrite_pattern_element(elem))
196 .collect(),
197 shortest_path_mode: path.shortest_path_mode,
198 }
199 }
200
201 fn rewrite_pattern_element(
202 &mut self,
203 elem: uni_cypher::ast::PatternElement,
204 ) -> uni_cypher::ast::PatternElement {
205 use uni_cypher::ast::PatternElement;
206
207 match elem {
208 PatternElement::Node(node) => PatternElement::Node(uni_cypher::ast::NodePattern {
209 variable: node.variable,
210 labels: node.labels,
211 properties: node.properties.map(|expr| self.rewrite_expr(expr)),
212 where_clause: node.where_clause.map(|expr| self.rewrite_expr(expr)),
213 }),
214 PatternElement::Relationship(rel) => {
215 PatternElement::Relationship(uni_cypher::ast::RelationshipPattern {
216 variable: rel.variable,
217 types: rel.types,
218 direction: rel.direction,
219 properties: rel.properties.map(|expr| self.rewrite_expr(expr)),
220 range: rel.range,
221 where_clause: rel.where_clause.map(|expr| self.rewrite_expr(expr)),
222 })
223 }
224 PatternElement::Parenthesized { pattern, range } => PatternElement::Parenthesized {
225 pattern: Box::new(self.rewrite_path_pattern(*pattern)),
226 range,
227 },
228 }
229 }
230
231 fn rewrite_order_by(
232 &mut self,
233 order_by: Option<Vec<uni_cypher::ast::SortItem>>,
234 ) -> Option<Vec<uni_cypher::ast::SortItem>> {
235 order_by.map(|items| {
236 items
237 .into_iter()
238 .map(|item| uni_cypher::ast::SortItem {
239 expr: self.rewrite_expr(item.expr),
240 ascending: item.ascending,
241 })
242 .collect()
243 })
244 }
245
246 fn rewrite_return_clause(
247 &mut self,
248 r: uni_cypher::ast::ReturnClause,
249 ) -> uni_cypher::ast::ReturnClause {
250 uni_cypher::ast::ReturnClause {
251 distinct: r.distinct,
252 items: r
253 .items
254 .into_iter()
255 .map(|item| self.rewrite_return_item(item))
256 .collect(),
257 order_by: self.rewrite_order_by(r.order_by),
258 skip: r.skip.map(|e| self.rewrite_expr(e)),
259 limit: r.limit.map(|e| self.rewrite_expr(e)),
260 }
261 }
262
263 fn rewrite_return_item(
264 &mut self,
265 item: uni_cypher::ast::ReturnItem,
266 ) -> uni_cypher::ast::ReturnItem {
267 use uni_cypher::ast::ReturnItem;
268
269 match item {
270 ReturnItem::All => ReturnItem::All,
271 ReturnItem::Expr {
272 expr,
273 alias,
274 source_text,
275 } => ReturnItem::Expr {
276 expr: self.rewrite_expr(expr),
277 alias,
278 source_text,
279 },
280 }
281 }
282
283 fn rewrite_with_clause(
284 &mut self,
285 w: uni_cypher::ast::WithClause,
286 ) -> uni_cypher::ast::WithClause {
287 uni_cypher::ast::WithClause {
288 distinct: w.distinct,
289 items: w
290 .items
291 .into_iter()
292 .map(|item| self.rewrite_return_item(item))
293 .collect(),
294 order_by: self.rewrite_order_by(w.order_by),
295 skip: w.skip.map(|e| self.rewrite_expr(e)),
296 limit: w.limit.map(|e| self.rewrite_expr(e)),
297 where_clause: w.where_clause.map(|e| self.rewrite_expr(e)),
298 }
299 }
300
301 pub fn rewrite_expr(&mut self, expr: Expr) -> Expr {
303 match expr {
304 Expr::PatternComprehension {
305 path_variable,
306 pattern,
307 where_clause,
308 map_expr,
309 } => Expr::PatternComprehension {
310 path_variable,
311 pattern, where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
313 map_expr: Box::new(self.rewrite_expr(*map_expr)),
314 },
315 Expr::CollectSubquery(_) => expr,
318 Expr::FunctionCall {
320 name,
321 args,
322 distinct,
323 window_spec,
324 } => self.try_rewrite_function(name, args, distinct, window_spec),
325
326 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
328 left: Box::new(self.rewrite_expr(*left)),
329 op,
330 right: Box::new(self.rewrite_expr(*right)),
331 },
332
333 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
334 op,
335 expr: Box::new(self.rewrite_expr(*expr)),
336 },
337
338 Expr::Property(expr, prop) => Expr::Property(Box::new(self.rewrite_expr(*expr)), prop),
339
340 Expr::List(exprs) => {
341 Expr::List(exprs.into_iter().map(|e| self.rewrite_expr(e)).collect())
342 }
343
344 Expr::Map(entries) => Expr::Map(
345 entries
346 .into_iter()
347 .map(|(k, v)| (k, self.rewrite_expr(v)))
348 .collect(),
349 ),
350
351 Expr::Case {
352 expr,
353 when_then,
354 else_expr,
355 } => Expr::Case {
356 expr: expr.map(|e| Box::new(self.rewrite_expr(*e))),
357 when_then: when_then
358 .into_iter()
359 .map(|(w, t)| (self.rewrite_expr(w), self.rewrite_expr(t)))
360 .collect(),
361 else_expr: else_expr.map(|e| Box::new(self.rewrite_expr(*e))),
362 },
363
364 Expr::Exists {
365 query,
366 from_pattern_predicate,
367 } => Expr::Exists {
368 query: Box::new(self.rewrite_query(*query)),
369 from_pattern_predicate,
370 },
371
372 Expr::CountSubquery(query) => Expr::CountSubquery(Box::new(self.rewrite_query(*query))),
373
374 Expr::IsNull(expr) => Expr::IsNull(Box::new(self.rewrite_expr(*expr))),
375
376 Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(self.rewrite_expr(*expr))),
377
378 Expr::IsUnique(expr) => Expr::IsUnique(Box::new(self.rewrite_expr(*expr))),
379
380 Expr::In { expr, list } => Expr::In {
381 expr: Box::new(self.rewrite_expr(*expr)),
382 list: Box::new(self.rewrite_expr(*list)),
383 },
384
385 Expr::ArrayIndex { array, index } => Expr::ArrayIndex {
386 array: Box::new(self.rewrite_expr(*array)),
387 index: Box::new(self.rewrite_expr(*index)),
388 },
389
390 Expr::ArraySlice { array, start, end } => Expr::ArraySlice {
391 array: Box::new(self.rewrite_expr(*array)),
392 start: start.map(|e| Box::new(self.rewrite_expr(*e))),
393 end: end.map(|e| Box::new(self.rewrite_expr(*e))),
394 },
395
396 Expr::Quantifier {
397 quantifier,
398 variable,
399 list,
400 predicate,
401 } => Expr::Quantifier {
402 quantifier,
403 variable,
404 list: Box::new(self.rewrite_expr(*list)),
405 predicate: Box::new(self.rewrite_expr(*predicate)),
406 },
407
408 Expr::Reduce {
409 accumulator,
410 init,
411 variable,
412 list,
413 expr,
414 } => Expr::Reduce {
415 accumulator,
416 init: Box::new(self.rewrite_expr(*init)),
417 variable,
418 list: Box::new(self.rewrite_expr(*list)),
419 expr: Box::new(self.rewrite_expr(*expr)),
420 },
421
422 Expr::ListComprehension {
423 variable,
424 list,
425 where_clause,
426 map_expr,
427 } => Expr::ListComprehension {
428 variable,
429 list: Box::new(self.rewrite_expr(*list)),
430 where_clause: where_clause.map(|e| Box::new(self.rewrite_expr(*e))),
431 map_expr: Box::new(self.rewrite_expr(*map_expr)),
432 },
433
434 Expr::ValidAt {
435 entity,
436 timestamp,
437 start_prop,
438 end_prop,
439 } => Expr::ValidAt {
440 entity: Box::new(self.rewrite_expr(*entity)),
441 timestamp: Box::new(self.rewrite_expr(*timestamp)),
442 start_prop,
443 end_prop,
444 },
445
446 Expr::MapProjection { base, items } => Expr::MapProjection {
447 base: Box::new(self.rewrite_expr(*base)),
448 items: items
449 .into_iter()
450 .map(|item| match item {
451 MapProjectionItem::LiteralEntry(k, v) => {
452 MapProjectionItem::LiteralEntry(k, Box::new(self.rewrite_expr(*v)))
453 }
454 other => other,
455 })
456 .collect(),
457 },
458
459 Expr::LabelCheck { expr, labels } => Expr::LabelCheck {
460 expr: Box::new(self.rewrite_expr(*expr)),
461 labels,
462 },
463
464 Expr::Literal(_) | Expr::Parameter(_) | Expr::Variable(_) | Expr::Wildcard => expr,
466 }
467 }
468
469 fn try_rewrite_function(
471 &mut self,
472 name: String,
473 args: Vec<Expr>,
474 distinct: bool,
475 window_spec: Option<uni_cypher::ast::WindowSpec>,
476 ) -> Expr {
477 let rewritten_args: Vec<Expr> =
479 args.into_iter().map(|arg| self.rewrite_expr(arg)).collect();
480
481 self.context.stats.record_visit();
483
484 let make_fallback = |name, args| Expr::FunctionCall {
486 name,
487 args,
488 distinct,
489 window_spec: window_spec.clone(),
490 };
491
492 let Some(rule) = self.registry.get_rule(&name) else {
494 return make_fallback(name, rewritten_args);
495 };
496
497 if let Err(e) = rule.validate_args(&rewritten_args) {
499 self.context.stats.record_failure(&name, e);
500 if self.context.config.verbose_logging {
501 tracing::debug!(
502 "Rewrite validation failed for {}: {:?}",
503 name,
504 self.context.stats.errors.last()
505 );
506 }
507 return make_fallback(name, rewritten_args);
508 }
509
510 if !rule.is_applicable(&self.context) {
512 let error = RewriteError::NotApplicable {
513 reason: "Context requirements not met".to_string(),
514 };
515 self.context.stats.record_failure(&name, error);
516 if self.context.config.verbose_logging {
517 tracing::debug!("Rewrite not applicable for {}", name);
518 }
519 return make_fallback(name, rewritten_args);
520 }
521
522 match rule.rewrite(rewritten_args.clone(), &self.context) {
524 Ok(rewritten_expr) => {
525 self.context.stats.record_success(&name);
526 if self.context.config.verbose_logging {
527 tracing::debug!("Rewrote function call: {} -> {:?}", name, rewritten_expr);
528 } else {
529 tracing::info!("Rewrote function: {}", name);
530 }
531 rewritten_expr
532 }
533 Err(e) => {
534 self.context.stats.record_failure(&name, e);
535 if self.context.config.verbose_logging {
536 tracing::debug!(
537 "Rewrite failed for {}: {:?}",
538 name,
539 self.context.stats.errors.last()
540 );
541 }
542 make_fallback(name, rewritten_args)
543 }
544 }
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use crate::query::rewrite::context::RewriteConfig;
552 use uni_cypher::ast::CypherLiteral;
553
554 #[test]
555 fn test_walker_visits_nested_expressions() {
556 let registry = RewriteRegistry::new();
557 let config = RewriteConfig::default();
558 let mut walker = ExpressionWalker::new(®istry, RewriteContext::with_config(config));
559
560 let expr = Expr::BinaryOp {
562 left: Box::new(Expr::FunctionCall {
563 name: "func1".into(),
564 args: vec![Expr::Literal(CypherLiteral::Integer(1))],
565 distinct: false,
566 window_spec: None,
567 }),
568 op: uni_cypher::ast::BinaryOp::And,
569 right: Box::new(Expr::FunctionCall {
570 name: "func2".into(),
571 args: vec![Expr::Literal(CypherLiteral::Integer(2))],
572 distinct: false,
573 window_spec: None,
574 }),
575 };
576
577 let _ = walker.rewrite_expr(expr);
578
579 assert_eq!(walker.context().stats.functions_visited, 2);
581 }
582
583 #[test]
584 fn test_walker_fallback_without_rules() {
585 let registry = RewriteRegistry::new();
586 let config = RewriteConfig::default();
587 let mut walker = ExpressionWalker::new(®istry, RewriteContext::with_config(config));
588
589 let original = Expr::FunctionCall {
590 name: "unknown".into(),
591 args: vec![Expr::Literal(CypherLiteral::Integer(1))],
592 distinct: false,
593 window_spec: None,
594 };
595
596 let rewritten = walker.rewrite_expr(original.clone());
597
598 assert!(matches!(rewritten, Expr::FunctionCall { name, .. } if name == "unknown"));
600 assert_eq!(walker.context().stats.functions_visited, 1);
601 }
602}