sqlparser/ast/
visitor.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
14
15use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
16use core::ops::ControlFlow;
17
18/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
19/// recursively visiting parsed SQL statements.
20///
21/// # Note
22///
23/// This trait should be automatically derived for sqlparser AST nodes
24/// using the [Visit](sqlparser_derive::Visit) proc macro.
25///
26/// ```text
27/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
28/// ```
29pub trait Visit {
30    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
31}
32
33/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
34/// recursively visiting parsed SQL statements.
35///
36/// # Note
37///
38/// This trait should be automatically derived for sqlparser AST nodes
39/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
40///
41/// ```text
42/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
43/// ```
44pub trait VisitMut {
45    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
46}
47
48impl<T: Visit> Visit for Option<T> {
49    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
50        if let Some(s) = self {
51            s.visit(visitor)?;
52        }
53        ControlFlow::Continue(())
54    }
55}
56
57impl<T: Visit> Visit for Vec<T> {
58    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
59        for v in self {
60            v.visit(visitor)?;
61        }
62        ControlFlow::Continue(())
63    }
64}
65
66impl<T: Visit> Visit for Box<T> {
67    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
68        T::visit(self, visitor)
69    }
70}
71
72impl<T: VisitMut> VisitMut for Option<T> {
73    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
74        if let Some(s) = self {
75            s.visit(visitor)?;
76        }
77        ControlFlow::Continue(())
78    }
79}
80
81impl<T: VisitMut> VisitMut for Vec<T> {
82    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
83        for v in self {
84            v.visit(visitor)?;
85        }
86        ControlFlow::Continue(())
87    }
88}
89
90impl<T: VisitMut> VisitMut for Box<T> {
91    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
92        T::visit(self, visitor)
93    }
94}
95
96macro_rules! visit_noop {
97    ($($t:ty),+) => {
98        $(impl Visit for $t {
99            fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
100               ControlFlow::Continue(())
101            }
102        })+
103        $(impl VisitMut for $t {
104            fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
105               ControlFlow::Continue(())
106            }
107        })+
108    };
109}
110
111visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
112
113#[cfg(feature = "bigdecimal")]
114visit_noop!(bigdecimal::BigDecimal);
115
116/// A visitor that can be used to walk an AST tree.
117///
118/// `pre_visit_` methods are invoked before visiting all children of the
119/// node and `post_visit_` methods are invoked after visiting all
120/// children of the node.
121///
122/// # See also
123///
124/// These methods provide a more concise way of visiting nodes of a certain type:
125/// * [visit_relations]
126/// * [visit_expressions]
127/// * [visit_statements]
128///
129/// # Example
130/// ```
131/// # use sqlparser::parser::Parser;
132/// # use sqlparser::dialect::GenericDialect;
133/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
134/// # use core::ops::ControlFlow;
135/// // A structure that records statements and relations
136/// #[derive(Default)]
137/// struct V {
138///    visited: Vec<String>,
139/// }
140///
141/// // Visit relations and exprs before children are visited (depth first walk)
142/// // Note you can also visit statements and visit exprs after children have been visited
143/// impl Visitor for V {
144///   type Break = ();
145///
146///   fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
147///     self.visited.push(format!("PRE: RELATION: {}", relation));
148///     ControlFlow::Continue(())
149///   }
150///
151///   fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
152///     self.visited.push(format!("PRE: EXPR: {}", expr));
153///     ControlFlow::Continue(())
154///   }
155/// }
156///
157/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
158/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
159///    .unwrap();
160///
161/// // Drive the visitor through the AST
162/// let mut visitor = V::default();
163/// statements.visit(&mut visitor);
164///
165/// // The visitor has visited statements and expressions in pre-traversal order
166/// let expected : Vec<_> = [
167///   "PRE: EXPR: a",
168///   "PRE: RELATION: foo",
169///   "PRE: EXPR: x IN (SELECT y FROM bar)",
170///   "PRE: EXPR: x",
171///   "PRE: EXPR: y",
172///   "PRE: RELATION: bar",
173/// ]
174///   .into_iter().map(|s| s.to_string()).collect();
175///
176/// assert_eq!(visitor.visited, expected);
177/// ```
178pub trait Visitor {
179    /// Type returned when the recursion returns early.
180    type Break;
181
182    /// Invoked for any queries that appear in the AST before visiting children
183    fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
184        ControlFlow::Continue(())
185    }
186
187    /// Invoked for any queries that appear in the AST after visiting children
188    fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189        ControlFlow::Continue(())
190    }
191
192    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
193    fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
194        ControlFlow::Continue(())
195    }
196
197    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
198    fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
199        ControlFlow::Continue(())
200    }
201
202    /// Invoked for any table factors that appear in the AST before visiting children
203    fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
204        ControlFlow::Continue(())
205    }
206
207    /// Invoked for any table factors that appear in the AST after visiting children
208    fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
209        ControlFlow::Continue(())
210    }
211
212    /// Invoked for any expressions that appear in the AST before visiting children
213    fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
214        ControlFlow::Continue(())
215    }
216
217    /// Invoked for any expressions that appear in the AST
218    fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
219        ControlFlow::Continue(())
220    }
221
222    /// Invoked for any statements that appear in the AST before visiting children
223    fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
224        ControlFlow::Continue(())
225    }
226
227    /// Invoked for any statements that appear in the AST after visiting children
228    fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
229        ControlFlow::Continue(())
230    }
231}
232
233/// A visitor that can be used to mutate an AST tree.
234///
235/// `pre_visit_` methods are invoked before visiting all children of the
236/// node and `post_visit_` methods are invoked after visiting all
237/// children of the node.
238///
239/// # See also
240///
241/// These methods provide a more concise way of visiting nodes of a certain type:
242/// * [visit_relations_mut]
243/// * [visit_expressions_mut]
244/// * [visit_statements_mut]
245///
246/// # Example
247/// ```
248/// # use sqlparser::parser::Parser;
249/// # use sqlparser::dialect::GenericDialect;
250/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
251/// # use core::ops::ControlFlow;
252///
253/// // A visitor that replaces "to_replace" with "replaced" in all expressions
254/// struct Replacer;
255///
256/// // Visit each expression after its children have been visited
257/// impl VisitorMut for Replacer {
258///   type Break = ();
259///
260///   fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
261///     if let Expr::Identifier(Ident{ value, ..}) = expr {
262///         *value = value.replace("to_replace", "replaced")
263///     }
264///     ControlFlow::Continue(())
265///   }
266/// }
267///
268/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
269/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
270///
271/// // Drive the visitor through the AST
272/// statements.visit(&mut Replacer);
273///
274/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
275/// ```
276pub trait VisitorMut {
277    /// Type returned when the recursion returns early.
278    type Break;
279
280    /// Invoked for any queries that appear in the AST before visiting children
281    fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
282        ControlFlow::Continue(())
283    }
284
285    /// Invoked for any queries that appear in the AST after visiting children
286    fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
287        ControlFlow::Continue(())
288    }
289
290    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
291    fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
292        ControlFlow::Continue(())
293    }
294
295    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
296    fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
297        ControlFlow::Continue(())
298    }
299
300    /// Invoked for any table factors that appear in the AST before visiting children
301    fn pre_visit_table_factor(
302        &mut self,
303        _table_factor: &mut TableFactor,
304    ) -> ControlFlow<Self::Break> {
305        ControlFlow::Continue(())
306    }
307
308    /// Invoked for any table factors that appear in the AST after visiting children
309    fn post_visit_table_factor(
310        &mut self,
311        _table_factor: &mut TableFactor,
312    ) -> ControlFlow<Self::Break> {
313        ControlFlow::Continue(())
314    }
315
316    /// Invoked for any expressions that appear in the AST before visiting children
317    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
318        ControlFlow::Continue(())
319    }
320
321    /// Invoked for any expressions that appear in the AST
322    fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
323        ControlFlow::Continue(())
324    }
325
326    /// Invoked for any statements that appear in the AST before visiting children
327    fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
328        ControlFlow::Continue(())
329    }
330
331    /// Invoked for any statements that appear in the AST after visiting children
332    fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
333        ControlFlow::Continue(())
334    }
335}
336
337struct RelationVisitor<F>(F);
338
339impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
340    type Break = E;
341
342    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
343        self.0(relation)
344    }
345}
346
347impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
348    type Break = E;
349
350    fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
351        self.0(relation)
352    }
353}
354
355/// Invokes the provided closure on all relations (e.g. table names) present in `v`
356///
357/// # Example
358/// ```
359/// # use sqlparser::parser::Parser;
360/// # use sqlparser::dialect::GenericDialect;
361/// # use sqlparser::ast::{visit_relations};
362/// # use core::ops::ControlFlow;
363/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
364/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
365///    .unwrap();
366///
367/// // visit statements, capturing relations (table names)
368/// let mut visited = vec![];
369/// visit_relations(&statements, |relation| {
370///   visited.push(format!("RELATION: {}", relation));
371///   ControlFlow::<()>::Continue(())
372/// });
373///
374/// let expected : Vec<_> = [
375///   "RELATION: foo",
376///   "RELATION: bar",
377/// ]
378///   .into_iter().map(|s| s.to_string()).collect();
379///
380/// assert_eq!(visited, expected);
381/// ```
382pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
383where
384    V: Visit,
385    F: FnMut(&ObjectName) -> ControlFlow<E>,
386{
387    let mut visitor = RelationVisitor(f);
388    v.visit(&mut visitor)?;
389    ControlFlow::Continue(())
390}
391
392/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
393/// present in `v`.
394///
395/// When the closure mutates its argument, the new mutated relation will not be visited again.
396///
397/// # Example
398/// ```
399/// # use sqlparser::parser::Parser;
400/// # use sqlparser::dialect::GenericDialect;
401/// # use sqlparser::ast::{ObjectName, visit_relations_mut};
402/// # use core::ops::ControlFlow;
403/// let sql = "SELECT a FROM foo";
404/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
405///    .unwrap();
406///
407/// // visit statements, renaming table foo to bar
408/// visit_relations_mut(&mut statements, |table| {
409///   table.0[0].value = table.0[0].value.replace("foo", "bar");
410///   ControlFlow::<()>::Continue(())
411/// });
412///
413/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
414/// ```
415pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
416where
417    V: VisitMut,
418    F: FnMut(&mut ObjectName) -> ControlFlow<E>,
419{
420    let mut visitor = RelationVisitor(f);
421    v.visit(&mut visitor)?;
422    ControlFlow::Continue(())
423}
424
425struct ExprVisitor<F>(F);
426
427impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
428    type Break = E;
429
430    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
431        self.0(expr)
432    }
433}
434
435impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
436    type Break = E;
437
438    fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
439        self.0(expr)
440    }
441}
442
443/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
444///
445/// # Example
446/// ```
447/// # use sqlparser::parser::Parser;
448/// # use sqlparser::dialect::GenericDialect;
449/// # use sqlparser::ast::{visit_expressions};
450/// # use core::ops::ControlFlow;
451/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
452/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
453///    .unwrap();
454///
455/// // visit all expressions
456/// let mut visited = vec![];
457/// visit_expressions(&statements, |expr| {
458///   visited.push(format!("EXPR: {}", expr));
459///   ControlFlow::<()>::Continue(())
460/// });
461///
462/// let expected : Vec<_> = [
463///   "EXPR: a",
464///   "EXPR: x IN (SELECT y FROM bar)",
465///   "EXPR: x",
466///   "EXPR: y",
467/// ]
468///   .into_iter().map(|s| s.to_string()).collect();
469///
470/// assert_eq!(visited, expected);
471/// ```
472pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
473where
474    V: Visit,
475    F: FnMut(&Expr) -> ControlFlow<E>,
476{
477    let mut visitor = ExprVisitor(f);
478    v.visit(&mut visitor)?;
479    ControlFlow::Continue(())
480}
481
482/// Invokes the provided closure iteratively with a mutable reference to all expressions
483/// present in `v`.
484///
485/// This performs a depth-first search, so if the closure mutates the expression
486///
487/// # Example
488///
489/// ## Remove all select limits in sub-queries
490/// ```
491/// # use sqlparser::parser::Parser;
492/// # use sqlparser::dialect::GenericDialect;
493/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
494/// # use core::ops::ControlFlow;
495/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
496/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
497///
498/// // Remove all select limits in sub-queries
499/// visit_expressions_mut(&mut statements, |expr| {
500///   if let Expr::Subquery(q) = expr {
501///      q.limit = None
502///   }
503///   ControlFlow::<()>::Continue(())
504/// });
505///
506/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
507/// ```
508///
509/// ## Wrap column name in function call
510///
511/// This demonstrates how to effectively replace an expression with another more complicated one
512/// that references the original. This example avoids unnecessary allocations by using the
513/// [`std::mem`] family of functions.
514///
515/// ```
516/// # use sqlparser::parser::Parser;
517/// # use sqlparser::dialect::GenericDialect;
518/// # use sqlparser::ast::{Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value, visit_expressions_mut, visit_statements_mut};
519/// # use core::ops::ControlFlow;
520/// let sql = "SELECT x, y FROM t";
521/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
522///
523/// visit_expressions_mut(&mut statements, |expr| {
524///   if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
525///     let old_expr = std::mem::replace(expr, Expr::Value(Value::Null));
526///     *expr = Expr::Function(Function {
527///           name: ObjectName(vec![Ident::new("f")]),
528///           args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
529///           null_treatment: None,
530///           filter: None, over: None, distinct: false, special: false, order_by: vec![],
531///      });
532///   }
533///   ControlFlow::<()>::Continue(())
534/// });
535///
536/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
537/// ```
538pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
539where
540    V: VisitMut,
541    F: FnMut(&mut Expr) -> ControlFlow<E>,
542{
543    v.visit(&mut ExprVisitor(f))?;
544    ControlFlow::Continue(())
545}
546
547struct StatementVisitor<F>(F);
548
549impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
550    type Break = E;
551
552    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
553        self.0(statement)
554    }
555}
556
557impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
558    type Break = E;
559
560    fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
561        self.0(statement)
562    }
563}
564
565/// Invokes the provided closure iteratively with a mutable reference to all statements
566/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
567///
568/// # Example
569/// ```
570/// # use sqlparser::parser::Parser;
571/// # use sqlparser::dialect::GenericDialect;
572/// # use sqlparser::ast::{visit_statements};
573/// # use core::ops::ControlFlow;
574/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
575/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
576///    .unwrap();
577///
578/// // visit all statements
579/// let mut visited = vec![];
580/// visit_statements(&statements, |stmt| {
581///   visited.push(format!("STATEMENT: {}", stmt));
582///   ControlFlow::<()>::Continue(())
583/// });
584///
585/// let expected : Vec<_> = [
586///   "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
587///   "STATEMENT: CREATE TABLE baz (q INT)"
588/// ]
589///   .into_iter().map(|s| s.to_string()).collect();
590///
591/// assert_eq!(visited, expected);
592/// ```
593pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
594where
595    V: Visit,
596    F: FnMut(&Statement) -> ControlFlow<E>,
597{
598    let mut visitor = StatementVisitor(f);
599    v.visit(&mut visitor)?;
600    ControlFlow::Continue(())
601}
602
603/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
604///
605/// # Example
606/// ```
607/// # use sqlparser::parser::Parser;
608/// # use sqlparser::dialect::GenericDialect;
609/// # use sqlparser::ast::{Statement, visit_statements_mut};
610/// # use core::ops::ControlFlow;
611/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
612/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
613///
614/// // Remove all select limits in outer statements (not in sub-queries)
615/// visit_statements_mut(&mut statements, |stmt| {
616///   if let Statement::Query(q) = stmt {
617///      q.limit = None
618///   }
619///   ControlFlow::<()>::Continue(())
620/// });
621///
622/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
623/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
624/// ```
625pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
626where
627    V: VisitMut,
628    F: FnMut(&mut Statement) -> ControlFlow<E>,
629{
630    v.visit(&mut StatementVisitor(f))?;
631    ControlFlow::Continue(())
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use crate::dialect::GenericDialect;
638    use crate::parser::Parser;
639    use crate::tokenizer::Tokenizer;
640
641    #[derive(Default)]
642    struct TestVisitor {
643        visited: Vec<String>,
644    }
645
646    impl Visitor for TestVisitor {
647        type Break = ();
648
649        /// Invoked for any queries that appear in the AST before visiting children
650        fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
651            self.visited.push(format!("PRE: QUERY: {query}"));
652            ControlFlow::Continue(())
653        }
654
655        /// Invoked for any queries that appear in the AST after visiting children
656        fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
657            self.visited.push(format!("POST: QUERY: {query}"));
658            ControlFlow::Continue(())
659        }
660
661        fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
662            self.visited.push(format!("PRE: RELATION: {relation}"));
663            ControlFlow::Continue(())
664        }
665
666        fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
667            self.visited.push(format!("POST: RELATION: {relation}"));
668            ControlFlow::Continue(())
669        }
670
671        fn pre_visit_table_factor(
672            &mut self,
673            table_factor: &TableFactor,
674        ) -> ControlFlow<Self::Break> {
675            self.visited
676                .push(format!("PRE: TABLE FACTOR: {table_factor}"));
677            ControlFlow::Continue(())
678        }
679
680        fn post_visit_table_factor(
681            &mut self,
682            table_factor: &TableFactor,
683        ) -> ControlFlow<Self::Break> {
684            self.visited
685                .push(format!("POST: TABLE FACTOR: {table_factor}"));
686            ControlFlow::Continue(())
687        }
688
689        fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
690            self.visited.push(format!("PRE: EXPR: {expr}"));
691            ControlFlow::Continue(())
692        }
693
694        fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
695            self.visited.push(format!("POST: EXPR: {expr}"));
696            ControlFlow::Continue(())
697        }
698
699        fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
700            self.visited.push(format!("PRE: STATEMENT: {statement}"));
701            ControlFlow::Continue(())
702        }
703
704        fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
705            self.visited.push(format!("POST: STATEMENT: {statement}"));
706            ControlFlow::Continue(())
707        }
708    }
709
710    fn do_visit(sql: &str) -> Vec<String> {
711        let dialect = GenericDialect {};
712        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
713        let s = Parser::new(&dialect)
714            .with_tokens(tokens)
715            .parse_statement()
716            .unwrap();
717
718        let mut visitor = TestVisitor::default();
719        s.visit(&mut visitor);
720        visitor.visited
721    }
722
723    #[test]
724    fn test_sql() {
725        let tests = vec![
726            (
727                "SELECT * from table_name as my_table",
728                vec![
729                    "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
730                    "PRE: QUERY: SELECT * FROM table_name AS my_table",
731                    "PRE: TABLE FACTOR: table_name AS my_table",
732                    "PRE: RELATION: table_name",
733                    "POST: RELATION: table_name",
734                    "POST: TABLE FACTOR: table_name AS my_table",
735                    "POST: QUERY: SELECT * FROM table_name AS my_table",
736                    "POST: STATEMENT: SELECT * FROM table_name AS my_table",
737                ],
738            ),
739            (
740                "SELECT * from t1 join t2 on t1.id = t2.t1_id",
741                vec![
742                    "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
743                    "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
744                    "PRE: TABLE FACTOR: t1",
745                    "PRE: RELATION: t1",
746                    "POST: RELATION: t1",
747                    "POST: TABLE FACTOR: t1",
748                    "PRE: TABLE FACTOR: t2",
749                    "PRE: RELATION: t2",
750                    "POST: RELATION: t2",
751                    "POST: TABLE FACTOR: t2",
752                    "PRE: EXPR: t1.id = t2.t1_id",
753                    "PRE: EXPR: t1.id",
754                    "POST: EXPR: t1.id",
755                    "PRE: EXPR: t2.t1_id",
756                    "POST: EXPR: t2.t1_id",
757                    "POST: EXPR: t1.id = t2.t1_id",
758                    "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
759                    "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
760                ],
761            ),
762            (
763                "SELECT * from t1 where EXISTS(SELECT column from t2)",
764                vec![
765                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
766                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
767                    "PRE: TABLE FACTOR: t1",
768                    "PRE: RELATION: t1",
769                    "POST: RELATION: t1",
770                    "POST: TABLE FACTOR: t1",
771                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
772                    "PRE: QUERY: SELECT column FROM t2",
773                    "PRE: EXPR: column",
774                    "POST: EXPR: column",
775                    "PRE: TABLE FACTOR: t2",
776                    "PRE: RELATION: t2",
777                    "POST: RELATION: t2",
778                    "POST: TABLE FACTOR: t2",
779                    "POST: QUERY: SELECT column FROM t2",
780                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
781                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
782                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
783                ],
784            ),
785            (
786                "SELECT * from t1 where EXISTS(SELECT column from t2)",
787                vec![
788                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
789                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
790                    "PRE: TABLE FACTOR: t1",
791                    "PRE: RELATION: t1",
792                    "POST: RELATION: t1",
793                    "POST: TABLE FACTOR: t1",
794                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
795                    "PRE: QUERY: SELECT column FROM t2",
796                    "PRE: EXPR: column",
797                    "POST: EXPR: column",
798                    "PRE: TABLE FACTOR: t2",
799                    "PRE: RELATION: t2",
800                    "POST: RELATION: t2",
801                    "POST: TABLE FACTOR: t2",
802                    "POST: QUERY: SELECT column FROM t2",
803                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
804                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
805                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
806                ],
807            ),
808            (
809                "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
810                vec![
811                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
812                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
813                    "PRE: TABLE FACTOR: t1",
814                    "PRE: RELATION: t1",
815                    "POST: RELATION: t1",
816                    "POST: TABLE FACTOR: t1",
817                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
818                    "PRE: QUERY: SELECT column FROM t2",
819                    "PRE: EXPR: column",
820                    "POST: EXPR: column",
821                    "PRE: TABLE FACTOR: t2",
822                    "PRE: RELATION: t2",
823                    "POST: RELATION: t2",
824                    "POST: TABLE FACTOR: t2",
825                    "POST: QUERY: SELECT column FROM t2",
826                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
827                    "PRE: TABLE FACTOR: t3",
828                    "PRE: RELATION: t3",
829                    "POST: RELATION: t3",
830                    "POST: TABLE FACTOR: t3",
831                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
832                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
833                ],
834            ),
835            (
836                concat!(
837                    "SELECT * FROM monthly_sales ",
838                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
839                    "ORDER BY EMPID"
840                ),
841                vec![
842                    "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
843                    "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
844                    "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
845                    "PRE: TABLE FACTOR: monthly_sales",
846                    "PRE: RELATION: monthly_sales",
847                    "POST: RELATION: monthly_sales",
848                    "POST: TABLE FACTOR: monthly_sales",
849                    "PRE: EXPR: SUM(a.amount)",
850                    "PRE: EXPR: a.amount",
851                    "POST: EXPR: a.amount",
852                    "POST: EXPR: SUM(a.amount)",
853                    "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
854                    "PRE: EXPR: EMPID",
855                    "POST: EXPR: EMPID",
856                    "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
857                    "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
858                ]
859            )
860        ];
861        for (sql, expected) in tests {
862            let actual = do_visit(sql);
863            let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
864            assert_eq!(actual, expected)
865        }
866    }
867}