sql_schema/
migration.rs

1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5    AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6    AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName,
7    ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13    kind: MigrateErrorKind,
14    statement_a: Option<Box<Statement>>,
15    statement_b: Option<Box<Statement>>,
16}
17
18impl fmt::Display for MigrateError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(
21            f,
22            "Oops, we couldn't migrate that: {reason}",
23            reason = self.kind
24        )?;
25        if let Some(statement_a) = &self.statement_a {
26            write!(f, "\n\nSubject:\n{statement_a}")?;
27        }
28        if let Some(statement_b) = &self.statement_b {
29            write!(f, "\n\nMigration:\n{statement_b}")?;
30        }
31        Ok(())
32    }
33}
34
35#[bon]
36impl MigrateError {
37    #[builder]
38    fn new(
39        kind: MigrateErrorKind,
40        statement_a: Option<Statement>,
41        statement_b: Option<Statement>,
42    ) -> Self {
43        Self {
44            kind,
45            statement_a: statement_a.map(Box::new),
46            statement_b: statement_b.map(Box::new),
47        }
48    }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54    #[error("can't migrate unnamed index")]
55    UnnamedIndex,
56    #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57    AlterTableOpNotImplemented(Box<AlterTableOperation>),
58    #[error("invalid ALTER TYPE operation \"{0}\"")]
59    AlterTypeInvalidOp(Box<AlterTypeOperation>),
60    #[error("not yet supported")]
61    NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70        let next: Self = self
71            .into_iter()
72            // perform any transformations on existing schema (e.g. ALTER/DROP table)
73            .filter_map(|sa| {
74                let orig = sa.clone();
75                match &sa {
76                    Statement::CreateTable(ca) => other
77                        .iter()
78                        .find(|sb| match sb {
79                            Statement::AlterTable { name, .. } => *name == ca.name,
80                            Statement::Drop {
81                                object_type, names, ..
82                            } => {
83                                *object_type == ObjectType::Table
84                                    && names.len() == 1
85                                    && names[0] == ca.name
86                            }
87                            _ => false,
88                        })
89                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90                    Statement::CreateIndex(a) => other
91                        .iter()
92                        .find(|sb| match sb {
93                            Statement::Drop {
94                                object_type, names, ..
95                            } => {
96                                *object_type == ObjectType::Index
97                                    && names.len() == 1
98                                    && Some(&names[0]) == a.name.as_ref()
99                            }
100                            _ => false,
101                        })
102                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103                    Statement::CreateType { name, .. } => other
104                        .iter()
105                        .find(|sb| match sb {
106                            Statement::AlterType(b) => *name == b.name,
107                            Statement::Drop {
108                                object_type, names, ..
109                            } => {
110                                *object_type == ObjectType::Type
111                                    && names.len() == 1
112                                    && names[0] == *name
113                            }
114                            _ => false,
115                        })
116                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117                    Statement::CreateExtension { name, .. } => other
118                        .iter()
119                        .find(|sb| match sb {
120                            Statement::DropExtension { names, .. } => names.contains(name),
121                            _ => false,
122                        })
123                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
124                    Statement::CreateDomain(a) => other
125                        .iter()
126                        .find(|sb| match sb {
127                            Statement::DropDomain(b) => a.name == b.name,
128                            _ => false,
129                        })
130                        .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
131                    _ => Some(Err(MigrateError::builder()
132                        .kind(MigrateErrorKind::NotImplemented)
133                        .statement_a(sa.clone())
134                        .build())),
135                }
136            })
137            // CREATE table etc.
138            .chain(other.iter().filter_map(|sb| match sb {
139                Statement::CreateTable(_)
140                | Statement::CreateIndex { .. }
141                | Statement::CreateType { .. }
142                | Statement::CreateExtension { .. }
143                | Statement::CreateDomain(..) => Some(Ok(sb.clone())),
144                _ => None,
145            }))
146            .collect::<Result<_, _>>()?;
147        Ok(Some(next))
148    }
149}
150
151impl Migrate for Statement {
152    fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
153        match self {
154            Self::CreateTable(ca) => match other {
155                Self::AlterTable {
156                    name, operations, ..
157                } => {
158                    if *name == ca.name {
159                        Ok(Some(Self::CreateTable(migrate_alter_table(
160                            ca, operations,
161                        )?)))
162                    } else {
163                        // ALTER TABLE statement for another table
164                        Ok(Some(Self::CreateTable(ca)))
165                    }
166                }
167                Self::Drop {
168                    object_type, names, ..
169                } => {
170                    if *object_type == ObjectType::Table && names.contains(&ca.name) {
171                        Ok(None)
172                    } else {
173                        // DROP statement is for another table
174                        Ok(Some(Self::CreateTable(ca)))
175                    }
176                }
177                _ => Err(MigrateError::builder()
178                    .kind(MigrateErrorKind::NotImplemented)
179                    .statement_a(Self::CreateTable(ca))
180                    .statement_b(other.clone())
181                    .build()),
182            },
183            Self::CreateIndex(a) => match other {
184                Self::Drop {
185                    object_type, names, ..
186                } => {
187                    let name = a.name.clone().ok_or_else(|| {
188                        MigrateError::builder()
189                            .kind(MigrateErrorKind::UnnamedIndex)
190                            .statement_a(Self::CreateIndex(a.clone()))
191                            .statement_b(other.clone())
192                            .build()
193                    })?;
194                    if *object_type == ObjectType::Index && names.contains(&name) {
195                        Ok(None)
196                    } else {
197                        // DROP statement is for another index
198                        Ok(Some(Self::CreateIndex(a)))
199                    }
200                }
201                _ => Err(MigrateError::builder()
202                    .kind(MigrateErrorKind::NotImplemented)
203                    .statement_a(Self::CreateIndex(a))
204                    .statement_b(other.clone())
205                    .build()),
206            },
207            Self::CreateType {
208                name,
209                representation,
210            } => match other {
211                Self::AlterType(ba) => {
212                    if name == ba.name {
213                        let (name, representation) =
214                            migrate_alter_type(name.clone(), representation.clone(), ba)?;
215                        Ok(Some(Self::CreateType {
216                            name,
217                            representation,
218                        }))
219                    } else {
220                        // ALTER TYPE statement for another type
221                        Ok(Some(Self::CreateType {
222                            name,
223                            representation,
224                        }))
225                    }
226                }
227                Self::Drop {
228                    object_type, names, ..
229                } => {
230                    if *object_type == ObjectType::Type && names.contains(&name) {
231                        Ok(None)
232                    } else {
233                        // DROP statement is for another type
234                        Ok(Some(Self::CreateType {
235                            name,
236                            representation,
237                        }))
238                    }
239                }
240                _ => Err(MigrateError::builder()
241                    .kind(MigrateErrorKind::NotImplemented)
242                    .statement_a(Self::CreateType {
243                        name,
244                        representation,
245                    })
246                    .statement_b(other.clone())
247                    .build()),
248            },
249            _ => Err(MigrateError::builder()
250                .kind(MigrateErrorKind::NotImplemented)
251                .statement_a(self)
252                .statement_b(other.clone())
253                .build()),
254        }
255    }
256}
257
258fn migrate_alter_table(
259    mut t: CreateTable,
260    ops: &[AlterTableOperation],
261) -> Result<CreateTable, MigrateError> {
262    for op in ops.iter() {
263        match op {
264            AlterTableOperation::AddColumn { column_def, .. } => {
265                t.columns.push(column_def.clone());
266            }
267            AlterTableOperation::DropColumn { column_name, .. } => {
268                t.columns.retain(|c| c.name != *column_name);
269            }
270            AlterTableOperation::AlterColumn { column_name, op } => {
271                t.columns.iter_mut().for_each(|c| {
272                    if c.name != *column_name {
273                        return;
274                    }
275                    match op {
276                        AlterColumnOperation::SetNotNull => {
277                            c.options.push(ColumnOptionDef {
278                                name: None,
279                                option: ColumnOption::NotNull,
280                            });
281                        }
282                        AlterColumnOperation::DropNotNull => {
283                            c.options
284                                .retain(|o| !matches!(o.option, ColumnOption::NotNull));
285                        }
286                        AlterColumnOperation::SetDefault { value } => {
287                            c.options
288                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
289                            c.options.push(ColumnOptionDef {
290                                name: None,
291                                option: ColumnOption::Default(value.clone()),
292                            });
293                        }
294                        AlterColumnOperation::DropDefault => {
295                            c.options
296                                .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
297                        }
298                        AlterColumnOperation::SetDataType {
299                            data_type,
300                            using: _, // not applicable since we're not running the query
301                        } => {
302                            c.data_type = data_type.clone();
303                        }
304                        AlterColumnOperation::AddGenerated {
305                            generated_as,
306                            sequence_options,
307                        } => {
308                            c.options
309                                .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
310                            c.options.push(ColumnOptionDef {
311                                name: None,
312                                option: ColumnOption::Generated {
313                                    generated_as: generated_as
314                                        .clone()
315                                        .unwrap_or(GeneratedAs::Always),
316                                    sequence_options: sequence_options.clone(),
317                                    generation_expr: None,
318                                    generation_expr_mode: None,
319                                    generated_keyword: true,
320                                },
321                            });
322                        }
323                    }
324                });
325            }
326            op => {
327                return Err(MigrateError::builder()
328                    .kind(MigrateErrorKind::AlterTableOpNotImplemented(Box::new(
329                        op.clone(),
330                    )))
331                    .statement_a(Statement::CreateTable(t.clone()))
332                    .build())
333            }
334        }
335    }
336
337    Ok(t)
338}
339
340fn migrate_alter_type(
341    name: ObjectName,
342    representation: UserDefinedTypeRepresentation,
343    other: &AlterType,
344) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> {
345    match &other.operation {
346        AlterTypeOperation::Rename(r) => {
347            let mut parts = name.0;
348            parts.pop();
349            parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
350            let name = ObjectName(parts);
351
352            Ok((name, representation))
353        }
354        AlterTypeOperation::AddValue(a) => match representation {
355            UserDefinedTypeRepresentation::Enum { mut labels } => {
356                match &a.position {
357                    Some(AlterTypeAddValuePosition::Before(before_name)) => {
358                        let index = labels
359                            .iter()
360                            .enumerate()
361                            .find(|(_, l)| *l == before_name)
362                            .map(|(i, _)| i)
363                            // insert at the beginning if `before_name` can't be found
364                            .unwrap_or(0);
365                        labels.insert(index, a.value.clone());
366                    }
367                    Some(AlterTypeAddValuePosition::After(after_name)) => {
368                        let index = labels
369                            .iter()
370                            .enumerate()
371                            .find(|(_, l)| *l == after_name)
372                            .map(|(i, _)| i + 1);
373                        match index {
374                            Some(index) => labels.insert(index, a.value.clone()),
375                            // push it to the end if `after_name` can't be found
376                            None => labels.push(a.value.clone()),
377                        }
378                    }
379                    None => labels.push(a.value.clone()),
380                }
381
382                Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
383            }
384            UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
385                .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
386                    other.operation.clone(),
387                )))
388                .statement_a(Statement::CreateType {
389                    name,
390                    representation,
391                })
392                .statement_b(Statement::AlterType(other.clone()))
393                .build()),
394        },
395        AlterTypeOperation::RenameValue(rv) => match representation {
396            UserDefinedTypeRepresentation::Enum { labels } => {
397                let labels = labels
398                    .into_iter()
399                    .map(|l| if l == rv.from { rv.to.clone() } else { l })
400                    .collect::<Vec<_>>();
401
402                Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
403            }
404            UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
405                .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
406                    other.operation.clone(),
407                )))
408                .statement_a(Statement::CreateType {
409                    name,
410                    representation,
411                })
412                .statement_b(Statement::AlterType(other.clone()))
413                .build()),
414        },
415    }
416}