sql_schema/
lib.rs

1use std::fmt;
2
3use bon::bon;
4use diff::Diff;
5use migration::Migrate;
6use sqlparser::{
7    ast::Statement,
8    dialect::{Dialect, GenericDialect},
9    parser::{Parser, ParserError},
10};
11
12mod diff;
13mod migration;
14
15#[derive(Debug)]
16pub struct SyntaxTree(Vec<Statement>);
17
18#[bon]
19impl SyntaxTree {
20    #[builder]
21    pub fn new<'a>(
22        dialect: Option<&dyn Dialect>,
23        sql: impl Into<&'a str>,
24    ) -> Result<Self, ParserError> {
25        let generic = GenericDialect {};
26        let dialect = dialect.unwrap_or(&generic);
27        let ast = Parser::parse_sql(dialect, sql.into())?;
28        Ok(Self(ast))
29    }
30
31    pub fn empty() -> Self {
32        Self(vec![])
33    }
34}
35
36impl SyntaxTree {
37    pub fn diff(&self, other: &SyntaxTree) -> Option<Self> {
38        Diff::diff(&self.0, &other.0).map(Self)
39    }
40
41    pub fn migrate(self, other: &SyntaxTree) -> Option<Self> {
42        Migrate::migrate(self.0, &other.0).map(Self)
43    }
44}
45
46impl fmt::Display for SyntaxTree {
47    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        let mut iter = self.0.iter().peekable();
49        while let Some(s) = iter.next() {
50            let formatted = sqlformat::format(
51                format!("{s};").as_str(),
52                &sqlformat::QueryParams::None,
53                &sqlformat::FormatOptions::default(),
54            );
55            write!(f, "{formatted}")?;
56            if iter.peek().is_some() {
57                write!(f, "\n\n")?;
58            }
59        }
60        Ok(())
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67
68    #[test]
69    fn diff_create_table() {
70        let sql_a = "CREATE TABLE foo(\
71                id int PRIMARY KEY
72            )";
73        let sql_b = "CREATE TABLE foo(\
74                id int PRIMARY KEY
75            );\
76            \
77            CREATE TABLE bar (id INT PRIMARY KEY);";
78        let sql_diff = "CREATE TABLE bar (id INT PRIMARY KEY);";
79
80        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
81        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
82        let ast_diff = ast_a.diff(&ast_b);
83
84        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
85    }
86
87    #[test]
88    fn diff_drop_table() {
89        let sql_a = "CREATE TABLE foo(\
90                    id int PRIMARY KEY
91                );\
92                \
93                CREATE TABLE bar (id INT PRIMARY KEY);";
94        let sql_b = "CREATE TABLE foo(\
95                    id int PRIMARY KEY
96                )";
97        let sql_diff = "DROP TABLE bar;";
98
99        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
100        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
101        let ast_diff = ast_a.diff(&ast_b);
102
103        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
104    }
105
106    #[test]
107    fn diff_add_column() {
108        let sql_a = "CREATE TABLE foo(\
109                id int PRIMARY KEY
110            )";
111        let sql_b = "CREATE TABLE foo(\
112                id int PRIMARY KEY,
113                bar text
114            )";
115        let sql_diff = "ALTER TABLE\n  foo\nADD\n  COLUMN bar TEXT;";
116
117        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
118        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
119        let ast_diff = ast_a.diff(&ast_b);
120
121        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
122    }
123
124    #[test]
125    fn diff_drop_column() {
126        let sql_a = "CREATE TABLE foo(\
127                    id int PRIMARY KEY,
128                    bar text
129                )";
130        let sql_b = "CREATE TABLE foo(\
131                    id int PRIMARY KEY
132                )";
133        let sql_diff = "ALTER TABLE\n  foo DROP COLUMN bar;";
134
135        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
136        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
137        let ast_diff = ast_a.diff(&ast_b);
138
139        assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
140    }
141
142    #[test]
143    fn apply_create_table() {
144        let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY);";
145        let sql_b = "CREATE TABLE foo (id INT PRIMARY KEY);";
146        let sql_res = sql_a.to_owned() + "\n\n" + sql_b;
147
148        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
149        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
150        let ast_res = ast_a.migrate(&ast_b);
151
152        assert_eq!(ast_res.unwrap().to_string(), sql_res);
153    }
154
155    #[test]
156    fn apply_drop_table() {
157        let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
158        let sql_b = "DROP TABLE bar; CREATE TABLE foo (id INT PRIMARY KEY)";
159        let sql_res = "CREATE TABLE foo (id INT PRIMARY KEY);";
160
161        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
162        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
163        let ast_res = ast_a.migrate(&ast_b);
164
165        assert_eq!(ast_res.unwrap().to_string(), sql_res);
166    }
167
168    #[test]
169    fn apply_alter_table_add_column() {
170        let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
171        let sql_b = "ALTER TABLE bar ADD COLUMN bar TEXT";
172        let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY, bar TEXT);";
173
174        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
175        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
176        let ast_res = ast_a.migrate(&ast_b);
177
178        assert_eq!(ast_res.unwrap().to_string(), sql_res);
179    }
180
181    #[test]
182    fn apply_alter_table_drop_column() {
183        let sql_a = "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)";
184        let sql_b = "ALTER TABLE bar DROP COLUMN bar";
185        let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY);";
186
187        let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
188        let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
189        let ast_res = ast_a.migrate(&ast_b);
190
191        assert_eq!(ast_res.unwrap().to_string(), sql_res);
192    }
193}