1use sea_query::{Alias, ForeignKey, Query, Table};
2
3use vespertide_core::{TableConstraint, TableDef};
4
5use super::helpers::{
6 build_sqlite_temp_table_create, recreate_indexes_after_rebuild, to_sea_fk_action,
7};
8use super::rename_table::build_rename_table;
9use super::types::{BuiltQuery, DatabaseBackend};
10use crate::error::QueryError;
11
12pub fn build_replace_constraint(
20 backend: &DatabaseBackend,
21 table: &str,
22 from: &TableConstraint,
23 to: &TableConstraint,
24 current_schema: &[TableDef],
25 pending_constraints: &[TableConstraint],
26) -> Result<Vec<BuiltQuery>, QueryError> {
27 match (from, to) {
28 (
29 TableConstraint::ForeignKey {
30 name: old_name,
31 columns: old_columns,
32 ..
33 },
34 TableConstraint::ForeignKey {
35 name: new_name,
36 columns: new_columns,
37 ref_table,
38 ref_columns,
39 on_delete,
40 on_update,
41 },
42 ) => {
43 if *backend == DatabaseBackend::Sqlite {
44 build_sqlite_constraint_replace(
45 backend,
46 table,
47 from,
48 to,
49 current_schema,
50 pending_constraints,
51 )
52 } else {
53 let old_fk_name = vespertide_naming::build_foreign_key_name(
55 table,
56 old_columns,
57 old_name.as_deref(),
58 );
59 let fk_drop = ForeignKey::drop()
60 .name(&old_fk_name)
61 .table(Alias::new(table))
62 .to_owned();
63
64 let new_fk_name = vespertide_naming::build_foreign_key_name(
65 table,
66 new_columns,
67 new_name.as_deref(),
68 );
69 let mut fk_create = ForeignKey::create();
70 fk_create = fk_create.name(&new_fk_name).to_owned();
71 fk_create = fk_create.from_tbl(Alias::new(table)).to_owned();
72 for col in new_columns {
73 fk_create = fk_create.from_col(Alias::new(col)).to_owned();
74 }
75 fk_create = fk_create.to_tbl(Alias::new(ref_table)).to_owned();
76 for col in ref_columns {
77 fk_create = fk_create.to_col(Alias::new(col)).to_owned();
78 }
79 if let Some(action) = on_delete {
80 fk_create = fk_create.on_delete(to_sea_fk_action(action)).to_owned();
81 }
82 if let Some(action) = on_update {
83 fk_create = fk_create.on_update(to_sea_fk_action(action)).to_owned();
84 }
85
86 Ok(vec![
87 BuiltQuery::DropForeignKey(Box::new(fk_drop)),
88 BuiltQuery::CreateForeignKey(Box::new(fk_create)),
89 ])
90 }
91 }
92 _ => {
94 if *backend == DatabaseBackend::Sqlite {
95 build_sqlite_constraint_replace(
96 backend,
97 table,
98 from,
99 to,
100 current_schema,
101 pending_constraints,
102 )
103 } else {
104 let mut queries = super::remove_constraint::build_remove_constraint(
105 backend,
106 table,
107 from,
108 current_schema,
109 pending_constraints,
110 )?;
111
112 let modified_schema: Vec<TableDef> = current_schema
114 .iter()
115 .map(|t| {
116 if t.name == table {
117 let mut modified = t.clone();
118 modified.constraints.retain(|c| c != from);
119 modified.constraints.push(to.clone());
120 modified
121 } else {
122 t.clone()
123 }
124 })
125 .collect();
126
127 queries.extend(super::add_constraint::build_add_constraint(
128 backend,
129 table,
130 to,
131 &modified_schema,
132 pending_constraints,
133 )?);
134 Ok(queries)
135 }
136 }
137 }
138}
139
140fn build_sqlite_constraint_replace(
143 backend: &DatabaseBackend,
144 table: &str,
145 from: &TableConstraint,
146 to: &TableConstraint,
147 current_schema: &[TableDef],
148 pending_constraints: &[TableConstraint],
149) -> Result<Vec<BuiltQuery>, QueryError> {
150 let table_def = current_schema
151 .iter()
152 .find(|t| t.name == table)
153 .ok_or_else(|| {
154 QueryError::Other(format!(
155 "Table '{}' not found in current schema. SQLite requires current schema \
156 information to replace constraints.",
157 table
158 ))
159 })?;
160
161 let new_constraints: Vec<TableConstraint> = table_def
163 .constraints
164 .iter()
165 .map(|c| if c == from { to.clone() } else { c.clone() })
166 .collect();
167
168 let temp_table = format!("{}_temp", table);
169
170 let create_query = build_sqlite_temp_table_create(
172 backend,
173 &temp_table,
174 table,
175 &table_def.columns,
176 &new_constraints,
177 );
178
179 let column_aliases: Vec<Alias> = table_def
181 .columns
182 .iter()
183 .map(|c| Alias::new(&c.name))
184 .collect();
185 let mut select_query = Query::select();
186 for col_alias in &column_aliases {
187 select_query = select_query.column(col_alias.clone()).to_owned();
188 }
189 select_query = select_query.from(Alias::new(table)).to_owned();
190
191 let insert_stmt = Query::insert()
192 .into_table(Alias::new(&temp_table))
193 .columns(column_aliases.clone())
194 .select_from(select_query)
195 .unwrap()
196 .to_owned();
197 let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
198
199 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
201 let drop_query = BuiltQuery::DropTable(Box::new(drop_table));
202
203 let rename_query = build_rename_table(&temp_table, table);
205
206 let index_queries =
208 recreate_indexes_after_rebuild(table, &table_def.constraints, pending_constraints);
209
210 let mut queries = vec![create_query, insert_query, drop_query, rename_query];
211 queries.extend(index_queries);
212 Ok(queries)
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use insta::{assert_snapshot, with_settings};
219 use rstest::rstest;
220 use vespertide_core::{
221 ColumnDef, ColumnType, ReferenceAction, SimpleColumnType, TableConstraint, TableDef,
222 };
223
224 fn test_schema() -> Vec<TableDef> {
225 vec![
226 TableDef {
227 name: "users".into(),
228 columns: vec![ColumnDef {
229 name: "id".into(),
230 r#type: ColumnType::Simple(SimpleColumnType::Integer),
231 nullable: false,
232 default: None,
233 comment: None,
234 primary_key: None,
235 unique: None,
236 index: None,
237 foreign_key: None,
238 }],
239 constraints: vec![TableConstraint::PrimaryKey {
240 auto_increment: false,
241 columns: vec!["id".into()],
242 }],
243 description: None,
244 },
245 TableDef {
246 name: "posts".into(),
247 columns: vec![
248 ColumnDef {
249 name: "id".into(),
250 r#type: ColumnType::Simple(SimpleColumnType::Integer),
251 nullable: false,
252 default: None,
253 comment: None,
254 primary_key: None,
255 unique: None,
256 index: None,
257 foreign_key: None,
258 },
259 ColumnDef {
260 name: "user_id".into(),
261 r#type: ColumnType::Simple(SimpleColumnType::Integer),
262 nullable: false,
263 default: None,
264 comment: None,
265 primary_key: None,
266 unique: None,
267 index: None,
268 foreign_key: None,
269 },
270 ],
271 constraints: vec![
272 TableConstraint::PrimaryKey {
273 auto_increment: false,
274 columns: vec!["id".into()],
275 },
276 TableConstraint::ForeignKey {
277 name: Some("fk_user".into()),
278 columns: vec!["user_id".into()],
279 ref_table: "users".into(),
280 ref_columns: vec!["id".into()],
281 on_delete: None,
282 on_update: None,
283 },
284 ],
285 description: None,
286 },
287 ]
288 }
289
290 #[rstest]
291 #[case::postgres(DatabaseBackend::Postgres)]
292 #[case::mysql(DatabaseBackend::MySql)]
293 #[case::sqlite(DatabaseBackend::Sqlite)]
294 fn replace_fk_on_delete(#[case] backend: DatabaseBackend) {
295 let schema = test_schema();
296 let from = TableConstraint::ForeignKey {
297 name: Some("fk_user".into()),
298 columns: vec!["user_id".into()],
299 ref_table: "users".into(),
300 ref_columns: vec!["id".into()],
301 on_delete: None,
302 on_update: None,
303 };
304 let to = TableConstraint::ForeignKey {
305 name: Some("fk_user".into()),
306 columns: vec!["user_id".into()],
307 ref_table: "users".into(),
308 ref_columns: vec!["id".into()],
309 on_delete: Some(ReferenceAction::Cascade),
310 on_update: None,
311 };
312
313 let queries = build_replace_constraint(&backend, "posts", &from, &to, &schema, &[])
314 .expect("should succeed");
315
316 let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
317 let combined = sql.join(";\n");
318
319 with_settings!({
320 description => format!("replace FK on_delete for {:?}", backend),
321 omit_expression => true,
322 snapshot_suffix => format!("replace_fk_on_delete_{:?}", backend),
323 }, {
324 assert_snapshot!(combined);
325 });
326 }
327
328 #[rstest]
329 #[case::postgres(DatabaseBackend::Postgres)]
330 #[case::mysql(DatabaseBackend::MySql)]
331 #[case::sqlite(DatabaseBackend::Sqlite)]
332 fn replace_fk_on_update(#[case] backend: DatabaseBackend) {
333 let schema = test_schema();
334 let from = TableConstraint::ForeignKey {
335 name: Some("fk_user".into()),
336 columns: vec!["user_id".into()],
337 ref_table: "users".into(),
338 ref_columns: vec!["id".into()],
339 on_delete: None,
340 on_update: None,
341 };
342 let to = TableConstraint::ForeignKey {
343 name: Some("fk_user".into()),
344 columns: vec!["user_id".into()],
345 ref_table: "users".into(),
346 ref_columns: vec!["id".into()],
347 on_delete: None,
348 on_update: Some(ReferenceAction::Cascade),
349 };
350
351 let queries = build_replace_constraint(&backend, "posts", &from, &to, &schema, &[])
352 .expect("should succeed");
353 let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
354 let combined = sql.join(";\n");
355
356 with_settings!({
357 description => format!("replace FK on_update for {:?}", backend),
358 omit_expression => true,
359 snapshot_suffix => format!("replace_fk_on_update_{:?}", backend),
360 }, {
361 assert_snapshot!(combined);
362 });
363 }
364
365 #[rstest]
366 #[case::postgres(DatabaseBackend::Postgres)]
367 #[case::mysql(DatabaseBackend::MySql)]
368 #[case::sqlite(DatabaseBackend::Sqlite)]
369 fn replace_unique_constraint(#[case] backend: DatabaseBackend) {
370 let schema = vec![
373 TableDef {
374 name: "other".into(),
375 description: None,
376 columns: vec![ColumnDef {
377 name: "id".into(),
378 r#type: ColumnType::Simple(SimpleColumnType::Integer),
379 nullable: false,
380 default: None,
381 comment: None,
382 primary_key: None,
383 unique: None,
384 index: None,
385 foreign_key: None,
386 }],
387 constraints: vec![],
388 },
389 TableDef {
390 name: "users".into(),
391 description: None,
392 columns: vec![
393 ColumnDef {
394 name: "id".into(),
395 r#type: ColumnType::Simple(SimpleColumnType::Integer),
396 nullable: false,
397 default: None,
398 comment: None,
399 primary_key: None,
400 unique: None,
401 index: None,
402 foreign_key: None,
403 },
404 ColumnDef {
405 name: "email".into(),
406 r#type: ColumnType::Simple(SimpleColumnType::Text),
407 nullable: false,
408 default: None,
409 comment: None,
410 primary_key: None,
411 unique: None,
412 index: None,
413 foreign_key: None,
414 },
415 ],
416 constraints: vec![
417 TableConstraint::PrimaryKey {
418 auto_increment: false,
419 columns: vec!["id".into()],
420 },
421 TableConstraint::Unique {
422 name: Some("uq_email".into()),
423 columns: vec!["email".into()],
424 },
425 ],
426 },
427 ];
428 let from = TableConstraint::Unique {
429 name: Some("uq_email".into()),
430 columns: vec!["email".into()],
431 };
432 let to = TableConstraint::Unique {
433 name: Some("uq_email_new".into()),
434 columns: vec!["email".into()],
435 };
436
437 let queries = build_replace_constraint(&backend, "users", &from, &to, &schema, &[])
438 .expect("should succeed");
439 let sql: Vec<String> = queries.iter().map(|q| q.build(backend)).collect();
440 let combined = sql.join(";\n");
441
442 with_settings!({
443 description => format!("replace unique constraint for {:?}", backend),
444 omit_expression => true,
445 snapshot_suffix => format!("replace_unique_{:?}", backend),
446 }, {
447 assert_snapshot!(combined);
448 });
449 }
450
451 #[test]
452 fn replace_constraint_table_not_found_sqlite() {
453 let from = TableConstraint::Unique {
454 name: Some("uq_old".into()),
455 columns: vec!["col".into()],
456 };
457 let to = TableConstraint::Unique {
458 name: Some("uq_new".into()),
459 columns: vec!["col".into()],
460 };
461 let err =
462 build_replace_constraint(&DatabaseBackend::Sqlite, "missing", &from, &to, &[], &[])
463 .unwrap_err();
464 assert!(format!("{}", err).contains("missing"));
465 }
466}