Skip to main content

vespertide_query/sql/
add_column.rs

1use sea_query::{Alias, Expr, Query, Table, TableAlterStatement};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_create_enum_type_sql, build_sea_column_def_with_table, build_sqlite_temp_table_create,
7    normalize_enum_default, normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend};
11use crate::error::QueryError;
12
13fn build_add_column_alter_for_backend(
14    backend: &DatabaseBackend,
15    table: &str,
16    column: &ColumnDef,
17) -> TableAlterStatement {
18    let col_def = build_sea_column_def_with_table(backend, table, column);
19    Table::alter()
20        .table(Alias::new(table))
21        .add_column(col_def)
22        .to_owned()
23}
24
25/// Check if the column type is an enum
26fn is_enum_column(column: &ColumnDef) -> bool {
27    matches!(
28        column.r#type,
29        vespertide_core::ColumnType::Complex(vespertide_core::ComplexColumnType::Enum { .. })
30    )
31}
32
33pub fn build_add_column(
34    backend: &DatabaseBackend,
35    table: &str,
36    column: &ColumnDef,
37    fill_with: Option<&str>,
38    current_schema: &[TableDef],
39) -> Result<Vec<BuiltQuery>, QueryError> {
40    // SQLite: NOT NULL additions or enum columns require table recreation
41    // (enum columns need CHECK constraint which requires table recreation in SQLite)
42    let sqlite_needs_recreation =
43        *backend == DatabaseBackend::Sqlite && (!column.nullable || is_enum_column(column));
44
45    if sqlite_needs_recreation {
46        let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to add columns.", table)))?;
47
48        let mut new_columns = table_def.columns.clone();
49        new_columns.push(column.clone());
50
51        let temp_table = format!("{}_temp", table);
52
53        // 1. Create temporary table with all CHECK constraints (enum + explicit)
54        let create_query = build_sqlite_temp_table_create(
55            backend,
56            &temp_table,
57            table,
58            &new_columns,
59            &table_def.constraints,
60        );
61
62        // Copy existing data, filling new column
63        let mut select_query = Query::select();
64        for col in &table_def.columns {
65            select_query = select_query.column(Alias::new(&col.name)).to_owned();
66        }
67        let normalized_fill = normalize_fill_with(fill_with);
68        let fill_expr = if let Some(fill) = normalized_fill.as_deref() {
69            Expr::cust(normalize_enum_default(&column.r#type, fill))
70        } else if let Some(def) = &column.default {
71            Expr::cust(normalize_enum_default(&column.r#type, &def.to_sql()))
72        } else {
73            Expr::cust("NULL")
74        };
75        select_query = select_query
76            .expr_as(fill_expr, Alias::new(&column.name))
77            .from(Alias::new(table))
78            .to_owned();
79
80        let mut columns_alias: Vec<Alias> = table_def
81            .columns
82            .iter()
83            .map(|c| Alias::new(&c.name))
84            .collect();
85        columns_alias.push(Alias::new(&column.name));
86        let insert_stmt = Query::insert()
87            .into_table(Alias::new(&temp_table))
88            .columns(columns_alias)
89            .select_from(select_query)
90            .unwrap()
91            .to_owned();
92        let insert_query = BuiltQuery::Insert(Box::new(insert_stmt));
93
94        let drop_query =
95            BuiltQuery::DropTable(Box::new(Table::drop().table(Alias::new(table)).to_owned()));
96        let rename_query = build_rename_table(&temp_table, table);
97
98        // Recreate indexes (both regular and UNIQUE)
99        let index_queries = recreate_indexes_after_rebuild(table, &table_def.constraints, &[]);
100
101        let mut stmts = vec![create_query, insert_query, drop_query, rename_query];
102        stmts.extend(index_queries);
103        return Ok(stmts);
104    }
105
106    let mut stmts: Vec<BuiltQuery> = Vec::new();
107
108    // If column type is an enum, create the type first (PostgreSQL only)
109    if let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type) {
110        stmts.push(BuiltQuery::Raw(create_type_sql));
111    }
112
113    // If adding NOT NULL without default, we need special handling
114    let needs_backfill = !column.nullable && column.default.is_none() && fill_with.is_some();
115
116    if needs_backfill {
117        // Add as nullable first
118        let mut temp_col = column.clone();
119        temp_col.nullable = true;
120
121        stmts.push(BuiltQuery::AlterTable(Box::new(
122            build_add_column_alter_for_backend(backend, table, &temp_col),
123        )));
124
125        // Backfill with provided value
126        if let Some(fill) = normalize_fill_with(fill_with) {
127            let update_stmt = Query::update()
128                .table(Alias::new(table))
129                .value(Alias::new(&column.name), Expr::cust(fill))
130                .to_owned();
131            stmts.push(BuiltQuery::Update(Box::new(update_stmt)));
132        }
133
134        // Set NOT NULL
135        let not_null_col = build_sea_column_def_with_table(backend, table, column);
136        let alter_not_null = Table::alter()
137            .table(Alias::new(table))
138            .modify_column(not_null_col)
139            .to_owned();
140        stmts.push(BuiltQuery::AlterTable(Box::new(alter_not_null)));
141    } else {
142        stmts.push(BuiltQuery::AlterTable(Box::new(
143            build_add_column_alter_for_backend(backend, table, column),
144        )));
145    }
146
147    Ok(stmts)
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use insta::{assert_snapshot, with_settings};
154    use rstest::rstest;
155    use vespertide_core::{ColumnType, SimpleColumnType, TableDef};
156
157    #[rstest]
158    #[case::add_column_with_backfill_postgres(
159        "add_column_with_backfill_postgres",
160        DatabaseBackend::Postgres,
161        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\" text"]
162    )]
163    #[case::add_column_with_backfill_mysql(
164        "add_column_with_backfill_mysql",
165        DatabaseBackend::MySql,
166        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
167    )]
168    #[case::add_column_with_backfill_sqlite(
169        "add_column_with_backfill_sqlite",
170        DatabaseBackend::Sqlite,
171        &["CREATE TABLE \"users_temp\""]
172    )]
173    #[case::add_column_simple_postgres(
174        "add_column_simple_postgres",
175        DatabaseBackend::Postgres,
176        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
177    )]
178    #[case::add_column_simple_mysql(
179        "add_column_simple_mysql",
180        DatabaseBackend::MySql,
181        &["ALTER TABLE `users` ADD COLUMN `nickname` text"]
182    )]
183    #[case::add_column_simple_sqlite(
184        "add_column_simple_sqlite",
185        DatabaseBackend::Sqlite,
186        &["ALTER TABLE \"users\" ADD COLUMN \"nickname\""]
187    )]
188    #[case::add_column_nullable_postgres(
189        "add_column_nullable_postgres",
190        DatabaseBackend::Postgres,
191        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
192    )]
193    #[case::add_column_nullable_mysql(
194        "add_column_nullable_mysql",
195        DatabaseBackend::MySql,
196        &["ALTER TABLE `users` ADD COLUMN `email` text"]
197    )]
198    #[case::add_column_nullable_sqlite(
199        "add_column_nullable_sqlite",
200        DatabaseBackend::Sqlite,
201        &["ALTER TABLE \"users\" ADD COLUMN \"email\" text"]
202    )]
203    fn test_add_column(
204        #[case] title: &str,
205        #[case] backend: DatabaseBackend,
206        #[case] expected: &[&str],
207    ) {
208        let column = ColumnDef {
209            name: if title.contains("age") {
210                "age"
211            } else if title.contains("nullable") {
212                "email"
213            } else {
214                "nickname"
215            }
216            .into(),
217            r#type: if title.contains("age") {
218                ColumnType::Simple(SimpleColumnType::Integer)
219            } else {
220                ColumnType::Simple(SimpleColumnType::Text)
221            },
222            nullable: !title.contains("backfill"),
223            default: None,
224            comment: None,
225            primary_key: None,
226            unique: None,
227            index: None,
228            foreign_key: None,
229        };
230        let fill_with = if title.contains("backfill") {
231            Some("0")
232        } else {
233            None
234        };
235        let current_schema = vec![TableDef {
236            name: "users".into(),
237            description: None,
238            columns: vec![ColumnDef {
239                name: "id".into(),
240                r#type: ColumnType::Simple(SimpleColumnType::Integer),
241                nullable: false,
242                default: None,
243                comment: None,
244                primary_key: None,
245                unique: None,
246                index: None,
247                foreign_key: None,
248            }],
249            constraints: vec![],
250        }];
251        let result =
252            build_add_column(&backend, "users", &column, fill_with, &current_schema).unwrap();
253        let sql = result[0].build(backend);
254        for exp in expected {
255            assert!(
256                sql.contains(exp),
257                "Expected SQL to contain '{}', got: {}",
258                exp,
259                sql
260            );
261        }
262
263        with_settings!({ snapshot_suffix => format!("add_column_{}", title) }, {
264            assert_snapshot!(result.iter().map(|q| q.build(backend)).collect::<Vec<String>>().join("\n"));
265        });
266    }
267
268    #[test]
269    fn test_add_column_sqlite_table_not_found() {
270        let column = ColumnDef {
271            name: "nickname".into(),
272            r#type: ColumnType::Simple(SimpleColumnType::Text),
273            nullable: false,
274            default: None,
275            comment: None,
276            primary_key: None,
277            unique: None,
278            index: None,
279            foreign_key: None,
280        };
281        let current_schema = vec![]; // Empty schema - table not found
282        let result = build_add_column(
283            &DatabaseBackend::Sqlite,
284            "users",
285            &column,
286            None,
287            &current_schema,
288        );
289        assert!(result.is_err());
290        let err_msg = result.unwrap_err().to_string();
291        assert!(err_msg.contains("Table 'users' not found in current schema"));
292    }
293
294    #[test]
295    fn test_add_column_sqlite_with_default() {
296        let column = ColumnDef {
297            name: "age".into(),
298            r#type: ColumnType::Simple(SimpleColumnType::Integer),
299            nullable: false,
300            default: Some("18".into()),
301            comment: None,
302            primary_key: None,
303            unique: None,
304            index: None,
305            foreign_key: None,
306        };
307        let current_schema = vec![TableDef {
308            name: "users".into(),
309            description: None,
310            columns: vec![ColumnDef {
311                name: "id".into(),
312                r#type: ColumnType::Simple(SimpleColumnType::Integer),
313                nullable: false,
314                default: None,
315                comment: None,
316                primary_key: None,
317                unique: None,
318                index: None,
319                foreign_key: None,
320            }],
321            constraints: vec![],
322        }];
323        let result = build_add_column(
324            &DatabaseBackend::Sqlite,
325            "users",
326            &column,
327            None,
328            &current_schema,
329        );
330        assert!(result.is_ok());
331        let queries = result.unwrap();
332        let sql = queries
333            .iter()
334            .map(|q| q.build(DatabaseBackend::Sqlite))
335            .collect::<Vec<String>>()
336            .join("\n");
337        // Should use default value (18) for fill
338        assert!(sql.contains("18"));
339    }
340
341    #[test]
342    fn test_add_column_sqlite_without_fill_or_default() {
343        let column = ColumnDef {
344            name: "age".into(),
345            r#type: ColumnType::Simple(SimpleColumnType::Integer),
346            nullable: false,
347            default: None,
348            comment: None,
349            primary_key: None,
350            unique: None,
351            index: None,
352            foreign_key: None,
353        };
354        let current_schema = vec![TableDef {
355            name: "users".into(),
356            description: None,
357            columns: vec![ColumnDef {
358                name: "id".into(),
359                r#type: ColumnType::Simple(SimpleColumnType::Integer),
360                nullable: false,
361                default: None,
362                comment: None,
363                primary_key: None,
364                unique: None,
365                index: None,
366                foreign_key: None,
367            }],
368            constraints: vec![],
369        }];
370        let result = build_add_column(
371            &DatabaseBackend::Sqlite,
372            "users",
373            &column,
374            None,
375            &current_schema,
376        );
377        assert!(result.is_ok());
378        let queries = result.unwrap();
379        let sql = queries
380            .iter()
381            .map(|q| q.build(DatabaseBackend::Sqlite))
382            .collect::<Vec<String>>()
383            .join("\n");
384        // Should use NULL for fill
385        assert!(sql.contains("NULL"));
386    }
387
388    #[test]
389    fn test_add_column_sqlite_with_indexes() {
390        use vespertide_core::TableConstraint;
391
392        let column = ColumnDef {
393            name: "nickname".into(),
394            r#type: ColumnType::Simple(SimpleColumnType::Text),
395            nullable: false,
396            default: None,
397            comment: None,
398            primary_key: None,
399            unique: None,
400            index: None,
401            foreign_key: None,
402        };
403        let current_schema = vec![TableDef {
404            name: "users".into(),
405            description: None,
406            columns: vec![ColumnDef {
407                name: "id".into(),
408                r#type: ColumnType::Simple(SimpleColumnType::Integer),
409                nullable: false,
410                default: None,
411                comment: None,
412                primary_key: None,
413                unique: None,
414                index: None,
415                foreign_key: None,
416            }],
417            constraints: vec![TableConstraint::Index {
418                name: Some("idx_id".into()),
419                columns: vec!["id".into()],
420            }],
421        }];
422        let result = build_add_column(
423            &DatabaseBackend::Sqlite,
424            "users",
425            &column,
426            None,
427            &current_schema,
428        );
429        assert!(result.is_ok());
430        let queries = result.unwrap();
431        let sql = queries
432            .iter()
433            .map(|q| q.build(DatabaseBackend::Sqlite))
434            .collect::<Vec<String>>()
435            .join("\n");
436        // Should recreate index
437        assert!(sql.contains("CREATE INDEX"));
438        assert!(sql.contains("idx_id"));
439    }
440
441    #[rstest]
442    #[case::add_column_with_enum_type_postgres(DatabaseBackend::Postgres)]
443    #[case::add_column_with_enum_type_mysql(DatabaseBackend::MySql)]
444    #[case::add_column_with_enum_type_sqlite(DatabaseBackend::Sqlite)]
445    fn test_add_column_with_enum_type(#[case] backend: DatabaseBackend) {
446        use insta::{assert_snapshot, with_settings};
447        use vespertide_core::{ComplexColumnType, EnumValues};
448
449        // Test that adding an enum column creates the enum type first (PostgreSQL only)
450        let column = ColumnDef {
451            name: "status".into(),
452            r#type: ColumnType::Complex(ComplexColumnType::Enum {
453                name: "status_type".into(),
454                values: EnumValues::String(vec!["active".into(), "inactive".into()]),
455            }),
456            nullable: true,
457            default: None,
458            comment: None,
459            primary_key: None,
460            unique: None,
461            index: None,
462            foreign_key: None,
463        };
464        let current_schema = vec![TableDef {
465            name: "users".into(),
466            description: None,
467            columns: vec![ColumnDef {
468                name: "id".into(),
469                r#type: ColumnType::Simple(SimpleColumnType::Integer),
470                nullable: false,
471                default: None,
472                comment: None,
473                primary_key: None,
474                unique: None,
475                index: None,
476                foreign_key: None,
477            }],
478            constraints: vec![],
479        }];
480        let result = build_add_column(&backend, "users", &column, None, &current_schema);
481        assert!(result.is_ok());
482        let queries = result.unwrap();
483        let sql = queries
484            .iter()
485            .map(|q| q.build(backend))
486            .collect::<Vec<String>>()
487            .join(";\n");
488
489        with_settings!({ snapshot_suffix => format!("add_column_with_enum_type_{:?}", backend) }, {
490            assert_snapshot!(sql);
491        });
492    }
493
494    #[rstest]
495    #[case::postgres(DatabaseBackend::Postgres)]
496    #[case::mysql(DatabaseBackend::MySql)]
497    #[case::sqlite(DatabaseBackend::Sqlite)]
498    fn test_add_column_enum_non_nullable_with_default(#[case] backend: DatabaseBackend) {
499        use insta::{assert_snapshot, with_settings};
500        use vespertide_core::{ComplexColumnType, EnumValues};
501
502        // Test adding an enum column that is non-nullable with a default value
503        let column = ColumnDef {
504            name: "status".into(),
505            r#type: ColumnType::Complex(ComplexColumnType::Enum {
506                name: "user_status".into(),
507                values: EnumValues::String(vec![
508                    "active".into(),
509                    "inactive".into(),
510                    "pending".into(),
511                ]),
512            }),
513            nullable: false,
514            default: Some("active".into()),
515            comment: None,
516            primary_key: None,
517            unique: None,
518            index: None,
519            foreign_key: None,
520        };
521        let current_schema = vec![TableDef {
522            name: "users".into(),
523            description: None,
524            columns: vec![ColumnDef {
525                name: "id".into(),
526                r#type: ColumnType::Simple(SimpleColumnType::Integer),
527                nullable: false,
528                default: None,
529                comment: None,
530                primary_key: None,
531                unique: None,
532                index: None,
533                foreign_key: None,
534            }],
535            constraints: vec![],
536        }];
537        let result = build_add_column(&backend, "users", &column, None, &current_schema);
538        assert!(result.is_ok());
539        let queries = result.unwrap();
540        let sql = queries
541            .iter()
542            .map(|q| q.build(backend))
543            .collect::<Vec<String>>()
544            .join(";\n");
545
546        with_settings!({ snapshot_suffix => format!("enum_non_nullable_with_default_{:?}", backend) }, {
547            assert_snapshot!(sql);
548        });
549    }
550
551    #[rstest]
552    #[case::postgres(DatabaseBackend::Postgres)]
553    #[case::mysql(DatabaseBackend::MySql)]
554    #[case::sqlite(DatabaseBackend::Sqlite)]
555    fn test_add_column_with_empty_string_default(#[case] backend: DatabaseBackend) {
556        use insta::{assert_snapshot, with_settings};
557
558        // Test adding a text column with empty string default
559        let column = ColumnDef {
560            name: "nickname".into(),
561            r#type: ColumnType::Simple(SimpleColumnType::Text),
562            nullable: false,
563            default: Some("".into()), // Empty string default
564            comment: None,
565            primary_key: None,
566            unique: None,
567            index: None,
568            foreign_key: None,
569        };
570        let current_schema = vec![TableDef {
571            name: "users".into(),
572            description: None,
573            columns: vec![ColumnDef {
574                name: "id".into(),
575                r#type: ColumnType::Simple(SimpleColumnType::Integer),
576                nullable: false,
577                default: None,
578                comment: None,
579                primary_key: None,
580                unique: None,
581                index: None,
582                foreign_key: None,
583            }],
584            constraints: vec![],
585        }];
586        let result = build_add_column(&backend, "users", &column, None, &current_schema);
587        assert!(result.is_ok());
588        let queries = result.unwrap();
589        let sql = queries
590            .iter()
591            .map(|q| q.build(backend))
592            .collect::<Vec<String>>()
593            .join(";\n");
594
595        // Verify empty string becomes ''
596        assert!(
597            sql.contains("''"),
598            "Expected SQL to contain empty string literal '', got: {}",
599            sql
600        );
601
602        with_settings!({ snapshot_suffix => format!("empty_string_default_{:?}", backend) }, {
603            assert_snapshot!(sql);
604        });
605    }
606
607    #[rstest]
608    #[case::postgres(DatabaseBackend::Postgres)]
609    #[case::mysql(DatabaseBackend::MySql)]
610    #[case::sqlite(DatabaseBackend::Sqlite)]
611    fn test_add_column_with_fill_with_empty_string(#[case] backend: DatabaseBackend) {
612        use insta::{assert_snapshot, with_settings};
613
614        // Test adding a column with fill_with as empty string
615        let column = ColumnDef {
616            name: "nickname".into(),
617            r#type: ColumnType::Simple(SimpleColumnType::Text),
618            nullable: false,
619            default: None,
620            comment: None,
621            primary_key: None,
622            unique: None,
623            index: None,
624            foreign_key: None,
625        };
626        let current_schema = vec![TableDef {
627            name: "users".into(),
628            description: None,
629            columns: vec![ColumnDef {
630                name: "id".into(),
631                r#type: ColumnType::Simple(SimpleColumnType::Integer),
632                nullable: false,
633                default: None,
634                comment: None,
635                primary_key: None,
636                unique: None,
637                index: None,
638                foreign_key: None,
639            }],
640            constraints: vec![],
641        }];
642        // fill_with empty string should become ''
643        let result = build_add_column(&backend, "users", &column, Some(""), &current_schema);
644        assert!(result.is_ok());
645        let queries = result.unwrap();
646        let sql = queries
647            .iter()
648            .map(|q| q.build(backend))
649            .collect::<Vec<String>>()
650            .join(";\n");
651
652        // Verify empty string becomes ''
653        assert!(
654            sql.contains("''"),
655            "Expected SQL to contain empty string literal '', got: {}",
656            sql
657        );
658
659        with_settings!({ snapshot_suffix => format!("fill_with_empty_string_{:?}", backend) }, {
660            assert_snapshot!(sql);
661        });
662    }
663}