Skip to main content

vespertide_query/sql/
modify_column_nullable.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_sea_column_def_with_table, build_sqlite_temp_table_create, normalize_fill_with,
7    recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13/// Build SQL for changing column nullability.
14/// For nullable -> non-nullable transitions, fill_with should be provided to update NULL values.
15pub fn build_modify_column_nullable(
16    backend: &DatabaseBackend,
17    table: &str,
18    column: &str,
19    nullable: bool,
20    fill_with: Option<&str>,
21    current_schema: &[TableDef],
22) -> Result<Vec<BuiltQuery>, QueryError> {
23    let mut queries = Vec::new();
24
25    // If changing to NOT NULL, first update existing NULL values if fill_with is provided
26    if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
27        let update_sql = match backend {
28            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
29                "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
30                table, column, fill_value, column
31            ),
32            DatabaseBackend::MySql => format!(
33                "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
34                table, column, fill_value, column
35            ),
36        };
37        queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
38    }
39
40    // Generate ALTER TABLE statement based on backend
41    match backend {
42        DatabaseBackend::Postgres => {
43            let alter_sql = if nullable {
44                format!(
45                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
46                    table, column
47                )
48            } else {
49                format!(
50                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
51                    table, column
52                )
53            };
54            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
55        }
56        DatabaseBackend::MySql => {
57            // MySQL requires the full column definition in MODIFY COLUMN
58            // We need to get the column type from current schema
59            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
60
61            let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
62
63            // Create a modified column def with the new nullability
64            let modified_col_def = ColumnDef {
65                nullable,
66                ..column_def.clone()
67            };
68
69            // Build sea-query ColumnDef with all properties (type, nullable, default)
70            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
71
72            let stmt = Table::alter()
73                .table(Alias::new(table))
74                .modify_column(sea_col)
75                .to_owned();
76            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
77        }
78        DatabaseBackend::Sqlite => {
79            // SQLite doesn't support ALTER COLUMN for nullability changes
80            // Use temporary table approach
81            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
82
83            // Create modified columns with the new nullability
84            let mut new_columns = table_def.columns.clone();
85            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
86                col.nullable = nullable;
87            }
88
89            // Generate temporary table name
90            let temp_table = format!("{}_temp", table);
91
92            // 1. Create temporary table with modified column + CHECK constraints
93            let create_query = build_sqlite_temp_table_create(
94                backend,
95                &temp_table,
96                table,
97                &new_columns,
98                &table_def.constraints,
99            );
100            queries.push(create_query);
101
102            // 2. Copy data (all columns)
103            let column_aliases: Vec<Alias> = table_def
104                .columns
105                .iter()
106                .map(|c| Alias::new(&c.name))
107                .collect();
108            let mut select_query = Query::select();
109            for col_alias in &column_aliases {
110                select_query = select_query.column(col_alias.clone()).to_owned();
111            }
112            select_query = select_query.from(Alias::new(table)).to_owned();
113
114            let insert_stmt = Query::insert()
115                .into_table(Alias::new(&temp_table))
116                .columns(column_aliases.clone())
117                .select_from(select_query)
118                .unwrap()
119                .to_owned();
120            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
121
122            // 3. Drop original table
123            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
124            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
125
126            // 4. Rename temporary table to original name
127            queries.push(build_rename_table(&temp_table, table));
128
129            // 5. Recreate indexes (both regular and UNIQUE)
130            queries.extend(recreate_indexes_after_rebuild(
131                table,
132                &table_def.constraints,
133                &[],
134            ));
135        }
136    }
137
138    Ok(queries)
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use insta::{assert_snapshot, with_settings};
145    use rstest::rstest;
146    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
147
148    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
149        ColumnDef {
150            name: name.to_string(),
151            r#type: ty,
152            nullable,
153            default: None,
154            comment: None,
155            primary_key: None,
156            unique: None,
157            index: None,
158            foreign_key: None,
159        }
160    }
161
162    fn table_def(
163        name: &str,
164        columns: Vec<ColumnDef>,
165        constraints: Vec<TableConstraint>,
166    ) -> TableDef {
167        TableDef {
168            name: name.to_string(),
169            description: None,
170            columns,
171            constraints,
172        }
173    }
174
175    #[rstest]
176    #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
177    #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
178    #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
179    #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
180    #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
181    #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
182    #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
183    #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
184    #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
185    fn test_build_modify_column_nullable(
186        #[case] backend: DatabaseBackend,
187        #[case] nullable: bool,
188        #[case] fill_with: Option<&str>,
189    ) {
190        let schema = vec![table_def(
191            "users",
192            vec![
193                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
194                col(
195                    "email",
196                    ColumnType::Simple(SimpleColumnType::Text),
197                    !nullable,
198                ),
199            ],
200            vec![],
201        )];
202
203        let result =
204            build_modify_column_nullable(&backend, "users", "email", nullable, fill_with, &schema);
205        assert!(result.is_ok());
206        let queries = result.unwrap();
207        let sql = queries
208            .iter()
209            .map(|q| q.build(backend))
210            .collect::<Vec<String>>()
211            .join("\n");
212
213        let suffix = format!(
214            "{}_{}_users{}",
215            match backend {
216                DatabaseBackend::Postgres => "postgres",
217                DatabaseBackend::MySql => "mysql",
218                DatabaseBackend::Sqlite => "sqlite",
219            },
220            if nullable { "nullable" } else { "not_null" },
221            if fill_with.is_some() {
222                "_with_fill"
223            } else {
224                ""
225            }
226        );
227
228        with_settings!({ snapshot_suffix => suffix }, {
229            assert_snapshot!(sql);
230        });
231    }
232
233    /// Test table not found error
234    #[rstest]
235    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
236    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
237    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
238    fn test_table_not_found(#[case] backend: DatabaseBackend) {
239        // Postgres doesn't need schema lookup for nullability changes
240        if backend == DatabaseBackend::Postgres {
241            return;
242        }
243
244        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &[]);
245        assert!(result.is_err());
246        let err_msg = result.unwrap_err().to_string();
247        assert!(err_msg.contains("Table 'users' not found"));
248    }
249
250    /// Test column not found error
251    #[rstest]
252    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
253    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
254    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
255    fn test_column_not_found(#[case] backend: DatabaseBackend) {
256        // Postgres doesn't need schema lookup for nullability changes
257        // SQLite doesn't validate column existence in modify_column_nullable
258        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
259            return;
260        }
261
262        let schema = vec![table_def(
263            "users",
264            vec![col(
265                "id",
266                ColumnType::Simple(SimpleColumnType::Integer),
267                false,
268            )],
269            vec![],
270        )];
271
272        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
273        assert!(result.is_err());
274        let err_msg = result.unwrap_err().to_string();
275        assert!(err_msg.contains("Column 'email' not found"));
276    }
277
278    /// Test with index - should recreate index after table rebuild (SQLite)
279    #[rstest]
280    #[case::postgres_with_index(DatabaseBackend::Postgres)]
281    #[case::mysql_with_index(DatabaseBackend::MySql)]
282    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
283    fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
284        let schema = vec![table_def(
285            "users",
286            vec![
287                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
288                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
289            ],
290            vec![TableConstraint::Index {
291                name: Some("idx_email".into()),
292                columns: vec!["email".into()],
293            }],
294        )];
295
296        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
297        assert!(result.is_ok());
298        let queries = result.unwrap();
299        let sql = queries
300            .iter()
301            .map(|q| q.build(backend))
302            .collect::<Vec<String>>()
303            .join("\n");
304
305        // SQLite should recreate the index after table rebuild
306        if backend == DatabaseBackend::Sqlite {
307            assert!(sql.contains("CREATE INDEX"));
308            assert!(sql.contains("idx_email"));
309        }
310
311        let suffix = format!(
312            "{}_with_index",
313            match backend {
314                DatabaseBackend::Postgres => "postgres",
315                DatabaseBackend::MySql => "mysql",
316                DatabaseBackend::Sqlite => "sqlite",
317            }
318        );
319
320        with_settings!({ snapshot_suffix => suffix }, {
321            assert_snapshot!(sql);
322        });
323    }
324
325    /// Test with default value - should preserve default in MODIFY COLUMN (MySQL)
326    #[rstest]
327    #[case::postgres_with_default(DatabaseBackend::Postgres)]
328    #[case::mysql_with_default(DatabaseBackend::MySql)]
329    #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
330    fn test_with_default_value(#[case] backend: DatabaseBackend) {
331        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
332        email_col.default = Some("'default@example.com'".into());
333
334        let schema = vec![table_def(
335            "users",
336            vec![
337                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
338                email_col,
339            ],
340            vec![],
341        )];
342
343        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
344        assert!(result.is_ok());
345        let queries = result.unwrap();
346        let sql = queries
347            .iter()
348            .map(|q| q.build(backend))
349            .collect::<Vec<String>>()
350            .join("\n");
351
352        // MySQL and SQLite should include DEFAULT clause
353        if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
354            assert!(sql.contains("DEFAULT"));
355        }
356
357        let suffix = format!(
358            "{}_with_default",
359            match backend {
360                DatabaseBackend::Postgres => "postgres",
361                DatabaseBackend::MySql => "mysql",
362                DatabaseBackend::Sqlite => "sqlite",
363            }
364        );
365
366        with_settings!({ snapshot_suffix => suffix }, {
367            assert_snapshot!(sql);
368        });
369    }
370}