sql_schema/
migration.rs

1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5    AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6    AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName,
7    ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13    kind: MigrateErrorKind,
14    statement_a: Option<Statement>,
15    statement_b: Option<Statement>,
16}
17
18impl fmt::Display for MigrateError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(
21            f,
22            "Oops, we couldn't migrate that: {reason}",
23            reason = self.kind
24        )?;
25        if let Some(statement_a) = &self.statement_a {
26            write!(f, "\n\nSubject:\n{statement_a}")?;
27        }
28        if let Some(statement_b) = &self.statement_b {
29            write!(f, "\n\nMigration:\n{statement_b}")?;
30        }
31        Ok(())
32    }
33}
34
35#[bon]
36impl MigrateError {
37    #[builder]
38    fn new(
39        kind: MigrateErrorKind,
40        statement_a: Option<Statement>,
41        statement_b: Option<Statement>,
42    ) -> Self {
43        Self {
44            kind,
45            statement_a,
46            statement_b,
47        }
48    }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54    #[error("can't migrate unnamed index")]
55    UnnamedIndex,
56    #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57    AlterTableOpNotImplemented(AlterTableOperation),
58    #[error("invalid ALTER TYPE operation \"{0}\"")]
59    AlterTypeInvalidOp(AlterTypeOperation),
60    #[error("not yet supported")]
61    NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70        let next: Self = self
71            .into_iter()
72            // perform any transformations on existing schema (e.g. ALTER/DROP table)
73            .filter_map(|sa| {
74                let orig = sa.clone();
75                match &sa {
76                    Statement::CreateTable(ca) => other
77                        .iter()
78                        .find(|sb| match sb {
79                            Statement::AlterTable { name, .. } => *name == ca.name,
80                            Statement::Drop {
81                                object_type, names, ..
82                            } => {
83                                *object_type == ObjectType::Table
84                                    && names.len() == 1
85                                    && names[0] == ca.name
86                            }
87                            _ => false,
88                        })
89                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90                    Statement::CreateIndex(a) => other
91                        .iter()
92                        .find(|sb| match sb {
93                            Statement::Drop {
94                                object_type, names, ..
95                            } => {
96                                *object_type == ObjectType::Index
97                                    && names.len() == 1
98                                    && Some(&names[0]) == a.name.as_ref()
99                            }
100                            _ => false,
101                        })
102                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103                    Statement::CreateType { name, .. } => other
104                        .iter()
105                        .find(|sb| match sb {
106                            Statement::AlterType(b) => *name == b.name,
107                            Statement::Drop {
108                                object_type, names, ..
109                            } => {
110                                *object_type == ObjectType::Type
111                                    && names.len() == 1
112                                    && names[0] == *name
113                            }
114                            _ => false,
115                        })
116                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117                    Statement::CreateExtension { name, .. } => other
118                        .iter()
119                        .find(|sb| match sb {
120                            Statement::DropExtension { names, .. } => names.contains(name),
121                            _ => false,
122                        })
123                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
124                    _ => Some(Err(MigrateError::builder()
125                        .kind(MigrateErrorKind::NotImplemented)
126                        .statement_a(sa.clone())
127                        .build())),
128                }
129            })
130            // CREATE table etc.
131            .chain(other.iter().filter_map(|sb| match sb {
132                Statement::CreateTable(_)
133                | Statement::CreateIndex { .. }
134                | Statement::CreateType { .. }
135                | Statement::CreateExtension { .. } => Some(Ok(sb.clone())),
136                _ => None,
137            }))
138            .collect::<Result<_, _>>()?;
139        Ok(Some(next))
140    }
141}
142
143impl Migrate for Statement {
144    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
145        match self {
146            Self::CreateTable(ca) => match other {
147                Self::AlterTable {
148                    name, operations, ..
149                } => {
150                    if *name == ca.name {
151                        Ok(Some(Self::CreateTable(migrate_alter_table(
152                            ca, operations,
153                        )?)))
154                    } else {
155                        // ALTER TABLE statement for another table
156                        Ok(Some(Self::CreateTable(ca)))
157                    }
158                }
159                Self::Drop {
160                    object_type, names, ..
161                } => {
162                    if *object_type == ObjectType::Table && names.contains(&ca.name) {
163                        Ok(None)
164                    } else {
165                        // DROP statement is for another table
166                        Ok(Some(Self::CreateTable(ca)))
167                    }
168                }
169                _ => Err(MigrateError::builder()
170                    .kind(MigrateErrorKind::NotImplemented)
171                    .statement_a(Self::CreateTable(ca))
172                    .statement_b(other.clone())
173                    .build()),
174            },
175            Self::CreateIndex(a) => match other {
176                Self::Drop {
177                    object_type, names, ..
178                } => {
179                    let name = a.name.clone().ok_or_else(|| {
180                        MigrateError::builder()
181                            .kind(MigrateErrorKind::UnnamedIndex)
182                            .statement_a(Self::CreateIndex(a.clone()))
183                            .statement_b(other.clone())
184                            .build()
185                    })?;
186                    if *object_type == ObjectType::Index && names.contains(&name) {
187                        Ok(None)
188                    } else {
189                        // DROP statement is for another index
190                        Ok(Some(Self::CreateIndex(a)))
191                    }
192                }
193                _ => Err(MigrateError::builder()
194                    .kind(MigrateErrorKind::NotImplemented)
195                    .statement_a(Self::CreateIndex(a))
196                    .statement_b(other.clone())
197                    .build()),
198            },
199            Self::CreateType {
200                name,
201                representation,
202            } => match other {
203                Self::AlterType(ba) => {
204                    if name == ba.name {
205                        let (name, representation) =
206                            migrate_alter_type(name.clone(), representation.clone(), ba)?;
207                        Ok(Some(Self::CreateType {
208                            name,
209                            representation,
210                        }))
211                    } else {
212                        // ALTER TYPE statement for another type
213                        Ok(Some(Self::CreateType {
214                            name,
215                            representation,
216                        }))
217                    }
218                }
219                Self::Drop {
220                    object_type, names, ..
221                } => {
222                    if *object_type == ObjectType::Type && names.contains(&name) {
223                        Ok(None)
224                    } else {
225                        // DROP statement is for another type
226                        Ok(Some(Self::CreateType {
227                            name,
228                            representation,
229                        }))
230                    }
231                }
232                _ => Err(MigrateError::builder()
233                    .kind(MigrateErrorKind::NotImplemented)
234                    .statement_a(Self::CreateType {
235                        name,
236                        representation,
237                    })
238                    .statement_b(other.clone())
239                    .build()),
240            },
241            _ => Err(MigrateError::builder()
242                .kind(MigrateErrorKind::NotImplemented)
243                .statement_a(self)
244                .statement_b(other.clone())
245                .build()),
246        }
247    }
248}
249
250fn migrate_alter_table(
251    mut t: CreateTable,
252    ops: &[AlterTableOperation],
253) -> Result<CreateTable, MigrateError> {
254    for op in ops.iter() {
255        match op {
256            AlterTableOperation::AddColumn { column_def, .. } => {
257                t.columns.push(column_def.clone());
258            }
259            AlterTableOperation::DropColumn { column_name, .. } => {
260                t.columns.retain(|c| c.name != *column_name);
261            }
262            AlterTableOperation::AlterColumn { column_name, op } => {
263                t.columns.iter_mut().for_each(|c| {
264                    if c.name != *column_name {
265                        return;
266                    }
267                    match op {
268                        AlterColumnOperation::SetNotNull => {
269                            c.options.push(ColumnOptionDef {
270                                name: None,
271                                option: ColumnOption::NotNull,
272                            });
273                        }
274                        AlterColumnOperation::DropNotNull => {
275                            c.options
276                                .retain(|o| !matches!(o.option, ColumnOption::NotNull));
277                        }
278                        AlterColumnOperation::SetDefault { value } => {
279                            c.options
280                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
281                            c.options.push(ColumnOptionDef {
282                                name: None,
283                                option: ColumnOption::Default(value.clone()),
284                            });
285                        }
286                        AlterColumnOperation::DropDefault => {
287                            c.options
288                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
289                        }
290                        AlterColumnOperation::SetDataType {
291                            data_type,
292                            using: _, // not applicable since we're not running the query
293                        } => {
294                            c.data_type = data_type.clone();
295                        }
296                        AlterColumnOperation::AddGenerated {
297                            generated_as,
298                            sequence_options,
299                        } => {
300                            c.options
301                                .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
302                            c.options.push(ColumnOptionDef {
303                                name: None,
304                                option: ColumnOption::Generated {
305                                    generated_as: generated_as
306                                        .clone()
307                                        .unwrap_or(GeneratedAs::Always),
308                                    sequence_options: sequence_options.clone(),
309                                    generation_expr: None,
310                                    generation_expr_mode: None,
311                                    generated_keyword: true,
312                                },
313                            });
314                        }
315                    }
316                });
317            }
318            op => {
319                return Err(MigrateError::builder()
320                    .kind(MigrateErrorKind::AlterTableOpNotImplemented(op.clone()))
321                    .statement_a(Statement::CreateTable(t.clone()))
322                    .build())
323            }
324        }
325    }
326
327    Ok(t)
328}
329
330fn migrate_alter_type(
331    name: ObjectName,
332    representation: UserDefinedTypeRepresentation,
333    other: &AlterType,
334) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> {
335    match &other.operation {
336        AlterTypeOperation::Rename(r) => {
337            let mut parts = name.0;
338            parts.pop();
339            parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
340            let name = ObjectName(parts);
341
342            Ok((name, representation))
343        }
344        AlterTypeOperation::AddValue(a) => match representation {
345            UserDefinedTypeRepresentation::Enum { mut labels } => {
346                match &a.position {
347                    Some(AlterTypeAddValuePosition::Before(before_name)) => {
348                        let index = labels
349                            .iter()
350                            .enumerate()
351                            .find(|(_, l)| *l == before_name)
352                            .map(|(i, _)| i)
353                            // insert at the beginning if `before_name` can't be found
354                            .unwrap_or(0);
355                        labels.insert(index, a.value.clone());
356                    }
357                    Some(AlterTypeAddValuePosition::After(after_name)) => {
358                        let index = labels
359                            .iter()
360                            .enumerate()
361                            .find(|(_, l)| *l == after_name)
362                            .map(|(i, _)| i + 1);
363                        match index {
364                            Some(index) => labels.insert(index, a.value.clone()),
365                            // push it to the end if `after_name` can't be found
366                            None => labels.push(a.value.clone()),
367                        }
368                    }
369                    None => labels.push(a.value.clone()),
370                }
371
372                Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
373            }
374            UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
375                .kind(MigrateErrorKind::AlterTypeInvalidOp(
376                    other.operation.clone(),
377                ))
378                .statement_a(Statement::CreateType {
379                    name,
380                    representation,
381                })
382                .statement_b(Statement::AlterType(other.clone()))
383                .build()),
384        },
385        AlterTypeOperation::RenameValue(rv) => match representation {
386            UserDefinedTypeRepresentation::Enum { labels } => {
387                let labels = labels
388                    .into_iter()
389                    .map(|l| if l == rv.from { rv.to.clone() } else { l })
390                    .collect::<Vec<_>>();
391
392                Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
393            }
394            UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
395                .kind(MigrateErrorKind::AlterTypeInvalidOp(
396                    other.operation.clone(),
397                ))
398                .statement_a(Statement::CreateType {
399                    name,
400                    representation,
401                })
402                .statement_b(Statement::AlterType(other.clone()))
403                .build()),
404        },
405    }
406}