1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6 build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
7 normalize_fill_with, recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13pub fn build_modify_column_nullable(
16 backend: &DatabaseBackend,
17 table: &str,
18 column: &str,
19 nullable: bool,
20 fill_with: Option<&str>,
21 current_schema: &[TableDef],
22) -> Result<Vec<BuiltQuery>, QueryError> {
23 let mut queries = Vec::new();
24
25 if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
27 let fill_value = convert_default_for_backend(&fill_value, backend);
28 let update_sql = match backend {
29 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
30 "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
31 table, column, fill_value, column
32 ),
33 DatabaseBackend::MySql => format!(
34 "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
35 table, column, fill_value, column
36 ),
37 };
38 queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
39 }
40
41 match backend {
43 DatabaseBackend::Postgres => {
44 let alter_sql = if nullable {
45 format!(
46 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
47 table, column
48 )
49 } else {
50 format!(
51 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
52 table, column
53 )
54 };
55 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
56 }
57 DatabaseBackend::MySql => {
58 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
61
62 let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
63
64 let modified_col_def = ColumnDef {
66 nullable,
67 ..column_def.clone()
68 };
69
70 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
72
73 let stmt = Table::alter()
74 .table(Alias::new(table))
75 .modify_column(sea_col)
76 .to_owned();
77 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
78 }
79 DatabaseBackend::Sqlite => {
80 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
83
84 let mut new_columns = table_def.columns.clone();
86 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
87 col.nullable = nullable;
88 }
89
90 let temp_table = format!("{}_temp", table);
92
93 let create_query = build_sqlite_temp_table_create(
95 backend,
96 &temp_table,
97 table,
98 &new_columns,
99 &table_def.constraints,
100 );
101 queries.push(create_query);
102
103 let column_aliases: Vec<Alias> = table_def
105 .columns
106 .iter()
107 .map(|c| Alias::new(&c.name))
108 .collect();
109 let mut select_query = Query::select();
110 for col_alias in &column_aliases {
111 select_query = select_query.column(col_alias.clone()).to_owned();
112 }
113 select_query = select_query.from(Alias::new(table)).to_owned();
114
115 let insert_stmt = Query::insert()
116 .into_table(Alias::new(&temp_table))
117 .columns(column_aliases.clone())
118 .select_from(select_query)
119 .unwrap()
120 .to_owned();
121 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
122
123 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
125 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
126
127 queries.push(build_rename_table(&temp_table, table));
129
130 queries.extend(recreate_indexes_after_rebuild(
132 table,
133 &table_def.constraints,
134 &[],
135 ));
136 }
137 }
138
139 Ok(queries)
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use insta::{assert_snapshot, with_settings};
146 use rstest::rstest;
147 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
148
149 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
150 ColumnDef {
151 name: name.to_string(),
152 r#type: ty,
153 nullable,
154 default: None,
155 comment: None,
156 primary_key: None,
157 unique: None,
158 index: None,
159 foreign_key: None,
160 }
161 }
162
163 fn table_def(
164 name: &str,
165 columns: Vec<ColumnDef>,
166 constraints: Vec<TableConstraint>,
167 ) -> TableDef {
168 TableDef {
169 name: name.to_string(),
170 description: None,
171 columns,
172 constraints,
173 }
174 }
175
176 #[rstest]
177 #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
178 #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
179 #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
180 #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
181 #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
182 #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
183 #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
184 #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
185 #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
186 fn test_build_modify_column_nullable(
187 #[case] backend: DatabaseBackend,
188 #[case] nullable: bool,
189 #[case] fill_with: Option<&str>,
190 ) {
191 let schema = vec![table_def(
192 "users",
193 vec![
194 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
195 col(
196 "email",
197 ColumnType::Simple(SimpleColumnType::Text),
198 !nullable,
199 ),
200 ],
201 vec![],
202 )];
203
204 let result =
205 build_modify_column_nullable(&backend, "users", "email", nullable, fill_with, &schema);
206 assert!(result.is_ok());
207 let queries = result.unwrap();
208 let sql = queries
209 .iter()
210 .map(|q| q.build(backend))
211 .collect::<Vec<String>>()
212 .join("\n");
213
214 let suffix = format!(
215 "{}_{}_users{}",
216 match backend {
217 DatabaseBackend::Postgres => "postgres",
218 DatabaseBackend::MySql => "mysql",
219 DatabaseBackend::Sqlite => "sqlite",
220 },
221 if nullable { "nullable" } else { "not_null" },
222 if fill_with.is_some() {
223 "_with_fill"
224 } else {
225 ""
226 }
227 );
228
229 with_settings!({ snapshot_suffix => suffix }, {
230 assert_snapshot!(sql);
231 });
232 }
233
234 #[rstest]
236 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
237 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
238 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
239 fn test_table_not_found(#[case] backend: DatabaseBackend) {
240 if backend == DatabaseBackend::Postgres {
242 return;
243 }
244
245 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &[]);
246 assert!(result.is_err());
247 let err_msg = result.unwrap_err().to_string();
248 assert!(err_msg.contains("Table 'users' not found"));
249 }
250
251 #[rstest]
253 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
254 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
255 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
256 fn test_column_not_found(#[case] backend: DatabaseBackend) {
257 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
260 return;
261 }
262
263 let schema = vec![table_def(
264 "users",
265 vec![col(
266 "id",
267 ColumnType::Simple(SimpleColumnType::Integer),
268 false,
269 )],
270 vec![],
271 )];
272
273 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
274 assert!(result.is_err());
275 let err_msg = result.unwrap_err().to_string();
276 assert!(err_msg.contains("Column 'email' not found"));
277 }
278
279 #[rstest]
281 #[case::postgres_with_index(DatabaseBackend::Postgres)]
282 #[case::mysql_with_index(DatabaseBackend::MySql)]
283 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
284 fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
285 let schema = vec![table_def(
286 "users",
287 vec![
288 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
289 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
290 ],
291 vec![TableConstraint::Index {
292 name: Some("idx_email".into()),
293 columns: vec!["email".into()],
294 }],
295 )];
296
297 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
298 assert!(result.is_ok());
299 let queries = result.unwrap();
300 let sql = queries
301 .iter()
302 .map(|q| q.build(backend))
303 .collect::<Vec<String>>()
304 .join("\n");
305
306 if backend == DatabaseBackend::Sqlite {
308 assert!(sql.contains("CREATE INDEX"));
309 assert!(sql.contains("idx_email"));
310 }
311
312 let suffix = format!(
313 "{}_with_index",
314 match backend {
315 DatabaseBackend::Postgres => "postgres",
316 DatabaseBackend::MySql => "mysql",
317 DatabaseBackend::Sqlite => "sqlite",
318 }
319 );
320
321 with_settings!({ snapshot_suffix => suffix }, {
322 assert_snapshot!(sql);
323 });
324 }
325
326 #[rstest]
328 #[case::postgres_fill_now(DatabaseBackend::Postgres)]
329 #[case::mysql_fill_now(DatabaseBackend::MySql)]
330 #[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
331 fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
332 let schema = vec![table_def(
333 "orders",
334 vec![
335 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
336 col(
337 "paid_at",
338 ColumnType::Simple(SimpleColumnType::Timestamptz),
339 true,
340 ),
341 ],
342 vec![],
343 )];
344
345 let result = build_modify_column_nullable(
346 &backend,
347 "orders",
348 "paid_at",
349 false,
350 Some("NOW()"),
351 &schema,
352 );
353 assert!(result.is_ok());
354 let queries = result.unwrap();
355 let sql = queries
356 .iter()
357 .map(|q| q.build(backend))
358 .collect::<Vec<String>>()
359 .join("\n");
360
361 assert!(
363 !sql.contains("NOW()"),
364 "SQL should not contain NOW(), got: {}",
365 sql
366 );
367 assert!(
368 sql.contains("CURRENT_TIMESTAMP"),
369 "SQL should contain CURRENT_TIMESTAMP, got: {}",
370 sql
371 );
372
373 let suffix = format!(
374 "{}_fill_now",
375 match backend {
376 DatabaseBackend::Postgres => "postgres",
377 DatabaseBackend::MySql => "mysql",
378 DatabaseBackend::Sqlite => "sqlite",
379 }
380 );
381
382 with_settings!({ snapshot_suffix => suffix }, {
383 assert_snapshot!(sql);
384 });
385 }
386
387 #[rstest]
389 #[case::postgres_with_default(DatabaseBackend::Postgres)]
390 #[case::mysql_with_default(DatabaseBackend::MySql)]
391 #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
392 fn test_with_default_value(#[case] backend: DatabaseBackend) {
393 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
394 email_col.default = Some("'default@example.com'".into());
395
396 let schema = vec![table_def(
397 "users",
398 vec![
399 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
400 email_col,
401 ],
402 vec![],
403 )];
404
405 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
406 assert!(result.is_ok());
407 let queries = result.unwrap();
408 let sql = queries
409 .iter()
410 .map(|q| q.build(backend))
411 .collect::<Vec<String>>()
412 .join("\n");
413
414 if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
416 assert!(sql.contains("DEFAULT"));
417 }
418
419 let suffix = format!(
420 "{}_with_default",
421 match backend {
422 DatabaseBackend::Postgres => "postgres",
423 DatabaseBackend::MySql => "mysql",
424 DatabaseBackend::Sqlite => "sqlite",
425 }
426 );
427
428 with_settings!({ snapshot_suffix => suffix }, {
429 assert_snapshot!(sql);
430 });
431 }
432}