Skip to main content

vespertide_query/sql/
replace_constraint.rs

1use sea_query::{Alias, ForeignKey, Query, Table};
2
3use vespertide_core::{TableConstraint, TableDef};
4
5use super::helpers::{
6    build_sqlite_temp_table_create, recreate_indexes_after_rebuild, to_sea_fk_action,
7};
8use super::rename_table::build_rename_table;
9use super::types::{BuiltQuery, DatabaseBackend};
10use crate::error::QueryError;
11
12/// Build SQL queries to replace a constraint in-place.
13///
14/// For PostgreSQL/MySQL: DROP old FK + ADD new FK (two ALTER TABLE statements).
15/// For SQLite: single temp table recreation with the new constraint swapped in.
16///
17/// This avoids the double table recreation that would occur with separate
18/// RemoveConstraint + AddConstraint on SQLite.
19pub fn build_replace_constraint(
20    backend: &DatabaseBackend,
21    table: &str,
22    from: &TableConstraint,
23    to: &TableConstraint,
24    current_schema: &[TableDef],
25    pending_constraints: &[TableConstraint],
26) -> Result<Vec<BuiltQuery>, QueryError> {
27    match (from, to) {
28        (
29            TableConstraint::ForeignKey {
30                name: old_name,
31                columns: old_columns,
32                ..
33            },
34            TableConstraint::ForeignKey {
35                name: new_name,
36                columns: new_columns,
37                ref_table,
38                ref_columns,
39                on_delete,
40                on_update,
41            },
42        ) => {
43            if *backend == DatabaseBackend::Sqlite {
44                build_sqlite_constraint_replace(
45                    backend,
46                    table,
47                    from,
48                    to,
49                    current_schema,
50                    pending_constraints,
51                )
52            } else {
53                // PostgreSQL/MySQL: DROP old FK + ADD new FK
54                let old_fk_name = vespertide_naming::build_foreign_key_name(
55                    table,
56                    old_columns,
57                    old_name.as_deref(),
58                );
59                let fk_drop = ForeignKey::drop()
60                    .name(&old_fk_name)
61                    .table(Alias::new(table))
62                    .to_owned();
63
64                let new_fk_name = vespertide_naming::build_foreign_key_name(
65                    table,
66                    new_columns,
67                    new_name.as_deref(),
68                );
69                let mut fk_create = ForeignKey::create();
70                fk_create = fk_create.name(&new_fk_name).to_owned();
71                fk_create = fk_create.from_tbl(Alias::new(table)).to_owned();
72                for col in new_columns {
73                    fk_create = fk_create.from_col(Alias::new(col)).to_owned();
74                }
75                fk_create = fk_create.to_tbl(Alias::new(ref_table)).to_owned();
76                for col in ref_columns {
77                    fk_create = fk_create.to_col(Alias::new(col)).to_owned();
78                }
79                if let Some(action) = on_delete {
80                    fk_create = fk_create.on_delete(to_sea_fk_action(action)).to_owned();
81                }
82                if let Some(action) = on_update {
83                    fk_create = fk_create.on_update(to_sea_fk_action(action)).to_owned();
84                }
85
86                Ok(vec![
87                    BuiltQuery::DropForeignKey(Box::new(fk_drop)),
88                    BuiltQuery::CreateForeignKey(Box::new(fk_create)),
89                ])
90            }
91        }
92        // For non-FK constraints: SQLite uses single temp table, PG/MySQL uses remove + add
93        _ => {
94            if *backend == DatabaseBackend::Sqlite {
95                build_sqlite_constraint_replace(
96                    backend,
97                    table,
98                    from,
99                    to,
100                    current_schema,
101                    pending_constraints,
102                )
103            } else {
104                let mut queries = super::remove_constraint::build_remove_constraint(
105                    backend,
106                    table,
107                    from,
108                    current_schema,
109                    pending_constraints,
110                )?;
111
112                // Build a modified schema with the old constraint removed and new one added
113                let modified_schema: Vec<TableDef> = current_schema
114                    .iter()
115                    .map(|t| {
116                        if t.name == table {
117                            let mut modified = t.clone();
118                            modified.constraints.retain(|c| c != from);
119                            modified.constraints.push(to.clone());
120                            modified
121                        } else {
122                            t.clone()
123                        }
124                    })
125                    .collect();
126
127                queries.extend(super::add_constraint::build_add_constraint(
128                    backend,
129                    table,
130                    to,
131                    &modified_schema,
132                    pending_constraints,
133                )?);
134                Ok(queries)
135            }
136        }
137    }
138}
139
140/// SQLite: single temp table recreation with the constraint replaced.
141/// Works for all constraint types (FK, Check, Unique, Index, PK).
142fn build_sqlite_constraint_replace(
143    backend: &DatabaseBackend,
144    table: &str,
145    from: &TableConstraint,
146    to: &TableConstraint,
147    current_schema: &[TableDef],
148    pending_constraints: &[TableConstraint],
149) -> Result<Vec<BuiltQuery>, QueryError> {
150    let table_def = current_schema
151        .iter()
152        .find(|t| t.name == table)
153        .ok_or_else(|| {
154            QueryError::Other(format!(
155                "Table '{}' not found in current schema. SQLite requires current schema \
156                 information to replace constraints.",
157                table
158            ))
159        })?;
160
161    // Build new constraints: replace old constraint with new one
162    let new_constraints: Vec<TableConstraint> = table_def
163        .constraints
164        .iter()
165        .map(|c| if c == from { to.clone() } else { c.clone() })
166        .collect();
167
168    let temp_table = format!("{}_temp", table);
169
170    // 1. Create temporary table with replaced constraint
171    let create_query = build_sqlite_temp_table_create(
172        backend,
173        &temp_table,
174        table,
175        &table_def.columns,
176        &new_constraints,
177    );
178
179    // 2. Copy data (all columns)
180    let column_aliases: Vec<Alias> = table_def
181        .columns
182        .iter()
183        .map(|c| Alias::new(&c.name))
184        .collect();
185    let mut select_query = Query::select();
186    for col_alias in &column_aliases {
187        select_query = select_query.column(col_alias.clone()).to_owned();
188    }
189    select_query = select_query.from(Alias::new(table)).to_owned();
190
191    let insert_stmt = Query::insert()
192        .into_table(Alias::new(&temp_table))
193        .columns(column_aliases.clone())
194        .select_from(select_query)
195        .unwrap()
196        .to_owned();
197    let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
198
199    // 3. Drop original table
200    let drop_table = Table::drop().table(Alias::new(table)).to_owned();
201    let drop_query = BuiltQuery::DropTable(Box::new(drop_table));
202
203    // 4. Rename temporary table to original name
204    let rename_query = build_rename_table(&temp_table, table);
205
206    // 5. Recreate indexes (both regular and UNIQUE)
207    let index_queries =
208        recreate_indexes_after_rebuild(table, &table_def.constraints, pending_constraints);
209
210    let mut queries = vec![create_query, insert_query, drop_query, rename_query];
211    queries.extend(index_queries);
212    Ok(queries)
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use insta::{assert_snapshot, with_settings};
219    use rstest::rstest;
220    use vespertide_core::{
221        ColumnDef, ColumnType, ReferenceAction, SimpleColumnType, TableConstraint, TableDef,
222    };
223
224    fn test_schema() -> Vec<TableDef> {
225        vec![
226            TableDef {
227                name: "users".into(),
228                columns: vec![ColumnDef {
229                    name: "id".into(),
230                    r#type: ColumnType::Simple(SimpleColumnType::Integer),
231                    nullable: false,
232                    default: None,
233                    comment: None,
234                    primary_key: None,
235                    unique: None,
236                    index: None,
237                    foreign_key: None,
238                }],
239                constraints: vec![TableConstraint::PrimaryKey {
240                    auto_increment: false,
241                    columns: vec!["id".into()],
242                }],
243                description: None,
244            },
245            TableDef {
246                name: "posts".into(),
247                columns: vec![
248                    ColumnDef {
249                        name: "id".into(),
250                        r#type: ColumnType::Simple(SimpleColumnType::Integer),
251                        nullable: false,
252                        default: None,
253                        comment: None,
254                        primary_key: None,
255                        unique: None,
256                        index: None,
257                        foreign_key: None,
258                    },
259                    ColumnDef {
260                        name: "user_id".into(),
261                        r#type: ColumnType::Simple(SimpleColumnType::Integer),
262                        nullable: false,
263                        default: None,
264                        comment: None,
265                        primary_key: None,
266                        unique: None,
267                        index: None,
268                        foreign_key: None,
269                    },
270                ],
271                constraints: vec![
272                    TableConstraint::PrimaryKey {
273                        auto_increment: false,
274                        columns: vec!["id".into()],
275                    },
276                    TableConstraint::ForeignKey {
277                        name: Some("fk_user".into()),
278                        columns: vec!["user_id".into()],
279                        ref_table: "users".into(),
280                        ref_columns: vec!["id".into()],
281                        on_delete: None,
282                        on_update: None,
283                    },
284                ],
285                description: None,
286            },
287        ]
288    }
289
290    #[rstest]
291    #[case::postgres(DatabaseBackend::Postgres)]
292    #[case::mysql(DatabaseBackend::MySql)]
293    #[case::sqlite(DatabaseBackend::Sqlite)]
294    fn replace_fk_on_delete(#[case] backend: DatabaseBackend) {
295        let schema = test_schema();
296        let from = TableConstraint::ForeignKey {
297            name: Some("fk_user".into()),
298            columns: vec!["user_id".into()],
299            ref_table: "users".into(),
300            ref_columns: vec!["id".into()],
301            on_delete: None,
302            on_update: None,
303        };
304        let to = TableConstraint::ForeignKey {
305            name: Some("fk_user".into()),
306            columns: vec!["user_id".into()],
307            ref_table: "users".into(),
308            ref_columns: vec!["id".into()],
309            on_delete: Some(ReferenceAction::Cascade),
310            on_update: None,
311        };
312
313        let queries = build_replace_constraint(&backend, "posts", &from, &to, &schema, &[])
314            .expect("should succeed");
315
316        let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
317        let combined = sql.join(";\n");
318
319        with_settings!({
320            description => format!("replace FK on_delete for {:?}", backend),
321            omit_expression => true,
322            snapshot_suffix => format!("replace_fk_on_delete_{:?}", backend),
323        }, {
324            assert_snapshot!(combined);
325        });
326    }
327
328    #[rstest]
329    #[case::postgres(DatabaseBackend::Postgres)]
330    #[case::mysql(DatabaseBackend::MySql)]
331    #[case::sqlite(DatabaseBackend::Sqlite)]
332    fn replace_fk_on_update(#[case] backend: DatabaseBackend) {
333        let schema = test_schema();
334        let from = TableConstraint::ForeignKey {
335            name: Some("fk_user".into()),
336            columns: vec!["user_id".into()],
337            ref_table: "users".into(),
338            ref_columns: vec!["id".into()],
339            on_delete: None,
340            on_update: None,
341        };
342        let to = TableConstraint::ForeignKey {
343            name: Some("fk_user".into()),
344            columns: vec!["user_id".into()],
345            ref_table: "users".into(),
346            ref_columns: vec!["id".into()],
347            on_delete: None,
348            on_update: Some(ReferenceAction::Cascade),
349        };
350
351        let queries = build_replace_constraint(&backend, "posts", &from, &to, &schema, &[])
352            .expect("should succeed");
353        let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
354        let combined = sql.join(";\n");
355
356        with_settings!({
357            description => format!("replace FK on_update for {:?}", backend),
358            omit_expression => true,
359            snapshot_suffix => format!("replace_fk_on_update_{:?}", backend),
360        }, {
361            assert_snapshot!(combined);
362        });
363    }
364
365    #[rstest]
366    #[case::postgres(DatabaseBackend::Postgres)]
367    #[case::mysql(DatabaseBackend::MySql)]
368    #[case::sqlite(DatabaseBackend::Sqlite)]
369    fn replace_unique_constraint(#[case] backend: DatabaseBackend) {
370        // Non-FK constraint: PG/MySQL uses remove+add, SQLite uses temp table
371        // Multi-table schema so the non-target table hits the else branch (t.clone())
372        let schema = vec![
373            TableDef {
374                name: "other".into(),
375                description: None,
376                columns: vec![ColumnDef {
377                    name: "id".into(),
378                    r#type: ColumnType::Simple(SimpleColumnType::Integer),
379                    nullable: false,
380                    default: None,
381                    comment: None,
382                    primary_key: None,
383                    unique: None,
384                    index: None,
385                    foreign_key: None,
386                }],
387                constraints: vec![],
388            },
389            TableDef {
390                name: "users".into(),
391                description: None,
392                columns: vec![
393                    ColumnDef {
394                        name: "id".into(),
395                        r#type: ColumnType::Simple(SimpleColumnType::Integer),
396                        nullable: false,
397                        default: None,
398                        comment: None,
399                        primary_key: None,
400                        unique: None,
401                        index: None,
402                        foreign_key: None,
403                    },
404                    ColumnDef {
405                        name: "email".into(),
406                        r#type: ColumnType::Simple(SimpleColumnType::Text),
407                        nullable: false,
408                        default: None,
409                        comment: None,
410                        primary_key: None,
411                        unique: None,
412                        index: None,
413                        foreign_key: None,
414                    },
415                ],
416                constraints: vec![
417                    TableConstraint::PrimaryKey {
418                        auto_increment: false,
419                        columns: vec!["id".into()],
420                    },
421                    TableConstraint::Unique {
422                        name: Some("uq_email".into()),
423                        columns: vec!["email".into()],
424                    },
425                ],
426            },
427        ];
428        let from = TableConstraint::Unique {
429            name: Some("uq_email".into()),
430            columns: vec!["email".into()],
431        };
432        let to = TableConstraint::Unique {
433            name: Some("uq_email_new".into()),
434            columns: vec!["email".into()],
435        };
436
437        let queries = build_replace_constraint(&backend, "users", &from, &to, &schema, &[])
438            .expect("should succeed");
439        let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
440        let combined = sql.join(";\n");
441
442        with_settings!({
443            description => format!("replace unique constraint for {:?}", backend),
444            omit_expression => true,
445            snapshot_suffix => format!("replace_unique_{:?}", backend),
446        }, {
447            assert_snapshot!(combined);
448        });
449    }
450
451    #[test]
452    fn replace_constraint_table_not_found_sqlite() {
453        let from = TableConstraint::Unique {
454            name: Some("uq_old".into()),
455            columns: vec!["col".into()],
456        };
457        let to = TableConstraint::Unique {
458            name: Some("uq_new".into()),
459            columns: vec!["col".into()],
460        };
461        let err =
462            build_replace_constraint(&DatabaseBackend::Sqlite, "missing", &from, &to, &[], &[])
463                .unwrap_err();
464        assert!(format!("{}", err).contains("missing"));
465    }
466}