1use std::fmt;
2
3use bon::bon;
4use diff::Diff;
5use migration::Migrate;
6use sqlparser::{
7 ast::Statement,
8 dialect::{Dialect, GenericDialect},
9 parser::{Parser, ParserError},
10};
11
12mod diff;
13mod migration;
14
15#[derive(Debug)]
16pub struct SyntaxTree(Vec<Statement>);
17
18#[bon]
19impl SyntaxTree {
20 #[builder]
21 pub fn new<'a>(
22 dialect: Option<&dyn Dialect>,
23 sql: impl Into<&'a str>,
24 ) -> Result<Self, ParserError> {
25 let generic = GenericDialect {};
26 let dialect = dialect.unwrap_or(&generic);
27 let ast = Parser::parse_sql(dialect, sql.into())?;
28 Ok(Self(ast))
29 }
30
31 pub fn empty() -> Self {
32 Self(vec![])
33 }
34}
35
36impl SyntaxTree {
37 pub fn diff(&self, other: &SyntaxTree) -> Option<Self> {
38 Diff::diff(&self.0, &other.0).map(Self)
39 }
40
41 pub fn migrate(self, other: &SyntaxTree) -> Option<Self> {
42 Migrate::migrate(self.0, &other.0).map(Self)
43 }
44}
45
46impl fmt::Display for SyntaxTree {
47 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48 let mut iter = self.0.iter().peekable();
49 while let Some(s) = iter.next() {
50 let formatted = sqlformat::format(
51 format!("{s};").as_str(),
52 &sqlformat::QueryParams::None,
53 &sqlformat::FormatOptions::default(),
54 );
55 write!(f, "{formatted}")?;
56 if iter.peek().is_some() {
57 write!(f, "\n\n")?;
58 }
59 }
60 Ok(())
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67
68 #[test]
69 fn diff_create_table() {
70 let sql_a = "CREATE TABLE foo(\
71 id int PRIMARY KEY
72 )";
73 let sql_b = "CREATE TABLE foo(\
74 id int PRIMARY KEY
75 );\
76 \
77 CREATE TABLE bar (id INT PRIMARY KEY);";
78 let sql_diff = "CREATE TABLE bar (id INT PRIMARY KEY);";
79
80 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
81 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
82 let ast_diff = ast_a.diff(&ast_b);
83
84 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
85 }
86
87 #[test]
88 fn diff_drop_table() {
89 let sql_a = "CREATE TABLE foo(\
90 id int PRIMARY KEY
91 );\
92 \
93 CREATE TABLE bar (id INT PRIMARY KEY);";
94 let sql_b = "CREATE TABLE foo(\
95 id int PRIMARY KEY
96 )";
97 let sql_diff = "DROP TABLE bar;";
98
99 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
100 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
101 let ast_diff = ast_a.diff(&ast_b);
102
103 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
104 }
105
106 #[test]
107 fn diff_add_column() {
108 let sql_a = "CREATE TABLE foo(\
109 id int PRIMARY KEY
110 )";
111 let sql_b = "CREATE TABLE foo(\
112 id int PRIMARY KEY,
113 bar text
114 )";
115 let sql_diff = "ALTER TABLE\n foo\nADD\n COLUMN bar TEXT;";
116
117 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
118 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
119 let ast_diff = ast_a.diff(&ast_b);
120
121 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
122 }
123
124 #[test]
125 fn diff_drop_column() {
126 let sql_a = "CREATE TABLE foo(\
127 id int PRIMARY KEY,
128 bar text
129 )";
130 let sql_b = "CREATE TABLE foo(\
131 id int PRIMARY KEY
132 )";
133 let sql_diff = "ALTER TABLE\n foo DROP COLUMN bar;";
134
135 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
136 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
137 let ast_diff = ast_a.diff(&ast_b);
138
139 assert_eq!(ast_diff.unwrap().to_string(), sql_diff);
140 }
141
142 #[test]
143 fn apply_create_table() {
144 let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY);";
145 let sql_b = "CREATE TABLE foo (id INT PRIMARY KEY);";
146 let sql_res = sql_a.to_owned() + "\n\n" + sql_b;
147
148 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
149 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
150 let ast_res = ast_a.migrate(&ast_b);
151
152 assert_eq!(ast_res.unwrap().to_string(), sql_res);
153 }
154
155 #[test]
156 fn apply_drop_table() {
157 let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
158 let sql_b = "DROP TABLE bar; CREATE TABLE foo (id INT PRIMARY KEY)";
159 let sql_res = "CREATE TABLE foo (id INT PRIMARY KEY);";
160
161 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
162 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
163 let ast_res = ast_a.migrate(&ast_b);
164
165 assert_eq!(ast_res.unwrap().to_string(), sql_res);
166 }
167
168 #[test]
169 fn apply_alter_table_add_column() {
170 let sql_a = "CREATE TABLE bar (id INT PRIMARY KEY)";
171 let sql_b = "ALTER TABLE bar ADD COLUMN bar TEXT";
172 let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY, bar TEXT);";
173
174 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
175 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
176 let ast_res = ast_a.migrate(&ast_b);
177
178 assert_eq!(ast_res.unwrap().to_string(), sql_res);
179 }
180
181 #[test]
182 fn apply_alter_table_drop_column() {
183 let sql_a = "CREATE TABLE bar (bar TEXT, id INT PRIMARY KEY)";
184 let sql_b = "ALTER TABLE bar DROP COLUMN bar";
185 let sql_res = "CREATE TABLE bar (id INT PRIMARY KEY);";
186
187 let ast_a = SyntaxTree::builder().sql(sql_a).build().unwrap();
188 let ast_b = SyntaxTree::builder().sql(sql_b).build().unwrap();
189 let ast_res = ast_a.migrate(&ast_b);
190
191 assert_eq!(ast_res.unwrap().to_string(), sql_res);
192 }
193}