sql_schema/
migration.rs

1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5    AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6    AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName,
7    ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13    kind: MigrateErrorKind,
14    statement_a: Option<Statement>,
15    statement_b: Option<Statement>,
16}
17
18impl fmt::Display for MigrateError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(
21            f,
22            "Oops, we couldn't migrate that: {reason}",
23            reason = self.kind
24        )?;
25        if let Some(statement_a) = &self.statement_a {
26            write!(f, "\n\nSubject:\n{statement_a}")?;
27        }
28        if let Some(statement_b) = &self.statement_b {
29            write!(f, "\n\nMigration:\n{statement_b}")?;
30        }
31        Ok(())
32    }
33}
34
35#[bon]
36impl MigrateError {
37    #[builder]
38    fn new(
39        kind: MigrateErrorKind,
40        statement_a: Option<Statement>,
41        statement_b: Option<Statement>,
42    ) -> Self {
43        Self {
44            kind,
45            statement_a,
46            statement_b,
47        }
48    }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54    #[error("can't migrate unnamed index")]
55    UnnamedIndex,
56    #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57    AlterTableOpNotImplemented(AlterTableOperation),
58    #[error("invalid ALTER TYPE operation \"{0}\"")]
59    AlterTypeInvalidOp(AlterTypeOperation),
60    #[error("not yet supported")]
61    NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70        let next: Self = self
71            .into_iter()
72            // perform any transformations on existing schema (e.g. ALTER/DROP table)
73            .filter_map(|sa| {
74                let orig = sa.clone();
75                match &sa {
76                    Statement::CreateTable(ca) => other
77                        .iter()
78                        .find(|sb| match sb {
79                            Statement::AlterTable { name, .. } => *name == ca.name,
80                            Statement::Drop {
81                                object_type, names, ..
82                            } => {
83                                *object_type == ObjectType::Table
84                                    && names.len() == 1
85                                    && names[0] == ca.name
86                            }
87                            _ => false,
88                        })
89                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90                    Statement::CreateIndex(a) => other
91                        .iter()
92                        .find(|sb| match sb {
93                            Statement::Drop {
94                                object_type, names, ..
95                            } => {
96                                *object_type == ObjectType::Index
97                                    && names.len() == 1
98                                    && Some(&names[0]) == a.name.as_ref()
99                            }
100                            _ => false,
101                        })
102                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103                    Statement::CreateType { name, .. } => other
104                        .iter()
105                        .find(|sb| match sb {
106                            Statement::AlterType(b) => *name == b.name,
107                            Statement::Drop {
108                                object_type, names, ..
109                            } => {
110                                *object_type == ObjectType::Type
111                                    && names.len() == 1
112                                    && names[0] == *name
113                            }
114                            _ => false,
115                        })
116                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117                    Statement::CreateExtension { name, .. } => other
118                        .iter()
119                        .find(|sb| match sb {
120                            Statement::DropExtension { names, .. } => names.contains(name),
121                            _ => false,
122                        })
123                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
124                    Statement::CreateDomain(a) => other
125                        .iter()
126                        .find(|sb| match sb {
127                            Statement::DropDomain(b) => a.name == b.name,
128                            _ => false,
129                        })
130                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
131                    _ => Some(Err(MigrateError::builder()
132                        .kind(MigrateErrorKind::NotImplemented)
133                        .statement_a(sa.clone())
134                        .build())),
135                }
136            })
137            // CREATE table etc.
138            .chain(other.iter().filter_map(|sb| match sb {
139                Statement::CreateTable(_)
140                | Statement::CreateIndex { .. }
141                | Statement::CreateType { .. }
142                | Statement::CreateExtension { .. }
143                | Statement::CreateDomain(..) => Some(Ok(sb.clone())),
144                _ => None,
145            }))
146            .collect::<Result<_, _>>()?;
147        Ok(Some(next))
148    }
149}
150
151impl Migrate for Statement {
152    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
153        match self {
154            Self::CreateTable(ca) => match other {
155                Self::AlterTable {
156                    name, operations, ..
157                } => {
158                    if *name == ca.name {
159                        Ok(Some(Self::CreateTable(migrate_alter_table(
160                            ca, operations,
161                        )?)))
162                    } else {
163                        // ALTER TABLE statement for another table
164                        Ok(Some(Self::CreateTable(ca)))
165                    }
166                }
167                Self::Drop {
168                    object_type, names, ..
169                } => {
170                    if *object_type == ObjectType::Table && names.contains(&ca.name) {
171                        Ok(None)
172                    } else {
173                        // DROP statement is for another table
174                        Ok(Some(Self::CreateTable(ca)))
175                    }
176                }
177                _ => Err(MigrateError::builder()
178                    .kind(MigrateErrorKind::NotImplemented)
179                    .statement_a(Self::CreateTable(ca))
180                    .statement_b(other.clone())
181                    .build()),
182            },
183            Self::CreateIndex(a) => match other {
184                Self::Drop {
185                    object_type, names, ..
186                } => {
187                    let name = a.name.clone().ok_or_else(|| {
188                        MigrateError::builder()
189                            .kind(MigrateErrorKind::UnnamedIndex)
190                            .statement_a(Self::CreateIndex(a.clone()))
191                            .statement_b(other.clone())
192                            .build()
193                    })?;
194                    if *object_type == ObjectType::Index && names.contains(&name) {
195                        Ok(None)
196                    } else {
197                        // DROP statement is for another index
198                        Ok(Some(Self::CreateIndex(a)))
199                    }
200                }
201                _ => Err(MigrateError::builder()
202                    .kind(MigrateErrorKind::NotImplemented)
203                    .statement_a(Self::CreateIndex(a))
204                    .statement_b(other.clone())
205                    .build()),
206            },
207            Self::CreateType {
208                name,
209                representation,
210            } => match other {
211                Self::AlterType(ba) => {
212                    if name == ba.name {
213                        let (name, representation) =
214                            migrate_alter_type(name.clone(), representation.clone(), ba)?;
215                        Ok(Some(Self::CreateType {
216                            name,
217                            representation,
218                        }))
219                    } else {
220                        // ALTER TYPE statement for another type
221                        Ok(Some(Self::CreateType {
222                            name,
223                            representation,
224                        }))
225                    }
226                }
227                Self::Drop {
228                    object_type, names, ..
229                } => {
230                    if *object_type == ObjectType::Type && names.contains(&name) {
231                        Ok(None)
232                    } else {
233                        // DROP statement is for another type
234                        Ok(Some(Self::CreateType {
235                            name,
236                            representation,
237                        }))
238                    }
239                }
240                _ => Err(MigrateError::builder()
241                    .kind(MigrateErrorKind::NotImplemented)
242                    .statement_a(Self::CreateType {
243                        name,
244                        representation,
245                    })
246                    .statement_b(other.clone())
247                    .build()),
248            },
249            _ => Err(MigrateError::builder()
250                .kind(MigrateErrorKind::NotImplemented)
251                .statement_a(self)
252                .statement_b(other.clone())
253                .build()),
254        }
255    }
256}
257
258fn migrate_alter_table(
259    mut t: CreateTable,
260    ops: &[AlterTableOperation],
261) -> Result<CreateTable, MigrateError> {
262    for op in ops.iter() {
263        match op {
264            AlterTableOperation::AddColumn { column_def, .. } => {
265                t.columns.push(column_def.clone());
266            }
267            AlterTableOperation::DropColumn { column_name, .. } => {
268                t.columns.retain(|c| c.name != *column_name);
269            }
270            AlterTableOperation::AlterColumn { column_name, op } => {
271                t.columns.iter_mut().for_each(|c| {
272                    if c.name != *column_name {
273                        return;
274                    }
275                    match op {
276                        AlterColumnOperation::SetNotNull => {
277                            c.options.push(ColumnOptionDef {
278                                name: None,
279                                option: ColumnOption::NotNull,
280                            });
281                        }
282                        AlterColumnOperation::DropNotNull => {
283                            c.options
284                                .retain(|o| !matches!(o.option, ColumnOption::NotNull));
285                        }
286                        AlterColumnOperation::SetDefault { value } => {
287                            c.options
288                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
289                            c.options.push(ColumnOptionDef {
290                                name: None,
291                                option: ColumnOption::Default(value.clone()),
292                            });
293                        }
294                        AlterColumnOperation::DropDefault => {
295                            c.options
296                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
297                        }
298                        AlterColumnOperation::SetDataType {
299                            data_type,
300                            using: _, // not applicable since we're not running the query
301                        } => {
302                            c.data_type = data_type.clone();
303                        }
304                        AlterColumnOperation::AddGenerated {
305                            generated_as,
306                            sequence_options,
307                        } => {
308                            c.options
309                                .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
310                            c.options.push(ColumnOptionDef {
311                                name: None,
312                                option: ColumnOption::Generated {
313                                    generated_as: generated_as
314                                        .clone()
315                                        .unwrap_or(GeneratedAs::Always),
316                                    sequence_options: sequence_options.clone(),
317                                    generation_expr: None,
318                                    generation_expr_mode: None,
319                                    generated_keyword: true,
320                                },
321                            });
322                        }
323                    }
324                });
325            }
326            op => {
327                return Err(MigrateError::builder()
328                    .kind(MigrateErrorKind::AlterTableOpNotImplemented(op.clone()))
329                    .statement_a(Statement::CreateTable(t.clone()))
330                    .build())
331            }
332        }
333    }
334
335    Ok(t)
336}
337
338fn migrate_alter_type(
339    name: ObjectName,
340    representation: UserDefinedTypeRepresentation,
341    other: &AlterType,
342) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> {
343    match &other.operation {
344        AlterTypeOperation::Rename(r) => {
345            let mut parts = name.0;
346            parts.pop();
347            parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
348            let name = ObjectName(parts);
349
350            Ok((name, representation))
351        }
352        AlterTypeOperation::AddValue(a) => match representation {
353            UserDefinedTypeRepresentation::Enum { mut labels } => {
354                match &a.position {
355                    Some(AlterTypeAddValuePosition::Before(before_name)) => {
356                        let index = labels
357                            .iter()
358                            .enumerate()
359                            .find(|(_, l)| *l == before_name)
360                            .map(|(i, _)| i)
361                            // insert at the beginning if `before_name` can't be found
362                            .unwrap_or(0);
363                        labels.insert(index, a.value.clone());
364                    }
365                    Some(AlterTypeAddValuePosition::After(after_name)) => {
366                        let index = labels
367                            .iter()
368                            .enumerate()
369                            .find(|(_, l)| *l == after_name)
370                            .map(|(i, _)| i + 1);
371                        match index {
372                            Some(index) => labels.insert(index, a.value.clone()),
373                            // push it to the end if `after_name` can't be found
374                            None => labels.push(a.value.clone()),
375                        }
376                    }
377                    None => labels.push(a.value.clone()),
378                }
379
380                Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
381            }
382            UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
383                .kind(MigrateErrorKind::AlterTypeInvalidOp(
384                    other.operation.clone(),
385                ))
386                .statement_a(Statement::CreateType {
387                    name,
388                    representation,
389                })
390                .statement_b(Statement::AlterType(other.clone()))
391                .build()),
392        },
393        AlterTypeOperation::RenameValue(rv) => match representation {
394            UserDefinedTypeRepresentation::Enum { labels } => {
395                let labels = labels
396                    .into_iter()
397                    .map(|l| if l == rv.from { rv.to.clone() } else { l })
398                    .collect::<Vec<_>>();
399
400                Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
401            }
402            UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
403                .kind(MigrateErrorKind::AlterTypeInvalidOp(
404                    other.operation.clone(),
405                ))
406                .statement_a(Statement::CreateType {
407                    name,
408                    representation,
409                })
410                .statement_b(Statement::AlterType(other.clone()))
411                .build()),
412        },
413    }
414}