1use std::{cmp::Ordering, collections::HashSet, fmt};
2
3use bon::bon;
4use sqlparser::ast::{
5 AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition,
6 AlterTypeOperation, CreateDomain, CreateIndex, CreateTable, DropDomain, Ident, ObjectName,
7 ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct DiffError {
13 kind: DiffErrorKind,
14 statement_a: Option<Box<Statement>>,
15 statement_b: Option<Box<Statement>>,
16}
17
18impl fmt::Display for DiffError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(
21 f,
22 "Oops, we couldn't diff that: {reason}",
23 reason = self.kind
24 )?;
25 if let Some(statement_a) = &self.statement_a {
26 write!(f, "\n\nStatement A:\n{statement_a}")?;
27 }
28 if let Some(statement_b) = &self.statement_b {
29 write!(f, "\n\nStatement B:\n{statement_b}")?;
30 }
31 Ok(())
32 }
33}
34
35#[bon]
36impl DiffError {
37 #[builder]
38 fn new(
39 kind: DiffErrorKind,
40 statement_a: Option<Statement>,
41 statement_b: Option<Statement>,
42 ) -> Self {
43 Self {
44 kind,
45 statement_a: statement_a.map(Box::new),
46 statement_b: statement_b.map(Box::new),
47 }
48 }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum DiffErrorKind {
54 #[error("can't drop unnamed index")]
55 DropUnnamedIndex,
56 #[error("can't compare unnamed index")]
57 CompareUnnamedIndex,
58 #[error("removing enum labels is not supported")]
59 RemoveEnumLabel,
60 #[error("not yet supported")]
61 NotImplemented,
62}
63
64pub(crate) trait Diff: Sized {
65 type Diff;
66
67 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError>;
68}
69
70impl Diff for Vec<Statement> {
71 type Diff = Option<Vec<Statement>>;
72
73 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
74 let res = self
75 .iter()
76 .filter_map(|sa| {
77 match sa {
78 Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other),
81 Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other),
82 Statement::CreateType { name, .. } => {
83 find_and_compare_create_type(sa, name, other)
84 }
85 Statement::CreateExtension {
86 name,
87 if_not_exists,
88 cascade,
89 ..
90 } => {
91 find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
92 }
93 Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other),
94 _ => Err(DiffError::builder()
95 .kind(DiffErrorKind::NotImplemented)
96 .statement_a(sa.clone())
97 .build()),
98 }
99 .transpose()
100 })
101 .chain(other.iter().filter_map(|sb| {
103 match sb {
104 Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa {
105 Statement::CreateTable(a) => a.name == b.name,
106 _ => false,
107 })),
108 Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa {
109 Statement::CreateIndex(a) => a.name == b.name,
110 _ => false,
111 })),
112 Statement::CreateType { name: b_name, .. } => {
113 Ok(self.iter().find(|sa| match sa {
114 Statement::CreateType { name: a_name, .. } => a_name == b_name,
115 _ => false,
116 }))
117 }
118 Statement::CreateExtension { name: b_name, .. } => {
119 Ok(self.iter().find(|sa| match sa {
120 Statement::CreateExtension { name: a_name, .. } => a_name == b_name,
121 _ => false,
122 }))
123 }
124 Statement::CreateDomain(b) => Ok(self.iter().find(|sa| match sa {
125 Statement::CreateDomain(a) => a.name == b.name,
126 _ => false,
127 })),
128 _ => Err(DiffError::builder()
129 .kind(DiffErrorKind::NotImplemented)
130 .statement_a(sb.clone())
131 .build()),
132 }
133 .transpose()
134 .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None)
136 }))
137 .collect::<Result<Vec<_>, _>>()?
138 .into_iter()
139 .flatten()
140 .collect::<Vec<_>>();
141
142 if res.is_empty() {
143 Ok(None)
144 } else {
145 Ok(Some(res))
146 }
147 }
148}
149
150fn find_and_compare<MF, DF>(
151 sa: &Statement,
152 other: &[Statement],
153 match_fn: MF,
154 drop_fn: DF,
155) -> Result<Option<Vec<Statement>>, DiffError>
156where
157 MF: Fn(&&Statement) -> bool,
158 DF: Fn() -> Result<Option<Vec<Statement>>, DiffError>,
159{
160 other.iter().find(match_fn).map_or_else(
161 drop_fn,
163 |sb| sa.diff(sb),
165 )
166}
167
168fn find_and_compare_create_table(
169 sa: &Statement,
170 a: &CreateTable,
171 other: &[Statement],
172) -> Result<Option<Vec<Statement>>, DiffError> {
173 find_and_compare(
174 sa,
175 other,
176 |sb| match sb {
177 Statement::CreateTable(b) => a.name == b.name,
178 _ => false,
179 },
180 || {
181 Ok(Some(vec![Statement::Drop {
182 object_type: sqlparser::ast::ObjectType::Table,
183 if_exists: a.if_not_exists,
184 names: vec![a.name.clone()],
185 cascade: false,
186 restrict: false,
187 purge: false,
188 temporary: false,
189 table: None,
190 }]))
191 },
192 )
193}
194
195fn find_and_compare_create_index(
196 sa: &Statement,
197 a: &CreateIndex,
198 other: &[Statement],
199) -> Result<Option<Vec<Statement>>, DiffError> {
200 find_and_compare(
201 sa,
202 other,
203 |sb| match sb {
204 Statement::CreateIndex(b) => a.name == b.name,
205 _ => false,
206 },
207 || {
208 let name = a.name.clone().ok_or_else(|| {
209 DiffError::builder()
210 .kind(DiffErrorKind::DropUnnamedIndex)
211 .statement_a(sa.clone())
212 .build()
213 })?;
214
215 Ok(Some(vec![Statement::Drop {
216 object_type: sqlparser::ast::ObjectType::Index,
217 if_exists: a.if_not_exists,
218 names: vec![name],
219 cascade: false,
220 restrict: false,
221 purge: false,
222 temporary: false,
223 table: None,
224 }]))
225 },
226 )
227}
228
229fn find_and_compare_create_type(
230 sa: &Statement,
231 a_name: &ObjectName,
232 other: &[Statement],
233) -> Result<Option<Vec<Statement>>, DiffError> {
234 find_and_compare(
235 sa,
236 other,
237 |sb| match sb {
238 Statement::CreateType { name: b_name, .. } => a_name == b_name,
239 _ => false,
240 },
241 || {
242 Ok(Some(vec![Statement::Drop {
243 object_type: sqlparser::ast::ObjectType::Type,
244 if_exists: false,
245 names: vec![a_name.clone()],
246 cascade: false,
247 restrict: false,
248 purge: false,
249 temporary: false,
250 table: None,
251 }]))
252 },
253 )
254}
255
256fn find_and_compare_create_extension(
257 sa: &Statement,
258 a_name: &Ident,
259 if_not_exists: bool,
260 cascade: bool,
261 other: &[Statement],
262) -> Result<Option<Vec<Statement>>, DiffError> {
263 find_and_compare(
264 sa,
265 other,
266 |sb| match sb {
267 Statement::CreateExtension { name: b_name, .. } => a_name == b_name,
268 _ => false,
269 },
270 || {
271 Ok(Some(vec![Statement::DropExtension {
272 names: vec![a_name.clone()],
273 if_exists: if_not_exists,
274 cascade_or_restrict: if cascade {
275 Some(sqlparser::ast::ReferentialAction::Cascade)
276 } else {
277 None
278 },
279 }]))
280 },
281 )
282}
283
284fn find_and_compare_create_domain(
285 orig: &Statement,
286 domain: &CreateDomain,
287 other: &[Statement],
288) -> Result<Option<Vec<Statement>>, DiffError> {
289 let res = other
290 .iter()
291 .find(|sb| match sb {
292 Statement::CreateDomain(b) => b.name == domain.name,
293 _ => false,
294 })
295 .map(|sb| orig.diff(sb))
296 .transpose()?
297 .flatten();
298 Ok(res)
299}
300
301impl Diff for Statement {
302 type Diff = Option<Vec<Statement>>;
303
304 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
305 match self {
306 Self::CreateTable(a) => match other {
307 Self::CreateTable(b) => Ok(compare_create_table(a, b)),
308 _ => Ok(None),
309 },
310 Self::CreateIndex(a) => match other {
311 Self::CreateIndex(b) => compare_create_index(a, b),
312 _ => Ok(None),
313 },
314 Self::CreateType {
315 name: a_name,
316 representation: a_rep,
317 } => match other {
318 Self::CreateType {
319 name: b_name,
320 representation: b_rep,
321 } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep),
322 _ => Ok(None),
323 },
324 Self::CreateDomain(a) => match other {
325 Self::CreateDomain(b) => Ok(compare_create_domain(a, b)),
326 _ => Ok(None),
327 },
328 _ => Err(DiffError::builder()
329 .kind(DiffErrorKind::NotImplemented)
330 .statement_a(self.clone())
331 .statement_b(other.clone())
332 .build()),
333 }
334 }
335}
336
337fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statement>> {
338 if a == b {
339 return None;
340 }
341
342 let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.clone()).collect();
343 let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.clone()).collect();
344
345 let ops = a
346 .columns
347 .iter()
348 .filter_map(|ac| {
349 if b_column_names.contains(&ac.name) {
350 None
351 } else {
352 Some(AlterTableOperation::DropColumn {
354 column_name: ac.name.clone(),
355 if_exists: a.if_not_exists,
356 drop_behavior: None,
357 has_column_keyword: true,
358 })
359 }
360 })
361 .chain(b.columns.iter().filter_map(|bc| {
362 if a_column_names.contains(&bc.name) {
363 None
364 } else {
365 Some(AlterTableOperation::AddColumn {
367 column_keyword: true,
368 if_not_exists: a.if_not_exists,
369 column_def: bc.clone(),
370 column_position: None,
371 })
372 }
373 }))
374 .collect();
375
376 Some(vec![Statement::AlterTable {
377 name: a.name.clone(),
378 if_exists: a.if_not_exists,
379 only: false,
380 operations: ops,
381 location: None,
382 on_cluster: a.on_cluster.clone(),
383 iceberg: false,
384 }])
385}
386
387fn compare_create_index(
388 a: &CreateIndex,
389 b: &CreateIndex,
390) -> Result<Option<Vec<Statement>>, DiffError> {
391 if a == b {
392 return Ok(None);
393 }
394
395 if a.name.is_none() || b.name.is_none() {
396 return Err(DiffError::builder()
397 .kind(DiffErrorKind::CompareUnnamedIndex)
398 .statement_a(Statement::CreateIndex(a.clone()))
399 .statement_b(Statement::CreateIndex(b.clone()))
400 .build());
401 }
402 let name = a.name.clone().unwrap();
403
404 Ok(Some(vec![
405 Statement::Drop {
406 object_type: ObjectType::Index,
407 if_exists: a.if_not_exists,
408 names: vec![name],
409 cascade: false,
410 restrict: false,
411 purge: false,
412 temporary: false,
413 table: None,
414 },
415 Statement::CreateIndex(b.clone()),
416 ]))
417}
418
419fn compare_create_type(
420 a: &Statement,
421 a_name: &ObjectName,
422 a_rep: &UserDefinedTypeRepresentation,
423 b: &Statement,
424 b_name: &ObjectName,
425 b_rep: &UserDefinedTypeRepresentation,
426) -> Result<Option<Vec<Statement>>, DiffError> {
427 if a_name == b_name && a_rep == b_rep {
428 return Ok(None);
429 }
430
431 let operations = match a_rep {
432 UserDefinedTypeRepresentation::Enum { labels: a_labels } => match b_rep {
433 UserDefinedTypeRepresentation::Enum { labels: b_labels } => {
434 match a_labels.len().cmp(&b_labels.len()) {
435 Ordering::Equal => {
436 let rename_labels: Vec<_> = a_labels
437 .iter()
438 .zip(b_labels.iter())
439 .filter_map(|(a, b)| {
440 if a == b {
441 None
442 } else {
443 Some(AlterTypeOperation::RenameValue(
444 sqlparser::ast::AlterTypeRenameValue {
445 from: a.clone(),
446 to: b.clone(),
447 },
448 ))
449 }
450 })
451 .collect();
452 rename_labels
453 }
454 Ordering::Less => {
455 let mut a_labels_iter = a_labels.iter().peekable();
456 let mut operations = Vec::new();
457 let mut prev = None;
458 for b in b_labels {
459 match a_labels_iter.peek() {
460 Some(a) => {
461 let a = *a;
462 if a == b {
463 prev = Some(a);
464 a_labels_iter.next();
465 continue;
466 }
467
468 let position = match prev {
469 Some(a) => AlterTypeAddValuePosition::After(a.clone()),
470 None => AlterTypeAddValuePosition::Before(a.clone()),
471 };
472
473 prev = Some(b);
474 operations.push(AlterTypeOperation::AddValue(
475 AlterTypeAddValue {
476 if_not_exists: false,
477 value: b.clone(),
478 position: Some(position),
479 },
480 ));
481 }
482 None => {
483 if a_labels.contains(b) {
484 continue;
485 }
486 operations.push(AlterTypeOperation::AddValue(
488 AlterTypeAddValue {
489 if_not_exists: false,
490 value: b.clone(),
491 position: None,
492 },
493 ));
494 }
495 }
496 }
497 operations
498 }
499 _ => {
500 return Err(DiffError::builder()
501 .kind(DiffErrorKind::RemoveEnumLabel)
502 .statement_a(a.clone())
503 .statement_b(b.clone())
504 .build());
505 }
506 }
507 }
508 _ => {
509 return Err(DiffError::builder()
511 .kind(DiffErrorKind::NotImplemented)
512 .statement_a(a.clone())
513 .statement_b(b.clone())
514 .build());
515 }
516 },
517 _ => {
518 return Err(DiffError::builder()
520 .kind(DiffErrorKind::NotImplemented)
521 .statement_a(a.clone())
522 .statement_b(b.clone())
523 .build());
524 }
525 };
526
527 if operations.is_empty() {
528 return Ok(None);
529 }
530
531 Ok(Some(
532 operations
533 .into_iter()
534 .map(|operation| {
535 Statement::AlterType(AlterType {
536 name: a_name.clone(),
537 operation,
538 })
539 })
540 .collect(),
541 ))
542}
543
544fn compare_create_domain(a: &CreateDomain, b: &CreateDomain) -> Option<Vec<Statement>> {
545 if a == b {
546 return None;
547 }
548
549 Some(vec![
550 Statement::DropDomain(DropDomain {
551 if_exists: true,
552 name: a.name.clone(),
553 drop_behavior: None,
554 }),
555 Statement::CreateDomain(b.clone()),
556 ])
557}