sqltk_parser/ast/visitor.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqltk_parser AST nodes
29/// using the [Visit](sqltk_parser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
36}
37
38/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
39/// recursively visiting parsed SQL statements.
40///
41/// # Note
42///
43/// This trait should be automatically derived for sqltk_parser AST nodes
44/// using the [VisitMut](sqltk_parser_derive::VisitMut) proc macro.
45///
46/// ```text
47/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
48/// ```
49pub trait VisitMut {
50 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
51}
52
53impl<T: Visit> Visit for Option<T> {
54 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
55 if let Some(s) = self {
56 s.visit(visitor)?;
57 }
58 ControlFlow::Continue(())
59 }
60}
61
62impl<T: Visit> Visit for Vec<T> {
63 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64 for v in self {
65 v.visit(visitor)?;
66 }
67 ControlFlow::Continue(())
68 }
69}
70
71impl<T: Visit> Visit for Box<T> {
72 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73 T::visit(self, visitor)
74 }
75}
76
77impl<T: VisitMut> VisitMut for Option<T> {
78 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
79 if let Some(s) = self {
80 s.visit(visitor)?;
81 }
82 ControlFlow::Continue(())
83 }
84}
85
86impl<T: VisitMut> VisitMut for Vec<T> {
87 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88 for v in self {
89 v.visit(visitor)?;
90 }
91 ControlFlow::Continue(())
92 }
93}
94
95impl<T: VisitMut> VisitMut for Box<T> {
96 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97 T::visit(self, visitor)
98 }
99}
100
101macro_rules! visit_noop {
102 ($($t:ty),+) => {
103 $(impl Visit for $t {
104 fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
105 ControlFlow::Continue(())
106 }
107 })+
108 $(impl VisitMut for $t {
109 fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
110 ControlFlow::Continue(())
111 }
112 })+
113 };
114}
115
116visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
117
118#[cfg(feature = "bigdecimal")]
119visit_noop!(bigdecimal::BigDecimal);
120
121/// A visitor that can be used to walk an AST tree.
122///
123/// `pre_visit_` methods are invoked before visiting all children of the
124/// node and `post_visit_` methods are invoked after visiting all
125/// children of the node.
126///
127/// # See also
128///
129/// These methods provide a more concise way of visiting nodes of a certain type:
130/// * [visit_relations]
131/// * [visit_expressions]
132/// * [visit_statements]
133///
134/// # Example
135/// ```
136/// # use sqltk_parser::parser::Parser;
137/// # use sqltk_parser::dialect::GenericDialect;
138/// # use sqltk_parser::ast::{Visit, Visitor, ObjectName, Expr};
139/// # use core::ops::ControlFlow;
140/// // A structure that records statements and relations
141/// #[derive(Default)]
142/// struct V {
143/// visited: Vec<String>,
144/// }
145///
146/// // Visit relations and exprs before children are visited (depth first walk)
147/// // Note you can also visit statements and visit exprs after children have been visited
148/// impl Visitor for V {
149/// type Break = ();
150///
151/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
152/// self.visited.push(format!("PRE: RELATION: {}", relation));
153/// ControlFlow::Continue(())
154/// }
155///
156/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
157/// self.visited.push(format!("PRE: EXPR: {}", expr));
158/// ControlFlow::Continue(())
159/// }
160/// }
161///
162/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
163/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
164/// .unwrap();
165///
166/// // Drive the visitor through the AST
167/// let mut visitor = V::default();
168/// statements.visit(&mut visitor);
169///
170/// // The visitor has visited statements and expressions in pre-traversal order
171/// let expected : Vec<_> = [
172/// "PRE: EXPR: a",
173/// "PRE: RELATION: foo",
174/// "PRE: EXPR: x IN (SELECT y FROM bar)",
175/// "PRE: EXPR: x",
176/// "PRE: EXPR: y",
177/// "PRE: RELATION: bar",
178/// ]
179/// .into_iter().map(|s| s.to_string()).collect();
180///
181/// assert_eq!(visitor.visited, expected);
182/// ```
183pub trait Visitor {
184 /// Type returned when the recursion returns early.
185 type Break;
186
187 /// Invoked for any queries that appear in the AST before visiting children
188 fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189 ControlFlow::Continue(())
190 }
191
192 /// Invoked for any queries that appear in the AST after visiting children
193 fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
194 ControlFlow::Continue(())
195 }
196
197 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
198 fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
199 ControlFlow::Continue(())
200 }
201
202 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
203 fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
204 ControlFlow::Continue(())
205 }
206
207 /// Invoked for any table factors that appear in the AST before visiting children
208 fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
209 ControlFlow::Continue(())
210 }
211
212 /// Invoked for any table factors that appear in the AST after visiting children
213 fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
214 ControlFlow::Continue(())
215 }
216
217 /// Invoked for any expressions that appear in the AST before visiting children
218 fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
219 ControlFlow::Continue(())
220 }
221
222 /// Invoked for any expressions that appear in the AST
223 fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
224 ControlFlow::Continue(())
225 }
226
227 /// Invoked for any statements that appear in the AST before visiting children
228 fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
229 ControlFlow::Continue(())
230 }
231
232 /// Invoked for any statements that appear in the AST after visiting children
233 fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
234 ControlFlow::Continue(())
235 }
236}
237
238/// A visitor that can be used to mutate an AST tree.
239///
240/// `pre_visit_` methods are invoked before visiting all children of the
241/// node and `post_visit_` methods are invoked after visiting all
242/// children of the node.
243///
244/// # See also
245///
246/// These methods provide a more concise way of visiting nodes of a certain type:
247/// * [visit_relations_mut]
248/// * [visit_expressions_mut]
249/// * [visit_statements_mut]
250///
251/// # Example
252/// ```
253/// # use sqltk_parser::parser::Parser;
254/// # use sqltk_parser::dialect::GenericDialect;
255/// # use sqltk_parser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
256/// # use core::ops::ControlFlow;
257///
258/// // A visitor that replaces "to_replace" with "replaced" in all expressions
259/// struct Replacer;
260///
261/// // Visit each expression after its children have been visited
262/// impl VisitorMut for Replacer {
263/// type Break = ();
264///
265/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
266/// if let Expr::Identifier(Ident{ value, ..}) = expr {
267/// *value = value.replace("to_replace", "replaced")
268/// }
269/// ControlFlow::Continue(())
270/// }
271/// }
272///
273/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
274/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
275///
276/// // Drive the visitor through the AST
277/// statements.visit(&mut Replacer);
278///
279/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
280/// ```
281pub trait VisitorMut {
282 /// Type returned when the recursion returns early.
283 type Break;
284
285 /// Invoked for any queries that appear in the AST before visiting children
286 fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
287 ControlFlow::Continue(())
288 }
289
290 /// Invoked for any queries that appear in the AST after visiting children
291 fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
292 ControlFlow::Continue(())
293 }
294
295 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
296 fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
297 ControlFlow::Continue(())
298 }
299
300 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
301 fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
302 ControlFlow::Continue(())
303 }
304
305 /// Invoked for any table factors that appear in the AST before visiting children
306 fn pre_visit_table_factor(
307 &mut self,
308 _table_factor: &mut TableFactor,
309 ) -> ControlFlow<Self::Break> {
310 ControlFlow::Continue(())
311 }
312
313 /// Invoked for any table factors that appear in the AST after visiting children
314 fn post_visit_table_factor(
315 &mut self,
316 _table_factor: &mut TableFactor,
317 ) -> ControlFlow<Self::Break> {
318 ControlFlow::Continue(())
319 }
320
321 /// Invoked for any expressions that appear in the AST before visiting children
322 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
323 ControlFlow::Continue(())
324 }
325
326 /// Invoked for any expressions that appear in the AST
327 fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
328 ControlFlow::Continue(())
329 }
330
331 /// Invoked for any statements that appear in the AST before visiting children
332 fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
333 ControlFlow::Continue(())
334 }
335
336 /// Invoked for any statements that appear in the AST after visiting children
337 fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
338 ControlFlow::Continue(())
339 }
340}
341
342struct RelationVisitor<F>(F);
343
344impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
345 type Break = E;
346
347 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
348 self.0(relation)
349 }
350}
351
352impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
353 type Break = E;
354
355 fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
356 self.0(relation)
357 }
358}
359
360/// Invokes the provided closure on all relations (e.g. table names) present in `v`
361///
362/// # Example
363/// ```
364/// # use sqltk_parser::parser::Parser;
365/// # use sqltk_parser::dialect::GenericDialect;
366/// # use sqltk_parser::ast::{visit_relations};
367/// # use core::ops::ControlFlow;
368/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
369/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
370/// .unwrap();
371///
372/// // visit statements, capturing relations (table names)
373/// let mut visited = vec![];
374/// visit_relations(&statements, |relation| {
375/// visited.push(format!("RELATION: {}", relation));
376/// ControlFlow::<()>::Continue(())
377/// });
378///
379/// let expected : Vec<_> = [
380/// "RELATION: foo",
381/// "RELATION: bar",
382/// ]
383/// .into_iter().map(|s| s.to_string()).collect();
384///
385/// assert_eq!(visited, expected);
386/// ```
387pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
388where
389 V: Visit,
390 F: FnMut(&ObjectName) -> ControlFlow<E>,
391{
392 let mut visitor = RelationVisitor(f);
393 v.visit(&mut visitor)?;
394 ControlFlow::Continue(())
395}
396
397/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
398/// present in `v`.
399///
400/// When the closure mutates its argument, the new mutated relation will not be visited again.
401///
402/// # Example
403/// ```
404/// # use sqltk_parser::parser::Parser;
405/// # use sqltk_parser::dialect::GenericDialect;
406/// # use sqltk_parser::ast::{ObjectName, visit_relations_mut};
407/// # use core::ops::ControlFlow;
408/// let sql = "SELECT a FROM foo";
409/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
410/// .unwrap();
411///
412/// // visit statements, renaming table foo to bar
413/// visit_relations_mut(&mut statements, |table| {
414/// table.0[0].value = table.0[0].value.replace("foo", "bar");
415/// ControlFlow::<()>::Continue(())
416/// });
417///
418/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
419/// ```
420pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
421where
422 V: VisitMut,
423 F: FnMut(&mut ObjectName) -> ControlFlow<E>,
424{
425 let mut visitor = RelationVisitor(f);
426 v.visit(&mut visitor)?;
427 ControlFlow::Continue(())
428}
429
430struct ExprVisitor<F>(F);
431
432impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
433 type Break = E;
434
435 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
436 self.0(expr)
437 }
438}
439
440impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
441 type Break = E;
442
443 fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
444 self.0(expr)
445 }
446}
447
448/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
449///
450/// # Example
451/// ```
452/// # use sqltk_parser::parser::Parser;
453/// # use sqltk_parser::dialect::GenericDialect;
454/// # use sqltk_parser::ast::{visit_expressions};
455/// # use core::ops::ControlFlow;
456/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
457/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
458/// .unwrap();
459///
460/// // visit all expressions
461/// let mut visited = vec![];
462/// visit_expressions(&statements, |expr| {
463/// visited.push(format!("EXPR: {}", expr));
464/// ControlFlow::<()>::Continue(())
465/// });
466///
467/// let expected : Vec<_> = [
468/// "EXPR: a",
469/// "EXPR: x IN (SELECT y FROM bar)",
470/// "EXPR: x",
471/// "EXPR: y",
472/// ]
473/// .into_iter().map(|s| s.to_string()).collect();
474///
475/// assert_eq!(visited, expected);
476/// ```
477pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
478where
479 V: Visit,
480 F: FnMut(&Expr) -> ControlFlow<E>,
481{
482 let mut visitor = ExprVisitor(f);
483 v.visit(&mut visitor)?;
484 ControlFlow::Continue(())
485}
486
487/// Invokes the provided closure iteratively with a mutable reference to all expressions
488/// present in `v`.
489///
490/// This performs a depth-first search, so if the closure mutates the expression
491///
492/// # Example
493///
494/// ## Remove all select limits in sub-queries
495/// ```
496/// # use sqltk_parser::parser::Parser;
497/// # use sqltk_parser::dialect::GenericDialect;
498/// # use sqltk_parser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
499/// # use core::ops::ControlFlow;
500/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
501/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
502///
503/// // Remove all select limits in sub-queries
504/// visit_expressions_mut(&mut statements, |expr| {
505/// if let Expr::Subquery(q) = expr {
506/// q.limit = None
507/// }
508/// ControlFlow::<()>::Continue(())
509/// });
510///
511/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
512/// ```
513///
514/// ## Wrap column name in function call
515///
516/// This demonstrates how to effectively replace an expression with another more complicated one
517/// that references the original. This example avoids unnecessary allocations by using the
518/// [`std::mem`] family of functions.
519///
520/// ```
521/// # use sqltk_parser::parser::Parser;
522/// # use sqltk_parser::dialect::GenericDialect;
523/// # use sqltk_parser::ast::*;
524/// # use core::ops::ControlFlow;
525/// let sql = "SELECT x, y FROM t";
526/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
527///
528/// visit_expressions_mut(&mut statements, |expr| {
529/// if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
530/// let old_expr = std::mem::replace(expr, Expr::Value(Value::Null));
531/// *expr = Expr::Function(Function {
532/// name: ObjectName(vec![Ident::new("f")]),
533/// args: FunctionArguments::List(FunctionArgumentList {
534/// duplicate_treatment: None,
535/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
536/// clauses: vec![],
537/// }),
538/// null_treatment: None,
539/// filter: None,
540/// over: None,
541/// parameters: FunctionArguments::None,
542/// within_group: vec![],
543/// });
544/// }
545/// ControlFlow::<()>::Continue(())
546/// });
547///
548/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
549/// ```
550pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
551where
552 V: VisitMut,
553 F: FnMut(&mut Expr) -> ControlFlow<E>,
554{
555 v.visit(&mut ExprVisitor(f))?;
556 ControlFlow::Continue(())
557}
558
559struct StatementVisitor<F>(F);
560
561impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
562 type Break = E;
563
564 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
565 self.0(statement)
566 }
567}
568
569impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
570 type Break = E;
571
572 fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
573 self.0(statement)
574 }
575}
576
577/// Invokes the provided closure iteratively with a mutable reference to all statements
578/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
579///
580/// # Example
581/// ```
582/// # use sqltk_parser::parser::Parser;
583/// # use sqltk_parser::dialect::GenericDialect;
584/// # use sqltk_parser::ast::{visit_statements};
585/// # use core::ops::ControlFlow;
586/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
587/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
588/// .unwrap();
589///
590/// // visit all statements
591/// let mut visited = vec![];
592/// visit_statements(&statements, |stmt| {
593/// visited.push(format!("STATEMENT: {}", stmt));
594/// ControlFlow::<()>::Continue(())
595/// });
596///
597/// let expected : Vec<_> = [
598/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
599/// "STATEMENT: CREATE TABLE baz (q INT)"
600/// ]
601/// .into_iter().map(|s| s.to_string()).collect();
602///
603/// assert_eq!(visited, expected);
604/// ```
605pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
606where
607 V: Visit,
608 F: FnMut(&Statement) -> ControlFlow<E>,
609{
610 let mut visitor = StatementVisitor(f);
611 v.visit(&mut visitor)?;
612 ControlFlow::Continue(())
613}
614
615/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
616///
617/// # Example
618/// ```
619/// # use sqltk_parser::parser::Parser;
620/// # use sqltk_parser::dialect::GenericDialect;
621/// # use sqltk_parser::ast::{Statement, visit_statements_mut};
622/// # use core::ops::ControlFlow;
623/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
624/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
625///
626/// // Remove all select limits in outer statements (not in sub-queries)
627/// visit_statements_mut(&mut statements, |stmt| {
628/// if let Statement::Query(q) = stmt {
629/// q.limit = None
630/// }
631/// ControlFlow::<()>::Continue(())
632/// });
633///
634/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
635/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
636/// ```
637pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
638where
639 V: VisitMut,
640 F: FnMut(&mut Statement) -> ControlFlow<E>,
641{
642 v.visit(&mut StatementVisitor(f))?;
643 ControlFlow::Continue(())
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use crate::dialect::GenericDialect;
650 use crate::parser::Parser;
651 use crate::tokenizer::Tokenizer;
652
653 #[derive(Default)]
654 struct TestVisitor {
655 visited: Vec<String>,
656 }
657
658 impl Visitor for TestVisitor {
659 type Break = ();
660
661 /// Invoked for any queries that appear in the AST before visiting children
662 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
663 self.visited.push(format!("PRE: QUERY: {query}"));
664 ControlFlow::Continue(())
665 }
666
667 /// Invoked for any queries that appear in the AST after visiting children
668 fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
669 self.visited.push(format!("POST: QUERY: {query}"));
670 ControlFlow::Continue(())
671 }
672
673 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
674 self.visited.push(format!("PRE: RELATION: {relation}"));
675 ControlFlow::Continue(())
676 }
677
678 fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
679 self.visited.push(format!("POST: RELATION: {relation}"));
680 ControlFlow::Continue(())
681 }
682
683 fn pre_visit_table_factor(
684 &mut self,
685 table_factor: &TableFactor,
686 ) -> ControlFlow<Self::Break> {
687 self.visited
688 .push(format!("PRE: TABLE FACTOR: {table_factor}"));
689 ControlFlow::Continue(())
690 }
691
692 fn post_visit_table_factor(
693 &mut self,
694 table_factor: &TableFactor,
695 ) -> ControlFlow<Self::Break> {
696 self.visited
697 .push(format!("POST: TABLE FACTOR: {table_factor}"));
698 ControlFlow::Continue(())
699 }
700
701 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
702 self.visited.push(format!("PRE: EXPR: {expr}"));
703 ControlFlow::Continue(())
704 }
705
706 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
707 self.visited.push(format!("POST: EXPR: {expr}"));
708 ControlFlow::Continue(())
709 }
710
711 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
712 self.visited.push(format!("PRE: STATEMENT: {statement}"));
713 ControlFlow::Continue(())
714 }
715
716 fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
717 self.visited.push(format!("POST: STATEMENT: {statement}"));
718 ControlFlow::Continue(())
719 }
720 }
721
722 fn do_visit(sql: &str) -> Vec<String> {
723 let dialect = GenericDialect {};
724 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
725 let s = Parser::new(&dialect)
726 .with_tokens(tokens)
727 .parse_statement()
728 .unwrap();
729
730 let mut visitor = TestVisitor::default();
731 s.visit(&mut visitor);
732 visitor.visited
733 }
734
735 #[test]
736 fn test_sql() {
737 let tests = vec![
738 (
739 "SELECT * from table_name as my_table",
740 vec![
741 "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
742 "PRE: QUERY: SELECT * FROM table_name AS my_table",
743 "PRE: TABLE FACTOR: table_name AS my_table",
744 "PRE: RELATION: table_name",
745 "POST: RELATION: table_name",
746 "POST: TABLE FACTOR: table_name AS my_table",
747 "POST: QUERY: SELECT * FROM table_name AS my_table",
748 "POST: STATEMENT: SELECT * FROM table_name AS my_table",
749 ],
750 ),
751 (
752 "SELECT * from t1 join t2 on t1.id = t2.t1_id",
753 vec![
754 "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
755 "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
756 "PRE: TABLE FACTOR: t1",
757 "PRE: RELATION: t1",
758 "POST: RELATION: t1",
759 "POST: TABLE FACTOR: t1",
760 "PRE: TABLE FACTOR: t2",
761 "PRE: RELATION: t2",
762 "POST: RELATION: t2",
763 "POST: TABLE FACTOR: t2",
764 "PRE: EXPR: t1.id = t2.t1_id",
765 "PRE: EXPR: t1.id",
766 "POST: EXPR: t1.id",
767 "PRE: EXPR: t2.t1_id",
768 "POST: EXPR: t2.t1_id",
769 "POST: EXPR: t1.id = t2.t1_id",
770 "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
771 "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
772 ],
773 ),
774 (
775 "SELECT * from t1 where EXISTS(SELECT column from t2)",
776 vec![
777 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
778 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
779 "PRE: TABLE FACTOR: t1",
780 "PRE: RELATION: t1",
781 "POST: RELATION: t1",
782 "POST: TABLE FACTOR: t1",
783 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
784 "PRE: QUERY: SELECT column FROM t2",
785 "PRE: EXPR: column",
786 "POST: EXPR: column",
787 "PRE: TABLE FACTOR: t2",
788 "PRE: RELATION: t2",
789 "POST: RELATION: t2",
790 "POST: TABLE FACTOR: t2",
791 "POST: QUERY: SELECT column FROM t2",
792 "POST: EXPR: EXISTS (SELECT column FROM t2)",
793 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
794 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
795 ],
796 ),
797 (
798 "SELECT * from t1 where EXISTS(SELECT column from t2)",
799 vec![
800 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
801 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
802 "PRE: TABLE FACTOR: t1",
803 "PRE: RELATION: t1",
804 "POST: RELATION: t1",
805 "POST: TABLE FACTOR: t1",
806 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
807 "PRE: QUERY: SELECT column FROM t2",
808 "PRE: EXPR: column",
809 "POST: EXPR: column",
810 "PRE: TABLE FACTOR: t2",
811 "PRE: RELATION: t2",
812 "POST: RELATION: t2",
813 "POST: TABLE FACTOR: t2",
814 "POST: QUERY: SELECT column FROM t2",
815 "POST: EXPR: EXISTS (SELECT column FROM t2)",
816 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
817 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
818 ],
819 ),
820 (
821 "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
822 vec![
823 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
824 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
825 "PRE: TABLE FACTOR: t1",
826 "PRE: RELATION: t1",
827 "POST: RELATION: t1",
828 "POST: TABLE FACTOR: t1",
829 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
830 "PRE: QUERY: SELECT column FROM t2",
831 "PRE: EXPR: column",
832 "POST: EXPR: column",
833 "PRE: TABLE FACTOR: t2",
834 "PRE: RELATION: t2",
835 "POST: RELATION: t2",
836 "POST: TABLE FACTOR: t2",
837 "POST: QUERY: SELECT column FROM t2",
838 "POST: EXPR: EXISTS (SELECT column FROM t2)",
839 "PRE: TABLE FACTOR: t3",
840 "PRE: RELATION: t3",
841 "POST: RELATION: t3",
842 "POST: TABLE FACTOR: t3",
843 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
844 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
845 ],
846 ),
847 (
848 concat!(
849 "SELECT * FROM monthly_sales ",
850 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
851 "ORDER BY EMPID"
852 ),
853 vec![
854 "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
855 "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
856 "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
857 "PRE: TABLE FACTOR: monthly_sales",
858 "PRE: RELATION: monthly_sales",
859 "POST: RELATION: monthly_sales",
860 "POST: TABLE FACTOR: monthly_sales",
861 "PRE: EXPR: SUM(a.amount)",
862 "PRE: EXPR: a.amount",
863 "POST: EXPR: a.amount",
864 "POST: EXPR: SUM(a.amount)",
865 "PRE: EXPR: 'JAN'",
866 "POST: EXPR: 'JAN'",
867 "PRE: EXPR: 'FEB'",
868 "POST: EXPR: 'FEB'",
869 "PRE: EXPR: 'MAR'",
870 "POST: EXPR: 'MAR'",
871 "PRE: EXPR: 'APR'",
872 "POST: EXPR: 'APR'",
873 "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
874 "PRE: EXPR: EMPID",
875 "POST: EXPR: EMPID",
876 "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
877 "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
878 ]
879 )
880 ];
881 for (sql, expected) in tests {
882 let actual = do_visit(sql);
883 let actual: Vec<_> = actual.iter().map(|x| x.as_str()).collect();
884 assert_eq!(actual, expected)
885 }
886 }
887}