1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7 convert_default_for_backend, normalize_enum_default, normalize_fill_with,
8 recreate_indexes_after_rebuild,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15 backend: &DatabaseBackend,
16 table: &str,
17 column: &ColumnDef,
18) -> TableAlterStatement {
19 let col_def = build_sea_column_def_with_table(backend, table, column);
20 Table::alter()
21 .table(Alias::new(table))
22 .add_column(col_def)
23 .to_owned()
24}
25
26fn is_enum_column(column: &ColumnDef) -> bool {
28 matches!(
29 column.r#type,
30 vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31 )
32}
33
34pub fn build_add_column(
35 backend: &DatabaseBackend,
36 table: &str,
37 column: &ColumnDef,
38 fill_with: Option<&str>,
39 current_schema: &[TableDef],
40) -> Result<Vec<BuiltQuery>, QueryError> {
41 let sqlite_needs_recreation =
44 *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
45
46 if sqlite_needs_recreation {
47 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
48
49 let mut new_columns = table_def.columns.clone();
50 new_columns.push(column.clone());
51
52 let temp_table = format!("{}_temp", table);
53
54 let create_query = build_sqlite_temp_table_create(
56 backend,
57 &temp_table,
58 table,
59 &new_columns,
60 &table_def.constraints,
61 );
62
63 let mut select_query = Query::select();
65 for col in &table_def.columns {
66 select_query = select_query.column(Alias::new(&col.name)).to_owned();
67 }
68 let normalized_fill = normalize_fill_with(fill_with);
69 let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
70 let converted = convert_default_for_backend(fill, backend);
71 Expr::cust(normalize_enum_default(&column.r#type, &converted))
72 } else if let Some(def) = &column.default {
73 let converted = convert_default_for_backend(&def.to_sql(), backend);
74 Expr::cust(normalize_enum_default(&column.r#type, &converted))
75 } else {
76 Expr::cust("NULL")
77 };
78 select_query = select_query
79 .expr_as(fill_expr, Alias::new(&column.name))
80 .from(Alias::new(table))
81 .to_owned();
82
83 let mut columns_alias: Vec<Alias> = table_def
84 .columns
85 .iter()
86 .map(|c| Alias::new(&c.name))
87 .collect();
88 columns_alias.push(Alias::new(&column.name));
89 let insert_stmt = Query::insert()
90 .into_table(Alias::new(&temp_table))
91 .columns(columns_alias)
92 .select_from(select_query)
93 .unwrap()
94 .to_owned();
95 let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
96
97 let drop_query =
98 BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
99 let rename_query = build_rename_table(&temp_table, table);
100
101 let index_queries = recreate_indexes_after_rebuild(table, &table_def.constraints, &[]);
103
104 let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
105 stmts.extend(index_queries);
106 return Ok(stmts);
107 }
108
109 let mut stmts: Vec<BuiltQuery> = Vec::new();
110
111 if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
113 stmts.push(BuiltQuery::Raw(create_type_sql));
114 }
115
116 let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
118
119 if needs_backfill {
120 let mut temp_col = column.clone();
122 temp_col.nullable = true;
123
124 stmts.push(BuiltQuery::AlterTable(Box::new(
125 build_add_column_alter_for_backend(backend, table, &temp_col),
126 )));
127
128 if let Some(fill) = normalize_fill_with(fill_with) {
130 let fill = convert_default_for_backend(&fill, backend);
131 let update_stmt = Query::update()
132 .table(Alias::new(table))
133 .value(Alias::new(&column.name), Expr::cust(fill))
134 .to_owned();
135 stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
136 }
137
138 let not_null_col = build_sea_column_def_with_table(backend, table, column);
140 let alter_not_null = Table::alter()
141 .table(Alias::new(table))
142 .modify_column(not_null_col)
143 .to_owned();
144 stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
145 } else {
146 stmts.push(BuiltQuery::AlterTable(Box::new(
147 build_add_column_alter_for_backend(backend, table, column),
148 )));
149 }
150
151 Ok(stmts)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use insta::{assert_snapshot, with_settings};
158 use rstest::rstest;
159 use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
160
161 #[rstest]
162 #[case::add_column_with_backfill_postgres(
163 "add_column_with_backfill_postgres",
164 DatabaseBackend::Postgres,
165 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
166 )]
167 #[case::add_column_with_backfill_mysql(
168 "add_column_with_backfill_mysql",
169 DatabaseBackend::MySql,
170 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
171 )]
172 #[case::add_column_with_backfill_sqlite(
173 "add_column_with_backfill_sqlite",
174 DatabaseBackend::Sqlite,
175 &["CREATE TABLE \"users_temp\""]
176 )]
177 #[case::add_column_simple_postgres(
178 "add_column_simple_postgres",
179 DatabaseBackend::Postgres,
180 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
181 )]
182 #[case::add_column_simple_mysql(
183 "add_column_simple_mysql",
184 DatabaseBackend::MySql,
185 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
186 )]
187 #[case::add_column_simple_sqlite(
188 "add_column_simple_sqlite",
189 DatabaseBackend::Sqlite,
190 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
191 )]
192 #[case::add_column_nullable_postgres(
193 "add_column_nullable_postgres",
194 DatabaseBackend::Postgres,
195 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
196 )]
197 #[case::add_column_nullable_mysql(
198 "add_column_nullable_mysql",
199 DatabaseBackend::MySql,
200 &["ALTER TABLE `users` ADD COLUMN `email` text"]
201 )]
202 #[case::add_column_nullable_sqlite(
203 "add_column_nullable_sqlite",
204 DatabaseBackend::Sqlite,
205 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
206 )]
207 fn test_add_column(
208 #[case] title: &str,
209 #[case] backend: DatabaseBackend,
210 #[case] expected: &[&str],
211 ) {
212 let column = ColumnDef {
213 name: if title.contains("age") {
214 "age"
215 } else if title.contains("nullable") {
216 "email"
217 } else {
218 "nickname"
219 }
220 .into(),
221 r#type: if title.contains("age") {
222 ColumnType::Simple(SimpleColumnType::Integer)
223 } else {
224 ColumnType::Simple(SimpleColumnType::Text)
225 },
226 nullable: !title.contains("backfill"),
227 default: None,
228 comment: None,
229 primary_key: None,
230 unique: None,
231 index: None,
232 foreign_key: None,
233 };
234 let fill_with = if title.contains("backfill") {
235 Some("0")
236 } else {
237 None
238 };
239 let current_schema = vec![TableDef {
240 name: "users".into(),
241 description: None,
242 columns: vec![ColumnDef {
243 name: "id".into(),
244 r#type: ColumnType::Simple(SimpleColumnType::Integer),
245 nullable: false,
246 default: None,
247 comment: None,
248 primary_key: None,
249 unique: None,
250 index: None,
251 foreign_key: None,
252 }],
253 constraints: vec![],
254 }];
255 let result =
256 build_add_column(&backend, "users", &column, fill_with, ¤t_schema).unwrap();
257 let sql = result[0].build(backend);
258 for exp in expected {
259 assert!(
260 sql.contains(exp),
261 "Expected SQL to contain '{}', got: {}",
262 exp,
263 sql
264 );
265 }
266
267 with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
268 assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
269 });
270 }
271
272 #[test]
273 fn test_add_column_sqlite_table_not_found() {
274 let column = ColumnDef {
275 name: "nickname".into(),
276 r#type: ColumnType::Simple(SimpleColumnType::Text),
277 nullable: false,
278 default: None,
279 comment: None,
280 primary_key: None,
281 unique: None,
282 index: None,
283 foreign_key: None,
284 };
285 let current_schema = vec![]; let result = build_add_column(
287 &DatabaseBackend::Sqlite,
288 "users",
289 &column,
290 None,
291 ¤t_schema,
292 );
293 assert!(result.is_err());
294 let err_msg = result.unwrap_err().to_string();
295 assert!(err_msg.contains("Table 'users' not found in current schema"));
296 }
297
298 #[test]
299 fn test_add_column_sqlite_with_default() {
300 let column = ColumnDef {
301 name: "age".into(),
302 r#type: ColumnType::Simple(SimpleColumnType::Integer),
303 nullable: false,
304 default: Some("18".into()),
305 comment: None,
306 primary_key: None,
307 unique: None,
308 index: None,
309 foreign_key: None,
310 };
311 let current_schema = vec![TableDef {
312 name: "users".into(),
313 description: None,
314 columns: vec![ColumnDef {
315 name: "id".into(),
316 r#type: ColumnType::Simple(SimpleColumnType::Integer),
317 nullable: false,
318 default: None,
319 comment: None,
320 primary_key: None,
321 unique: None,
322 index: None,
323 foreign_key: None,
324 }],
325 constraints: vec![],
326 }];
327 let result = build_add_column(
328 &DatabaseBackend::Sqlite,
329 "users",
330 &column,
331 None,
332 ¤t_schema,
333 );
334 assert!(result.is_ok());
335 let queries = result.unwrap();
336 let sql = queries
337 .iter()
338 .map(|q| q.build(DatabaseBackend::Sqlite))
339 .collect::<Vec<String>>()
340 .join("\n");
341 assert!(sql.contains("18"));
343 }
344
345 #[test]
346 fn test_add_column_sqlite_without_fill_or_default() {
347 let column = ColumnDef {
348 name: "age".into(),
349 r#type: ColumnType::Simple(SimpleColumnType::Integer),
350 nullable: false,
351 default: None,
352 comment: None,
353 primary_key: None,
354 unique: None,
355 index: None,
356 foreign_key: None,
357 };
358 let current_schema = vec![TableDef {
359 name: "users".into(),
360 description: None,
361 columns: vec![ColumnDef {
362 name: "id".into(),
363 r#type: ColumnType::Simple(SimpleColumnType::Integer),
364 nullable: false,
365 default: None,
366 comment: None,
367 primary_key: None,
368 unique: None,
369 index: None,
370 foreign_key: None,
371 }],
372 constraints: vec![],
373 }];
374 let result = build_add_column(
375 &DatabaseBackend::Sqlite,
376 "users",
377 &column,
378 None,
379 ¤t_schema,
380 );
381 assert!(result.is_ok());
382 let queries = result.unwrap();
383 let sql = queries
384 .iter()
385 .map(|q| q.build(DatabaseBackend::Sqlite))
386 .collect::<Vec<String>>()
387 .join("\n");
388 assert!(sql.contains("NULL"));
390 }
391
392 #[test]
393 fn test_add_column_sqlite_with_indexes() {
394 use vespertide_core::TableConstraint;
395
396 let column = ColumnDef {
397 name: "nickname".into(),
398 r#type: ColumnType::Simple(SimpleColumnType::Text),
399 nullable: false,
400 default: None,
401 comment: None,
402 primary_key: None,
403 unique: None,
404 index: None,
405 foreign_key: None,
406 };
407 let current_schema = vec![TableDef {
408 name: "users".into(),
409 description: None,
410 columns: vec![ColumnDef {
411 name: "id".into(),
412 r#type: ColumnType::Simple(SimpleColumnType::Integer),
413 nullable: false,
414 default: None,
415 comment: None,
416 primary_key: None,
417 unique: None,
418 index: None,
419 foreign_key: None,
420 }],
421 constraints: vec![TableConstraint::Index {
422 name: Some("idx_id".into()),
423 columns: vec!["id".into()],
424 }],
425 }];
426 let result = build_add_column(
427 &DatabaseBackend::Sqlite,
428 "users",
429 &column,
430 None,
431 ¤t_schema,
432 );
433 assert!(result.is_ok());
434 let queries = result.unwrap();
435 let sql = queries
436 .iter()
437 .map(|q| q.build(DatabaseBackend::Sqlite))
438 .collect::<Vec<String>>()
439 .join("\n");
440 assert!(sql.contains("CREATE INDEX"));
442 assert!(sql.contains("idx_id"));
443 }
444
445 #[rstest]
446 #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
447 #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
448 #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
449 fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
450 use insta::{assert_snapshot, with_settings};
451 use vespertide_core::{ComplexColumnType, EnumValues};
452
453 let column = ColumnDef {
455 name: "status".into(),
456 r#type: ColumnType::Complex(ComplexColumnType::Enum {
457 name: "status_type".into(),
458 values: EnumValues::String(vec!["active".into(), "inactive".into()]),
459 }),
460 nullable: true,
461 default: None,
462 comment: None,
463 primary_key: None,
464 unique: None,
465 index: None,
466 foreign_key: None,
467 };
468 let current_schema = vec![TableDef {
469 name: "users".into(),
470 description: None,
471 columns: vec![ColumnDef {
472 name: "id".into(),
473 r#type: ColumnType::Simple(SimpleColumnType::Integer),
474 nullable: false,
475 default: None,
476 comment: None,
477 primary_key: None,
478 unique: None,
479 index: None,
480 foreign_key: None,
481 }],
482 constraints: vec![],
483 }];
484 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
485 assert!(result.is_ok());
486 let queries = result.unwrap();
487 let sql = queries
488 .iter()
489 .map(|q| q.build(backend))
490 .collect::<Vec<String>>()
491 .join(";\n");
492
493 with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
494 assert_snapshot!(sql);
495 });
496 }
497
498 #[rstest]
499 #[case::postgres(DatabaseBackend::Postgres)]
500 #[case::mysql(DatabaseBackend::MySql)]
501 #[case::sqlite(DatabaseBackend::Sqlite)]
502 fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
503 use insta::{assert_snapshot, with_settings};
504 use vespertide_core::{ComplexColumnType, EnumValues};
505
506 let column = ColumnDef {
508 name: "status".into(),
509 r#type: ColumnType::Complex(ComplexColumnType::Enum {
510 name: "user_status".into(),
511 values: EnumValues::String(vec![
512 "active".into(),
513 "inactive".into(),
514 "pending".into(),
515 ]),
516 }),
517 nullable: false,
518 default: Some("active".into()),
519 comment: None,
520 primary_key: None,
521 unique: None,
522 index: None,
523 foreign_key: None,
524 };
525 let current_schema = vec![TableDef {
526 name: "users".into(),
527 description: None,
528 columns: vec![ColumnDef {
529 name: "id".into(),
530 r#type: ColumnType::Simple(SimpleColumnType::Integer),
531 nullable: false,
532 default: None,
533 comment: None,
534 primary_key: None,
535 unique: None,
536 index: None,
537 foreign_key: None,
538 }],
539 constraints: vec![],
540 }];
541 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
542 assert!(result.is_ok());
543 let queries = result.unwrap();
544 let sql = queries
545 .iter()
546 .map(|q| q.build(backend))
547 .collect::<Vec<String>>()
548 .join(";\n");
549
550 with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
551 assert_snapshot!(sql);
552 });
553 }
554
555 #[rstest]
556 #[case::postgres(DatabaseBackend::Postgres)]
557 #[case::mysql(DatabaseBackend::MySql)]
558 #[case::sqlite(DatabaseBackend::Sqlite)]
559 fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
560 use insta::{assert_snapshot, with_settings};
561
562 let column = ColumnDef {
564 name: "nickname".into(),
565 r#type: ColumnType::Simple(SimpleColumnType::Text),
566 nullable: false,
567 default: Some("".into()), comment: None,
569 primary_key: None,
570 unique: None,
571 index: None,
572 foreign_key: None,
573 };
574 let current_schema = vec![TableDef {
575 name: "users".into(),
576 description: None,
577 columns: vec![ColumnDef {
578 name: "id".into(),
579 r#type: ColumnType::Simple(SimpleColumnType::Integer),
580 nullable: false,
581 default: None,
582 comment: None,
583 primary_key: None,
584 unique: None,
585 index: None,
586 foreign_key: None,
587 }],
588 constraints: vec![],
589 }];
590 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
591 assert!(result.is_ok());
592 let queries = result.unwrap();
593 let sql = queries
594 .iter()
595 .map(|q| q.build(backend))
596 .collect::<Vec<String>>()
597 .join(";\n");
598
599 assert!(
601 sql.contains("''"),
602 "Expected SQL to contain empty string literal '', got: {}",
603 sql
604 );
605
606 with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
607 assert_snapshot!(sql);
608 });
609 }
610
611 #[rstest]
614 #[case::postgres(DatabaseBackend::Postgres)]
615 #[case::mysql(DatabaseBackend::MySql)]
616 #[case::sqlite(DatabaseBackend::Sqlite)]
617 fn test_add_column_with_pg_type_cast_default(#[case] backend: DatabaseBackend) {
618 let column = ColumnDef {
619 name: "story_index".into(),
620 r#type: ColumnType::Simple(SimpleColumnType::Json),
621 nullable: false,
622 default: Some("'[]'::json".into()),
623 comment: None,
624 primary_key: None,
625 unique: None,
626 index: None,
627 foreign_key: None,
628 };
629 let current_schema = vec![TableDef {
630 name: "project".into(),
631 description: None,
632 columns: vec![ColumnDef {
633 name: "id".into(),
634 r#type: ColumnType::Simple(SimpleColumnType::Integer),
635 nullable: false,
636 default: None,
637 comment: None,
638 primary_key: None,
639 unique: None,
640 index: None,
641 foreign_key: None,
642 }],
643 constraints: vec![],
644 }];
645 let result = build_add_column(&backend, "project", &column, None, ¤t_schema).unwrap();
646 let sql = result
647 .iter()
648 .map(|q| q.build(backend))
649 .collect::<Vec<String>>()
650 .join("\n");
651
652 if backend == DatabaseBackend::Sqlite {
654 assert!(
655 !sql.contains("::json"),
656 "SQLite SQL should not contain ::json cast, got: {}",
657 sql
658 );
659 }
660
661 if backend == DatabaseBackend::MySql {
663 assert!(
664 !sql.contains("::json"),
665 "MySQL SQL should not contain ::json cast, got: {}",
666 sql
667 );
668 }
669
670 with_settings!({ snapshot_suffix => format!("pg_type_cast_default_{:?}", backend) }, {
671 assert_snapshot!(sql);
672 });
673 }
674
675 #[rstest]
676 #[case::postgres(DatabaseBackend::Postgres)]
677 #[case::mysql(DatabaseBackend::MySql)]
678 #[case::sqlite(DatabaseBackend::Sqlite)]
679 fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
680 use insta::{assert_snapshot, with_settings};
681
682 let column = ColumnDef {
684 name: "nickname".into(),
685 r#type: ColumnType::Simple(SimpleColumnType::Text),
686 nullable: false,
687 default: None,
688 comment: None,
689 primary_key: None,
690 unique: None,
691 index: None,
692 foreign_key: None,
693 };
694 let current_schema = vec![TableDef {
695 name: "users".into(),
696 description: None,
697 columns: vec![ColumnDef {
698 name: "id".into(),
699 r#type: ColumnType::Simple(SimpleColumnType::Integer),
700 nullable: false,
701 default: None,
702 comment: None,
703 primary_key: None,
704 unique: None,
705 index: None,
706 foreign_key: None,
707 }],
708 constraints: vec![],
709 }];
710 let result = build_add_column(&backend, "users", &column, Some(""), ¤t_schema);
712 assert!(result.is_ok());
713 let queries = result.unwrap();
714 let sql = queries
715 .iter()
716 .map(|q| q.build(backend))
717 .collect::<Vec<String>>()
718 .join(";\n");
719
720 assert!(
722 sql.contains("''"),
723 "Expected SQL to contain empty string literal '', got: {}",
724 sql
725 );
726
727 with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
728 assert_snapshot!(sql);
729 });
730 }
731}