1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7 convert_default_for_backend, normalize_enum_default, normalize_fill_with,
8 recreate_indexes_after_rebuild,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15 backend: &DatabaseBackend,
16 table: &str,
17 column: &ColumnDef,
18) -> TableAlterStatement {
19 let col_def = build_sea_column_def_with_table(backend, table, column);
20 Table::alter()
21 .table(Alias::new(table))
22 .add_column(col_def)
23 .to_owned()
24}
25
26fn is_enum_column(column: &ColumnDef) -> bool {
28 matches!(
29 column.r#type,
30 vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31 )
32}
33
34pub fn build_add_column(
35 backend: &DatabaseBackend,
36 table: &str,
37 column: &ColumnDef,
38 fill_with: Option<&str>,
39 current_schema: &[TableDef],
40 pending_constraints: &[vespertide_core::TableConstraint],
41) -> Result<Vec<BuiltQuery>, QueryError> {
42 let sqlite_needs_recreation =
45 *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
46
47 if sqlite_needs_recreation {
48 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
49
50 let mut new_columns = table_def.columns.clone();
51 new_columns.push(column.clone());
52
53 let temp_table = format!("{}_temp", table);
54
55 let create_query = build_sqlite_temp_table_create(
57 backend,
58 &temp_table,
59 table,
60 &new_columns,
61 &table_def.constraints,
62 );
63
64 let mut select_query = Query::select();
66 for col in &table_def.columns {
67 select_query = select_query.column(Alias::new(&col.name)).to_owned();
68 }
69 let normalized_fill = normalize_fill_with(fill_with);
70 let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
71 let converted = convert_default_for_backend(fill, backend);
72 Expr::cust(normalize_enum_default(&column.r#type, &converted))
73 } else if let Some(def) = &column.default {
74 let converted = convert_default_for_backend(&def.to_sql(), backend);
75 Expr::cust(normalize_enum_default(&column.r#type, &converted))
76 } else {
77 Expr::cust("NULL")
78 };
79 select_query = select_query
80 .expr_as(fill_expr, Alias::new(&column.name))
81 .from(Alias::new(table))
82 .to_owned();
83
84 let mut columns_alias: Vec<Alias> = table_def
85 .columns
86 .iter()
87 .map(|c| Alias::new(&c.name))
88 .collect();
89 columns_alias.push(Alias::new(&column.name));
90 let insert_stmt = Query::insert()
91 .into_table(Alias::new(&temp_table))
92 .columns(columns_alias)
93 .select_from(select_query)
94 .unwrap()
95 .to_owned();
96 let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
97
98 let drop_query =
99 BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
100 let rename_query = build_rename_table(&temp_table, table);
101
102 let index_queries =
105 recreate_indexes_after_rebuild(table, &table_def.constraints, pending_constraints);
106
107 let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
108 stmts.extend(index_queries);
109 return Ok(stmts);
110 }
111
112 let mut stmts: Vec<BuiltQuery> = Vec::new();
113
114 if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
116 stmts.push(BuiltQuery::Raw(create_type_sql));
117 }
118
119 let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
121
122 if needs_backfill {
123 let mut temp_col = column.clone();
125 temp_col.nullable = true;
126
127 stmts.push(BuiltQuery::AlterTable(Box::new(
128 build_add_column_alter_for_backend(backend, table, &temp_col),
129 )));
130
131 if let Some(fill) = normalize_fill_with(fill_with) {
133 let fill = convert_default_for_backend(&fill, backend);
134 let update_stmt = Query::update()
135 .table(Alias::new(table))
136 .value(Alias::new(&column.name), Expr::cust(fill))
137 .to_owned();
138 stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
139 }
140
141 let not_null_col = build_sea_column_def_with_table(backend, table, column);
143 let alter_not_null = Table::alter()
144 .table(Alias::new(table))
145 .modify_column(not_null_col)
146 .to_owned();
147 stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
148 } else {
149 stmts.push(BuiltQuery::AlterTable(Box::new(
150 build_add_column_alter_for_backend(backend, table, column),
151 )));
152 }
153
154 Ok(stmts)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use insta::{assert_snapshot, with_settings};
161 use rstest::rstest;
162 use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
163
164 #[rstest]
165 #[case::add_column_with_backfill_postgres(
166 "add_column_with_backfill_postgres",
167 DatabaseBackend::Postgres,
168 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
169 )]
170 #[case::add_column_with_backfill_mysql(
171 "add_column_with_backfill_mysql",
172 DatabaseBackend::MySql,
173 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
174 )]
175 #[case::add_column_with_backfill_sqlite(
176 "add_column_with_backfill_sqlite",
177 DatabaseBackend::Sqlite,
178 &["CREATE TABLE \"users_temp\""]
179 )]
180 #[case::add_column_simple_postgres(
181 "add_column_simple_postgres",
182 DatabaseBackend::Postgres,
183 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
184 )]
185 #[case::add_column_simple_mysql(
186 "add_column_simple_mysql",
187 DatabaseBackend::MySql,
188 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
189 )]
190 #[case::add_column_simple_sqlite(
191 "add_column_simple_sqlite",
192 DatabaseBackend::Sqlite,
193 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
194 )]
195 #[case::add_column_nullable_postgres(
196 "add_column_nullable_postgres",
197 DatabaseBackend::Postgres,
198 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
199 )]
200 #[case::add_column_nullable_mysql(
201 "add_column_nullable_mysql",
202 DatabaseBackend::MySql,
203 &["ALTER TABLE `users` ADD COLUMN `email` text"]
204 )]
205 #[case::add_column_nullable_sqlite(
206 "add_column_nullable_sqlite",
207 DatabaseBackend::Sqlite,
208 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
209 )]
210 fn test_add_column(
211 #[case] title: &str,
212 #[case] backend: DatabaseBackend,
213 #[case] expected: &[&str],
214 ) {
215 let column = ColumnDef {
216 name: if title.contains("age") {
217 "age"
218 } else if title.contains("nullable") {
219 "email"
220 } else {
221 "nickname"
222 }
223 .into(),
224 r#type: if title.contains("age") {
225 ColumnType::Simple(SimpleColumnType::Integer)
226 } else {
227 ColumnType::Simple(SimpleColumnType::Text)
228 },
229 nullable: !title.contains("backfill"),
230 default: None,
231 comment: None,
232 primary_key: None,
233 unique: None,
234 index: None,
235 foreign_key: None,
236 };
237 let fill_with = if title.contains("backfill") {
238 Some("0")
239 } else {
240 None
241 };
242 let current_schema = vec![TableDef {
243 name: "users".into(),
244 description: None,
245 columns: vec![ColumnDef {
246 name: "id".into(),
247 r#type: ColumnType::Simple(SimpleColumnType::Integer),
248 nullable: false,
249 default: None,
250 comment: None,
251 primary_key: None,
252 unique: None,
253 index: None,
254 foreign_key: None,
255 }],
256 constraints: vec![],
257 }];
258 let result =
259 build_add_column(&backend, "users", &column, fill_with, ¤t_schema, &[]).unwrap();
260 let sql = result[0].build(backend);
261 for exp in expected {
262 assert!(
263 sql.contains(exp),
264 "Expected SQL to contain '{}', got: {}",
265 exp,
266 sql
267 );
268 }
269
270 with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
271 assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
272 });
273 }
274
275 #[test]
276 fn test_add_column_sqlite_table_not_found() {
277 let column = ColumnDef {
278 name: "nickname".into(),
279 r#type: ColumnType::Simple(SimpleColumnType::Text),
280 nullable: false,
281 default: None,
282 comment: None,
283 primary_key: None,
284 unique: None,
285 index: None,
286 foreign_key: None,
287 };
288 let current_schema = vec![]; let result = build_add_column(
290 &DatabaseBackend::Sqlite,
291 "users",
292 &column,
293 None,
294 ¤t_schema,
295 &[],
296 );
297 assert!(result.is_err());
298 let err_msg = result.unwrap_err().to_string();
299 assert!(err_msg.contains("Table 'users' not found in current schema"));
300 }
301
302 #[test]
303 fn test_add_column_sqlite_with_default() {
304 let column = ColumnDef {
305 name: "age".into(),
306 r#type: ColumnType::Simple(SimpleColumnType::Integer),
307 nullable: false,
308 default: Some("18".into()),
309 comment: None,
310 primary_key: None,
311 unique: None,
312 index: None,
313 foreign_key: None,
314 };
315 let current_schema = vec![TableDef {
316 name: "users".into(),
317 description: None,
318 columns: vec![ColumnDef {
319 name: "id".into(),
320 r#type: ColumnType::Simple(SimpleColumnType::Integer),
321 nullable: false,
322 default: None,
323 comment: None,
324 primary_key: None,
325 unique: None,
326 index: None,
327 foreign_key: None,
328 }],
329 constraints: vec![],
330 }];
331 let result = build_add_column(
332 &DatabaseBackend::Sqlite,
333 "users",
334 &column,
335 None,
336 ¤t_schema,
337 &[],
338 );
339 assert!(result.is_ok());
340 let queries = result.unwrap();
341 let sql = queries
342 .iter()
343 .map(|q| q.build(DatabaseBackend::Sqlite))
344 .collect::<Vec<String>>()
345 .join("\n");
346 assert!(sql.contains("18"));
348 }
349
350 #[test]
351 fn test_add_column_sqlite_without_fill_or_default() {
352 let column = ColumnDef {
353 name: "age".into(),
354 r#type: ColumnType::Simple(SimpleColumnType::Integer),
355 nullable: false,
356 default: None,
357 comment: None,
358 primary_key: None,
359 unique: None,
360 index: None,
361 foreign_key: None,
362 };
363 let current_schema = vec![TableDef {
364 name: "users".into(),
365 description: None,
366 columns: vec![ColumnDef {
367 name: "id".into(),
368 r#type: ColumnType::Simple(SimpleColumnType::Integer),
369 nullable: false,
370 default: None,
371 comment: None,
372 primary_key: None,
373 unique: None,
374 index: None,
375 foreign_key: None,
376 }],
377 constraints: vec![],
378 }];
379 let result = build_add_column(
380 &DatabaseBackend::Sqlite,
381 "users",
382 &column,
383 None,
384 ¤t_schema,
385 &[],
386 );
387 assert!(result.is_ok());
388 let queries = result.unwrap();
389 let sql = queries
390 .iter()
391 .map(|q| q.build(DatabaseBackend::Sqlite))
392 .collect::<Vec<String>>()
393 .join("\n");
394 assert!(sql.contains("NULL"));
396 }
397
398 #[test]
399 fn test_add_column_sqlite_with_indexes() {
400 use vespertide_core::TableConstraint;
401
402 let column = ColumnDef {
403 name: "nickname".into(),
404 r#type: ColumnType::Simple(SimpleColumnType::Text),
405 nullable: false,
406 default: None,
407 comment: None,
408 primary_key: None,
409 unique: None,
410 index: None,
411 foreign_key: None,
412 };
413 let current_schema = vec![TableDef {
414 name: "users".into(),
415 description: None,
416 columns: vec![ColumnDef {
417 name: "id".into(),
418 r#type: ColumnType::Simple(SimpleColumnType::Integer),
419 nullable: false,
420 default: None,
421 comment: None,
422 primary_key: None,
423 unique: None,
424 index: None,
425 foreign_key: None,
426 }],
427 constraints: vec![TableConstraint::Index {
428 name: Some("idx_id".into()),
429 columns: vec!["id".into()],
430 }],
431 }];
432 let result = build_add_column(
433 &DatabaseBackend::Sqlite,
434 "users",
435 &column,
436 None,
437 ¤t_schema,
438 &[],
439 );
440 assert!(result.is_ok());
441 let queries = result.unwrap();
442 let sql = queries
443 .iter()
444 .map(|q| q.build(DatabaseBackend::Sqlite))
445 .collect::<Vec<String>>()
446 .join("\n");
447 assert!(sql.contains("CREATE INDEX"));
449 assert!(sql.contains("idx_id"));
450 }
451
452 #[rstest]
453 #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
454 #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
455 #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
456 fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
457 use insta::{assert_snapshot, with_settings};
458 use vespertide_core::{ComplexColumnType, EnumValues};
459
460 let column = ColumnDef {
462 name: "status".into(),
463 r#type: ColumnType::Complex(ComplexColumnType::Enum {
464 name: "status_type".into(),
465 values: EnumValues::String(vec!["active".into(), "inactive".into()]),
466 }),
467 nullable: true,
468 default: None,
469 comment: None,
470 primary_key: None,
471 unique: None,
472 index: None,
473 foreign_key: None,
474 };
475 let current_schema = vec![TableDef {
476 name: "users".into(),
477 description: None,
478 columns: vec![ColumnDef {
479 name: "id".into(),
480 r#type: ColumnType::Simple(SimpleColumnType::Integer),
481 nullable: false,
482 default: None,
483 comment: None,
484 primary_key: None,
485 unique: None,
486 index: None,
487 foreign_key: None,
488 }],
489 constraints: vec![],
490 }];
491 let result = build_add_column(&backend, "users", &column, None, ¤t_schema, &[]);
492 assert!(result.is_ok());
493 let queries = result.unwrap();
494 let sql = queries
495 .iter()
496 .map(|q| q.build(backend))
497 .collect::<Vec<String>>()
498 .join(";\n");
499
500 with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
501 assert_snapshot!(sql);
502 });
503 }
504
505 #[rstest]
506 #[case::postgres(DatabaseBackend::Postgres)]
507 #[case::mysql(DatabaseBackend::MySql)]
508 #[case::sqlite(DatabaseBackend::Sqlite)]
509 fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
510 use insta::{assert_snapshot, with_settings};
511 use vespertide_core::{ComplexColumnType, EnumValues};
512
513 let column = ColumnDef {
515 name: "status".into(),
516 r#type: ColumnType::Complex(ComplexColumnType::Enum {
517 name: "user_status".into(),
518 values: EnumValues::String(vec![
519 "active".into(),
520 "inactive".into(),
521 "pending".into(),
522 ]),
523 }),
524 nullable: false,
525 default: Some("active".into()),
526 comment: None,
527 primary_key: None,
528 unique: None,
529 index: None,
530 foreign_key: None,
531 };
532 let current_schema = vec![TableDef {
533 name: "users".into(),
534 description: None,
535 columns: vec![ColumnDef {
536 name: "id".into(),
537 r#type: ColumnType::Simple(SimpleColumnType::Integer),
538 nullable: false,
539 default: None,
540 comment: None,
541 primary_key: None,
542 unique: None,
543 index: None,
544 foreign_key: None,
545 }],
546 constraints: vec![],
547 }];
548 let result = build_add_column(&backend, "users", &column, None, ¤t_schema, &[]);
549 assert!(result.is_ok());
550 let queries = result.unwrap();
551 let sql = queries
552 .iter()
553 .map(|q| q.build(backend))
554 .collect::<Vec<String>>()
555 .join(";\n");
556
557 with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
558 assert_snapshot!(sql);
559 });
560 }
561
562 #[rstest]
563 #[case::postgres(DatabaseBackend::Postgres)]
564 #[case::mysql(DatabaseBackend::MySql)]
565 #[case::sqlite(DatabaseBackend::Sqlite)]
566 fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
567 use insta::{assert_snapshot, with_settings};
568
569 let column = ColumnDef {
571 name: "nickname".into(),
572 r#type: ColumnType::Simple(SimpleColumnType::Text),
573 nullable: false,
574 default: Some("".into()), comment: None,
576 primary_key: None,
577 unique: None,
578 index: None,
579 foreign_key: None,
580 };
581 let current_schema = vec![TableDef {
582 name: "users".into(),
583 description: None,
584 columns: vec![ColumnDef {
585 name: "id".into(),
586 r#type: ColumnType::Simple(SimpleColumnType::Integer),
587 nullable: false,
588 default: None,
589 comment: None,
590 primary_key: None,
591 unique: None,
592 index: None,
593 foreign_key: None,
594 }],
595 constraints: vec![],
596 }];
597 let result = build_add_column(&backend, "users", &column, None, ¤t_schema, &[]);
598 assert!(result.is_ok());
599 let queries = result.unwrap();
600 let sql = queries
601 .iter()
602 .map(|q| q.build(backend))
603 .collect::<Vec<String>>()
604 .join(";\n");
605
606 assert!(
608 sql.contains("''"),
609 "Expected SQL to contain empty string literal '', got: {}",
610 sql
611 );
612
613 with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
614 assert_snapshot!(sql);
615 });
616 }
617
618 #[rstest]
621 #[case::postgres(DatabaseBackend::Postgres)]
622 #[case::mysql(DatabaseBackend::MySql)]
623 #[case::sqlite(DatabaseBackend::Sqlite)]
624 fn test_add_column_with_pg_type_cast_default(#[case] backend: DatabaseBackend) {
625 let column = ColumnDef {
626 name: "story_index".into(),
627 r#type: ColumnType::Simple(SimpleColumnType::Json),
628 nullable: false,
629 default: Some("'[]'::json".into()),
630 comment: None,
631 primary_key: None,
632 unique: None,
633 index: None,
634 foreign_key: None,
635 };
636 let current_schema = vec![TableDef {
637 name: "project".into(),
638 description: None,
639 columns: vec![ColumnDef {
640 name: "id".into(),
641 r#type: ColumnType::Simple(SimpleColumnType::Integer),
642 nullable: false,
643 default: None,
644 comment: None,
645 primary_key: None,
646 unique: None,
647 index: None,
648 foreign_key: None,
649 }],
650 constraints: vec![],
651 }];
652 let result =
653 build_add_column(&backend, "project", &column, None, ¤t_schema, &[]).unwrap();
654 let sql = result
655 .iter()
656 .map(|q| q.build(backend))
657 .collect::<Vec<String>>()
658 .join("\n");
659
660 if backend == DatabaseBackend::Sqlite {
662 assert!(
663 !sql.contains("::json"),
664 "SQLite SQL should not contain ::json cast, got: {}",
665 sql
666 );
667 }
668
669 if backend == DatabaseBackend::MySql {
671 assert!(
672 !sql.contains("::json"),
673 "MySQL SQL should not contain ::json cast, got: {}",
674 sql
675 );
676 }
677
678 with_settings!({ snapshot_suffix => format!("pg_type_cast_default_{:?}", backend) }, {
679 assert_snapshot!(sql);
680 });
681 }
682
683 #[rstest]
684 #[case::postgres(DatabaseBackend::Postgres)]
685 #[case::mysql(DatabaseBackend::MySql)]
686 #[case::sqlite(DatabaseBackend::Sqlite)]
687 fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
688 use insta::{assert_snapshot, with_settings};
689
690 let column = ColumnDef {
692 name: "nickname".into(),
693 r#type: ColumnType::Simple(SimpleColumnType::Text),
694 nullable: false,
695 default: None,
696 comment: None,
697 primary_key: None,
698 unique: None,
699 index: None,
700 foreign_key: None,
701 };
702 let current_schema = vec![TableDef {
703 name: "users".into(),
704 description: None,
705 columns: vec![ColumnDef {
706 name: "id".into(),
707 r#type: ColumnType::Simple(SimpleColumnType::Integer),
708 nullable: false,
709 default: None,
710 comment: None,
711 primary_key: None,
712 unique: None,
713 index: None,
714 foreign_key: None,
715 }],
716 constraints: vec![],
717 }];
718 let result = build_add_column(&backend, "users", &column, Some(""), ¤t_schema, &[]);
720 assert!(result.is_ok());
721 let queries = result.unwrap();
722 let sql = queries
723 .iter()
724 .map(|q| q.build(backend))
725 .collect::<Vec<String>>()
726 .join(";\n");
727
728 assert!(
730 sql.contains("''"),
731 "Expected SQL to contain empty string literal '', got: {}",
732 sql
733 );
734
735 with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
736 assert_snapshot!(sql);
737 });
738 }
739}