1use std::fmt;
2
3use bon::bon;
4use diff::Diff;
5use migration::Migrate;
6use sqlparser::{
7 ast::Statement,
8 dialect::{self},
9 parser::{Parser, ParserError},
10};
11
12mod diff;
13mod migration;
14
15#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
16#[cfg_attr(feature = "clap", derive(clap::ValueEnum), clap(rename_all = "lower"))]
17#[non_exhaustive]
18pub enum Dialect {
19 Ansi,
20 BigQuery,
21 ClickHouse,
22 Databricks,
23 DuckDb,
24 #[default]
25 Generic,
26 Hive,
27 MsSql,
28 MySql,
29 PostgreSql,
30 RedshiftSql,
31 Snowflake,
32 SQLite,
33}
34
35impl Dialect {
36 fn to_sqlparser_dialect(self) -> Box<dyn dialect::Dialect> {
37 match self {
38 Self::Ansi => Box::new(dialect::AnsiDialect {}),
39 Self::BigQuery => Box::new(dialect::BigQueryDialect {}),
40 Self::ClickHouse => Box::new(dialect::ClickHouseDialect {}),
41 Self::Databricks => Box::new(dialect::DatabricksDialect {}),
42 Self::DuckDb => Box::new(dialect::DuckDbDialect {}),
43 Self::Generic => Box::new(dialect::GenericDialect {}),
44 Self::Hive => Box::new(dialect::HiveDialect {}),
45 Self::MsSql => Box::new(dialect::MsSqlDialect {}),
46 Self::MySql => Box::new(dialect::MySqlDialect {}),
47 Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}),
48 Self::RedshiftSql => Box::new(dialect::RedshiftSqlDialect {}),
49 Self::Snowflake => Box::new(dialect::SnowflakeDialect {}),
50 Self::SQLite => Box::new(dialect::SQLiteDialect {}),
51 }
52 }
53}
54
55impl fmt::Display for Dialect {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(
58 f,
59 "{}",
60 format!("{self:?}")
61 .to_ascii_lowercase()
62 .split('-')
63 .collect::<String>()
64 )
65 }
66}
67
68#[derive(Debug)]
69pub struct SyntaxTree(Vec<Statement>);
70
71#[bon]
72impl SyntaxTree {
73 #[builder]
74 pub fn new<'a>(dialect: Option<Dialect>, sql: impl Into<&'a str>) -> Result<Self, ParserError> {
75 let dialect = dialect.unwrap_or_default().to_sqlparser_dialect();
76 let ast = Parser::parse_sql(dialect.as_ref(), sql.into())?;
77 Ok(Self(ast))
78 }
79
80 pub fn empty() -> Self {
81 Self(vec![])
82 }
83}
84
85impl SyntaxTree {
86 pub fn diff(&self, other: &SyntaxTree) -> Option<Self> {
87 Diff::diff(&self.0, &other.0).map(Self)
88 }
89
90 pub fn migrate(self, other: &SyntaxTree) -> Option<Self> {
91 Migrate::migrate(self.0, &other.0).map(Self)
92 }
93}
94
95impl fmt::Display for SyntaxTree {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 let mut iter = self.0.iter().peekable();
98 while let Some(s) = iter.next() {
99 let formatted = sqlformat::format(
100 format!("{s};").as_str(),
101 &sqlformat::QueryParams::None,
102 &sqlformat::FormatOptions::default(),
103 );
104 write!(f, "{formatted}")?;
105 if iter.peek().is_some() {
106 write!(f, "\n\n")?;
107 }
108 }
109 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn diff_create_table() {
119 let sql_a = "CREATE TABLE foo(\
120 id int PRIMARY KEY
121 )";
122 let sql_b = "CREATE TABLE foo(\
123 id int PRIMARY KEY
124 );\
125 \
126 CREATE TABLE bar (id INT PRIMARY KEY);";
127 let sql_diff = "CREATE TABLE bar (id INT PRIMARY KEY);";
128
129 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
130 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
131 let ast_diff = ast_a.diff(&ast_b);
132
133 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
134 }
135
136 #[test]
137 fn diff_drop_table() {
138 let sql_a = "CREATE TABLE foo(\
139 id int PRIMARY KEY
140 );\
141 \
142 CREATE TABLE bar (id INT PRIMARY KEY);";
143 let sql_b = "CREATE TABLE foo(\
144 id int PRIMARY KEY
145 )";
146 let sql_diff = "DROP TABLE bar;";
147
148 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
149 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
150 let ast_diff = ast_a.diff(&ast_b);
151
152 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
153 }
154
155 #[test]
156 fn diff_add_column() {
157 let sql_a = "CREATE TABLE foo(\
158 id int PRIMARY KEY
159 )";
160 let sql_b = "CREATE TABLE foo(\
161 id int PRIMARY KEY,
162 bar text
163 )";
164 let sql_diff = "ALTER TABLE\n foo\nADD\n COLUMN bar TEXT;";
165
166 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
167 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
168 let ast_diff = ast_a.diff(&ast_b);
169
170 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
171 }
172
173 #[test]
174 fn diff_drop_column() {
175 let sql_a = "CREATE TABLE foo(\
176 id int PRIMARY KEY,
177 bar text
178 )";
179 let sql_b = "CREATE TABLE foo(\
180 id int PRIMARY KEY
181 )";
182 let sql_diff = "ALTER TABLE\n foo DROP COLUMN bar;";
183
184 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
185 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
186 let ast_diff = ast_a.diff(&ast_b);
187
188 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
189 }
190
191 #[test]
192 fn apply_create_table() {
193 let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY);";
194 let sql_b = "CREATE TABLE foo (id INT PRIMARY KEY);";
195 let sql_res = sql_a.to_owned() + "\n\n" + sql_b;
196
197 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
198 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
199 let ast_res = ast_a.migrate(&ast_b);
200
201 assert_eq!(ast_res.unwrap().to_string(), sql_res);
202 }
203
204 #[test]
205 fn apply_drop_table() {
206 let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
207 let sql_b = "DROP TABLE bar; CREATE TABLE foo (id INT PRIMARY KEY)";
208 let sql_res = "CREATE TABLE foo (id INT PRIMARY KEY);";
209
210 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
211 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
212 let ast_res = ast_a.migrate(&ast_b);
213
214 assert_eq!(ast_res.unwrap().to_string(), sql_res);
215 }
216
217 #[test]
218 fn apply_alter_table_add_column() {
219 let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
220 let sql_b = "ALTER TABLE bar ADD COLUMN bar TEXT";
221 let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY, bar TEXT);";
222
223 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
224 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
225 let ast_res = ast_a.migrate(&ast_b);
226
227 assert_eq!(ast_res.unwrap().to_string(), sql_res);
228 }
229
230 #[test]
231 fn apply_alter_table_drop_column() {
232 let sql_a = "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)";
233 let sql_b = "ALTER TABLE bar DROP COLUMN bar";
234 let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY);";
235
236 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
237 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
238 let ast_res = ast_a.migrate(&ast_b);
239
240 assert_eq!(ast_res.unwrap().to_string(), sql_res);
241 }
242
243 #[test]
244 fn apply_alter_table_alter_column() {
245 #[derive(Debug)]
246 struct TestCase {
247 dialect: Dialect,
248 sql_a: &'static str,
249 sql_b: &'static str,
250 expect: &'static str,
251 }
252 let test_cases = vec![
253 TestCase {
254 dialect: Dialect::Generic,
255 sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
256 sql_b: "ALTER TABLE bar ALTER COLUMN bar SET NOT NULL",
257 expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);",
258 },
259 TestCase {
260 dialect: Dialect::Generic,
261 sql_a: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY)",
262 sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP NOT NULL",
263 expect: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY);",
264 },
265 TestCase {
266 dialect: Dialect::Generic,
267 sql_a: "CREATE TABLE bar (bar TEXT NOT NULL DEFAULT 'foo', id INT PRIMARY KEY)",
268 sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP DEFAULT",
269 expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);",
270 },
271 TestCase {
272 dialect: Dialect::Generic,
273 sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
274 sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE INTEGER",
275 expect: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY);",
276 },
277 TestCase {
278 dialect: Dialect::PostgreSql,
279 sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
280 sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE timestamp with time zone\n USING timestamp with time zone 'epoch' + foo_timestamp * interval '1 second'",
281 expect: "CREATE TABLE bar (bar TIMESTAMP WITH TIME ZONE, id INT PRIMARY KEY);",
282 },
283 TestCase {
284 dialect: Dialect::Generic,
285 sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)",
286 sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED BY DEFAULT AS IDENTITY",
287 expect: "CREATE TABLE bar (\n bar INTEGER GENERATED BY DEFAULT AS IDENTITY,\n id INT PRIMARY KEY\n);",
288 },
289 TestCase {
290 dialect: Dialect::Generic,
291 sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)",
292 sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED ALWAYS AS IDENTITY (START WITH 10)",
293 expect: "CREATE TABLE bar (\n bar INTEGER GENERATED ALWAYS AS IDENTITY (START WITH 10),\n id INT PRIMARY KEY\n);",
294 },
295 ];
296
297 test_cases.into_iter().for_each(|tc| {
298 let ast_a = SyntaxTree::builder()
299 .dialect(tc.dialect.clone())
300 .sql(tc.sql_a)
301 .build()
302 .unwrap();
303 let ast_b = SyntaxTree::builder()
304 .dialect(tc.dialect.clone())
305 .sql(tc.sql_b)
306 .build()
307 .unwrap();
308 let ast_res = ast_a.migrate(&ast_b);
309
310 assert_eq!(ast_res.unwrap().to_string(), tc.expect, "{tc:?}");
311 });
312 }
313}