sql_schema/
diff.rs

1use std::{cmp::Ordering, collections::HashSet, fmt};
2
3use bon::bon;
4use sqlparser::ast::{
5    AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition,
6    AlterTypeOperation, CreateDomain, CreateIndex, CreateTable, DropDomain, Ident, ObjectName,
7    ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct DiffError {
13    kind: DiffErrorKind,
14    statement_a: Option<Box<Statement>>,
15    statement_b: Option<Box<Statement>>,
16}
17
18impl fmt::Display for DiffError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(
21            f,
22            "Oops, we couldn't diff that: {reason}",
23            reason = self.kind
24        )?;
25        if let Some(statement_a) = &self.statement_a {
26            write!(f, "\n\nStatement A:\n{statement_a}")?;
27        }
28        if let Some(statement_b) = &self.statement_b {
29            write!(f, "\n\nStatement B:\n{statement_b}")?;
30        }
31        Ok(())
32    }
33}
34
35#[bon]
36impl DiffError {
37    #[builder]
38    fn new(
39        kind: DiffErrorKind,
40        statement_a: Option<Statement>,
41        statement_b: Option<Statement>,
42    ) -> Self {
43        Self {
44            kind,
45            statement_a: statement_a.map(Box::new),
46            statement_b: statement_b.map(Box::new),
47        }
48    }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum DiffErrorKind {
54    #[error("can't drop unnamed index")]
55    DropUnnamedIndex,
56    #[error("can't compare unnamed index")]
57    CompareUnnamedIndex,
58    #[error("removing enum labels is not supported")]
59    RemoveEnumLabel,
60    #[error("not yet supported")]
61    NotImplemented,
62}
63
64pub(crate) trait Diff: Sized {
65    type Diff;
66
67    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError>;
68}
69
70impl Diff for Vec<Statement> {
71    type Diff = Option<Vec<Statement>>;
72
73    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
74        let res = self
75            .iter()
76            .filter_map(|sa| {
77                match sa {
78                    // CreateTable: compare against another CreateTable with the same name
79                    // TODO: handle renames (e.g. use comments to tag a previous name for a table in a schema)
80                    Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other),
81                    Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other),
82                    Statement::CreateType { name, .. } => {
83                        find_and_compare_create_type(sa, name, other)
84                    }
85                    Statement::CreateExtension {
86                        name,
87                        if_not_exists,
88                        cascade,
89                        ..
90                    } => {
91                        find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
92                    }
93                    Statement::CreateDomain(a) => find_and_compare_create_domain(sa, a, other),
94                    _ => Err(DiffError::builder()
95                        .kind(DiffErrorKind::NotImplemented)
96                        .statement_a(sa.clone())
97                        .build()),
98                }
99                .transpose()
100            })
101            // find resources that are in `other` but not in `self`
102            .chain(other.iter().filter_map(|sb| {
103                match sb {
104                    Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa {
105                        Statement::CreateTable(a) => a.name == b.name,
106                        _ => false,
107                    })),
108                    Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa {
109                        Statement::CreateIndex(a) => a.name == b.name,
110                        _ => false,
111                    })),
112                    Statement::CreateType { name: b_name, .. } => {
113                        Ok(self.iter().find(|sa| match sa {
114                            Statement::CreateType { name: a_name, .. } => a_name == b_name,
115                            _ => false,
116                        }))
117                    }
118                    Statement::CreateExtension { name: b_name, .. } => {
119                        Ok(self.iter().find(|sa| match sa {
120                            Statement::CreateExtension { name: a_name, .. } => a_name == b_name,
121                            _ => false,
122                        }))
123                    }
124                    Statement::CreateDomain(b) => Ok(self.iter().find(|sa| match sa {
125                        Statement::CreateDomain(a) => a.name == b.name,
126                        _ => false,
127                    })),
128                    _ => Err(DiffError::builder()
129                        .kind(DiffErrorKind::NotImplemented)
130                        .statement_a(sb.clone())
131                        .build()),
132                }
133                .transpose()
134                // return the statement if it's not in `self`
135                .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None)
136            }))
137            .collect::<Result<Vec<_>, _>>()?
138            .into_iter()
139            .flatten()
140            .collect::<Vec<_>>();
141
142        if res.is_empty() {
143            Ok(None)
144        } else {
145            Ok(Some(res))
146        }
147    }
148}
149
150fn find_and_compare<MF, DF>(
151    sa: &Statement,
152    other: &[Statement],
153    match_fn: MF,
154    drop_fn: DF,
155) -> Result<Option<Vec<Statement>>, DiffError>
156where
157    MF: Fn(&&Statement) -> bool,
158    DF: Fn() -> Result<Option<Vec<Statement>>, DiffError>,
159{
160    other.iter().find(match_fn).map_or_else(
161        // drop the statement if it wasn't found in `other`
162        drop_fn,
163        // otherwise diff the two statements
164        |sb| sa.diff(sb),
165    )
166}
167
168fn find_and_compare_create_table(
169    sa: &Statement,
170    a: &CreateTable,
171    other: &[Statement],
172) -> Result<Option<Vec<Statement>>, DiffError> {
173    find_and_compare(
174        sa,
175        other,
176        |sb| match sb {
177            Statement::CreateTable(b) => a.name == b.name,
178            _ => false,
179        },
180        || {
181            Ok(Some(vec![Statement::Drop {
182                object_type: sqlparser::ast::ObjectType::Table,
183                if_exists: a.if_not_exists,
184                names: vec![a.name.clone()],
185                cascade: false,
186                restrict: false,
187                purge: false,
188                temporary: false,
189                table: None,
190            }]))
191        },
192    )
193}
194
195fn find_and_compare_create_index(
196    sa: &Statement,
197    a: &CreateIndex,
198    other: &[Statement],
199) -> Result<Option<Vec<Statement>>, DiffError> {
200    find_and_compare(
201        sa,
202        other,
203        |sb| match sb {
204            Statement::CreateIndex(b) => a.name == b.name,
205            _ => false,
206        },
207        || {
208            let name = a.name.clone().ok_or_else(|| {
209                DiffError::builder()
210                    .kind(DiffErrorKind::DropUnnamedIndex)
211                    .statement_a(sa.clone())
212                    .build()
213            })?;
214
215            Ok(Some(vec![Statement::Drop {
216                object_type: sqlparser::ast::ObjectType::Index,
217                if_exists: a.if_not_exists,
218                names: vec![name],
219                cascade: false,
220                restrict: false,
221                purge: false,
222                temporary: false,
223                table: None,
224            }]))
225        },
226    )
227}
228
229fn find_and_compare_create_type(
230    sa: &Statement,
231    a_name: &ObjectName,
232    other: &[Statement],
233) -> Result<Option<Vec<Statement>>, DiffError> {
234    find_and_compare(
235        sa,
236        other,
237        |sb| match sb {
238            Statement::CreateType { name: b_name, .. } => a_name == b_name,
239            _ => false,
240        },
241        || {
242            Ok(Some(vec![Statement::Drop {
243                object_type: sqlparser::ast::ObjectType::Type,
244                if_exists: false,
245                names: vec![a_name.clone()],
246                cascade: false,
247                restrict: false,
248                purge: false,
249                temporary: false,
250                table: None,
251            }]))
252        },
253    )
254}
255
256fn find_and_compare_create_extension(
257    sa: &Statement,
258    a_name: &Ident,
259    if_not_exists: bool,
260    cascade: bool,
261    other: &[Statement],
262) -> Result<Option<Vec<Statement>>, DiffError> {
263    find_and_compare(
264        sa,
265        other,
266        |sb| match sb {
267            Statement::CreateExtension { name: b_name, .. } => a_name == b_name,
268            _ => false,
269        },
270        || {
271            Ok(Some(vec![Statement::DropExtension {
272                names: vec![a_name.clone()],
273                if_exists: if_not_exists,
274                cascade_or_restrict: if cascade {
275                    Some(sqlparser::ast::ReferentialAction::Cascade)
276                } else {
277                    None
278                },
279            }]))
280        },
281    )
282}
283
284fn find_and_compare_create_domain(
285    orig: &Statement,
286    domain: &CreateDomain,
287    other: &[Statement],
288) -> Result<Option<Vec<Statement>>, DiffError> {
289    let res = other
290        .iter()
291        .find(|sb| match sb {
292            Statement::CreateDomain(b) => b.name == domain.name,
293            _ => false,
294        })
295        .map(|sb| orig.diff(sb))
296        .transpose()?
297        .flatten();
298    Ok(res)
299}
300
301impl Diff for Statement {
302    type Diff = Option<Vec<Statement>>;
303
304    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
305        match self {
306            Self::CreateTable(a) => match other {
307                Self::CreateTable(b) => Ok(compare_create_table(a, b)),
308                _ => Ok(None),
309            },
310            Self::CreateIndex(a) => match other {
311                Self::CreateIndex(b) => compare_create_index(a, b),
312                _ => Ok(None),
313            },
314            Self::CreateType {
315                name: a_name,
316                representation: a_rep,
317            } => match other {
318                Self::CreateType {
319                    name: b_name,
320                    representation: b_rep,
321                } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep),
322                _ => Ok(None),
323            },
324            Self::CreateDomain(a) => match other {
325                Self::CreateDomain(b) => Ok(compare_create_domain(a, b)),
326                _ => Ok(None),
327            },
328            _ => Err(DiffError::builder()
329                .kind(DiffErrorKind::NotImplemented)
330                .statement_a(self.clone())
331                .statement_b(other.clone())
332                .build()),
333        }
334    }
335}
336
337fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statement>> {
338    if a == b {
339        return None;
340    }
341
342    let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.clone()).collect();
343    let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.clone()).collect();
344
345    let ops = a
346        .columns
347        .iter()
348        .filter_map(|ac| {
349            if b_column_names.contains(&ac.name) {
350                None
351            } else {
352                // drop column if it only exists in `a`
353                Some(AlterTableOperation::DropColumn {
354                    column_name: ac.name.clone(),
355                    if_exists: a.if_not_exists,
356                    drop_behavior: None,
357                    has_column_keyword: true,
358                })
359            }
360        })
361        .chain(b.columns.iter().filter_map(|bc| {
362            if a_column_names.contains(&bc.name) {
363                None
364            } else {
365                // add the column if it only exists in `b`
366                Some(AlterTableOperation::AddColumn {
367                    column_keyword: true,
368                    if_not_exists: a.if_not_exists,
369                    column_def: bc.clone(),
370                    column_position: None,
371                })
372            }
373        }))
374        .collect();
375
376    Some(vec![Statement::AlterTable {
377        name: a.name.clone(),
378        if_exists: a.if_not_exists,
379        only: false,
380        operations: ops,
381        location: None,
382        on_cluster: a.on_cluster.clone(),
383        iceberg: false,
384    }])
385}
386
387fn compare_create_index(
388    a: &CreateIndex,
389    b: &CreateIndex,
390) -> Result<Option<Vec<Statement>>, DiffError> {
391    if a == b {
392        return Ok(None);
393    }
394
395    if a.name.is_none() || b.name.is_none() {
396        return Err(DiffError::builder()
397            .kind(DiffErrorKind::CompareUnnamedIndex)
398            .statement_a(Statement::CreateIndex(a.clone()))
399            .statement_b(Statement::CreateIndex(b.clone()))
400            .build());
401    }
402    let name = a.name.clone().unwrap();
403
404    Ok(Some(vec![
405        Statement::Drop {
406            object_type: ObjectType::Index,
407            if_exists: a.if_not_exists,
408            names: vec![name],
409            cascade: false,
410            restrict: false,
411            purge: false,
412            temporary: false,
413            table: None,
414        },
415        Statement::CreateIndex(b.clone()),
416    ]))
417}
418
419fn compare_create_type(
420    a: &Statement,
421    a_name: &ObjectName,
422    a_rep: &UserDefinedTypeRepresentation,
423    b: &Statement,
424    b_name: &ObjectName,
425    b_rep: &UserDefinedTypeRepresentation,
426) -> Result<Option<Vec<Statement>>, DiffError> {
427    if a_name == b_name && a_rep == b_rep {
428        return Ok(None);
429    }
430
431    let operations = match a_rep {
432        UserDefinedTypeRepresentation::Enum { labels: a_labels } => match b_rep {
433            UserDefinedTypeRepresentation::Enum { labels: b_labels } => {
434                match a_labels.len().cmp(&b_labels.len()) {
435                    Ordering::Equal => {
436                        let rename_labels: Vec<_> = a_labels
437                            .iter()
438                            .zip(b_labels.iter())
439                            .filter_map(|(a, b)| {
440                                if a == b {
441                                    None
442                                } else {
443                                    Some(AlterTypeOperation::RenameValue(
444                                        sqlparser::ast::AlterTypeRenameValue {
445                                            from: a.clone(),
446                                            to: b.clone(),
447                                        },
448                                    ))
449                                }
450                            })
451                            .collect();
452                        rename_labels
453                    }
454                    Ordering::Less => {
455                        let mut a_labels_iter = a_labels.iter().peekable();
456                        let mut operations = Vec::new();
457                        let mut prev = None;
458                        for b in b_labels {
459                            match a_labels_iter.peek() {
460                                Some(a) => {
461                                    let a = *a;
462                                    if a == b {
463                                        prev = Some(a);
464                                        a_labels_iter.next();
465                                        continue;
466                                    }
467
468                                    let position = match prev {
469                                        Some(a) => AlterTypeAddValuePosition::After(a.clone()),
470                                        None => AlterTypeAddValuePosition::Before(a.clone()),
471                                    };
472
473                                    prev = Some(b);
474                                    operations.push(AlterTypeOperation::AddValue(
475                                        AlterTypeAddValue {
476                                            if_not_exists: false,
477                                            value: b.clone(),
478                                            position: Some(position),
479                                        },
480                                    ));
481                                }
482                                None => {
483                                    if a_labels.contains(b) {
484                                        continue;
485                                    }
486                                    // labels occuring after all existing ones get added to the end
487                                    operations.push(AlterTypeOperation::AddValue(
488                                        AlterTypeAddValue {
489                                            if_not_exists: false,
490                                            value: b.clone(),
491                                            position: None,
492                                        },
493                                    ));
494                                }
495                            }
496                        }
497                        operations
498                    }
499                    _ => {
500                        return Err(DiffError::builder()
501                            .kind(DiffErrorKind::RemoveEnumLabel)
502                            .statement_a(a.clone())
503                            .statement_b(b.clone())
504                            .build());
505                    }
506                }
507            }
508            _ => {
509                // TODO: DROP and CREATE type
510                return Err(DiffError::builder()
511                    .kind(DiffErrorKind::NotImplemented)
512                    .statement_a(a.clone())
513                    .statement_b(b.clone())
514                    .build());
515            }
516        },
517        _ => {
518            // TODO: handle diffing composite attributes for CREATE TYPE
519            return Err(DiffError::builder()
520                .kind(DiffErrorKind::NotImplemented)
521                .statement_a(a.clone())
522                .statement_b(b.clone())
523                .build());
524        }
525    };
526
527    if operations.is_empty() {
528        return Ok(None);
529    }
530
531    Ok(Some(
532        operations
533            .into_iter()
534            .map(|operation| {
535                Statement::AlterType(AlterType {
536                    name: a_name.clone(),
537                    operation,
538                })
539            })
540            .collect(),
541    ))
542}
543
544fn compare_create_domain(a: &CreateDomain, b: &CreateDomain) -> Option<Vec<Statement>> {
545    if a == b {
546        return None;
547    }
548
549    Some(vec![
550        Statement::DropDomain(DropDomain {
551            if_exists: true,
552            name: a.name.clone(),
553            drop_behavior: None,
554        }),
555        Statement::CreateDomain(b.clone()),
556    ])
557}