1use anyhow::Result;
20use uni_cypher::ast::{
21 Clause, Expr, MapProjectionItem, Pattern, PatternElement, Query, RemoveItem, ReturnItem,
22 SetItem, SortItem, Statement,
23};
24
25pub fn rewrite_function_calls_in_query<F>(query: Query, rename: &mut F) -> Result<Query>
30where
31 F: FnMut(&str) -> Result<Option<String>>,
32{
33 match query {
34 Query::Single(stmt) => Ok(Query::Single(rewrite_statement(stmt, rename)?)),
35 Query::Union { left, right, all } => Ok(Query::Union {
36 left: Box::new(rewrite_function_calls_in_query(*left, rename)?),
37 right: Box::new(rewrite_function_calls_in_query(*right, rename)?),
38 all,
39 }),
40 Query::Schema(s) => Ok(Query::Schema(s)),
41 Query::Explain(inner) => Ok(Query::Explain(Box::new(rewrite_function_calls_in_query(
42 *inner, rename,
43 )?))),
44 Query::TimeTravel { .. } => Ok(query),
45 }
46}
47
48fn rewrite_statement<F>(stmt: Statement, rename: &mut F) -> Result<Statement>
49where
50 F: FnMut(&str) -> Result<Option<String>>,
51{
52 let mut clauses = Vec::with_capacity(stmt.clauses.len());
53 for c in stmt.clauses {
54 clauses.push(rewrite_clause(c, rename)?);
55 }
56 Ok(Statement { clauses })
57}
58
59fn rewrite_clause<F>(clause: Clause, rename: &mut F) -> Result<Clause>
60where
61 F: FnMut(&str) -> Result<Option<String>>,
62{
63 Ok(match clause {
64 Clause::Match(m) => Clause::Match(uni_cypher::ast::MatchClause {
65 optional: m.optional,
66 for_update: m.for_update,
67 pattern: rewrite_pattern(m.pattern, rename)?,
68 where_clause: opt_expr(m.where_clause, rename)?,
69 }),
70 Clause::Create(c) => Clause::Create(uni_cypher::ast::CreateClause {
71 pattern: rewrite_pattern(c.pattern, rename)?,
72 }),
73 Clause::Return(r) => Clause::Return(uni_cypher::ast::ReturnClause {
74 distinct: r.distinct,
75 items: r
76 .items
77 .into_iter()
78 .map(|item| rewrite_return_item(item, rename))
79 .collect::<Result<_>>()?,
80 order_by: rewrite_order_by(r.order_by, rename)?,
81 skip: opt_expr(r.skip, rename)?,
82 limit: opt_expr(r.limit, rename)?,
83 }),
84 Clause::With(w) => Clause::With(uni_cypher::ast::WithClause {
85 distinct: w.distinct,
86 items: w
87 .items
88 .into_iter()
89 .map(|item| rewrite_return_item(item, rename))
90 .collect::<Result<_>>()?,
91 order_by: rewrite_order_by(w.order_by, rename)?,
92 skip: opt_expr(w.skip, rename)?,
93 limit: opt_expr(w.limit, rename)?,
94 where_clause: opt_expr(w.where_clause, rename)?,
95 }),
96 Clause::Unwind(u) => Clause::Unwind(uni_cypher::ast::UnwindClause {
97 expr: rewrite_expr(u.expr, rename)?,
98 variable: u.variable,
99 }),
100 Clause::Set(s) => Clause::Set(uni_cypher::ast::SetClause {
101 items: s
102 .items
103 .into_iter()
104 .map(|item| rewrite_set_item(item, rename))
105 .collect::<Result<_>>()?,
106 }),
107 Clause::Delete(d) => Clause::Delete(uni_cypher::ast::DeleteClause {
108 detach: d.detach,
109 items: d
110 .items
111 .into_iter()
112 .map(|e| rewrite_expr(e, rename))
113 .collect::<Result<_>>()?,
114 }),
115 Clause::Remove(r) => Clause::Remove(uni_cypher::ast::RemoveClause {
116 items: r
117 .items
118 .into_iter()
119 .map(|item| rewrite_remove_item(item, rename))
120 .collect::<Result<_>>()?,
121 }),
122 Clause::Call(mut call) => {
123 match &mut call.kind {
125 uni_cypher::ast::CallKind::Procedure { arguments, .. } => {
126 let mut new_args = Vec::with_capacity(arguments.len());
127 for a in arguments.drain(..) {
128 new_args.push(rewrite_expr(a, rename)?);
129 }
130 *arguments = new_args;
131 }
132 uni_cypher::ast::CallKind::Subquery(query) => {
133 let q = std::mem::replace(
134 query.as_mut(),
135 Query::Single(Statement { clauses: vec![] }),
136 );
137 **query = rewrite_function_calls_in_query(q, rename)?;
138 }
139 }
140 if let Some(w) = call.where_clause.take() {
141 call.where_clause = Some(rewrite_expr(w, rename)?);
142 }
143 Clause::Call(call)
144 }
145 other => other,
147 })
148}
149
150fn rewrite_set_item<F>(item: SetItem, rename: &mut F) -> Result<SetItem>
151where
152 F: FnMut(&str) -> Result<Option<String>>,
153{
154 Ok(match item {
155 SetItem::Property { expr, value } => SetItem::Property {
156 expr: rewrite_expr(expr, rename)?,
157 value: rewrite_expr(value, rename)?,
158 },
159 SetItem::Variable { variable, value } => SetItem::Variable {
160 variable,
161 value: rewrite_expr(value, rename)?,
162 },
163 SetItem::VariablePlus { variable, value } => SetItem::VariablePlus {
164 variable,
165 value: rewrite_expr(value, rename)?,
166 },
167 SetItem::Labels { variable, labels } => SetItem::Labels { variable, labels },
168 })
169}
170
171fn rewrite_remove_item<F>(item: RemoveItem, rename: &mut F) -> Result<RemoveItem>
172where
173 F: FnMut(&str) -> Result<Option<String>>,
174{
175 Ok(match item {
176 RemoveItem::Property(e) => RemoveItem::Property(rewrite_expr(e, rename)?),
177 RemoveItem::Labels { variable, labels } => RemoveItem::Labels { variable, labels },
178 })
179}
180
181fn rewrite_return_item<F>(item: ReturnItem, rename: &mut F) -> Result<ReturnItem>
182where
183 F: FnMut(&str) -> Result<Option<String>>,
184{
185 Ok(match item {
186 ReturnItem::All => ReturnItem::All,
187 ReturnItem::Expr {
188 expr,
189 alias,
190 source_text,
191 } => ReturnItem::Expr {
192 expr: rewrite_expr(expr, rename)?,
193 alias,
194 source_text,
195 },
196 })
197}
198
199fn rewrite_order_by<F>(
200 order_by: Option<Vec<SortItem>>,
201 rename: &mut F,
202) -> Result<Option<Vec<SortItem>>>
203where
204 F: FnMut(&str) -> Result<Option<String>>,
205{
206 let Some(items) = order_by else {
207 return Ok(None);
208 };
209 let mut out = Vec::with_capacity(items.len());
210 for item in items {
211 out.push(SortItem {
212 expr: rewrite_expr(item.expr, rename)?,
213 ascending: item.ascending,
214 });
215 }
216 Ok(Some(out))
217}
218
219fn rewrite_pattern<F>(pattern: Pattern, rename: &mut F) -> Result<Pattern>
220where
221 F: FnMut(&str) -> Result<Option<String>>,
222{
223 let mut paths = Vec::with_capacity(pattern.paths.len());
224 for path in pattern.paths {
225 paths.push(uni_cypher::ast::PathPattern {
226 variable: path.variable,
227 elements: path
228 .elements
229 .into_iter()
230 .map(|e| rewrite_pattern_element(e, rename))
231 .collect::<Result<_>>()?,
232 shortest_path_mode: path.shortest_path_mode,
233 });
234 }
235 Ok(Pattern { paths })
236}
237
238fn rewrite_pattern_element<F>(elem: PatternElement, rename: &mut F) -> Result<PatternElement>
239where
240 F: FnMut(&str) -> Result<Option<String>>,
241{
242 Ok(match elem {
243 PatternElement::Node(n) => PatternElement::Node(uni_cypher::ast::NodePattern {
244 variable: n.variable,
245 labels: n.labels,
246 properties: opt_expr(n.properties, rename)?,
247 where_clause: opt_expr(n.where_clause, rename)?,
248 }),
249 PatternElement::Relationship(r) => {
250 PatternElement::Relationship(uni_cypher::ast::RelationshipPattern {
251 variable: r.variable,
252 types: r.types,
253 direction: r.direction,
254 properties: opt_expr(r.properties, rename)?,
255 range: r.range,
256 where_clause: opt_expr(r.where_clause, rename)?,
257 })
258 }
259 PatternElement::Parenthesized { pattern, range } => PatternElement::Parenthesized {
260 pattern: Box::new(uni_cypher::ast::PathPattern {
261 variable: pattern.variable,
262 elements: pattern
263 .elements
264 .into_iter()
265 .map(|e| rewrite_pattern_element(e, rename))
266 .collect::<Result<_>>()?,
267 shortest_path_mode: pattern.shortest_path_mode,
268 }),
269 range,
270 },
271 })
272}
273
274fn opt_expr<F>(e: Option<Expr>, rename: &mut F) -> Result<Option<Expr>>
275where
276 F: FnMut(&str) -> Result<Option<String>>,
277{
278 match e {
279 Some(e) => Ok(Some(rewrite_expr(e, rename)?)),
280 None => Ok(None),
281 }
282}
283
284fn rewrite_expr<F>(expr: Expr, rename: &mut F) -> Result<Expr>
285where
286 F: FnMut(&str) -> Result<Option<String>>,
287{
288 Ok(match expr {
289 Expr::FunctionCall {
290 name,
291 args,
292 distinct,
293 window_spec,
294 } => {
295 let mut new_args = Vec::with_capacity(args.len());
296 for a in args {
297 new_args.push(rewrite_expr(a, rename)?);
298 }
299 let new_name = rename(&name)?.unwrap_or(name);
300 Expr::FunctionCall {
301 name: new_name,
302 args: new_args,
303 distinct,
304 window_spec,
305 }
306 }
307 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
308 left: Box::new(rewrite_expr(*left, rename)?),
309 op,
310 right: Box::new(rewrite_expr(*right, rename)?),
311 },
312 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
313 op,
314 expr: Box::new(rewrite_expr(*expr, rename)?),
315 },
316 Expr::Property(base, prop) => Expr::Property(Box::new(rewrite_expr(*base, rename)?), prop),
317 Expr::List(exprs) => Expr::List(
318 exprs
319 .into_iter()
320 .map(|e| rewrite_expr(e, rename))
321 .collect::<Result<_>>()?,
322 ),
323 Expr::Map(entries) => {
324 let mut out = Vec::with_capacity(entries.len());
325 for (k, v) in entries {
326 out.push((k, rewrite_expr(v, rename)?));
327 }
328 Expr::Map(out)
329 }
330 Expr::Case {
331 expr,
332 when_then,
333 else_expr,
334 } => {
335 let expr = match expr {
336 Some(e) => Some(Box::new(rewrite_expr(*e, rename)?)),
337 None => None,
338 };
339 let mut new_when = Vec::with_capacity(when_then.len());
340 for (w, t) in when_then {
341 new_when.push((rewrite_expr(w, rename)?, rewrite_expr(t, rename)?));
342 }
343 let else_expr = match else_expr {
344 Some(e) => Some(Box::new(rewrite_expr(*e, rename)?)),
345 None => None,
346 };
347 Expr::Case {
348 expr,
349 when_then: new_when,
350 else_expr,
351 }
352 }
353 Expr::Exists {
354 query,
355 from_pattern_predicate,
356 } => Expr::Exists {
357 query: Box::new(rewrite_function_calls_in_query(*query, rename)?),
358 from_pattern_predicate,
359 },
360 Expr::CountSubquery(q) => {
361 Expr::CountSubquery(Box::new(rewrite_function_calls_in_query(*q, rename)?))
362 }
363 Expr::CollectSubquery(q) => {
364 Expr::CollectSubquery(Box::new(rewrite_function_calls_in_query(*q, rename)?))
365 }
366 Expr::IsNull(e) => Expr::IsNull(Box::new(rewrite_expr(*e, rename)?)),
367 Expr::IsNotNull(e) => Expr::IsNotNull(Box::new(rewrite_expr(*e, rename)?)),
368 Expr::IsUnique(e) => Expr::IsUnique(Box::new(rewrite_expr(*e, rename)?)),
369 Expr::In { expr, list } => Expr::In {
370 expr: Box::new(rewrite_expr(*expr, rename)?),
371 list: Box::new(rewrite_expr(*list, rename)?),
372 },
373 Expr::ArrayIndex { array, index } => Expr::ArrayIndex {
374 array: Box::new(rewrite_expr(*array, rename)?),
375 index: Box::new(rewrite_expr(*index, rename)?),
376 },
377 Expr::ArraySlice { array, start, end } => Expr::ArraySlice {
378 array: Box::new(rewrite_expr(*array, rename)?),
379 start: match start {
380 Some(s) => Some(Box::new(rewrite_expr(*s, rename)?)),
381 None => None,
382 },
383 end: match end {
384 Some(e) => Some(Box::new(rewrite_expr(*e, rename)?)),
385 None => None,
386 },
387 },
388 Expr::Quantifier {
389 quantifier,
390 variable,
391 list,
392 predicate,
393 } => Expr::Quantifier {
394 quantifier,
395 variable,
396 list: Box::new(rewrite_expr(*list, rename)?),
397 predicate: Box::new(rewrite_expr(*predicate, rename)?),
398 },
399 Expr::Reduce {
400 accumulator,
401 init,
402 variable,
403 list,
404 expr,
405 } => Expr::Reduce {
406 accumulator,
407 init: Box::new(rewrite_expr(*init, rename)?),
408 variable,
409 list: Box::new(rewrite_expr(*list, rename)?),
410 expr: Box::new(rewrite_expr(*expr, rename)?),
411 },
412 Expr::ListComprehension {
413 variable,
414 list,
415 where_clause,
416 map_expr,
417 } => Expr::ListComprehension {
418 variable,
419 list: Box::new(rewrite_expr(*list, rename)?),
420 where_clause: match where_clause {
421 Some(w) => Some(Box::new(rewrite_expr(*w, rename)?)),
422 None => None,
423 },
424 map_expr: Box::new(rewrite_expr(*map_expr, rename)?),
425 },
426 Expr::PatternComprehension {
427 path_variable,
428 pattern,
429 where_clause,
430 map_expr,
431 } => Expr::PatternComprehension {
432 path_variable,
433 pattern: rewrite_pattern(pattern, rename)?,
434 where_clause: match where_clause {
435 Some(w) => Some(Box::new(rewrite_expr(*w, rename)?)),
436 None => None,
437 },
438 map_expr: Box::new(rewrite_expr(*map_expr, rename)?),
439 },
440 Expr::ValidAt {
441 entity,
442 timestamp,
443 start_prop,
444 end_prop,
445 } => Expr::ValidAt {
446 entity: Box::new(rewrite_expr(*entity, rename)?),
447 timestamp: Box::new(rewrite_expr(*timestamp, rename)?),
448 start_prop,
449 end_prop,
450 },
451 Expr::MapProjection { base, items } => {
452 let mut new_items = Vec::with_capacity(items.len());
453 for item in items {
454 new_items.push(match item {
455 MapProjectionItem::LiteralEntry(k, v) => {
456 MapProjectionItem::LiteralEntry(k, Box::new(rewrite_expr(*v, rename)?))
457 }
458 other => other,
459 });
460 }
461 Expr::MapProjection {
462 base: Box::new(rewrite_expr(*base, rename)?),
463 items: new_items,
464 }
465 }
466 Expr::LabelCheck { expr, labels } => Expr::LabelCheck {
467 expr: Box::new(rewrite_expr(*expr, rename)?),
468 labels,
469 },
470 leaf @ (Expr::Literal(_) | Expr::Parameter(_) | Expr::Variable(_) | Expr::Wildcard) => leaf,
472 })
473}