sql_schema/
lib.rs

1use std::fmt;
2
3use bon::bon;
4use diff::Diff;
5use migration::Migrate;
6use sqlparser::{
7    ast::Statement,
8    dialect::{self},
9    parser::{self, Parser},
10};
11use thiserror::Error;
12
13mod diff;
14mod migration;
15pub mod name_gen;
16pub mod path_template;
17
18#[derive(Error, Debug)]
19#[error("Oops, we couldn't parse that!")]
20pub struct ParseError(#[from] parser::ParserError);
21
22#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
23#[cfg_attr(feature = "clap", derive(clap::ValueEnum), clap(rename_all = "lower"))]
24#[non_exhaustive]
25pub enum Dialect {
26    Ansi,
27    BigQuery,
28    ClickHouse,
29    Databricks,
30    DuckDb,
31    #[default]
32    Generic,
33    Hive,
34    MsSql,
35    MySql,
36    PostgreSql,
37    RedshiftSql,
38    Snowflake,
39    SQLite,
40}
41
42impl Dialect {
43    fn to_sqlparser_dialect(self) -> Box<dyn dialect::Dialect> {
44        match self {
45            Self::Ansi => Box::new(dialect::AnsiDialect {}),
46            Self::BigQuery => Box::new(dialect::BigQueryDialect {}),
47            Self::ClickHouse => Box::new(dialect::ClickHouseDialect {}),
48            Self::Databricks => Box::new(dialect::DatabricksDialect {}),
49            Self::DuckDb => Box::new(dialect::DuckDbDialect {}),
50            Self::Generic => Box::new(dialect::GenericDialect {}),
51            Self::Hive => Box::new(dialect::HiveDialect {}),
52            Self::MsSql => Box::new(dialect::MsSqlDialect {}),
53            Self::MySql => Box::new(dialect::MySqlDialect {}),
54            Self::PostgreSql => Box::new(dialect::PostgreSqlDialect {}),
55            Self::RedshiftSql => Box::new(dialect::RedshiftSqlDialect {}),
56            Self::Snowflake => Box::new(dialect::SnowflakeDialect {}),
57            Self::SQLite => Box::new(dialect::SQLiteDialect {}),
58        }
59    }
60}
61
62impl fmt::Display for Dialect {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        // NOTE: this must match how clap::ValueEnum displays variants
65        write!(
66            f,
67            "{}",
68            format!("{self:?}")
69                .to_ascii_lowercase()
70                .split('-')
71                .collect::<String>()
72        )
73    }
74}
75
76#[derive(Debug, Clone)]
77pub struct SyntaxTree(pub(crate) Vec<Statement>);
78
79#[bon]
80impl SyntaxTree {
81    #[builder]
82    pub fn new<'a>(dialect: Option<Dialect>, sql: impl Into<&'a str>) -> Result<Self, ParseError> {
83        let dialect = dialect.unwrap_or_default().to_sqlparser_dialect();
84        let ast = Parser::parse_sql(dialect.as_ref(), sql.into())?;
85        Ok(Self(ast))
86    }
87
88    pub fn empty() -> Self {
89        Self(vec![])
90    }
91}
92
93pub use diff::DiffError;
94pub use migration::MigrateError;
95
96impl SyntaxTree {
97    pub fn diff(&self, other: &SyntaxTree) -> Result<Option<Self>, DiffError> {
98        Ok(Diff::diff(&self.0, &other.0)?.map(Self))
99    }
100
101    pub fn migrate(self, other: &SyntaxTree) -> Result<Option<Self>, MigrateError> {
102        Ok(Migrate::migrate(self.0, &other.0)?.map(Self))
103    }
104}
105
106impl fmt::Display for SyntaxTree {
107    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108        let mut iter = self.0.iter().peekable();
109        while let Some(s) = iter.next() {
110            let formatted = sqlformat::format(
111                format!("{s};").as_str(),
112                &sqlformat::QueryParams::None,
113                &sqlformat::FormatOptions::default(),
114            );
115            write!(f, "{formatted}")?;
116            if iter.peek().is_some() {
117                write!(f, "\n\n")?;
118            }
119        }
120        Ok(())
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[derive(Debug)]
129    struct TestCase {
130        dialect: Dialect,
131        sql_a: &'static str,
132        sql_b: &'static str,
133        expect: &'static str,
134    }
135
136    fn run_test_case<F>(tc: &TestCase, testfn: F)
137    where
138        F: Fn(SyntaxTree, SyntaxTree) -> SyntaxTree,
139    {
140        let ast_a = SyntaxTree::builder()
141            .dialect(tc.dialect.clone())
142            .sql(tc.sql_a)
143            .build()
144            .unwrap();
145        let ast_b = SyntaxTree::builder()
146            .dialect(tc.dialect.clone())
147            .sql(tc.sql_b)
148            .build()
149            .unwrap();
150        SyntaxTree::builder()
151            .dialect(tc.dialect.clone())
152            .sql(tc.expect)
153            .build()
154            .expect(format!("invalid SQL: {:?}", tc.expect).as_str());
155        let actual = testfn(ast_a, ast_b);
156        assert_eq!(actual.to_string(), tc.expect, "{tc:?}");
157    }
158
159    fn run_test_cases<F, E: fmt::Debug>(test_cases: Vec<TestCase>, testfn: F)
160    where
161        F: Fn(SyntaxTree, SyntaxTree) -> Result<Option<SyntaxTree>, E>,
162    {
163        test_cases.into_iter().for_each(|tc| {
164            run_test_case(&tc, |ast_a, ast_b| {
165                testfn(ast_a, ast_b)
166                    .inspect_err(|err| eprintln!("Error: {err:?}"))
167                    .unwrap()
168                    .unwrap()
169            })
170        });
171    }
172
173    #[test]
174    fn diff_create_table() {
175        run_test_cases(
176            vec![TestCase {
177                dialect: Dialect::Generic,
178                sql_a: "CREATE TABLE foo(\
179                            id int PRIMARY KEY
180                        )",
181                sql_b: "CREATE TABLE foo(\
182                            id int PRIMARY KEY
183                        );\
184                        CREATE TABLE bar (id INT PRIMARY KEY);",
185                expect: "CREATE TABLE bar (id INT PRIMARY KEY);",
186            }],
187            |ast_a, ast_b| ast_a.diff(&ast_b),
188        );
189    }
190
191    #[test]
192    fn diff_drop_table() {
193        run_test_cases(
194            vec![TestCase {
195                dialect: Dialect::Generic,
196                sql_a: "CREATE TABLE foo(\
197                        id int PRIMARY KEY
198                    );\
199                    CREATE TABLE bar (id INT PRIMARY KEY);",
200                sql_b: "CREATE TABLE foo(\
201                        id int PRIMARY KEY
202                    )",
203                expect: "DROP TABLE bar;",
204            }],
205            |ast_a, ast_b| ast_a.diff(&ast_b),
206        );
207    }
208
209    #[test]
210    fn diff_add_column() {
211        run_test_cases(
212            vec![TestCase {
213                dialect: Dialect::Generic,
214                sql_a: "CREATE TABLE foo(\
215                        id int PRIMARY KEY
216                    )",
217                sql_b: "CREATE TABLE foo(\
218                        id int PRIMARY KEY,
219                        bar text
220                    )",
221                expect: "ALTER TABLE\n  foo\nADD\n  COLUMN bar TEXT;",
222            }],
223            |ast_a, ast_b| ast_a.diff(&ast_b),
224        );
225    }
226
227    #[test]
228    fn diff_drop_column() {
229        run_test_cases(
230            vec![TestCase {
231                dialect: Dialect::Generic,
232                sql_a: "CREATE TABLE foo(\
233                        id int PRIMARY KEY,
234                        bar text
235                    )",
236                sql_b: "CREATE TABLE foo(\
237                        id int PRIMARY KEY
238                    )",
239                expect: "ALTER TABLE\n  foo DROP COLUMN bar;",
240            }],
241            |ast_a, ast_b| ast_a.diff(&ast_b),
242        );
243    }
244
245    #[test]
246    fn diff_create_index() {
247        run_test_cases(
248            vec![
249                TestCase {
250                    dialect: Dialect::Generic,
251                    sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);",
252                    sql_b: "CREATE UNIQUE INDEX title_idx ON films ((lower(title)));",
253                    expect: "DROP INDEX title_idx;\n\nCREATE UNIQUE INDEX title_idx ON films((lower(title)));",
254                },
255                TestCase {
256                    dialect: Dialect::Generic,
257                    sql_a: "CREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films (title);",
258                    sql_b: "CREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films ((lower(title)));",
259                    expect: "DROP INDEX IF EXISTS title_idx;\n\nCREATE UNIQUE INDEX IF NOT EXISTS title_idx ON films((lower(title)));",
260                },
261            ],
262            |ast_a, ast_b| ast_a.diff(&ast_b),
263        );
264    }
265
266    #[test]
267    fn diff_create_type() {
268        run_test_cases(
269            vec![
270                TestCase {
271                    dialect: Dialect::Generic,
272                    sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');",
273                    sql_b: "CREATE TYPE foo AS ENUM ('bar');",
274                    expect: "DROP TYPE bug_status;\n\nCREATE TYPE foo AS ENUM ('bar');",
275                },
276                TestCase {
277                    dialect: Dialect::Generic,
278                    sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');",
279                    sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed');",
280                    expect: "ALTER TYPE bug_status\nADD\n  VALUE 'assigned'\nAFTER\n  'open';",
281                },
282                TestCase {
283                    dialect: Dialect::Generic,
284                    sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');",
285                    sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');",
286                    expect: "ALTER TYPE bug_status\nADD\n  VALUE 'new' BEFORE 'open';",
287                },
288                TestCase {
289                    dialect: Dialect::Generic,
290                    sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');",
291                    sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');",
292                    expect: "ALTER TYPE bug_status\nADD\n  VALUE 'closed';",
293                },
294                TestCase {
295                    dialect: Dialect::Generic,
296                    sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'open');",
297                    sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed');",
298                    expect: "ALTER TYPE bug_status\nADD\n  VALUE 'assigned';\n\nALTER TYPE bug_status\nADD\n  VALUE 'closed';",
299                },
300                TestCase {
301                    dialect: Dialect::Generic,
302                    sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'critical');",
303                    sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'assigned', 'closed', 'critical');",
304                    expect: "ALTER TYPE bug_status\nADD\n  VALUE 'new' BEFORE 'open';\n\nALTER TYPE bug_status\nADD\n  VALUE 'assigned'\nAFTER\n  'open';\n\nALTER TYPE bug_status\nADD\n  VALUE 'closed'\nAFTER\n  'assigned';",
305                },
306                TestCase {
307                    dialect: Dialect::Generic,
308                    sql_a: "CREATE TYPE bug_status AS ENUM ('open');",
309                    sql_b: "CREATE TYPE bug_status AS ENUM ('new', 'open', 'closed');",
310                    expect: "ALTER TYPE bug_status\nADD\n  VALUE 'new' BEFORE 'open';\n\nALTER TYPE bug_status\nADD\n  VALUE 'closed';",
311                },
312            ],
313            |ast_a, ast_b| ast_a.diff(&ast_b),
314        );
315    }
316
317    #[test]
318    fn diff_create_extension() {
319        run_test_cases(
320            vec![TestCase {
321                dialect: Dialect::Generic,
322                sql_a: "CREATE EXTENSION hstore;",
323                sql_b: "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";",
324                expect: "DROP EXTENSION hstore;\n\nCREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";",
325            }],
326            |ast_a, ast_b| ast_a.diff(&ast_b),
327        );
328    }
329
330    #[test]
331    fn apply_create_table() {
332        run_test_cases(
333            vec![TestCase {
334                dialect: Dialect::Generic,
335                sql_a: "CREATE TABLE bar (id INT PRIMARY KEY);",
336                sql_b: "CREATE TABLE foo (id INT PRIMARY KEY);",
337                expect: "CREATE TABLE bar (id INT PRIMARY KEY);\n\nCREATE TABLE foo (id INT PRIMARY KEY);",
338            }],
339            |ast_a, ast_b| ast_a.migrate(&ast_b),
340        );
341    }
342
343    #[test]
344    fn apply_drop_table() {
345        run_test_cases(
346            vec![TestCase {
347                dialect: Dialect::Generic,
348                sql_a: "CREATE TABLE bar (id INT PRIMARY KEY)",
349                sql_b: "DROP TABLE bar; CREATE TABLE foo (id INT PRIMARY KEY)",
350                expect: "CREATE TABLE foo (id INT PRIMARY KEY);",
351            }],
352            |ast_a, ast_b| ast_a.migrate(&ast_b),
353        );
354    }
355
356    #[test]
357    fn apply_alter_table_add_column() {
358        run_test_cases(
359            vec![TestCase {
360                dialect: Dialect::Generic,
361                sql_a: "CREATE TABLE bar (id INT PRIMARY KEY)",
362                sql_b: "ALTER TABLE bar ADD COLUMN bar TEXT",
363                expect: "CREATE TABLE bar (id INT PRIMARY KEY, bar TEXT);",
364            }],
365            |ast_a, ast_b| ast_a.migrate(&ast_b),
366        );
367    }
368
369    #[test]
370    fn apply_alter_table_drop_column() {
371        run_test_cases(
372            vec![TestCase {
373                dialect: Dialect::Generic,
374                sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
375                sql_b: "ALTER TABLE bar DROP COLUMN bar",
376                expect: "CREATE TABLE bar (id INT PRIMARY KEY);",
377            }],
378            |ast_a, ast_b| ast_a.migrate(&ast_b),
379        );
380    }
381
382    #[test]
383    fn apply_alter_table_alter_column() {
384        run_test_cases(
385            vec![
386                TestCase {
387                    dialect: Dialect::Generic,
388                    sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
389                    sql_b: "ALTER TABLE bar ALTER COLUMN bar SET NOT NULL",
390                    expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);",
391                },
392                TestCase {
393                    dialect: Dialect::Generic,
394                    sql_a: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY)",
395                    sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP NOT NULL",
396                    expect: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY);",
397                },
398                TestCase {
399                    dialect: Dialect::Generic,
400                    sql_a: "CREATE TABLE bar (bar TEXT NOT NULL DEFAULT 'foo', id INT PRIMARY KEY)",
401                    sql_b: "ALTER TABLE bar ALTER COLUMN bar DROP DEFAULT",
402                    expect: "CREATE TABLE bar (bar TEXT NOT NULL, id INT PRIMARY KEY);",
403                },
404                TestCase {
405                    dialect: Dialect::Generic,
406                    sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
407                    sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE INTEGER",
408                    expect: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY);",
409                },
410                TestCase {
411                    dialect: Dialect::PostgreSql,
412                    sql_a: "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)",
413                    sql_b: "ALTER TABLE bar ALTER COLUMN bar SET DATA TYPE timestamp with time zone\n USING timestamp with time zone 'epoch' + foo_timestamp * interval '1 second'",
414                    expect: "CREATE TABLE bar (bar TIMESTAMP WITH TIME ZONE, id INT PRIMARY KEY);",
415                },
416                TestCase {
417                    dialect: Dialect::Generic,
418                    sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)",
419                    sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED BY DEFAULT AS IDENTITY",
420                    expect: "CREATE TABLE bar (\n  bar INTEGER GENERATED BY DEFAULT AS IDENTITY,\n  id INT PRIMARY KEY\n);",
421                },
422                TestCase {
423                    dialect: Dialect::Generic,
424                    sql_a: "CREATE TABLE bar (bar INTEGER, id INT PRIMARY KEY)",
425                    sql_b: "ALTER TABLE bar ALTER COLUMN bar ADD GENERATED ALWAYS AS IDENTITY (START WITH 10)",
426                    expect: "CREATE TABLE bar (\n  bar INTEGER GENERATED ALWAYS AS IDENTITY (START WITH 10),\n  id INT PRIMARY KEY\n);",
427                },
428            ],
429            |ast_a, ast_b| ast_a.migrate(&ast_b),
430        );
431    }
432
433    #[test]
434    fn apply_create_index() {
435        run_test_cases(
436            vec![
437                TestCase {
438                    dialect: Dialect::Generic,
439                    sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);",
440                    sql_b: "CREATE INDEX code_idx ON films (code);",
441                    expect: "CREATE UNIQUE INDEX title_idx ON films(title);\n\nCREATE INDEX code_idx ON films(code);",
442                },
443                TestCase {
444                    dialect: Dialect::Generic,
445                    sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);",
446                    sql_b: "DROP INDEX title_idx;",
447                    expect: "",
448                },
449                TestCase {
450                    dialect: Dialect::Generic,
451                    sql_a: "CREATE UNIQUE INDEX title_idx ON films (title);",
452                    sql_b: "DROP INDEX title_idx;CREATE INDEX code_idx ON films (code);",
453                    expect: "CREATE INDEX code_idx ON films(code);",
454                },
455            ],
456            |ast_a, ast_b| ast_a.migrate(&ast_b),
457        );
458    }
459
460    #[test]
461    fn apply_alter_create_type() {
462        run_test_cases(
463            vec![TestCase {
464                dialect: Dialect::Generic,
465                sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');",
466                sql_b: "CREATE TYPE compfoo AS (f1 int, f2 text);",
467                expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');\n\nCREATE TYPE compfoo AS (f1 INT, f2 TEXT);",
468            }],
469            |ast_a, ast_b| ast_a.migrate(&ast_b),
470        );
471    }
472
473    #[test]
474    fn apply_alter_type_rename() {
475        run_test_cases(
476            vec![TestCase {
477                dialect: Dialect::Generic,
478                sql_a: "CREATE TYPE bug_status AS ENUM ('open', 'closed');",
479                sql_b: "ALTER TYPE bug_status RENAME TO issue_status",
480                expect: "CREATE TYPE issue_status AS ENUM ('open', 'closed');",
481            }],
482            |ast_a, ast_b| ast_a.migrate(&ast_b),
483        );
484    }
485
486    #[test]
487    fn apply_alter_type_add_value() {
488        run_test_cases(
489            vec![
490                TestCase {
491                    dialect: Dialect::Generic,
492                    sql_a: "CREATE TYPE bug_status AS ENUM ('open');",
493                    sql_b: "ALTER TYPE bug_status ADD VALUE 'new' BEFORE 'open';",
494                    expect: "CREATE TYPE bug_status AS ENUM ('new', 'open');",
495                },
496                TestCase {
497                    dialect: Dialect::Generic,
498                    sql_a: "CREATE TYPE bug_status AS ENUM ('open');",
499                    sql_b: "ALTER TYPE bug_status ADD VALUE 'closed' AFTER 'open';",
500                    expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');",
501                },
502                TestCase {
503                    dialect: Dialect::Generic,
504                    sql_a: "CREATE TYPE bug_status AS ENUM ('open');",
505                    sql_b: "ALTER TYPE bug_status ADD VALUE 'closed';",
506                    expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');",
507                },
508            ],
509            |ast_a, ast_b| ast_a.migrate(&ast_b),
510        );
511    }
512
513    #[test]
514    fn apply_alter_type_rename_value() {
515        run_test_cases(
516            vec![TestCase {
517                dialect: Dialect::Generic,
518                sql_a: "CREATE TYPE bug_status AS ENUM ('new', 'closed');",
519                sql_b: "ALTER TYPE bug_status RENAME VALUE 'new' TO 'open';",
520                expect: "CREATE TYPE bug_status AS ENUM ('open', 'closed');",
521            }],
522            |ast_a, ast_b| ast_a.migrate(&ast_b),
523        );
524    }
525
526    #[test]
527    fn apply_create_extension() {
528        run_test_cases(
529            vec![TestCase {
530                dialect: Dialect::Generic,
531                sql_a: "CREATE EXTENSION hstore;",
532                sql_b: "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";",
533                expect: "CREATE EXTENSION hstore;\n\nCREATE EXTENSION IF NOT EXISTS \"uuid-ossp\";",
534            }],
535            |ast_a, ast_b| ast_a.migrate(&ast_b),
536        );
537    }
538}