sqlparser/ast/visitor.rs
1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
14
15use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
16use core::ops::ControlFlow;
17
18/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
19/// recursively visiting parsed SQL statements.
20///
21/// # Note
22///
23/// This trait should be automatically derived for sqlparser AST nodes
24/// using the [Visit](sqlparser_derive::Visit) proc macro.
25///
26/// ```text
27/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
28/// ```
29pub trait Visit {
30 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
31}
32
33/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
34/// recursively visiting parsed SQL statements.
35///
36/// # Note
37///
38/// This trait should be automatically derived for sqlparser AST nodes
39/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
40///
41/// ```text
42/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
43/// ```
44pub trait VisitMut {
45 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
46}
47
48impl<T: Visit> Visit for Option<T> {
49 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
50 if let Some(s) = self {
51 s.visit(visitor)?;
52 }
53 ControlFlow::Continue(())
54 }
55}
56
57impl<T: Visit> Visit for Vec<T> {
58 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
59 for v in self {
60 v.visit(visitor)?;
61 }
62 ControlFlow::Continue(())
63 }
64}
65
66impl<T: Visit> Visit for Box<T> {
67 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
68 T::visit(self, visitor)
69 }
70}
71
72impl<T: VisitMut> VisitMut for Option<T> {
73 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
74 if let Some(s) = self {
75 s.visit(visitor)?;
76 }
77 ControlFlow::Continue(())
78 }
79}
80
81impl<T: VisitMut> VisitMut for Vec<T> {
82 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
83 for v in self {
84 v.visit(visitor)?;
85 }
86 ControlFlow::Continue(())
87 }
88}
89
90impl<T: VisitMut> VisitMut for Box<T> {
91 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
92 T::visit(self, visitor)
93 }
94}
95
96macro_rules! visit_noop {
97 ($($t:ty),+) => {
98 $(impl Visit for $t {
99 fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
100 ControlFlow::Continue(())
101 }
102 })+
103 $(impl VisitMut for $t {
104 fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
105 ControlFlow::Continue(())
106 }
107 })+
108 };
109}
110
111visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
112
113#[cfg(feature = "bigdecimal")]
114visit_noop!(bigdecimal::BigDecimal);
115
116/// A visitor that can be used to walk an AST tree.
117///
118/// `pre_visit_` methods are invoked before visiting all children of the
119/// node and `post_visit_` methods are invoked after visiting all
120/// children of the node.
121///
122/// # See also
123///
124/// These methods provide a more concise way of visiting nodes of a certain type:
125/// * [visit_relations]
126/// * [visit_expressions]
127/// * [visit_statements]
128///
129/// # Example
130/// ```
131/// # use sqlparser::parser::Parser;
132/// # use sqlparser::dialect::GenericDialect;
133/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
134/// # use core::ops::ControlFlow;
135/// // A structure that records statements and relations
136/// #[derive(Default)]
137/// struct V {
138/// visited: Vec<String>,
139/// }
140///
141/// // Visit relations and exprs before children are visited (depth first walk)
142/// // Note you can also visit statements and visit exprs after children have been visited
143/// impl Visitor for V {
144/// type Break = ();
145///
146/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
147/// self.visited.push(format!("PRE: RELATION: {}", relation));
148/// ControlFlow::Continue(())
149/// }
150///
151/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
152/// self.visited.push(format!("PRE: EXPR: {}", expr));
153/// ControlFlow::Continue(())
154/// }
155/// }
156///
157/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
158/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
159/// .unwrap();
160///
161/// // Drive the visitor through the AST
162/// let mut visitor = V::default();
163/// statements.visit(&mut visitor);
164///
165/// // The visitor has visited statements and expressions in pre-traversal order
166/// let expected : Vec<_> = [
167/// "PRE: EXPR: a",
168/// "PRE: RELATION: foo",
169/// "PRE: EXPR: x IN (SELECT y FROM bar)",
170/// "PRE: EXPR: x",
171/// "PRE: EXPR: y",
172/// "PRE: RELATION: bar",
173/// ]
174/// .into_iter().map(|s| s.to_string()).collect();
175///
176/// assert_eq!(visitor.visited, expected);
177/// ```
178pub trait Visitor {
179 /// Type returned when the recursion returns early.
180 type Break;
181
182 /// Invoked for any queries that appear in the AST before visiting children
183 fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
184 ControlFlow::Continue(())
185 }
186
187 /// Invoked for any queries that appear in the AST after visiting children
188 fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189 ControlFlow::Continue(())
190 }
191
192 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
193 fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
194 ControlFlow::Continue(())
195 }
196
197 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
198 fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
199 ControlFlow::Continue(())
200 }
201
202 /// Invoked for any table factors that appear in the AST before visiting children
203 fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
204 ControlFlow::Continue(())
205 }
206
207 /// Invoked for any table factors that appear in the AST after visiting children
208 fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
209 ControlFlow::Continue(())
210 }
211
212 /// Invoked for any expressions that appear in the AST before visiting children
213 fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
214 ControlFlow::Continue(())
215 }
216
217 /// Invoked for any expressions that appear in the AST
218 fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
219 ControlFlow::Continue(())
220 }
221
222 /// Invoked for any statements that appear in the AST before visiting children
223 fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
224 ControlFlow::Continue(())
225 }
226
227 /// Invoked for any statements that appear in the AST after visiting children
228 fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
229 ControlFlow::Continue(())
230 }
231}
232
233/// A visitor that can be used to mutate an AST tree.
234///
235/// `pre_visit_` methods are invoked before visiting all children of the
236/// node and `post_visit_` methods are invoked after visiting all
237/// children of the node.
238///
239/// # See also
240///
241/// These methods provide a more concise way of visiting nodes of a certain type:
242/// * [visit_relations_mut]
243/// * [visit_expressions_mut]
244/// * [visit_statements_mut]
245///
246/// # Example
247/// ```
248/// # use sqlparser::parser::Parser;
249/// # use sqlparser::dialect::GenericDialect;
250/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
251/// # use core::ops::ControlFlow;
252///
253/// // A visitor that replaces "to_replace" with "replaced" in all expressions
254/// struct Replacer;
255///
256/// // Visit each expression after its children have been visited
257/// impl VisitorMut for Replacer {
258/// type Break = ();
259///
260/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
261/// if let Expr::Identifier(Ident{ value, ..}) = expr {
262/// *value = value.replace("to_replace", "replaced")
263/// }
264/// ControlFlow::Continue(())
265/// }
266/// }
267///
268/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
269/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
270///
271/// // Drive the visitor through the AST
272/// statements.visit(&mut Replacer);
273///
274/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
275/// ```
276pub trait VisitorMut {
277 /// Type returned when the recursion returns early.
278 type Break;
279
280 /// Invoked for any queries that appear in the AST before visiting children
281 fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
282 ControlFlow::Continue(())
283 }
284
285 /// Invoked for any queries that appear in the AST after visiting children
286 fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
287 ControlFlow::Continue(())
288 }
289
290 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
291 fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
292 ControlFlow::Continue(())
293 }
294
295 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
296 fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
297 ControlFlow::Continue(())
298 }
299
300 /// Invoked for any table factors that appear in the AST before visiting children
301 fn pre_visit_table_factor(
302 &mut self,
303 _table_factor: &mut TableFactor,
304 ) -> ControlFlow<Self::Break> {
305 ControlFlow::Continue(())
306 }
307
308 /// Invoked for any table factors that appear in the AST after visiting children
309 fn post_visit_table_factor(
310 &mut self,
311 _table_factor: &mut TableFactor,
312 ) -> ControlFlow<Self::Break> {
313 ControlFlow::Continue(())
314 }
315
316 /// Invoked for any expressions that appear in the AST before visiting children
317 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
318 ControlFlow::Continue(())
319 }
320
321 /// Invoked for any expressions that appear in the AST
322 fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
323 ControlFlow::Continue(())
324 }
325
326 /// Invoked for any statements that appear in the AST before visiting children
327 fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
328 ControlFlow::Continue(())
329 }
330
331 /// Invoked for any statements that appear in the AST after visiting children
332 fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
333 ControlFlow::Continue(())
334 }
335}
336
337struct RelationVisitor<F>(F);
338
339impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
340 type Break = E;
341
342 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
343 self.0(relation)
344 }
345}
346
347impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
348 type Break = E;
349
350 fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
351 self.0(relation)
352 }
353}
354
355/// Invokes the provided closure on all relations (e.g. table names) present in `v`
356///
357/// # Example
358/// ```
359/// # use sqlparser::parser::Parser;
360/// # use sqlparser::dialect::GenericDialect;
361/// # use sqlparser::ast::{visit_relations};
362/// # use core::ops::ControlFlow;
363/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
364/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
365/// .unwrap();
366///
367/// // visit statements, capturing relations (table names)
368/// let mut visited = vec![];
369/// visit_relations(&statements, |relation| {
370/// visited.push(format!("RELATION: {}", relation));
371/// ControlFlow::<()>::Continue(())
372/// });
373///
374/// let expected : Vec<_> = [
375/// "RELATION: foo",
376/// "RELATION: bar",
377/// ]
378/// .into_iter().map(|s| s.to_string()).collect();
379///
380/// assert_eq!(visited, expected);
381/// ```
382pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
383where
384 V: Visit,
385 F: FnMut(&ObjectName) -> ControlFlow<E>,
386{
387 let mut visitor = RelationVisitor(f);
388 v.visit(&mut visitor)?;
389 ControlFlow::Continue(())
390}
391
392/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
393/// present in `v`.
394///
395/// When the closure mutates its argument, the new mutated relation will not be visited again.
396///
397/// # Example
398/// ```
399/// # use sqlparser::parser::Parser;
400/// # use sqlparser::dialect::GenericDialect;
401/// # use sqlparser::ast::{ObjectName, visit_relations_mut};
402/// # use core::ops::ControlFlow;
403/// let sql = "SELECT a FROM foo";
404/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
405/// .unwrap();
406///
407/// // visit statements, renaming table foo to bar
408/// visit_relations_mut(&mut statements, |table| {
409/// table.0[0].value = table.0[0].value.replace("foo", "bar");
410/// ControlFlow::<()>::Continue(())
411/// });
412///
413/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
414/// ```
415pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
416where
417 V: VisitMut,
418 F: FnMut(&mut ObjectName) -> ControlFlow<E>,
419{
420 let mut visitor = RelationVisitor(f);
421 v.visit(&mut visitor)?;
422 ControlFlow::Continue(())
423}
424
425struct ExprVisitor<F>(F);
426
427impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
428 type Break = E;
429
430 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
431 self.0(expr)
432 }
433}
434
435impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
436 type Break = E;
437
438 fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
439 self.0(expr)
440 }
441}
442
443/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
444///
445/// # Example
446/// ```
447/// # use sqlparser::parser::Parser;
448/// # use sqlparser::dialect::GenericDialect;
449/// # use sqlparser::ast::{visit_expressions};
450/// # use core::ops::ControlFlow;
451/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
452/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
453/// .unwrap();
454///
455/// // visit all expressions
456/// let mut visited = vec![];
457/// visit_expressions(&statements, |expr| {
458/// visited.push(format!("EXPR: {}", expr));
459/// ControlFlow::<()>::Continue(())
460/// });
461///
462/// let expected : Vec<_> = [
463/// "EXPR: a",
464/// "EXPR: x IN (SELECT y FROM bar)",
465/// "EXPR: x",
466/// "EXPR: y",
467/// ]
468/// .into_iter().map(|s| s.to_string()).collect();
469///
470/// assert_eq!(visited, expected);
471/// ```
472pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
473where
474 V: Visit,
475 F: FnMut(&Expr) -> ControlFlow<E>,
476{
477 let mut visitor = ExprVisitor(f);
478 v.visit(&mut visitor)?;
479 ControlFlow::Continue(())
480}
481
482/// Invokes the provided closure iteratively with a mutable reference to all expressions
483/// present in `v`.
484///
485/// This performs a depth-first search, so if the closure mutates the expression
486///
487/// # Example
488///
489/// ## Remove all select limits in sub-queries
490/// ```
491/// # use sqlparser::parser::Parser;
492/// # use sqlparser::dialect::GenericDialect;
493/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
494/// # use core::ops::ControlFlow;
495/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
496/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
497///
498/// // Remove all select limits in sub-queries
499/// visit_expressions_mut(&mut statements, |expr| {
500/// if let Expr::Subquery(q) = expr {
501/// q.limit = None
502/// }
503/// ControlFlow::<()>::Continue(())
504/// });
505///
506/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
507/// ```
508///
509/// ## Wrap column name in function call
510///
511/// This demonstrates how to effectively replace an expression with another more complicated one
512/// that references the original. This example avoids unnecessary allocations by using the
513/// [`std::mem`] family of functions.
514///
515/// ```
516/// # use sqlparser::parser::Parser;
517/// # use sqlparser::dialect::GenericDialect;
518/// # use sqlparser::ast::{Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value, visit_expressions_mut, visit_statements_mut};
519/// # use core::ops::ControlFlow;
520/// let sql = "SELECT x, y FROM t";
521/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
522///
523/// visit_expressions_mut(&mut statements, |expr| {
524/// if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
525/// let old_expr = std::mem::replace(expr, Expr::Value(Value::Null));
526/// *expr = Expr::Function(Function {
527/// name: ObjectName(vec![Ident::new("f")]),
528/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
529/// null_treatment: None,
530/// filter: None, over: None, distinct: false, special: false, order_by: vec![],
531/// });
532/// }
533/// ControlFlow::<()>::Continue(())
534/// });
535///
536/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
537/// ```
538pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
539where
540 V: VisitMut,
541 F: FnMut(&mut Expr) -> ControlFlow<E>,
542{
543 v.visit(&mut ExprVisitor(f))?;
544 ControlFlow::Continue(())
545}
546
547struct StatementVisitor<F>(F);
548
549impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
550 type Break = E;
551
552 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
553 self.0(statement)
554 }
555}
556
557impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
558 type Break = E;
559
560 fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
561 self.0(statement)
562 }
563}
564
565/// Invokes the provided closure iteratively with a mutable reference to all statements
566/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
567///
568/// # Example
569/// ```
570/// # use sqlparser::parser::Parser;
571/// # use sqlparser::dialect::GenericDialect;
572/// # use sqlparser::ast::{visit_statements};
573/// # use core::ops::ControlFlow;
574/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
575/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
576/// .unwrap();
577///
578/// // visit all statements
579/// let mut visited = vec![];
580/// visit_statements(&statements, |stmt| {
581/// visited.push(format!("STATEMENT: {}", stmt));
582/// ControlFlow::<()>::Continue(())
583/// });
584///
585/// let expected : Vec<_> = [
586/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
587/// "STATEMENT: CREATE TABLE baz (q INT)"
588/// ]
589/// .into_iter().map(|s| s.to_string()).collect();
590///
591/// assert_eq!(visited, expected);
592/// ```
593pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
594where
595 V: Visit,
596 F: FnMut(&Statement) -> ControlFlow<E>,
597{
598 let mut visitor = StatementVisitor(f);
599 v.visit(&mut visitor)?;
600 ControlFlow::Continue(())
601}
602
603/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
604///
605/// # Example
606/// ```
607/// # use sqlparser::parser::Parser;
608/// # use sqlparser::dialect::GenericDialect;
609/// # use sqlparser::ast::{Statement, visit_statements_mut};
610/// # use core::ops::ControlFlow;
611/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
612/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
613///
614/// // Remove all select limits in outer statements (not in sub-queries)
615/// visit_statements_mut(&mut statements, |stmt| {
616/// if let Statement::Query(q) = stmt {
617/// q.limit = None
618/// }
619/// ControlFlow::<()>::Continue(())
620/// });
621///
622/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
623/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
624/// ```
625pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
626where
627 V: VisitMut,
628 F: FnMut(&mut Statement) -> ControlFlow<E>,
629{
630 v.visit(&mut StatementVisitor(f))?;
631 ControlFlow::Continue(())
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use crate::dialect::GenericDialect;
638 use crate::parser::Parser;
639 use crate::tokenizer::Tokenizer;
640
641 #[derive(Default)]
642 struct TestVisitor {
643 visited: Vec<String>,
644 }
645
646 impl Visitor for TestVisitor {
647 type Break = ();
648
649 /// Invoked for any queries that appear in the AST before visiting children
650 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
651 self.visited.push(format!("PRE: QUERY: {query}"));
652 ControlFlow::Continue(())
653 }
654
655 /// Invoked for any queries that appear in the AST after visiting children
656 fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
657 self.visited.push(format!("POST: QUERY: {query}"));
658 ControlFlow::Continue(())
659 }
660
661 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
662 self.visited.push(format!("PRE: RELATION: {relation}"));
663 ControlFlow::Continue(())
664 }
665
666 fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
667 self.visited.push(format!("POST: RELATION: {relation}"));
668 ControlFlow::Continue(())
669 }
670
671 fn pre_visit_table_factor(
672 &mut self,
673 table_factor: &TableFactor,
674 ) -> ControlFlow<Self::Break> {
675 self.visited
676 .push(format!("PRE: TABLE FACTOR: {table_factor}"));
677 ControlFlow::Continue(())
678 }
679
680 fn post_visit_table_factor(
681 &mut self,
682 table_factor: &TableFactor,
683 ) -> ControlFlow<Self::Break> {
684 self.visited
685 .push(format!("POST: TABLE FACTOR: {table_factor}"));
686 ControlFlow::Continue(())
687 }
688
689 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
690 self.visited.push(format!("PRE: EXPR: {expr}"));
691 ControlFlow::Continue(())
692 }
693
694 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
695 self.visited.push(format!("POST: EXPR: {expr}"));
696 ControlFlow::Continue(())
697 }
698
699 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
700 self.visited.push(format!("PRE: STATEMENT: {statement}"));
701 ControlFlow::Continue(())
702 }
703
704 fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
705 self.visited.push(format!("POST: STATEMENT: {statement}"));
706 ControlFlow::Continue(())
707 }
708 }
709
710 fn do_visit(sql: &str) -> Vec<String> {
711 let dialect = GenericDialect {};
712 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
713 let s = Parser::new(&dialect)
714 .with_tokens(tokens)
715 .parse_statement()
716 .unwrap();
717
718 let mut visitor = TestVisitor::default();
719 s.visit(&mut visitor);
720 visitor.visited
721 }
722
723 #[test]
724 fn test_sql() {
725 let tests = vec![
726 (
727 "SELECT * from table_name as my_table",
728 vec![
729 "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
730 "PRE: QUERY: SELECT * FROM table_name AS my_table",
731 "PRE: TABLE FACTOR: table_name AS my_table",
732 "PRE: RELATION: table_name",
733 "POST: RELATION: table_name",
734 "POST: TABLE FACTOR: table_name AS my_table",
735 "POST: QUERY: SELECT * FROM table_name AS my_table",
736 "POST: STATEMENT: SELECT * FROM table_name AS my_table",
737 ],
738 ),
739 (
740 "SELECT * from t1 join t2 on t1.id = t2.t1_id",
741 vec![
742 "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
743 "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
744 "PRE: TABLE FACTOR: t1",
745 "PRE: RELATION: t1",
746 "POST: RELATION: t1",
747 "POST: TABLE FACTOR: t1",
748 "PRE: TABLE FACTOR: t2",
749 "PRE: RELATION: t2",
750 "POST: RELATION: t2",
751 "POST: TABLE FACTOR: t2",
752 "PRE: EXPR: t1.id = t2.t1_id",
753 "PRE: EXPR: t1.id",
754 "POST: EXPR: t1.id",
755 "PRE: EXPR: t2.t1_id",
756 "POST: EXPR: t2.t1_id",
757 "POST: EXPR: t1.id = t2.t1_id",
758 "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
759 "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
760 ],
761 ),
762 (
763 "SELECT * from t1 where EXISTS(SELECT column from t2)",
764 vec![
765 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
766 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
767 "PRE: TABLE FACTOR: t1",
768 "PRE: RELATION: t1",
769 "POST: RELATION: t1",
770 "POST: TABLE FACTOR: t1",
771 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
772 "PRE: QUERY: SELECT column FROM t2",
773 "PRE: EXPR: column",
774 "POST: EXPR: column",
775 "PRE: TABLE FACTOR: t2",
776 "PRE: RELATION: t2",
777 "POST: RELATION: t2",
778 "POST: TABLE FACTOR: t2",
779 "POST: QUERY: SELECT column FROM t2",
780 "POST: EXPR: EXISTS (SELECT column FROM t2)",
781 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
782 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
783 ],
784 ),
785 (
786 "SELECT * from t1 where EXISTS(SELECT column from t2)",
787 vec![
788 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
789 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
790 "PRE: TABLE FACTOR: t1",
791 "PRE: RELATION: t1",
792 "POST: RELATION: t1",
793 "POST: TABLE FACTOR: t1",
794 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
795 "PRE: QUERY: SELECT column FROM t2",
796 "PRE: EXPR: column",
797 "POST: EXPR: column",
798 "PRE: TABLE FACTOR: t2",
799 "PRE: RELATION: t2",
800 "POST: RELATION: t2",
801 "POST: TABLE FACTOR: t2",
802 "POST: QUERY: SELECT column FROM t2",
803 "POST: EXPR: EXISTS (SELECT column FROM t2)",
804 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
805 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
806 ],
807 ),
808 (
809 "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
810 vec![
811 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
812 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
813 "PRE: TABLE FACTOR: t1",
814 "PRE: RELATION: t1",
815 "POST: RELATION: t1",
816 "POST: TABLE FACTOR: t1",
817 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
818 "PRE: QUERY: SELECT column FROM t2",
819 "PRE: EXPR: column",
820 "POST: EXPR: column",
821 "PRE: TABLE FACTOR: t2",
822 "PRE: RELATION: t2",
823 "POST: RELATION: t2",
824 "POST: TABLE FACTOR: t2",
825 "POST: QUERY: SELECT column FROM t2",
826 "POST: EXPR: EXISTS (SELECT column FROM t2)",
827 "PRE: TABLE FACTOR: t3",
828 "PRE: RELATION: t3",
829 "POST: RELATION: t3",
830 "POST: TABLE FACTOR: t3",
831 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
832 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
833 ],
834 ),
835 (
836 concat!(
837 "SELECT * FROM monthly_sales ",
838 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
839 "ORDER BY EMPID"
840 ),
841 vec![
842 "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
843 "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
844 "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
845 "PRE: TABLE FACTOR: monthly_sales",
846 "PRE: RELATION: monthly_sales",
847 "POST: RELATION: monthly_sales",
848 "POST: TABLE FACTOR: monthly_sales",
849 "PRE: EXPR: SUM(a.amount)",
850 "PRE: EXPR: a.amount",
851 "POST: EXPR: a.amount",
852 "POST: EXPR: SUM(a.amount)",
853 "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
854 "PRE: EXPR: EMPID",
855 "POST: EXPR: EMPID",
856 "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
857 "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
858 ]
859 )
860 ];
861 for (sql, expected) in tests {
862 let actual = do_visit(sql);
863 let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
864 assert_eq!(actual, expected)
865 }
866 }
867}