Skip to main content

vespertide_query/sql/
modify_column_nullable.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7    normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13/// Build SQL for changing column nullability.
14/// For nullable -> non-nullable transitions, fill_with should be provided to update NULL values.
15pub fn build_modify_column_nullable(
16    backend: &DatabaseBackend,
17    table: &str,
18    column: &str,
19    nullable: bool,
20    fill_with: Option<&str>,
21    current_schema: &[TableDef],
22) -> Result<Vec<BuiltQuery>, QueryError> {
23    let mut queries = Vec::new();
24
25    // If changing to NOT NULL, first update existing NULL values if fill_with is provided
26    if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
27        let fill_value = convert_default_for_backend(&fill_value, backend);
28        let update_sql = match backend {
29            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
30                "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
31                table, column, fill_value, column
32            ),
33            DatabaseBackend::MySql => format!(
34                "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
35                table, column, fill_value, column
36            ),
37        };
38        queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
39    }
40
41    // Generate ALTER TABLE statement based on backend
42    match backend {
43        DatabaseBackend::Postgres => {
44            let alter_sql = if nullable {
45                format!(
46                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
47                    table, column
48                )
49            } else {
50                format!(
51                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
52                    table, column
53                )
54            };
55            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
56        }
57        DatabaseBackend::MySql => {
58            // MySQL requires the full column definition in MODIFY COLUMN
59            // We need to get the column type from current schema
60            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
61
62            let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
63
64            // Create a modified column def with the new nullability
65            let modified_col_def = ColumnDef {
66                nullable,
67                ..column_def.clone()
68            };
69
70            // Build sea-query ColumnDef with all properties (type, nullable, default)
71            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
72
73            let stmt = Table::alter()
74                .table(Alias::new(table))
75                .modify_column(sea_col)
76                .to_owned();
77            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
78        }
79        DatabaseBackend::Sqlite => {
80            // SQLite doesn't support ALTER COLUMN for nullability changes
81            // Use temporary table approach
82            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
83
84            // Create modified columns with the new nullability
85            let mut new_columns = table_def.columns.clone();
86            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
87                col.nullable = nullable;
88            }
89
90            // Generate temporary table name
91            let temp_table = format!("{}_temp", table);
92
93            // 1. Create temporary table with modified column + CHECK constraints
94            let create_query = build_sqlite_temp_table_create(
95                backend,
96                &temp_table,
97                table,
98                &new_columns,
99                &table_def.constraints,
100            );
101            queries.push(create_query);
102
103            // 2. Copy data (all columns)
104            let column_aliases: Vec<Alias> = table_def
105                .columns
106                .iter()
107                .map(|c| Alias::new(&c.name))
108                .collect();
109            let mut select_query = Query::select();
110            for col_alias in &column_aliases {
111                select_query = select_query.column(col_alias.clone()).to_owned();
112            }
113            select_query = select_query.from(Alias::new(table)).to_owned();
114
115            let insert_stmt = Query::insert()
116                .into_table(Alias::new(&temp_table))
117                .columns(column_aliases.clone())
118                .select_from(select_query)
119                .unwrap()
120                .to_owned();
121            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
122
123            // 3. Drop original table
124            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
125            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
126
127            // 4. Rename temporary table to original name
128            queries.push(build_rename_table(&temp_table, table));
129
130            // 5. Recreate indexes (both regular and UNIQUE)
131            queries.extend(recreate_indexes_after_rebuild(
132                table,
133                &table_def.constraints,
134                &[],
135            ));
136        }
137    }
138
139    Ok(queries)
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use insta::{assert_snapshot, with_settings};
146    use rstest::rstest;
147    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
148
149    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
150        ColumnDef {
151            name: name.to_string(),
152            r#type: ty,
153            nullable,
154            default: None,
155            comment: None,
156            primary_key: None,
157            unique: None,
158            index: None,
159            foreign_key: None,
160        }
161    }
162
163    fn table_def(
164        name: &str,
165        columns: Vec<ColumnDef>,
166        constraints: Vec<TableConstraint>,
167    ) -> TableDef {
168        TableDef {
169            name: name.to_string(),
170            description: None,
171            columns,
172            constraints,
173        }
174    }
175
176    #[rstest]
177    #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
178    #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
179    #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
180    #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
181    #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
182    #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
183    #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
184    #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
185    #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
186    fn test_build_modify_column_nullable(
187        #[case] backend: DatabaseBackend,
188        #[case] nullable: bool,
189        #[case] fill_with: Option<&str>,
190    ) {
191        let schema = vec![table_def(
192            "users",
193            vec![
194                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
195                col(
196                    "email",
197                    ColumnType::Simple(SimpleColumnType::Text),
198                    !nullable,
199                ),
200            ],
201            vec![],
202        )];
203
204        let result =
205            build_modify_column_nullable(&backend, "users", "email", nullable, fill_with, &schema);
206        assert!(result.is_ok());
207        let queries = result.unwrap();
208        let sql = queries
209            .iter()
210            .map(|q| q.build(backend))
211            .collect::<Vec<String>>()
212            .join("\n");
213
214        let suffix = format!(
215            "{}_{}_users{}",
216            match backend {
217                DatabaseBackend::Postgres => "postgres",
218                DatabaseBackend::MySql => "mysql",
219                DatabaseBackend::Sqlite => "sqlite",
220            },
221            if nullable { "nullable" } else { "not_null" },
222            if fill_with.is_some() {
223                "_with_fill"
224            } else {
225                ""
226            }
227        );
228
229        with_settings!({ snapshot_suffix => suffix }, {
230            assert_snapshot!(sql);
231        });
232    }
233
234    /// Test table not found error
235    #[rstest]
236    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
237    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
238    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
239    fn test_table_not_found(#[case] backend: DatabaseBackend) {
240        // Postgres doesn't need schema lookup for nullability changes
241        if backend == DatabaseBackend::Postgres {
242            return;
243        }
244
245        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &[]);
246        assert!(result.is_err());
247        let err_msg = result.unwrap_err().to_string();
248        assert!(err_msg.contains("Table 'users' not found"));
249    }
250
251    /// Test column not found error
252    #[rstest]
253    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
254    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
255    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
256    fn test_column_not_found(#[case] backend: DatabaseBackend) {
257        // Postgres doesn't need schema lookup for nullability changes
258        // SQLite doesn't validate column existence in modify_column_nullable
259        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
260            return;
261        }
262
263        let schema = vec![table_def(
264            "users",
265            vec![col(
266                "id",
267                ColumnType::Simple(SimpleColumnType::Integer),
268                false,
269            )],
270            vec![],
271        )];
272
273        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
274        assert!(result.is_err());
275        let err_msg = result.unwrap_err().to_string();
276        assert!(err_msg.contains("Column 'email' not found"));
277    }
278
279    /// Test with index - should recreate index after table rebuild (SQLite)
280    #[rstest]
281    #[case::postgres_with_index(DatabaseBackend::Postgres)]
282    #[case::mysql_with_index(DatabaseBackend::MySql)]
283    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
284    fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
285        let schema = vec![table_def(
286            "users",
287            vec![
288                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
289                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
290            ],
291            vec![TableConstraint::Index {
292                name: Some("idx_email".into()),
293                columns: vec!["email".into()],
294            }],
295        )];
296
297        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
298        assert!(result.is_ok());
299        let queries = result.unwrap();
300        let sql = queries
301            .iter()
302            .map(|q| q.build(backend))
303            .collect::<Vec<String>>()
304            .join("\n");
305
306        // SQLite should recreate the index after table rebuild
307        if backend == DatabaseBackend::Sqlite {
308            assert!(sql.contains("CREATE INDEX"));
309            assert!(sql.contains("idx_email"));
310        }
311
312        let suffix = format!(
313            "{}_with_index",
314            match backend {
315                DatabaseBackend::Postgres => "postgres",
316                DatabaseBackend::MySql => "mysql",
317                DatabaseBackend::Sqlite => "sqlite",
318            }
319        );
320
321        with_settings!({ snapshot_suffix => suffix }, {
322            assert_snapshot!(sql);
323        });
324    }
325
326    /// Test fill_with containing NOW() should be converted to CURRENT_TIMESTAMP for all backends
327    #[rstest]
328    #[case::postgres_fill_now(DatabaseBackend::Postgres)]
329    #[case::mysql_fill_now(DatabaseBackend::MySql)]
330    #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
331    fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
332        let schema = vec![table_def(
333            "orders",
334            vec![
335                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
336                col(
337                    "paid_at",
338                    ColumnType::Simple(SimpleColumnType::Timestamptz),
339                    true,
340                ),
341            ],
342            vec![],
343        )];
344
345        let result = build_modify_column_nullable(
346            &backend,
347            "orders",
348            "paid_at",
349            false,
350            Some("NOW()"),
351            &schema,
352        );
353        assert!(result.is_ok());
354        let queries = result.unwrap();
355        let sql = queries
356            .iter()
357            .map(|q| q.build(backend))
358            .collect::<Vec<String>>()
359            .join("\n");
360
361        // NOW() should be converted to CURRENT_TIMESTAMP for all backends
362        assert!(
363            !sql.contains("NOW()"),
364            "SQL should not contain NOW(), got: {}",
365            sql
366        );
367        assert!(
368            sql.contains("CURRENT_TIMESTAMP"),
369            "SQL should contain CURRENT_TIMESTAMP, got: {}",
370            sql
371        );
372
373        let suffix = format!(
374            "{}_fill_now",
375            match backend {
376                DatabaseBackend::Postgres => "postgres",
377                DatabaseBackend::MySql => "mysql",
378                DatabaseBackend::Sqlite => "sqlite",
379            }
380        );
381
382        with_settings!({ snapshot_suffix => suffix }, {
383            assert_snapshot!(sql);
384        });
385    }
386
387    /// Test with default value - should preserve default in MODIFY COLUMN (MySQL)
388    #[rstest]
389    #[case::postgres_with_default(DatabaseBackend::Postgres)]
390    #[case::mysql_with_default(DatabaseBackend::MySql)]
391    #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
392    fn test_with_default_value(#[case] backend: DatabaseBackend) {
393        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
394        email_col.default = Some("'default@example.com'".into());
395
396        let schema = vec![table_def(
397            "users",
398            vec![
399                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
400                email_col,
401            ],
402            vec![],
403        )];
404
405        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
406        assert!(result.is_ok());
407        let queries = result.unwrap();
408        let sql = queries
409            .iter()
410            .map(|q| q.build(backend))
411            .collect::<Vec<String>>()
412            .join("\n");
413
414        // MySQL and SQLite should include DEFAULT clause
415        if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
416            assert!(sql.contains("DEFAULT"));
417        }
418
419        let suffix = format!(
420            "{}_with_default",
421            match backend {
422                DatabaseBackend::Postgres => "postgres",
423                DatabaseBackend::MySql => "mysql",
424                DatabaseBackend::Sqlite => "sqlite",
425            }
426        );
427
428        with_settings!({ snapshot_suffix => suffix }, {
429            assert_snapshot!(sql);
430        });
431    }
432}