Skip to main content

vespertide_query/sql/
add_column.rs

1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7    convert_default_for_backend, normalize_enum_default, normalize_fill_with,
8    recreate_indexes_after_rebuild,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15    backend: &DatabaseBackend,
16    table: &str,
17    column: &ColumnDef,
18) -> TableAlterStatement {
19    let col_def = build_sea_column_def_with_table(backend, table, column);
20    Table::alter()
21        .table(Alias::new(table))
22        .add_column(col_def)
23        .to_owned()
24}
25
26/// Check if the column type is an enum
27fn is_enum_column(column: &ColumnDef) -> bool {
28    matches!(
29        column.r#type,
30        vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31    )
32}
33
34pub fn build_add_column(
35    backend: &DatabaseBackend,
36    table: &str,
37    column: &ColumnDef,
38    fill_with: Option<&str>,
39    current_schema: &[TableDef],
40) -> Result<Vec<BuiltQuery>, QueryError> {
41    // SQLite: NOT NULL additions or enum columns require table recreation
42    // (enum columns need CHECK constraint which requires table recreation in SQLite)
43    let sqlite_needs_recreation =
44        *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
45
46    if sqlite_needs_recreation {
47        let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
48
49        let mut new_columns = table_def.columns.clone();
50        new_columns.push(column.clone());
51
52        let temp_table = format!("{}_temp", table);
53
54        // 1. Create temporary table with all CHECK constraints (enum + explicit)
55        let create_query = build_sqlite_temp_table_create(
56            backend,
57            &temp_table,
58            table,
59            &new_columns,
60            &table_def.constraints,
61        );
62
63        // Copy existing data, filling new column
64        let mut select_query = Query::select();
65        for col in &table_def.columns {
66            select_query = select_query.column(Alias::new(&col.name)).to_owned();
67        }
68        let normalized_fill = normalize_fill_with(fill_with);
69        let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
70            let converted = convert_default_for_backend(fill, backend);
71            Expr::cust(normalize_enum_default(&column.r#type, &converted))
72        } else if let Some(def) = &column.default {
73            let converted = convert_default_for_backend(&def.to_sql(), backend);
74            Expr::cust(normalize_enum_default(&column.r#type, &converted))
75        } else {
76            Expr::cust("NULL")
77        };
78        select_query = select_query
79            .expr_as(fill_expr, Alias::new(&column.name))
80            .from(Alias::new(table))
81            .to_owned();
82
83        let mut columns_alias: Vec<Alias> = table_def
84            .columns
85            .iter()
86            .map(|c| Alias::new(&c.name))
87            .collect();
88        columns_alias.push(Alias::new(&column.name));
89        let insert_stmt = Query::insert()
90            .into_table(Alias::new(&temp_table))
91            .columns(columns_alias)
92            .select_from(select_query)
93            .unwrap()
94            .to_owned();
95        let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
96
97        let drop_query =
98            BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
99        let rename_query = build_rename_table(&temp_table, table);
100
101        // Recreate indexes (both regular and UNIQUE)
102        let index_queries = recreate_indexes_after_rebuild(table, &table_def.constraints, &[]);
103
104        let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
105        stmts.extend(index_queries);
106        return Ok(stmts);
107    }
108
109    let mut stmts: Vec<BuiltQuery> = Vec::new();
110
111    // If column type is an enum, create the type first (PostgreSQL only)
112    if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
113        stmts.push(BuiltQuery::Raw(create_type_sql));
114    }
115
116    // If adding NOT NULL without default, we need special handling
117    let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
118
119    if needs_backfill {
120        // Add as nullable first
121        let mut temp_col = column.clone();
122        temp_col.nullable = true;
123
124        stmts.push(BuiltQuery::AlterTable(Box::new(
125            build_add_column_alter_for_backend(backend, table, &temp_col),
126        )));
127
128        // Backfill with provided value
129        if let Some(fill) = normalize_fill_with(fill_with) {
130            let fill = convert_default_for_backend(&fill, backend);
131            let update_stmt = Query::update()
132                .table(Alias::new(table))
133                .value(Alias::new(&column.name), Expr::cust(fill))
134                .to_owned();
135            stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
136        }
137
138        // Set NOT NULL
139        let not_null_col = build_sea_column_def_with_table(backend, table, column);
140        let alter_not_null = Table::alter()
141            .table(Alias::new(table))
142            .modify_column(not_null_col)
143            .to_owned();
144        stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
145    } else {
146        stmts.push(BuiltQuery::AlterTable(Box::new(
147            build_add_column_alter_for_backend(backend, table, column),
148        )));
149    }
150
151    Ok(stmts)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use insta::{assert_snapshot, with_settings};
158    use rstest::rstest;
159    use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
160
161    #[rstest]
162    #[case::add_column_with_backfill_postgres(
163        "add_column_with_backfill_postgres",
164        DatabaseBackend::Postgres,
165        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
166    )]
167    #[case::add_column_with_backfill_mysql(
168        "add_column_with_backfill_mysql",
169        DatabaseBackend::MySql,
170        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
171    )]
172    #[case::add_column_with_backfill_sqlite(
173        "add_column_with_backfill_sqlite",
174        DatabaseBackend::Sqlite,
175        &["CREATE TABLE \"users_temp\""]
176    )]
177    #[case::add_column_simple_postgres(
178        "add_column_simple_postgres",
179        DatabaseBackend::Postgres,
180        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
181    )]
182    #[case::add_column_simple_mysql(
183        "add_column_simple_mysql",
184        DatabaseBackend::MySql,
185        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
186    )]
187    #[case::add_column_simple_sqlite(
188        "add_column_simple_sqlite",
189        DatabaseBackend::Sqlite,
190        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
191    )]
192    #[case::add_column_nullable_postgres(
193        "add_column_nullable_postgres",
194        DatabaseBackend::Postgres,
195        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
196    )]
197    #[case::add_column_nullable_mysql(
198        "add_column_nullable_mysql",
199        DatabaseBackend::MySql,
200        &["ALTER TABLE `users` ADD COLUMN `email` text"]
201    )]
202    #[case::add_column_nullable_sqlite(
203        "add_column_nullable_sqlite",
204        DatabaseBackend::Sqlite,
205        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
206    )]
207    fn test_add_column(
208        #[case] title: &str,
209        #[case] backend: DatabaseBackend,
210        #[case] expected: &[&str],
211    ) {
212        let column = ColumnDef {
213            name: if title.contains("age") {
214                "age"
215            } else if title.contains("nullable") {
216                "email"
217            } else {
218                "nickname"
219            }
220            .into(),
221            r#type: if title.contains("age") {
222                ColumnType::Simple(SimpleColumnType::Integer)
223            } else {
224                ColumnType::Simple(SimpleColumnType::Text)
225            },
226            nullable: !title.contains("backfill"),
227            default: None,
228            comment: None,
229            primary_key: None,
230            unique: None,
231            index: None,
232            foreign_key: None,
233        };
234        let fill_with = if title.contains("backfill") {
235            Some("0")
236        } else {
237            None
238        };
239        let current_schema = vec![TableDef {
240            name: "users".into(),
241            description: None,
242            columns: vec![ColumnDef {
243                name: "id".into(),
244                r#type: ColumnType::Simple(SimpleColumnType::Integer),
245                nullable: false,
246                default: None,
247                comment: None,
248                primary_key: None,
249                unique: None,
250                index: None,
251                foreign_key: None,
252            }],
253            constraints: vec![],
254        }];
255        let result =
256            build_add_column(&backend, "users", &column, fill_with, &current_schema).unwrap();
257        let sql = result[0].build(backend);
258        for exp in expected {
259            assert!(
260                sql.contains(exp),
261                "Expected SQL to contain '{}', got: {}",
262                exp,
263                sql
264            );
265        }
266
267        with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
268            assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
269        });
270    }
271
272    #[test]
273    fn test_add_column_sqlite_table_not_found() {
274        let column = ColumnDef {
275            name: "nickname".into(),
276            r#type: ColumnType::Simple(SimpleColumnType::Text),
277            nullable: false,
278            default: None,
279            comment: None,
280            primary_key: None,
281            unique: None,
282            index: None,
283            foreign_key: None,
284        };
285        let current_schema = vec![]; // Empty schema - table not found
286        let result = build_add_column(
287            &DatabaseBackend::Sqlite,
288            "users",
289            &column,
290            None,
291            &current_schema,
292        );
293        assert!(result.is_err());
294        let err_msg = result.unwrap_err().to_string();
295        assert!(err_msg.contains("Table 'users' not found in current schema"));
296    }
297
298    #[test]
299    fn test_add_column_sqlite_with_default() {
300        let column = ColumnDef {
301            name: "age".into(),
302            r#type: ColumnType::Simple(SimpleColumnType::Integer),
303            nullable: false,
304            default: Some("18".into()),
305            comment: None,
306            primary_key: None,
307            unique: None,
308            index: None,
309            foreign_key: None,
310        };
311        let current_schema = vec![TableDef {
312            name: "users".into(),
313            description: None,
314            columns: vec![ColumnDef {
315                name: "id".into(),
316                r#type: ColumnType::Simple(SimpleColumnType::Integer),
317                nullable: false,
318                default: None,
319                comment: None,
320                primary_key: None,
321                unique: None,
322                index: None,
323                foreign_key: None,
324            }],
325            constraints: vec![],
326        }];
327        let result = build_add_column(
328            &DatabaseBackend::Sqlite,
329            "users",
330            &column,
331            None,
332            &current_schema,
333        );
334        assert!(result.is_ok());
335        let queries = result.unwrap();
336        let sql = queries
337            .iter()
338            .map(|q| q.build(DatabaseBackend::Sqlite))
339            .collect::<Vec<String>>()
340            .join("\n");
341        // Should use default value (18) for fill
342        assert!(sql.contains("18"));
343    }
344
345    #[test]
346    fn test_add_column_sqlite_without_fill_or_default() {
347        let column = ColumnDef {
348            name: "age".into(),
349            r#type: ColumnType::Simple(SimpleColumnType::Integer),
350            nullable: false,
351            default: None,
352            comment: None,
353            primary_key: None,
354            unique: None,
355            index: None,
356            foreign_key: None,
357        };
358        let current_schema = vec![TableDef {
359            name: "users".into(),
360            description: None,
361            columns: vec![ColumnDef {
362                name: "id".into(),
363                r#type: ColumnType::Simple(SimpleColumnType::Integer),
364                nullable: false,
365                default: None,
366                comment: None,
367                primary_key: None,
368                unique: None,
369                index: None,
370                foreign_key: None,
371            }],
372            constraints: vec![],
373        }];
374        let result = build_add_column(
375            &DatabaseBackend::Sqlite,
376            "users",
377            &column,
378            None,
379            &current_schema,
380        );
381        assert!(result.is_ok());
382        let queries = result.unwrap();
383        let sql = queries
384            .iter()
385            .map(|q| q.build(DatabaseBackend::Sqlite))
386            .collect::<Vec<String>>()
387            .join("\n");
388        // Should use NULL for fill
389        assert!(sql.contains("NULL"));
390    }
391
392    #[test]
393    fn test_add_column_sqlite_with_indexes() {
394        use vespertide_core::TableConstraint;
395
396        let column = ColumnDef {
397            name: "nickname".into(),
398            r#type: ColumnType::Simple(SimpleColumnType::Text),
399            nullable: false,
400            default: None,
401            comment: None,
402            primary_key: None,
403            unique: None,
404            index: None,
405            foreign_key: None,
406        };
407        let current_schema = vec![TableDef {
408            name: "users".into(),
409            description: None,
410            columns: vec![ColumnDef {
411                name: "id".into(),
412                r#type: ColumnType::Simple(SimpleColumnType::Integer),
413                nullable: false,
414                default: None,
415                comment: None,
416                primary_key: None,
417                unique: None,
418                index: None,
419                foreign_key: None,
420            }],
421            constraints: vec![TableConstraint::Index {
422                name: Some("idx_id".into()),
423                columns: vec!["id".into()],
424            }],
425        }];
426        let result = build_add_column(
427            &DatabaseBackend::Sqlite,
428            "users",
429            &column,
430            None,
431            &current_schema,
432        );
433        assert!(result.is_ok());
434        let queries = result.unwrap();
435        let sql = queries
436            .iter()
437            .map(|q| q.build(DatabaseBackend::Sqlite))
438            .collect::<Vec<String>>()
439            .join("\n");
440        // Should recreate index
441        assert!(sql.contains("CREATE INDEX"));
442        assert!(sql.contains("idx_id"));
443    }
444
445    #[rstest]
446    #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
447    #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
448    #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
449    fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
450        use insta::{assert_snapshot, with_settings};
451        use vespertide_core::{ComplexColumnType, EnumValues};
452
453        // Test that adding an enum column creates the enum type first (PostgreSQL only)
454        let column = ColumnDef {
455            name: "status".into(),
456            r#type: ColumnType::Complex(ComplexColumnType::Enum {
457                name: "status_type".into(),
458                values: EnumValues::String(vec!["active".into(), "inactive".into()]),
459            }),
460            nullable: true,
461            default: None,
462            comment: None,
463            primary_key: None,
464            unique: None,
465            index: None,
466            foreign_key: None,
467        };
468        let current_schema = vec![TableDef {
469            name: "users".into(),
470            description: None,
471            columns: vec![ColumnDef {
472                name: "id".into(),
473                r#type: ColumnType::Simple(SimpleColumnType::Integer),
474                nullable: false,
475                default: None,
476                comment: None,
477                primary_key: None,
478                unique: None,
479                index: None,
480                foreign_key: None,
481            }],
482            constraints: vec![],
483        }];
484        let result = build_add_column(&backend, "users", &column, None, &current_schema);
485        assert!(result.is_ok());
486        let queries = result.unwrap();
487        let sql = queries
488            .iter()
489            .map(|q| q.build(backend))
490            .collect::<Vec<String>>()
491            .join(";\n");
492
493        with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
494            assert_snapshot!(sql);
495        });
496    }
497
498    #[rstest]
499    #[case::postgres(DatabaseBackend::Postgres)]
500    #[case::mysql(DatabaseBackend::MySql)]
501    #[case::sqlite(DatabaseBackend::Sqlite)]
502    fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
503        use insta::{assert_snapshot, with_settings};
504        use vespertide_core::{ComplexColumnType, EnumValues};
505
506        // Test adding an enum column that is non-nullable with a default value
507        let column = ColumnDef {
508            name: "status".into(),
509            r#type: ColumnType::Complex(ComplexColumnType::Enum {
510                name: "user_status".into(),
511                values: EnumValues::String(vec![
512                    "active".into(),
513                    "inactive".into(),
514                    "pending".into(),
515                ]),
516            }),
517            nullable: false,
518            default: Some("active".into()),
519            comment: None,
520            primary_key: None,
521            unique: None,
522            index: None,
523            foreign_key: None,
524        };
525        let current_schema = vec![TableDef {
526            name: "users".into(),
527            description: None,
528            columns: vec![ColumnDef {
529                name: "id".into(),
530                r#type: ColumnType::Simple(SimpleColumnType::Integer),
531                nullable: false,
532                default: None,
533                comment: None,
534                primary_key: None,
535                unique: None,
536                index: None,
537                foreign_key: None,
538            }],
539            constraints: vec![],
540        }];
541        let result = build_add_column(&backend, "users", &column, None, &current_schema);
542        assert!(result.is_ok());
543        let queries = result.unwrap();
544        let sql = queries
545            .iter()
546            .map(|q| q.build(backend))
547            .collect::<Vec<String>>()
548            .join(";\n");
549
550        with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
551            assert_snapshot!(sql);
552        });
553    }
554
555    #[rstest]
556    #[case::postgres(DatabaseBackend::Postgres)]
557    #[case::mysql(DatabaseBackend::MySql)]
558    #[case::sqlite(DatabaseBackend::Sqlite)]
559    fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
560        use insta::{assert_snapshot, with_settings};
561
562        // Test adding a text column with empty string default
563        let column = ColumnDef {
564            name: "nickname".into(),
565            r#type: ColumnType::Simple(SimpleColumnType::Text),
566            nullable: false,
567            default: Some("".into()), // Empty string default
568            comment: None,
569            primary_key: None,
570            unique: None,
571            index: None,
572            foreign_key: None,
573        };
574        let current_schema = vec![TableDef {
575            name: "users".into(),
576            description: None,
577            columns: vec![ColumnDef {
578                name: "id".into(),
579                r#type: ColumnType::Simple(SimpleColumnType::Integer),
580                nullable: false,
581                default: None,
582                comment: None,
583                primary_key: None,
584                unique: None,
585                index: None,
586                foreign_key: None,
587            }],
588            constraints: vec![],
589        }];
590        let result = build_add_column(&backend, "users", &column, None, &current_schema);
591        assert!(result.is_ok());
592        let queries = result.unwrap();
593        let sql = queries
594            .iter()
595            .map(|q| q.build(backend))
596            .collect::<Vec<String>>()
597            .join(";\n");
598
599        // Verify empty string becomes ''
600        assert!(
601            sql.contains("''"),
602            "Expected SQL to contain empty string literal '', got: {}",
603            sql
604        );
605
606        with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
607            assert_snapshot!(sql);
608        });
609    }
610
611    /// Test adding NOT NULL column with '[]'::json default on SQLite
612    /// SQLite should strip the ::json cast, MySQL should use CAST(... AS JSON)
613    #[rstest]
614    #[case::postgres(DatabaseBackend::Postgres)]
615    #[case::mysql(DatabaseBackend::MySql)]
616    #[case::sqlite(DatabaseBackend::Sqlite)]
617    fn test_add_column_with_pg_type_cast_default(#[case] backend: DatabaseBackend) {
618        let column = ColumnDef {
619            name: "story_index".into(),
620            r#type: ColumnType::Simple(SimpleColumnType::Json),
621            nullable: false,
622            default: Some("'[]'::json".into()),
623            comment: None,
624            primary_key: None,
625            unique: None,
626            index: None,
627            foreign_key: None,
628        };
629        let current_schema = vec![TableDef {
630            name: "project".into(),
631            description: None,
632            columns: vec![ColumnDef {
633                name: "id".into(),
634                r#type: ColumnType::Simple(SimpleColumnType::Integer),
635                nullable: false,
636                default: None,
637                comment: None,
638                primary_key: None,
639                unique: None,
640                index: None,
641                foreign_key: None,
642            }],
643            constraints: vec![],
644        }];
645        let result = build_add_column(&backend, "project", &column, None, &current_schema).unwrap();
646        let sql = result
647            .iter()
648            .map(|q| q.build(backend))
649            .collect::<Vec<String>>()
650            .join("\n");
651
652        // SQLite must NOT contain ::json syntax
653        if backend == DatabaseBackend::Sqlite {
654            assert!(
655                !sql.contains("::json"),
656                "SQLite SQL should not contain ::json cast, got: {}",
657                sql
658            );
659        }
660
661        // MySQL should use CAST syntax
662        if backend == DatabaseBackend::MySql {
663            assert!(
664                !sql.contains("::json"),
665                "MySQL SQL should not contain ::json cast, got: {}",
666                sql
667            );
668        }
669
670        with_settings!({ snapshot_suffix => format!("pg_type_cast_default_{:?}", backend) }, {
671            assert_snapshot!(sql);
672        });
673    }
674
675    #[rstest]
676    #[case::postgres(DatabaseBackend::Postgres)]
677    #[case::mysql(DatabaseBackend::MySql)]
678    #[case::sqlite(DatabaseBackend::Sqlite)]
679    fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
680        use insta::{assert_snapshot, with_settings};
681
682        // Test adding a column with fill_with as empty string
683        let column = ColumnDef {
684            name: "nickname".into(),
685            r#type: ColumnType::Simple(SimpleColumnType::Text),
686            nullable: false,
687            default: None,
688            comment: None,
689            primary_key: None,
690            unique: None,
691            index: None,
692            foreign_key: None,
693        };
694        let current_schema = vec![TableDef {
695            name: "users".into(),
696            description: None,
697            columns: vec![ColumnDef {
698                name: "id".into(),
699                r#type: ColumnType::Simple(SimpleColumnType::Integer),
700                nullable: false,
701                default: None,
702                comment: None,
703                primary_key: None,
704                unique: None,
705                index: None,
706                foreign_key: None,
707            }],
708            constraints: vec![],
709        }];
710        // fill_with empty string should become ''
711        let result = build_add_column(&backend, "users", &column, Some(""), &current_schema);
712        assert!(result.is_ok());
713        let queries = result.unwrap();
714        let sql = queries
715            .iter()
716            .map(|q| q.build(backend))
717            .collect::<Vec<String>>()
718            .join(";\n");
719
720        // Verify empty string becomes ''
721        assert!(
722            sql.contains("''"),
723            "Expected SQL to contain empty string literal '', got: {}",
724            sql
725        );
726
727        with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
728            assert_snapshot!(sql);
729        });
730    }
731}