Skip to main content

vespertide_query/sql/
modify_column_nullable.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7    normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13/// Build SQL for changing column nullability.
14/// For nullable -> non-nullable transitions, fill_with should be provided to update NULL values.
15pub fn build_modify_column_nullable(
16    backend: &DatabaseBackend,
17    table: &str,
18    column: &str,
19    nullable: bool,
20    fill_with: Option<&str>,
21    delete_null_rows: bool,
22    current_schema: &[TableDef],
23) -> Result<Vec<BuiltQuery>, QueryError> {
24    let mut queries = Vec::new();
25
26    // If delete_null_rows is set, delete rows with NULL values instead of updating
27    if !nullable && delete_null_rows {
28        let delete_sql = match backend {
29            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
30                format!("DELETE FROM \"{}\" WHERE \"{}\" IS NULL", table, column)
31            }
32            DatabaseBackend::MySql => {
33                format!("DELETE FROM `{}` WHERE `{}` IS NULL", table, column)
34            }
35        };
36        queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
37    }
38    // If changing to NOT NULL, first update existing NULL values if fill_with is provided
39    else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
40        let fill_value = convert_default_for_backend(&fill_value, backend);
41        let update_sql = match backend {
42            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
43                "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
44                table, column, fill_value, column
45            ),
46            DatabaseBackend::MySql => format!(
47                "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
48                table, column, fill_value, column
49            ),
50        };
51        queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
52    }
53
54    // Generate ALTER TABLE statement based on backend
55    match backend {
56        DatabaseBackend::Postgres => {
57            let alter_sql = if nullable {
58                format!(
59                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
60                    table, column
61                )
62            } else {
63                format!(
64                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
65                    table, column
66                )
67            };
68            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
69        }
70        DatabaseBackend::MySql => {
71            // MySQL requires the full column definition in MODIFY COLUMN
72            // We need to get the column type from current schema
73            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
74
75            let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
76
77            // Create a modified column def with the new nullability
78            let modified_col_def = ColumnDef {
79                nullable,
80                ..column_def.clone()
81            };
82
83            // Build sea-query ColumnDef with all properties (type, nullable, default)
84            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
85
86            let stmt = Table::alter()
87                .table(Alias::new(table))
88                .modify_column(sea_col)
89                .to_owned();
90            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
91        }
92        DatabaseBackend::Sqlite => {
93            // SQLite doesn't support ALTER COLUMN for nullability changes
94            // Use temporary table approach
95            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
96
97            // Create modified columns with the new nullability
98            let mut new_columns = table_def.columns.clone();
99            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
100                col.nullable = nullable;
101            }
102
103            // Generate temporary table name
104            let temp_table = format!("{}_temp", table);
105
106            // 1. Create temporary table with modified column + CHECK constraints
107            let create_query = build_sqlite_temp_table_create(
108                backend,
109                &temp_table,
110                table,
111                &new_columns,
112                &table_def.constraints,
113            );
114            queries.push(create_query);
115
116            // 2. Copy data (all columns)
117            let column_aliases: Vec<Alias> = table_def
118                .columns
119                .iter()
120                .map(|c| Alias::new(&c.name))
121                .collect();
122            let mut select_query = Query::select();
123            for col_alias in &column_aliases {
124                select_query = select_query.column(col_alias.clone()).to_owned();
125            }
126            select_query = select_query.from(Alias::new(table)).to_owned();
127
128            let insert_stmt = Query::insert()
129                .into_table(Alias::new(&temp_table))
130                .columns(column_aliases.clone())
131                .select_from(select_query)
132                .unwrap()
133                .to_owned();
134            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
135
136            // 3. Drop original table
137            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
138            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
139
140            // 4. Rename temporary table to original name
141            queries.push(build_rename_table(&temp_table, table));
142
143            // 5. Recreate indexes (both regular and UNIQUE)
144            queries.extend(recreate_indexes_after_rebuild(
145                table,
146                &table_def.constraints,
147                &[],
148            ));
149        }
150    }
151
152    Ok(queries)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use insta::{assert_snapshot, with_settings};
159    use rstest::rstest;
160    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
161
162    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
163        ColumnDef {
164            name: name.to_string(),
165            r#type: ty,
166            nullable,
167            default: None,
168            comment: None,
169            primary_key: None,
170            unique: None,
171            index: None,
172            foreign_key: None,
173        }
174    }
175
176    fn table_def(
177        name: &str,
178        columns: Vec<ColumnDef>,
179        constraints: Vec<TableConstraint>,
180    ) -> TableDef {
181        TableDef {
182            name: name.to_string(),
183            description: None,
184            columns,
185            constraints,
186        }
187    }
188
189    #[rstest]
190    #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
191    #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
192    #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
193    #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
194    #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
195    #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
196    #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
197    #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
198    #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
199    fn test_build_modify_column_nullable(
200        #[case] backend: DatabaseBackend,
201        #[case] nullable: bool,
202        #[case] fill_with: Option<&str>,
203    ) {
204        let schema = vec![table_def(
205            "users",
206            vec![
207                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
208                col(
209                    "email",
210                    ColumnType::Simple(SimpleColumnType::Text),
211                    !nullable,
212                ),
213            ],
214            vec![],
215        )];
216
217        let result = build_modify_column_nullable(
218            &backend, "users", "email", nullable, fill_with, false, &schema,
219        );
220        assert!(result.is_ok());
221        let queries = result.unwrap();
222        let sql = queries
223            .iter()
224            .map(|q| q.build(backend))
225            .collect::<Vec<String>>()
226            .join("\n");
227
228        let suffix = format!(
229            "{}_{}_users{}",
230            match backend {
231                DatabaseBackend::Postgres => "postgres",
232                DatabaseBackend::MySql => "mysql",
233                DatabaseBackend::Sqlite => "sqlite",
234            },
235            if nullable { "nullable" } else { "not_null" },
236            if fill_with.is_some() {
237                "_with_fill"
238            } else {
239                ""
240            }
241        );
242
243        with_settings!({ snapshot_suffix => suffix }, {
244            assert_snapshot!(sql);
245        });
246    }
247
248    /// Test table not found error
249    #[rstest]
250    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
251    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
252    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
253    fn test_table_not_found(#[case] backend: DatabaseBackend) {
254        // Postgres doesn't need schema lookup for nullability changes
255        if backend == DatabaseBackend::Postgres {
256            return;
257        }
258
259        let result =
260            build_modify_column_nullable(&backend, "users", "email", false, None, false, &[]);
261        assert!(result.is_err());
262        let err_msg = result.unwrap_err().to_string();
263        assert!(err_msg.contains("Table 'users' not found"));
264    }
265
266    /// Test column not found error
267    #[rstest]
268    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
269    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
270    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
271    fn test_column_not_found(#[case] backend: DatabaseBackend) {
272        // Postgres doesn't need schema lookup for nullability changes
273        // SQLite doesn't validate column existence in modify_column_nullable
274        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
275            return;
276        }
277
278        let schema = vec![table_def(
279            "users",
280            vec![col(
281                "id",
282                ColumnType::Simple(SimpleColumnType::Integer),
283                false,
284            )],
285            vec![],
286        )];
287
288        let result =
289            build_modify_column_nullable(&backend, "users", "email", false, None, false, &schema);
290        assert!(result.is_err());
291        let err_msg = result.unwrap_err().to_string();
292        assert!(err_msg.contains("Column 'email' not found"));
293    }
294
295    /// Test with index - should recreate index after table rebuild (SQLite)
296    #[rstest]
297    #[case::postgres_with_index(DatabaseBackend::Postgres)]
298    #[case::mysql_with_index(DatabaseBackend::MySql)]
299    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
300    fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
301        let schema = vec![table_def(
302            "users",
303            vec![
304                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
305                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
306            ],
307            vec![TableConstraint::Index {
308                name: Some("idx_email".into()),
309                columns: vec!["email".into()],
310            }],
311        )];
312
313        let result =
314            build_modify_column_nullable(&backend, "users", "email", false, None, false, &schema);
315        assert!(result.is_ok());
316        let queries = result.unwrap();
317        let sql = queries
318            .iter()
319            .map(|q| q.build(backend))
320            .collect::<Vec<String>>()
321            .join("\n");
322
323        // SQLite should recreate the index after table rebuild
324        if backend == DatabaseBackend::Sqlite {
325            assert!(sql.contains("CREATE INDEX"));
326            assert!(sql.contains("idx_email"));
327        }
328
329        let suffix = format!(
330            "{}_with_index",
331            match backend {
332                DatabaseBackend::Postgres => "postgres",
333                DatabaseBackend::MySql => "mysql",
334                DatabaseBackend::Sqlite => "sqlite",
335            }
336        );
337
338        with_settings!({ snapshot_suffix => suffix }, {
339            assert_snapshot!(sql);
340        });
341    }
342
343    /// Test fill_with containing NOW() should be converted to CURRENT_TIMESTAMP for all backends
344    #[rstest]
345    #[case::postgres_fill_now(DatabaseBackend::Postgres)]
346    #[case::mysql_fill_now(DatabaseBackend::MySql)]
347    #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
348    fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
349        let schema = vec![table_def(
350            "orders",
351            vec![
352                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
353                col(
354                    "paid_at",
355                    ColumnType::Simple(SimpleColumnType::Timestamptz),
356                    true,
357                ),
358            ],
359            vec![],
360        )];
361
362        let result = build_modify_column_nullable(
363            &backend,
364            "orders",
365            "paid_at",
366            false,
367            Some("NOW()"),
368            false,
369            &schema,
370        );
371        assert!(result.is_ok());
372        let queries = result.unwrap();
373        let sql = queries
374            .iter()
375            .map(|q| q.build(backend))
376            .collect::<Vec<String>>()
377            .join("\n");
378
379        // NOW() should be converted to CURRENT_TIMESTAMP for all backends
380        assert!(
381            !sql.contains("NOW()"),
382            "SQL should not contain NOW(), got: {}",
383            sql
384        );
385        assert!(
386            sql.contains("CURRENT_TIMESTAMP"),
387            "SQL should contain CURRENT_TIMESTAMP, got: {}",
388            sql
389        );
390
391        let suffix = format!(
392            "{}_fill_now",
393            match backend {
394                DatabaseBackend::Postgres => "postgres",
395                DatabaseBackend::MySql => "mysql",
396                DatabaseBackend::Sqlite => "sqlite",
397            }
398        );
399
400        with_settings!({ snapshot_suffix => suffix }, {
401            assert_snapshot!(sql);
402        });
403    }
404
405    /// Test with default value - should preserve default in MODIFY COLUMN (MySQL)
406    #[rstest]
407    #[case::postgres_with_default(DatabaseBackend::Postgres)]
408    #[case::mysql_with_default(DatabaseBackend::MySql)]
409    #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
410    fn test_with_default_value(#[case] backend: DatabaseBackend) {
411        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
412        email_col.default = Some("'default@example.com'".into());
413
414        let schema = vec![table_def(
415            "users",
416            vec![
417                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
418                email_col,
419            ],
420            vec![],
421        )];
422
423        let result =
424            build_modify_column_nullable(&backend, "users", "email", false, None, false, &schema);
425        assert!(result.is_ok());
426        let queries = result.unwrap();
427        let sql = queries
428            .iter()
429            .map(|q| q.build(backend))
430            .collect::<Vec<String>>()
431            .join("\n");
432
433        // MySQL and SQLite should include DEFAULT clause
434        if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
435            assert!(sql.contains("DEFAULT"));
436        }
437
438        let suffix = format!(
439            "{}_with_default",
440            match backend {
441                DatabaseBackend::Postgres => "postgres",
442                DatabaseBackend::MySql => "mysql",
443                DatabaseBackend::Sqlite => "sqlite",
444            }
445        );
446
447        with_settings!({ snapshot_suffix => suffix }, {
448            assert_snapshot!(sql);
449        });
450    }
451
452    /// Test delete_null_rows generates DELETE instead of UPDATE
453    #[rstest]
454    #[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
455    #[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
456    #[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
457    fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
458        let schema = vec![table_def(
459            "orders",
460            vec![
461                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
462                col(
463                    "user_id",
464                    ColumnType::Simple(SimpleColumnType::Integer),
465                    true,
466                ),
467            ],
468            vec![],
469        )];
470
471        let result =
472            build_modify_column_nullable(&backend, "orders", "user_id", false, None, true, &schema);
473        assert!(result.is_ok());
474        let queries = result.unwrap();
475        let sql = queries
476            .iter()
477            .map(|q| q.build(backend))
478            .collect::<Vec<String>>()
479            .join("\n");
480
481        assert!(
482            sql.contains("DELETE FROM"),
483            "Expected DELETE FROM in SQL, got: {}",
484            sql
485        );
486        assert!(
487            sql.contains("IS NULL"),
488            "Expected IS NULL in SQL, got: {}",
489            sql
490        );
491        assert!(
492            !sql.contains("UPDATE"),
493            "Should NOT contain UPDATE, got: {}",
494            sql
495        );
496
497        let suffix = format!(
498            "{}_delete_null_rows",
499            match backend {
500                DatabaseBackend::Postgres => "postgres",
501                DatabaseBackend::MySql => "mysql",
502                DatabaseBackend::Sqlite => "sqlite",
503            }
504        );
505
506        with_settings!({ snapshot_suffix => suffix }, {
507            assert_snapshot!(sql);
508        });
509    }
510
511    /// Test delete_null_rows=true with nullable=true does nothing special
512    #[rstest]
513    #[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
514    fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
515        let schema = vec![table_def(
516            "orders",
517            vec![
518                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
519                col(
520                    "user_id",
521                    ColumnType::Simple(SimpleColumnType::Integer),
522                    false,
523                ),
524            ],
525            vec![],
526        )];
527
528        let result =
529            build_modify_column_nullable(&backend, "orders", "user_id", true, None, true, &schema);
530        assert!(result.is_ok());
531        let queries = result.unwrap();
532        let sql = queries
533            .iter()
534            .map(|q| q.build(backend))
535            .collect::<Vec<String>>()
536            .join("\n");
537
538        assert!(
539            !sql.contains("DELETE FROM"),
540            "Should NOT contain DELETE when nullable=true, got: {}",
541            sql
542        );
543    }
544}