vespertide_query/sql/
add_column.rs

1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{
7    build_create_enum_type_sql, build_schema_statement, build_sea_column_def_with_table,
8    collect_sqlite_enum_check_clauses, normalize_enum_default, normalize_fill_with,
9};
10use super::rename_table::build_rename_table;
11use super::types::{BuiltQuery, DatabaseBackend, RawSql};
12use crate::error::QueryError;
13
14fn build_add_column_alter_for_backend(
15    backend: &DatabaseBackend,
16    table: &str,
17    column: &ColumnDef,
18) -> TableAlterStatement {
19    let col_def = build_sea_column_def_with_table(backend, table, column);
20    Table::alter()
21        .table(Alias::new(table))
22        .add_column(col_def)
23        .to_owned()
24}
25
26/// Check if the column type is an enum
27fn is_enum_column(column: &ColumnDef) -> bool {
28    matches!(
29        column.r#type,
30        vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
31    )
32}
33
34pub fn build_add_column(
35    backend: &DatabaseBackend,
36    table: &str,
37    column: &ColumnDef,
38    fill_with: Option<&str>,
39    current_schema: &[TableDef],
40) -> Result<Vec<BuiltQuery>, QueryError> {
41    // SQLite: NOT NULL additions or enum columns require table recreation
42    // (enum columns need CHECK constraint which requires table recreation in SQLite)
43    let sqlite_needs_recreation =
44        *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
45
46    if sqlite_needs_recreation {
47        let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
48
49        let mut new_columns = table_def.columns.clone();
50        new_columns.push(column.clone());
51
52        let temp_table = format!("{}_temp", table);
53        let create_temp = build_create_table_for_backend(
54            backend,
55            &temp_table,
56            &new_columns,
57            &table_def.constraints,
58        );
59
60        // For SQLite, add CHECK constraints for enum columns
61        // Use original table name for constraint naming (table will be renamed back)
62        let enum_check_clauses = collect_sqlite_enum_check_clauses(table, &new_columns);
63        let create_query = if !enum_check_clauses.is_empty() {
64            let base_sql = build_schema_statement(&create_temp, *backend);
65            let mut modified_sql = base_sql;
66            if let Some(pos) = modified_sql.rfind(')') {
67                let check_sql = enum_check_clauses.join(", ");
68                modified_sql.insert_str(pos, &format!(", {}", check_sql));
69            }
70            BuiltQuery::Raw(RawSql::per_backend(
71                modified_sql.clone(),
72                modified_sql.clone(),
73                modified_sql,
74            ))
75        } else {
76            BuiltQuery::CreateTable(Box::new(create_temp))
77        };
78
79        // Copy existing data, filling new column
80        let mut select_query = Query::select();
81        for col in &table_def.columns {
82            select_query = select_query.column(Alias::new(&col.name)).to_owned();
83        }
84        let normalized_fill = normalize_fill_with(fill_with);
85        let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
86            Expr::cust(normalize_enum_default(&column.r#type, fill))
87        } else if let Some(def) = &column.default {
88            Expr::cust(normalize_enum_default(&column.r#type, &def.to_sql()))
89        } else {
90            Expr::cust("NULL")
91        };
92        select_query = select_query
93            .expr_as(fill_expr, Alias::new(&column.name))
94            .from(Alias::new(table))
95            .to_owned();
96
97        let mut columns_alias: Vec<Alias> = table_def
98            .columns
99            .iter()
100            .map(|c| Alias::new(&c.name))
101            .collect();
102        columns_alias.push(Alias::new(&column.name));
103        let insert_stmt = Query::insert()
104            .into_table(Alias::new(&temp_table))
105            .columns(columns_alias)
106            .select_from(select_query)
107            .unwrap()
108            .to_owned();
109        let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
110
111        let drop_query =
112            BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
113        let rename_query = build_rename_table(&temp_table, table);
114
115        // Recreate indexes from Index constraints
116        let mut index_queries = Vec::new();
117        for constraint in &table_def.constraints {
118            if let vespertide_core::TableConstraint::Index { name, columns } = constraint {
119                let index_name =
120                    vespertide_naming::build_index_name(table, columns, name.as_deref());
121                let mut idx_stmt = sea_query::Index::create();
122                idx_stmt = idx_stmt.name(&index_name).to_owned();
123                for col_name in columns {
124                    idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
125                }
126                idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
127                index_queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
128            }
129        }
130
131        let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
132        stmts.extend(index_queries);
133        return Ok(stmts);
134    }
135
136    let mut stmts: Vec<BuiltQuery> = Vec::new();
137
138    // If column type is an enum, create the type first (PostgreSQL only)
139    if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
140        stmts.push(BuiltQuery::Raw(create_type_sql));
141    }
142
143    // If adding NOT NULL without default, we need special handling
144    let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
145
146    if needs_backfill {
147        // Add as nullable first
148        let mut temp_col = column.clone();
149        temp_col.nullable = true;
150
151        stmts.push(BuiltQuery::AlterTable(Box::new(
152            build_add_column_alter_for_backend(backend, table, &temp_col),
153        )));
154
155        // Backfill with provided value
156        if let Some(fill) = normalize_fill_with(fill_with) {
157            let update_stmt = Query::update()
158                .table(Alias::new(table))
159                .value(Alias::new(&column.name), Expr::cust(fill))
160                .to_owned();
161            stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
162        }
163
164        // Set NOT NULL
165        let not_null_col = build_sea_column_def_with_table(backend, table, column);
166        let alter_not_null = Table::alter()
167            .table(Alias::new(table))
168            .modify_column(not_null_col)
169            .to_owned();
170        stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
171    } else {
172        stmts.push(BuiltQuery::AlterTable(Box::new(
173            build_add_column_alter_for_backend(backend, table, column),
174        )));
175    }
176
177    Ok(stmts)
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use insta::{assert_snapshot, with_settings};
184    use rstest::rstest;
185    use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
186
187    #[rstest]
188    #[case::add_column_with_backfill_postgres(
189        "add_column_with_backfill_postgres",
190        DatabaseBackend::Postgres,
191        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
192    )]
193    #[case::add_column_with_backfill_mysql(
194        "add_column_with_backfill_mysql",
195        DatabaseBackend::MySql,
196        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
197    )]
198    #[case::add_column_with_backfill_sqlite(
199        "add_column_with_backfill_sqlite",
200        DatabaseBackend::Sqlite,
201        &["CREATE TABLE \"users_temp\""]
202    )]
203    #[case::add_column_simple_postgres(
204        "add_column_simple_postgres",
205        DatabaseBackend::Postgres,
206        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
207    )]
208    #[case::add_column_simple_mysql(
209        "add_column_simple_mysql",
210        DatabaseBackend::MySql,
211        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
212    )]
213    #[case::add_column_simple_sqlite(
214        "add_column_simple_sqlite",
215        DatabaseBackend::Sqlite,
216        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
217    )]
218    #[case::add_column_nullable_postgres(
219        "add_column_nullable_postgres",
220        DatabaseBackend::Postgres,
221        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
222    )]
223    #[case::add_column_nullable_mysql(
224        "add_column_nullable_mysql",
225        DatabaseBackend::MySql,
226        &["ALTER TABLE `users` ADD COLUMN `email` text"]
227    )]
228    #[case::add_column_nullable_sqlite(
229        "add_column_nullable_sqlite",
230        DatabaseBackend::Sqlite,
231        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
232    )]
233    fn test_add_column(
234        #[case] title: &str,
235        #[case] backend: DatabaseBackend,
236        #[case] expected: &[&str],
237    ) {
238        let column = ColumnDef {
239            name: if title.contains("age") {
240                "age"
241            } else if title.contains("nullable") {
242                "email"
243            } else {
244                "nickname"
245            }
246            .into(),
247            r#type: if title.contains("age") {
248                ColumnType::Simple(SimpleColumnType::Integer)
249            } else {
250                ColumnType::Simple(SimpleColumnType::Text)
251            },
252            nullable: !title.contains("backfill"),
253            default: None,
254            comment: None,
255            primary_key: None,
256            unique: None,
257            index: None,
258            foreign_key: None,
259        };
260        let fill_with = if title.contains("backfill") {
261            Some("0")
262        } else {
263            None
264        };
265        let current_schema = vec![TableDef {
266            name: "users".into(),
267            description: None,
268            columns: vec![ColumnDef {
269                name: "id".into(),
270                r#type: ColumnType::Simple(SimpleColumnType::Integer),
271                nullable: false,
272                default: None,
273                comment: None,
274                primary_key: None,
275                unique: None,
276                index: None,
277                foreign_key: None,
278            }],
279            constraints: vec![],
280        }];
281        let result =
282            build_add_column(&backend, "users", &column, fill_with, &current_schema).unwrap();
283        let sql = result[0].build(backend);
284        for exp in expected {
285            assert!(
286                sql.contains(exp),
287                "Expected SQL to contain '{}', got: {}",
288                exp,
289                sql
290            );
291        }
292
293        with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
294            assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
295        });
296    }
297
298    #[test]
299    fn test_add_column_sqlite_table_not_found() {
300        let column = ColumnDef {
301            name: "nickname".into(),
302            r#type: ColumnType::Simple(SimpleColumnType::Text),
303            nullable: false,
304            default: None,
305            comment: None,
306            primary_key: None,
307            unique: None,
308            index: None,
309            foreign_key: None,
310        };
311        let current_schema = vec![]; // Empty schema - table not found
312        let result = build_add_column(
313            &DatabaseBackend::Sqlite,
314            "users",
315            &column,
316            None,
317            &current_schema,
318        );
319        assert!(result.is_err());
320        let err_msg = result.unwrap_err().to_string();
321        assert!(err_msg.contains("Table 'users' not found in current schema"));
322    }
323
324    #[test]
325    fn test_add_column_sqlite_with_default() {
326        let column = ColumnDef {
327            name: "age".into(),
328            r#type: ColumnType::Simple(SimpleColumnType::Integer),
329            nullable: false,
330            default: Some("18".into()),
331            comment: None,
332            primary_key: None,
333            unique: None,
334            index: None,
335            foreign_key: None,
336        };
337        let current_schema = vec![TableDef {
338            name: "users".into(),
339            description: None,
340            columns: vec![ColumnDef {
341                name: "id".into(),
342                r#type: ColumnType::Simple(SimpleColumnType::Integer),
343                nullable: false,
344                default: None,
345                comment: None,
346                primary_key: None,
347                unique: None,
348                index: None,
349                foreign_key: None,
350            }],
351            constraints: vec![],
352        }];
353        let result = build_add_column(
354            &DatabaseBackend::Sqlite,
355            "users",
356            &column,
357            None,
358            &current_schema,
359        );
360        assert!(result.is_ok());
361        let queries = result.unwrap();
362        let sql = queries
363            .iter()
364            .map(|q| q.build(DatabaseBackend::Sqlite))
365            .collect::<Vec<String>>()
366            .join("\n");
367        // Should use default value (18) for fill
368        assert!(sql.contains("18"));
369    }
370
371    #[test]
372    fn test_add_column_sqlite_without_fill_or_default() {
373        let column = ColumnDef {
374            name: "age".into(),
375            r#type: ColumnType::Simple(SimpleColumnType::Integer),
376            nullable: false,
377            default: None,
378            comment: None,
379            primary_key: None,
380            unique: None,
381            index: None,
382            foreign_key: None,
383        };
384        let current_schema = vec![TableDef {
385            name: "users".into(),
386            description: None,
387            columns: vec![ColumnDef {
388                name: "id".into(),
389                r#type: ColumnType::Simple(SimpleColumnType::Integer),
390                nullable: false,
391                default: None,
392                comment: None,
393                primary_key: None,
394                unique: None,
395                index: None,
396                foreign_key: None,
397            }],
398            constraints: vec![],
399        }];
400        let result = build_add_column(
401            &DatabaseBackend::Sqlite,
402            "users",
403            &column,
404            None,
405            &current_schema,
406        );
407        assert!(result.is_ok());
408        let queries = result.unwrap();
409        let sql = queries
410            .iter()
411            .map(|q| q.build(DatabaseBackend::Sqlite))
412            .collect::<Vec<String>>()
413            .join("\n");
414        // Should use NULL for fill
415        assert!(sql.contains("NULL"));
416    }
417
418    #[test]
419    fn test_add_column_sqlite_with_indexes() {
420        use vespertide_core::TableConstraint;
421
422        let column = ColumnDef {
423            name: "nickname".into(),
424            r#type: ColumnType::Simple(SimpleColumnType::Text),
425            nullable: false,
426            default: None,
427            comment: None,
428            primary_key: None,
429            unique: None,
430            index: None,
431            foreign_key: None,
432        };
433        let current_schema = vec![TableDef {
434            name: "users".into(),
435            description: None,
436            columns: vec![ColumnDef {
437                name: "id".into(),
438                r#type: ColumnType::Simple(SimpleColumnType::Integer),
439                nullable: false,
440                default: None,
441                comment: None,
442                primary_key: None,
443                unique: None,
444                index: None,
445                foreign_key: None,
446            }],
447            constraints: vec![TableConstraint::Index {
448                name: Some("idx_id".into()),
449                columns: vec!["id".into()],
450            }],
451        }];
452        let result = build_add_column(
453            &DatabaseBackend::Sqlite,
454            "users",
455            &column,
456            None,
457            &current_schema,
458        );
459        assert!(result.is_ok());
460        let queries = result.unwrap();
461        let sql = queries
462            .iter()
463            .map(|q| q.build(DatabaseBackend::Sqlite))
464            .collect::<Vec<String>>()
465            .join("\n");
466        // Should recreate index
467        assert!(sql.contains("CREATE INDEX"));
468        assert!(sql.contains("idx_id"));
469    }
470
471    #[rstest]
472    #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
473    #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
474    #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
475    fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
476        use insta::{assert_snapshot, with_settings};
477        use vespertide_core::{ComplexColumnType, EnumValues};
478
479        // Test that adding an enum column creates the enum type first (PostgreSQL only)
480        let column = ColumnDef {
481            name: "status".into(),
482            r#type: ColumnType::Complex(ComplexColumnType::Enum {
483                name: "status_type".into(),
484                values: EnumValues::String(vec!["active".into(), "inactive".into()]),
485            }),
486            nullable: true,
487            default: None,
488            comment: None,
489            primary_key: None,
490            unique: None,
491            index: None,
492            foreign_key: None,
493        };
494        let current_schema = vec![TableDef {
495            name: "users".into(),
496            description: None,
497            columns: vec![ColumnDef {
498                name: "id".into(),
499                r#type: ColumnType::Simple(SimpleColumnType::Integer),
500                nullable: false,
501                default: None,
502                comment: None,
503                primary_key: None,
504                unique: None,
505                index: None,
506                foreign_key: None,
507            }],
508            constraints: vec![],
509        }];
510        let result = build_add_column(&backend, "users", &column, None, &current_schema);
511        assert!(result.is_ok());
512        let queries = result.unwrap();
513        let sql = queries
514            .iter()
515            .map(|q| q.build(backend))
516            .collect::<Vec<String>>()
517            .join(";\n");
518
519        with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
520            assert_snapshot!(sql);
521        });
522    }
523
524    #[rstest]
525    #[case::postgres(DatabaseBackend::Postgres)]
526    #[case::mysql(DatabaseBackend::MySql)]
527    #[case::sqlite(DatabaseBackend::Sqlite)]
528    fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
529        use insta::{assert_snapshot, with_settings};
530        use vespertide_core::{ComplexColumnType, EnumValues};
531
532        // Test adding an enum column that is non-nullable with a default value
533        let column = ColumnDef {
534            name: "status".into(),
535            r#type: ColumnType::Complex(ComplexColumnType::Enum {
536                name: "user_status".into(),
537                values: EnumValues::String(vec![
538                    "active".into(),
539                    "inactive".into(),
540                    "pending".into(),
541                ]),
542            }),
543            nullable: false,
544            default: Some("active".into()),
545            comment: None,
546            primary_key: None,
547            unique: None,
548            index: None,
549            foreign_key: None,
550        };
551        let current_schema = vec![TableDef {
552            name: "users".into(),
553            description: None,
554            columns: vec![ColumnDef {
555                name: "id".into(),
556                r#type: ColumnType::Simple(SimpleColumnType::Integer),
557                nullable: false,
558                default: None,
559                comment: None,
560                primary_key: None,
561                unique: None,
562                index: None,
563                foreign_key: None,
564            }],
565            constraints: vec![],
566        }];
567        let result = build_add_column(&backend, "users", &column, None, &current_schema);
568        assert!(result.is_ok());
569        let queries = result.unwrap();
570        let sql = queries
571            .iter()
572            .map(|q| q.build(backend))
573            .collect::<Vec<String>>()
574            .join(";\n");
575
576        with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
577            assert_snapshot!(sql);
578        });
579    }
580
581    #[rstest]
582    #[case::postgres(DatabaseBackend::Postgres)]
583    #[case::mysql(DatabaseBackend::MySql)]
584    #[case::sqlite(DatabaseBackend::Sqlite)]
585    fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
586        use insta::{assert_snapshot, with_settings};
587
588        // Test adding a text column with empty string default
589        let column = ColumnDef {
590            name: "nickname".into(),
591            r#type: ColumnType::Simple(SimpleColumnType::Text),
592            nullable: false,
593            default: Some("".into()), // Empty string default
594            comment: None,
595            primary_key: None,
596            unique: None,
597            index: None,
598            foreign_key: None,
599        };
600        let current_schema = vec![TableDef {
601            name: "users".into(),
602            description: None,
603            columns: vec![ColumnDef {
604                name: "id".into(),
605                r#type: ColumnType::Simple(SimpleColumnType::Integer),
606                nullable: false,
607                default: None,
608                comment: None,
609                primary_key: None,
610                unique: None,
611                index: None,
612                foreign_key: None,
613            }],
614            constraints: vec![],
615        }];
616        let result = build_add_column(&backend, "users", &column, None, &current_schema);
617        assert!(result.is_ok());
618        let queries = result.unwrap();
619        let sql = queries
620            .iter()
621            .map(|q| q.build(backend))
622            .collect::<Vec<String>>()
623            .join(";\n");
624
625        // Verify empty string becomes ''
626        assert!(
627            sql.contains("''"),
628            "Expected SQL to contain empty string literal '', got: {}",
629            sql
630        );
631
632        with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
633            assert_snapshot!(sql);
634        });
635    }
636
637    #[rstest]
638    #[case::postgres(DatabaseBackend::Postgres)]
639    #[case::mysql(DatabaseBackend::MySql)]
640    #[case::sqlite(DatabaseBackend::Sqlite)]
641    fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
642        use insta::{assert_snapshot, with_settings};
643
644        // Test adding a column with fill_with as empty string
645        let column = ColumnDef {
646            name: "nickname".into(),
647            r#type: ColumnType::Simple(SimpleColumnType::Text),
648            nullable: false,
649            default: None,
650            comment: None,
651            primary_key: None,
652            unique: None,
653            index: None,
654            foreign_key: None,
655        };
656        let current_schema = vec![TableDef {
657            name: "users".into(),
658            description: None,
659            columns: vec![ColumnDef {
660                name: "id".into(),
661                r#type: ColumnType::Simple(SimpleColumnType::Integer),
662                nullable: false,
663                default: None,
664                comment: None,
665                primary_key: None,
666                unique: None,
667                index: None,
668                foreign_key: None,
669            }],
670            constraints: vec![],
671        }];
672        // fill_with empty string should become ''
673        let result = build_add_column(&backend, "users", &column, Some(""), &current_schema);
674        assert!(result.is_ok());
675        let queries = result.unwrap();
676        let sql = queries
677            .iter()
678            .map(|q| q.build(backend))
679            .collect::<Vec<String>>()
680            .join(";\n");
681
682        // Verify empty string becomes ''
683        assert!(
684            sql.contains("''"),
685            "Expected SQL to contain empty string literal '', got: {}",
686            sql
687        );
688
689        with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
690            assert_snapshot!(sql);
691        });
692    }
693}