1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, normalize_enum_default,
7 recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13pub fn build_modify_column_default(
15 backend: &DatabaseBackend,
16 table: &str,
17 column: &str,
18 new_default: Option<&str>,
19 current_schema: &[TableDef],
20) -> Result<Vec<BuiltQuery>, QueryError> {
21 let mut queries = Vec::new();
22
23 match backend {
24 DatabaseBackend::Postgres => {
25 let alter_sql = if let Some(default_value) = new_default {
26 let column_type = current_schema
28 .iter()
29 .find(|t| t.name == table)
30 .and_then(|t| t.columns.iter().find(|c| c.name == column))
31 .map(|c| &c.r#type);
32
33 let normalized_default = if let Some(col_type) = column_type {
34 normalize_enum_default(col_type, default_value)
35 } else {
36 default_value.to_string()
37 };
38
39 format!(
40 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
41 table, column, normalized_default
42 )
43 } else {
44 format!(
45 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
46 table, column
47 )
48 };
49 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
50 }
51 DatabaseBackend::MySql => {
52 let table_def = current_schema
54 .iter()
55 .find(|t| t.name == table)
56 .ok_or_else(|| {
57 QueryError::Other(format!("Table '{}' not found in current schema.", table))
58 })?;
59
60 let column_def = table_def
61 .columns
62 .iter()
63 .find(|c| c.name == column)
64 .ok_or_else(|| {
65 QueryError::Other(format!(
66 "Column '{}' not found in table '{}'.",
67 column, table
68 ))
69 })?;
70
71 let modified_col_def = ColumnDef {
73 default: new_default.map(|s| s.into()),
74 ..column_def.clone()
75 };
76
77 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
78
79 let stmt = Table::alter()
80 .table(Alias::new(table))
81 .modify_column(sea_col)
82 .to_owned();
83 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
84 }
85 DatabaseBackend::Sqlite => {
86 let table_def = current_schema
89 .iter()
90 .find(|t| t.name == table)
91 .ok_or_else(|| {
92 QueryError::Other(format!("Table '{}' not found in current schema.", table))
93 })?;
94
95 let mut new_columns = table_def.columns.clone();
97 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
98 col.default = new_default.map(|s| s.into());
99 }
100
101 let temp_table = format!("{}_temp", table);
103
104 let create_query = build_sqlite_temp_table_create(
106 backend,
107 &temp_table,
108 table,
109 &new_columns,
110 &table_def.constraints,
111 );
112 queries.push(create_query);
113
114 let column_aliases: Vec<Alias> = table_def
116 .columns
117 .iter()
118 .map(|c| Alias::new(&c.name))
119 .collect();
120 let mut select_query = Query::select();
121 for col_alias in &column_aliases {
122 select_query = select_query.column(col_alias.clone()).to_owned();
123 }
124 select_query = select_query.from(Alias::new(table)).to_owned();
125
126 let insert_stmt = Query::insert()
127 .into_table(Alias::new(&temp_table))
128 .columns(column_aliases.clone())
129 .select_from(select_query)
130 .unwrap()
131 .to_owned();
132 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
133
134 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
136 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
137
138 queries.push(build_rename_table(&temp_table, table));
140
141 queries.extend(recreate_indexes_after_rebuild(
143 table,
144 &table_def.constraints,
145 &[],
146 ));
147 }
148 }
149
150 Ok(queries)
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use insta::{assert_snapshot, with_settings};
157 use rstest::rstest;
158 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
159
160 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
161 ColumnDef {
162 name: name.to_string(),
163 r#type: ty,
164 nullable,
165 default: None,
166 comment: None,
167 primary_key: None,
168 unique: None,
169 index: None,
170 foreign_key: None,
171 }
172 }
173
174 fn table_def(
175 name: &str,
176 columns: Vec<ColumnDef>,
177 constraints: Vec<TableConstraint>,
178 ) -> TableDef {
179 TableDef {
180 name: name.to_string(),
181 description: None,
182 columns,
183 constraints,
184 }
185 }
186
187 #[rstest]
188 #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
189 #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
190 #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
191 #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
192 #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
193 #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
194 fn test_build_modify_column_default(
195 #[case] backend: DatabaseBackend,
196 #[case] new_default: Option<&str>,
197 ) {
198 let schema = vec![table_def(
199 "users",
200 vec![
201 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
202 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
203 ],
204 vec![],
205 )];
206
207 let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
208 assert!(result.is_ok());
209 let queries = result.unwrap();
210 let sql = queries
211 .iter()
212 .map(|q| q.build(backend))
213 .collect::<Vec<String>>()
214 .join("\n");
215
216 let suffix = format!(
217 "{}_{}_users",
218 match backend {
219 DatabaseBackend::Postgres => "postgres",
220 DatabaseBackend::MySql => "mysql",
221 DatabaseBackend::Sqlite => "sqlite",
222 },
223 if new_default.is_some() {
224 "set_default"
225 } else {
226 "drop_default"
227 }
228 );
229
230 with_settings!({ snapshot_suffix => suffix }, {
231 assert_snapshot!(sql);
232 });
233 }
234
235 #[rstest]
237 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
238 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
239 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
240 fn test_table_not_found(#[case] backend: DatabaseBackend) {
241 if backend == DatabaseBackend::Postgres {
243 return;
244 }
245
246 let result =
247 build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
248 assert!(result.is_err());
249 let err_msg = result.unwrap_err().to_string();
250 assert!(err_msg.contains("Table 'users' not found"));
251 }
252
253 #[rstest]
255 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
256 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
257 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
258 fn test_column_not_found(#[case] backend: DatabaseBackend) {
259 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
262 return;
263 }
264
265 let schema = vec![table_def(
266 "users",
267 vec![col(
268 "id",
269 ColumnType::Simple(SimpleColumnType::Integer),
270 false,
271 )],
272 vec![],
273 )];
274
275 let result =
276 build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
277 assert!(result.is_err());
278 let err_msg = result.unwrap_err().to_string();
279 assert!(err_msg.contains("Column 'email' not found"));
280 }
281
282 #[test]
285 fn test_postgres_column_not_in_schema_uses_default_as_is() {
286 let schema = vec![table_def(
287 "users",
288 vec![col(
289 "id",
290 ColumnType::Simple(SimpleColumnType::Integer),
291 false,
292 )],
293 vec![],
295 )];
296
297 let result = build_modify_column_default(
299 &DatabaseBackend::Postgres,
300 "users",
301 "status", Some("'active'"),
303 &schema,
304 );
305 assert!(result.is_ok());
306 let queries = result.unwrap();
307 let sql = queries
308 .iter()
309 .map(|q| q.build(DatabaseBackend::Postgres))
310 .collect::<Vec<String>>()
311 .join("\n");
312
313 assert!(sql.contains("ALTER TABLE \"users\" ALTER COLUMN \"status\" SET DEFAULT 'active'"));
315 }
316
317 #[rstest]
319 #[case::postgres_with_index(DatabaseBackend::Postgres)]
320 #[case::mysql_with_index(DatabaseBackend::MySql)]
321 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
322 fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
323 let schema = vec![table_def(
324 "users",
325 vec![
326 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
327 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
328 ],
329 vec![TableConstraint::Index {
330 name: Some("idx_users_email".into()),
331 columns: vec!["email".into()],
332 }],
333 )];
334
335 let result = build_modify_column_default(
336 &backend,
337 "users",
338 "email",
339 Some("'default@example.com'"),
340 &schema,
341 );
342 assert!(result.is_ok());
343 let queries = result.unwrap();
344 let sql = queries
345 .iter()
346 .map(|q| q.build(backend))
347 .collect::<Vec<String>>()
348 .join("\n");
349
350 if backend == DatabaseBackend::Sqlite {
352 assert!(sql.contains("CREATE INDEX"));
353 assert!(sql.contains("idx_users_email"));
354 }
355
356 let suffix = format!(
357 "{}_with_index",
358 match backend {
359 DatabaseBackend::Postgres => "postgres",
360 DatabaseBackend::MySql => "mysql",
361 DatabaseBackend::Sqlite => "sqlite",
362 }
363 );
364
365 with_settings!({ snapshot_suffix => suffix }, {
366 assert_snapshot!(sql);
367 });
368 }
369
370 #[rstest]
372 #[case::postgres_change_default(DatabaseBackend::Postgres)]
373 #[case::mysql_change_default(DatabaseBackend::MySql)]
374 #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
375 fn test_change_default_value(#[case] backend: DatabaseBackend) {
376 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
377 email_col.default = Some("'old@example.com'".into());
378
379 let schema = vec![table_def(
380 "users",
381 vec![
382 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
383 email_col,
384 ],
385 vec![],
386 )];
387
388 let result = build_modify_column_default(
389 &backend,
390 "users",
391 "email",
392 Some("'new@example.com'"),
393 &schema,
394 );
395 assert!(result.is_ok());
396 let queries = result.unwrap();
397 let sql = queries
398 .iter()
399 .map(|q| q.build(backend))
400 .collect::<Vec<String>>()
401 .join("\n");
402
403 let suffix = format!(
404 "{}_change_default",
405 match backend {
406 DatabaseBackend::Postgres => "postgres",
407 DatabaseBackend::MySql => "mysql",
408 DatabaseBackend::Sqlite => "sqlite",
409 }
410 );
411
412 with_settings!({ snapshot_suffix => suffix }, {
413 assert_snapshot!(sql);
414 });
415 }
416
417 #[rstest]
419 #[case::postgres_integer_default(DatabaseBackend::Postgres)]
420 #[case::mysql_integer_default(DatabaseBackend::MySql)]
421 #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
422 fn test_integer_default(#[case] backend: DatabaseBackend) {
423 let schema = vec![table_def(
424 "products",
425 vec![
426 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
427 col(
428 "quantity",
429 ColumnType::Simple(SimpleColumnType::Integer),
430 false,
431 ),
432 ],
433 vec![],
434 )];
435
436 let result =
437 build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
438 assert!(result.is_ok());
439 let queries = result.unwrap();
440 let sql = queries
441 .iter()
442 .map(|q| q.build(backend))
443 .collect::<Vec<String>>()
444 .join("\n");
445
446 let suffix = format!(
447 "{}_integer_default",
448 match backend {
449 DatabaseBackend::Postgres => "postgres",
450 DatabaseBackend::MySql => "mysql",
451 DatabaseBackend::Sqlite => "sqlite",
452 }
453 );
454
455 with_settings!({ snapshot_suffix => suffix }, {
456 assert_snapshot!(sql);
457 });
458 }
459
460 #[rstest]
462 #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
463 #[case::mysql_boolean_default(DatabaseBackend::MySql)]
464 #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
465 fn test_boolean_default(#[case] backend: DatabaseBackend) {
466 let schema = vec![table_def(
467 "users",
468 vec![
469 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
470 col(
471 "is_active",
472 ColumnType::Simple(SimpleColumnType::Boolean),
473 false,
474 ),
475 ],
476 vec![],
477 )];
478
479 let result =
480 build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
481 assert!(result.is_ok());
482 let queries = result.unwrap();
483 let sql = queries
484 .iter()
485 .map(|q| q.build(backend))
486 .collect::<Vec<String>>()
487 .join("\n");
488
489 let suffix = format!(
490 "{}_boolean_default",
491 match backend {
492 DatabaseBackend::Postgres => "postgres",
493 DatabaseBackend::MySql => "mysql",
494 DatabaseBackend::Sqlite => "sqlite",
495 }
496 );
497
498 with_settings!({ snapshot_suffix => suffix }, {
499 assert_snapshot!(sql);
500 });
501 }
502
503 #[rstest]
505 #[case::postgres_function_default(DatabaseBackend::Postgres)]
506 #[case::mysql_function_default(DatabaseBackend::MySql)]
507 #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
508 fn test_function_default(#[case] backend: DatabaseBackend) {
509 let schema = vec![table_def(
510 "events",
511 vec![
512 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
513 col(
514 "created_at",
515 ColumnType::Simple(SimpleColumnType::Timestamp),
516 false,
517 ),
518 ],
519 vec![],
520 )];
521
522 let default_value = match backend {
523 DatabaseBackend::Postgres => "NOW()",
524 DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
525 DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
526 };
527
528 let result = build_modify_column_default(
529 &backend,
530 "events",
531 "created_at",
532 Some(default_value),
533 &schema,
534 );
535 assert!(result.is_ok());
536 let queries = result.unwrap();
537 let sql = queries
538 .iter()
539 .map(|q| q.build(backend))
540 .collect::<Vec<String>>()
541 .join("\n");
542
543 let suffix = format!(
544 "{}_function_default",
545 match backend {
546 DatabaseBackend::Postgres => "postgres",
547 DatabaseBackend::MySql => "mysql",
548 DatabaseBackend::Sqlite => "sqlite",
549 }
550 );
551
552 with_settings!({ snapshot_suffix => suffix }, {
553 assert_snapshot!(sql);
554 });
555 }
556
557 #[rstest]
559 #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
560 #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
561 #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
562 fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
563 let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
564 status_col.default = Some("'pending'".into());
565
566 let schema = vec![table_def(
567 "orders",
568 vec![
569 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
570 status_col,
571 ],
572 vec![],
573 )];
574
575 let result = build_modify_column_default(
576 &backend, "orders", "status", None, &schema,
578 );
579 assert!(result.is_ok());
580 let queries = result.unwrap();
581 let sql = queries
582 .iter()
583 .map(|q| q.build(backend))
584 .collect::<Vec<String>>()
585 .join("\n");
586
587 let suffix = format!(
588 "{}_drop_existing_default",
589 match backend {
590 DatabaseBackend::Postgres => "postgres",
591 DatabaseBackend::MySql => "mysql",
592 DatabaseBackend::Sqlite => "sqlite",
593 }
594 );
595
596 with_settings!({ snapshot_suffix => suffix }, {
597 assert_snapshot!(sql);
598 });
599 }
600}