1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7 normalize_enum_default, normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend};
11use crate::error::QueryError;
12
13fn build_add_column_alter_for_backend(
14 backend: &DatabaseBackend,
15 table: &str,
16 column: &ColumnDef,
17) -> TableAlterStatement {
18 let col_def = build_sea_column_def_with_table(backend, table, column);
19 Table::alter()
20 .table(Alias::new(table))
21 .add_column(col_def)
22 .to_owned()
23}
24
25fn is_enum_column(column: &ColumnDef) -> bool {
27 matches!(
28 column.r#type,
29 vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
30 )
31}
32
33pub fn build_add_column(
34 backend: &DatabaseBackend,
35 table: &str,
36 column: &ColumnDef,
37 fill_with: Option<&str>,
38 current_schema: &[TableDef],
39) -> Result<Vec<BuiltQuery>, QueryError> {
40 let sqlite_needs_recreation =
43 *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
44
45 if sqlite_needs_recreation {
46 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
47
48 let mut new_columns = table_def.columns.clone();
49 new_columns.push(column.clone());
50
51 let temp_table = format!("{}_temp", table);
52
53 let create_query = build_sqlite_temp_table_create(
55 backend,
56 &temp_table,
57 table,
58 &new_columns,
59 &table_def.constraints,
60 );
61
62 let mut select_query = Query::select();
64 for col in &table_def.columns {
65 select_query = select_query.column(Alias::new(&col.name)).to_owned();
66 }
67 let normalized_fill = normalize_fill_with(fill_with);
68 let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
69 Expr::cust(normalize_enum_default(&column.r#type, fill))
70 } else if let Some(def) = &column.default {
71 Expr::cust(normalize_enum_default(&column.r#type, &def.to_sql()))
72 } else {
73 Expr::cust("NULL")
74 };
75 select_query = select_query
76 .expr_as(fill_expr, Alias::new(&column.name))
77 .from(Alias::new(table))
78 .to_owned();
79
80 let mut columns_alias: Vec<Alias> = table_def
81 .columns
82 .iter()
83 .map(|c| Alias::new(&c.name))
84 .collect();
85 columns_alias.push(Alias::new(&column.name));
86 let insert_stmt = Query::insert()
87 .into_table(Alias::new(&temp_table))
88 .columns(columns_alias)
89 .select_from(select_query)
90 .unwrap()
91 .to_owned();
92 let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
93
94 let drop_query =
95 BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
96 let rename_query = build_rename_table(&temp_table, table);
97
98 let index_queries = recreate_indexes_after_rebuild(table, &table_def.constraints, &[]);
100
101 let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
102 stmts.extend(index_queries);
103 return Ok(stmts);
104 }
105
106 let mut stmts: Vec<BuiltQuery> = Vec::new();
107
108 if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
110 stmts.push(BuiltQuery::Raw(create_type_sql));
111 }
112
113 let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
115
116 if needs_backfill {
117 let mut temp_col = column.clone();
119 temp_col.nullable = true;
120
121 stmts.push(BuiltQuery::AlterTable(Box::new(
122 build_add_column_alter_for_backend(backend, table, &temp_col),
123 )));
124
125 if let Some(fill) = normalize_fill_with(fill_with) {
127 let update_stmt = Query::update()
128 .table(Alias::new(table))
129 .value(Alias::new(&column.name), Expr::cust(fill))
130 .to_owned();
131 stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
132 }
133
134 let not_null_col = build_sea_column_def_with_table(backend, table, column);
136 let alter_not_null = Table::alter()
137 .table(Alias::new(table))
138 .modify_column(not_null_col)
139 .to_owned();
140 stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
141 } else {
142 stmts.push(BuiltQuery::AlterTable(Box::new(
143 build_add_column_alter_for_backend(backend, table, column),
144 )));
145 }
146
147 Ok(stmts)
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use insta::{assert_snapshot, with_settings};
154 use rstest::rstest;
155 use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
156
157 #[rstest]
158 #[case::add_column_with_backfill_postgres(
159 "add_column_with_backfill_postgres",
160 DatabaseBackend::Postgres,
161 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
162 )]
163 #[case::add_column_with_backfill_mysql(
164 "add_column_with_backfill_mysql",
165 DatabaseBackend::MySql,
166 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
167 )]
168 #[case::add_column_with_backfill_sqlite(
169 "add_column_with_backfill_sqlite",
170 DatabaseBackend::Sqlite,
171 &["CREATE TABLE \"users_temp\""]
172 )]
173 #[case::add_column_simple_postgres(
174 "add_column_simple_postgres",
175 DatabaseBackend::Postgres,
176 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
177 )]
178 #[case::add_column_simple_mysql(
179 "add_column_simple_mysql",
180 DatabaseBackend::MySql,
181 &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
182 )]
183 #[case::add_column_simple_sqlite(
184 "add_column_simple_sqlite",
185 DatabaseBackend::Sqlite,
186 &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
187 )]
188 #[case::add_column_nullable_postgres(
189 "add_column_nullable_postgres",
190 DatabaseBackend::Postgres,
191 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
192 )]
193 #[case::add_column_nullable_mysql(
194 "add_column_nullable_mysql",
195 DatabaseBackend::MySql,
196 &["ALTER TABLE `users` ADD COLUMN `email` text"]
197 )]
198 #[case::add_column_nullable_sqlite(
199 "add_column_nullable_sqlite",
200 DatabaseBackend::Sqlite,
201 &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
202 )]
203 fn test_add_column(
204 #[case] title: &str,
205 #[case] backend: DatabaseBackend,
206 #[case] expected: &[&str],
207 ) {
208 let column = ColumnDef {
209 name: if title.contains("age") {
210 "age"
211 } else if title.contains("nullable") {
212 "email"
213 } else {
214 "nickname"
215 }
216 .into(),
217 r#type: if title.contains("age") {
218 ColumnType::Simple(SimpleColumnType::Integer)
219 } else {
220 ColumnType::Simple(SimpleColumnType::Text)
221 },
222 nullable: !title.contains("backfill"),
223 default: None,
224 comment: None,
225 primary_key: None,
226 unique: None,
227 index: None,
228 foreign_key: None,
229 };
230 let fill_with = if title.contains("backfill") {
231 Some("0")
232 } else {
233 None
234 };
235 let current_schema = vec![TableDef {
236 name: "users".into(),
237 description: None,
238 columns: vec![ColumnDef {
239 name: "id".into(),
240 r#type: ColumnType::Simple(SimpleColumnType::Integer),
241 nullable: false,
242 default: None,
243 comment: None,
244 primary_key: None,
245 unique: None,
246 index: None,
247 foreign_key: None,
248 }],
249 constraints: vec![],
250 }];
251 let result =
252 build_add_column(&backend, "users", &column, fill_with, ¤t_schema).unwrap();
253 let sql = result[0].build(backend);
254 for exp in expected {
255 assert!(
256 sql.contains(exp),
257 "Expected SQL to contain '{}', got: {}",
258 exp,
259 sql
260 );
261 }
262
263 with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
264 assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
265 });
266 }
267
268 #[test]
269 fn test_add_column_sqlite_table_not_found() {
270 let column = ColumnDef {
271 name: "nickname".into(),
272 r#type: ColumnType::Simple(SimpleColumnType::Text),
273 nullable: false,
274 default: None,
275 comment: None,
276 primary_key: None,
277 unique: None,
278 index: None,
279 foreign_key: None,
280 };
281 let current_schema = vec![]; let result = build_add_column(
283 &DatabaseBackend::Sqlite,
284 "users",
285 &column,
286 None,
287 ¤t_schema,
288 );
289 assert!(result.is_err());
290 let err_msg = result.unwrap_err().to_string();
291 assert!(err_msg.contains("Table 'users' not found in current schema"));
292 }
293
294 #[test]
295 fn test_add_column_sqlite_with_default() {
296 let column = ColumnDef {
297 name: "age".into(),
298 r#type: ColumnType::Simple(SimpleColumnType::Integer),
299 nullable: false,
300 default: Some("18".into()),
301 comment: None,
302 primary_key: None,
303 unique: None,
304 index: None,
305 foreign_key: None,
306 };
307 let current_schema = vec![TableDef {
308 name: "users".into(),
309 description: None,
310 columns: vec![ColumnDef {
311 name: "id".into(),
312 r#type: ColumnType::Simple(SimpleColumnType::Integer),
313 nullable: false,
314 default: None,
315 comment: None,
316 primary_key: None,
317 unique: None,
318 index: None,
319 foreign_key: None,
320 }],
321 constraints: vec![],
322 }];
323 let result = build_add_column(
324 &DatabaseBackend::Sqlite,
325 "users",
326 &column,
327 None,
328 ¤t_schema,
329 );
330 assert!(result.is_ok());
331 let queries = result.unwrap();
332 let sql = queries
333 .iter()
334 .map(|q| q.build(DatabaseBackend::Sqlite))
335 .collect::<Vec<String>>()
336 .join("\n");
337 assert!(sql.contains("18"));
339 }
340
341 #[test]
342 fn test_add_column_sqlite_without_fill_or_default() {
343 let column = ColumnDef {
344 name: "age".into(),
345 r#type: ColumnType::Simple(SimpleColumnType::Integer),
346 nullable: false,
347 default: None,
348 comment: None,
349 primary_key: None,
350 unique: None,
351 index: None,
352 foreign_key: None,
353 };
354 let current_schema = vec![TableDef {
355 name: "users".into(),
356 description: None,
357 columns: vec![ColumnDef {
358 name: "id".into(),
359 r#type: ColumnType::Simple(SimpleColumnType::Integer),
360 nullable: false,
361 default: None,
362 comment: None,
363 primary_key: None,
364 unique: None,
365 index: None,
366 foreign_key: None,
367 }],
368 constraints: vec![],
369 }];
370 let result = build_add_column(
371 &DatabaseBackend::Sqlite,
372 "users",
373 &column,
374 None,
375 ¤t_schema,
376 );
377 assert!(result.is_ok());
378 let queries = result.unwrap();
379 let sql = queries
380 .iter()
381 .map(|q| q.build(DatabaseBackend::Sqlite))
382 .collect::<Vec<String>>()
383 .join("\n");
384 assert!(sql.contains("NULL"));
386 }
387
388 #[test]
389 fn test_add_column_sqlite_with_indexes() {
390 use vespertide_core::TableConstraint;
391
392 let column = ColumnDef {
393 name: "nickname".into(),
394 r#type: ColumnType::Simple(SimpleColumnType::Text),
395 nullable: false,
396 default: None,
397 comment: None,
398 primary_key: None,
399 unique: None,
400 index: None,
401 foreign_key: None,
402 };
403 let current_schema = vec![TableDef {
404 name: "users".into(),
405 description: None,
406 columns: vec![ColumnDef {
407 name: "id".into(),
408 r#type: ColumnType::Simple(SimpleColumnType::Integer),
409 nullable: false,
410 default: None,
411 comment: None,
412 primary_key: None,
413 unique: None,
414 index: None,
415 foreign_key: None,
416 }],
417 constraints: vec![TableConstraint::Index {
418 name: Some("idx_id".into()),
419 columns: vec!["id".into()],
420 }],
421 }];
422 let result = build_add_column(
423 &DatabaseBackend::Sqlite,
424 "users",
425 &column,
426 None,
427 ¤t_schema,
428 );
429 assert!(result.is_ok());
430 let queries = result.unwrap();
431 let sql = queries
432 .iter()
433 .map(|q| q.build(DatabaseBackend::Sqlite))
434 .collect::<Vec<String>>()
435 .join("\n");
436 assert!(sql.contains("CREATE INDEX"));
438 assert!(sql.contains("idx_id"));
439 }
440
441 #[rstest]
442 #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
443 #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
444 #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
445 fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
446 use insta::{assert_snapshot, with_settings};
447 use vespertide_core::{ComplexColumnType, EnumValues};
448
449 let column = ColumnDef {
451 name: "status".into(),
452 r#type: ColumnType::Complex(ComplexColumnType::Enum {
453 name: "status_type".into(),
454 values: EnumValues::String(vec!["active".into(), "inactive".into()]),
455 }),
456 nullable: true,
457 default: None,
458 comment: None,
459 primary_key: None,
460 unique: None,
461 index: None,
462 foreign_key: None,
463 };
464 let current_schema = vec![TableDef {
465 name: "users".into(),
466 description: None,
467 columns: vec![ColumnDef {
468 name: "id".into(),
469 r#type: ColumnType::Simple(SimpleColumnType::Integer),
470 nullable: false,
471 default: None,
472 comment: None,
473 primary_key: None,
474 unique: None,
475 index: None,
476 foreign_key: None,
477 }],
478 constraints: vec![],
479 }];
480 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
481 assert!(result.is_ok());
482 let queries = result.unwrap();
483 let sql = queries
484 .iter()
485 .map(|q| q.build(backend))
486 .collect::<Vec<String>>()
487 .join(";\n");
488
489 with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
490 assert_snapshot!(sql);
491 });
492 }
493
494 #[rstest]
495 #[case::postgres(DatabaseBackend::Postgres)]
496 #[case::mysql(DatabaseBackend::MySql)]
497 #[case::sqlite(DatabaseBackend::Sqlite)]
498 fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
499 use insta::{assert_snapshot, with_settings};
500 use vespertide_core::{ComplexColumnType, EnumValues};
501
502 let column = ColumnDef {
504 name: "status".into(),
505 r#type: ColumnType::Complex(ComplexColumnType::Enum {
506 name: "user_status".into(),
507 values: EnumValues::String(vec![
508 "active".into(),
509 "inactive".into(),
510 "pending".into(),
511 ]),
512 }),
513 nullable: false,
514 default: Some("active".into()),
515 comment: None,
516 primary_key: None,
517 unique: None,
518 index: None,
519 foreign_key: None,
520 };
521 let current_schema = vec![TableDef {
522 name: "users".into(),
523 description: None,
524 columns: vec![ColumnDef {
525 name: "id".into(),
526 r#type: ColumnType::Simple(SimpleColumnType::Integer),
527 nullable: false,
528 default: None,
529 comment: None,
530 primary_key: None,
531 unique: None,
532 index: None,
533 foreign_key: None,
534 }],
535 constraints: vec![],
536 }];
537 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
538 assert!(result.is_ok());
539 let queries = result.unwrap();
540 let sql = queries
541 .iter()
542 .map(|q| q.build(backend))
543 .collect::<Vec<String>>()
544 .join(";\n");
545
546 with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
547 assert_snapshot!(sql);
548 });
549 }
550
551 #[rstest]
552 #[case::postgres(DatabaseBackend::Postgres)]
553 #[case::mysql(DatabaseBackend::MySql)]
554 #[case::sqlite(DatabaseBackend::Sqlite)]
555 fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
556 use insta::{assert_snapshot, with_settings};
557
558 let column = ColumnDef {
560 name: "nickname".into(),
561 r#type: ColumnType::Simple(SimpleColumnType::Text),
562 nullable: false,
563 default: Some("".into()), comment: None,
565 primary_key: None,
566 unique: None,
567 index: None,
568 foreign_key: None,
569 };
570 let current_schema = vec![TableDef {
571 name: "users".into(),
572 description: None,
573 columns: vec![ColumnDef {
574 name: "id".into(),
575 r#type: ColumnType::Simple(SimpleColumnType::Integer),
576 nullable: false,
577 default: None,
578 comment: None,
579 primary_key: None,
580 unique: None,
581 index: None,
582 foreign_key: None,
583 }],
584 constraints: vec![],
585 }];
586 let result = build_add_column(&backend, "users", &column, None, ¤t_schema);
587 assert!(result.is_ok());
588 let queries = result.unwrap();
589 let sql = queries
590 .iter()
591 .map(|q| q.build(backend))
592 .collect::<Vec<String>>()
593 .join(";\n");
594
595 assert!(
597 sql.contains("''"),
598 "Expected SQL to contain empty string literal '', got: {}",
599 sql
600 );
601
602 with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
603 assert_snapshot!(sql);
604 });
605 }
606
607 #[rstest]
608 #[case::postgres(DatabaseBackend::Postgres)]
609 #[case::mysql(DatabaseBackend::MySql)]
610 #[case::sqlite(DatabaseBackend::Sqlite)]
611 fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
612 use insta::{assert_snapshot, with_settings};
613
614 let column = ColumnDef {
616 name: "nickname".into(),
617 r#type: ColumnType::Simple(SimpleColumnType::Text),
618 nullable: false,
619 default: None,
620 comment: None,
621 primary_key: None,
622 unique: None,
623 index: None,
624 foreign_key: None,
625 };
626 let current_schema = vec![TableDef {
627 name: "users".into(),
628 description: None,
629 columns: vec![ColumnDef {
630 name: "id".into(),
631 r#type: ColumnType::Simple(SimpleColumnType::Integer),
632 nullable: false,
633 default: None,
634 comment: None,
635 primary_key: None,
636 unique: None,
637 index: None,
638 foreign_key: None,
639 }],
640 constraints: vec![],
641 }];
642 let result = build_add_column(&backend, "users", &column, Some(""), ¤t_schema);
644 assert!(result.is_ok());
645 let queries = result.unwrap();
646 let sql = queries
647 .iter()
648 .map(|q| q.build(backend))
649 .collect::<Vec<String>>()
650 .join(";\n");
651
652 assert!(
654 sql.contains("''"),
655 "Expected SQL to contain empty string literal '', got: {}",
656 sql
657 );
658
659 with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
660 assert_snapshot!(sql);
661 });
662 }
663}