1use std::{cmp::Ordering, collections::HashSet, fmt};
2
3use bon::bon;
4use sqlparser::ast::{
5 helpers::attached_token::AttachedToken, AlterTable, AlterTableOperation, AlterType,
6 AlterTypeAddValue, AlterTypeAddValuePosition, AlterTypeOperation, CreateDomain,
7 CreateExtension, CreateIndex, CreateTable, DropDomain, DropExtension, Ident, ObjectName,
8 ObjectType, Statement, UserDefinedTypeRepresentation,
9};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct DiffError {
14 kind: DiffErrorKind,
15 statement_a: Option<Box<Statement>>,
16 statement_b: Option<Box<Statement>>,
17}
18
19impl fmt::Display for DiffError {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 write!(
22 f,
23 "Oops, we couldn't diff that: {reason}",
24 reason = self.kind
25 )?;
26 if let Some(statement_a) = &self.statement_a {
27 write!(f, "\n\nStatement A:\n{statement_a}")?;
28 }
29 if let Some(statement_b) = &self.statement_b {
30 write!(f, "\n\nStatement B:\n{statement_b}")?;
31 }
32 Ok(())
33 }
34}
35
36#[bon]
37impl DiffError {
38 #[builder]
39 fn new(
40 kind: DiffErrorKind,
41 statement_a: Option<Statement>,
42 statement_b: Option<Statement>,
43 ) -> Self {
44 Self {
45 kind,
46 statement_a: statement_a.map(Box::new),
47 statement_b: statement_b.map(Box::new),
48 }
49 }
50}
51
52#[derive(Error, Debug)]
53#[non_exhaustive]
54enum DiffErrorKind {
55 #[error("can't drop unnamed index")]
56 DropUnnamedIndex,
57 #[error("can't compare unnamed index")]
58 CompareUnnamedIndex,
59 #[error("removing enum labels is not supported")]
60 RemoveEnumLabel,
61 #[error("not yet supported")]
62 NotImplemented,
63}
64
65pub(crate) trait Diff: Sized {
66 type Diff;
67
68 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError>;
69}
70
71impl Diff for Vec<Statement> {
72 type Diff = Option<Vec<Statement>>;
73
74 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
75 let res = self
76 .iter()
77 .filter_map(|sa| {
78 match sa {
79 Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other),
82 Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other),
83 Statement::CreateType { name, .. } => {
84 find_and_compare_create_type(sa, name, other)
85 }
86 Statement::CreateExtension(CreateExtension {
87 name,
88 if_not_exists,
89 cascade,
90 ..
91 }) => {
92 find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
93 }
94 Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other),
95 _ => Err(DiffError::builder()
96 .kind(DiffErrorKind::NotImplemented)
97 .statement_a(sa.clone())
98 .build()),
99 }
100 .transpose()
101 })
102 .chain(other.iter().filter_map(|sb| {
104 match sb {
105 Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa {
106 Statement::CreateTable(a) => a.name == b.name,
107 _ => false,
108 })),
109 Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa {
110 Statement::CreateIndex(a) => a.name == b.name,
111 _ => false,
112 })),
113 Statement::CreateType { name: b_name, .. } => {
114 Ok(self.iter().find(|sa| match sa {
115 Statement::CreateType { name: a_name, .. } => a_name == b_name,
116 _ => false,
117 }))
118 }
119 Statement::CreateExtension(CreateExtension { name: b_name, .. }) => {
120 Ok(self.iter().find(|sa| match sa {
121 Statement::CreateExtension(CreateExtension {
122 name: a_name, ..
123 }) => a_name == b_name,
124 _ => false,
125 }))
126 }
127 Statement::CreateDomain(b) => Ok(self.iter().find(|sa| match sa {
128 Statement::CreateDomain(a) => a.name == b.name,
129 _ => false,
130 })),
131 _ => Err(DiffError::builder()
132 .kind(DiffErrorKind::NotImplemented)
133 .statement_a(sb.clone())
134 .build()),
135 }
136 .transpose()
137 .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None)
139 }))
140 .collect::<Result<Vec<_>, _>>()?
141 .into_iter()
142 .flatten()
143 .collect::<Vec<_>>();
144
145 if res.is_empty() {
146 Ok(None)
147 } else {
148 Ok(Some(res))
149 }
150 }
151}
152
153fn find_and_compare<MF, DF>(
154 sa: &Statement,
155 other: &[Statement],
156 match_fn: MF,
157 drop_fn: DF,
158) -> Result<Option<Vec<Statement>>, DiffError>
159where
160 MF: Fn(&&Statement) -> bool,
161 DF: Fn() -> Result<Option<Vec<Statement>>, DiffError>,
162{
163 other.iter().find(match_fn).map_or_else(
164 drop_fn,
166 |sb| sa.diff(sb),
168 )
169}
170
171fn find_and_compare_create_table(
172 sa: &Statement,
173 a: &CreateTable,
174 other: &[Statement],
175) -> Result<Option<Vec<Statement>>, DiffError> {
176 find_and_compare(
177 sa,
178 other,
179 |sb| match sb {
180 Statement::CreateTable(b) => a.name == b.name,
181 _ => false,
182 },
183 || {
184 Ok(Some(vec![Statement::Drop {
185 object_type: sqlparser::ast::ObjectType::Table,
186 if_exists: a.if_not_exists,
187 names: vec![a.name.clone()],
188 cascade: false,
189 restrict: false,
190 purge: false,
191 temporary: false,
192 table: None,
193 }]))
194 },
195 )
196}
197
198fn find_and_compare_create_index(
199 sa: &Statement,
200 a: &CreateIndex,
201 other: &[Statement],
202) -> Result<Option<Vec<Statement>>, DiffError> {
203 find_and_compare(
204 sa,
205 other,
206 |sb| match sb {
207 Statement::CreateIndex(b) => a.name == b.name,
208 _ => false,
209 },
210 || {
211 let name = a.name.clone().ok_or_else(|| {
212 DiffError::builder()
213 .kind(DiffErrorKind::DropUnnamedIndex)
214 .statement_a(sa.clone())
215 .build()
216 })?;
217
218 Ok(Some(vec![Statement::Drop {
219 object_type: sqlparser::ast::ObjectType::Index,
220 if_exists: a.if_not_exists,
221 names: vec![name],
222 cascade: false,
223 restrict: false,
224 purge: false,
225 temporary: false,
226 table: None,
227 }]))
228 },
229 )
230}
231
232fn find_and_compare_create_type(
233 sa: &Statement,
234 a_name: &ObjectName,
235 other: &[Statement],
236) -> Result<Option<Vec<Statement>>, DiffError> {
237 find_and_compare(
238 sa,
239 other,
240 |sb| match sb {
241 Statement::CreateType { name: b_name, .. } => a_name == b_name,
242 _ => false,
243 },
244 || {
245 Ok(Some(vec![Statement::Drop {
246 object_type: sqlparser::ast::ObjectType::Type,
247 if_exists: false,
248 names: vec![a_name.clone()],
249 cascade: false,
250 restrict: false,
251 purge: false,
252 temporary: false,
253 table: None,
254 }]))
255 },
256 )
257}
258
259fn find_and_compare_create_extension(
260 sa: &Statement,
261 a_name: &Ident,
262 if_not_exists: bool,
263 cascade: bool,
264 other: &[Statement],
265) -> Result<Option<Vec<Statement>>, DiffError> {
266 find_and_compare(
267 sa,
268 other,
269 |sb| match sb {
270 Statement::CreateExtension(CreateExtension { name: b_name, .. }) => a_name == b_name,
271 _ => false,
272 },
273 || {
274 Ok(Some(vec![Statement::DropExtension(DropExtension {
275 names: vec![a_name.clone()],
276 if_exists: if_not_exists,
277 cascade_or_restrict: if cascade {
278 Some(sqlparser::ast::ReferentialAction::Cascade)
279 } else {
280 None
281 },
282 })]))
283 },
284 )
285}
286
287fn find_and_compare_create_domain(
288 orig: &Statement,
289 domain: &CreateDomain,
290 other: &[Statement],
291) -> Result<Option<Vec<Statement>>, DiffError> {
292 let res = other
293 .iter()
294 .find(|sb| match sb {
295 Statement::CreateDomain(b) => b.name == domain.name,
296 _ => false,
297 })
298 .map(|sb| orig.diff(sb))
299 .transpose()?
300 .flatten();
301 Ok(res)
302}
303
304impl Diff for Statement {
305 type Diff = Option<Vec<Statement>>;
306
307 fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
308 match self {
309 Self::CreateTable(a) => match other {
310 Self::CreateTable(b) => Ok(compare_create_table(a, b)),
311 _ => Ok(None),
312 },
313 Self::CreateIndex(a) => match other {
314 Self::CreateIndex(b) => compare_create_index(a, b),
315 _ => Ok(None),
316 },
317 Self::CreateType {
318 name: a_name,
319 representation: a_rep,
320 } => match other {
321 Self::CreateType {
322 name: b_name,
323 representation: b_rep,
324 } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep),
325 _ => Ok(None),
326 },
327 Self::CreateDomain(a) => match other {
328 Self::CreateDomain(b) => Ok(compare_create_domain(a, b)),
329 _ => Ok(None),
330 },
331 _ => Err(DiffError::builder()
332 .kind(DiffErrorKind::NotImplemented)
333 .statement_a(self.clone())
334 .statement_b(other.clone())
335 .build()),
336 }
337 }
338}
339
340fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statement>> {
341 if a == b {
342 return None;
343 }
344
345 let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.value.clone()).collect();
346 let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.value.clone()).collect();
347
348 let operations: Vec<_> = a
349 .columns
350 .iter()
351 .filter_map(|ac| {
352 if b_column_names.contains(&ac.name.value) {
353 None
354 } else {
355 Some(AlterTableOperation::DropColumn {
357 column_names: vec![ac.name.clone()],
358 if_exists: a.if_not_exists,
359 drop_behavior: None,
360 has_column_keyword: true,
361 })
362 }
363 })
364 .chain(b.columns.iter().filter_map(|bc| {
365 if a_column_names.contains(&bc.name.value) {
366 None
367 } else {
368 Some(AlterTableOperation::AddColumn {
370 column_keyword: true,
371 if_not_exists: a.if_not_exists,
372 column_def: bc.clone(),
373 column_position: None,
374 })
375 }
376 }))
377 .collect();
378
379 if operations.is_empty() {
380 return None;
381 }
382
383 Some(vec![Statement::AlterTable(AlterTable {
384 table_type: None,
385 name: a.name.clone(),
386 if_exists: a.if_not_exists,
387 only: false,
388 operations,
389 location: None,
390 on_cluster: a.on_cluster.clone(),
391 end_token: AttachedToken::empty(),
392 })])
393}
394
395fn compare_create_index(
396 a: &CreateIndex,
397 b: &CreateIndex,
398) -> Result<Option<Vec<Statement>>, DiffError> {
399 if a == b {
400 return Ok(None);
401 }
402
403 if a.name.is_none() || b.name.is_none() {
404 return Err(DiffError::builder()
405 .kind(DiffErrorKind::CompareUnnamedIndex)
406 .statement_a(Statement::CreateIndex(a.clone()))
407 .statement_b(Statement::CreateIndex(b.clone()))
408 .build());
409 }
410 let name = a.name.clone().unwrap();
411
412 Ok(Some(vec![
413 Statement::Drop {
414 object_type: ObjectType::Index,
415 if_exists: a.if_not_exists,
416 names: vec![name],
417 cascade: false,
418 restrict: false,
419 purge: false,
420 temporary: false,
421 table: None,
422 },
423 Statement::CreateIndex(b.clone()),
424 ]))
425}
426
427fn compare_create_type(
428 a: &Statement,
429 a_name: &ObjectName,
430 a_rep: &Option<UserDefinedTypeRepresentation>,
431 b: &Statement,
432 b_name: &ObjectName,
433 b_rep: &Option<UserDefinedTypeRepresentation>,
434) -> Result<Option<Vec<Statement>>, DiffError> {
435 if a_name == b_name && a_rep == b_rep {
436 return Ok(None);
437 }
438
439 let operations = match a_rep {
440 Some(UserDefinedTypeRepresentation::Enum { labels: a_labels }) => match b_rep {
441 Some(UserDefinedTypeRepresentation::Enum { labels: b_labels }) => {
442 match a_labels.len().cmp(&b_labels.len()) {
443 Ordering::Equal => {
444 let rename_labels: Vec<_> = a_labels
445 .iter()
446 .zip(b_labels.iter())
447 .filter_map(|(a, b)| {
448 if a == b {
449 None
450 } else {
451 Some(AlterTypeOperation::RenameValue(
452 sqlparser::ast::AlterTypeRenameValue {
453 from: a.clone(),
454 to: b.clone(),
455 },
456 ))
457 }
458 })
459 .collect();
460 rename_labels
461 }
462 Ordering::Less => {
463 let mut a_labels_iter = a_labels.iter().peekable();
464 let mut operations = Vec::new();
465 let mut prev = None;
466 for b in b_labels {
467 match a_labels_iter.peek() {
468 Some(a) => {
469 let a = *a;
470 if a == b {
471 prev = Some(a);
472 a_labels_iter.next();
473 continue;
474 }
475
476 let position = match prev {
477 Some(a) => AlterTypeAddValuePosition::After(a.clone()),
478 None => AlterTypeAddValuePosition::Before(a.clone()),
479 };
480
481 prev = Some(b);
482 operations.push(AlterTypeOperation::AddValue(
483 AlterTypeAddValue {
484 if_not_exists: false,
485 value: b.clone(),
486 position: Some(position),
487 },
488 ));
489 }
490 None => {
491 if a_labels.contains(b) {
492 continue;
493 }
494 operations.push(AlterTypeOperation::AddValue(
496 AlterTypeAddValue {
497 if_not_exists: false,
498 value: b.clone(),
499 position: None,
500 },
501 ));
502 }
503 }
504 }
505 operations
506 }
507 _ => {
508 return Err(DiffError::builder()
509 .kind(DiffErrorKind::RemoveEnumLabel)
510 .statement_a(a.clone())
511 .statement_b(b.clone())
512 .build());
513 }
514 }
515 }
516 _ => {
517 return Err(DiffError::builder()
519 .kind(DiffErrorKind::NotImplemented)
520 .statement_a(a.clone())
521 .statement_b(b.clone())
522 .build());
523 }
524 },
525 _ => {
526 return Err(DiffError::builder()
528 .kind(DiffErrorKind::NotImplemented)
529 .statement_a(a.clone())
530 .statement_b(b.clone())
531 .build());
532 }
533 };
534
535 if operations.is_empty() {
536 return Ok(None);
537 }
538
539 Ok(Some(
540 operations
541 .into_iter()
542 .map(|operation| {
543 Statement::AlterType(AlterType {
544 name: a_name.clone(),
545 operation,
546 })
547 })
548 .collect(),
549 ))
550}
551
552fn compare_create_domain(a: &CreateDomain, b: &CreateDomain) -> Option<Vec<Statement>> {
553 if a == b {
554 return None;
555 }
556
557 Some(vec![
558 Statement::DropDomain(DropDomain {
559 if_exists: true,
560 name: a.name.clone(),
561 drop_behavior: None,
562 }),
563 Statement::CreateDomain(b.clone()),
564 ])
565}