1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5 AlterColumnOperation, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6 AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateTable, GeneratedAs, ObjectName,
7 ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13 kind: MigrateErrorKind,
14 statement_a: Option<Statement>,
15 statement_b: Option<Statement>,
16}
17
18impl fmt::Display for MigrateError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(
21 f,
22 "Oops, we couldn't migrate that: {reason}",
23 reason = self.kind
24 )?;
25 if let Some(statement_a) = &self.statement_a {
26 write!(f, "\n\nSubject:\n{statement_a}")?;
27 }
28 if let Some(statement_b) = &self.statement_b {
29 write!(f, "\n\nMigration:\n{statement_b}")?;
30 }
31 Ok(())
32 }
33}
34
35#[bon]
36impl MigrateError {
37 #[builder]
38 fn new(
39 kind: MigrateErrorKind,
40 statement_a: Option<Statement>,
41 statement_b: Option<Statement>,
42 ) -> Self {
43 Self {
44 kind,
45 statement_a,
46 statement_b,
47 }
48 }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54 #[error("can't migrate unnamed index")]
55 UnnamedIndex,
56 #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57 AlterTableOpNotImplemented(AlterTableOperation),
58 #[error("invalid ALTER TYPE operation \"{0}\"")]
59 AlterTypeInvalidOp(AlterTypeOperation),
60 #[error("not yet supported")]
61 NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70 let next: Self = self
71 .into_iter()
72 .filter_map(|sa| {
74 let orig = sa.clone();
75 match &sa {
76 Statement::CreateTable(ca) => other
77 .iter()
78 .find(|sb| match sb {
79 Statement::AlterTable { name, .. } => *name == ca.name,
80 Statement::Drop {
81 object_type, names, ..
82 } => {
83 *object_type == ObjectType::Table
84 && names.len() == 1
85 && names[0] == ca.name
86 }
87 _ => false,
88 })
89 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90 Statement::CreateIndex(a) => other
91 .iter()
92 .find(|sb| match sb {
93 Statement::Drop {
94 object_type, names, ..
95 } => {
96 *object_type == ObjectType::Index
97 && names.len() == 1
98 && Some(&names[0]) == a.name.as_ref()
99 }
100 _ => false,
101 })
102 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103 Statement::CreateType { name, .. } => other
104 .iter()
105 .find(|sb| match sb {
106 Statement::AlterType(b) => *name == b.name,
107 Statement::Drop {
108 object_type, names, ..
109 } => {
110 *object_type == ObjectType::Type
111 && names.len() == 1
112 && names[0] == *name
113 }
114 _ => false,
115 })
116 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117 Statement::CreateExtension { name, .. } => other
118 .iter()
119 .find(|sb| match sb {
120 Statement::DropExtension { names, .. } => names.contains(name),
121 _ => false,
122 })
123 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
124 _ => Some(Err(MigrateError::builder()
125 .kind(MigrateErrorKind::NotImplemented)
126 .statement_a(sa.clone())
127 .build())),
128 }
129 })
130 .chain(other.iter().filter_map(|sb| match sb {
132 Statement::CreateTable(_)
133 | Statement::CreateIndex { .. }
134 | Statement::CreateType { .. }
135 | Statement::CreateExtension { .. } => Some(Ok(sb.clone())),
136 _ => None,
137 }))
138 .collect::<Result<_, _>>()?;
139 Ok(Some(next))
140 }
141}
142
143impl Migrate for Statement {
144 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
145 match self {
146 Self::CreateTable(ca) => match other {
147 Self::AlterTable {
148 name, operations, ..
149 } => {
150 if *name == ca.name {
151 Ok(Some(Self::CreateTable(migrate_alter_table(
152 ca, operations,
153 )?)))
154 } else {
155 Ok(Some(Self::CreateTable(ca)))
157 }
158 }
159 Self::Drop {
160 object_type, names, ..
161 } => {
162 if *object_type == ObjectType::Table && names.contains(&ca.name) {
163 Ok(None)
164 } else {
165 Ok(Some(Self::CreateTable(ca)))
167 }
168 }
169 _ => Err(MigrateError::builder()
170 .kind(MigrateErrorKind::NotImplemented)
171 .statement_a(Self::CreateTable(ca))
172 .statement_b(other.clone())
173 .build()),
174 },
175 Self::CreateIndex(a) => match other {
176 Self::Drop {
177 object_type, names, ..
178 } => {
179 let name = a.name.clone().ok_or_else(|| {
180 MigrateError::builder()
181 .kind(MigrateErrorKind::UnnamedIndex)
182 .statement_a(Self::CreateIndex(a.clone()))
183 .statement_b(other.clone())
184 .build()
185 })?;
186 if *object_type == ObjectType::Index && names.contains(&name) {
187 Ok(None)
188 } else {
189 Ok(Some(Self::CreateIndex(a)))
191 }
192 }
193 _ => Err(MigrateError::builder()
194 .kind(MigrateErrorKind::NotImplemented)
195 .statement_a(Self::CreateIndex(a))
196 .statement_b(other.clone())
197 .build()),
198 },
199 Self::CreateType {
200 name,
201 representation,
202 } => match other {
203 Self::AlterType(ba) => {
204 if name == ba.name {
205 let (name, representation) =
206 migrate_alter_type(name.clone(), representation.clone(), ba)?;
207 Ok(Some(Self::CreateType {
208 name,
209 representation,
210 }))
211 } else {
212 Ok(Some(Self::CreateType {
214 name,
215 representation,
216 }))
217 }
218 }
219 Self::Drop {
220 object_type, names, ..
221 } => {
222 if *object_type == ObjectType::Type && names.contains(&name) {
223 Ok(None)
224 } else {
225 Ok(Some(Self::CreateType {
227 name,
228 representation,
229 }))
230 }
231 }
232 _ => Err(MigrateError::builder()
233 .kind(MigrateErrorKind::NotImplemented)
234 .statement_a(Self::CreateType {
235 name,
236 representation,
237 })
238 .statement_b(other.clone())
239 .build()),
240 },
241 _ => Err(MigrateError::builder()
242 .kind(MigrateErrorKind::NotImplemented)
243 .statement_a(self)
244 .statement_b(other.clone())
245 .build()),
246 }
247 }
248}
249
250fn migrate_alter_table(
251 mut t: CreateTable,
252 ops: &[AlterTableOperation],
253) -> Result<CreateTable, MigrateError> {
254 for op in ops.iter() {
255 match op {
256 AlterTableOperation::AddColumn { column_def, .. } => {
257 t.columns.push(column_def.clone());
258 }
259 AlterTableOperation::DropColumn { column_name, .. } => {
260 t.columns.retain(|c| c.name != *column_name);
261 }
262 AlterTableOperation::AlterColumn { column_name, op } => {
263 t.columns.iter_mut().for_each(|c| {
264 if c.name != *column_name {
265 return;
266 }
267 match op {
268 AlterColumnOperation::SetNotNull => {
269 c.options.push(ColumnOptionDef {
270 name: None,
271 option: ColumnOption::NotNull,
272 });
273 }
274 AlterColumnOperation::DropNotNull => {
275 c.options
276 .retain(|o| !matches!(o.option, ColumnOption::NotNull));
277 }
278 AlterColumnOperation::SetDefault { value } => {
279 c.options
280 .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
281 c.options.push(ColumnOptionDef {
282 name: None,
283 option: ColumnOption::Default(value.clone()),
284 });
285 }
286 AlterColumnOperation::DropDefault => {
287 c.options
288 .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
289 }
290 AlterColumnOperation::SetDataType {
291 data_type,
292 using: _, } => {
294 c.data_type = data_type.clone();
295 }
296 AlterColumnOperation::AddGenerated {
297 generated_as,
298 sequence_options,
299 } => {
300 c.options
301 .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
302 c.options.push(ColumnOptionDef {
303 name: None,
304 option: ColumnOption::Generated {
305 generated_as: generated_as
306 .clone()
307 .unwrap_or(GeneratedAs::Always),
308 sequence_options: sequence_options.clone(),
309 generation_expr: None,
310 generation_expr_mode: None,
311 generated_keyword: true,
312 },
313 });
314 }
315 }
316 });
317 }
318 op => {
319 return Err(MigrateError::builder()
320 .kind(MigrateErrorKind::AlterTableOpNotImplemented(op.clone()))
321 .statement_a(Statement::CreateTable(t.clone()))
322 .build())
323 }
324 }
325 }
326
327 Ok(t)
328}
329
330fn migrate_alter_type(
331 name: ObjectName,
332 representation: UserDefinedTypeRepresentation,
333 other: &AlterType,
334) -> Result<(ObjectName, UserDefinedTypeRepresentation), MigrateError> {
335 match &other.operation {
336 AlterTypeOperation::Rename(r) => {
337 let mut parts = name.0;
338 parts.pop();
339 parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
340 let name = ObjectName(parts);
341
342 Ok((name, representation))
343 }
344 AlterTypeOperation::AddValue(a) => match representation {
345 UserDefinedTypeRepresentation::Enum { mut labels } => {
346 match &a.position {
347 Some(AlterTypeAddValuePosition::Before(before_name)) => {
348 let index = labels
349 .iter()
350 .enumerate()
351 .find(|(_, l)| *l == before_name)
352 .map(|(i, _)| i)
353 .unwrap_or(0);
355 labels.insert(index, a.value.clone());
356 }
357 Some(AlterTypeAddValuePosition::After(after_name)) => {
358 let index = labels
359 .iter()
360 .enumerate()
361 .find(|(_, l)| *l == after_name)
362 .map(|(i, _)| i + 1);
363 match index {
364 Some(index) => labels.insert(index, a.value.clone()),
365 None => labels.push(a.value.clone()),
367 }
368 }
369 None => labels.push(a.value.clone()),
370 }
371
372 Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
373 }
374 UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
375 .kind(MigrateErrorKind::AlterTypeInvalidOp(
376 other.operation.clone(),
377 ))
378 .statement_a(Statement::CreateType {
379 name,
380 representation,
381 })
382 .statement_b(Statement::AlterType(other.clone()))
383 .build()),
384 },
385 AlterTypeOperation::RenameValue(rv) => match representation {
386 UserDefinedTypeRepresentation::Enum { labels } => {
387 let labels = labels
388 .into_iter()
389 .map(|l| if l == rv.from { rv.to.clone() } else { l })
390 .collect::<Vec<_>>();
391
392 Ok((name, UserDefinedTypeRepresentation::Enum { labels }))
393 }
394 UserDefinedTypeRepresentation::Composite { .. } => Err(MigrateError::builder()
395 .kind(MigrateErrorKind::AlterTypeInvalidOp(
396 other.operation.clone(),
397 ))
398 .statement_a(Statement::CreateType {
399 name,
400 representation,
401 })
402 .statement_b(Statement::AlterType(other.clone()))
403 .build()),
404 },
405 }
406}