1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7 normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13pub fn build_modify_column_nullable(
16 backend: &DatabaseBackend,
17 table: &str,
18 column: &str,
19 nullable: bool,
20 fill_with: Option<&str>,
21 delete_null_rows: bool,
22 current_schema: &[TableDef],
23) -> Result<Vec<BuiltQuery>, QueryError> {
24 let mut queries = Vec::new();
25
26 if !nullable && delete_null_rows {
28 let delete_sql = match backend {
29 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
30 format!("DELETE FROM \"{}\" WHERE \"{}\" IS NULL", table, column)
31 }
32 DatabaseBackend::MySql => {
33 format!("DELETE FROM `{}` WHERE `{}` IS NULL", table, column)
34 }
35 };
36 queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
37 }
38 else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
40 let fill_value = convert_default_for_backend(&fill_value, backend);
41 let update_sql = match backend {
42 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
43 "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
44 table, column, fill_value, column
45 ),
46 DatabaseBackend::MySql => format!(
47 "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
48 table, column, fill_value, column
49 ),
50 };
51 queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
52 }
53
54 match backend {
56 DatabaseBackend::Postgres => {
57 let alter_sql = if nullable {
58 format!(
59 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
60 table, column
61 )
62 } else {
63 format!(
64 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
65 table, column
66 )
67 };
68 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
69 }
70 DatabaseBackend::MySql => {
71 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
74
75 let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
76
77 let modified_col_def = ColumnDef {
79 nullable,
80 ..column_def.clone()
81 };
82
83 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
85
86 let stmt = Table::alter()
87 .table(Alias::new(table))
88 .modify_column(sea_col)
89 .to_owned();
90 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
91 }
92 DatabaseBackend::Sqlite => {
93 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
96
97 let mut new_columns = table_def.columns.clone();
99 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
100 col.nullable = nullable;
101 }
102
103 let temp_table = format!("{}_temp", table);
105
106 let create_query = build_sqlite_temp_table_create(
108 backend,
109 &temp_table,
110 table,
111 &new_columns,
112 &table_def.constraints,
113 );
114 queries.push(create_query);
115
116 let column_aliases: Vec<Alias> = table_def
118 .columns
119 .iter()
120 .map(|c| Alias::new(&c.name))
121 .collect();
122 let mut select_query = Query::select();
123 for col_alias in &column_aliases {
124 select_query = select_query.column(col_alias.clone()).to_owned();
125 }
126 select_query = select_query.from(Alias::new(table)).to_owned();
127
128 let insert_stmt = Query::insert()
129 .into_table(Alias::new(&temp_table))
130 .columns(column_aliases.clone())
131 .select_from(select_query)
132 .unwrap()
133 .to_owned();
134 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
135
136 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
138 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
139
140 queries.push(build_rename_table(&temp_table, table));
142
143 queries.extend(recreate_indexes_after_rebuild(
145 table,
146 &table_def.constraints,
147 &[],
148 ));
149 }
150 }
151
152 Ok(queries)
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use insta::{assert_snapshot, with_settings};
159 use rstest::rstest;
160 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
161
162 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
163 ColumnDef {
164 name: name.to_string(),
165 r#type: ty,
166 nullable,
167 default: None,
168 comment: None,
169 primary_key: None,
170 unique: None,
171 index: None,
172 foreign_key: None,
173 }
174 }
175
176 fn table_def(
177 name: &str,
178 columns: Vec<ColumnDef>,
179 constraints: Vec<TableConstraint>,
180 ) -> TableDef {
181 TableDef {
182 name: name.to_string(),
183 description: None,
184 columns,
185 constraints,
186 }
187 }
188
189 #[rstest]
190 #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
191 #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
192 #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
193 #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
194 #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
195 #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
196 #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
197 #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
198 #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
199 fn test_build_modify_column_nullable(
200 #[case] backend: DatabaseBackend,
201 #[case] nullable: bool,
202 #[case] fill_with: Option<&str>,
203 ) {
204 let schema = vec![table_def(
205 "users",
206 vec![
207 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
208 col(
209 "email",
210 ColumnType::Simple(SimpleColumnType::Text),
211 !nullable,
212 ),
213 ],
214 vec![],
215 )];
216
217 let result = build_modify_column_nullable(
218 &backend, "users", "email", nullable, fill_with, false, &schema,
219 );
220 assert!(result.is_ok());
221 let queries = result.unwrap();
222 let sql = queries
223 .iter()
224 .map(|q| q.build(backend))
225 .collect::<Vec<String>>()
226 .join("\n");
227
228 let suffix = format!(
229 "{}_{}_users{}",
230 match backend {
231 DatabaseBackend::Postgres => "postgres",
232 DatabaseBackend::MySql => "mysql",
233 DatabaseBackend::Sqlite => "sqlite",
234 },
235 if nullable { "nullable" } else { "not_null" },
236 if fill_with.is_some() {
237 "_with_fill"
238 } else {
239 ""
240 }
241 );
242
243 with_settings!({ snapshot_suffix => suffix }, {
244 assert_snapshot!(sql);
245 });
246 }
247
248 #[rstest]
250 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
251 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
252 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
253 fn test_table_not_found(#[case] backend: DatabaseBackend) {
254 if backend == DatabaseBackend::Postgres {
256 return;
257 }
258
259 let result =
260 build_modify_column_nullable(&backend, "users", "email", false, None, false, &[]);
261 assert!(result.is_err());
262 let err_msg = result.unwrap_err().to_string();
263 assert!(err_msg.contains("Table 'users' not found"));
264 }
265
266 #[rstest]
268 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
269 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
270 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
271 fn test_column_not_found(#[case] backend: DatabaseBackend) {
272 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
275 return;
276 }
277
278 let schema = vec![table_def(
279 "users",
280 vec![col(
281 "id",
282 ColumnType::Simple(SimpleColumnType::Integer),
283 false,
284 )],
285 vec![],
286 )];
287
288 let result =
289 build_modify_column_nullable(&backend, "users", "email", false, None, false, &schema);
290 assert!(result.is_err());
291 let err_msg = result.unwrap_err().to_string();
292 assert!(err_msg.contains("Column 'email' not found"));
293 }
294
295 #[rstest]
297 #[case::postgres_with_index(DatabaseBackend::Postgres)]
298 #[case::mysql_with_index(DatabaseBackend::MySql)]
299 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
300 fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
301 let schema = vec![table_def(
302 "users",
303 vec![
304 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
305 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
306 ],
307 vec![TableConstraint::Index {
308 name: Some("idx_email".into()),
309 columns: vec!["email".into()],
310 }],
311 )];
312
313 let result =
314 build_modify_column_nullable(&backend, "users", "email", false, None, false, &schema);
315 assert!(result.is_ok());
316 let queries = result.unwrap();
317 let sql = queries
318 .iter()
319 .map(|q| q.build(backend))
320 .collect::<Vec<String>>()
321 .join("\n");
322
323 if backend == DatabaseBackend::Sqlite {
325 assert!(sql.contains("CREATE INDEX"));
326 assert!(sql.contains("idx_email"));
327 }
328
329 let suffix = format!(
330 "{}_with_index",
331 match backend {
332 DatabaseBackend::Postgres => "postgres",
333 DatabaseBackend::MySql => "mysql",
334 DatabaseBackend::Sqlite => "sqlite",
335 }
336 );
337
338 with_settings!({ snapshot_suffix => suffix }, {
339 assert_snapshot!(sql);
340 });
341 }
342
343 #[rstest]
345 #[case::postgres_fill_now(DatabaseBackend::Postgres)]
346 #[case::mysql_fill_now(DatabaseBackend::MySql)]
347 #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
348 fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
349 let schema = vec![table_def(
350 "orders",
351 vec![
352 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
353 col(
354 "paid_at",
355 ColumnType::Simple(SimpleColumnType::Timestamptz),
356 true,
357 ),
358 ],
359 vec![],
360 )];
361
362 let result = build_modify_column_nullable(
363 &backend,
364 "orders",
365 "paid_at",
366 false,
367 Some("NOW()"),
368 false,
369 &schema,
370 );
371 assert!(result.is_ok());
372 let queries = result.unwrap();
373 let sql = queries
374 .iter()
375 .map(|q| q.build(backend))
376 .collect::<Vec<String>>()
377 .join("\n");
378
379 assert!(
381 !sql.contains("NOW()"),
382 "SQL should not contain NOW(), got: {}",
383 sql
384 );
385 assert!(
386 sql.contains("CURRENT_TIMESTAMP"),
387 "SQL should contain CURRENT_TIMESTAMP, got: {}",
388 sql
389 );
390
391 let suffix = format!(
392 "{}_fill_now",
393 match backend {
394 DatabaseBackend::Postgres => "postgres",
395 DatabaseBackend::MySql => "mysql",
396 DatabaseBackend::Sqlite => "sqlite",
397 }
398 );
399
400 with_settings!({ snapshot_suffix => suffix }, {
401 assert_snapshot!(sql);
402 });
403 }
404
405 #[rstest]
407 #[case::postgres_with_default(DatabaseBackend::Postgres)]
408 #[case::mysql_with_default(DatabaseBackend::MySql)]
409 #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
410 fn test_with_default_value(#[case] backend: DatabaseBackend) {
411 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
412 email_col.default = Some("'default@example.com'".into());
413
414 let schema = vec![table_def(
415 "users",
416 vec![
417 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
418 email_col,
419 ],
420 vec![],
421 )];
422
423 let result =
424 build_modify_column_nullable(&backend, "users", "email", false, None, false, &schema);
425 assert!(result.is_ok());
426 let queries = result.unwrap();
427 let sql = queries
428 .iter()
429 .map(|q| q.build(backend))
430 .collect::<Vec<String>>()
431 .join("\n");
432
433 if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
435 assert!(sql.contains("DEFAULT"));
436 }
437
438 let suffix = format!(
439 "{}_with_default",
440 match backend {
441 DatabaseBackend::Postgres => "postgres",
442 DatabaseBackend::MySql => "mysql",
443 DatabaseBackend::Sqlite => "sqlite",
444 }
445 );
446
447 with_settings!({ snapshot_suffix => suffix }, {
448 assert_snapshot!(sql);
449 });
450 }
451
452 #[rstest]
454 #[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
455 #[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
456 #[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
457 fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
458 let schema = vec![table_def(
459 "orders",
460 vec![
461 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
462 col(
463 "user_id",
464 ColumnType::Simple(SimpleColumnType::Integer),
465 true,
466 ),
467 ],
468 vec![],
469 )];
470
471 let result =
472 build_modify_column_nullable(&backend, "orders", "user_id", false, None, true, &schema);
473 assert!(result.is_ok());
474 let queries = result.unwrap();
475 let sql = queries
476 .iter()
477 .map(|q| q.build(backend))
478 .collect::<Vec<String>>()
479 .join("\n");
480
481 assert!(
482 sql.contains("DELETE FROM"),
483 "Expected DELETE FROM in SQL, got: {}",
484 sql
485 );
486 assert!(
487 sql.contains("IS NULL"),
488 "Expected IS NULL in SQL, got: {}",
489 sql
490 );
491 assert!(
492 !sql.contains("UPDATE"),
493 "Should NOT contain UPDATE, got: {}",
494 sql
495 );
496
497 let suffix = format!(
498 "{}_delete_null_rows",
499 match backend {
500 DatabaseBackend::Postgres => "postgres",
501 DatabaseBackend::MySql => "mysql",
502 DatabaseBackend::Sqlite => "sqlite",
503 }
504 );
505
506 with_settings!({ snapshot_suffix => suffix }, {
507 assert_snapshot!(sql);
508 });
509 }
510
511 #[rstest]
513 #[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
514 fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
515 let schema = vec![table_def(
516 "orders",
517 vec![
518 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
519 col(
520 "user_id",
521 ColumnType::Simple(SimpleColumnType::Integer),
522 false,
523 ),
524 ],
525 vec![],
526 )];
527
528 let result =
529 build_modify_column_nullable(&backend, "orders", "user_id", true, None, true, &schema);
530 assert!(result.is_ok());
531 let queries = result.unwrap();
532 let sql = queries
533 .iter()
534 .map(|q| q.build(backend))
535 .collect::<Vec<String>>()
536 .join("\n");
537
538 assert!(
539 !sql.contains("DELETE FROM"),
540 "Should NOT contain DELETE when nullable=true, got: {}",
541 sql
542 );
543 }
544}