1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{build_sea_column_def_with_table, normalize_enum_default};
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11pub fn build_modify_column_default(
13 backend: &DatabaseBackend,
14 table: &str,
15 column: &str,
16 new_default: Option<&str>,
17 current_schema: &[TableDef],
18) -> Result<Vec<BuiltQuery>, QueryError> {
19 let mut queries = Vec::new();
20
21 match backend {
22 DatabaseBackend::Postgres => {
23 let alter_sql = if let Some(default_value) = new_default {
24 let column_type = current_schema
26 .iter()
27 .find(|t| t.name == table)
28 .and_then(|t| t.columns.iter().find(|c| c.name == column))
29 .map(|c| &c.r#type);
30
31 let normalized_default = if let Some(col_type) = column_type {
32 normalize_enum_default(col_type, default_value)
33 } else {
34 default_value.to_string()
35 };
36
37 format!(
38 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
39 table, column, normalized_default
40 )
41 } else {
42 format!(
43 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
44 table, column
45 )
46 };
47 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
48 }
49 DatabaseBackend::MySql => {
50 let table_def = current_schema
52 .iter()
53 .find(|t| t.name == table)
54 .ok_or_else(|| {
55 QueryError::Other(format!("Table '{}' not found in current schema.", table))
56 })?;
57
58 let column_def = table_def
59 .columns
60 .iter()
61 .find(|c| c.name == column)
62 .ok_or_else(|| {
63 QueryError::Other(format!(
64 "Column '{}' not found in table '{}'.",
65 column, table
66 ))
67 })?;
68
69 let modified_col_def = ColumnDef {
71 default: new_default.map(|s| s.into()),
72 ..column_def.clone()
73 };
74
75 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
76
77 let stmt = Table::alter()
78 .table(Alias::new(table))
79 .modify_column(sea_col)
80 .to_owned();
81 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
82 }
83 DatabaseBackend::Sqlite => {
84 let table_def = current_schema
87 .iter()
88 .find(|t| t.name == table)
89 .ok_or_else(|| {
90 QueryError::Other(format!("Table '{}' not found in current schema.", table))
91 })?;
92
93 let mut new_columns = table_def.columns.clone();
95 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
96 col.default = new_default.map(|s| s.into());
97 }
98
99 let temp_table = format!("{}_temp", table);
101
102 let create_temp_table = build_create_table_for_backend(
104 backend,
105 &temp_table,
106 &new_columns,
107 &table_def.constraints,
108 );
109 queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
110
111 let column_aliases: Vec<Alias> = table_def
113 .columns
114 .iter()
115 .map(|c| Alias::new(&c.name))
116 .collect();
117 let mut select_query = Query::select();
118 for col_alias in &column_aliases {
119 select_query = select_query.column(col_alias.clone()).to_owned();
120 }
121 select_query = select_query.from(Alias::new(table)).to_owned();
122
123 let insert_stmt = Query::insert()
124 .into_table(Alias::new(&temp_table))
125 .columns(column_aliases.clone())
126 .select_from(select_query)
127 .unwrap()
128 .to_owned();
129 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
130
131 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
133 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
134
135 queries.push(build_rename_table(&temp_table, table));
137
138 for constraint in &table_def.constraints {
140 if let vespertide_core::TableConstraint::Index {
141 name: idx_name,
142 columns: idx_cols,
143 } = constraint
144 {
145 let index_name =
146 vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
147 let mut idx_stmt = sea_query::Index::create();
148 idx_stmt = idx_stmt.name(&index_name).to_owned();
149 for col_name in idx_cols {
150 idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
151 }
152 idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
153 queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
154 }
155 }
156 }
157 }
158
159 Ok(queries)
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use insta::{assert_snapshot, with_settings};
166 use rstest::rstest;
167 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
168
169 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
170 ColumnDef {
171 name: name.to_string(),
172 r#type: ty,
173 nullable,
174 default: None,
175 comment: None,
176 primary_key: None,
177 unique: None,
178 index: None,
179 foreign_key: None,
180 }
181 }
182
183 fn table_def(
184 name: &str,
185 columns: Vec<ColumnDef>,
186 constraints: Vec<TableConstraint>,
187 ) -> TableDef {
188 TableDef {
189 name: name.to_string(),
190 description: None,
191 columns,
192 constraints,
193 }
194 }
195
196 #[rstest]
197 #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
198 #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
199 #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
200 #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
201 #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
202 #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
203 fn test_build_modify_column_default(
204 #[case] backend: DatabaseBackend,
205 #[case] new_default: Option<&str>,
206 ) {
207 let schema = vec![table_def(
208 "users",
209 vec![
210 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
211 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
212 ],
213 vec![],
214 )];
215
216 let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
217 assert!(result.is_ok());
218 let queries = result.unwrap();
219 let sql = queries
220 .iter()
221 .map(|q| q.build(backend))
222 .collect::<Vec<String>>()
223 .join("\n");
224
225 let suffix = format!(
226 "{}_{}_users",
227 match backend {
228 DatabaseBackend::Postgres => "postgres",
229 DatabaseBackend::MySql => "mysql",
230 DatabaseBackend::Sqlite => "sqlite",
231 },
232 if new_default.is_some() {
233 "set_default"
234 } else {
235 "drop_default"
236 }
237 );
238
239 with_settings!({ snapshot_suffix => suffix }, {
240 assert_snapshot!(sql);
241 });
242 }
243
244 #[rstest]
246 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
247 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
248 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
249 fn test_table_not_found(#[case] backend: DatabaseBackend) {
250 if backend == DatabaseBackend::Postgres {
252 return;
253 }
254
255 let result =
256 build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
257 assert!(result.is_err());
258 let err_msg = result.unwrap_err().to_string();
259 assert!(err_msg.contains("Table 'users' not found"));
260 }
261
262 #[rstest]
264 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
265 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
266 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
267 fn test_column_not_found(#[case] backend: DatabaseBackend) {
268 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
271 return;
272 }
273
274 let schema = vec![table_def(
275 "users",
276 vec![col(
277 "id",
278 ColumnType::Simple(SimpleColumnType::Integer),
279 false,
280 )],
281 vec![],
282 )];
283
284 let result =
285 build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
286 assert!(result.is_err());
287 let err_msg = result.unwrap_err().to_string();
288 assert!(err_msg.contains("Column 'email' not found"));
289 }
290
291 #[rstest]
293 #[case::postgres_with_index(DatabaseBackend::Postgres)]
294 #[case::mysql_with_index(DatabaseBackend::MySql)]
295 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
296 fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
297 let schema = vec![table_def(
298 "users",
299 vec![
300 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
301 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
302 ],
303 vec![TableConstraint::Index {
304 name: Some("idx_users_email".into()),
305 columns: vec!["email".into()],
306 }],
307 )];
308
309 let result = build_modify_column_default(
310 &backend,
311 "users",
312 "email",
313 Some("'default@example.com'"),
314 &schema,
315 );
316 assert!(result.is_ok());
317 let queries = result.unwrap();
318 let sql = queries
319 .iter()
320 .map(|q| q.build(backend))
321 .collect::<Vec<String>>()
322 .join("\n");
323
324 if backend == DatabaseBackend::Sqlite {
326 assert!(sql.contains("CREATE INDEX"));
327 assert!(sql.contains("idx_users_email"));
328 }
329
330 let suffix = format!(
331 "{}_with_index",
332 match backend {
333 DatabaseBackend::Postgres => "postgres",
334 DatabaseBackend::MySql => "mysql",
335 DatabaseBackend::Sqlite => "sqlite",
336 }
337 );
338
339 with_settings!({ snapshot_suffix => suffix }, {
340 assert_snapshot!(sql);
341 });
342 }
343
344 #[rstest]
346 #[case::postgres_change_default(DatabaseBackend::Postgres)]
347 #[case::mysql_change_default(DatabaseBackend::MySql)]
348 #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
349 fn test_change_default_value(#[case] backend: DatabaseBackend) {
350 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
351 email_col.default = Some("'old@example.com'".into());
352
353 let schema = vec![table_def(
354 "users",
355 vec![
356 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
357 email_col,
358 ],
359 vec![],
360 )];
361
362 let result = build_modify_column_default(
363 &backend,
364 "users",
365 "email",
366 Some("'new@example.com'"),
367 &schema,
368 );
369 assert!(result.is_ok());
370 let queries = result.unwrap();
371 let sql = queries
372 .iter()
373 .map(|q| q.build(backend))
374 .collect::<Vec<String>>()
375 .join("\n");
376
377 let suffix = format!(
378 "{}_change_default",
379 match backend {
380 DatabaseBackend::Postgres => "postgres",
381 DatabaseBackend::MySql => "mysql",
382 DatabaseBackend::Sqlite => "sqlite",
383 }
384 );
385
386 with_settings!({ snapshot_suffix => suffix }, {
387 assert_snapshot!(sql);
388 });
389 }
390
391 #[rstest]
393 #[case::postgres_integer_default(DatabaseBackend::Postgres)]
394 #[case::mysql_integer_default(DatabaseBackend::MySql)]
395 #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
396 fn test_integer_default(#[case] backend: DatabaseBackend) {
397 let schema = vec![table_def(
398 "products",
399 vec![
400 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
401 col(
402 "quantity",
403 ColumnType::Simple(SimpleColumnType::Integer),
404 false,
405 ),
406 ],
407 vec![],
408 )];
409
410 let result =
411 build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
412 assert!(result.is_ok());
413 let queries = result.unwrap();
414 let sql = queries
415 .iter()
416 .map(|q| q.build(backend))
417 .collect::<Vec<String>>()
418 .join("\n");
419
420 let suffix = format!(
421 "{}_integer_default",
422 match backend {
423 DatabaseBackend::Postgres => "postgres",
424 DatabaseBackend::MySql => "mysql",
425 DatabaseBackend::Sqlite => "sqlite",
426 }
427 );
428
429 with_settings!({ snapshot_suffix => suffix }, {
430 assert_snapshot!(sql);
431 });
432 }
433
434 #[rstest]
436 #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
437 #[case::mysql_boolean_default(DatabaseBackend::MySql)]
438 #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
439 fn test_boolean_default(#[case] backend: DatabaseBackend) {
440 let schema = vec![table_def(
441 "users",
442 vec![
443 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
444 col(
445 "is_active",
446 ColumnType::Simple(SimpleColumnType::Boolean),
447 false,
448 ),
449 ],
450 vec![],
451 )];
452
453 let result =
454 build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
455 assert!(result.is_ok());
456 let queries = result.unwrap();
457 let sql = queries
458 .iter()
459 .map(|q| q.build(backend))
460 .collect::<Vec<String>>()
461 .join("\n");
462
463 let suffix = format!(
464 "{}_boolean_default",
465 match backend {
466 DatabaseBackend::Postgres => "postgres",
467 DatabaseBackend::MySql => "mysql",
468 DatabaseBackend::Sqlite => "sqlite",
469 }
470 );
471
472 with_settings!({ snapshot_suffix => suffix }, {
473 assert_snapshot!(sql);
474 });
475 }
476
477 #[rstest]
479 #[case::postgres_function_default(DatabaseBackend::Postgres)]
480 #[case::mysql_function_default(DatabaseBackend::MySql)]
481 #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
482 fn test_function_default(#[case] backend: DatabaseBackend) {
483 let schema = vec![table_def(
484 "events",
485 vec![
486 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
487 col(
488 "created_at",
489 ColumnType::Simple(SimpleColumnType::Timestamp),
490 false,
491 ),
492 ],
493 vec![],
494 )];
495
496 let default_value = match backend {
497 DatabaseBackend::Postgres => "NOW()",
498 DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
499 DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
500 };
501
502 let result = build_modify_column_default(
503 &backend,
504 "events",
505 "created_at",
506 Some(default_value),
507 &schema,
508 );
509 assert!(result.is_ok());
510 let queries = result.unwrap();
511 let sql = queries
512 .iter()
513 .map(|q| q.build(backend))
514 .collect::<Vec<String>>()
515 .join("\n");
516
517 let suffix = format!(
518 "{}_function_default",
519 match backend {
520 DatabaseBackend::Postgres => "postgres",
521 DatabaseBackend::MySql => "mysql",
522 DatabaseBackend::Sqlite => "sqlite",
523 }
524 );
525
526 with_settings!({ snapshot_suffix => suffix }, {
527 assert_snapshot!(sql);
528 });
529 }
530
531 #[rstest]
533 #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
534 #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
535 #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
536 fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
537 let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
538 status_col.default = Some("'pending'".into());
539
540 let schema = vec![table_def(
541 "orders",
542 vec![
543 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
544 status_col,
545 ],
546 vec![],
547 )];
548
549 let result = build_modify_column_default(
550 &backend, "orders", "status", None, &schema,
552 );
553 assert!(result.is_ok());
554 let queries = result.unwrap();
555 let sql = queries
556 .iter()
557 .map(|q| q.build(backend))
558 .collect::<Vec<String>>()
559 .join("\n");
560
561 let suffix = format!(
562 "{}_drop_existing_default",
563 match backend {
564 DatabaseBackend::Postgres => "postgres",
565 DatabaseBackend::MySql => "mysql",
566 DatabaseBackend::Sqlite => "sqlite",
567 }
568 );
569
570 with_settings!({ snapshot_suffix => suffix }, {
571 assert_snapshot!(sql);
572 });
573 }
574}