Skip to main content

vespertide_query/sql/
replace_constraint.rs

1use sea_query::{Alias, ForeignKey, Query, Table};
2
3use vespertide_core::{TableConstraint, TableDef};
4
5use super::helpers::{
6    build_sqlite_temp_table_create, recreate_indexes_after_rebuild, to_sea_fk_action,
7};
8use super::rename_table::build_rename_table;
9use super::types::{BuiltQuery, DatabaseBackend};
10use crate::error::QueryError;
11
12/// Build SQL queries to replace a constraint in-place.
13///
14/// For PostgreSQL/MySQL: DROP old FK + ADD new FK (two ALTER TABLE statements).
15/// For `SQLite`: single temp table recreation with the new constraint swapped in.
16///
17/// This avoids the double table recreation that would occur with separate
18/// `RemoveConstraint` + `AddConstraint` on `SQLite`.
19pub fn build_replace_constraint(
20    backend: DatabaseBackend,
21    table: &str,
22    from: &TableConstraint,
23    to: &TableConstraint,
24    current_schema: &[TableDef],
25    pending_constraints: &[TableConstraint],
26) -> Result<Vec<BuiltQuery>, QueryError> {
27    match (from, to) {
28        (
29            TableConstraint::ForeignKey {
30                name: old_name,
31                columns: old_columns,
32                ..
33            },
34            TableConstraint::ForeignKey {
35                name: new_name,
36                columns: new_columns,
37                ref_table,
38                ref_columns,
39                on_delete,
40                on_update,
41                ..
42            },
43        ) => {
44            if backend == DatabaseBackend::Sqlite {
45                build_sqlite_constraint_replace(
46                    backend,
47                    table,
48                    from,
49                    to,
50                    current_schema,
51                    pending_constraints,
52                )
53            } else {
54                Ok(build_direct_foreign_key_replace(
55                    table,
56                    old_name.as_deref(),
57                    old_columns,
58                    new_name.as_deref(),
59                    new_columns,
60                    ref_table,
61                    ref_columns,
62                    on_delete.as_ref(),
63                    on_update.as_ref(),
64                ))
65            }
66        }
67        // For non-FK constraints: SQLite uses single temp table, PG/MySQL uses remove + add
68        _ => {
69            if backend == DatabaseBackend::Sqlite {
70                build_sqlite_constraint_replace(
71                    backend,
72                    table,
73                    from,
74                    to,
75                    current_schema,
76                    pending_constraints,
77                )
78            } else {
79                let mut queries = super::remove_constraint::build_remove_constraint(
80                    backend,
81                    table,
82                    from,
83                    current_schema,
84                    pending_constraints,
85                )?;
86
87                // Build a modified schema with the old constraint removed and new one added
88                let modified_schema: Vec<TableDef> = current_schema
89                    .iter()
90                    .map(|t| {
91                        if t.name == table {
92                            let mut modified = t.clone();
93                            modified.constraints.retain(|c| c != from);
94                            modified.constraints.push(to.clone());
95                            modified
96                        } else {
97                            t.clone()
98                        }
99                    })
100                    .collect();
101
102                queries.extend(super::add_constraint::build_add_constraint(
103                    backend,
104                    table,
105                    to,
106                    &modified_schema,
107                    pending_constraints,
108                )?);
109                Ok(queries)
110            }
111        }
112    }
113}
114
115#[expect(
116    clippy::too_many_arguments,
117    reason = "mirrors foreign key action fields"
118)]
119fn build_direct_foreign_key_replace<T: AsRef<str>, U: AsRef<str>, V: AsRef<str>>(
120    table: &str,
121    old_name: Option<&str>,
122    old_columns: &[T],
123    new_name: Option<&str>,
124    new_columns: &[U],
125    ref_table: &str,
126    ref_columns: &[V],
127    on_delete: Option<&vespertide_core::ReferenceAction>,
128    on_update: Option<&vespertide_core::ReferenceAction>,
129) -> Vec<BuiltQuery> {
130    let old_fk_name = vespertide_naming::build_foreign_key_name(table, old_columns, old_name);
131    let fk_drop = ForeignKey::drop()
132        .name(&old_fk_name)
133        .table(Alias::new(table))
134        .to_owned();
135    let fk_create = build_replacement_foreign_key(
136        table,
137        new_name,
138        new_columns,
139        ref_table,
140        ref_columns,
141        on_delete,
142        on_update,
143    );
144
145    vec![
146        BuiltQuery::DropForeignKey(Box::new(fk_drop)),
147        BuiltQuery::CreateForeignKey(Box::new(fk_create)),
148    ]
149}
150
151fn build_replacement_foreign_key<T: AsRef<str>, U: AsRef<str>>(
152    table: &str,
153    new_name: Option<&str>,
154    new_columns: &[T],
155    ref_table: &str,
156    ref_columns: &[U],
157    on_delete: Option<&vespertide_core::ReferenceAction>,
158    on_update: Option<&vespertide_core::ReferenceAction>,
159) -> sea_query::ForeignKeyCreateStatement {
160    let new_fk_name = vespertide_naming::build_foreign_key_name(table, new_columns, new_name);
161    let mut fk_create = ForeignKey::create();
162    fk_create.name(&new_fk_name);
163    fk_create.from_tbl(Alias::new(table));
164    for col in new_columns {
165        fk_create.from_col(Alias::new(col.as_ref()));
166    }
167    fk_create.to_tbl(Alias::new(ref_table));
168    for col in ref_columns {
169        fk_create.to_col(Alias::new(col.as_ref()));
170    }
171    if let Some(action) = on_delete {
172        fk_create.on_delete(to_sea_fk_action(action));
173    }
174    if let Some(action) = on_update {
175        fk_create.on_update(to_sea_fk_action(action));
176    }
177    fk_create
178}
179
180/// `SQLite`: single temp table recreation with the constraint replaced.
181/// Works for all constraint types (FK, Check, Unique, Index, PK).
182fn build_sqlite_constraint_replace(
183    backend: DatabaseBackend,
184    table: &str,
185    from: &TableConstraint,
186    to: &TableConstraint,
187    current_schema: &[TableDef],
188    pending_constraints: &[TableConstraint],
189) -> Result<Vec<BuiltQuery>, QueryError> {
190    let table_def = current_schema
191        .iter()
192        .find(|t| t.name == table)
193        .ok_or_else(|| {
194            QueryError::SchemaError(format!(
195                "Table '{table}' not found in current schema. SQLite requires current schema \
196                 information to replace constraints."
197            ))
198        })?;
199
200    // Build new constraints: replace old constraint with new one
201    let new_constraints: Vec<TableConstraint> = table_def
202        .constraints
203        .iter()
204        .map(|c| if c == from { to.clone() } else { c.clone() })
205        .collect();
206
207    let temp_table = format!("{table}_temp");
208
209    // 1. Create temporary table with replaced constraint
210    let create_query = build_sqlite_temp_table_create(
211        backend,
212        &temp_table,
213        table,
214        &table_def.columns,
215        &new_constraints,
216    );
217
218    // 2. Copy data (all columns)
219    let column_aliases: Vec<Alias> = table_def
220        .columns
221        .iter()
222        .map(|c| Alias::new(&c.name))
223        .collect();
224    let mut select_query = Query::select();
225    for col_alias in &column_aliases {
226        select_query.column(col_alias.clone());
227    }
228    select_query.from(Alias::new(table));
229
230    let insert_stmt = Query::insert()
231        .into_table(Alias::new(&temp_table))
232        .columns(column_aliases.clone())
233        .select_from(select_query)
234        .unwrap()
235        .to_owned();
236    let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
237
238    // 3. Drop original table
239    let drop_table = Table::drop().table(Alias::new(table)).to_owned();
240    let drop_query = BuiltQuery::DropTable(Box::new(drop_table));
241
242    // 4. Rename temporary table to original name
243    let rename_query = build_rename_table(&temp_table, table);
244
245    // 5. Recreate indexes (both regular and UNIQUE)
246    let index_queries =
247        recreate_indexes_after_rebuild(table, &table_def.constraints, pending_constraints);
248
249    let mut queries = vec![create_query, insert_query, drop_query, rename_query];
250    queries.extend(index_queries);
251    Ok(queries)
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use insta::{assert_snapshot, with_settings};
258    use rstest::rstest;
259    use vespertide_core::{
260        ColumnDef, ColumnType, ReferenceAction, SimpleColumnType, TableConstraint, TableDef,
261    };
262
263    fn test_schema() -> Vec<TableDef> {
264        vec![
265            TableDef {
266                name: "users".into(),
267                columns: vec![ColumnDef {
268                    name: "id".into(),
269                    r#type: ColumnType::Simple(SimpleColumnType::Integer),
270                    nullable: false,
271                    default: None,
272                    comment: None,
273                    primary_key: None,
274                    unique: None,
275                    index: None,
276                    foreign_key: None,
277                }],
278                constraints: vec![TableConstraint::PrimaryKey {
279                    auto_increment: false,
280                    columns: vec!["id".into()],
281                    strategy: vespertide_core::PrimaryKeyAdditionStrategy::default(),
282                }],
283                description: None,
284            },
285            TableDef {
286                name: "posts".into(),
287                columns: vec![
288                    ColumnDef {
289                        name: "id".into(),
290                        r#type: ColumnType::Simple(SimpleColumnType::Integer),
291                        nullable: false,
292                        default: None,
293                        comment: None,
294                        primary_key: None,
295                        unique: None,
296                        index: None,
297                        foreign_key: None,
298                    },
299                    ColumnDef {
300                        name: "user_id".into(),
301                        r#type: ColumnType::Simple(SimpleColumnType::Integer),
302                        nullable: false,
303                        default: None,
304                        comment: None,
305                        primary_key: None,
306                        unique: None,
307                        index: None,
308                        foreign_key: None,
309                    },
310                ],
311                constraints: vec![
312                    TableConstraint::PrimaryKey {
313                        auto_increment: false,
314                        columns: vec!["id".into()],
315                        strategy: vespertide_core::PrimaryKeyAdditionStrategy::default(),
316                    },
317                    TableConstraint::ForeignKey {
318                        name: Some("fk_user".into()),
319                        columns: vec!["user_id".into()],
320                        ref_table: "users".into(),
321                        ref_columns: vec!["id".into()],
322                        on_delete: None,
323                        on_update: None,
324                        orphan_strategy: vespertide_core::ForeignKeyOrphanStrategy::default(),
325                    },
326                ],
327                description: None,
328            },
329        ]
330    }
331
332    #[rstest]
333    #[case::postgres(DatabaseBackend::Postgres)]
334    #[case::mysql(DatabaseBackend::MySql)]
335    #[case::sqlite(DatabaseBackend::Sqlite)]
336    fn replace_fk_on_delete(#[case] backend: DatabaseBackend) {
337        let schema = test_schema();
338        let from = TableConstraint::ForeignKey {
339            name: Some("fk_user".into()),
340            columns: vec!["user_id".into()],
341            ref_table: "users".into(),
342            ref_columns: vec!["id".into()],
343            on_delete: None,
344            on_update: None,
345            orphan_strategy: vespertide_core::ForeignKeyOrphanStrategy::default(),
346        };
347        let to = TableConstraint::ForeignKey {
348            name: Some("fk_user".into()),
349            columns: vec!["user_id".into()],
350            ref_table: "users".into(),
351            ref_columns: vec!["id".into()],
352            on_delete: Some(ReferenceAction::Cascade),
353            on_update: None,
354            orphan_strategy: vespertide_core::ForeignKeyOrphanStrategy::default(),
355        };
356
357        let queries = build_replace_constraint(backend, "posts", &from, &to, &schema, &[])
358            .expect("should succeed");
359
360        let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
361        let combined = sql.join(";\n");
362
363        with_settings!({
364            description => format!("replace FK on_delete for {:?}", backend),
365            omit_expression => true,
366            snapshot_suffix => format!("replace_fk_on_delete_{:?}", backend),
367        }, {
368            assert_snapshot!(combined);
369        });
370    }
371
372    #[rstest]
373    #[case::postgres(DatabaseBackend::Postgres)]
374    #[case::mysql(DatabaseBackend::MySql)]
375    #[case::sqlite(DatabaseBackend::Sqlite)]
376    fn replace_fk_on_update(#[case] backend: DatabaseBackend) {
377        let schema = test_schema();
378        let from = TableConstraint::ForeignKey {
379            name: Some("fk_user".into()),
380            columns: vec!["user_id".into()],
381            ref_table: "users".into(),
382            ref_columns: vec!["id".into()],
383            on_delete: None,
384            on_update: None,
385            orphan_strategy: vespertide_core::ForeignKeyOrphanStrategy::default(),
386        };
387        let to = TableConstraint::ForeignKey {
388            name: Some("fk_user".into()),
389            columns: vec!["user_id".into()],
390            ref_table: "users".into(),
391            ref_columns: vec!["id".into()],
392            on_delete: None,
393            on_update: Some(ReferenceAction::Cascade),
394            orphan_strategy: vespertide_core::ForeignKeyOrphanStrategy::default(),
395        };
396
397        let queries = build_replace_constraint(backend, "posts", &from, &to, &schema, &[])
398            .expect("should succeed");
399        let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
400        let combined = sql.join(";\n");
401
402        with_settings!({
403            description => format!("replace FK on_update for {:?}", backend),
404            omit_expression => true,
405            snapshot_suffix => format!("replace_fk_on_update_{:?}", backend),
406        }, {
407            assert_snapshot!(combined);
408        });
409    }
410
411    #[rstest]
412    #[case::postgres(DatabaseBackend::Postgres)]
413    #[case::mysql(DatabaseBackend::MySql)]
414    #[case::sqlite(DatabaseBackend::Sqlite)]
415    fn replace_unique_constraint(#[case] backend: DatabaseBackend) {
416        // Non-FK constraint: PG/MySQL uses remove+add, SQLite uses temp table
417        // Multi-table schema so the non-target table hits the else branch (t.clone())
418        let schema = vec![
419            TableDef {
420                name: "other".into(),
421                description: None,
422                columns: vec![ColumnDef {
423                    name: "id".into(),
424                    r#type: ColumnType::Simple(SimpleColumnType::Integer),
425                    nullable: false,
426                    default: None,
427                    comment: None,
428                    primary_key: None,
429                    unique: None,
430                    index: None,
431                    foreign_key: None,
432                }],
433                constraints: vec![],
434            },
435            TableDef {
436                name: "users".into(),
437                description: None,
438                columns: vec![
439                    ColumnDef {
440                        name: "id".into(),
441                        r#type: ColumnType::Simple(SimpleColumnType::Integer),
442                        nullable: false,
443                        default: None,
444                        comment: None,
445                        primary_key: None,
446                        unique: None,
447                        index: None,
448                        foreign_key: None,
449                    },
450                    ColumnDef {
451                        name: "email".into(),
452                        r#type: ColumnType::Simple(SimpleColumnType::Text),
453                        nullable: false,
454                        default: None,
455                        comment: None,
456                        primary_key: None,
457                        unique: None,
458                        index: None,
459                        foreign_key: None,
460                    },
461                ],
462                constraints: vec![
463                    TableConstraint::PrimaryKey {
464                        auto_increment: false,
465                        columns: vec!["id".into()],
466                        strategy: vespertide_core::PrimaryKeyAdditionStrategy::default(),
467                    },
468                    TableConstraint::Unique {
469                        name: Some("uq_email".into()),
470                        columns: vec!["email".into()],
471                        strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
472                            keep: vespertide_core::KeepPolicy::First,
473                        },
474                    },
475                ],
476            },
477        ];
478        let from = TableConstraint::Unique {
479            name: Some("uq_email".into()),
480            columns: vec!["email".into()],
481            strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
482                keep: vespertide_core::KeepPolicy::First,
483            },
484        };
485        let to = TableConstraint::Unique {
486            name: Some("uq_email_new".into()),
487            columns: vec!["email".into()],
488            strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
489                keep: vespertide_core::KeepPolicy::First,
490            },
491        };
492
493        let queries = build_replace_constraint(backend, "users", &from, &to, &schema, &[])
494            .expect("should succeed");
495        let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
496        let combined = sql.join(";\n");
497
498        with_settings!({
499            description => format!("replace unique constraint for {:?}", backend),
500            omit_expression => true,
501            snapshot_suffix => format!("replace_unique_{:?}", backend),
502        }, {
503            assert_snapshot!(combined);
504        });
505    }
506
507    #[test]
508    fn replace_constraint_table_not_found_sqlite() {
509        let from = TableConstraint::Unique {
510            name: Some("uq_old".into()),
511            columns: vec!["col".into()],
512            strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
513                keep: vespertide_core::KeepPolicy::First,
514            },
515        };
516        let to = TableConstraint::Unique {
517            name: Some("uq_new".into()),
518            columns: vec!["col".into()],
519            strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
520                keep: vespertide_core::KeepPolicy::First,
521            },
522        };
523        let err =
524            build_replace_constraint(DatabaseBackend::Sqlite, "missing", &from, &to, &[], &[])
525                .unwrap_err();
526        assert!(format!("{err}").contains("missing"));
527    }
528}