sqltk_parser/ast/
visitor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqltk_parser AST nodes
29/// using the [Visit](sqltk_parser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
36}
37
38/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
39/// recursively visiting parsed SQL statements.
40///
41/// # Note
42///
43/// This trait should be automatically derived for sqltk_parser AST nodes
44/// using the [VisitMut](sqltk_parser_derive::VisitMut) proc macro.
45///
46/// ```text
47/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
48/// ```
49pub trait VisitMut {
50    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
51}
52
53impl<T: Visit> Visit for Option<T> {
54    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
55        if let Some(s) = self {
56            s.visit(visitor)?;
57        }
58        ControlFlow::Continue(())
59    }
60}
61
62impl<T: Visit> Visit for Vec<T> {
63    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64        for v in self {
65            v.visit(visitor)?;
66        }
67        ControlFlow::Continue(())
68    }
69}
70
71impl<T: Visit> Visit for Box<T> {
72    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73        T::visit(self, visitor)
74    }
75}
76
77impl<T: VisitMut> VisitMut for Option<T> {
78    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
79        if let Some(s) = self {
80            s.visit(visitor)?;
81        }
82        ControlFlow::Continue(())
83    }
84}
85
86impl<T: VisitMut> VisitMut for Vec<T> {
87    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88        for v in self {
89            v.visit(visitor)?;
90        }
91        ControlFlow::Continue(())
92    }
93}
94
95impl<T: VisitMut> VisitMut for Box<T> {
96    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97        T::visit(self, visitor)
98    }
99}
100
101macro_rules! visit_noop {
102    ($($t:ty),+) => {
103        $(impl Visit for $t {
104            fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
105               ControlFlow::Continue(())
106            }
107        })+
108        $(impl VisitMut for $t {
109            fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
110               ControlFlow::Continue(())
111            }
112        })+
113    };
114}
115
116visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
117
118#[cfg(feature = "bigdecimal")]
119visit_noop!(bigdecimal::BigDecimal);
120
121/// A visitor that can be used to walk an AST tree.
122///
123/// `pre_visit_` methods are invoked before visiting all children of the
124/// node and `post_visit_` methods are invoked after visiting all
125/// children of the node.
126///
127/// # See also
128///
129/// These methods provide a more concise way of visiting nodes of a certain type:
130/// * [visit_relations]
131/// * [visit_expressions]
132/// * [visit_statements]
133///
134/// # Example
135/// ```
136/// # use sqltk_parser::parser::Parser;
137/// # use sqltk_parser::dialect::GenericDialect;
138/// # use sqltk_parser::ast::{Visit, Visitor, ObjectName, Expr};
139/// # use core::ops::ControlFlow;
140/// // A structure that records statements and relations
141/// #[derive(Default)]
142/// struct V {
143///    visited: Vec<String>,
144/// }
145///
146/// // Visit relations and exprs before children are visited (depth first walk)
147/// // Note you can also visit statements and visit exprs after children have been visited
148/// impl Visitor for V {
149///   type Break = ();
150///
151///   fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
152///     self.visited.push(format!("PRE: RELATION: {}", relation));
153///     ControlFlow::Continue(())
154///   }
155///
156///   fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
157///     self.visited.push(format!("PRE: EXPR: {}", expr));
158///     ControlFlow::Continue(())
159///   }
160/// }
161///
162/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
163/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
164///    .unwrap();
165///
166/// // Drive the visitor through the AST
167/// let mut visitor = V::default();
168/// statements.visit(&mut visitor);
169///
170/// // The visitor has visited statements and expressions in pre-traversal order
171/// let expected : Vec<_> = [
172///   "PRE: EXPR: a",
173///   "PRE: RELATION: foo",
174///   "PRE: EXPR: x IN (SELECT y FROM bar)",
175///   "PRE: EXPR: x",
176///   "PRE: EXPR: y",
177///   "PRE: RELATION: bar",
178/// ]
179///   .into_iter().map(|s| s.to_string()).collect();
180///
181/// assert_eq!(visitor.visited, expected);
182/// ```
183pub trait Visitor {
184    /// Type returned when the recursion returns early.
185    type Break;
186
187    /// Invoked for any queries that appear in the AST before visiting children
188    fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189        ControlFlow::Continue(())
190    }
191
192    /// Invoked for any queries that appear in the AST after visiting children
193    fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
194        ControlFlow::Continue(())
195    }
196
197    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
198    fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
199        ControlFlow::Continue(())
200    }
201
202    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
203    fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
204        ControlFlow::Continue(())
205    }
206
207    /// Invoked for any table factors that appear in the AST before visiting children
208    fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
209        ControlFlow::Continue(())
210    }
211
212    /// Invoked for any table factors that appear in the AST after visiting children
213    fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
214        ControlFlow::Continue(())
215    }
216
217    /// Invoked for any expressions that appear in the AST before visiting children
218    fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
219        ControlFlow::Continue(())
220    }
221
222    /// Invoked for any expressions that appear in the AST
223    fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
224        ControlFlow::Continue(())
225    }
226
227    /// Invoked for any statements that appear in the AST before visiting children
228    fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
229        ControlFlow::Continue(())
230    }
231
232    /// Invoked for any statements that appear in the AST after visiting children
233    fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
234        ControlFlow::Continue(())
235    }
236}
237
238/// A visitor that can be used to mutate an AST tree.
239///
240/// `pre_visit_` methods are invoked before visiting all children of the
241/// node and `post_visit_` methods are invoked after visiting all
242/// children of the node.
243///
244/// # See also
245///
246/// These methods provide a more concise way of visiting nodes of a certain type:
247/// * [visit_relations_mut]
248/// * [visit_expressions_mut]
249/// * [visit_statements_mut]
250///
251/// # Example
252/// ```
253/// # use sqltk_parser::parser::Parser;
254/// # use sqltk_parser::dialect::GenericDialect;
255/// # use sqltk_parser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
256/// # use core::ops::ControlFlow;
257///
258/// // A visitor that replaces "to_replace" with "replaced" in all expressions
259/// struct Replacer;
260///
261/// // Visit each expression after its children have been visited
262/// impl VisitorMut for Replacer {
263///   type Break = ();
264///
265///   fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
266///     if let Expr::Identifier(Ident{ value, ..}) = expr {
267///         *value = value.replace("to_replace", "replaced")
268///     }
269///     ControlFlow::Continue(())
270///   }
271/// }
272///
273/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
274/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
275///
276/// // Drive the visitor through the AST
277/// statements.visit(&mut Replacer);
278///
279/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
280/// ```
281pub trait VisitorMut {
282    /// Type returned when the recursion returns early.
283    type Break;
284
285    /// Invoked for any queries that appear in the AST before visiting children
286    fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
287        ControlFlow::Continue(())
288    }
289
290    /// Invoked for any queries that appear in the AST after visiting children
291    fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
292        ControlFlow::Continue(())
293    }
294
295    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
296    fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
297        ControlFlow::Continue(())
298    }
299
300    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
301    fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
302        ControlFlow::Continue(())
303    }
304
305    /// Invoked for any table factors that appear in the AST before visiting children
306    fn pre_visit_table_factor(
307        &mut self,
308        _table_factor: &mut TableFactor,
309    ) -> ControlFlow<Self::Break> {
310        ControlFlow::Continue(())
311    }
312
313    /// Invoked for any table factors that appear in the AST after visiting children
314    fn post_visit_table_factor(
315        &mut self,
316        _table_factor: &mut TableFactor,
317    ) -> ControlFlow<Self::Break> {
318        ControlFlow::Continue(())
319    }
320
321    /// Invoked for any expressions that appear in the AST before visiting children
322    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
323        ControlFlow::Continue(())
324    }
325
326    /// Invoked for any expressions that appear in the AST
327    fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
328        ControlFlow::Continue(())
329    }
330
331    /// Invoked for any statements that appear in the AST before visiting children
332    fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
333        ControlFlow::Continue(())
334    }
335
336    /// Invoked for any statements that appear in the AST after visiting children
337    fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
338        ControlFlow::Continue(())
339    }
340}
341
342struct RelationVisitor<F>(F);
343
344impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
345    type Break = E;
346
347    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
348        self.0(relation)
349    }
350}
351
352impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
353    type Break = E;
354
355    fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
356        self.0(relation)
357    }
358}
359
360/// Invokes the provided closure on all relations (e.g. table names) present in `v`
361///
362/// # Example
363/// ```
364/// # use sqltk_parser::parser::Parser;
365/// # use sqltk_parser::dialect::GenericDialect;
366/// # use sqltk_parser::ast::{visit_relations};
367/// # use core::ops::ControlFlow;
368/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
369/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
370///    .unwrap();
371///
372/// // visit statements, capturing relations (table names)
373/// let mut visited = vec![];
374/// visit_relations(&statements, |relation| {
375///   visited.push(format!("RELATION: {}", relation));
376///   ControlFlow::<()>::Continue(())
377/// });
378///
379/// let expected : Vec<_> = [
380///   "RELATION: foo",
381///   "RELATION: bar",
382/// ]
383///   .into_iter().map(|s| s.to_string()).collect();
384///
385/// assert_eq!(visited, expected);
386/// ```
387pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
388where
389    V: Visit,
390    F: FnMut(&ObjectName) -> ControlFlow<E>,
391{
392    let mut visitor = RelationVisitor(f);
393    v.visit(&mut visitor)?;
394    ControlFlow::Continue(())
395}
396
397/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
398/// present in `v`.
399///
400/// When the closure mutates its argument, the new mutated relation will not be visited again.
401///
402/// # Example
403/// ```
404/// # use sqltk_parser::parser::Parser;
405/// # use sqltk_parser::dialect::GenericDialect;
406/// # use sqltk_parser::ast::{ObjectName, visit_relations_mut};
407/// # use core::ops::ControlFlow;
408/// let sql = "SELECT a FROM foo";
409/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
410///    .unwrap();
411///
412/// // visit statements, renaming table foo to bar
413/// visit_relations_mut(&mut statements, |table| {
414///   table.0[0].value = table.0[0].value.replace("foo", "bar");
415///   ControlFlow::<()>::Continue(())
416/// });
417///
418/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
419/// ```
420pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
421where
422    V: VisitMut,
423    F: FnMut(&mut ObjectName) -> ControlFlow<E>,
424{
425    let mut visitor = RelationVisitor(f);
426    v.visit(&mut visitor)?;
427    ControlFlow::Continue(())
428}
429
430struct ExprVisitor<F>(F);
431
432impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
433    type Break = E;
434
435    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
436        self.0(expr)
437    }
438}
439
440impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
441    type Break = E;
442
443    fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
444        self.0(expr)
445    }
446}
447
448/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
449///
450/// # Example
451/// ```
452/// # use sqltk_parser::parser::Parser;
453/// # use sqltk_parser::dialect::GenericDialect;
454/// # use sqltk_parser::ast::{visit_expressions};
455/// # use core::ops::ControlFlow;
456/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
457/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
458///    .unwrap();
459///
460/// // visit all expressions
461/// let mut visited = vec![];
462/// visit_expressions(&statements, |expr| {
463///   visited.push(format!("EXPR: {}", expr));
464///   ControlFlow::<()>::Continue(())
465/// });
466///
467/// let expected : Vec<_> = [
468///   "EXPR: a",
469///   "EXPR: x IN (SELECT y FROM bar)",
470///   "EXPR: x",
471///   "EXPR: y",
472/// ]
473///   .into_iter().map(|s| s.to_string()).collect();
474///
475/// assert_eq!(visited, expected);
476/// ```
477pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
478where
479    V: Visit,
480    F: FnMut(&Expr) -> ControlFlow<E>,
481{
482    let mut visitor = ExprVisitor(f);
483    v.visit(&mut visitor)?;
484    ControlFlow::Continue(())
485}
486
487/// Invokes the provided closure iteratively with a mutable reference to all expressions
488/// present in `v`.
489///
490/// This performs a depth-first search, so if the closure mutates the expression
491///
492/// # Example
493///
494/// ## Remove all select limits in sub-queries
495/// ```
496/// # use sqltk_parser::parser::Parser;
497/// # use sqltk_parser::dialect::GenericDialect;
498/// # use sqltk_parser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
499/// # use core::ops::ControlFlow;
500/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
501/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
502///
503/// // Remove all select limits in sub-queries
504/// visit_expressions_mut(&mut statements, |expr| {
505///   if let Expr::Subquery(q) = expr {
506///      q.limit = None
507///   }
508///   ControlFlow::<()>::Continue(())
509/// });
510///
511/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
512/// ```
513///
514/// ## Wrap column name in function call
515///
516/// This demonstrates how to effectively replace an expression with another more complicated one
517/// that references the original. This example avoids unnecessary allocations by using the
518/// [`std::mem`] family of functions.
519///
520/// ```
521/// # use sqltk_parser::parser::Parser;
522/// # use sqltk_parser::dialect::GenericDialect;
523/// # use sqltk_parser::ast::*;
524/// # use core::ops::ControlFlow;
525/// let sql = "SELECT x, y FROM t";
526/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
527///
528/// visit_expressions_mut(&mut statements, |expr| {
529///   if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
530///     let old_expr = std::mem::replace(expr, Expr::Value(Value::Null));
531///     *expr = Expr::Function(Function {
532///           name: ObjectName(vec![Ident::new("f")]),
533///           args: FunctionArguments::List(FunctionArgumentList {
534///               duplicate_treatment: None,
535///               args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
536///               clauses: vec![],
537///           }),
538///           null_treatment: None,
539///           filter: None,
540///           over: None,
541///           parameters: FunctionArguments::None,
542///           within_group: vec![],
543///      });
544///   }
545///   ControlFlow::<()>::Continue(())
546/// });
547///
548/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
549/// ```
550pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
551where
552    V: VisitMut,
553    F: FnMut(&mut Expr) -> ControlFlow<E>,
554{
555    v.visit(&mut ExprVisitor(f))?;
556    ControlFlow::Continue(())
557}
558
559struct StatementVisitor<F>(F);
560
561impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
562    type Break = E;
563
564    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
565        self.0(statement)
566    }
567}
568
569impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
570    type Break = E;
571
572    fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
573        self.0(statement)
574    }
575}
576
577/// Invokes the provided closure iteratively with a mutable reference to all statements
578/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
579///
580/// # Example
581/// ```
582/// # use sqltk_parser::parser::Parser;
583/// # use sqltk_parser::dialect::GenericDialect;
584/// # use sqltk_parser::ast::{visit_statements};
585/// # use core::ops::ControlFlow;
586/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
587/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
588///    .unwrap();
589///
590/// // visit all statements
591/// let mut visited = vec![];
592/// visit_statements(&statements, |stmt| {
593///   visited.push(format!("STATEMENT: {}", stmt));
594///   ControlFlow::<()>::Continue(())
595/// });
596///
597/// let expected : Vec<_> = [
598///   "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
599///   "STATEMENT: CREATE TABLE baz (q INT)"
600/// ]
601///   .into_iter().map(|s| s.to_string()).collect();
602///
603/// assert_eq!(visited, expected);
604/// ```
605pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
606where
607    V: Visit,
608    F: FnMut(&Statement) -> ControlFlow<E>,
609{
610    let mut visitor = StatementVisitor(f);
611    v.visit(&mut visitor)?;
612    ControlFlow::Continue(())
613}
614
615/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
616///
617/// # Example
618/// ```
619/// # use sqltk_parser::parser::Parser;
620/// # use sqltk_parser::dialect::GenericDialect;
621/// # use sqltk_parser::ast::{Statement, visit_statements_mut};
622/// # use core::ops::ControlFlow;
623/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
624/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
625///
626/// // Remove all select limits in outer statements (not in sub-queries)
627/// visit_statements_mut(&mut statements, |stmt| {
628///   if let Statement::Query(q) = stmt {
629///      q.limit = None
630///   }
631///   ControlFlow::<()>::Continue(())
632/// });
633///
634/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
635/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
636/// ```
637pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
638where
639    V: VisitMut,
640    F: FnMut(&mut Statement) -> ControlFlow<E>,
641{
642    v.visit(&mut StatementVisitor(f))?;
643    ControlFlow::Continue(())
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use crate::dialect::GenericDialect;
650    use crate::parser::Parser;
651    use crate::tokenizer::Tokenizer;
652
653    #[derive(Default)]
654    struct TestVisitor {
655        visited: Vec<String>,
656    }
657
658    impl Visitor for TestVisitor {
659        type Break = ();
660
661        /// Invoked for any queries that appear in the AST before visiting children
662        fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
663            self.visited.push(format!("PRE: QUERY: {query}"));
664            ControlFlow::Continue(())
665        }
666
667        /// Invoked for any queries that appear in the AST after visiting children
668        fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
669            self.visited.push(format!("POST: QUERY: {query}"));
670            ControlFlow::Continue(())
671        }
672
673        fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
674            self.visited.push(format!("PRE: RELATION: {relation}"));
675            ControlFlow::Continue(())
676        }
677
678        fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
679            self.visited.push(format!("POST: RELATION: {relation}"));
680            ControlFlow::Continue(())
681        }
682
683        fn pre_visit_table_factor(
684            &mut self,
685            table_factor: &TableFactor,
686        ) -> ControlFlow<Self::Break> {
687            self.visited
688                .push(format!("PRE: TABLE FACTOR: {table_factor}"));
689            ControlFlow::Continue(())
690        }
691
692        fn post_visit_table_factor(
693            &mut self,
694            table_factor: &TableFactor,
695        ) -> ControlFlow<Self::Break> {
696            self.visited
697                .push(format!("POST: TABLE FACTOR: {table_factor}"));
698            ControlFlow::Continue(())
699        }
700
701        fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
702            self.visited.push(format!("PRE: EXPR: {expr}"));
703            ControlFlow::Continue(())
704        }
705
706        fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
707            self.visited.push(format!("POST: EXPR: {expr}"));
708            ControlFlow::Continue(())
709        }
710
711        fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
712            self.visited.push(format!("PRE: STATEMENT: {statement}"));
713            ControlFlow::Continue(())
714        }
715
716        fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
717            self.visited.push(format!("POST: STATEMENT: {statement}"));
718            ControlFlow::Continue(())
719        }
720    }
721
722    fn do_visit(sql: &str) -> Vec<String> {
723        let dialect = GenericDialect {};
724        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
725        let s = Parser::new(&dialect)
726            .with_tokens(tokens)
727            .parse_statement()
728            .unwrap();
729
730        let mut visitor = TestVisitor::default();
731        s.visit(&mut visitor);
732        visitor.visited
733    }
734
735    #[test]
736    fn test_sql() {
737        let tests = vec![
738            (
739                "SELECT * from table_name as my_table",
740                vec![
741                    "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
742                    "PRE: QUERY: SELECT * FROM table_name AS my_table",
743                    "PRE: TABLE FACTOR: table_name AS my_table",
744                    "PRE: RELATION: table_name",
745                    "POST: RELATION: table_name",
746                    "POST: TABLE FACTOR: table_name AS my_table",
747                    "POST: QUERY: SELECT * FROM table_name AS my_table",
748                    "POST: STATEMENT: SELECT * FROM table_name AS my_table",
749                ],
750            ),
751            (
752                "SELECT * from t1 join t2 on t1.id = t2.t1_id",
753                vec![
754                    "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
755                    "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
756                    "PRE: TABLE FACTOR: t1",
757                    "PRE: RELATION: t1",
758                    "POST: RELATION: t1",
759                    "POST: TABLE FACTOR: t1",
760                    "PRE: TABLE FACTOR: t2",
761                    "PRE: RELATION: t2",
762                    "POST: RELATION: t2",
763                    "POST: TABLE FACTOR: t2",
764                    "PRE: EXPR: t1.id = t2.t1_id",
765                    "PRE: EXPR: t1.id",
766                    "POST: EXPR: t1.id",
767                    "PRE: EXPR: t2.t1_id",
768                    "POST: EXPR: t2.t1_id",
769                    "POST: EXPR: t1.id = t2.t1_id",
770                    "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
771                    "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
772                ],
773            ),
774            (
775                "SELECT * from t1 where EXISTS(SELECT column from t2)",
776                vec![
777                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
778                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
779                    "PRE: TABLE FACTOR: t1",
780                    "PRE: RELATION: t1",
781                    "POST: RELATION: t1",
782                    "POST: TABLE FACTOR: t1",
783                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
784                    "PRE: QUERY: SELECT column FROM t2",
785                    "PRE: EXPR: column",
786                    "POST: EXPR: column",
787                    "PRE: TABLE FACTOR: t2",
788                    "PRE: RELATION: t2",
789                    "POST: RELATION: t2",
790                    "POST: TABLE FACTOR: t2",
791                    "POST: QUERY: SELECT column FROM t2",
792                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
793                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
794                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
795                ],
796            ),
797            (
798                "SELECT * from t1 where EXISTS(SELECT column from t2)",
799                vec![
800                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
801                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
802                    "PRE: TABLE FACTOR: t1",
803                    "PRE: RELATION: t1",
804                    "POST: RELATION: t1",
805                    "POST: TABLE FACTOR: t1",
806                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
807                    "PRE: QUERY: SELECT column FROM t2",
808                    "PRE: EXPR: column",
809                    "POST: EXPR: column",
810                    "PRE: TABLE FACTOR: t2",
811                    "PRE: RELATION: t2",
812                    "POST: RELATION: t2",
813                    "POST: TABLE FACTOR: t2",
814                    "POST: QUERY: SELECT column FROM t2",
815                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
816                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
817                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
818                ],
819            ),
820            (
821                "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
822                vec![
823                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
824                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
825                    "PRE: TABLE FACTOR: t1",
826                    "PRE: RELATION: t1",
827                    "POST: RELATION: t1",
828                    "POST: TABLE FACTOR: t1",
829                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
830                    "PRE: QUERY: SELECT column FROM t2",
831                    "PRE: EXPR: column",
832                    "POST: EXPR: column",
833                    "PRE: TABLE FACTOR: t2",
834                    "PRE: RELATION: t2",
835                    "POST: RELATION: t2",
836                    "POST: TABLE FACTOR: t2",
837                    "POST: QUERY: SELECT column FROM t2",
838                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
839                    "PRE: TABLE FACTOR: t3",
840                    "PRE: RELATION: t3",
841                    "POST: RELATION: t3",
842                    "POST: TABLE FACTOR: t3",
843                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
844                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
845                ],
846            ),
847            (
848                concat!(
849                    "SELECT * FROM monthly_sales ",
850                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
851                    "ORDER BY EMPID"
852                ),
853                vec![
854                    "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
855                    "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
856                    "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
857                    "PRE: TABLE FACTOR: monthly_sales",
858                    "PRE: RELATION: monthly_sales",
859                    "POST: RELATION: monthly_sales",
860                    "POST: TABLE FACTOR: monthly_sales",
861                    "PRE: EXPR: SUM(a.amount)",
862                    "PRE: EXPR: a.amount",
863                    "POST: EXPR: a.amount",
864                    "POST: EXPR: SUM(a.amount)",
865                    "PRE: EXPR: 'JAN'",
866                    "POST: EXPR: 'JAN'",
867                    "PRE: EXPR: 'FEB'",
868                    "POST: EXPR: 'FEB'",
869                    "PRE: EXPR: 'MAR'",
870                    "POST: EXPR: 'MAR'",
871                    "PRE: EXPR: 'APR'",
872                    "POST: EXPR: 'APR'",
873                    "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
874                    "PRE: EXPR: EMPID",
875                    "POST: EXPR: EMPID",
876                    "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
877                    "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
878                ]
879            )
880        ];
881        for (sql, expected) in tests {
882            let actual = do_visit(sql);
883            let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
884            assert_eq!(actual, expected)
885        }
886    }
887}