sqlx_models_parser/
test_utils.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13/// This module contains internal utilities used for testing the library.
14/// While technically public, the library's users are not supposed to rely
15/// on this module, as it will change without notice.
16//
17// Integration tests (i.e. everything under `tests/`) import this
18// via `tests/test_utils/mod.rs`.
19
20#[cfg(not(feature = "std"))]
21use alloc::{
22    boxed::Box,
23    string::{String, ToString},
24    vec,
25    vec::Vec,
26};
27use core::fmt::Debug;
28
29use crate::ast::*;
30use crate::dialect::*;
31use crate::parser::{Parser, ParserError};
32use crate::tokenizer::Tokenizer;
33
34/// Tests use the methods on this struct to invoke the parser on one or
35/// multiple dialects.
36pub struct TestedDialects {
37    pub dialects: Vec<Box<dyn Dialect>>,
38}
39
40impl TestedDialects {
41    /// Run the given function for all of `self.dialects`, assert that they
42    /// return the same result, and return that result.
43    pub fn one_of_identical_results<F, T: Debug + PartialEq>(&self, f: F) -> T
44    where
45        F: Fn(&dyn Dialect) -> T,
46    {
47        let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect)));
48        parse_results
49            .fold(None, |s, (dialect, parsed)| {
50                if let Some((prev_dialect, prev_parsed)) = s {
51                    assert_eq!(
52                        prev_parsed, parsed,
53                        "Parse results with {:?} are different from {:?}",
54                        prev_dialect, dialect
55                    );
56                }
57                Some((dialect, parsed))
58            })
59            .unwrap()
60            .1
61    }
62
63    pub fn run_parser_method<F, T: Debug + PartialEq>(&self, sql: &str, f: F) -> T
64    where
65        F: Fn(&mut Parser) -> T,
66    {
67        self.one_of_identical_results(|dialect| {
68            let mut tokenizer = Tokenizer::new(dialect, sql);
69            let tokens = tokenizer.tokenize().unwrap();
70            f(&mut Parser::new(tokens, dialect))
71        })
72    }
73
74    pub fn parse_sql_statements(&self, sql: &str) -> Result<Vec<Statement>, ParserError> {
75        self.one_of_identical_results(|dialect| Parser::parse_sql(dialect, sql))
76        // To fail the `ensure_multiple_dialects_are_tested` test:
77        // Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
78    }
79
80    /// Ensures that `sql` parses as a single statement and returns it.
81    /// If non-empty `canonical` SQL representation is provided,
82    /// additionally asserts that parsing `sql` results in the same parse
83    /// tree as parsing `canonical`, and that serializing it back to string
84    /// results in the `canonical` representation.
85    pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
86        let mut statements = self.parse_sql_statements(sql).unwrap();
87        assert_eq!(statements.len(), 1);
88
89        if !canonical.is_empty() && sql != canonical {
90            assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
91        }
92
93        let only_statement = statements.pop().unwrap();
94        if !canonical.is_empty() {
95            assert_eq!(canonical, only_statement.to_string())
96        }
97        only_statement
98    }
99
100    /// Ensures that `sql` parses as a single [Statement], and is not modified
101    /// after a serialization round-trip.
102    pub fn verified_stmt(&self, query: &str) -> Statement {
103        self.one_statement_parses_to(query, query)
104    }
105
106    /// Ensures that `sql` parses as a single [Query], and is not modified
107    /// after a serialization round-trip.
108    pub fn verified_query(&self, sql: &str) -> Query {
109        match self.verified_stmt(sql) {
110            Statement::Query(query) => *query,
111            _ => panic!("Expected Query"),
112        }
113    }
114
115    /// Ensures that `sql` parses as a single [Select], and is not modified
116    /// after a serialization round-trip.
117    pub fn verified_only_select(&self, query: &str) -> Select {
118        match self.verified_query(query).body {
119            SetExpr::Select(s) => *s,
120            _ => panic!("Expected SetExpr::Select"),
121        }
122    }
123
124    /// Ensures that `sql` parses as an expression, and is not modified
125    /// after a serialization round-trip.
126    pub fn verified_expr(&self, sql: &str) -> Expr {
127        let ast = self
128            .run_parser_method(sql, |parser| parser.parse_expr())
129            .unwrap();
130        assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
131        ast
132    }
133}
134
135pub fn all_dialects() -> TestedDialects {
136    TestedDialects {
137        dialects: vec![
138            Box::new(GenericDialect {}),
139            Box::new(PostgreSqlDialect {}),
140            Box::new(MsSqlDialect {}),
141            Box::new(AnsiDialect {}),
142            Box::new(SnowflakeDialect {}),
143            Box::new(HiveDialect {}),
144        ],
145    }
146}
147
148pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
149    let mut iter = v.into_iter();
150    if let (Some(item), None) = (iter.next(), iter.next()) {
151        item
152    } else {
153        panic!("only called on collection without exactly one item")
154    }
155}
156
157pub fn expr_from_projection(item: &SelectItem) -> &Expr {
158    match item {
159        SelectItem::UnnamedExpr(expr) => expr,
160        _ => panic!("Expected UnnamedExpr"),
161    }
162}
163
164pub fn number(n: &'static str) -> Value {
165    Value::Number(n.parse().unwrap(), false)
166}
167
168pub fn table_alias(name: impl Into<String>) -> Option<TableAlias> {
169    Some(TableAlias {
170        name: Ident::new(name),
171        columns: vec![],
172    })
173}
174
175pub fn table(name: impl Into<String>) -> TableFactor {
176    TableFactor::Table {
177        name: ObjectName(vec![Ident::new(name.into())]),
178        alias: None,
179        args: vec![],
180        with_hints: vec![],
181    }
182}
183
184pub fn join(relation: TableFactor) -> Join {
185    Join {
186        relation,
187        join_operator: JoinOperator::Inner(JoinConstraint::Natural),
188    }
189}