1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, normalize_enum_default,
7 recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13pub fn build_modify_column_default(
15 backend: &DatabaseBackend,
16 table: &str,
17 column: &str,
18 new_default: Option<&str>,
19 current_schema: &[TableDef],
20 pending_constraints: &[vespertide_core::TableConstraint],
21) -> Result<Vec<BuiltQuery>, QueryError> {
22 let mut queries = Vec::new();
23
24 match backend {
25 DatabaseBackend::Postgres => {
26 let alter_sql = if let Some(default_value) = new_default {
27 let column_type = current_schema
29 .iter()
30 .find(|t| t.name == table)
31 .and_then(|t| t.columns.iter().find(|c| c.name == column))
32 .map(|c| &c.r#type);
33
34 let normalized_default = if let Some(col_type) = column_type {
35 normalize_enum_default(col_type, default_value)
36 } else {
37 default_value.to_string()
38 };
39
40 format!(
41 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
42 table, column, normalized_default
43 )
44 } else {
45 format!(
46 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
47 table, column
48 )
49 };
50 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
51 }
52 DatabaseBackend::MySql => {
53 let table_def = current_schema
55 .iter()
56 .find(|t| t.name == table)
57 .ok_or_else(|| {
58 QueryError::Other(format!("Table '{}' not found in current schema.", table))
59 })?;
60
61 let column_def = table_def
62 .columns
63 .iter()
64 .find(|c| c.name == column)
65 .ok_or_else(|| {
66 QueryError::Other(format!(
67 "Column '{}' not found in table '{}'.",
68 column, table
69 ))
70 })?;
71
72 let modified_col_def = ColumnDef {
74 default: new_default.map(|s| s.into()),
75 ..column_def.clone()
76 };
77
78 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
79
80 let stmt = Table::alter()
81 .table(Alias::new(table))
82 .modify_column(sea_col)
83 .to_owned();
84 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
85 }
86 DatabaseBackend::Sqlite => {
87 let table_def = current_schema
90 .iter()
91 .find(|t| t.name == table)
92 .ok_or_else(|| {
93 QueryError::Other(format!("Table '{}' not found in current schema.", table))
94 })?;
95
96 let mut new_columns = table_def.columns.clone();
98 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
99 col.default = new_default.map(|s| s.into());
100 }
101
102 let temp_table = format!("{}_temp", table);
104
105 let create_query = build_sqlite_temp_table_create(
107 backend,
108 &temp_table,
109 table,
110 &new_columns,
111 &table_def.constraints,
112 );
113 queries.push(create_query);
114
115 let column_aliases: Vec<Alias> = table_def
117 .columns
118 .iter()
119 .map(|c| Alias::new(&c.name))
120 .collect();
121 let mut select_query = Query::select();
122 for col_alias in &column_aliases {
123 select_query = select_query.column(col_alias.clone()).to_owned();
124 }
125 select_query = select_query.from(Alias::new(table)).to_owned();
126
127 let insert_stmt = Query::insert()
128 .into_table(Alias::new(&temp_table))
129 .columns(column_aliases.clone())
130 .select_from(select_query)
131 .unwrap()
132 .to_owned();
133 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
134
135 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
137 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
138
139 queries.push(build_rename_table(&temp_table, table));
141
142 queries.extend(recreate_indexes_after_rebuild(
144 table,
145 &table_def.constraints,
146 pending_constraints,
147 ));
148 }
149 }
150
151 Ok(queries)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use insta::{assert_snapshot, with_settings};
158 use rstest::rstest;
159 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
160
161 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
162 ColumnDef {
163 name: name.to_string(),
164 r#type: ty,
165 nullable,
166 default: None,
167 comment: None,
168 primary_key: None,
169 unique: None,
170 index: None,
171 foreign_key: None,
172 }
173 }
174
175 fn table_def(
176 name: &str,
177 columns: Vec<ColumnDef>,
178 constraints: Vec<TableConstraint>,
179 ) -> TableDef {
180 TableDef {
181 name: name.to_string(),
182 description: None,
183 columns,
184 constraints,
185 }
186 }
187
188 #[rstest]
189 #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
190 #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
191 #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
192 #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
193 #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
194 #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
195 fn test_build_modify_column_default(
196 #[case] backend: DatabaseBackend,
197 #[case] new_default: Option<&str>,
198 ) {
199 let schema = vec![table_def(
200 "users",
201 vec![
202 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
203 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
204 ],
205 vec![],
206 )];
207
208 let result =
209 build_modify_column_default(&backend, "users", "email", new_default, &schema, &[]);
210 assert!(result.is_ok());
211 let queries = result.unwrap();
212 let sql = queries
213 .iter()
214 .map(|q| q.build(backend))
215 .collect::<Vec<String>>()
216 .join("\n");
217
218 let suffix = format!(
219 "{}_{}_users",
220 match backend {
221 DatabaseBackend::Postgres => "postgres",
222 DatabaseBackend::MySql => "mysql",
223 DatabaseBackend::Sqlite => "sqlite",
224 },
225 if new_default.is_some() {
226 "set_default"
227 } else {
228 "drop_default"
229 }
230 );
231
232 with_settings!({ snapshot_suffix => suffix }, {
233 assert_snapshot!(sql);
234 });
235 }
236
237 #[rstest]
239 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
240 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
241 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
242 fn test_table_not_found(#[case] backend: DatabaseBackend) {
243 if backend == DatabaseBackend::Postgres {
245 return;
246 }
247
248 let result =
249 build_modify_column_default(&backend, "users", "email", Some("'default'"), &[], &[]);
250 assert!(result.is_err());
251 let err_msg = result.unwrap_err().to_string();
252 assert!(err_msg.contains("Table 'users' not found"));
253 }
254
255 #[rstest]
257 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
258 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
259 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
260 fn test_column_not_found(#[case] backend: DatabaseBackend) {
261 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
264 return;
265 }
266
267 let schema = vec![table_def(
268 "users",
269 vec![col(
270 "id",
271 ColumnType::Simple(SimpleColumnType::Integer),
272 false,
273 )],
274 vec![],
275 )];
276
277 let result = build_modify_column_default(
278 &backend,
279 "users",
280 "email",
281 Some("'default'"),
282 &schema,
283 &[],
284 );
285 assert!(result.is_err());
286 let err_msg = result.unwrap_err().to_string();
287 assert!(err_msg.contains("Column 'email' not found"));
288 }
289
290 #[test]
293 fn test_postgres_column_not_in_schema_uses_default_as_is() {
294 let schema = vec![table_def(
295 "users",
296 vec![col(
297 "id",
298 ColumnType::Simple(SimpleColumnType::Integer),
299 false,
300 )],
301 vec![],
303 )];
304
305 let result = build_modify_column_default(
307 &DatabaseBackend::Postgres,
308 "users",
309 "status", Some("'active'"),
311 &schema,
312 &[],
313 );
314 assert!(result.is_ok());
315 let queries = result.unwrap();
316 let sql = queries
317 .iter()
318 .map(|q| q.build(DatabaseBackend::Postgres))
319 .collect::<Vec<String>>()
320 .join("\n");
321
322 assert!(sql.contains("ALTER TABLE \"users\" ALTER COLUMN \"status\" SET DEFAULT 'active'"));
324 }
325
326 #[rstest]
328 #[case::postgres_with_index(DatabaseBackend::Postgres)]
329 #[case::mysql_with_index(DatabaseBackend::MySql)]
330 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
331 fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
332 let schema = vec![table_def(
333 "users",
334 vec![
335 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
336 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
337 ],
338 vec![TableConstraint::Index {
339 name: Some("idx_users_email".into()),
340 columns: vec!["email".into()],
341 }],
342 )];
343
344 let result = build_modify_column_default(
345 &backend,
346 "users",
347 "email",
348 Some("'default@example.com'"),
349 &schema,
350 &[],
351 );
352 assert!(result.is_ok());
353 let queries = result.unwrap();
354 let sql = queries
355 .iter()
356 .map(|q| q.build(backend))
357 .collect::<Vec<String>>()
358 .join("\n");
359
360 if backend == DatabaseBackend::Sqlite {
362 assert!(sql.contains("CREATE INDEX"));
363 assert!(sql.contains("idx_users_email"));
364 }
365
366 let suffix = format!(
367 "{}_with_index",
368 match backend {
369 DatabaseBackend::Postgres => "postgres",
370 DatabaseBackend::MySql => "mysql",
371 DatabaseBackend::Sqlite => "sqlite",
372 }
373 );
374
375 with_settings!({ snapshot_suffix => suffix }, {
376 assert_snapshot!(sql);
377 });
378 }
379
380 #[rstest]
382 #[case::postgres_change_default(DatabaseBackend::Postgres)]
383 #[case::mysql_change_default(DatabaseBackend::MySql)]
384 #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
385 fn test_change_default_value(#[case] backend: DatabaseBackend) {
386 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
387 email_col.default = Some("'old@example.com'".into());
388
389 let schema = vec![table_def(
390 "users",
391 vec![
392 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
393 email_col,
394 ],
395 vec![],
396 )];
397
398 let result = build_modify_column_default(
399 &backend,
400 "users",
401 "email",
402 Some("'new@example.com'"),
403 &schema,
404 &[],
405 );
406 assert!(result.is_ok());
407 let queries = result.unwrap();
408 let sql = queries
409 .iter()
410 .map(|q| q.build(backend))
411 .collect::<Vec<String>>()
412 .join("\n");
413
414 let suffix = format!(
415 "{}_change_default",
416 match backend {
417 DatabaseBackend::Postgres => "postgres",
418 DatabaseBackend::MySql => "mysql",
419 DatabaseBackend::Sqlite => "sqlite",
420 }
421 );
422
423 with_settings!({ snapshot_suffix => suffix }, {
424 assert_snapshot!(sql);
425 });
426 }
427
428 #[rstest]
430 #[case::postgres_integer_default(DatabaseBackend::Postgres)]
431 #[case::mysql_integer_default(DatabaseBackend::MySql)]
432 #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
433 fn test_integer_default(#[case] backend: DatabaseBackend) {
434 let schema = vec![table_def(
435 "products",
436 vec![
437 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
438 col(
439 "quantity",
440 ColumnType::Simple(SimpleColumnType::Integer),
441 false,
442 ),
443 ],
444 vec![],
445 )];
446
447 let result =
448 build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema, &[]);
449 assert!(result.is_ok());
450 let queries = result.unwrap();
451 let sql = queries
452 .iter()
453 .map(|q| q.build(backend))
454 .collect::<Vec<String>>()
455 .join("\n");
456
457 let suffix = format!(
458 "{}_integer_default",
459 match backend {
460 DatabaseBackend::Postgres => "postgres",
461 DatabaseBackend::MySql => "mysql",
462 DatabaseBackend::Sqlite => "sqlite",
463 }
464 );
465
466 with_settings!({ snapshot_suffix => suffix }, {
467 assert_snapshot!(sql);
468 });
469 }
470
471 #[rstest]
473 #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
474 #[case::mysql_boolean_default(DatabaseBackend::MySql)]
475 #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
476 fn test_boolean_default(#[case] backend: DatabaseBackend) {
477 let schema = vec![table_def(
478 "users",
479 vec![
480 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
481 col(
482 "is_active",
483 ColumnType::Simple(SimpleColumnType::Boolean),
484 false,
485 ),
486 ],
487 vec![],
488 )];
489
490 let result =
491 build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema, &[]);
492 assert!(result.is_ok());
493 let queries = result.unwrap();
494 let sql = queries
495 .iter()
496 .map(|q| q.build(backend))
497 .collect::<Vec<String>>()
498 .join("\n");
499
500 let suffix = format!(
501 "{}_boolean_default",
502 match backend {
503 DatabaseBackend::Postgres => "postgres",
504 DatabaseBackend::MySql => "mysql",
505 DatabaseBackend::Sqlite => "sqlite",
506 }
507 );
508
509 with_settings!({ snapshot_suffix => suffix }, {
510 assert_snapshot!(sql);
511 });
512 }
513
514 #[rstest]
516 #[case::postgres_function_default(DatabaseBackend::Postgres)]
517 #[case::mysql_function_default(DatabaseBackend::MySql)]
518 #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
519 fn test_function_default(#[case] backend: DatabaseBackend) {
520 let schema = vec![table_def(
521 "events",
522 vec![
523 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
524 col(
525 "created_at",
526 ColumnType::Simple(SimpleColumnType::Timestamp),
527 false,
528 ),
529 ],
530 vec![],
531 )];
532
533 let default_value = match backend {
534 DatabaseBackend::Postgres => "NOW()",
535 DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
536 DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
537 };
538
539 let result = build_modify_column_default(
540 &backend,
541 "events",
542 "created_at",
543 Some(default_value),
544 &schema,
545 &[],
546 );
547 assert!(result.is_ok());
548 let queries = result.unwrap();
549 let sql = queries
550 .iter()
551 .map(|q| q.build(backend))
552 .collect::<Vec<String>>()
553 .join("\n");
554
555 let suffix = format!(
556 "{}_function_default",
557 match backend {
558 DatabaseBackend::Postgres => "postgres",
559 DatabaseBackend::MySql => "mysql",
560 DatabaseBackend::Sqlite => "sqlite",
561 }
562 );
563
564 with_settings!({ snapshot_suffix => suffix }, {
565 assert_snapshot!(sql);
566 });
567 }
568
569 #[rstest]
571 #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
572 #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
573 #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
574 fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
575 let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
576 status_col.default = Some("'pending'".into());
577
578 let schema = vec![table_def(
579 "orders",
580 vec![
581 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
582 status_col,
583 ],
584 vec![],
585 )];
586
587 let result = build_modify_column_default(
588 &backend,
589 "orders",
590 "status",
591 None, &schema,
593 &[],
594 );
595 assert!(result.is_ok());
596 let queries = result.unwrap();
597 let sql = queries
598 .iter()
599 .map(|q| q.build(backend))
600 .collect::<Vec<String>>()
601 .join("\n");
602
603 let suffix = format!(
604 "{}_drop_existing_default",
605 match backend {
606 DatabaseBackend::Postgres => "postgres",
607 DatabaseBackend::MySql => "mysql",
608 DatabaseBackend::Sqlite => "sqlite",
609 }
610 );
611
612 with_settings!({ snapshot_suffix => suffix }, {
613 assert_snapshot!(sql);
614 });
615 }
616}