Skip to main content

vespertide_query/sql/
add_column.rs

1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7    convert_default_for_backend, normalize_enum_default, normalize_fill_with,
8    recreate_indexes_after_rebuild,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15    backend: &DatabaseBackend,
16    table: &str,
17    column: &ColumnDef,
18) -> TableAlterStatement {
19    let col_def = build_sea_column_def_with_table(backend, table, column);
20    Table::alter()
21        .table(Alias::new(table))
22        .add_column(col_def)
23        .to_owned()
24}
25
26/// Check if the column type is an enum
27fn is_enum_column(column: &ColumnDef) -> bool {
28    matches!(
29        column.r#type,
30        vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31    )
32}
33
34pub fn build_add_column(
35    backend: &DatabaseBackend,
36    table: &str,
37    column: &ColumnDef,
38    fill_with: Option<&str>,
39    current_schema: &[TableDef],
40    pending_constraints: &[vespertide_core::TableConstraint],
41) -> Result<Vec<BuiltQuery>, QueryError> {
42    // SQLite: NOT NULL additions or enum columns require table recreation
43    // (enum columns need CHECK constraint which requires table recreation in SQLite)
44    let sqlite_needs_recreation =
45        *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
46
47    if sqlite_needs_recreation {
48        let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
49
50        let mut new_columns = table_def.columns.clone();
51        new_columns.push(column.clone());
52
53        let temp_table = format!("{}_temp", table);
54
55        // 1. Create temporary table with all CHECK constraints (enum + explicit)
56        let create_query = build_sqlite_temp_table_create(
57            backend,
58            &temp_table,
59            table,
60            &new_columns,
61            &table_def.constraints,
62        );
63
64        // Copy existing data, filling new column
65        let mut select_query = Query::select();
66        for col in &table_def.columns {
67            select_query = select_query.column(Alias::new(&col.name)).to_owned();
68        }
69        let normalized_fill = normalize_fill_with(fill_with);
70        let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
71            let converted = convert_default_for_backend(fill, backend);
72            Expr::cust(normalize_enum_default(&column.r#type, &converted))
73        } else if let Some(def) = &column.default {
74            let converted = convert_default_for_backend(&def.to_sql(), backend);
75            Expr::cust(normalize_enum_default(&column.r#type, &converted))
76        } else {
77            Expr::cust("NULL")
78        };
79        select_query = select_query
80            .expr_as(fill_expr, Alias::new(&column.name))
81            .from(Alias::new(table))
82            .to_owned();
83
84        let mut columns_alias: Vec<Alias> = table_def
85            .columns
86            .iter()
87            .map(|c| Alias::new(&c.name))
88            .collect();
89        columns_alias.push(Alias::new(&column.name));
90        let insert_stmt = Query::insert()
91            .into_table(Alias::new(&temp_table))
92            .columns(columns_alias)
93            .select_from(select_query)
94            .unwrap()
95            .to_owned();
96        let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
97
98        let drop_query =
99            BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
100        let rename_query = build_rename_table(&temp_table, table);
101
102        // Recreate indexes (both regular and UNIQUE)
103        // Skip pending constraints that will be created by future AddConstraint actions
104        let index_queries =
105            recreate_indexes_after_rebuild(table, &table_def.constraints, pending_constraints);
106
107        let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
108        stmts.extend(index_queries);
109        return Ok(stmts);
110    }
111
112    let mut stmts: Vec<BuiltQuery> = Vec::new();
113
114    // If column type is an enum, create the type first (PostgreSQL only)
115    if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
116        stmts.push(BuiltQuery::Raw(create_type_sql));
117    }
118
119    // If adding NOT NULL without default, we need special handling
120    let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
121
122    if needs_backfill {
123        // Add as nullable first
124        let mut temp_col = column.clone();
125        temp_col.nullable = true;
126
127        stmts.push(BuiltQuery::AlterTable(Box::new(
128            build_add_column_alter_for_backend(backend, table, &temp_col),
129        )));
130
131        // Backfill with provided value
132        if let Some(fill) = normalize_fill_with(fill_with) {
133            let fill = convert_default_for_backend(&fill, backend);
134            let update_stmt = Query::update()
135                .table(Alias::new(table))
136                .value(Alias::new(&column.name), Expr::cust(fill))
137                .to_owned();
138            stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
139        }
140
141        // Set NOT NULL
142        let not_null_col = build_sea_column_def_with_table(backend, table, column);
143        let alter_not_null = Table::alter()
144            .table(Alias::new(table))
145            .modify_column(not_null_col)
146            .to_owned();
147        stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
148    } else {
149        stmts.push(BuiltQuery::AlterTable(Box::new(
150            build_add_column_alter_for_backend(backend, table, column),
151        )));
152    }
153
154    Ok(stmts)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use insta::{assert_snapshot, with_settings};
161    use rstest::rstest;
162    use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
163
164    #[rstest]
165    #[case::add_column_with_backfill_postgres(
166        "add_column_with_backfill_postgres",
167        DatabaseBackend::Postgres,
168        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
169    )]
170    #[case::add_column_with_backfill_mysql(
171        "add_column_with_backfill_mysql",
172        DatabaseBackend::MySql,
173        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
174    )]
175    #[case::add_column_with_backfill_sqlite(
176        "add_column_with_backfill_sqlite",
177        DatabaseBackend::Sqlite,
178        &["CREATE TABLE \"users_temp\""]
179    )]
180    #[case::add_column_simple_postgres(
181        "add_column_simple_postgres",
182        DatabaseBackend::Postgres,
183        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
184    )]
185    #[case::add_column_simple_mysql(
186        "add_column_simple_mysql",
187        DatabaseBackend::MySql,
188        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
189    )]
190    #[case::add_column_simple_sqlite(
191        "add_column_simple_sqlite",
192        DatabaseBackend::Sqlite,
193        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
194    )]
195    #[case::add_column_nullable_postgres(
196        "add_column_nullable_postgres",
197        DatabaseBackend::Postgres,
198        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
199    )]
200    #[case::add_column_nullable_mysql(
201        "add_column_nullable_mysql",
202        DatabaseBackend::MySql,
203        &["ALTER TABLE `users` ADD COLUMN `email` text"]
204    )]
205    #[case::add_column_nullable_sqlite(
206        "add_column_nullable_sqlite",
207        DatabaseBackend::Sqlite,
208        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
209    )]
210    fn test_add_column(
211        #[case] title: &str,
212        #[case] backend: DatabaseBackend,
213        #[case] expected: &[&str],
214    ) {
215        let column = ColumnDef {
216            name: if title.contains("age") {
217                "age"
218            } else if title.contains("nullable") {
219                "email"
220            } else {
221                "nickname"
222            }
223            .into(),
224            r#type: if title.contains("age") {
225                ColumnType::Simple(SimpleColumnType::Integer)
226            } else {
227                ColumnType::Simple(SimpleColumnType::Text)
228            },
229            nullable: !title.contains("backfill"),
230            default: None,
231            comment: None,
232            primary_key: None,
233            unique: None,
234            index: None,
235            foreign_key: None,
236        };
237        let fill_with = if title.contains("backfill") {
238            Some("0")
239        } else {
240            None
241        };
242        let current_schema = vec![TableDef {
243            name: "users".into(),
244            description: None,
245            columns: vec![ColumnDef {
246                name: "id".into(),
247                r#type: ColumnType::Simple(SimpleColumnType::Integer),
248                nullable: false,
249                default: None,
250                comment: None,
251                primary_key: None,
252                unique: None,
253                index: None,
254                foreign_key: None,
255            }],
256            constraints: vec![],
257        }];
258        let result =
259            build_add_column(&backend, "users", &column, fill_with, &current_schema, &[]).unwrap();
260        let sql = result[0].build(backend);
261        for exp in expected {
262            assert!(
263                sql.contains(exp),
264                "Expected SQL to contain '{}', got: {}",
265                exp,
266                sql
267            );
268        }
269
270        with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
271            assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
272        });
273    }
274
275    #[test]
276    fn test_add_column_sqlite_table_not_found() {
277        let column = ColumnDef {
278            name: "nickname".into(),
279            r#type: ColumnType::Simple(SimpleColumnType::Text),
280            nullable: false,
281            default: None,
282            comment: None,
283            primary_key: None,
284            unique: None,
285            index: None,
286            foreign_key: None,
287        };
288        let current_schema = vec![]; // Empty schema - table not found
289        let result = build_add_column(
290            &DatabaseBackend::Sqlite,
291            "users",
292            &column,
293            None,
294            &current_schema,
295            &[],
296        );
297        assert!(result.is_err());
298        let err_msg = result.unwrap_err().to_string();
299        assert!(err_msg.contains("Table 'users' not found in current schema"));
300    }
301
302    #[test]
303    fn test_add_column_sqlite_with_default() {
304        let column = ColumnDef {
305            name: "age".into(),
306            r#type: ColumnType::Simple(SimpleColumnType::Integer),
307            nullable: false,
308            default: Some("18".into()),
309            comment: None,
310            primary_key: None,
311            unique: None,
312            index: None,
313            foreign_key: None,
314        };
315        let current_schema = vec![TableDef {
316            name: "users".into(),
317            description: None,
318            columns: vec![ColumnDef {
319                name: "id".into(),
320                r#type: ColumnType::Simple(SimpleColumnType::Integer),
321                nullable: false,
322                default: None,
323                comment: None,
324                primary_key: None,
325                unique: None,
326                index: None,
327                foreign_key: None,
328            }],
329            constraints: vec![],
330        }];
331        let result = build_add_column(
332            &DatabaseBackend::Sqlite,
333            "users",
334            &column,
335            None,
336            &current_schema,
337            &[],
338        );
339        assert!(result.is_ok());
340        let queries = result.unwrap();
341        let sql = queries
342            .iter()
343            .map(|q| q.build(DatabaseBackend::Sqlite))
344            .collect::<Vec<String>>()
345            .join("\n");
346        // Should use default value (18) for fill
347        assert!(sql.contains("18"));
348    }
349
350    #[test]
351    fn test_add_column_sqlite_without_fill_or_default() {
352        let column = ColumnDef {
353            name: "age".into(),
354            r#type: ColumnType::Simple(SimpleColumnType::Integer),
355            nullable: false,
356            default: None,
357            comment: None,
358            primary_key: None,
359            unique: None,
360            index: None,
361            foreign_key: None,
362        };
363        let current_schema = vec![TableDef {
364            name: "users".into(),
365            description: None,
366            columns: vec![ColumnDef {
367                name: "id".into(),
368                r#type: ColumnType::Simple(SimpleColumnType::Integer),
369                nullable: false,
370                default: None,
371                comment: None,
372                primary_key: None,
373                unique: None,
374                index: None,
375                foreign_key: None,
376            }],
377            constraints: vec![],
378        }];
379        let result = build_add_column(
380            &DatabaseBackend::Sqlite,
381            "users",
382            &column,
383            None,
384            &current_schema,
385            &[],
386        );
387        assert!(result.is_ok());
388        let queries = result.unwrap();
389        let sql = queries
390            .iter()
391            .map(|q| q.build(DatabaseBackend::Sqlite))
392            .collect::<Vec<String>>()
393            .join("\n");
394        // Should use NULL for fill
395        assert!(sql.contains("NULL"));
396    }
397
398    #[test]
399    fn test_add_column_sqlite_with_indexes() {
400        use vespertide_core::TableConstraint;
401
402        let column = ColumnDef {
403            name: "nickname".into(),
404            r#type: ColumnType::Simple(SimpleColumnType::Text),
405            nullable: false,
406            default: None,
407            comment: None,
408            primary_key: None,
409            unique: None,
410            index: None,
411            foreign_key: None,
412        };
413        let current_schema = vec![TableDef {
414            name: "users".into(),
415            description: None,
416            columns: vec![ColumnDef {
417                name: "id".into(),
418                r#type: ColumnType::Simple(SimpleColumnType::Integer),
419                nullable: false,
420                default: None,
421                comment: None,
422                primary_key: None,
423                unique: None,
424                index: None,
425                foreign_key: None,
426            }],
427            constraints: vec![TableConstraint::Index {
428                name: Some("idx_id".into()),
429                columns: vec!["id".into()],
430            }],
431        }];
432        let result = build_add_column(
433            &DatabaseBackend::Sqlite,
434            "users",
435            &column,
436            None,
437            &current_schema,
438            &[],
439        );
440        assert!(result.is_ok());
441        let queries = result.unwrap();
442        let sql = queries
443            .iter()
444            .map(|q| q.build(DatabaseBackend::Sqlite))
445            .collect::<Vec<String>>()
446            .join("\n");
447        // Should recreate index
448        assert!(sql.contains("CREATE INDEX"));
449        assert!(sql.contains("idx_id"));
450    }
451
452    #[rstest]
453    #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
454    #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
455    #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
456    fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
457        use insta::{assert_snapshot, with_settings};
458        use vespertide_core::{ComplexColumnType, EnumValues};
459
460        // Test that adding an enum column creates the enum type first (PostgreSQL only)
461        let column = ColumnDef {
462            name: "status".into(),
463            r#type: ColumnType::Complex(ComplexColumnType::Enum {
464                name: "status_type".into(),
465                values: EnumValues::String(vec!["active".into(), "inactive".into()]),
466            }),
467            nullable: true,
468            default: None,
469            comment: None,
470            primary_key: None,
471            unique: None,
472            index: None,
473            foreign_key: None,
474        };
475        let current_schema = vec![TableDef {
476            name: "users".into(),
477            description: None,
478            columns: vec![ColumnDef {
479                name: "id".into(),
480                r#type: ColumnType::Simple(SimpleColumnType::Integer),
481                nullable: false,
482                default: None,
483                comment: None,
484                primary_key: None,
485                unique: None,
486                index: None,
487                foreign_key: None,
488            }],
489            constraints: vec![],
490        }];
491        let result = build_add_column(&backend, "users", &column, None, &current_schema, &[]);
492        assert!(result.is_ok());
493        let queries = result.unwrap();
494        let sql = queries
495            .iter()
496            .map(|q| q.build(backend))
497            .collect::<Vec<String>>()
498            .join(";\n");
499
500        with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
501            assert_snapshot!(sql);
502        });
503    }
504
505    #[rstest]
506    #[case::postgres(DatabaseBackend::Postgres)]
507    #[case::mysql(DatabaseBackend::MySql)]
508    #[case::sqlite(DatabaseBackend::Sqlite)]
509    fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
510        use insta::{assert_snapshot, with_settings};
511        use vespertide_core::{ComplexColumnType, EnumValues};
512
513        // Test adding an enum column that is non-nullable with a default value
514        let column = ColumnDef {
515            name: "status".into(),
516            r#type: ColumnType::Complex(ComplexColumnType::Enum {
517                name: "user_status".into(),
518                values: EnumValues::String(vec![
519                    "active".into(),
520                    "inactive".into(),
521                    "pending".into(),
522                ]),
523            }),
524            nullable: false,
525            default: Some("active".into()),
526            comment: None,
527            primary_key: None,
528            unique: None,
529            index: None,
530            foreign_key: None,
531        };
532        let current_schema = vec![TableDef {
533            name: "users".into(),
534            description: None,
535            columns: vec![ColumnDef {
536                name: "id".into(),
537                r#type: ColumnType::Simple(SimpleColumnType::Integer),
538                nullable: false,
539                default: None,
540                comment: None,
541                primary_key: None,
542                unique: None,
543                index: None,
544                foreign_key: None,
545            }],
546            constraints: vec![],
547        }];
548        let result = build_add_column(&backend, "users", &column, None, &current_schema, &[]);
549        assert!(result.is_ok());
550        let queries = result.unwrap();
551        let sql = queries
552            .iter()
553            .map(|q| q.build(backend))
554            .collect::<Vec<String>>()
555            .join(";\n");
556
557        with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
558            assert_snapshot!(sql);
559        });
560    }
561
562    #[rstest]
563    #[case::postgres(DatabaseBackend::Postgres)]
564    #[case::mysql(DatabaseBackend::MySql)]
565    #[case::sqlite(DatabaseBackend::Sqlite)]
566    fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
567        use insta::{assert_snapshot, with_settings};
568
569        // Test adding a text column with empty string default
570        let column = ColumnDef {
571            name: "nickname".into(),
572            r#type: ColumnType::Simple(SimpleColumnType::Text),
573            nullable: false,
574            default: Some("".into()), // Empty string default
575            comment: None,
576            primary_key: None,
577            unique: None,
578            index: None,
579            foreign_key: None,
580        };
581        let current_schema = vec![TableDef {
582            name: "users".into(),
583            description: None,
584            columns: vec![ColumnDef {
585                name: "id".into(),
586                r#type: ColumnType::Simple(SimpleColumnType::Integer),
587                nullable: false,
588                default: None,
589                comment: None,
590                primary_key: None,
591                unique: None,
592                index: None,
593                foreign_key: None,
594            }],
595            constraints: vec![],
596        }];
597        let result = build_add_column(&backend, "users", &column, None, &current_schema, &[]);
598        assert!(result.is_ok());
599        let queries = result.unwrap();
600        let sql = queries
601            .iter()
602            .map(|q| q.build(backend))
603            .collect::<Vec<String>>()
604            .join(";\n");
605
606        // Verify empty string becomes ''
607        assert!(
608            sql.contains("''"),
609            "Expected SQL to contain empty string literal '', got: {}",
610            sql
611        );
612
613        with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
614            assert_snapshot!(sql);
615        });
616    }
617
618    /// Test adding NOT NULL column with '[]'::json default on SQLite
619    /// SQLite should strip the ::json cast, MySQL should use CAST(... AS JSON)
620    #[rstest]
621    #[case::postgres(DatabaseBackend::Postgres)]
622    #[case::mysql(DatabaseBackend::MySql)]
623    #[case::sqlite(DatabaseBackend::Sqlite)]
624    fn test_add_column_with_pg_type_cast_default(#[case] backend: DatabaseBackend) {
625        let column = ColumnDef {
626            name: "story_index".into(),
627            r#type: ColumnType::Simple(SimpleColumnType::Json),
628            nullable: false,
629            default: Some("'[]'::json".into()),
630            comment: None,
631            primary_key: None,
632            unique: None,
633            index: None,
634            foreign_key: None,
635        };
636        let current_schema = vec![TableDef {
637            name: "project".into(),
638            description: None,
639            columns: vec![ColumnDef {
640                name: "id".into(),
641                r#type: ColumnType::Simple(SimpleColumnType::Integer),
642                nullable: false,
643                default: None,
644                comment: None,
645                primary_key: None,
646                unique: None,
647                index: None,
648                foreign_key: None,
649            }],
650            constraints: vec![],
651        }];
652        let result =
653            build_add_column(&backend, "project", &column, None, &current_schema, &[]).unwrap();
654        let sql = result
655            .iter()
656            .map(|q| q.build(backend))
657            .collect::<Vec<String>>()
658            .join("\n");
659
660        // SQLite must NOT contain ::json syntax
661        if backend == DatabaseBackend::Sqlite {
662            assert!(
663                !sql.contains("::json"),
664                "SQLite SQL should not contain ::json cast, got: {}",
665                sql
666            );
667        }
668
669        // MySQL should use CAST syntax
670        if backend == DatabaseBackend::MySql {
671            assert!(
672                !sql.contains("::json"),
673                "MySQL SQL should not contain ::json cast, got: {}",
674                sql
675            );
676        }
677
678        with_settings!({ snapshot_suffix => format!("pg_type_cast_default_{:?}", backend) }, {
679            assert_snapshot!(sql);
680        });
681    }
682
683    #[rstest]
684    #[case::postgres(DatabaseBackend::Postgres)]
685    #[case::mysql(DatabaseBackend::MySql)]
686    #[case::sqlite(DatabaseBackend::Sqlite)]
687    fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
688        use insta::{assert_snapshot, with_settings};
689
690        // Test adding a column with fill_with as empty string
691        let column = ColumnDef {
692            name: "nickname".into(),
693            r#type: ColumnType::Simple(SimpleColumnType::Text),
694            nullable: false,
695            default: None,
696            comment: None,
697            primary_key: None,
698            unique: None,
699            index: None,
700            foreign_key: None,
701        };
702        let current_schema = vec![TableDef {
703            name: "users".into(),
704            description: None,
705            columns: vec![ColumnDef {
706                name: "id".into(),
707                r#type: ColumnType::Simple(SimpleColumnType::Integer),
708                nullable: false,
709                default: None,
710                comment: None,
711                primary_key: None,
712                unique: None,
713                index: None,
714                foreign_key: None,
715            }],
716            constraints: vec![],
717        }];
718        // fill_with empty string should become ''
719        let result = build_add_column(&backend, "users", &column, Some(""), &current_schema, &[]);
720        assert!(result.is_ok());
721        let queries = result.unwrap();
722        let sql = queries
723            .iter()
724            .map(|q| q.build(backend))
725            .collect::<Vec<String>>()
726            .join(";\n");
727
728        // Verify empty string becomes ''
729        assert!(
730            sql.contains("''"),
731            "Expected SQL to contain empty string literal '', got: {}",
732            sql
733        );
734
735        with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
736            assert_snapshot!(sql);
737        });
738    }
739}