sql_schema/
diff.rs

1use std::{cmp::Ordering, collections::HashSet, fmt};
2
3use bon::bon;
4use sqlparser::ast::{
5    AlterTableOperation, AlterType, AlterTypeAddValue, AlterTypeAddValuePosition,
6    AlterTypeOperation, CreateIndex, CreateTable, Ident, ObjectName, ObjectType, Statement,
7    UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct DiffError {
13    kind: DiffErrorKind,
14    statement_a: Option<Statement>,
15    statement_b: Option<Statement>,
16}
17
18impl fmt::Display for DiffError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(
21            f,
22            "Oops, we couldn't diff that: {reason}",
23            reason = self.kind
24        )?;
25        if let Some(statement_a) = &self.statement_a {
26            write!(f, "\n\nStatement A:\n{statement_a}")?;
27        }
28        if let Some(statement_b) = &self.statement_b {
29            write!(f, "\n\nStatement B:\n{statement_b}")?;
30        }
31        Ok(())
32    }
33}
34
35#[bon]
36impl DiffError {
37    #[builder]
38    fn new(
39        kind: DiffErrorKind,
40        statement_a: Option<Statement>,
41        statement_b: Option<Statement>,
42    ) -> Self {
43        Self {
44            kind,
45            statement_a,
46            statement_b,
47        }
48    }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum DiffErrorKind {
54    #[error("can't drop unnamed index")]
55    DropUnnamedIndex,
56    #[error("can't compare unnamed index")]
57    CompareUnnamedIndex,
58    #[error("removing enum labels is not supported")]
59    RemoveEnumLabel,
60    #[error("not yet supported")]
61    NotImplemented,
62}
63
64pub(crate) trait Diff: Sized {
65    type Diff;
66
67    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError>;
68}
69
70impl Diff for Vec<Statement> {
71    type Diff = Option<Vec<Statement>>;
72
73    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
74        let res = self
75            .iter()
76            .filter_map(|sa| {
77                match sa {
78                    // CreateTable: compare against another CreateTable with the same name
79                    // TODO: handle renames (e.g. use comments to tag a previous name for a table in a schema)
80                    Statement::CreateTable(a) => find_and_compare_create_table(sa, a, other),
81                    Statement::CreateIndex(a) => find_and_compare_create_index(sa, a, other),
82                    Statement::CreateType { name, .. } => {
83                        find_and_compare_create_type(sa, name, other)
84                    }
85                    Statement::CreateExtension {
86                        name,
87                        if_not_exists,
88                        cascade,
89                        ..
90                    } => {
91                        find_and_compare_create_extension(sa, name, *if_not_exists, *cascade, other)
92                    }
93                    _ => Err(DiffError::builder()
94                        .kind(DiffErrorKind::NotImplemented)
95                        .statement_a(sa.clone())
96                        .build()),
97                }
98                .transpose()
99            })
100            // find resources that are in `other` but not in `self`
101            .chain(other.iter().filter_map(|sb| {
102                match sb {
103                    Statement::CreateTable(b) => Ok(self.iter().find(|sa| match sa {
104                        Statement::CreateTable(a) => a.name == b.name,
105                        _ => false,
106                    })),
107                    Statement::CreateIndex(b) => Ok(self.iter().find(|sa| match sa {
108                        Statement::CreateIndex(a) => a.name == b.name,
109                        _ => false,
110                    })),
111                    Statement::CreateType { name: b_name, .. } => {
112                        Ok(self.iter().find(|sa| match sa {
113                            Statement::CreateType { name: a_name, .. } => a_name == b_name,
114                            _ => false,
115                        }))
116                    }
117                    Statement::CreateExtension { name: b_name, .. } => {
118                        Ok(self.iter().find(|sa| match sa {
119                            Statement::CreateExtension { name: a_name, .. } => a_name == b_name,
120                            _ => false,
121                        }))
122                    }
123                    _ => Err(DiffError::builder()
124                        .kind(DiffErrorKind::NotImplemented)
125                        .statement_a(sb.clone())
126                        .build()),
127                }
128                .transpose()
129                // return the statement if it's not in `self`
130                .map_or_else(|| Some(Ok(vec![sb.clone()])), |_| None)
131            }))
132            .collect::<Result<Vec<_>, _>>()?
133            .into_iter()
134            .flatten()
135            .collect::<Vec<_>>();
136
137        if res.is_empty() {
138            Ok(None)
139        } else {
140            Ok(Some(res))
141        }
142    }
143}
144
145fn find_and_compare<MF, DF>(
146    sa: &Statement,
147    other: &[Statement],
148    match_fn: MF,
149    drop_fn: DF,
150) -> Result<Option<Vec<Statement>>, DiffError>
151where
152    MF: Fn(&&Statement) -> bool,
153    DF: Fn() -> Result<Option<Vec<Statement>>, DiffError>,
154{
155    other.iter().find(match_fn).map_or_else(
156        // drop the statement if it wasn't found in `other`
157        drop_fn,
158        // otherwise diff the two statements
159        |sb| sa.diff(sb),
160    )
161}
162
163fn find_and_compare_create_table(
164    sa: &Statement,
165    a: &CreateTable,
166    other: &[Statement],
167) -> Result<Option<Vec<Statement>>, DiffError> {
168    find_and_compare(
169        sa,
170        other,
171        |sb| match sb {
172            Statement::CreateTable(b) => a.name == b.name,
173            _ => false,
174        },
175        || {
176            Ok(Some(vec![Statement::Drop {
177                object_type: sqlparser::ast::ObjectType::Table,
178                if_exists: a.if_not_exists,
179                names: vec![a.name.clone()],
180                cascade: false,
181                restrict: false,
182                purge: false,
183                temporary: false,
184            }]))
185        },
186    )
187}
188
189fn find_and_compare_create_index(
190    sa: &Statement,
191    a: &CreateIndex,
192    other: &[Statement],
193) -> Result<Option<Vec<Statement>>, DiffError> {
194    find_and_compare(
195        sa,
196        other,
197        |sb| match sb {
198            Statement::CreateIndex(b) => a.name == b.name,
199            _ => false,
200        },
201        || {
202            let name = a.name.clone().ok_or_else(|| {
203                DiffError::builder()
204                    .kind(DiffErrorKind::DropUnnamedIndex)
205                    .statement_a(sa.clone())
206                    .build()
207            })?;
208
209            Ok(Some(vec![Statement::Drop {
210                object_type: sqlparser::ast::ObjectType::Index,
211                if_exists: a.if_not_exists,
212                names: vec![name],
213                cascade: false,
214                restrict: false,
215                purge: false,
216                temporary: false,
217            }]))
218        },
219    )
220}
221
222fn find_and_compare_create_type(
223    sa: &Statement,
224    a_name: &ObjectName,
225    other: &[Statement],
226) -> Result<Option<Vec<Statement>>, DiffError> {
227    find_and_compare(
228        sa,
229        other,
230        |sb| match sb {
231            Statement::CreateType { name: b_name, .. } => a_name == b_name,
232            _ => false,
233        },
234        || {
235            Ok(Some(vec![Statement::Drop {
236                object_type: sqlparser::ast::ObjectType::Type,
237                if_exists: false,
238                names: vec![a_name.clone()],
239                cascade: false,
240                restrict: false,
241                purge: false,
242                temporary: false,
243            }]))
244        },
245    )
246}
247
248fn find_and_compare_create_extension(
249    sa: &Statement,
250    a_name: &Ident,
251    if_not_exists: bool,
252    cascade: bool,
253    other: &[Statement],
254) -> Result<Option<Vec<Statement>>, DiffError> {
255    find_and_compare(
256        sa,
257        other,
258        |sb| match sb {
259            Statement::CreateExtension { name: b_name, .. } => a_name == b_name,
260            _ => false,
261        },
262        || {
263            Ok(Some(vec![Statement::DropExtension {
264                names: vec![a_name.clone()],
265                if_exists: if_not_exists,
266                cascade_or_restrict: if cascade {
267                    Some(sqlparser::ast::ReferentialAction::Cascade)
268                } else {
269                    None
270                },
271            }]))
272        },
273    )
274}
275
276impl Diff for Statement {
277    type Diff = Option<Vec<Statement>>;
278
279    fn diff(&self, other: &Self) -> Result<Self::Diff, DiffError> {
280        match self {
281            Self::CreateTable(a) => match other {
282                Self::CreateTable(b) => Ok(compare_create_table(a, b)),
283                _ => Ok(None),
284            },
285            Self::CreateIndex(a) => match other {
286                Self::CreateIndex(b) => compare_create_index(a, b),
287                _ => Ok(None),
288            },
289            Self::CreateType {
290                name: a_name,
291                representation: a_rep,
292            } => match other {
293                Self::CreateType {
294                    name: b_name,
295                    representation: b_rep,
296                } => compare_create_type(self, a_name, a_rep, other, b_name, b_rep),
297                _ => Ok(None),
298            },
299            _ => Err(DiffError::builder()
300                .kind(DiffErrorKind::NotImplemented)
301                .statement_a(self.clone())
302                .statement_b(other.clone())
303                .build()),
304        }
305    }
306}
307
308fn compare_create_table(a: &CreateTable, b: &CreateTable) -> Option<Vec<Statement>> {
309    if a == b {
310        return None;
311    }
312
313    let a_column_names: HashSet<_> = a.columns.iter().map(|c| c.name.clone()).collect();
314    let b_column_names: HashSet<_> = b.columns.iter().map(|c| c.name.clone()).collect();
315
316    let ops = a
317        .columns
318        .iter()
319        .filter_map(|ac| {
320            if b_column_names.contains(&ac.name) {
321                None
322            } else {
323                // drop column if it only exists in `a`
324                Some(AlterTableOperation::DropColumn {
325                    column_name: ac.name.clone(),
326                    if_exists: a.if_not_exists,
327                    drop_behavior: None,
328                })
329            }
330        })
331        .chain(b.columns.iter().filter_map(|bc| {
332            if a_column_names.contains(&bc.name) {
333                None
334            } else {
335                // add the column if it only exists in `b`
336                Some(AlterTableOperation::AddColumn {
337                    column_keyword: true,
338                    if_not_exists: a.if_not_exists,
339                    column_def: bc.clone(),
340                    column_position: None,
341                })
342            }
343        }))
344        .collect();
345
346    Some(vec![Statement::AlterTable {
347        name: a.name.clone(),
348        if_exists: a.if_not_exists,
349        only: false,
350        operations: ops,
351        location: None,
352        on_cluster: a.on_cluster.clone(),
353    }])
354}
355
356fn compare_create_index(
357    a: &CreateIndex,
358    b: &CreateIndex,
359) -> Result<Option<Vec<Statement>>, DiffError> {
360    if a == b {
361        return Ok(None);
362    }
363
364    if a.name.is_none() || b.name.is_none() {
365        return Err(DiffError::builder()
366            .kind(DiffErrorKind::CompareUnnamedIndex)
367            .statement_a(Statement::CreateIndex(a.clone()))
368            .statement_b(Statement::CreateIndex(b.clone()))
369            .build());
370    }
371    let name = a.name.clone().unwrap();
372
373    Ok(Some(vec![
374        Statement::Drop {
375            object_type: ObjectType::Index,
376            if_exists: a.if_not_exists,
377            names: vec![name],
378            cascade: false,
379            restrict: false,
380            purge: false,
381            temporary: false,
382        },
383        Statement::CreateIndex(b.clone()),
384    ]))
385}
386
387fn compare_create_type(
388    a: &Statement,
389    a_name: &ObjectName,
390    a_rep: &UserDefinedTypeRepresentation,
391    b: &Statement,
392    b_name: &ObjectName,
393    b_rep: &UserDefinedTypeRepresentation,
394) -> Result<Option<Vec<Statement>>, DiffError> {
395    if a_name == b_name && a_rep == b_rep {
396        return Ok(None);
397    }
398
399    let operations = match a_rep {
400        UserDefinedTypeRepresentation::Enum { labels: a_labels } => match b_rep {
401            UserDefinedTypeRepresentation::Enum { labels: b_labels } => {
402                match a_labels.len().cmp(&b_labels.len()) {
403                    Ordering::Equal => {
404                        let rename_labels: Vec<_> = a_labels
405                            .iter()
406                            .zip(b_labels.iter())
407                            .filter_map(|(a, b)| {
408                                if a == b {
409                                    None
410                                } else {
411                                    Some(AlterTypeOperation::RenameValue(
412                                        sqlparser::ast::AlterTypeRenameValue {
413                                            from: a.clone(),
414                                            to: b.clone(),
415                                        },
416                                    ))
417                                }
418                            })
419                            .collect();
420                        rename_labels
421                    }
422                    Ordering::Less => {
423                        let mut a_labels_iter = a_labels.iter().peekable();
424                        let mut operations = Vec::new();
425                        let mut prev = None;
426                        for b in b_labels {
427                            match a_labels_iter.peek() {
428                                Some(a) => {
429                                    let a = *a;
430                                    if a == b {
431                                        prev = Some(a);
432                                        a_labels_iter.next();
433                                        continue;
434                                    }
435
436                                    let position = match prev {
437                                        Some(a) => AlterTypeAddValuePosition::After(a.clone()),
438                                        None => AlterTypeAddValuePosition::Before(a.clone()),
439                                    };
440
441                                    prev = Some(b);
442                                    operations.push(AlterTypeOperation::AddValue(
443                                        AlterTypeAddValue {
444                                            if_not_exists: false,
445                                            value: b.clone(),
446                                            position: Some(position),
447                                        },
448                                    ));
449                                }
450                                None => {
451                                    if a_labels.contains(b) {
452                                        continue;
453                                    }
454                                    // labels occuring after all existing ones get added to the end
455                                    operations.push(AlterTypeOperation::AddValue(
456                                        AlterTypeAddValue {
457                                            if_not_exists: false,
458                                            value: b.clone(),
459                                            position: None,
460                                        },
461                                    ));
462                                }
463                            }
464                        }
465                        operations
466                    }
467                    _ => {
468                        return Err(DiffError::builder()
469                            .kind(DiffErrorKind::RemoveEnumLabel)
470                            .statement_a(a.clone())
471                            .statement_b(b.clone())
472                            .build());
473                    }
474                }
475            }
476            _ => {
477                // TODO: DROP and CREATE type
478                return Err(DiffError::builder()
479                    .kind(DiffErrorKind::NotImplemented)
480                    .statement_a(a.clone())
481                    .statement_b(b.clone())
482                    .build());
483            }
484        },
485        _ => {
486            // TODO: handle diffing composite attributes for CREATE TYPE
487            return Err(DiffError::builder()
488                .kind(DiffErrorKind::NotImplemented)
489                .statement_a(a.clone())
490                .statement_b(b.clone())
491                .build());
492        }
493    };
494
495    if operations.is_empty() {
496        return Ok(None);
497    }
498
499    Ok(Some(
500        operations
501            .into_iter()
502            .map(|operation| {
503                Statement::AlterType(AlterType {
504                    name: a_name.clone(),
505                    operation,
506                })
507            })
508            .collect(),
509    ))
510}