1use std::{cmp::Ordering, collections::HashSet, fmt};
2
3use bon::bon;
4use sqlparser::ast::{
5 AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition,
6 AlterTypeOperation, CreateIndex, CreateTable, Ident, ObjectName, ObjectType, Statement,
7 UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct DiffError {
13 kind: DiffErrorKind,
14 statement_a: Option<Statement>,
15 statement_b: Option<Statement>,
16}
17
18impl fmt::Display for DiffError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(
21 f,
22 "Oops, we couldn't diff that: {reason}",
23 reason = self.kind
24 )?;
25 if let Some(statement_a) = &self.statement_a {
26 write!(f, "\n\nStatement A:\n{statement_a}")?;
27 }
28 if let Some(statement_b) = &self.statement_b {
29 write!(f, "\n\nStatement B:\n{statement_b}")?;
30 }
31 Ok(())
32 }
33}
34
35#[bon]
36impl DiffError {
37 #[builder]
38 fn new(
39 kind: DiffErrorKind,
40 statement_a: Option<Statement>,
41 statement_b: Option<Statement>,
42 ) -> Self {
43 Self {
44 kind,
45 statement_a,
46 statement_b,
47 }
48 }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum DiffErrorKind {
54 #[error("can't drop unnamed index")]
55 DropUnnamedIndex,
56 #[error("can't compare unnamed index")]
57 CompareUnnamedIndex,
58 #[error("removing enum labels is not supported")]
59 RemoveEnumLabel,
60 #[error("not yet supported")]
61 NotImplemented,
62}
63
64pub(crate) trait Diff: Sized {
65 type Diff;
66
67 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError>;
68}
69
70impl Diff for Vec<Statement> {
71 type Diff = Option<Vec<Statement>>;
72
73 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
74 let res = self
75 .iter()
76 .filter_map(|sa| {
77 match sa {
78 Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other),
81 Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other),
82 Statement::CreateType { name, .. } => {
83 find_and_compare_create_type(sa, name, other)
84 }
85 Statement::CreateExtension {
86 name,
87 if_not_exists,
88 cascade,
89 ..
90 } => {
91 find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
92 }
93 _ => Err(DiffError::builder()
94 .kind(DiffErrorKind::NotImplemented)
95 .statement_a(sa.clone())
96 .build()),
97 }
98 .transpose()
99 })
100 .chain(other.iter().filter_map(|sb| {
102 match sb {
103 Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa {
104 Statement::CreateTable(a) => a.name == b.name,
105 _ => false,
106 })),
107 Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa {
108 Statement::CreateIndex(a) => a.name == b.name,
109 _ => false,
110 })),
111 Statement::CreateType { name: b_name, .. } => {
112 Ok(self.iter().find(|sa| match sa {
113 Statement::CreateType { name: a_name, .. } => a_name == b_name,
114 _ => false,
115 }))
116 }
117 Statement::CreateExtension { name: b_name, .. } => {
118 Ok(self.iter().find(|sa| match sa {
119 Statement::CreateExtension { name: a_name, .. } => a_name == b_name,
120 _ => false,
121 }))
122 }
123 _ => Err(DiffError::builder()
124 .kind(DiffErrorKind::NotImplemented)
125 .statement_a(sb.clone())
126 .build()),
127 }
128 .transpose()
129 .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None)
131 }))
132 .collect::<Result<Vec<_>, _>>()?
133 .into_iter()
134 .flatten()
135 .collect::<Vec<_>>();
136
137 if res.is_empty() {
138 Ok(None)
139 } else {
140 Ok(Some(res))
141 }
142 }
143}
144
145fn find_and_compare<MF, DF>(
146 sa: &Statement,
147 other: &[Statement],
148 match_fn: MF,
149 drop_fn: DF,
150) -> Result<Option<Vec<Statement>>, DiffError>
151where
152 MF: Fn(&&Statement) -> bool,
153 DF: Fn() -> Result<Option<Vec<Statement>>, DiffError>,
154{
155 other.iter().find(match_fn).map_or_else(
156 drop_fn,
158 |sb| sa.diff(sb),
160 )
161}
162
163fn find_and_compare_create_table(
164 sa: &Statement,
165 a: &CreateTable,
166 other: &[Statement],
167) -> Result<Option<Vec<Statement>>, DiffError> {
168 find_and_compare(
169 sa,
170 other,
171 |sb| match sb {
172 Statement::CreateTable(b) => a.name == b.name,
173 _ => false,
174 },
175 || {
176 Ok(Some(vec![Statement::Drop {
177 object_type: sqlparser::ast::ObjectType::Table,
178 if_exists: a.if_not_exists,
179 names: vec![a.name.clone()],
180 cascade: false,
181 restrict: false,
182 purge: false,
183 temporary: false,
184 }]))
185 },
186 )
187}
188
189fn find_and_compare_create_index(
190 sa: &Statement,
191 a: &CreateIndex,
192 other: &[Statement],
193) -> Result<Option<Vec<Statement>>, DiffError> {
194 find_and_compare(
195 sa,
196 other,
197 |sb| match sb {
198 Statement::CreateIndex(b) => a.name == b.name,
199 _ => false,
200 },
201 || {
202 let name = a.name.clone().ok_or_else(|| {
203 DiffError::builder()
204 .kind(DiffErrorKind::DropUnnamedIndex)
205 .statement_a(sa.clone())
206 .build()
207 })?;
208
209 Ok(Some(vec![Statement::Drop {
210 object_type: sqlparser::ast::ObjectType::Index,
211 if_exists: a.if_not_exists,
212 names: vec![name],
213 cascade: false,
214 restrict: false,
215 purge: false,
216 temporary: false,
217 }]))
218 },
219 )
220}
221
222fn find_and_compare_create_type(
223 sa: &Statement,
224 a_name: &ObjectName,
225 other: &[Statement],
226) -> Result<Option<Vec<Statement>>, DiffError> {
227 find_and_compare(
228 sa,
229 other,
230 |sb| match sb {
231 Statement::CreateType { name: b_name, .. } => a_name == b_name,
232 _ => false,
233 },
234 || {
235 Ok(Some(vec![Statement::Drop {
236 object_type: sqlparser::ast::ObjectType::Type,
237 if_exists: false,
238 names: vec![a_name.clone()],
239 cascade: false,
240 restrict: false,
241 purge: false,
242 temporary: false,
243 }]))
244 },
245 )
246}
247
248fn find_and_compare_create_extension(
249 sa: &Statement,
250 a_name: &Ident,
251 if_not_exists: bool,
252 cascade: bool,
253 other: &[Statement],
254) -> Result<Option<Vec<Statement>>, DiffError> {
255 find_and_compare(
256 sa,
257 other,
258 |sb| match sb {
259 Statement::CreateExtension { name: b_name, .. } => a_name == b_name,
260 _ => false,
261 },
262 || {
263 Ok(Some(vec![Statement::DropExtension {
264 names: vec![a_name.clone()],
265 if_exists: if_not_exists,
266 cascade_or_restrict: if cascade {
267 Some(sqlparser::ast::ReferentialAction::Cascade)
268 } else {
269 None
270 },
271 }]))
272 },
273 )
274}
275
276impl Diff for Statement {
277 type Diff = Option<Vec<Statement>>;
278
279 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
280 match self {
281 Self::CreateTable(a) => match other {
282 Self::CreateTable(b) => Ok(compare_create_table(a, b)),
283 _ => Ok(None),
284 },
285 Self::CreateIndex(a) => match other {
286 Self::CreateIndex(b) => compare_create_index(a, b),
287 _ => Ok(None),
288 },
289 Self::CreateType {
290 name: a_name,
291 representation: a_rep,
292 } => match other {
293 Self::CreateType {
294 name: b_name,
295 representation: b_rep,
296 } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep),
297 _ => Ok(None),
298 },
299 _ => Err(DiffError::builder()
300 .kind(DiffErrorKind::NotImplemented)
301 .statement_a(self.clone())
302 .statement_b(other.clone())
303 .build()),
304 }
305 }
306}
307
308fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statement>> {
309 if a == b {
310 return None;
311 }
312
313 let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.clone()).collect();
314 let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.clone()).collect();
315
316 let ops = a
317 .columns
318 .iter()
319 .filter_map(|ac| {
320 if b_column_names.contains(&ac.name) {
321 None
322 } else {
323 Some(AlterTableOperation::DropColumn {
325 column_name: ac.name.clone(),
326 if_exists: a.if_not_exists,
327 drop_behavior: None,
328 })
329 }
330 })
331 .chain(b.columns.iter().filter_map(|bc| {
332 if a_column_names.contains(&bc.name) {
333 None
334 } else {
335 Some(AlterTableOperation::AddColumn {
337 column_keyword: true,
338 if_not_exists: a.if_not_exists,
339 column_def: bc.clone(),
340 column_position: None,
341 })
342 }
343 }))
344 .collect();
345
346 Some(vec![Statement::AlterTable {
347 name: a.name.clone(),
348 if_exists: a.if_not_exists,
349 only: false,
350 operations: ops,
351 location: None,
352 on_cluster: a.on_cluster.clone(),
353 }])
354}
355
356fn compare_create_index(
357 a: &CreateIndex,
358 b: &CreateIndex,
359) -> Result<Option<Vec<Statement>>, DiffError> {
360 if a == b {
361 return Ok(None);
362 }
363
364 if a.name.is_none() || b.name.is_none() {
365 return Err(DiffError::builder()
366 .kind(DiffErrorKind::CompareUnnamedIndex)
367 .statement_a(Statement::CreateIndex(a.clone()))
368 .statement_b(Statement::CreateIndex(b.clone()))
369 .build());
370 }
371 let name = a.name.clone().unwrap();
372
373 Ok(Some(vec![
374 Statement::Drop {
375 object_type: ObjectType::Index,
376 if_exists: a.if_not_exists,
377 names: vec![name],
378 cascade: false,
379 restrict: false,
380 purge: false,
381 temporary: false,
382 },
383 Statement::CreateIndex(b.clone()),
384 ]))
385}
386
387fn compare_create_type(
388 a: &Statement,
389 a_name: &ObjectName,
390 a_rep: &UserDefinedTypeRepresentation,
391 b: &Statement,
392 b_name: &ObjectName,
393 b_rep: &UserDefinedTypeRepresentation,
394) -> Result<Option<Vec<Statement>>, DiffError> {
395 if a_name == b_name && a_rep == b_rep {
396 return Ok(None);
397 }
398
399 let operations = match a_rep {
400 UserDefinedTypeRepresentation::Enum { labels: a_labels } => match b_rep {
401 UserDefinedTypeRepresentation::Enum { labels: b_labels } => {
402 match a_labels.len().cmp(&b_labels.len()) {
403 Ordering::Equal => {
404 let rename_labels: Vec<_> = a_labels
405 .iter()
406 .zip(b_labels.iter())
407 .filter_map(|(a, b)| {
408 if a == b {
409 None
410 } else {
411 Some(AlterTypeOperation::RenameValue(
412 sqlparser::ast::AlterTypeRenameValue {
413 from: a.clone(),
414 to: b.clone(),
415 },
416 ))
417 }
418 })
419 .collect();
420 rename_labels
421 }
422 Ordering::Less => {
423 let mut a_labels_iter = a_labels.iter().peekable();
424 let mut operations = Vec::new();
425 let mut prev = None;
426 for b in b_labels {
427 match a_labels_iter.peek() {
428 Some(a) => {
429 let a = *a;
430 if a == b {
431 prev = Some(a);
432 a_labels_iter.next();
433 continue;
434 }
435
436 let position = match prev {
437 Some(a) => AlterTypeAddValuePosition::After(a.clone()),
438 None => AlterTypeAddValuePosition::Before(a.clone()),
439 };
440
441 prev = Some(b);
442 operations.push(AlterTypeOperation::AddValue(
443 AlterTypeAddValue {
444 if_not_exists: false,
445 value: b.clone(),
446 position: Some(position),
447 },
448 ));
449 }
450 None => {
451 if a_labels.contains(b) {
452 continue;
453 }
454 operations.push(AlterTypeOperation::AddValue(
456 AlterTypeAddValue {
457 if_not_exists: false,
458 value: b.clone(),
459 position: None,
460 },
461 ));
462 }
463 }
464 }
465 operations
466 }
467 _ => {
468 return Err(DiffError::builder()
469 .kind(DiffErrorKind::RemoveEnumLabel)
470 .statement_a(a.clone())
471 .statement_b(b.clone())
472 .build());
473 }
474 }
475 }
476 _ => {
477 return Err(DiffError::builder()
479 .kind(DiffErrorKind::NotImplemented)
480 .statement_a(a.clone())
481 .statement_b(b.clone())
482 .build());
483 }
484 },
485 _ => {
486 return Err(DiffError::builder()
488 .kind(DiffErrorKind::NotImplemented)
489 .statement_a(a.clone())
490 .statement_b(b.clone())
491 .build());
492 }
493 };
494
495 if operations.is_empty() {
496 return Ok(None);
497 }
498
499 Ok(Some(
500 operations
501 .into_iter()
502 .map(|operation| {
503 Statement::AlterType(AlterType {
504 name: a_name.clone(),
505 operation,
506 })
507 })
508 .collect(),
509 ))
510}