Skip to main content

uni_query_functions/rewrite/
function_rename.rs

1//! Function-call name-substitution walker.
2//!
3//! Used by the planner to apply `ReplacementScanProvider`-driven function
4//! rewrites (M5 follow-up #5) as an AST pass before logical planning. The
5//! walker descends into all expression-bearing positions and calls a
6//! caller-supplied closure for each [`Expr::FunctionCall`] name. If the
7//! closure returns `Some(new_name)`, the call's name is substituted in
8//! place; arguments are recursively walked either way. Errors short-circuit
9//! the entire traversal.
10//!
11//! Mirrors the traversal shape of
12//! [`crate::rewrite::walker::ExpressionWalker`] — the two could
13//! eventually share a visitor trait, but for now the duplication is
14//! deliberate: this walker takes `&mut FnMut` (so it can capture the
15//! planner's `&self` plus mutable state for hop-cap enforcement) and
16//! propagates `Result` (so a wrong-variant or already-rerouted error
17//! aborts cleanly), neither of which the rule-driven walker supports.
18
19use anyhow::Result;
20use uni_cypher::ast::{
21    Clause, Expr, MapProjectionItem, Pattern, PatternElement, Query, RemoveItem, ReturnItem,
22    SetItem, SortItem, Statement,
23};
24
25/// Walk `query`, calling `rename` on every [`Expr::FunctionCall`] name
26/// (post-order: arguments are visited first). When `rename` returns
27/// `Some(new_name)`, the call's `name` is replaced. When it returns
28/// `None`, the original name is kept. Errors propagate.
29pub fn rewrite_function_calls_in_query<F>(query: Query, rename: &mut F) -> Result<Query>
30where
31    F: FnMut(&str) -> Result<Option<String>>,
32{
33    match query {
34        Query::Single(stmt) => Ok(Query::Single(rewrite_statement(stmt, rename)?)),
35        Query::Union { left, right, all } => Ok(Query::Union {
36            left: Box::new(rewrite_function_calls_in_query(*left, rename)?),
37            right: Box::new(rewrite_function_calls_in_query(*right, rename)?),
38            all,
39        }),
40        Query::Schema(s) => Ok(Query::Schema(s)),
41        Query::Explain(inner) => Ok(Query::Explain(Box::new(rewrite_function_calls_in_query(
42            *inner, rename,
43        )?))),
44        Query::TimeTravel { .. } => Ok(query),
45    }
46}
47
48fn rewrite_statement<F>(stmt: Statement, rename: &mut F) -> Result<Statement>
49where
50    F: FnMut(&str) -> Result<Option<String>>,
51{
52    let mut clauses = Vec::with_capacity(stmt.clauses.len());
53    for c in stmt.clauses {
54        clauses.push(rewrite_clause(c, rename)?);
55    }
56    Ok(Statement { clauses })
57}
58
59fn rewrite_clause<F>(clause: Clause, rename: &mut F) -> Result<Clause>
60where
61    F: FnMut(&str) -> Result<Option<String>>,
62{
63    Ok(match clause {
64        Clause::Match(m) => Clause::Match(uni_cypher::ast::MatchClause {
65            optional: m.optional,
66            for_update: m.for_update,
67            pattern: rewrite_pattern(m.pattern, rename)?,
68            where_clause: opt_expr(m.where_clause, rename)?,
69        }),
70        Clause::Create(c) => Clause::Create(uni_cypher::ast::CreateClause {
71            pattern: rewrite_pattern(c.pattern, rename)?,
72        }),
73        Clause::Return(r) => Clause::Return(uni_cypher::ast::ReturnClause {
74            distinct: r.distinct,
75            items: r
76                .items
77                .into_iter()
78                .map(|item| rewrite_return_item(item, rename))
79                .collect::<Result<_>>()?,
80            order_by: rewrite_order_by(r.order_by, rename)?,
81            skip: opt_expr(r.skip, rename)?,
82            limit: opt_expr(r.limit, rename)?,
83        }),
84        Clause::With(w) => Clause::With(uni_cypher::ast::WithClause {
85            distinct: w.distinct,
86            items: w
87                .items
88                .into_iter()
89                .map(|item| rewrite_return_item(item, rename))
90                .collect::<Result<_>>()?,
91            order_by: rewrite_order_by(w.order_by, rename)?,
92            skip: opt_expr(w.skip, rename)?,
93            limit: opt_expr(w.limit, rename)?,
94            where_clause: opt_expr(w.where_clause, rename)?,
95        }),
96        Clause::Unwind(u) => Clause::Unwind(uni_cypher::ast::UnwindClause {
97            expr: rewrite_expr(u.expr, rename)?,
98            variable: u.variable,
99        }),
100        Clause::Set(s) => Clause::Set(uni_cypher::ast::SetClause {
101            items: s
102                .items
103                .into_iter()
104                .map(|item| rewrite_set_item(item, rename))
105                .collect::<Result<_>>()?,
106        }),
107        Clause::Delete(d) => Clause::Delete(uni_cypher::ast::DeleteClause {
108            detach: d.detach,
109            items: d
110                .items
111                .into_iter()
112                .map(|e| rewrite_expr(e, rename))
113                .collect::<Result<_>>()?,
114        }),
115        Clause::Remove(r) => Clause::Remove(uni_cypher::ast::RemoveClause {
116            items: r
117                .items
118                .into_iter()
119                .map(|item| rewrite_remove_item(item, rename))
120                .collect::<Result<_>>()?,
121        }),
122        Clause::Call(mut call) => {
123            // Procedure arguments and YIELD where-clauses can carry FunctionCalls.
124            match &mut call.kind {
125                uni_cypher::ast::CallKind::Procedure { arguments, .. } => {
126                    let mut new_args = Vec::with_capacity(arguments.len());
127                    for a in arguments.drain(..) {
128                        new_args.push(rewrite_expr(a, rename)?);
129                    }
130                    *arguments = new_args;
131                }
132                uni_cypher::ast::CallKind::Subquery(query) => {
133                    let q = std::mem::replace(
134                        query.as_mut(),
135                        Query::Single(Statement { clauses: vec![] }),
136                    );
137                    **query = rewrite_function_calls_in_query(q, rename)?;
138                }
139            }
140            if let Some(w) = call.where_clause.take() {
141                call.where_clause = Some(rewrite_expr(w, rename)?);
142            }
143            Clause::Call(call)
144        }
145        // Clauses we don't traverse (no expressions, or not user-rewritable here).
146        other => other,
147    })
148}
149
150fn rewrite_set_item<F>(item: SetItem, rename: &mut F) -> Result<SetItem>
151where
152    F: FnMut(&str) -> Result<Option<String>>,
153{
154    Ok(match item {
155        SetItem::Property { expr, value } => SetItem::Property {
156            expr: rewrite_expr(expr, rename)?,
157            value: rewrite_expr(value, rename)?,
158        },
159        SetItem::Variable { variable, value } => SetItem::Variable {
160            variable,
161            value: rewrite_expr(value, rename)?,
162        },
163        SetItem::VariablePlus { variable, value } => SetItem::VariablePlus {
164            variable,
165            value: rewrite_expr(value, rename)?,
166        },
167        SetItem::Labels { variable, labels } => SetItem::Labels { variable, labels },
168    })
169}
170
171fn rewrite_remove_item<F>(item: RemoveItem, rename: &mut F) -> Result<RemoveItem>
172where
173    F: FnMut(&str) -> Result<Option<String>>,
174{
175    Ok(match item {
176        RemoveItem::Property(e) => RemoveItem::Property(rewrite_expr(e, rename)?),
177        RemoveItem::Labels { variable, labels } => RemoveItem::Labels { variable, labels },
178    })
179}
180
181fn rewrite_return_item<F>(item: ReturnItem, rename: &mut F) -> Result<ReturnItem>
182where
183    F: FnMut(&str) -> Result<Option<String>>,
184{
185    Ok(match item {
186        ReturnItem::All => ReturnItem::All,
187        ReturnItem::Expr {
188            expr,
189            alias,
190            source_text,
191        } => ReturnItem::Expr {
192            expr: rewrite_expr(expr, rename)?,
193            alias,
194            source_text,
195        },
196    })
197}
198
199fn rewrite_order_by<F>(
200    order_by: Option<Vec<SortItem>>,
201    rename: &mut F,
202) -> Result<Option<Vec<SortItem>>>
203where
204    F: FnMut(&str) -> Result<Option<String>>,
205{
206    let Some(items) = order_by else {
207        return Ok(None);
208    };
209    let mut out = Vec::with_capacity(items.len());
210    for item in items {
211        out.push(SortItem {
212            expr: rewrite_expr(item.expr, rename)?,
213            ascending: item.ascending,
214        });
215    }
216    Ok(Some(out))
217}
218
219fn rewrite_pattern<F>(pattern: Pattern, rename: &mut F) -> Result<Pattern>
220where
221    F: FnMut(&str) -> Result<Option<String>>,
222{
223    let mut paths = Vec::with_capacity(pattern.paths.len());
224    for path in pattern.paths {
225        paths.push(uni_cypher::ast::PathPattern {
226            variable: path.variable,
227            elements: path
228                .elements
229                .into_iter()
230                .map(|e| rewrite_pattern_element(e, rename))
231                .collect::<Result<_>>()?,
232            shortest_path_mode: path.shortest_path_mode,
233        });
234    }
235    Ok(Pattern { paths })
236}
237
238fn rewrite_pattern_element<F>(elem: PatternElement, rename: &mut F) -> Result<PatternElement>
239where
240    F: FnMut(&str) -> Result<Option<String>>,
241{
242    Ok(match elem {
243        PatternElement::Node(n) => PatternElement::Node(uni_cypher::ast::NodePattern {
244            variable: n.variable,
245            labels: n.labels,
246            properties: opt_expr(n.properties, rename)?,
247            where_clause: opt_expr(n.where_clause, rename)?,
248        }),
249        PatternElement::Relationship(r) => {
250            PatternElement::Relationship(uni_cypher::ast::RelationshipPattern {
251                variable: r.variable,
252                types: r.types,
253                direction: r.direction,
254                properties: opt_expr(r.properties, rename)?,
255                range: r.range,
256                where_clause: opt_expr(r.where_clause, rename)?,
257            })
258        }
259        PatternElement::Parenthesized { pattern, range } => PatternElement::Parenthesized {
260            pattern: Box::new(uni_cypher::ast::PathPattern {
261                variable: pattern.variable,
262                elements: pattern
263                    .elements
264                    .into_iter()
265                    .map(|e| rewrite_pattern_element(e, rename))
266                    .collect::<Result<_>>()?,
267                shortest_path_mode: pattern.shortest_path_mode,
268            }),
269            range,
270        },
271    })
272}
273
274fn opt_expr<F>(e: Option<Expr>, rename: &mut F) -> Result<Option<Expr>>
275where
276    F: FnMut(&str) -> Result<Option<String>>,
277{
278    match e {
279        Some(e) => Ok(Some(rewrite_expr(e, rename)?)),
280        None => Ok(None),
281    }
282}
283
284fn rewrite_expr<F>(expr: Expr, rename: &mut F) -> Result<Expr>
285where
286    F: FnMut(&str) -> Result<Option<String>>,
287{
288    Ok(match expr {
289        Expr::FunctionCall {
290            name,
291            args,
292            distinct,
293            window_spec,
294        } => {
295            let mut new_args = Vec::with_capacity(args.len());
296            for a in args {
297                new_args.push(rewrite_expr(a, rename)?);
298            }
299            let new_name = rename(&name)?.unwrap_or(name);
300            Expr::FunctionCall {
301                name: new_name,
302                args: new_args,
303                distinct,
304                window_spec,
305            }
306        }
307        Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
308            left: Box::new(rewrite_expr(*left, rename)?),
309            op,
310            right: Box::new(rewrite_expr(*right, rename)?),
311        },
312        Expr::UnaryOp { op, expr } => Expr::UnaryOp {
313            op,
314            expr: Box::new(rewrite_expr(*expr, rename)?),
315        },
316        Expr::Property(base, prop) => Expr::Property(Box::new(rewrite_expr(*base, rename)?), prop),
317        Expr::List(exprs) => Expr::List(
318            exprs
319                .into_iter()
320                .map(|e| rewrite_expr(e, rename))
321                .collect::<Result<_>>()?,
322        ),
323        Expr::Map(entries) => {
324            let mut out = Vec::with_capacity(entries.len());
325            for (k, v) in entries {
326                out.push((k, rewrite_expr(v, rename)?));
327            }
328            Expr::Map(out)
329        }
330        Expr::Case {
331            expr,
332            when_then,
333            else_expr,
334        } => {
335            let expr = match expr {
336                Some(e) => Some(Box::new(rewrite_expr(*e, rename)?)),
337                None => None,
338            };
339            let mut new_when = Vec::with_capacity(when_then.len());
340            for (w, t) in when_then {
341                new_when.push((rewrite_expr(w, rename)?, rewrite_expr(t, rename)?));
342            }
343            let else_expr = match else_expr {
344                Some(e) => Some(Box::new(rewrite_expr(*e, rename)?)),
345                None => None,
346            };
347            Expr::Case {
348                expr,
349                when_then: new_when,
350                else_expr,
351            }
352        }
353        Expr::Exists {
354            query,
355            from_pattern_predicate,
356        } => Expr::Exists {
357            query: Box::new(rewrite_function_calls_in_query(*query, rename)?),
358            from_pattern_predicate,
359        },
360        Expr::CountSubquery(q) => {
361            Expr::CountSubquery(Box::new(rewrite_function_calls_in_query(*q, rename)?))
362        }
363        Expr::CollectSubquery(q) => {
364            Expr::CollectSubquery(Box::new(rewrite_function_calls_in_query(*q, rename)?))
365        }
366        Expr::IsNull(e) => Expr::IsNull(Box::new(rewrite_expr(*e, rename)?)),
367        Expr::IsNotNull(e) => Expr::IsNotNull(Box::new(rewrite_expr(*e, rename)?)),
368        Expr::IsUnique(e) => Expr::IsUnique(Box::new(rewrite_expr(*e, rename)?)),
369        Expr::In { expr, list } => Expr::In {
370            expr: Box::new(rewrite_expr(*expr, rename)?),
371            list: Box::new(rewrite_expr(*list, rename)?),
372        },
373        Expr::ArrayIndex { array, index } => Expr::ArrayIndex {
374            array: Box::new(rewrite_expr(*array, rename)?),
375            index: Box::new(rewrite_expr(*index, rename)?),
376        },
377        Expr::ArraySlice { array, start, end } => Expr::ArraySlice {
378            array: Box::new(rewrite_expr(*array, rename)?),
379            start: match start {
380                Some(s) => Some(Box::new(rewrite_expr(*s, rename)?)),
381                None => None,
382            },
383            end: match end {
384                Some(e) => Some(Box::new(rewrite_expr(*e, rename)?)),
385                None => None,
386            },
387        },
388        Expr::Quantifier {
389            quantifier,
390            variable,
391            list,
392            predicate,
393        } => Expr::Quantifier {
394            quantifier,
395            variable,
396            list: Box::new(rewrite_expr(*list, rename)?),
397            predicate: Box::new(rewrite_expr(*predicate, rename)?),
398        },
399        Expr::Reduce {
400            accumulator,
401            init,
402            variable,
403            list,
404            expr,
405        } => Expr::Reduce {
406            accumulator,
407            init: Box::new(rewrite_expr(*init, rename)?),
408            variable,
409            list: Box::new(rewrite_expr(*list, rename)?),
410            expr: Box::new(rewrite_expr(*expr, rename)?),
411        },
412        Expr::ListComprehension {
413            variable,
414            list,
415            where_clause,
416            map_expr,
417        } => Expr::ListComprehension {
418            variable,
419            list: Box::new(rewrite_expr(*list, rename)?),
420            where_clause: match where_clause {
421                Some(w) => Some(Box::new(rewrite_expr(*w, rename)?)),
422                None => None,
423            },
424            map_expr: Box::new(rewrite_expr(*map_expr, rename)?),
425        },
426        Expr::PatternComprehension {
427            path_variable,
428            pattern,
429            where_clause,
430            map_expr,
431        } => Expr::PatternComprehension {
432            path_variable,
433            pattern: rewrite_pattern(pattern, rename)?,
434            where_clause: match where_clause {
435                Some(w) => Some(Box::new(rewrite_expr(*w, rename)?)),
436                None => None,
437            },
438            map_expr: Box::new(rewrite_expr(*map_expr, rename)?),
439        },
440        Expr::ValidAt {
441            entity,
442            timestamp,
443            start_prop,
444            end_prop,
445        } => Expr::ValidAt {
446            entity: Box::new(rewrite_expr(*entity, rename)?),
447            timestamp: Box::new(rewrite_expr(*timestamp, rename)?),
448            start_prop,
449            end_prop,
450        },
451        Expr::MapProjection { base, items } => {
452            let mut new_items = Vec::with_capacity(items.len());
453            for item in items {
454                new_items.push(match item {
455                    MapProjectionItem::LiteralEntry(k, v) => {
456                        MapProjectionItem::LiteralEntry(k, Box::new(rewrite_expr(*v, rename)?))
457                    }
458                    other => other,
459                });
460            }
461            Expr::MapProjection {
462                base: Box::new(rewrite_expr(*base, rename)?),
463                items: new_items,
464            }
465        }
466        Expr::LabelCheck { expr, labels } => Expr::LabelCheck {
467            expr: Box::new(rewrite_expr(*expr, rename)?),
468            labels,
469        },
470        // Leaves.
471        leaf @ (Expr::Literal(_) | Expr::Parameter(_) | Expr::Variable(_) | Expr::Wildcard) => leaf,
472    })
473}