vespertide_query/sql/
modify_column_nullable.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{build_sea_column_def_with_table, normalize_fill_with};
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11/// Build SQL for changing column nullability.
12/// For nullable -> non-nullable transitions, fill_with should be provided to update NULL values.
13pub fn build_modify_column_nullable(
14    backend: &DatabaseBackend,
15    table: &str,
16    column: &str,
17    nullable: bool,
18    fill_with: Option<&str>,
19    current_schema: &[TableDef],
20) -> Result<Vec<BuiltQuery>, QueryError> {
21    let mut queries = Vec::new();
22
23    // If changing to NOT NULL, first update existing NULL values if fill_with is provided
24    if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
25        let update_sql = match backend {
26            DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
27                "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
28                table, column, fill_value, column
29            ),
30            DatabaseBackend::MySql => format!(
31                "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
32                table, column, fill_value, column
33            ),
34        };
35        queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
36    }
37
38    // Generate ALTER TABLE statement based on backend
39    match backend {
40        DatabaseBackend::Postgres => {
41            let alter_sql = if nullable {
42                format!(
43                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
44                    table, column
45                )
46            } else {
47                format!(
48                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
49                    table, column
50                )
51            };
52            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
53        }
54        DatabaseBackend::MySql => {
55            // MySQL requires the full column definition in MODIFY COLUMN
56            // We need to get the column type from current schema
57            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
58
59            let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
60
61            // Create a modified column def with the new nullability
62            let modified_col_def = ColumnDef {
63                nullable,
64                ..column_def.clone()
65            };
66
67            // Build sea-query ColumnDef with all properties (type, nullable, default)
68            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
69
70            let stmt = Table::alter()
71                .table(Alias::new(table))
72                .modify_column(sea_col)
73                .to_owned();
74            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
75        }
76        DatabaseBackend::Sqlite => {
77            // SQLite doesn't support ALTER COLUMN for nullability changes
78            // Use temporary table approach
79            let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
80
81            // Create modified columns with the new nullability
82            let mut new_columns = table_def.columns.clone();
83            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
84                col.nullable = nullable;
85            }
86
87            // Generate temporary table name
88            let temp_table = format!("{}_temp", table);
89
90            // 1. Create temporary table with modified column
91            let create_temp_table = build_create_table_for_backend(
92                backend,
93                &temp_table,
94                &new_columns,
95                &table_def.constraints,
96            );
97            queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
98
99            // 2. Copy data (all columns)
100            let column_aliases: Vec<Alias> = table_def
101                .columns
102                .iter()
103                .map(|c| Alias::new(&c.name))
104                .collect();
105            let mut select_query = Query::select();
106            for col_alias in &column_aliases {
107                select_query = select_query.column(col_alias.clone()).to_owned();
108            }
109            select_query = select_query.from(Alias::new(table)).to_owned();
110
111            let insert_stmt = Query::insert()
112                .into_table(Alias::new(&temp_table))
113                .columns(column_aliases.clone())
114                .select_from(select_query)
115                .unwrap()
116                .to_owned();
117            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
118
119            // 3. Drop original table
120            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
121            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
122
123            // 4. Rename temporary table to original name
124            queries.push(build_rename_table(&temp_table, table));
125
126            // 5. Recreate indexes from Index constraints
127            for constraint in &table_def.constraints {
128                if let vespertide_core::TableConstraint::Index {
129                    name: idx_name,
130                    columns: idx_cols,
131                } = constraint
132                {
133                    let index_name =
134                        vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
135                    let mut idx_stmt = sea_query::Index::create();
136                    idx_stmt = idx_stmt.name(&index_name).to_owned();
137                    for col_name in idx_cols {
138                        idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
139                    }
140                    idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
141                    queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
142                }
143            }
144        }
145    }
146
147    Ok(queries)
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use insta::{assert_snapshot, with_settings};
154    use rstest::rstest;
155    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
156
157    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
158        ColumnDef {
159            name: name.to_string(),
160            r#type: ty,
161            nullable,
162            default: None,
163            comment: None,
164            primary_key: None,
165            unique: None,
166            index: None,
167            foreign_key: None,
168        }
169    }
170
171    fn table_def(
172        name: &str,
173        columns: Vec<ColumnDef>,
174        constraints: Vec<TableConstraint>,
175    ) -> TableDef {
176        TableDef {
177            name: name.to_string(),
178            description: None,
179            columns,
180            constraints,
181        }
182    }
183
184    #[rstest]
185    #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
186    #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
187    #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
188    #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
189    #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
190    #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
191    #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
192    #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
193    #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
194    fn test_build_modify_column_nullable(
195        #[case] backend: DatabaseBackend,
196        #[case] nullable: bool,
197        #[case] fill_with: Option<&str>,
198    ) {
199        let schema = vec![table_def(
200            "users",
201            vec![
202                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
203                col(
204                    "email",
205                    ColumnType::Simple(SimpleColumnType::Text),
206                    !nullable,
207                ),
208            ],
209            vec![],
210        )];
211
212        let result =
213            build_modify_column_nullable(&backend, "users", "email", nullable, fill_with, &schema);
214        assert!(result.is_ok());
215        let queries = result.unwrap();
216        let sql = queries
217            .iter()
218            .map(|q| q.build(backend))
219            .collect::<Vec<String>>()
220            .join("\n");
221
222        let suffix = format!(
223            "{}_{}_users{}",
224            match backend {
225                DatabaseBackend::Postgres => "postgres",
226                DatabaseBackend::MySql => "mysql",
227                DatabaseBackend::Sqlite => "sqlite",
228            },
229            if nullable { "nullable" } else { "not_null" },
230            if fill_with.is_some() {
231                "_with_fill"
232            } else {
233                ""
234            }
235        );
236
237        with_settings!({ snapshot_suffix => suffix }, {
238            assert_snapshot!(sql);
239        });
240    }
241
242    /// Test table not found error
243    #[rstest]
244    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
245    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
246    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
247    fn test_table_not_found(#[case] backend: DatabaseBackend) {
248        // Postgres doesn't need schema lookup for nullability changes
249        if backend == DatabaseBackend::Postgres {
250            return;
251        }
252
253        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &[]);
254        assert!(result.is_err());
255        let err_msg = result.unwrap_err().to_string();
256        assert!(err_msg.contains("Table 'users' not found"));
257    }
258
259    /// Test column not found error
260    #[rstest]
261    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
262    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
263    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
264    fn test_column_not_found(#[case] backend: DatabaseBackend) {
265        // Postgres doesn't need schema lookup for nullability changes
266        // SQLite doesn't validate column existence in modify_column_nullable
267        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
268            return;
269        }
270
271        let schema = vec![table_def(
272            "users",
273            vec![col(
274                "id",
275                ColumnType::Simple(SimpleColumnType::Integer),
276                false,
277            )],
278            vec![],
279        )];
280
281        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
282        assert!(result.is_err());
283        let err_msg = result.unwrap_err().to_string();
284        assert!(err_msg.contains("Column 'email' not found"));
285    }
286
287    /// Test with index - should recreate index after table rebuild (SQLite)
288    #[rstest]
289    #[case::postgres_with_index(DatabaseBackend::Postgres)]
290    #[case::mysql_with_index(DatabaseBackend::MySql)]
291    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
292    fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
293        let schema = vec![table_def(
294            "users",
295            vec![
296                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
297                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
298            ],
299            vec![TableConstraint::Index {
300                name: Some("idx_email".into()),
301                columns: vec!["email".into()],
302            }],
303        )];
304
305        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
306        assert!(result.is_ok());
307        let queries = result.unwrap();
308        let sql = queries
309            .iter()
310            .map(|q| q.build(backend))
311            .collect::<Vec<String>>()
312            .join("\n");
313
314        // SQLite should recreate the index after table rebuild
315        if backend == DatabaseBackend::Sqlite {
316            assert!(sql.contains("CREATE INDEX"));
317            assert!(sql.contains("idx_email"));
318        }
319
320        let suffix = format!(
321            "{}_with_index",
322            match backend {
323                DatabaseBackend::Postgres => "postgres",
324                DatabaseBackend::MySql => "mysql",
325                DatabaseBackend::Sqlite => "sqlite",
326            }
327        );
328
329        with_settings!({ snapshot_suffix => suffix }, {
330            assert_snapshot!(sql);
331        });
332    }
333
334    /// Test with default value - should preserve default in MODIFY COLUMN (MySQL)
335    #[rstest]
336    #[case::postgres_with_default(DatabaseBackend::Postgres)]
337    #[case::mysql_with_default(DatabaseBackend::MySql)]
338    #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
339    fn test_with_default_value(#[case] backend: DatabaseBackend) {
340        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
341        email_col.default = Some("'default@example.com'".into());
342
343        let schema = vec![table_def(
344            "users",
345            vec![
346                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
347                email_col,
348            ],
349            vec![],
350        )];
351
352        let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
353        assert!(result.is_ok());
354        let queries = result.unwrap();
355        let sql = queries
356            .iter()
357            .map(|q| q.build(backend))
358            .collect::<Vec<String>>()
359            .join("\n");
360
361        // MySQL and SQLite should include DEFAULT clause
362        if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
363            assert!(sql.contains("DEFAULT"));
364        }
365
366        let suffix = format!(
367            "{}_with_default",
368            match backend {
369                DatabaseBackend::Postgres => "postgres",
370                DatabaseBackend::MySql => "mysql",
371                DatabaseBackend::Sqlite => "sqlite",
372            }
373        );
374
375        with_settings!({ snapshot_suffix => suffix }, {
376            assert_snapshot!(sql);
377        });
378    }
379}