1#[cfg(not(feature = "std"))]
21use alloc::{
22 boxed::Box,
23 string::{String, ToString},
24 vec,
25 vec::Vec,
26};
27use core::fmt::Debug;
28
29use crate::ast::*;
30use crate::dialect::*;
31use crate::parser::{Parser, ParserError};
32
33pub struct TestedDialects {
36 pub dialects: Vec<Box<dyn Dialect>>,
37}
38
39impl TestedDialects {
40 pub fn one_of_identical_results<F, T: Debug + PartialEq>(&self, f: F) -> T
43 where
44 F: Fn(&dyn Dialect) -> T,
45 {
46 let parse_results = self.dialects.iter().map(|dialect| (dialect, f(&**dialect)));
47 parse_results
48 .fold(None, |s, (dialect, parsed)| {
49 if let Some((prev_dialect, prev_parsed)) = s {
50 assert_eq!(
51 prev_parsed, parsed,
52 "Parse results with {prev_dialect:?} are different from {dialect:?}"
53 );
54 }
55 Some((dialect, parsed))
56 })
57 .unwrap()
58 .1
59 }
60
61 pub fn run_parser_method<F, T: Debug + PartialEq>(&self, sql: &str, f: F) -> T
62 where
63 F: Fn(&mut Parser) -> T,
64 {
65 self.one_of_identical_results(|dialect| {
66 let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
67 f(&mut parser)
68 })
69 }
70
71 pub fn parse_sql_statements(&self, sql: &str) -> Result<Vec<Statement>, ParserError> {
74 self.one_of_identical_results(|dialect| Parser::parse_sql(dialect, sql))
75 }
78
79 pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
91 let mut statements = self.parse_sql_statements(sql).unwrap();
92 assert_eq!(statements.len(), 1);
93
94 if !canonical.is_empty() && sql != canonical {
95 assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
96 }
97
98 let only_statement = statements.pop().unwrap();
99 if !canonical.is_empty() {
100 assert_eq!(canonical, only_statement.to_string())
101 }
102 only_statement
103 }
104
105 pub fn verified_stmt(&self, sql: &str) -> Statement {
109 self.one_statement_parses_to(sql, sql)
110 }
111
112 pub fn verified_query(&self, sql: &str) -> Query {
116 match self.verified_stmt(sql) {
117 Statement::Query(query) => *query,
118 _ => panic!("Expected Query"),
119 }
120 }
121
122 pub fn verified_only_select(&self, query: &str) -> Select {
126 match *self.verified_query(query).body {
127 SetExpr::Select(s) => *s,
128 _ => panic!("Expected SetExpr::Select"),
129 }
130 }
131
132 pub fn verified_expr(&self, sql: &str) -> Expr {
136 let ast = self
137 .run_parser_method(sql, |parser| parser.parse_expr())
138 .unwrap();
139 assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
140 ast
141 }
142}
143
144pub fn all_dialects() -> TestedDialects {
145 TestedDialects {
146 dialects: vec![
147 Box::new(GenericDialect {}),
148 Box::new(PostgreSqlDialect {}),
149 Box::new(MsSqlDialect {}),
150 Box::new(AnsiDialect {}),
151 Box::new(SnowflakeDialect {}),
152 Box::new(HiveDialect {}),
153 Box::new(RedshiftSqlDialect {}),
154 Box::new(MySqlDialect {}),
155 Box::new(BigQueryDialect {}),
156 Box::new(SQLiteDialect {}),
157 ],
158 }
159}
160
161pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
162 assert_eq!(
163 expected,
164 actual.iter().map(ToString::to_string).collect::<Vec<_>>()
165 );
166}
167
168pub fn only<T>(v: impl IntoIterator<Item = T>) -> T {
169 let mut iter = v.into_iter();
170 if let (Some(item), None) = (iter.next(), iter.next()) {
171 item
172 } else {
173 panic!("only called on collection without exactly one item")
174 }
175}
176
177pub fn expr_from_projection(item: &SelectItem) -> &Expr {
178 match item {
179 SelectItem::UnnamedExpr(expr) => expr,
180 _ => panic!("Expected UnnamedExpr"),
181 }
182}
183
184pub fn number(n: &'static str) -> Value {
185 Value::Number(n.parse().unwrap(), false)
186}
187
188pub fn table_alias(name: impl Into<String>) -> Option<TableAlias> {
189 Some(TableAlias {
190 name: Ident::new(name),
191 columns: vec![],
192 })
193}
194
195pub fn table(name: impl Into<String>) -> TableFactor {
196 TableFactor::Table {
197 name: ObjectName(vec![Ident::new(name.into())]),
198 alias: None,
199 args: None,
200 with_hints: vec![],
201 }
202}
203
204pub fn join(relation: TableFactor) -> Join {
205 Join {
206 relation,
207 join_operator: JoinOperator::Inner(JoinConstraint::Natural),
208 }
209}