1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{
7 build_create_enum_type_sql, build_schema_statement, build_sea_column_def_with_table,
8 collect_sqlite_enum_check_clauses, normalize_enum_default, normalize_fill_with,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend, RawSql};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15 backend: &DatabaseBackend,
16 table: &str,
17 column: &ColumnDef,
18) -> TableAlterStatement {
19 let col_def = build_sea_column_def_with_table(backend, table, column);
20 Table::alter()
21 .table(Alias::new(table))
22 .add_column(col_def)
23 .to_owned()
24}
25
26fn is_enum_column(column: &ColumnDef) -> bool {
28 matches!(
29 column.r#type,
30 vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31 )
32}
33
34pub fn build_add_column(
35 backend: &DatabaseBackend,
36 table: &str,
37 column: &ColumnDef,
38 fill_with: Option<&str>,
39 current_schema: &[TableDef],
40) -> Result<Vec<BuiltQuery>, QueryError> {
41 let sqlite_needs_recreation =
44 *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
45
46 if sqlite_needs_recreation {
47 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
48
49 let mut new_columns = table_def.columns.clone();
50 new_columns.push(column.clone());
51
52 let temp_table = format!("{}_temp", table);
53 let create_temp = build_create_table_for_backend(
54 backend,
55 &temp_table,
56 &new_columns,
57 &table_def.constraints,
58 );
59
60 let enum_check_clauses = collect_sqlite_enum_check_clauses(table, &new_columns);
63 let create_query = if !enum_check_clauses.is_empty() {
64 let base_sql = build_schema_statement(&create_temp, *backend);
65 let mut modified_sql = base_sql;
66 if let Some(pos) = modified_sql.rfind(')') {
67 let check_sql = enum_check_clauses.join(", ");
68 modified_sql.insert_str(pos, &format!(", {}", check_sql));
69 }
70 BuiltQuery::Raw(RawSql::per_backend(
71 modified_sql.clone(),
72 modified_sql.clone(),
73 modified_sql,
74 ))
75 } else {
76 BuiltQuery::CreateTable(Box::new(create_temp))
77 };
78
79 let mut select_query = Query::select();
81 for col in &table_def.columns {
82 select_query = select_query.column(Alias::new(&col.name)).to_owned();
83 }
84 let normalized_fill = normalize_fill_with(fill_with);
85 let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
86 Expr::cust(normalize_enum_default(&column.r#type, fill))
87 } else if let Some(def) = &column.default {
88 Expr::cust(normalize_enum_default(&column.r#type, &def.to_sql()))
89 } else {
90 Expr::cust("NULL")
91 };
92 select_query = select_query
93 .expr_as(fill_expr, Alias::new(&column.name))
94 .from(Alias::new(table))
95 .to_owned();
96
97 let mut columns_alias: Vec<Alias> = table_def
98 .columns
99 .iter()
100 .map(|c| Alias::new(&c.name))
101 .collect();
102 columns_alias.push(Alias::new(&column.name));
103 let insert_stmt = Query::insert()
104 .into_table(Alias::new(&temp_table))
105 .columns(columns_alias)
106 .select_from(select_query)
107 .unwrap()
108 .to_owned();
109 let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
110
111 let drop_query =
112 BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
113 let rename_query = build_rename_table(&temp_table, table);
114
115 let mut index_queries = Vec::new();
117 for constraint in &table_def.constraints {
118 if let vespertide_core::TableConstraint::Index { name, columns } = constraint {
119 let index_name =
120 vespertide_naming::build_index_name(table, columns, name.as_deref());
121 let mut idx_stmt = sea_query::Index::create();
122 idx_stmt = idx_stmt.name(&index_name).to_owned();
123 for col_name in columns {
124 idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
125 }
126 idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
127 index_queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
128 }
129 }
130
131 let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
132 stmts.extend(index_queries);
133 return Ok(stmts);
134 }
135
136 let mut stmts: Vec<BuiltQuery> = Vec::new();
137
138 if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
140 stmts.push(BuiltQuery::Raw(create_type_sql));
141 }
142
143 let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
145
146 if needs_backfill {
147 let mut temp_col = column.clone();
149 temp_col.nullable = true;
150
151 stmts.push(BuiltQuery::AlterTable(Box::new(
152 build_add_column_alter_for_backend(backend, table, &temp_col),
153 )));
154
155 if let Some(fill) = normalize_fill_with(fill_with) {
157 let update_stmt = Query::update()
158 .table(Alias::new(table))
159 .value(Alias::new(&column.name), Expr::cust(fill))
160 .to_owned();
161 stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
162 }
163
164 let not_null_col = build_sea_column_def_with_table(backend, table, column);
166 let alter_not_null = Table::alter()
167 .table(Alias::new(table))
168 .modify_column(not_null_col)
169 .to_owned();
170 stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
171 } else {
172 stmts.push(BuiltQuery::AlterTable(Box::new(
173 build_add_column_alter_for_backend(backend, table, column),
174 )));
175 }
176
177 Ok(stmts)
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use insta::{assert_snapshot, with_settings};
184 use rstest::rstest;
185 use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
186
187 #[rstest]
188 #[case::add_column_with_backfill_postgres(
189 "add_column_with_backfill_postgres",
190 DatabaseBackend::Postgres,
191 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
192 )]
193 #[case::add_column_with_backfill_mysql(
194 "add_column_with_backfill_mysql",
195 DatabaseBackend::MySql,
196 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
197 )]
198 #[case::add_column_with_backfill_sqlite(
199 "add_column_with_backfill_sqlite",
200 DatabaseBackend::Sqlite,
201 &["CREATE TABLE \"users_temp\""]
202 )]
203 #[case::add_column_simple_postgres(
204 "add_column_simple_postgres",
205 DatabaseBackend::Postgres,
206 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
207 )]
208 #[case::add_column_simple_mysql(
209 "add_column_simple_mysql",
210 DatabaseBackend::MySql,
211 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
212 )]
213 #[case::add_column_simple_sqlite(
214 "add_column_simple_sqlite",
215 DatabaseBackend::Sqlite,
216 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
217 )]
218 #[case::add_column_nullable_postgres(
219 "add_column_nullable_postgres",
220 DatabaseBackend::Postgres,
221 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
222 )]
223 #[case::add_column_nullable_mysql(
224 "add_column_nullable_mysql",
225 DatabaseBackend::MySql,
226 &["ALTER TABLE `users` ADD COLUMN `email` text"]
227 )]
228 #[case::add_column_nullable_sqlite(
229 "add_column_nullable_sqlite",
230 DatabaseBackend::Sqlite,
231 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
232 )]
233 fn test_add_column(
234 #[case] title: &str,
235 #[case] backend: DatabaseBackend,
236 #[case] expected: &[&str],
237 ) {
238 let column = ColumnDef {
239 name: if title.contains("age") {
240 "age"
241 } else if title.contains("nullable") {
242 "email"
243 } else {
244 "nickname"
245 }
246 .into(),
247 r#type: if title.contains("age") {
248 ColumnType::Simple(SimpleColumnType::Integer)
249 } else {
250 ColumnType::Simple(SimpleColumnType::Text)
251 },
252 nullable: !title.contains("backfill"),
253 default: None,
254 comment: None,
255 primary_key: None,
256 unique: None,
257 index: None,
258 foreign_key: None,
259 };
260 let fill_with = if title.contains("backfill") {
261 Some("0")
262 } else {
263 None
264 };
265 let current_schema = vec![TableDef {
266 name: "users".into(),
267 description: None,
268 columns: vec![ColumnDef {
269 name: "id".into(),
270 r#type: ColumnType::Simple(SimpleColumnType::Integer),
271 nullable: false,
272 default: None,
273 comment: None,
274 primary_key: None,
275 unique: None,
276 index: None,
277 foreign_key: None,
278 }],
279 constraints: vec![],
280 }];
281 let result =
282 build_add_column(&backend, "users", &column, fill_with, ¤t_schema).unwrap();
283 let sql = result[0].build(backend);
284 for exp in expected {
285 assert!(
286 sql.contains(exp),
287 "Expected SQL to contain '{}', got: {}",
288 exp,
289 sql
290 );
291 }
292
293 with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
294 assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
295 });
296 }
297
298 #[test]
299 fn test_add_column_sqlite_table_not_found() {
300 let column = ColumnDef {
301 name: "nickname".into(),
302 r#type: ColumnType::Simple(SimpleColumnType::Text),
303 nullable: false,
304 default: None,
305 comment: None,
306 primary_key: None,
307 unique: None,
308 index: None,
309 foreign_key: None,
310 };
311 let current_schema = vec![]; let result = build_add_column(
313 &DatabaseBackend::Sqlite,
314 "users",
315 &column,
316 None,
317 ¤t_schema,
318 );
319 assert!(result.is_err());
320 let err_msg = result.unwrap_err().to_string();
321 assert!(err_msg.contains("Table 'users' not found in current schema"));
322 }
323
324 #[test]
325 fn test_add_column_sqlite_with_default() {
326 let column = ColumnDef {
327 name: "age".into(),
328 r#type: ColumnType::Simple(SimpleColumnType::Integer),
329 nullable: false,
330 default: Some("18".into()),
331 comment: None,
332 primary_key: None,
333 unique: None,
334 index: None,
335 foreign_key: None,
336 };
337 let current_schema = vec![TableDef {
338 name: "users".into(),
339 description: None,
340 columns: vec![ColumnDef {
341 name: "id".into(),
342 r#type: ColumnType::Simple(SimpleColumnType::Integer),
343 nullable: false,
344 default: None,
345 comment: None,
346 primary_key: None,
347 unique: None,
348 index: None,
349 foreign_key: None,
350 }],
351 constraints: vec![],
352 }];
353 let result = build_add_column(
354 &DatabaseBackend::Sqlite,
355 "users",
356 &column,
357 None,
358 ¤t_schema,
359 );
360 assert!(result.is_ok());
361 let queries = result.unwrap();
362 let sql = queries
363 .iter()
364 .map(|q| q.build(DatabaseBackend::Sqlite))
365 .collect::<Vec<String>>()
366 .join("\n");
367 assert!(sql.contains("18"));
369 }
370
371 #[test]
372 fn test_add_column_sqlite_without_fill_or_default() {
373 let column = ColumnDef {
374 name: "age".into(),
375 r#type: ColumnType::Simple(SimpleColumnType::Integer),
376 nullable: false,
377 default: None,
378 comment: None,
379 primary_key: None,
380 unique: None,
381 index: None,
382 foreign_key: None,
383 };
384 let current_schema = vec![TableDef {
385 name: "users".into(),
386 description: None,
387 columns: vec![ColumnDef {
388 name: "id".into(),
389 r#type: ColumnType::Simple(SimpleColumnType::Integer),
390 nullable: false,
391 default: None,
392 comment: None,
393 primary_key: None,
394 unique: None,
395 index: None,
396 foreign_key: None,
397 }],
398 constraints: vec![],
399 }];
400 let result = build_add_column(
401 &DatabaseBackend::Sqlite,
402 "users",
403 &column,
404 None,
405 ¤t_schema,
406 );
407 assert!(result.is_ok());
408 let queries = result.unwrap();
409 let sql = queries
410 .iter()
411 .map(|q| q.build(DatabaseBackend::Sqlite))
412 .collect::<Vec<String>>()
413 .join("\n");
414 assert!(sql.contains("NULL"));
416 }
417
418 #[test]
419 fn test_add_column_sqlite_with_indexes() {
420 use vespertide_core::TableConstraint;
421
422 let column = ColumnDef {
423 name: "nickname".into(),
424 r#type: ColumnType::Simple(SimpleColumnType::Text),
425 nullable: false,
426 default: None,
427 comment: None,
428 primary_key: None,
429 unique: None,
430 index: None,
431 foreign_key: None,
432 };
433 let current_schema = vec![TableDef {
434 name: "users".into(),
435 description: None,
436 columns: vec![ColumnDef {
437 name: "id".into(),
438 r#type: ColumnType::Simple(SimpleColumnType::Integer),
439 nullable: false,
440 default: None,
441 comment: None,
442 primary_key: None,
443 unique: None,
444 index: None,
445 foreign_key: None,
446 }],
447 constraints: vec![TableConstraint::Index {
448 name: Some("idx_id".into()),
449 columns: vec!["id".into()],
450 }],
451 }];
452 let result = build_add_column(
453 &DatabaseBackend::Sqlite,
454 "users",
455 &column,
456 None,
457 ¤t_schema,
458 );
459 assert!(result.is_ok());
460 let queries = result.unwrap();
461 let sql = queries
462 .iter()
463 .map(|q| q.build(DatabaseBackend::Sqlite))
464 .collect::<Vec<String>>()
465 .join("\n");
466 assert!(sql.contains("CREATE INDEX"));
468 assert!(sql.contains("idx_id"));
469 }
470
471 #[rstest]
472 #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
473 #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
474 #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
475 fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
476 use insta::{assert_snapshot, with_settings};
477 use vespertide_core::{ComplexColumnType, EnumValues};
478
479 let column = ColumnDef {
481 name: "status".into(),
482 r#type: ColumnType::Complex(ComplexColumnType::Enum {
483 name: "status_type".into(),
484 values: EnumValues::String(vec!["active".into(), "inactive".into()]),
485 }),
486 nullable: true,
487 default: None,
488 comment: None,
489 primary_key: None,
490 unique: None,
491 index: None,
492 foreign_key: None,
493 };
494 let current_schema = vec![TableDef {
495 name: "users".into(),
496 description: None,
497 columns: vec![ColumnDef {
498 name: "id".into(),
499 r#type: ColumnType::Simple(SimpleColumnType::Integer),
500 nullable: false,
501 default: None,
502 comment: None,
503 primary_key: None,
504 unique: None,
505 index: None,
506 foreign_key: None,
507 }],
508 constraints: vec![],
509 }];
510 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
511 assert!(result.is_ok());
512 let queries = result.unwrap();
513 let sql = queries
514 .iter()
515 .map(|q| q.build(backend))
516 .collect::<Vec<String>>()
517 .join(";\n");
518
519 with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
520 assert_snapshot!(sql);
521 });
522 }
523
524 #[rstest]
525 #[case::postgres(DatabaseBackend::Postgres)]
526 #[case::mysql(DatabaseBackend::MySql)]
527 #[case::sqlite(DatabaseBackend::Sqlite)]
528 fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
529 use insta::{assert_snapshot, with_settings};
530 use vespertide_core::{ComplexColumnType, EnumValues};
531
532 let column = ColumnDef {
534 name: "status".into(),
535 r#type: ColumnType::Complex(ComplexColumnType::Enum {
536 name: "user_status".into(),
537 values: EnumValues::String(vec![
538 "active".into(),
539 "inactive".into(),
540 "pending".into(),
541 ]),
542 }),
543 nullable: false,
544 default: Some("active".into()),
545 comment: None,
546 primary_key: None,
547 unique: None,
548 index: None,
549 foreign_key: None,
550 };
551 let current_schema = vec![TableDef {
552 name: "users".into(),
553 description: None,
554 columns: vec![ColumnDef {
555 name: "id".into(),
556 r#type: ColumnType::Simple(SimpleColumnType::Integer),
557 nullable: false,
558 default: None,
559 comment: None,
560 primary_key: None,
561 unique: None,
562 index: None,
563 foreign_key: None,
564 }],
565 constraints: vec![],
566 }];
567 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
568 assert!(result.is_ok());
569 let queries = result.unwrap();
570 let sql = queries
571 .iter()
572 .map(|q| q.build(backend))
573 .collect::<Vec<String>>()
574 .join(";\n");
575
576 with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
577 assert_snapshot!(sql);
578 });
579 }
580
581 #[rstest]
582 #[case::postgres(DatabaseBackend::Postgres)]
583 #[case::mysql(DatabaseBackend::MySql)]
584 #[case::sqlite(DatabaseBackend::Sqlite)]
585 fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
586 use insta::{assert_snapshot, with_settings};
587
588 let column = ColumnDef {
590 name: "nickname".into(),
591 r#type: ColumnType::Simple(SimpleColumnType::Text),
592 nullable: false,
593 default: Some("".into()), comment: None,
595 primary_key: None,
596 unique: None,
597 index: None,
598 foreign_key: None,
599 };
600 let current_schema = vec![TableDef {
601 name: "users".into(),
602 description: None,
603 columns: vec![ColumnDef {
604 name: "id".into(),
605 r#type: ColumnType::Simple(SimpleColumnType::Integer),
606 nullable: false,
607 default: None,
608 comment: None,
609 primary_key: None,
610 unique: None,
611 index: None,
612 foreign_key: None,
613 }],
614 constraints: vec![],
615 }];
616 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
617 assert!(result.is_ok());
618 let queries = result.unwrap();
619 let sql = queries
620 .iter()
621 .map(|q| q.build(backend))
622 .collect::<Vec<String>>()
623 .join(";\n");
624
625 assert!(
627 sql.contains("''"),
628 "Expected SQL to contain empty string literal '', got: {}",
629 sql
630 );
631
632 with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
633 assert_snapshot!(sql);
634 });
635 }
636
637 #[rstest]
638 #[case::postgres(DatabaseBackend::Postgres)]
639 #[case::mysql(DatabaseBackend::MySql)]
640 #[case::sqlite(DatabaseBackend::Sqlite)]
641 fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
642 use insta::{assert_snapshot, with_settings};
643
644 let column = ColumnDef {
646 name: "nickname".into(),
647 r#type: ColumnType::Simple(SimpleColumnType::Text),
648 nullable: false,
649 default: None,
650 comment: None,
651 primary_key: None,
652 unique: None,
653 index: None,
654 foreign_key: None,
655 };
656 let current_schema = vec![TableDef {
657 name: "users".into(),
658 description: None,
659 columns: vec![ColumnDef {
660 name: "id".into(),
661 r#type: ColumnType::Simple(SimpleColumnType::Integer),
662 nullable: false,
663 default: None,
664 comment: None,
665 primary_key: None,
666 unique: None,
667 index: None,
668 foreign_key: None,
669 }],
670 constraints: vec![],
671 }];
672 let result = build_add_column(&backend, "users", &column, Some(""), ¤t_schema);
674 assert!(result.is_ok());
675 let queries = result.unwrap();
676 let sql = queries
677 .iter()
678 .map(|q| q.build(backend))
679 .collect::<Vec<String>>()
680 .join(";\n");
681
682 assert!(
684 sql.contains("''"),
685 "Expected SQL to contain empty string literal '', got: {}",
686 sql
687 );
688
689 with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
690 assert_snapshot!(sql);
691 });
692 }
693}