sql_schema/
lib.rs

1use std::fmt;
2
3use bon::bon;
4use diff::Diff;
5use migration::Migrate;
6use sqlparser::{
7    ast::Statement,
8    dialect::{self},
9    parser::{Parser, ParserError},
10};
11
12mod diff;
13mod migration;
14
15#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
16#[cfg_attr(feature = "clap", derive(clap::ValueEnum), clap(rename_all = "lower"))]
17#[non_exhaustive]
18pub enum Dialect {
19    Ansi,
20    BigQuery,
21    ClickHouse,
22    Databricks,
23    DuckDb,
24    #[default]
25    Generic,
26    Hive,
27    MsSql,
28    MySql,
29    PostgreSql,
30    RedshiftSql,
31    Snowflake,
32    SQLite,
33}
34
35impl Dialect {
36    fn to_sqlparser_dialect(self) -> Box<dyn dialect::Dialect> {
37        match self {
38            Self::Ansi => Box::new(dialect::AnsiDialect {}),
39            Self::BigQuery => Box::new(dialect::BigQueryDialect {}),
40            Self::ClickHouse => Box::new(dialect::ClickHouseDialect {}),
41            Self::Databricks => Box::new(dialect::DatabricksDialect {}),
42            Self::DuckDb => Box::new(dialect::DuckDbDialect {}),
43            Self::Generic => Box::new(dialect::GenericDialect {}),
44            Self::Hive => Box::new(dialect::HiveDialect {}),
45            Self::MsSql => Box::new(dialect::MsSqlDialect {}),
46            Self::MySql => Box::new(dialect::MySqlDialect {}),
47            Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}),
48            Self::RedshiftSql => Box::new(dialect::RedshiftSqlDialect {}),
49            Self::Snowflake => Box::new(dialect::SnowflakeDialect {}),
50            Self::SQLite => Box::new(dialect::SQLiteDialect {}),
51        }
52    }
53}
54
55impl fmt::Display for Dialect {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(
58            f,
59            "{}",
60            format!("{self:?}")
61                .to_ascii_lowercase()
62                .split('-')
63                .collect::<String>()
64        )
65    }
66}
67
68#[derive(Debug)]
69pub struct SyntaxTree(Vec<Statement>);
70
71#[bon]
72impl SyntaxTree {
73    #[builder]
74    pub fn new<'a>(dialect: Option<Dialect>, sql: impl Into<&'a str>) -> Result<Self, ParserError> {
75        let dialect = dialect.unwrap_or_default().to_sqlparser_dialect();
76        let ast = Parser::parse_sql(dialect.as_ref(), sql.into())?;
77        Ok(Self(ast))
78    }
79
80    pub fn empty() -> Self {
81        Self(vec![])
82    }
83}
84
85impl SyntaxTree {
86    pub fn diff(&self, other: &SyntaxTree) -> Option<Self> {
87        Diff::diff(&self.0, &other.0).map(Self)
88    }
89
90    pub fn migrate(self, other: &SyntaxTree) -> Option<Self> {
91        Migrate::migrate(self.0, &other.0).map(Self)
92    }
93}
94
95impl fmt::Display for SyntaxTree {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        let mut iter = self.0.iter().peekable();
98        while let Some(s) = iter.next() {
99            let formatted = sqlformat::format(
100                format!("{s};").as_str(),
101                &sqlformat::QueryParams::None,
102                &sqlformat::FormatOptions::default(),
103            );
104            write!(f, "{formatted}")?;
105            if iter.peek().is_some() {
106                write!(f, "\n\n")?;
107            }
108        }
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn diff_create_table() {
119        let sql_a = "CREATE TABLE foo(\
120                id int PRIMARY KEY
121            )";
122        let sql_b = "CREATE TABLE foo(\
123                id int PRIMARY KEY
124            );\
125            \
126            CREATE TABLE bar (id INT PRIMARY KEY);";
127        let sql_diff = "CREATE TABLE bar (id INT PRIMARY KEY);";
128
129        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
130        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
131        let ast_diff = ast_a.diff(&ast_b);
132
133        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
134    }
135
136    #[test]
137    fn diff_drop_table() {
138        let sql_a = "CREATE TABLE foo(\
139                    id int PRIMARY KEY
140                );\
141                \
142                CREATE TABLE bar (id INT PRIMARY KEY);";
143        let sql_b = "CREATE TABLE foo(\
144                    id int PRIMARY KEY
145                )";
146        let sql_diff = "DROP TABLE bar;";
147
148        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
149        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
150        let ast_diff = ast_a.diff(&ast_b);
151
152        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
153    }
154
155    #[test]
156    fn diff_add_column() {
157        let sql_a = "CREATE TABLE foo(\
158                id int PRIMARY KEY
159            )";
160        let sql_b = "CREATE TABLE foo(\
161                id int PRIMARY KEY,
162                bar text
163            )";
164        let sql_diff = "ALTER TABLE\n  foo\nADD\n  COLUMN bar TEXT;";
165
166        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
167        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
168        let ast_diff = ast_a.diff(&ast_b);
169
170        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
171    }
172
173    #[test]
174    fn diff_drop_column() {
175        let sql_a = "CREATE TABLE foo(\
176                    id int PRIMARY KEY,
177                    bar text
178                )";
179        let sql_b = "CREATE TABLE foo(\
180                    id int PRIMARY KEY
181                )";
182        let sql_diff = "ALTER TABLE\n  foo DROP COLUMN bar;";
183
184        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
185        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
186        let ast_diff = ast_a.diff(&ast_b);
187
188        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
189    }
190
191    #[test]
192    fn apply_create_table() {
193        let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY);";
194        let sql_b = "CREATE TABLE foo (id INT PRIMARY KEY);";
195        let sql_res = sql_a.to_owned() + "\n\n" + sql_b;
196
197        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
198        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
199        let ast_res = ast_a.migrate(&ast_b);
200
201        assert_eq!(ast_res.unwrap().to_string(), sql_res);
202    }
203
204    #[test]
205    fn apply_drop_table() {
206        let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
207        let sql_b = "DROP TABLE bar; CREATE TABLE foo (id INT PRIMARY KEY)";
208        let sql_res = "CREATE TABLE foo (id INT PRIMARY KEY);";
209
210        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
211        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
212        let ast_res = ast_a.migrate(&ast_b);
213
214        assert_eq!(ast_res.unwrap().to_string(), sql_res);
215    }
216
217    #[test]
218    fn apply_alter_table_add_column() {
219        let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
220        let sql_b = "ALTER TABLE bar ADD COLUMN bar TEXT";
221        let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY, bar TEXT);";
222
223        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
224        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
225        let ast_res = ast_a.migrate(&ast_b);
226
227        assert_eq!(ast_res.unwrap().to_string(), sql_res);
228    }
229
230    #[test]
231    fn apply_alter_table_drop_column() {
232        let sql_a = "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)";
233        let sql_b = "ALTER TABLE bar DROP COLUMN bar";
234        let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY);";
235
236        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
237        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
238        let ast_res = ast_a.migrate(&ast_b);
239
240        assert_eq!(ast_res.unwrap().to_string(), sql_res);
241    }
242
243    #[test]
244    fn apply_alter_table_alter_column() {
245        #[derive(Debug)]
246        struct TestCase {
247            dialect: Dialect,
248            sql_a: &'static str,
249            sql_b: &'static str,
250            expect: &'static str,
251        }
252        let test_cases = vec![
253            TestCase {
254                dialect: Dialect::Generic,
255                sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
256                sql_b: "ALTER TABLE bar ALTER COLUMN bar SET NOT NULL",
257                expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);",
258            },
259            TestCase {
260                dialect: Dialect::Generic,
261                sql_a: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY)",
262                sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP NOT NULL",
263                expect: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY);",
264            },
265            TestCase {
266                dialect: Dialect::Generic,
267                sql_a: "CREATE TABLE bar (bar TEXT NOT NULL DEFAULT 'foo', id INT PRIMARY KEY)",
268                sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP DEFAULT",
269                expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);",
270            },
271            TestCase {
272                dialect: Dialect::Generic,
273                sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
274                sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE INTEGER",
275                expect: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY);",
276            },
277            TestCase {
278                dialect: Dialect::PostgreSql,
279                sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
280                sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE timestamp with time zone\n USING timestamp with time zone 'epoch' + foo_timestamp * interval '1 second'",
281                expect: "CREATE TABLE bar (bar TIMESTAMP WITH TIME ZONE, id INT PRIMARY KEY);",
282            },
283            TestCase {
284                dialect: Dialect::Generic,
285                sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)",
286                sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED BY DEFAULT AS IDENTITY",
287                expect: "CREATE TABLE bar (\n  bar INTEGER GENERATED BY DEFAULT AS IDENTITY,\n  id INT PRIMARY KEY\n);",
288            },
289            TestCase {
290                dialect: Dialect::Generic,
291                sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)",
292                sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED ALWAYS AS IDENTITY (START WITH 10)",
293                expect: "CREATE TABLE bar (\n  bar INTEGER GENERATED ALWAYS AS IDENTITY (START WITH 10),\n  id INT PRIMARY KEY\n);",
294            },
295        ];
296
297        test_cases.into_iter().for_each(|tc| {
298            let ast_a = SyntaxTree::builder()
299                .dialect(tc.dialect.clone())
300                .sql(tc.sql_a)
301                .build()
302                .unwrap();
303            let ast_b = SyntaxTree::builder()
304                .dialect(tc.dialect.clone())
305                .sql(tc.sql_b)
306                .build()
307                .unwrap();
308            let ast_res = ast_a.migrate(&ast_b);
309
310            assert_eq!(ast_res.unwrap().to_string(), tc.expect, "{tc:?}");
311        });
312    }
313}