1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5 AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6 AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName,
7 ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13 kind: MigrateErrorKind,
14 statement_a: Option<Box<Statement>>,
15 statement_b: Option<Box<Statement>>,
16}
17
18impl fmt::Display for MigrateError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(
21 f,
22 "Oops, we couldn't migrate that: {reason}",
23 reason = self.kind
24 )?;
25 if let Some(statement_a) = &self.statement_a {
26 write!(f, "\n\nSubject:\n{statement_a}")?;
27 }
28 if let Some(statement_b) = &self.statement_b {
29 write!(f, "\n\nMigration:\n{statement_b}")?;
30 }
31 Ok(())
32 }
33}
34
35#[bon]
36impl MigrateError {
37 #[builder]
38 fn new(
39 kind: MigrateErrorKind,
40 statement_a: Option<Statement>,
41 statement_b: Option<Statement>,
42 ) -> Self {
43 Self {
44 kind,
45 statement_a: statement_a.map(Box::new),
46 statement_b: statement_b.map(Box::new),
47 }
48 }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54 #[error("can't migrate unnamed index")]
55 UnnamedIndex,
56 #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57 AlterTableOpNotImplemented(Box<AlterTableOperation>),
58 #[error("invalid ALTER TYPE operation \"{0}\"")]
59 AlterTypeInvalidOp(Box<AlterTypeOperation>),
60 #[error("not yet supported")]
61 NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70 let next: Self = self
71 .into_iter()
72 .filter_map(|sa| {
74 let orig = sa.clone();
75 match &sa {
76 Statement::CreateTable(ca) => other
77 .iter()
78 .find(|sb| match sb {
79 Statement::AlterTable { name, .. } => *name == ca.name,
80 Statement::Drop {
81 object_type, names, ..
82 } => {
83 *object_type == ObjectType::Table
84 && names.len() == 1
85 && names[0] == ca.name
86 }
87 _ => false,
88 })
89 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90 Statement::CreateIndex(a) => other
91 .iter()
92 .find(|sb| match sb {
93 Statement::Drop {
94 object_type, names, ..
95 } => {
96 *object_type == ObjectType::Index
97 && names.len() == 1
98 && Some(&names[0]) == a.name.as_ref()
99 }
100 _ => false,
101 })
102 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103 Statement::CreateType { name, .. } => other
104 .iter()
105 .find(|sb| match sb {
106 Statement::AlterType(b) => *name == b.name,
107 Statement::Drop {
108 object_type, names, ..
109 } => {
110 *object_type == ObjectType::Type
111 && names.len() == 1
112 && names[0] == *name
113 }
114 _ => false,
115 })
116 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117 Statement::CreateExtension { name, .. } => other
118 .iter()
119 .find(|sb| match sb {
120 Statement::DropExtension { names, .. } => names.contains(name),
121 _ => false,
122 })
123 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
124 Statement::CreateDomain(a) => other
125 .iter()
126 .find(|sb| match sb {
127 Statement::DropDomain(b) => a.name == b.name,
128 _ => false,
129 })
130 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
131 _ => Some(Err(MigrateError::builder()
132 .kind(MigrateErrorKind::NotImplemented)
133 .statement_a(sa.clone())
134 .build())),
135 }
136 })
137 .chain(other.iter().filter_map(|sb| match sb {
139 Statement::CreateTable(_)
140 | Statement::CreateIndex { .. }
141 | Statement::CreateType { .. }
142 | Statement::CreateExtension { .. }
143 | Statement::CreateDomain(..) => Some(Ok(sb.clone())),
144 _ => None,
145 }))
146 .collect::<Result<_, _>>()?;
147 Ok(Some(next))
148 }
149}
150
151impl Migrate for Statement {
152 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
153 match self {
154 Self::CreateTable(ca) => match other {
155 Self::AlterTable {
156 name, operations, ..
157 } => {
158 if *name == ca.name {
159 Ok(Some(Self::CreateTable(migrate_alter_table(
160 ca, operations,
161 )?)))
162 } else {
163 Ok(Some(Self::CreateTable(ca)))
165 }
166 }
167 Self::Drop {
168 object_type, names, ..
169 } => {
170 if *object_type == ObjectType::Table && names.contains(&ca.name) {
171 Ok(None)
172 } else {
173 Ok(Some(Self::CreateTable(ca)))
175 }
176 }
177 _ => Err(MigrateError::builder()
178 .kind(MigrateErrorKind::NotImplemented)
179 .statement_a(Self::CreateTable(ca))
180 .statement_b(other.clone())
181 .build()),
182 },
183 Self::CreateIndex(a) => match other {
184 Self::Drop {
185 object_type, names, ..
186 } => {
187 let name = a.name.clone().ok_or_else(|| {
188 MigrateError::builder()
189 .kind(MigrateErrorKind::UnnamedIndex)
190 .statement_a(Self::CreateIndex(a.clone()))
191 .statement_b(other.clone())
192 .build()
193 })?;
194 if *object_type == ObjectType::Index && names.contains(&name) {
195 Ok(None)
196 } else {
197 Ok(Some(Self::CreateIndex(a)))
199 }
200 }
201 _ => Err(MigrateError::builder()
202 .kind(MigrateErrorKind::NotImplemented)
203 .statement_a(Self::CreateIndex(a))
204 .statement_b(other.clone())
205 .build()),
206 },
207 Self::CreateType {
208 name,
209 representation,
210 } => match other {
211 Self::AlterType(ba) => {
212 if name == ba.name {
213 let (name, representation) =
214 migrate_alter_type(name.clone(), representation.clone(), ba)?;
215 Ok(Some(Self::CreateType {
216 name,
217 representation,
218 }))
219 } else {
220 Ok(Some(Self::CreateType {
222 name,
223 representation,
224 }))
225 }
226 }
227 Self::Drop {
228 object_type, names, ..
229 } => {
230 if *object_type == ObjectType::Type && names.contains(&name) {
231 Ok(None)
232 } else {
233 Ok(Some(Self::CreateType {
235 name,
236 representation,
237 }))
238 }
239 }
240 _ => Err(MigrateError::builder()
241 .kind(MigrateErrorKind::NotImplemented)
242 .statement_a(Self::CreateType {
243 name,
244 representation,
245 })
246 .statement_b(other.clone())
247 .build()),
248 },
249 _ => Err(MigrateError::builder()
250 .kind(MigrateErrorKind::NotImplemented)
251 .statement_a(self)
252 .statement_b(other.clone())
253 .build()),
254 }
255 }
256}
257
258fn migrate_alter_table(
259 mut t: CreateTable,
260 ops: &[AlterTableOperation],
261) -> Result<CreateTable, MigrateError> {
262 for op in ops.iter() {
263 match op {
264 AlterTableOperation::AddColumn { column_def, .. } => {
265 t.columns.push(column_def.clone());
266 }
267 AlterTableOperation::DropColumn { column_name, .. } => {
268 t.columns.retain(|c| c.name != *column_name);
269 }
270 AlterTableOperation::AlterColumn { column_name, op } => {
271 t.columns.iter_mut().for_each(|c| {
272 if c.name != *column_name {
273 return;
274 }
275 match op {
276 AlterColumnOperation::SetNotNull => {
277 c.options.push(ColumnOptionDef {
278 name: None,
279 option: ColumnOption::NotNull,
280 });
281 }
282 AlterColumnOperation::DropNotNull => {
283 c.options
284 .retain(|o| !matches!(o.option, ColumnOption::NotNull));
285 }
286 AlterColumnOperation::SetDefault { value } => {
287 c.options
288 .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
289 c.options.push(ColumnOptionDef {
290 name: None,
291 option: ColumnOption::Default(value.clone()),
292 });
293 }
294 AlterColumnOperation::DropDefault => {
295 c.options
296 .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
297 }
298 AlterColumnOperation::SetDataType {
299 data_type,
300 using: _, } => {
302 c.data_type = data_type.clone();
303 }
304 AlterColumnOperation::AddGenerated {
305 generated_as,
306 sequence_options,
307 } => {
308 c.options
309 .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
310 c.options.push(ColumnOptionDef {
311 name: None,
312 option: ColumnOption::Generated {
313 generated_as: generated_as
314 .clone()
315 .unwrap_or(GeneratedAs::Always),
316 sequence_options: sequence_options.clone(),
317 generation_expr: None,
318 generation_expr_mode: None,
319 generated_keyword: true,
320 },
321 });
322 }
323 }
324 });
325 }
326 op => {
327 return Err(MigrateError::builder()
328 .kind(MigrateErrorKind::AlterTableOpNotImplemented(Box::new(
329 op.clone(),
330 )))
331 .statement_a(Statement::CreateTable(t.clone()))
332 .build())
333 }
334 }
335 }
336
337 Ok(t)
338}
339
340fn migrate_alter_type(
341 name: ObjectName,
342 representation: UserDefinedTypeRepresentation,
343 other: &AlterType,
344) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> {
345 match &other.operation {
346 AlterTypeOperation::Rename(r) => {
347 let mut parts = name.0;
348 parts.pop();
349 parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
350 let name = ObjectName(parts);
351
352 Ok((name, representation))
353 }
354 AlterTypeOperation::AddValue(a) => match representation {
355 UserDefinedTypeRepresentation::Enum { mut labels } => {
356 match &a.position {
357 Some(AlterTypeAddValuePosition::Before(before_name)) => {
358 let index = labels
359 .iter()
360 .enumerate()
361 .find(|(_, l)| *l == before_name)
362 .map(|(i, _)| i)
363 .unwrap_or(0);
365 labels.insert(index, a.value.clone());
366 }
367 Some(AlterTypeAddValuePosition::After(after_name)) => {
368 let index = labels
369 .iter()
370 .enumerate()
371 .find(|(_, l)| *l == after_name)
372 .map(|(i, _)| i + 1);
373 match index {
374 Some(index) => labels.insert(index, a.value.clone()),
375 None => labels.push(a.value.clone()),
377 }
378 }
379 None => labels.push(a.value.clone()),
380 }
381
382 Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
383 }
384 UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
385 .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
386 other.operation.clone(),
387 )))
388 .statement_a(Statement::CreateType {
389 name,
390 representation,
391 })
392 .statement_b(Statement::AlterType(other.clone()))
393 .build()),
394 },
395 AlterTypeOperation::RenameValue(rv) => match representation {
396 UserDefinedTypeRepresentation::Enum { labels } => {
397 let labels = labels
398 .into_iter()
399 .map(|l| if l == rv.from { rv.to.clone() } else { l })
400 .collect::<Vec<_>>();
401
402 Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
403 }
404 UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
405 .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
406 other.operation.clone(),
407 )))
408 .statement_a(Statement::CreateType {
409 name,
410 representation,
411 })
412 .statement_b(Statement::AlterType(other.clone()))
413 .build()),
414 },
415 }
416}