sqlparser/
test_utils.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13/// This module contains internal utilities used for testing the library.
14/// While technically public, the library's users are not supposed to rely
15/// on this module, as it will change without notice.
16//
17// Integration tests (i.e. everything under `tests/`) import this
18// via `tests/test_utils/helpers`.
19
20#[cfg(not(feature = "std"))]
21use alloc::{
22    boxed::Box,
23    string::{String, ToString},
24    vec,
25    vec::Vec,
26};
27use core::fmt::Debug;
28
29use crate::ast::*;
30use crate::dialect::*;
31use crate::parser::{Parser, ParserError};
32
33/// Tests use the methods on this struct to invoke the parser on one or
34/// multiple dialects.
35pub struct TestedDialects {
36    pub dialects: Vec<Box<dyn Dialect>>,
37}
38
39impl TestedDialects {
40    /// Run the given function for all of `self.dialects`, assert that they
41    /// return the same result, and return that result.
42    pub fn one_of_identical_results<F, T: Debug + PartialEq>(&self, f: F) -> T
43    where
44        F: Fn(&dyn Dialect) -> T,
45    {
46        let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect)));
47        parse_results
48            .fold(None, |s, (dialect, parsed)| {
49                if let Some((prev_dialect, prev_parsed)) = s {
50                    assert_eq!(
51                        prev_parsed, parsed,
52                        "Parse results with {prev_dialect:?} are different from {dialect:?}"
53                    );
54                }
55                Some((dialect, parsed))
56            })
57            .unwrap()
58            .1
59    }
60
61    pub fn run_parser_method<F, T: Debug + PartialEq>(&self, sql: &str, f: F) -> T
62    where
63        F: Fn(&mut Parser) -> T,
64    {
65        self.one_of_identical_results(|dialect| {
66            let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
67            f(&mut parser)
68        })
69    }
70
71    /// Parses a single SQL string into multiple statements, ensuring
72    /// the result is the same for all tested dialects.
73    pub fn parse_sql_statements(&self, sql: &str) -> Result<Vec<Statement>, ParserError> {
74        self.one_of_identical_results(|dialect| Parser::parse_sql(dialect, sql))
75        // To fail the `ensure_multiple_dialects_are_tested` test:
76        // Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
77    }
78
79    /// Ensures that `sql` parses as a single [Statement] for all tested
80    /// dialects.
81    ///
82    /// If `canonical` is non empty,this function additionally asserts
83    /// that:
84    ///
85    /// 1. parsing `sql` results in the same [`Statement`] as parsing
86    /// `canonical`.
87    ///
88    /// 2. re-serializing the result of parsing `sql` produces the same
89    /// `canonical` sql string
90    pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
91        let mut statements = self.parse_sql_statements(sql).unwrap();
92        assert_eq!(statements.len(), 1);
93
94        if !canonical.is_empty() && sql != canonical {
95            assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
96        }
97
98        let only_statement = statements.pop().unwrap();
99        if !canonical.is_empty() {
100            assert_eq!(canonical, only_statement.to_string())
101        }
102        only_statement
103    }
104
105    /// Ensures that `sql` parses as a single [Statement], and that
106    /// re-serializing the parse result produces the same `sql`
107    /// string (is not modified after a serialization round-trip).
108    pub fn verified_stmt(&self, sql: &str) -> Statement {
109        self.one_statement_parses_to(sql, sql)
110    }
111
112    /// Ensures that `sql` parses as a single [Query], and that
113    /// re-serializing the parse result produces the same `sql`
114    /// string (is not modified after a serialization round-trip).
115    pub fn verified_query(&self, sql: &str) -> Query {
116        match self.verified_stmt(sql) {
117            Statement::Query(query) => *query,
118            _ => panic!("Expected Query"),
119        }
120    }
121
122    /// Ensures that `sql` parses as a single [Select], and that
123    /// re-serializing the parse result produces the same `sql`
124    /// string (is not modified after a serialization round-trip).
125    pub fn verified_only_select(&self, query: &str) -> Select {
126        match *self.verified_query(query).body {
127            SetExpr::Select(s) => *s,
128            _ => panic!("Expected SetExpr::Select"),
129        }
130    }
131
132    /// Ensures that `sql` parses as an [`Expr`], and that
133    /// re-serializing the parse result produces the same `sql`
134    /// string (is not modified after a serialization round-trip).
135    pub fn verified_expr(&self, sql: &str) -> Expr {
136        let ast = self
137            .run_parser_method(sql, |parser| parser.parse_expr())
138            .unwrap();
139        assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
140        ast
141    }
142}
143
144pub fn all_dialects() -> TestedDialects {
145    TestedDialects {
146        dialects: vec![
147            Box::new(GenericDialect {}),
148            Box::new(PostgreSqlDialect {}),
149            Box::new(MsSqlDialect {}),
150            Box::new(AnsiDialect {}),
151            Box::new(SnowflakeDialect {}),
152            Box::new(HiveDialect {}),
153            Box::new(RedshiftSqlDialect {}),
154            Box::new(MySqlDialect {}),
155            Box::new(BigQueryDialect {}),
156            Box::new(SQLiteDialect {}),
157        ],
158    }
159}
160
161pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
162    assert_eq!(
163        expected,
164        actual.iter().map(ToString::to_string).collect::<Vec<_>>()
165    );
166}
167
168pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
169    let mut iter = v.into_iter();
170    if let (Some(item), None) = (iter.next(), iter.next()) {
171        item
172    } else {
173        panic!("only called on collection without exactly one item")
174    }
175}
176
177pub fn expr_from_projection(item: &SelectItem) -> &Expr {
178    match item {
179        SelectItem::UnnamedExpr(expr) => expr,
180        _ => panic!("Expected UnnamedExpr"),
181    }
182}
183
184pub fn number(n: &'static str) -> Value {
185    Value::Number(n.parse().unwrap(), false)
186}
187
188pub fn table_alias(name: impl Into<String>) -> Option<TableAlias> {
189    Some(TableAlias {
190        name: Ident::new(name),
191        columns: vec![],
192    })
193}
194
195pub fn table(name: impl Into<String>) -> TableFactor {
196    TableFactor::Table {
197        name: ObjectName(vec![Ident::new(name.into())]),
198        alias: None,
199        args: None,
200        with_hints: vec![],
201    }
202}
203
204pub fn join(relation: TableFactor) -> Join {
205    Join {
206        relation,
207        join_operator: JoinOperator::Inner(JoinConstraint::Natural),
208    }
209}