1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7 normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13#[allow(clippy::too_many_arguments)]
16pub fn build_modify_column_nullable(
17 backend: &DatabaseBackend,
18 table: &str,
19 column: &str,
20 nullable: bool,
21 fill_with: Option<&str>,
22 delete_null_rows: bool,
23 current_schema: &[TableDef],
24 pending_constraints: &[vespertide_core::TableConstraint],
25) -> Result<Vec<BuiltQuery>, QueryError> {
26 let mut queries = Vec::new();
27
28 if !nullable && delete_null_rows {
30 let delete_sql = match backend {
31 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
32 format!("DELETE FROM \"{}\" WHERE \"{}\" IS NULL", table, column)
33 }
34 DatabaseBackend::MySql => {
35 format!("DELETE FROM `{}` WHERE `{}` IS NULL", table, column)
36 }
37 };
38 queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
39 }
40 else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
42 let fill_value = convert_default_for_backend(&fill_value, backend);
43 let update_sql = match backend {
44 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
45 "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
46 table, column, fill_value, column
47 ),
48 DatabaseBackend::MySql => format!(
49 "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
50 table, column, fill_value, column
51 ),
52 };
53 queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
54 }
55
56 match backend {
58 DatabaseBackend::Postgres => {
59 let alter_sql = if nullable {
60 format!(
61 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
62 table, column
63 )
64 } else {
65 format!(
66 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
67 table, column
68 )
69 };
70 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
71 }
72 DatabaseBackend::MySql => {
73 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
76
77 let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
78
79 let modified_col_def = ColumnDef {
81 nullable,
82 ..column_def.clone()
83 };
84
85 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
87
88 let stmt = Table::alter()
89 .table(Alias::new(table))
90 .modify_column(sea_col)
91 .to_owned();
92 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
93 }
94 DatabaseBackend::Sqlite => {
95 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
98
99 let mut new_columns = table_def.columns.clone();
101 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
102 col.nullable = nullable;
103 }
104
105 let temp_table = format!("{}_temp", table);
107
108 let create_query = build_sqlite_temp_table_create(
110 backend,
111 &temp_table,
112 table,
113 &new_columns,
114 &table_def.constraints,
115 );
116 queries.push(create_query);
117
118 let column_aliases: Vec<Alias> = table_def
120 .columns
121 .iter()
122 .map(|c| Alias::new(&c.name))
123 .collect();
124 let mut select_query = Query::select();
125 for col_alias in &column_aliases {
126 select_query = select_query.column(col_alias.clone()).to_owned();
127 }
128 select_query = select_query.from(Alias::new(table)).to_owned();
129
130 let insert_stmt = Query::insert()
131 .into_table(Alias::new(&temp_table))
132 .columns(column_aliases.clone())
133 .select_from(select_query)
134 .unwrap()
135 .to_owned();
136 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
137
138 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
140 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
141
142 queries.push(build_rename_table(&temp_table, table));
144
145 queries.extend(recreate_indexes_after_rebuild(
147 table,
148 &table_def.constraints,
149 pending_constraints,
150 ));
151 }
152 }
153
154 Ok(queries)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use insta::{assert_snapshot, with_settings};
161 use rstest::rstest;
162 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
163
164 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
165 ColumnDef {
166 name: name.to_string(),
167 r#type: ty,
168 nullable,
169 default: None,
170 comment: None,
171 primary_key: None,
172 unique: None,
173 index: None,
174 foreign_key: None,
175 }
176 }
177
178 fn table_def(
179 name: &str,
180 columns: Vec<ColumnDef>,
181 constraints: Vec<TableConstraint>,
182 ) -> TableDef {
183 TableDef {
184 name: name.to_string(),
185 description: None,
186 columns,
187 constraints,
188 }
189 }
190
191 #[rstest]
192 #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
193 #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
194 #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
195 #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
196 #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
197 #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
198 #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
199 #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
200 #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
201 fn test_build_modify_column_nullable(
202 #[case] backend: DatabaseBackend,
203 #[case] nullable: bool,
204 #[case] fill_with: Option<&str>,
205 ) {
206 let schema = vec![table_def(
207 "users",
208 vec![
209 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
210 col(
211 "email",
212 ColumnType::Simple(SimpleColumnType::Text),
213 !nullable,
214 ),
215 ],
216 vec![],
217 )];
218
219 let result = build_modify_column_nullable(
220 &backend,
221 "users",
222 "email",
223 nullable,
224 fill_with,
225 false,
226 &schema,
227 &[],
228 );
229 assert!(result.is_ok());
230 let queries = result.unwrap();
231 let sql = queries
232 .iter()
233 .map(|q| q.build(backend))
234 .collect::<Vec<String>>()
235 .join("\n");
236
237 let suffix = format!(
238 "{}_{}_users{}",
239 match backend {
240 DatabaseBackend::Postgres => "postgres",
241 DatabaseBackend::MySql => "mysql",
242 DatabaseBackend::Sqlite => "sqlite",
243 },
244 if nullable { "nullable" } else { "not_null" },
245 if fill_with.is_some() {
246 "_with_fill"
247 } else {
248 ""
249 }
250 );
251
252 with_settings!({ snapshot_suffix => suffix }, {
253 assert_snapshot!(sql);
254 });
255 }
256
257 #[rstest]
259 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
260 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
261 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
262 fn test_table_not_found(#[case] backend: DatabaseBackend) {
263 if backend == DatabaseBackend::Postgres {
265 return;
266 }
267
268 let result =
269 build_modify_column_nullable(&backend, "users", "email", false, None, false, &[], &[]);
270 assert!(result.is_err());
271 let err_msg = result.unwrap_err().to_string();
272 assert!(err_msg.contains("Table 'users' not found"));
273 }
274
275 #[rstest]
277 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
278 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
279 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
280 fn test_column_not_found(#[case] backend: DatabaseBackend) {
281 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
284 return;
285 }
286
287 let schema = vec![table_def(
288 "users",
289 vec![col(
290 "id",
291 ColumnType::Simple(SimpleColumnType::Integer),
292 false,
293 )],
294 vec![],
295 )];
296
297 let result = build_modify_column_nullable(
298 &backend,
299 "users",
300 "email",
301 false,
302 None,
303 false,
304 &schema,
305 &[],
306 );
307 assert!(result.is_err());
308 let err_msg = result.unwrap_err().to_string();
309 assert!(err_msg.contains("Column 'email' not found"));
310 }
311
312 #[rstest]
314 #[case::postgres_with_index(DatabaseBackend::Postgres)]
315 #[case::mysql_with_index(DatabaseBackend::MySql)]
316 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
317 fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
318 let schema = vec![table_def(
319 "users",
320 vec![
321 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
322 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
323 ],
324 vec![TableConstraint::Index {
325 name: Some("idx_email".into()),
326 columns: vec!["email".into()],
327 }],
328 )];
329
330 let result = build_modify_column_nullable(
331 &backend,
332 "users",
333 "email",
334 false,
335 None,
336 false,
337 &schema,
338 &[],
339 );
340 assert!(result.is_ok());
341 let queries = result.unwrap();
342 let sql = queries
343 .iter()
344 .map(|q| q.build(backend))
345 .collect::<Vec<String>>()
346 .join("\n");
347
348 if backend == DatabaseBackend::Sqlite {
350 assert!(sql.contains("CREATE INDEX"));
351 assert!(sql.contains("idx_email"));
352 }
353
354 let suffix = format!(
355 "{}_with_index",
356 match backend {
357 DatabaseBackend::Postgres => "postgres",
358 DatabaseBackend::MySql => "mysql",
359 DatabaseBackend::Sqlite => "sqlite",
360 }
361 );
362
363 with_settings!({ snapshot_suffix => suffix }, {
364 assert_snapshot!(sql);
365 });
366 }
367
368 #[rstest]
370 #[case::postgres_fill_now(DatabaseBackend::Postgres)]
371 #[case::mysql_fill_now(DatabaseBackend::MySql)]
372 #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
373 fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
374 let schema = vec![table_def(
375 "orders",
376 vec![
377 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
378 col(
379 "paid_at",
380 ColumnType::Simple(SimpleColumnType::Timestamptz),
381 true,
382 ),
383 ],
384 vec![],
385 )];
386
387 let result = build_modify_column_nullable(
388 &backend,
389 "orders",
390 "paid_at",
391 false,
392 Some("NOW()"),
393 false,
394 &schema,
395 &[],
396 );
397 assert!(result.is_ok());
398 let queries = result.unwrap();
399 let sql = queries
400 .iter()
401 .map(|q| q.build(backend))
402 .collect::<Vec<String>>()
403 .join("\n");
404
405 assert!(
407 !sql.contains("NOW()"),
408 "SQL should not contain NOW(), got: {}",
409 sql
410 );
411 assert!(
412 sql.contains("CURRENT_TIMESTAMP"),
413 "SQL should contain CURRENT_TIMESTAMP, got: {}",
414 sql
415 );
416
417 let suffix = format!(
418 "{}_fill_now",
419 match backend {
420 DatabaseBackend::Postgres => "postgres",
421 DatabaseBackend::MySql => "mysql",
422 DatabaseBackend::Sqlite => "sqlite",
423 }
424 );
425
426 with_settings!({ snapshot_suffix => suffix }, {
427 assert_snapshot!(sql);
428 });
429 }
430
431 #[rstest]
433 #[case::postgres_with_default(DatabaseBackend::Postgres)]
434 #[case::mysql_with_default(DatabaseBackend::MySql)]
435 #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
436 fn test_with_default_value(#[case] backend: DatabaseBackend) {
437 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
438 email_col.default = Some("'default@example.com'".into());
439
440 let schema = vec![table_def(
441 "users",
442 vec![
443 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
444 email_col,
445 ],
446 vec![],
447 )];
448
449 let result = build_modify_column_nullable(
450 &backend,
451 "users",
452 "email",
453 false,
454 None,
455 false,
456 &schema,
457 &[],
458 );
459 assert!(result.is_ok());
460 let queries = result.unwrap();
461 let sql = queries
462 .iter()
463 .map(|q| q.build(backend))
464 .collect::<Vec<String>>()
465 .join("\n");
466
467 if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
469 assert!(sql.contains("DEFAULT"));
470 }
471
472 let suffix = format!(
473 "{}_with_default",
474 match backend {
475 DatabaseBackend::Postgres => "postgres",
476 DatabaseBackend::MySql => "mysql",
477 DatabaseBackend::Sqlite => "sqlite",
478 }
479 );
480
481 with_settings!({ snapshot_suffix => suffix }, {
482 assert_snapshot!(sql);
483 });
484 }
485
486 #[rstest]
488 #[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
489 #[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
490 #[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
491 fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
492 let schema = vec![table_def(
493 "orders",
494 vec![
495 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
496 col(
497 "user_id",
498 ColumnType::Simple(SimpleColumnType::Integer),
499 true,
500 ),
501 ],
502 vec![],
503 )];
504
505 let result = build_modify_column_nullable(
506 &backend,
507 "orders",
508 "user_id",
509 false,
510 None,
511 true,
512 &schema,
513 &[],
514 );
515 assert!(result.is_ok());
516 let queries = result.unwrap();
517 let sql = queries
518 .iter()
519 .map(|q| q.build(backend))
520 .collect::<Vec<String>>()
521 .join("\n");
522
523 assert!(
524 sql.contains("DELETE FROM"),
525 "Expected DELETE FROM in SQL, got: {}",
526 sql
527 );
528 assert!(
529 sql.contains("IS NULL"),
530 "Expected IS NULL in SQL, got: {}",
531 sql
532 );
533 assert!(
534 !sql.contains("UPDATE"),
535 "Should NOT contain UPDATE, got: {}",
536 sql
537 );
538
539 let suffix = format!(
540 "{}_delete_null_rows",
541 match backend {
542 DatabaseBackend::Postgres => "postgres",
543 DatabaseBackend::MySql => "mysql",
544 DatabaseBackend::Sqlite => "sqlite",
545 }
546 );
547
548 with_settings!({ snapshot_suffix => suffix }, {
549 assert_snapshot!(sql);
550 });
551 }
552
553 #[rstest]
555 #[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
556 fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
557 let schema = vec![table_def(
558 "orders",
559 vec![
560 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
561 col(
562 "user_id",
563 ColumnType::Simple(SimpleColumnType::Integer),
564 false,
565 ),
566 ],
567 vec![],
568 )];
569
570 let result = build_modify_column_nullable(
571 &backend,
572 "orders",
573 "user_id",
574 true,
575 None,
576 true,
577 &schema,
578 &[],
579 );
580 assert!(result.is_ok());
581 let queries = result.unwrap();
582 let sql = queries
583 .iter()
584 .map(|q| q.build(backend))
585 .collect::<Vec<String>>()
586 .join("\n");
587
588 assert!(
589 !sql.contains("DELETE FROM"),
590 "Should NOT contain DELETE when nullable=true, got: {}",
591 sql
592 );
593 }
594}