1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, normalize_fill_with,
7 recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13pub fn build_modify_column_nullable(
16 backend: &DatabaseBackend,
17 table: &str,
18 column: &str,
19 nullable: bool,
20 fill_with: Option<&str>,
21 current_schema: &[TableDef],
22) -> Result<Vec<BuiltQuery>, QueryError> {
23 let mut queries = Vec::new();
24
25 if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
27 let update_sql = match backend {
28 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
29 "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
30 table, column, fill_value, column
31 ),
32 DatabaseBackend::MySql => format!(
33 "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
34 table, column, fill_value, column
35 ),
36 };
37 queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
38 }
39
40 match backend {
42 DatabaseBackend::Postgres => {
43 let alter_sql = if nullable {
44 format!(
45 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
46 table, column
47 )
48 } else {
49 format!(
50 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
51 table, column
52 )
53 };
54 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
55 }
56 DatabaseBackend::MySql => {
57 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
60
61 let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
62
63 let modified_col_def = ColumnDef {
65 nullable,
66 ..column_def.clone()
67 };
68
69 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
71
72 let stmt = Table::alter()
73 .table(Alias::new(table))
74 .modify_column(sea_col)
75 .to_owned();
76 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
77 }
78 DatabaseBackend::Sqlite => {
79 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
82
83 let mut new_columns = table_def.columns.clone();
85 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
86 col.nullable = nullable;
87 }
88
89 let temp_table = format!("{}_temp", table);
91
92 let create_query = build_sqlite_temp_table_create(
94 backend,
95 &temp_table,
96 table,
97 &new_columns,
98 &table_def.constraints,
99 );
100 queries.push(create_query);
101
102 let column_aliases: Vec<Alias> = table_def
104 .columns
105 .iter()
106 .map(|c| Alias::new(&c.name))
107 .collect();
108 let mut select_query = Query::select();
109 for col_alias in &column_aliases {
110 select_query = select_query.column(col_alias.clone()).to_owned();
111 }
112 select_query = select_query.from(Alias::new(table)).to_owned();
113
114 let insert_stmt = Query::insert()
115 .into_table(Alias::new(&temp_table))
116 .columns(column_aliases.clone())
117 .select_from(select_query)
118 .unwrap()
119 .to_owned();
120 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
121
122 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
124 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
125
126 queries.push(build_rename_table(&temp_table, table));
128
129 queries.extend(recreate_indexes_after_rebuild(
131 table,
132 &table_def.constraints,
133 &[],
134 ));
135 }
136 }
137
138 Ok(queries)
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use insta::{assert_snapshot, with_settings};
145 use rstest::rstest;
146 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
147
148 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
149 ColumnDef {
150 name: name.to_string(),
151 r#type: ty,
152 nullable,
153 default: None,
154 comment: None,
155 primary_key: None,
156 unique: None,
157 index: None,
158 foreign_key: None,
159 }
160 }
161
162 fn table_def(
163 name: &str,
164 columns: Vec<ColumnDef>,
165 constraints: Vec<TableConstraint>,
166 ) -> TableDef {
167 TableDef {
168 name: name.to_string(),
169 description: None,
170 columns,
171 constraints,
172 }
173 }
174
175 #[rstest]
176 #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
177 #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
178 #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
179 #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
180 #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
181 #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
182 #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
183 #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
184 #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
185 fn test_build_modify_column_nullable(
186 #[case] backend: DatabaseBackend,
187 #[case] nullable: bool,
188 #[case] fill_with: Option<&str>,
189 ) {
190 let schema = vec![table_def(
191 "users",
192 vec![
193 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
194 col(
195 "email",
196 ColumnType::Simple(SimpleColumnType::Text),
197 !nullable,
198 ),
199 ],
200 vec![],
201 )];
202
203 let result =
204 build_modify_column_nullable(&backend, "users", "email", nullable, fill_with, &schema);
205 assert!(result.is_ok());
206 let queries = result.unwrap();
207 let sql = queries
208 .iter()
209 .map(|q| q.build(backend))
210 .collect::<Vec<String>>()
211 .join("\n");
212
213 let suffix = format!(
214 "{}_{}_users{}",
215 match backend {
216 DatabaseBackend::Postgres => "postgres",
217 DatabaseBackend::MySql => "mysql",
218 DatabaseBackend::Sqlite => "sqlite",
219 },
220 if nullable { "nullable" } else { "not_null" },
221 if fill_with.is_some() {
222 "_with_fill"
223 } else {
224 ""
225 }
226 );
227
228 with_settings!({ snapshot_suffix => suffix }, {
229 assert_snapshot!(sql);
230 });
231 }
232
233 #[rstest]
235 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
236 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
237 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
238 fn test_table_not_found(#[case] backend: DatabaseBackend) {
239 if backend == DatabaseBackend::Postgres {
241 return;
242 }
243
244 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &[]);
245 assert!(result.is_err());
246 let err_msg = result.unwrap_err().to_string();
247 assert!(err_msg.contains("Table 'users' not found"));
248 }
249
250 #[rstest]
252 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
253 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
254 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
255 fn test_column_not_found(#[case] backend: DatabaseBackend) {
256 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
259 return;
260 }
261
262 let schema = vec![table_def(
263 "users",
264 vec![col(
265 "id",
266 ColumnType::Simple(SimpleColumnType::Integer),
267 false,
268 )],
269 vec![],
270 )];
271
272 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
273 assert!(result.is_err());
274 let err_msg = result.unwrap_err().to_string();
275 assert!(err_msg.contains("Column 'email' not found"));
276 }
277
278 #[rstest]
280 #[case::postgres_with_index(DatabaseBackend::Postgres)]
281 #[case::mysql_with_index(DatabaseBackend::MySql)]
282 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
283 fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
284 let schema = vec![table_def(
285 "users",
286 vec![
287 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
288 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
289 ],
290 vec![TableConstraint::Index {
291 name: Some("idx_email".into()),
292 columns: vec!["email".into()],
293 }],
294 )];
295
296 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
297 assert!(result.is_ok());
298 let queries = result.unwrap();
299 let sql = queries
300 .iter()
301 .map(|q| q.build(backend))
302 .collect::<Vec<String>>()
303 .join("\n");
304
305 if backend == DatabaseBackend::Sqlite {
307 assert!(sql.contains("CREATE INDEX"));
308 assert!(sql.contains("idx_email"));
309 }
310
311 let suffix = format!(
312 "{}_with_index",
313 match backend {
314 DatabaseBackend::Postgres => "postgres",
315 DatabaseBackend::MySql => "mysql",
316 DatabaseBackend::Sqlite => "sqlite",
317 }
318 );
319
320 with_settings!({ snapshot_suffix => suffix }, {
321 assert_snapshot!(sql);
322 });
323 }
324
325 #[rstest]
327 #[case::postgres_with_default(DatabaseBackend::Postgres)]
328 #[case::mysql_with_default(DatabaseBackend::MySql)]
329 #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
330 fn test_with_default_value(#[case] backend: DatabaseBackend) {
331 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
332 email_col.default = Some("'default@example.com'".into());
333
334 let schema = vec![table_def(
335 "users",
336 vec![
337 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
338 email_col,
339 ],
340 vec![],
341 )];
342
343 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
344 assert!(result.is_ok());
345 let queries = result.unwrap();
346 let sql = queries
347 .iter()
348 .map(|q| q.build(backend))
349 .collect::<Vec<String>>()
350 .join("\n");
351
352 if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
354 assert!(sql.contains("DEFAULT"));
355 }
356
357 let suffix = format!(
358 "{}_with_default",
359 match backend {
360 DatabaseBackend::Postgres => "postgres",
361 DatabaseBackend::MySql => "mysql",
362 DatabaseBackend::Sqlite => "sqlite",
363 }
364 );
365
366 with_settings!({ snapshot_suffix => suffix }, {
367 assert_snapshot!(sql);
368 });
369 }
370}