Skip to main content

vespertide_query/sql/
modify_column_nullable.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7    normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13/// Build SQL for changing column nullability.
14/// For nullable -> non-nullable transitions, fill_with should be provided to update NULL values.
15#[allow(clippy::too_many_arguments)]
16pub fn build_modify_column_nullable(
17    backend: &DatabaseBackend,
18    table: &str,
19    column: &str,
20    nullable: bool,
21    fill_with: Option<&str>,
22    delete_null_rows: bool,
23    current_schema: &[TableDef],
24    pending_constraints: &[vespertide_core::TableConstraint],
25) -> Result<Vec<BuiltQuery>, QueryError> {
26    let mut queries = Vec::new();
27
28    // If delete_null_rows is set, delete rows with NULL values instead of updating
29    if !nullable && delete_null_rows {
30        let delete_sql = match backend {
31            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
32                format!("DELETE FROM \"{}\" WHERE \"{}\" IS NULL", table, column)
33            }
34            DatabaseBackend::MySql => {
35                format!("DELETE FROM `{}` WHERE `{}` IS NULL", table, column)
36            }
37        };
38        queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
39    }
40    // If changing to NOT NULL, first update existing NULL values if fill_with is provided
41    else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
42        let fill_value = convert_default_for_backend(&fill_value, backend);
43        let update_sql = match backend {
44            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
45                "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
46                table, column, fill_value, column
47            ),
48            DatabaseBackend::MySql => format!(
49                "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
50                table, column, fill_value, column
51            ),
52        };
53        queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
54    }
55
56    // Generate ALTER TABLE statement based on backend
57    match backend {
58        DatabaseBackend::Postgres => {
59            let alter_sql = if nullable {
60                format!(
61                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
62                    table, column
63                )
64            } else {
65                format!(
66                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
67                    table, column
68                )
69            };
70            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
71        }
72        DatabaseBackend::MySql => {
73            // MySQL requires the full column definition in MODIFY COLUMN
74            // We need to get the column type from current schema
75            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
76
77            let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
78
79            // Create a modified column def with the new nullability
80            let modified_col_def = ColumnDef {
81                nullable,
82                ..column_def.clone()
83            };
84
85            // Build sea-query ColumnDef with all properties (type, nullable, default)
86            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
87
88            let stmt = Table::alter()
89                .table(Alias::new(table))
90                .modify_column(sea_col)
91                .to_owned();
92            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
93        }
94        DatabaseBackend::Sqlite => {
95            // SQLite doesn't support ALTER COLUMN for nullability changes
96            // Use temporary table approach
97            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
98
99            // Create modified columns with the new nullability
100            let mut new_columns = table_def.columns.clone();
101            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
102                col.nullable = nullable;
103            }
104
105            // Generate temporary table name
106            let temp_table = format!("{}_temp", table);
107
108            // 1. Create temporary table with modified column + CHECK constraints
109            let create_query = build_sqlite_temp_table_create(
110                backend,
111                &temp_table,
112                table,
113                &new_columns,
114                &table_def.constraints,
115            );
116            queries.push(create_query);
117
118            // 2. Copy data (all columns)
119            let column_aliases: Vec<Alias> = table_def
120                .columns
121                .iter()
122                .map(|c| Alias::new(&c.name))
123                .collect();
124            let mut select_query = Query::select();
125            for col_alias in &column_aliases {
126                select_query = select_query.column(col_alias.clone()).to_owned();
127            }
128            select_query = select_query.from(Alias::new(table)).to_owned();
129
130            let insert_stmt = Query::insert()
131                .into_table(Alias::new(&temp_table))
132                .columns(column_aliases.clone())
133                .select_from(select_query)
134                .unwrap()
135                .to_owned();
136            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
137
138            // 3. Drop original table
139            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
140            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
141
142            // 4. Rename temporary table to original name
143            queries.push(build_rename_table(&temp_table, table));
144
145            // 5. Recreate indexes (both regular and UNIQUE)
146            queries.extend(recreate_indexes_after_rebuild(
147                table,
148                &table_def.constraints,
149                pending_constraints,
150            ));
151        }
152    }
153
154    Ok(queries)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use insta::{assert_snapshot, with_settings};
161    use rstest::rstest;
162    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
163
164    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
165        ColumnDef {
166            name: name.to_string(),
167            r#type: ty,
168            nullable,
169            default: None,
170            comment: None,
171            primary_key: None,
172            unique: None,
173            index: None,
174            foreign_key: None,
175        }
176    }
177
178    fn table_def(
179        name: &str,
180        columns: Vec<ColumnDef>,
181        constraints: Vec<TableConstraint>,
182    ) -> TableDef {
183        TableDef {
184            name: name.to_string(),
185            description: None,
186            columns,
187            constraints,
188        }
189    }
190
191    #[rstest]
192    #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
193    #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
194    #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
195    #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
196    #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
197    #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
198    #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
199    #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
200    #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
201    fn test_build_modify_column_nullable(
202        #[case] backend: DatabaseBackend,
203        #[case] nullable: bool,
204        #[case] fill_with: Option<&str>,
205    ) {
206        let schema = vec![table_def(
207            "users",
208            vec![
209                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
210                col(
211                    "email",
212                    ColumnType::Simple(SimpleColumnType::Text),
213                    !nullable,
214                ),
215            ],
216            vec![],
217        )];
218
219        let result = build_modify_column_nullable(
220            &backend,
221            "users",
222            "email",
223            nullable,
224            fill_with,
225            false,
226            &schema,
227            &[],
228        );
229        assert!(result.is_ok());
230        let queries = result.unwrap();
231        let sql = queries
232            .iter()
233            .map(|q| q.build(backend))
234            .collect::<Vec<String>>()
235            .join("\n");
236
237        let suffix = format!(
238            "{}_{}_users{}",
239            match backend {
240                DatabaseBackend::Postgres => "postgres",
241                DatabaseBackend::MySql => "mysql",
242                DatabaseBackend::Sqlite => "sqlite",
243            },
244            if nullable { "nullable" } else { "not_null" },
245            if fill_with.is_some() {
246                "_with_fill"
247            } else {
248                ""
249            }
250        );
251
252        with_settings!({ snapshot_suffix => suffix }, {
253            assert_snapshot!(sql);
254        });
255    }
256
257    /// Test table not found error
258    #[rstest]
259    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
260    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
261    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
262    fn test_table_not_found(#[case] backend: DatabaseBackend) {
263        // Postgres doesn't need schema lookup for nullability changes
264        if backend == DatabaseBackend::Postgres {
265            return;
266        }
267
268        let result =
269            build_modify_column_nullable(&backend, "users", "email", false, None, false, &[], &[]);
270        assert!(result.is_err());
271        let err_msg = result.unwrap_err().to_string();
272        assert!(err_msg.contains("Table 'users' not found"));
273    }
274
275    /// Test column not found error
276    #[rstest]
277    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
278    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
279    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
280    fn test_column_not_found(#[case] backend: DatabaseBackend) {
281        // Postgres doesn't need schema lookup for nullability changes
282        // SQLite doesn't validate column existence in modify_column_nullable
283        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
284            return;
285        }
286
287        let schema = vec![table_def(
288            "users",
289            vec![col(
290                "id",
291                ColumnType::Simple(SimpleColumnType::Integer),
292                false,
293            )],
294            vec![],
295        )];
296
297        let result = build_modify_column_nullable(
298            &backend,
299            "users",
300            "email",
301            false,
302            None,
303            false,
304            &schema,
305            &[],
306        );
307        assert!(result.is_err());
308        let err_msg = result.unwrap_err().to_string();
309        assert!(err_msg.contains("Column 'email' not found"));
310    }
311
312    /// Test with index - should recreate index after table rebuild (SQLite)
313    #[rstest]
314    #[case::postgres_with_index(DatabaseBackend::Postgres)]
315    #[case::mysql_with_index(DatabaseBackend::MySql)]
316    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
317    fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
318        let schema = vec![table_def(
319            "users",
320            vec![
321                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
322                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
323            ],
324            vec![TableConstraint::Index {
325                name: Some("idx_email".into()),
326                columns: vec!["email".into()],
327            }],
328        )];
329
330        let result = build_modify_column_nullable(
331            &backend,
332            "users",
333            "email",
334            false,
335            None,
336            false,
337            &schema,
338            &[],
339        );
340        assert!(result.is_ok());
341        let queries = result.unwrap();
342        let sql = queries
343            .iter()
344            .map(|q| q.build(backend))
345            .collect::<Vec<String>>()
346            .join("\n");
347
348        // SQLite should recreate the index after table rebuild
349        if backend == DatabaseBackend::Sqlite {
350            assert!(sql.contains("CREATE INDEX"));
351            assert!(sql.contains("idx_email"));
352        }
353
354        let suffix = format!(
355            "{}_with_index",
356            match backend {
357                DatabaseBackend::Postgres => "postgres",
358                DatabaseBackend::MySql => "mysql",
359                DatabaseBackend::Sqlite => "sqlite",
360            }
361        );
362
363        with_settings!({ snapshot_suffix => suffix }, {
364            assert_snapshot!(sql);
365        });
366    }
367
368    /// Test fill_with containing NOW() should be converted to CURRENT_TIMESTAMP for all backends
369    #[rstest]
370    #[case::postgres_fill_now(DatabaseBackend::Postgres)]
371    #[case::mysql_fill_now(DatabaseBackend::MySql)]
372    #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
373    fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
374        let schema = vec![table_def(
375            "orders",
376            vec![
377                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
378                col(
379                    "paid_at",
380                    ColumnType::Simple(SimpleColumnType::Timestamptz),
381                    true,
382                ),
383            ],
384            vec![],
385        )];
386
387        let result = build_modify_column_nullable(
388            &backend,
389            "orders",
390            "paid_at",
391            false,
392            Some("NOW()"),
393            false,
394            &schema,
395            &[],
396        );
397        assert!(result.is_ok());
398        let queries = result.unwrap();
399        let sql = queries
400            .iter()
401            .map(|q| q.build(backend))
402            .collect::<Vec<String>>()
403            .join("\n");
404
405        // NOW() should be converted to CURRENT_TIMESTAMP for all backends
406        assert!(
407            !sql.contains("NOW()"),
408            "SQL should not contain NOW(), got: {}",
409            sql
410        );
411        assert!(
412            sql.contains("CURRENT_TIMESTAMP"),
413            "SQL should contain CURRENT_TIMESTAMP, got: {}",
414            sql
415        );
416
417        let suffix = format!(
418            "{}_fill_now",
419            match backend {
420                DatabaseBackend::Postgres => "postgres",
421                DatabaseBackend::MySql => "mysql",
422                DatabaseBackend::Sqlite => "sqlite",
423            }
424        );
425
426        with_settings!({ snapshot_suffix => suffix }, {
427            assert_snapshot!(sql);
428        });
429    }
430
431    /// Test with default value - should preserve default in MODIFY COLUMN (MySQL)
432    #[rstest]
433    #[case::postgres_with_default(DatabaseBackend::Postgres)]
434    #[case::mysql_with_default(DatabaseBackend::MySql)]
435    #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
436    fn test_with_default_value(#[case] backend: DatabaseBackend) {
437        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
438        email_col.default = Some("'default@example.com'".into());
439
440        let schema = vec![table_def(
441            "users",
442            vec![
443                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
444                email_col,
445            ],
446            vec![],
447        )];
448
449        let result = build_modify_column_nullable(
450            &backend,
451            "users",
452            "email",
453            false,
454            None,
455            false,
456            &schema,
457            &[],
458        );
459        assert!(result.is_ok());
460        let queries = result.unwrap();
461        let sql = queries
462            .iter()
463            .map(|q| q.build(backend))
464            .collect::<Vec<String>>()
465            .join("\n");
466
467        // MySQL and SQLite should include DEFAULT clause
468        if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
469            assert!(sql.contains("DEFAULT"));
470        }
471
472        let suffix = format!(
473            "{}_with_default",
474            match backend {
475                DatabaseBackend::Postgres => "postgres",
476                DatabaseBackend::MySql => "mysql",
477                DatabaseBackend::Sqlite => "sqlite",
478            }
479        );
480
481        with_settings!({ snapshot_suffix => suffix }, {
482            assert_snapshot!(sql);
483        });
484    }
485
486    /// Test delete_null_rows generates DELETE instead of UPDATE
487    #[rstest]
488    #[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
489    #[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
490    #[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
491    fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
492        let schema = vec![table_def(
493            "orders",
494            vec![
495                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
496                col(
497                    "user_id",
498                    ColumnType::Simple(SimpleColumnType::Integer),
499                    true,
500                ),
501            ],
502            vec![],
503        )];
504
505        let result = build_modify_column_nullable(
506            &backend,
507            "orders",
508            "user_id",
509            false,
510            None,
511            true,
512            &schema,
513            &[],
514        );
515        assert!(result.is_ok());
516        let queries = result.unwrap();
517        let sql = queries
518            .iter()
519            .map(|q| q.build(backend))
520            .collect::<Vec<String>>()
521            .join("\n");
522
523        assert!(
524            sql.contains("DELETE FROM"),
525            "Expected DELETE FROM in SQL, got: {}",
526            sql
527        );
528        assert!(
529            sql.contains("IS NULL"),
530            "Expected IS NULL in SQL, got: {}",
531            sql
532        );
533        assert!(
534            !sql.contains("UPDATE"),
535            "Should NOT contain UPDATE, got: {}",
536            sql
537        );
538
539        let suffix = format!(
540            "{}_delete_null_rows",
541            match backend {
542                DatabaseBackend::Postgres => "postgres",
543                DatabaseBackend::MySql => "mysql",
544                DatabaseBackend::Sqlite => "sqlite",
545            }
546        );
547
548        with_settings!({ snapshot_suffix => suffix }, {
549            assert_snapshot!(sql);
550        });
551    }
552
553    /// Test delete_null_rows=true with nullable=true does nothing special
554    #[rstest]
555    #[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
556    fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
557        let schema = vec![table_def(
558            "orders",
559            vec![
560                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
561                col(
562                    "user_id",
563                    ColumnType::Simple(SimpleColumnType::Integer),
564                    false,
565                ),
566            ],
567            vec![],
568        )];
569
570        let result = build_modify_column_nullable(
571            &backend,
572            "orders",
573            "user_id",
574            true,
575            None,
576            true,
577            &schema,
578            &[],
579        );
580        assert!(result.is_ok());
581        let queries = result.unwrap();
582        let sql = queries
583            .iter()
584            .map(|q| q.build(backend))
585            .collect::<Vec<String>>()
586            .join("\n");
587
588        assert!(
589            !sql.contains("DELETE FROM"),
590            "Should NOT contain DELETE when nullable=true, got: {}",
591            sql
592        );
593    }
594}