1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7 normalize_fill_with, quote_ident, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13#[expect(
16 clippy::too_many_arguments,
17 reason = "nullability builder needs action fields, fill strategy, backend, and SQLite rebuild context; NullabilityContext is deferred"
18)]
19pub fn build_modify_column_nullable(
20 backend: DatabaseBackend,
21 table: &str,
22 column: &str,
23 nullable: bool,
24 fill_with: Option<&str>,
25 delete_null_rows: bool,
26 current_schema: &[TableDef],
27 pending_constraints: &[vespertide_core::TableConstraint],
28) -> Result<Vec<BuiltQuery>, QueryError> {
29 let mut queries = Vec::new();
30
31 if !nullable && delete_null_rows {
33 let quoted_table = quote_ident(table, backend);
34 let quoted_column = quote_ident(column, backend);
35 let delete_sql = format!("DELETE FROM {quoted_table} WHERE {quoted_column} IS NULL");
36 queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
37 }
38 else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
40 let fill_value = convert_default_for_backend(&fill_value, backend);
41 let quoted_table = quote_ident(table, backend);
42 let quoted_column = quote_ident(column, backend);
43 let update_sql = format!(
44 "UPDATE {quoted_table} SET {quoted_column} = {fill_value} WHERE {quoted_column} IS NULL"
45 );
46 queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
47 }
48
49 match backend {
51 DatabaseBackend::Postgres => {
52 let quoted_table = quote_ident(table, backend);
53 let quoted_column = quote_ident(column, backend);
54 let alter_sql = if nullable {
55 format!("ALTER TABLE {quoted_table} ALTER COLUMN {quoted_column} DROP NOT NULL")
56 } else {
57 format!("ALTER TABLE {quoted_table} ALTER COLUMN {quoted_column} SET NOT NULL")
58 };
59 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
60 }
61 DatabaseBackend::MySql => {
62 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::SchemaError(format!("Table '{table}' not found in current schema. MySQL requires current schema information to modify column nullability.")))?;
65
66 let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::SchemaError(format!("Column '{column}' not found in table '{table}'. MySQL requires column information to modify nullability.")))?;
67
68 let modified_col_def = ColumnDef {
70 nullable,
71 ..column_def.clone()
72 };
73
74 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
76
77 let stmt = Table::alter()
78 .table(Alias::new(table))
79 .modify_column(sea_col)
80 .to_owned();
81 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
82 }
83 DatabaseBackend::Sqlite => {
84 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::SchemaError(format!("Table '{table}' not found in current schema. SQLite requires current schema information to modify column nullability.")))?;
87
88 let mut new_columns = table_def.columns.clone();
90 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
91 col.nullable = nullable;
92 }
93
94 let temp_table = format!("{table}_temp");
96
97 let create_query = build_sqlite_temp_table_create(
99 backend,
100 &temp_table,
101 table,
102 &new_columns,
103 &table_def.constraints,
104 );
105 queries.push(create_query);
106
107 let column_aliases: Vec<Alias> = table_def
109 .columns
110 .iter()
111 .map(|c| Alias::new(&c.name))
112 .collect();
113 let mut select_query = Query::select();
114 for col_alias in &column_aliases {
115 select_query.column(col_alias.clone());
116 }
117 select_query.from(Alias::new(table));
118
119 let insert_stmt = Query::insert()
120 .into_table(Alias::new(&temp_table))
121 .columns(column_aliases.clone())
122 .select_from(select_query)
123 .unwrap()
124 .to_owned();
125 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
126
127 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
129 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
130
131 queries.push(build_rename_table(&temp_table, table));
133
134 queries.extend(recreate_indexes_after_rebuild(
136 table,
137 &table_def.constraints,
138 pending_constraints,
139 ));
140 }
141 }
142
143 Ok(queries)
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::test_support::col_n as col;
150 use insta::{assert_snapshot, with_settings};
151 use rstest::rstest;
152 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
153
154 fn table_def(
155 name: &str,
156 columns: Vec<ColumnDef>,
157 constraints: Vec<TableConstraint>,
158 ) -> TableDef {
159 TableDef {
160 name: name.into(),
161 description: None,
162 columns,
163 constraints,
164 }
165 }
166
167 #[rstest]
168 #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
169 #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
170 #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
171 #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
172 #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
173 #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
174 #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
175 #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
176 #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
177 fn test_build_modify_column_nullable(
178 #[case] backend: DatabaseBackend,
179 #[case] nullable: bool,
180 #[case] fill_with: Option<&str>,
181 ) {
182 let schema = vec![table_def(
183 "users",
184 vec![
185 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
186 col(
187 "email",
188 ColumnType::Simple(SimpleColumnType::Text),
189 !nullable,
190 ),
191 ],
192 vec![],
193 )];
194
195 let result = build_modify_column_nullable(
196 backend,
197 "users",
198 "email",
199 nullable,
200 fill_with,
201 false,
202 &schema,
203 &[],
204 );
205 assert!(result.is_ok());
206 let queries = result.unwrap();
207 let sql = queries
208 .iter()
209 .map(|q| q.build(backend))
210 .collect::<Vec<String>>()
211 .join("\n");
212
213 let suffix = format!(
214 "{}_{}_users{}",
215 match backend {
216 DatabaseBackend::Postgres => "postgres",
217 DatabaseBackend::MySql => "mysql",
218 DatabaseBackend::Sqlite => "sqlite",
219 },
220 if nullable { "nullable" } else { "not_null" },
221 if fill_with.is_some() {
222 "_with_fill"
223 } else {
224 ""
225 }
226 );
227
228 with_settings!({ snapshot_suffix => suffix }, {
229 assert_snapshot!(sql);
230 });
231 }
232
233 #[rstest]
235 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
236 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
237 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
238 fn test_table_not_found(#[case] backend: DatabaseBackend) {
239 if backend == DatabaseBackend::Postgres {
241 return;
242 }
243
244 let result =
245 build_modify_column_nullable(backend, "users", "email", false, None, false, &[], &[]);
246 assert!(result.is_err());
247 let err_msg = result.unwrap_err().to_string();
248 assert!(err_msg.contains("Table 'users' not found"));
249 }
250
251 #[rstest]
253 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
254 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
255 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
256 fn test_column_not_found(#[case] backend: DatabaseBackend) {
257 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
260 return;
261 }
262
263 let schema = vec![table_def(
264 "users",
265 vec![col(
266 "id",
267 ColumnType::Simple(SimpleColumnType::Integer),
268 false,
269 )],
270 vec![],
271 )];
272
273 let result = build_modify_column_nullable(
274 backend,
275 "users",
276 "email",
277 false,
278 None,
279 false,
280 &schema,
281 &[],
282 );
283 assert!(result.is_err());
284 let err_msg = result.unwrap_err().to_string();
285 assert!(err_msg.contains("Column 'email' not found"));
286 }
287
288 #[rstest]
290 #[case::postgres_with_index(DatabaseBackend::Postgres)]
291 #[case::mysql_with_index(DatabaseBackend::MySql)]
292 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
293 fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
294 let schema = vec![table_def(
295 "users",
296 vec![
297 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
298 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
299 ],
300 vec![TableConstraint::Index {
301 name: Some("idx_email".into()),
302 columns: vec!["email".into()],
303 }],
304 )];
305
306 let result = build_modify_column_nullable(
307 backend,
308 "users",
309 "email",
310 false,
311 None,
312 false,
313 &schema,
314 &[],
315 );
316 assert!(result.is_ok());
317 let queries = result.unwrap();
318 let sql = queries
319 .iter()
320 .map(|q| q.build(backend))
321 .collect::<Vec<String>>()
322 .join("\n");
323
324 if backend == DatabaseBackend::Sqlite {
326 assert!(sql.contains("CREATE INDEX"));
327 assert!(sql.contains("idx_email"));
328 }
329
330 let suffix = format!(
331 "{}_with_index",
332 match backend {
333 DatabaseBackend::Postgres => "postgres",
334 DatabaseBackend::MySql => "mysql",
335 DatabaseBackend::Sqlite => "sqlite",
336 }
337 );
338
339 with_settings!({ snapshot_suffix => suffix }, {
340 assert_snapshot!(sql);
341 });
342 }
343
344 #[rstest]
346 #[case::postgres_fill_now(DatabaseBackend::Postgres)]
347 #[case::mysql_fill_now(DatabaseBackend::MySql)]
348 #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
349 fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
350 let schema = vec![table_def(
351 "orders",
352 vec![
353 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
354 col(
355 "paid_at",
356 ColumnType::Simple(SimpleColumnType::Timestamptz),
357 true,
358 ),
359 ],
360 vec![],
361 )];
362
363 let result = build_modify_column_nullable(
364 backend,
365 "orders",
366 "paid_at",
367 false,
368 Some("NOW()"),
369 false,
370 &schema,
371 &[],
372 );
373 assert!(result.is_ok());
374 let queries = result.unwrap();
375 let sql = queries
376 .iter()
377 .map(|q| q.build(backend))
378 .collect::<Vec<String>>()
379 .join("\n");
380
381 assert!(
383 !sql.contains("NOW()"),
384 "SQL should not contain NOW(), got: {sql}"
385 );
386 assert!(
387 sql.contains("CURRENT_TIMESTAMP"),
388 "SQL should contain CURRENT_TIMESTAMP, got: {sql}"
389 );
390
391 let suffix = format!(
392 "{}_fill_now",
393 match backend {
394 DatabaseBackend::Postgres => "postgres",
395 DatabaseBackend::MySql => "mysql",
396 DatabaseBackend::Sqlite => "sqlite",
397 }
398 );
399
400 with_settings!({ snapshot_suffix => suffix }, {
401 assert_snapshot!(sql);
402 });
403 }
404
405 #[rstest]
407 #[case::postgres_with_default(DatabaseBackend::Postgres)]
408 #[case::mysql_with_default(DatabaseBackend::MySql)]
409 #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
410 fn test_with_default_value(#[case] backend: DatabaseBackend) {
411 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
412 email_col.default = Some("'default@example.com'".into());
413
414 let schema = vec![table_def(
415 "users",
416 vec![
417 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
418 email_col,
419 ],
420 vec![],
421 )];
422
423 let result = build_modify_column_nullable(
424 backend,
425 "users",
426 "email",
427 false,
428 None,
429 false,
430 &schema,
431 &[],
432 );
433 assert!(result.is_ok());
434 let queries = result.unwrap();
435 let sql = queries
436 .iter()
437 .map(|q| q.build(backend))
438 .collect::<Vec<String>>()
439 .join("\n");
440
441 if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
443 assert!(sql.contains("DEFAULT"));
444 }
445
446 let suffix = format!(
447 "{}_with_default",
448 match backend {
449 DatabaseBackend::Postgres => "postgres",
450 DatabaseBackend::MySql => "mysql",
451 DatabaseBackend::Sqlite => "sqlite",
452 }
453 );
454
455 with_settings!({ snapshot_suffix => suffix }, {
456 assert_snapshot!(sql);
457 });
458 }
459
460 #[rstest]
462 #[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
463 #[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
464 #[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
465 fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
466 let schema = vec![table_def(
467 "orders",
468 vec![
469 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
470 col(
471 "user_id",
472 ColumnType::Simple(SimpleColumnType::Integer),
473 true,
474 ),
475 ],
476 vec![],
477 )];
478
479 let result = build_modify_column_nullable(
480 backend,
481 "orders",
482 "user_id",
483 false,
484 None,
485 true,
486 &schema,
487 &[],
488 );
489 assert!(result.is_ok());
490 let queries = result.unwrap();
491 let sql = queries
492 .iter()
493 .map(|q| q.build(backend))
494 .collect::<Vec<String>>()
495 .join("\n");
496
497 assert!(
498 sql.contains("DELETE FROM"),
499 "Expected DELETE FROM in SQL, got: {sql}"
500 );
501 assert!(
502 sql.contains("IS NULL"),
503 "Expected IS NULL in SQL, got: {sql}"
504 );
505 assert!(
506 !sql.contains("UPDATE"),
507 "Should NOT contain UPDATE, got: {sql}"
508 );
509
510 let suffix = format!(
511 "{}_delete_null_rows",
512 match backend {
513 DatabaseBackend::Postgres => "postgres",
514 DatabaseBackend::MySql => "mysql",
515 DatabaseBackend::Sqlite => "sqlite",
516 }
517 );
518
519 with_settings!({ snapshot_suffix => suffix }, {
520 assert_snapshot!(sql);
521 });
522 }
523
524 #[rstest]
526 #[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
527 fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
528 let schema = vec![table_def(
529 "orders",
530 vec![
531 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
532 col(
533 "user_id",
534 ColumnType::Simple(SimpleColumnType::Integer),
535 false,
536 ),
537 ],
538 vec![],
539 )];
540
541 let result = build_modify_column_nullable(
542 backend,
543 "orders",
544 "user_id",
545 true,
546 None,
547 true,
548 &schema,
549 &[],
550 );
551 assert!(result.is_ok());
552 let queries = result.unwrap();
553 let sql = queries
554 .iter()
555 .map(|q| q.build(backend))
556 .collect::<Vec<String>>()
557 .join("\n");
558
559 assert!(
560 !sql.contains("DELETE FROM"),
561 "Should NOT contain DELETE when nullable=true, got: {sql}"
562 );
563 }
564}