Skip to main content

vespertide_query/sql/
add_column.rs

1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7    convert_default_for_backend, normalize_enum_default, normalize_fill_with,
8    recreate_indexes_after_rebuild,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15    backend: DatabaseBackend,
16    table: &str,
17    column: &ColumnDef,
18) -> TableAlterStatement {
19    let col_def = build_sea_column_def_with_table(backend, table, column);
20    Table::alter()
21        .table(Alias::new(table))
22        .add_column(col_def)
23        .to_owned()
24}
25
26/// Check if the column type is an enum
27fn is_enum_column(column: &ColumnDef) -> bool {
28    matches!(
29        column.r#type,
30        vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31    )
32}
33
34pub fn build_add_column(
35    backend: DatabaseBackend,
36    table: &str,
37    column: &ColumnDef,
38    fill_with: Option<&str>,
39    current_schema: &[TableDef],
40    pending_constraints: &[vespertide_core::TableConstraint],
41) -> Result<Vec<BuiltQuery>, QueryError> {
42    // SQLite: NOT NULL additions or enum columns require table recreation
43    // (enum columns need CHECK constraint which requires table recreation in SQLite)
44    let sqlite_needs_recreation =
45        backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
46
47    if sqlite_needs_recreation {
48        let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::SchemaError(format!("Table '{table}' not found in current schema. SQLite requires current schema information to add columns.")))?;
49
50        let mut new_columns = table_def.columns.clone();
51        new_columns.push(column.clone());
52
53        let temp_table = format!("{table}_temp");
54
55        // 1. Create temporary table with all CHECK constraints (enum + explicit)
56        let create_query = build_sqlite_temp_table_create(
57            backend,
58            &temp_table,
59            table,
60            &new_columns,
61            &table_def.constraints,
62        );
63
64        // Copy existing data, filling new column
65        let mut select_query = Query::select();
66        for col in &table_def.columns {
67            select_query.column(Alias::new(&col.name));
68        }
69        let normalized_fill = normalize_fill_with(fill_with);
70        let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
71            let converted = convert_default_for_backend(fill, backend);
72            Expr::cust(normalize_enum_default(&column.r#type, &converted))
73        } else if let Some(def) = &column.default {
74            let converted = convert_default_for_backend(&def.to_sql(), backend);
75            Expr::cust(normalize_enum_default(&column.r#type, &converted))
76        } else {
77            Expr::cust("NULL")
78        };
79        select_query
80            .expr_as(fill_expr, Alias::new(&column.name))
81            .from(Alias::new(table));
82
83        let mut columns_alias: Vec<Alias> = table_def
84            .columns
85            .iter()
86            .map(|c| Alias::new(&c.name))
87            .collect();
88        columns_alias.push(Alias::new(&column.name));
89        let insert_stmt = Query::insert()
90            .into_table(Alias::new(&temp_table))
91            .columns(columns_alias)
92            .select_from(select_query)
93            .unwrap()
94            .to_owned();
95        let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
96
97        let drop_query =
98            BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
99        let rename_query = build_rename_table(&temp_table, table);
100
101        // Recreate indexes (both regular and UNIQUE)
102        // Skip pending constraints that will be created by future AddConstraint actions
103        let index_queries =
104            recreate_indexes_after_rebuild(table, &table_def.constraints, pending_constraints);
105
106        let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
107        stmts.extend(index_queries);
108        return Ok(stmts);
109    }
110
111    let mut stmts: Vec<BuiltQuery> = Vec::new();
112
113    // If column type is an enum, create the type first (PostgreSQL only)
114    if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
115        stmts.push(BuiltQuery::Raw(create_type_sql));
116    }
117
118    // If adding NOT NULL without default, we need special handling
119    let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
120
121    if needs_backfill {
122        // Add as nullable first
123        let mut temp_col = column.clone();
124        temp_col.nullable = true;
125
126        stmts.push(BuiltQuery::AlterTable(Box::new(
127            build_add_column_alter_for_backend(backend, table, &temp_col),
128        )));
129
130        // Backfill with provided value
131        if let Some(fill) = normalize_fill_with(fill_with) {
132            let fill = convert_default_for_backend(&fill, backend);
133            let update_stmt = Query::update()
134                .table(Alias::new(table))
135                .value(Alias::new(&column.name), Expr::cust(fill))
136                .to_owned();
137            stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
138        }
139
140        // Set NOT NULL
141        let not_null_col = build_sea_column_def_with_table(backend, table, column);
142        let alter_not_null = Table::alter()
143            .table(Alias::new(table))
144            .modify_column(not_null_col)
145            .to_owned();
146        stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
147    } else {
148        stmts.push(BuiltQuery::AlterTable(Box::new(
149            build_add_column_alter_for_backend(backend, table, column),
150        )));
151    }
152
153    Ok(stmts)
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use insta::{assert_snapshot, with_settings};
160    use rstest::rstest;
161    use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
162
163    #[rstest]
164    #[case::add_column_with_backfill_postgres(
165        "add_column_with_backfill_postgres",
166        DatabaseBackend::Postgres,
167        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
168    )]
169    #[case::add_column_with_backfill_mysql(
170        "add_column_with_backfill_mysql",
171        DatabaseBackend::MySql,
172        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
173    )]
174    #[case::add_column_with_backfill_sqlite(
175        "add_column_with_backfill_sqlite",
176        DatabaseBackend::Sqlite,
177        &["CREATE TABLE \"users_temp\""]
178    )]
179    #[case::add_column_simple_postgres(
180        "add_column_simple_postgres",
181        DatabaseBackend::Postgres,
182        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
183    )]
184    #[case::add_column_simple_mysql(
185        "add_column_simple_mysql",
186        DatabaseBackend::MySql,
187        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
188    )]
189    #[case::add_column_simple_sqlite(
190        "add_column_simple_sqlite",
191        DatabaseBackend::Sqlite,
192        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
193    )]
194    #[case::add_column_nullable_postgres(
195        "add_column_nullable_postgres",
196        DatabaseBackend::Postgres,
197        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
198    )]
199    #[case::add_column_nullable_mysql(
200        "add_column_nullable_mysql",
201        DatabaseBackend::MySql,
202        &["ALTER TABLE `users` ADD COLUMN `email` text"]
203    )]
204    #[case::add_column_nullable_sqlite(
205        "add_column_nullable_sqlite",
206        DatabaseBackend::Sqlite,
207        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
208    )]
209    fn test_add_column(
210        #[case] title: &str,
211        #[case] backend: DatabaseBackend,
212        #[case] expected: &[&str],
213    ) {
214        let column = ColumnDef {
215            name: if title.contains("age") {
216                "age"
217            } else if title.contains("nullable") {
218                "email"
219            } else {
220                "nickname"
221            }
222            .into(),
223            r#type: if title.contains("age") {
224                ColumnType::Simple(SimpleColumnType::Integer)
225            } else {
226                ColumnType::Simple(SimpleColumnType::Text)
227            },
228            nullable: !title.contains("backfill"),
229            default: None,
230            comment: None,
231            primary_key: None,
232            unique: None,
233            index: None,
234            foreign_key: None,
235        };
236        let fill_with = if title.contains("backfill") {
237            Some("0")
238        } else {
239            None
240        };
241        let current_schema = vec![TableDef {
242            name: "users".into(),
243            description: None,
244            columns: vec![ColumnDef {
245                name: "id".into(),
246                r#type: ColumnType::Simple(SimpleColumnType::Integer),
247                nullable: false,
248                default: None,
249                comment: None,
250                primary_key: None,
251                unique: None,
252                index: None,
253                foreign_key: None,
254            }],
255            constraints: vec![],
256        }];
257        let result =
258            build_add_column(backend, "users", &column, fill_with, &current_schema, &[]).unwrap();
259        let sql = result[0].build(backend);
260        for exp in expected {
261            assert!(
262                sql.contains(exp),
263                "Expected SQL to contain '{exp}', got: {sql}"
264            );
265        }
266
267        with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
268            assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
269        });
270    }
271
272    #[test]
273    fn test_add_column_sqlite_table_not_found() {
274        let column = ColumnDef {
275            name: "nickname".into(),
276            r#type: ColumnType::Simple(SimpleColumnType::Text),
277            nullable: false,
278            default: None,
279            comment: None,
280            primary_key: None,
281            unique: None,
282            index: None,
283            foreign_key: None,
284        };
285        let current_schema = vec![]; // Empty schema - table not found
286        let result = build_add_column(
287            DatabaseBackend::Sqlite,
288            "users",
289            &column,
290            None,
291            &current_schema,
292            &[],
293        );
294        assert!(result.is_err());
295        let err_msg = result.unwrap_err().to_string();
296        assert!(err_msg.contains("Table 'users' not found in current schema"));
297    }
298
299    #[test]
300    fn test_add_column_sqlite_with_default() {
301        let column = ColumnDef {
302            name: "age".into(),
303            r#type: ColumnType::Simple(SimpleColumnType::Integer),
304            nullable: false,
305            default: Some("18".into()),
306            comment: None,
307            primary_key: None,
308            unique: None,
309            index: None,
310            foreign_key: None,
311        };
312        let current_schema = vec![TableDef {
313            name: "users".into(),
314            description: None,
315            columns: vec![ColumnDef {
316                name: "id".into(),
317                r#type: ColumnType::Simple(SimpleColumnType::Integer),
318                nullable: false,
319                default: None,
320                comment: None,
321                primary_key: None,
322                unique: None,
323                index: None,
324                foreign_key: None,
325            }],
326            constraints: vec![],
327        }];
328        let result = build_add_column(
329            DatabaseBackend::Sqlite,
330            "users",
331            &column,
332            None,
333            &current_schema,
334            &[],
335        );
336        assert!(result.is_ok());
337        let queries = result.unwrap();
338        let sql = queries
339            .iter()
340            .map(|q| q.build(DatabaseBackend::Sqlite))
341            .collect::<Vec<String>>()
342            .join("\n");
343        // Should use default value (18) for fill
344        assert!(sql.contains("18"));
345    }
346
347    #[test]
348    fn test_add_column_sqlite_without_fill_or_default() {
349        let column = ColumnDef {
350            name: "age".into(),
351            r#type: ColumnType::Simple(SimpleColumnType::Integer),
352            nullable: false,
353            default: None,
354            comment: None,
355            primary_key: None,
356            unique: None,
357            index: None,
358            foreign_key: None,
359        };
360        let current_schema = vec![TableDef {
361            name: "users".into(),
362            description: None,
363            columns: vec![ColumnDef {
364                name: "id".into(),
365                r#type: ColumnType::Simple(SimpleColumnType::Integer),
366                nullable: false,
367                default: None,
368                comment: None,
369                primary_key: None,
370                unique: None,
371                index: None,
372                foreign_key: None,
373            }],
374            constraints: vec![],
375        }];
376        let result = build_add_column(
377            DatabaseBackend::Sqlite,
378            "users",
379            &column,
380            None,
381            &current_schema,
382            &[],
383        );
384        assert!(result.is_ok());
385        let queries = result.unwrap();
386        let sql = queries
387            .iter()
388            .map(|q| q.build(DatabaseBackend::Sqlite))
389            .collect::<Vec<String>>()
390            .join("\n");
391        // Should use NULL for fill
392        assert!(sql.contains("NULL"));
393    }
394
395    #[test]
396    fn test_add_column_sqlite_with_indexes() {
397        use vespertide_core::TableConstraint;
398
399        let column = ColumnDef {
400            name: "nickname".into(),
401            r#type: ColumnType::Simple(SimpleColumnType::Text),
402            nullable: false,
403            default: None,
404            comment: None,
405            primary_key: None,
406            unique: None,
407            index: None,
408            foreign_key: None,
409        };
410        let current_schema = vec![TableDef {
411            name: "users".into(),
412            description: None,
413            columns: vec![ColumnDef {
414                name: "id".into(),
415                r#type: ColumnType::Simple(SimpleColumnType::Integer),
416                nullable: false,
417                default: None,
418                comment: None,
419                primary_key: None,
420                unique: None,
421                index: None,
422                foreign_key: None,
423            }],
424            constraints: vec![TableConstraint::Index {
425                name: Some("idx_id".into()),
426                columns: vec!["id".into()],
427            }],
428        }];
429        let result = build_add_column(
430            DatabaseBackend::Sqlite,
431            "users",
432            &column,
433            None,
434            &current_schema,
435            &[],
436        );
437        assert!(result.is_ok());
438        let queries = result.unwrap();
439        let sql = queries
440            .iter()
441            .map(|q| q.build(DatabaseBackend::Sqlite))
442            .collect::<Vec<String>>()
443            .join("\n");
444        // Should recreate index
445        assert!(sql.contains("CREATE INDEX"));
446        assert!(sql.contains("idx_id"));
447    }
448
449    #[rstest]
450    #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
451    #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
452    #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
453    fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
454        use insta::{assert_snapshot, with_settings};
455        use vespertide_core::{ComplexColumnType, EnumValues};
456
457        // Test that adding an enum column creates the enum type first (PostgreSQL only)
458        let column = ColumnDef {
459            name: "status".into(),
460            r#type: ColumnType::Complex(ComplexColumnType::Enum {
461                name: "status_type".into(),
462                values: EnumValues::String(vec!["active".into(), "inactive".into()]),
463            }),
464            nullable: true,
465            default: None,
466            comment: None,
467            primary_key: None,
468            unique: None,
469            index: None,
470            foreign_key: None,
471        };
472        let current_schema = vec![TableDef {
473            name: "users".into(),
474            description: None,
475            columns: vec![ColumnDef {
476                name: "id".into(),
477                r#type: ColumnType::Simple(SimpleColumnType::Integer),
478                nullable: false,
479                default: None,
480                comment: None,
481                primary_key: None,
482                unique: None,
483                index: None,
484                foreign_key: None,
485            }],
486            constraints: vec![],
487        }];
488        let result = build_add_column(backend, "users", &column, None, &current_schema, &[]);
489        assert!(result.is_ok());
490        let queries = result.unwrap();
491        let sql = queries
492            .iter()
493            .map(|q| q.build(backend))
494            .collect::<Vec<String>>()
495            .join(";\n");
496
497        with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
498            assert_snapshot!(sql);
499        });
500    }
501
502    #[rstest]
503    #[case::postgres(DatabaseBackend::Postgres)]
504    #[case::mysql(DatabaseBackend::MySql)]
505    #[case::sqlite(DatabaseBackend::Sqlite)]
506    fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
507        use insta::{assert_snapshot, with_settings};
508        use vespertide_core::{ComplexColumnType, EnumValues};
509
510        // Test adding an enum column that is non-nullable with a default value
511        let column = ColumnDef {
512            name: "status".into(),
513            r#type: ColumnType::Complex(ComplexColumnType::Enum {
514                name: "user_status".into(),
515                values: EnumValues::String(vec![
516                    "active".into(),
517                    "inactive".into(),
518                    "pending".into(),
519                ]),
520            }),
521            nullable: false,
522            default: Some("active".into()),
523            comment: None,
524            primary_key: None,
525            unique: None,
526            index: None,
527            foreign_key: None,
528        };
529        let current_schema = vec![TableDef {
530            name: "users".into(),
531            description: None,
532            columns: vec![ColumnDef {
533                name: "id".into(),
534                r#type: ColumnType::Simple(SimpleColumnType::Integer),
535                nullable: false,
536                default: None,
537                comment: None,
538                primary_key: None,
539                unique: None,
540                index: None,
541                foreign_key: None,
542            }],
543            constraints: vec![],
544        }];
545        let result = build_add_column(backend, "users", &column, None, &current_schema, &[]);
546        assert!(result.is_ok());
547        let queries = result.unwrap();
548        let sql = queries
549            .iter()
550            .map(|q| q.build(backend))
551            .collect::<Vec<String>>()
552            .join(";\n");
553
554        with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
555            assert_snapshot!(sql);
556        });
557    }
558
559    #[rstest]
560    #[case::postgres(DatabaseBackend::Postgres)]
561    #[case::mysql(DatabaseBackend::MySql)]
562    #[case::sqlite(DatabaseBackend::Sqlite)]
563    fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
564        use insta::{assert_snapshot, with_settings};
565
566        // Test adding a text column with empty string default
567        let column = ColumnDef {
568            name: "nickname".into(),
569            r#type: ColumnType::Simple(SimpleColumnType::Text),
570            nullable: false,
571            default: Some("".into()), // Empty string default
572            comment: None,
573            primary_key: None,
574            unique: None,
575            index: None,
576            foreign_key: None,
577        };
578        let current_schema = vec![TableDef {
579            name: "users".into(),
580            description: None,
581            columns: vec![ColumnDef {
582                name: "id".into(),
583                r#type: ColumnType::Simple(SimpleColumnType::Integer),
584                nullable: false,
585                default: None,
586                comment: None,
587                primary_key: None,
588                unique: None,
589                index: None,
590                foreign_key: None,
591            }],
592            constraints: vec![],
593        }];
594        let result = build_add_column(backend, "users", &column, None, &current_schema, &[]);
595        assert!(result.is_ok());
596        let queries = result.unwrap();
597        let sql = queries
598            .iter()
599            .map(|q| q.build(backend))
600            .collect::<Vec<String>>()
601            .join(";\n");
602
603        // Verify empty string becomes ''
604        assert!(
605            sql.contains("''"),
606            "Expected SQL to contain empty string literal '', got: {sql}"
607        );
608
609        with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
610            assert_snapshot!(sql);
611        });
612    }
613
614    /// Test adding NOT NULL column with '[]'`::json` default on `SQLite`
615    /// `SQLite` should strip the `::json` cast, `MySQL` should use CAST(... AS JSON)
616    #[rstest]
617    #[case::postgres(DatabaseBackend::Postgres)]
618    #[case::mysql(DatabaseBackend::MySql)]
619    #[case::sqlite(DatabaseBackend::Sqlite)]
620    fn test_add_column_with_pg_type_cast_default(#[case] backend: DatabaseBackend) {
621        let column = ColumnDef {
622            name: "story_index".into(),
623            r#type: ColumnType::Simple(SimpleColumnType::Json),
624            nullable: false,
625            default: Some("'[]'::json".into()),
626            comment: None,
627            primary_key: None,
628            unique: None,
629            index: None,
630            foreign_key: None,
631        };
632        let current_schema = vec![TableDef {
633            name: "project".into(),
634            description: None,
635            columns: vec![ColumnDef {
636                name: "id".into(),
637                r#type: ColumnType::Simple(SimpleColumnType::Integer),
638                nullable: false,
639                default: None,
640                comment: None,
641                primary_key: None,
642                unique: None,
643                index: None,
644                foreign_key: None,
645            }],
646            constraints: vec![],
647        }];
648        let result =
649            build_add_column(backend, "project", &column, None, &current_schema, &[]).unwrap();
650        let sql = result
651            .iter()
652            .map(|q| q.build(backend))
653            .collect::<Vec<String>>()
654            .join("\n");
655
656        // SQLite must NOT contain ::json syntax
657        if backend == DatabaseBackend::Sqlite {
658            assert!(
659                !sql.contains("::json"),
660                "SQLite SQL should not contain ::json cast, got: {sql}"
661            );
662        }
663
664        // MySQL should use CAST syntax
665        if backend == DatabaseBackend::MySql {
666            assert!(
667                !sql.contains("::json"),
668                "MySQL SQL should not contain ::json cast, got: {sql}"
669            );
670        }
671
672        with_settings!({ snapshot_suffix => format!("pg_type_cast_default_{:?}", backend) }, {
673            assert_snapshot!(sql);
674        });
675    }
676
677    #[rstest]
678    #[case::postgres(DatabaseBackend::Postgres)]
679    #[case::mysql(DatabaseBackend::MySql)]
680    #[case::sqlite(DatabaseBackend::Sqlite)]
681    fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
682        use insta::{assert_snapshot, with_settings};
683
684        // Test adding a column with fill_with as empty string
685        let column = ColumnDef {
686            name: "nickname".into(),
687            r#type: ColumnType::Simple(SimpleColumnType::Text),
688            nullable: false,
689            default: None,
690            comment: None,
691            primary_key: None,
692            unique: None,
693            index: None,
694            foreign_key: None,
695        };
696        let current_schema = vec![TableDef {
697            name: "users".into(),
698            description: None,
699            columns: vec![ColumnDef {
700                name: "id".into(),
701                r#type: ColumnType::Simple(SimpleColumnType::Integer),
702                nullable: false,
703                default: None,
704                comment: None,
705                primary_key: None,
706                unique: None,
707                index: None,
708                foreign_key: None,
709            }],
710            constraints: vec![],
711        }];
712        // fill_with empty string should become ''
713        let result = build_add_column(backend, "users", &column, Some(""), &current_schema, &[]);
714        assert!(result.is_ok());
715        let queries = result.unwrap();
716        let sql = queries
717            .iter()
718            .map(|q| q.build(backend))
719            .collect::<Vec<String>>()
720            .join(";\n");
721
722        // Verify empty string becomes ''
723        assert!(
724            sql.contains("''"),
725            "Expected SQL to contain empty string literal '', got: {sql}"
726        );
727
728        with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
729            assert_snapshot!(sql);
730        });
731    }
732}