Skip to main content

sql_schema/
migration.rs

1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5    AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6    AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateExtension, CreateTable, DropExtension,
7    GeneratedAs, ObjectName, ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13    kind: MigrateErrorKind,
14    statement_a: Option<Box<Statement>>,
15    statement_b: Option<Box<Statement>>,
16}
17
18impl fmt::Display for MigrateError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(
21            f,
22            "Oops, we couldn't migrate that: {reason}",
23            reason = self.kind
24        )?;
25        if let Some(statement_a) = &self.statement_a {
26            write!(f, "\n\nSubject:\n{statement_a}")?;
27        }
28        if let Some(statement_b) = &self.statement_b {
29            write!(f, "\n\nMigration:\n{statement_b}")?;
30        }
31        Ok(())
32    }
33}
34
35#[bon]
36impl MigrateError {
37    #[builder]
38    fn new(
39        kind: MigrateErrorKind,
40        statement_a: Option<Statement>,
41        statement_b: Option<Statement>,
42    ) -> Self {
43        Self {
44            kind,
45            statement_a: statement_a.map(Box::new),
46            statement_b: statement_b.map(Box::new),
47        }
48    }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54    #[error("can't migrate unnamed index")]
55    UnnamedIndex,
56    #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57    AlterTableOpNotImplemented(Box<AlterTableOperation>),
58    #[error("invalid ALTER TYPE operation \"{0}\"")]
59    AlterTypeInvalidOp(Box<AlterTypeOperation>),
60    #[error("not yet supported")]
61    NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70        let next: Self = self
71            .into_iter()
72            // perform any transformations on existing schema (e.g. ALTER/DROP table)
73            .filter_map(|sa| {
74                let orig = sa.clone();
75                match &sa {
76                    Statement::CreateTable(ca) => other
77                        .iter()
78                        .find(|sb| match sb {
79                            Statement::AlterTable(AlterTable { name, .. }) => *name == ca.name,
80                            Statement::Drop {
81                                object_type, names, ..
82                            } => {
83                                *object_type == ObjectType::Table
84                                    && names.len() == 1
85                                    && names[0] == ca.name
86                            }
87                            _ => false,
88                        })
89                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90                    Statement::CreateIndex(a) => other
91                        .iter()
92                        .find(|sb| match sb {
93                            Statement::Drop {
94                                object_type, names, ..
95                            } => {
96                                *object_type == ObjectType::Index
97                                    && names.len() == 1
98                                    && Some(&names[0]) == a.name.as_ref()
99                            }
100                            _ => false,
101                        })
102                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103                    Statement::CreateType { name, .. } => other
104                        .iter()
105                        .find(|sb| match sb {
106                            Statement::AlterType(b) => *name == b.name,
107                            Statement::Drop {
108                                object_type, names, ..
109                            } => {
110                                *object_type == ObjectType::Type
111                                    && names.len() == 1
112                                    && names[0] == *name
113                            }
114                            _ => false,
115                        })
116                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117                    Statement::CreateExtension(CreateExtension { name, .. }) => other
118                        .iter()
119                        .find(|sb| match sb {
120                            Statement::DropExtension(DropExtension { names, .. }) => {
121                                names.contains(name)
122                            }
123                            _ => false,
124                        })
125                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
126                    Statement::CreateDomain(a) => other
127                        .iter()
128                        .find(|sb| match sb {
129                            Statement::DropDomain(b) => a.name == b.name,
130                            _ => false,
131                        })
132                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
133                    _ => Some(Err(MigrateError::builder()
134                        .kind(MigrateErrorKind::NotImplemented)
135                        .statement_a(sa.clone())
136                        .build())),
137                }
138            })
139            // CREATE table etc.
140            .chain(other.iter().filter_map(|sb| match sb {
141                Statement::CreateTable(_)
142                | Statement::CreateIndex { .. }
143                | Statement::CreateType { .. }
144                | Statement::CreateExtension { .. }
145                | Statement::CreateDomain(..) => Some(Ok(sb.clone())),
146                _ => None,
147            }))
148            .collect::<Result<_, _>>()?;
149        Ok(Some(next))
150    }
151}
152
153impl Migrate for Statement {
154    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
155        match self {
156            Self::CreateTable(ca) => match other {
157                Self::AlterTable(AlterTable {
158                    name, operations, ..
159                }) => {
160                    if *name == ca.name {
161                        Ok(Some(Self::CreateTable(migrate_alter_table(
162                            ca, operations,
163                        )?)))
164                    } else {
165                        // ALTER TABLE statement for another table
166                        Ok(Some(Self::CreateTable(ca)))
167                    }
168                }
169                Self::Drop {
170                    object_type, names, ..
171                } => {
172                    if *object_type == ObjectType::Table && names.contains(&ca.name) {
173                        Ok(None)
174                    } else {
175                        // DROP statement is for another table
176                        Ok(Some(Self::CreateTable(ca)))
177                    }
178                }
179                _ => Err(MigrateError::builder()
180                    .kind(MigrateErrorKind::NotImplemented)
181                    .statement_a(Self::CreateTable(ca))
182                    .statement_b(other.clone())
183                    .build()),
184            },
185            Self::CreateIndex(a) => match other {
186                Self::Drop {
187                    object_type, names, ..
188                } => {
189                    let name = a.name.clone().ok_or_else(|| {
190                        MigrateError::builder()
191                            .kind(MigrateErrorKind::UnnamedIndex)
192                            .statement_a(Self::CreateIndex(a.clone()))
193                            .statement_b(other.clone())
194                            .build()
195                    })?;
196                    if *object_type == ObjectType::Index && names.contains(&name) {
197                        Ok(None)
198                    } else {
199                        // DROP statement is for another index
200                        Ok(Some(Self::CreateIndex(a)))
201                    }
202                }
203                _ => Err(MigrateError::builder()
204                    .kind(MigrateErrorKind::NotImplemented)
205                    .statement_a(Self::CreateIndex(a))
206                    .statement_b(other.clone())
207                    .build()),
208            },
209            Self::CreateType {
210                name,
211                representation,
212            } => match other {
213                Self::AlterType(ba) => {
214                    if name == ba.name {
215                        let (name, representation) =
216                            migrate_alter_type(name.clone(), representation.clone(), ba)?;
217                        Ok(Some(Self::CreateType {
218                            name,
219                            representation,
220                        }))
221                    } else {
222                        // ALTER TYPE statement for another type
223                        Ok(Some(Self::CreateType {
224                            name,
225                            representation,
226                        }))
227                    }
228                }
229                Self::Drop {
230                    object_type, names, ..
231                } => {
232                    if *object_type == ObjectType::Type && names.contains(&name) {
233                        Ok(None)
234                    } else {
235                        // DROP statement is for another type
236                        Ok(Some(Self::CreateType {
237                            name,
238                            representation,
239                        }))
240                    }
241                }
242                _ => Err(MigrateError::builder()
243                    .kind(MigrateErrorKind::NotImplemented)
244                    .statement_a(Self::CreateType {
245                        name,
246                        representation,
247                    })
248                    .statement_b(other.clone())
249                    .build()),
250            },
251            _ => Err(MigrateError::builder()
252                .kind(MigrateErrorKind::NotImplemented)
253                .statement_a(self)
254                .statement_b(other.clone())
255                .build()),
256        }
257    }
258}
259
260fn migrate_alter_table(
261    mut t: CreateTable,
262    ops: &[AlterTableOperation],
263) -> Result<CreateTable, MigrateError> {
264    for op in ops.iter() {
265        match op {
266            AlterTableOperation::AddColumn { column_def, .. } => {
267                t.columns.push(column_def.clone());
268            }
269            AlterTableOperation::DropColumn { column_names, .. } => {
270                t.columns.retain(|c| {
271                    !column_names
272                        .iter().any(|name| c.name.value == name.value)
273                });
274            }
275            AlterTableOperation::AlterColumn { column_name, op } => {
276                t.columns.iter_mut().for_each(|c| {
277                    if c.name != *column_name {
278                        return;
279                    }
280                    match op {
281                        AlterColumnOperation::SetNotNull => {
282                            c.options.push(ColumnOptionDef {
283                                name: None,
284                                option: ColumnOption::NotNull,
285                            });
286                        }
287                        AlterColumnOperation::DropNotNull => {
288                            c.options
289                                .retain(|o| !matches!(o.option, ColumnOption::NotNull));
290                        }
291                        AlterColumnOperation::SetDefault { value } => {
292                            c.options
293                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
294                            c.options.push(ColumnOptionDef {
295                                name: None,
296                                option: ColumnOption::Default(value.clone()),
297                            });
298                        }
299                        AlterColumnOperation::DropDefault => {
300                            c.options
301                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
302                        }
303                        AlterColumnOperation::SetDataType {
304                            data_type,
305                            using: _,   // not applicable since we're not running the query
306                            had_set: _, // this doesn't change the meaning
307                        } => {
308                            c.data_type = data_type.clone();
309                        }
310                        AlterColumnOperation::AddGenerated {
311                            generated_as,
312                            sequence_options,
313                        } => {
314                            c.options
315                                .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
316                            c.options.push(ColumnOptionDef {
317                                name: None,
318                                option: ColumnOption::Generated {
319                                    generated_as: (*generated_as)
320                                        .unwrap_or(GeneratedAs::Always),
321                                    sequence_options: sequence_options.clone(),
322                                    generation_expr: None,
323                                    generation_expr_mode: None,
324                                    generated_keyword: true,
325                                },
326                            });
327                        }
328                    }
329                });
330            }
331            op => {
332                return Err(MigrateError::builder()
333                    .kind(MigrateErrorKind::AlterTableOpNotImplemented(Box::new(
334                        op.clone(),
335                    )))
336                    .statement_a(Statement::CreateTable(t.clone()))
337                    .build())
338            }
339        }
340    }
341
342    Ok(t)
343}
344
345fn migrate_alter_type(
346    name: ObjectName,
347    representation: Option<UserDefinedTypeRepresentation>,
348    other: &AlterType,
349) -> Result<(ObjectName, Option<UserDefinedTypeRepresentation>), MigrateError> {
350    match &other.operation {
351        AlterTypeOperation::Rename(r) => {
352            let mut parts = name.0;
353            parts.pop();
354            parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
355            let name = ObjectName(parts);
356
357            Ok((name, representation))
358        }
359        AlterTypeOperation::AddValue(a) => match representation {
360            Some(UserDefinedTypeRepresentation::Enum { mut labels }) => {
361                match &a.position {
362                    Some(AlterTypeAddValuePosition::Before(before_name)) => {
363                        let index = labels
364                            .iter()
365                            .enumerate()
366                            .find(|(_, l)| *l == before_name)
367                            .map(|(i, _)| i)
368                            // insert at the beginning if `before_name` can't be found
369                            .unwrap_or(0);
370                        labels.insert(index, a.value.clone());
371                    }
372                    Some(AlterTypeAddValuePosition::After(after_name)) => {
373                        let index = labels
374                            .iter()
375                            .enumerate()
376                            .find(|(_, l)| *l == after_name)
377                            .map(|(i, _)| i + 1);
378                        match index {
379                            Some(index) => labels.insert(index, a.value.clone()),
380                            // push it to the end if `after_name` can't be found
381                            None => labels.push(a.value.clone()),
382                        }
383                    }
384                    None => labels.push(a.value.clone()),
385                }
386
387                Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels })))
388            }
389            Some(_) | None => Err(MigrateError::builder()
390                .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
391                    other.operation.clone(),
392                )))
393                .statement_a(Statement::CreateType {
394                    name,
395                    representation,
396                })
397                .statement_b(Statement::AlterType(other.clone()))
398                .build()),
399        },
400        AlterTypeOperation::RenameValue(rv) => match representation {
401            Some(UserDefinedTypeRepresentation::Enum { labels }) => {
402                let labels = labels
403                    .into_iter()
404                    .map(|l| if l == rv.from { rv.to.clone() } else { l })
405                    .collect::<Vec<_>>();
406
407                Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels })))
408            }
409            Some(_) | None => Err(MigrateError::builder()
410                .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
411                    other.operation.clone(),
412                )))
413                .statement_a(Statement::CreateType {
414                    name,
415                    representation,
416                })
417                .statement_b(Statement::AlterType(other.clone()))
418                .build()),
419        },
420    }
421}