1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{build_sea_column_def_with_table, normalize_enum_default};
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11pub fn build_modify_column_default(
13 backend: &DatabaseBackend,
14 table: &str,
15 column: &str,
16 new_default: Option<&str>,
17 current_schema: &[TableDef],
18) -> Result<Vec<BuiltQuery>, QueryError> {
19 let mut queries = Vec::new();
20
21 match backend {
22 DatabaseBackend::Postgres => {
23 let alter_sql = if let Some(default_value) = new_default {
24 let column_type = current_schema
26 .iter()
27 .find(|t| t.name == table)
28 .and_then(|t| t.columns.iter().find(|c| c.name == column))
29 .map(|c| &c.r#type);
30
31 let normalized_default = if let Some(col_type) = column_type {
32 normalize_enum_default(col_type, default_value)
33 } else {
34 default_value.to_string()
35 };
36
37 format!(
38 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
39 table, column, normalized_default
40 )
41 } else {
42 format!(
43 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
44 table, column
45 )
46 };
47 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
48 }
49 DatabaseBackend::MySql => {
50 let table_def = current_schema
52 .iter()
53 .find(|t| t.name == table)
54 .ok_or_else(|| {
55 QueryError::Other(format!("Table '{}' not found in current schema.", table))
56 })?;
57
58 let column_def = table_def
59 .columns
60 .iter()
61 .find(|c| c.name == column)
62 .ok_or_else(|| {
63 QueryError::Other(format!(
64 "Column '{}' not found in table '{}'.",
65 column, table
66 ))
67 })?;
68
69 let modified_col_def = ColumnDef {
71 default: new_default.map(|s| s.into()),
72 ..column_def.clone()
73 };
74
75 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
76
77 let stmt = Table::alter()
78 .table(Alias::new(table))
79 .modify_column(sea_col)
80 .to_owned();
81 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
82 }
83 DatabaseBackend::Sqlite => {
84 let table_def = current_schema
87 .iter()
88 .find(|t| t.name == table)
89 .ok_or_else(|| {
90 QueryError::Other(format!("Table '{}' not found in current schema.", table))
91 })?;
92
93 let mut new_columns = table_def.columns.clone();
95 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
96 col.default = new_default.map(|s| s.into());
97 }
98
99 let temp_table = format!("{}_temp", table);
101
102 let create_temp_table = build_create_table_for_backend(
104 backend,
105 &temp_table,
106 &new_columns,
107 &table_def.constraints,
108 );
109 queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
110
111 let column_aliases: Vec<Alias> = table_def
113 .columns
114 .iter()
115 .map(|c| Alias::new(&c.name))
116 .collect();
117 let mut select_query = Query::select();
118 for col_alias in &column_aliases {
119 select_query = select_query.column(col_alias.clone()).to_owned();
120 }
121 select_query = select_query.from(Alias::new(table)).to_owned();
122
123 let insert_stmt = Query::insert()
124 .into_table(Alias::new(&temp_table))
125 .columns(column_aliases.clone())
126 .select_from(select_query)
127 .unwrap()
128 .to_owned();
129 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
130
131 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
133 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
134
135 queries.push(build_rename_table(&temp_table, table));
137
138 for constraint in &table_def.constraints {
140 if let vespertide_core::TableConstraint::Index {
141 name: idx_name,
142 columns: idx_cols,
143 } = constraint
144 {
145 let index_name =
146 vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
147 let mut idx_stmt = sea_query::Index::create();
148 idx_stmt = idx_stmt.name(&index_name).to_owned();
149 for col_name in idx_cols {
150 idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
151 }
152 idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
153 queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
154 }
155 }
156 }
157 }
158
159 Ok(queries)
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use insta::{assert_snapshot, with_settings};
166 use rstest::rstest;
167 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
168
169 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
170 ColumnDef {
171 name: name.to_string(),
172 r#type: ty,
173 nullable,
174 default: None,
175 comment: None,
176 primary_key: None,
177 unique: None,
178 index: None,
179 foreign_key: None,
180 }
181 }
182
183 fn table_def(
184 name: &str,
185 columns: Vec<ColumnDef>,
186 constraints: Vec<TableConstraint>,
187 ) -> TableDef {
188 TableDef {
189 name: name.to_string(),
190 description: None,
191 columns,
192 constraints,
193 }
194 }
195
196 #[rstest]
197 #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
198 #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
199 #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
200 #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
201 #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
202 #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
203 fn test_build_modify_column_default(
204 #[case] backend: DatabaseBackend,
205 #[case] new_default: Option<&str>,
206 ) {
207 let schema = vec![table_def(
208 "users",
209 vec![
210 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
211 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
212 ],
213 vec![],
214 )];
215
216 let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
217 assert!(result.is_ok());
218 let queries = result.unwrap();
219 let sql = queries
220 .iter()
221 .map(|q| q.build(backend))
222 .collect::<Vec<String>>()
223 .join("\n");
224
225 let suffix = format!(
226 "{}_{}_users",
227 match backend {
228 DatabaseBackend::Postgres => "postgres",
229 DatabaseBackend::MySql => "mysql",
230 DatabaseBackend::Sqlite => "sqlite",
231 },
232 if new_default.is_some() {
233 "set_default"
234 } else {
235 "drop_default"
236 }
237 );
238
239 with_settings!({ snapshot_suffix => suffix }, {
240 assert_snapshot!(sql);
241 });
242 }
243
244 #[rstest]
246 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
247 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
248 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
249 fn test_table_not_found(#[case] backend: DatabaseBackend) {
250 if backend == DatabaseBackend::Postgres {
252 return;
253 }
254
255 let result =
256 build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
257 assert!(result.is_err());
258 let err_msg = result.unwrap_err().to_string();
259 assert!(err_msg.contains("Table 'users' not found"));
260 }
261
262 #[rstest]
264 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
265 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
266 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
267 fn test_column_not_found(#[case] backend: DatabaseBackend) {
268 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
271 return;
272 }
273
274 let schema = vec![table_def(
275 "users",
276 vec![col(
277 "id",
278 ColumnType::Simple(SimpleColumnType::Integer),
279 false,
280 )],
281 vec![],
282 )];
283
284 let result =
285 build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
286 assert!(result.is_err());
287 let err_msg = result.unwrap_err().to_string();
288 assert!(err_msg.contains("Column 'email' not found"));
289 }
290
291 #[test]
294 fn test_postgres_column_not_in_schema_uses_default_as_is() {
295 let schema = vec![table_def(
296 "users",
297 vec![col(
298 "id",
299 ColumnType::Simple(SimpleColumnType::Integer),
300 false,
301 )],
302 vec![],
304 )];
305
306 let result = build_modify_column_default(
308 &DatabaseBackend::Postgres,
309 "users",
310 "status", Some("'active'"),
312 &schema,
313 );
314 assert!(result.is_ok());
315 let queries = result.unwrap();
316 let sql = queries
317 .iter()
318 .map(|q| q.build(DatabaseBackend::Postgres))
319 .collect::<Vec<String>>()
320 .join("\n");
321
322 assert!(sql.contains("ALTER TABLE \"users\" ALTER COLUMN \"status\" SET DEFAULT 'active'"));
324 }
325
326 #[rstest]
328 #[case::postgres_with_index(DatabaseBackend::Postgres)]
329 #[case::mysql_with_index(DatabaseBackend::MySql)]
330 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
331 fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
332 let schema = vec![table_def(
333 "users",
334 vec![
335 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
336 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
337 ],
338 vec![TableConstraint::Index {
339 name: Some("idx_users_email".into()),
340 columns: vec!["email".into()],
341 }],
342 )];
343
344 let result = build_modify_column_default(
345 &backend,
346 "users",
347 "email",
348 Some("'default@example.com'"),
349 &schema,
350 );
351 assert!(result.is_ok());
352 let queries = result.unwrap();
353 let sql = queries
354 .iter()
355 .map(|q| q.build(backend))
356 .collect::<Vec<String>>()
357 .join("\n");
358
359 if backend == DatabaseBackend::Sqlite {
361 assert!(sql.contains("CREATE INDEX"));
362 assert!(sql.contains("idx_users_email"));
363 }
364
365 let suffix = format!(
366 "{}_with_index",
367 match backend {
368 DatabaseBackend::Postgres => "postgres",
369 DatabaseBackend::MySql => "mysql",
370 DatabaseBackend::Sqlite => "sqlite",
371 }
372 );
373
374 with_settings!({ snapshot_suffix => suffix }, {
375 assert_snapshot!(sql);
376 });
377 }
378
379 #[rstest]
381 #[case::postgres_change_default(DatabaseBackend::Postgres)]
382 #[case::mysql_change_default(DatabaseBackend::MySql)]
383 #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
384 fn test_change_default_value(#[case] backend: DatabaseBackend) {
385 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
386 email_col.default = Some("'old@example.com'".into());
387
388 let schema = vec![table_def(
389 "users",
390 vec![
391 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
392 email_col,
393 ],
394 vec![],
395 )];
396
397 let result = build_modify_column_default(
398 &backend,
399 "users",
400 "email",
401 Some("'new@example.com'"),
402 &schema,
403 );
404 assert!(result.is_ok());
405 let queries = result.unwrap();
406 let sql = queries
407 .iter()
408 .map(|q| q.build(backend))
409 .collect::<Vec<String>>()
410 .join("\n");
411
412 let suffix = format!(
413 "{}_change_default",
414 match backend {
415 DatabaseBackend::Postgres => "postgres",
416 DatabaseBackend::MySql => "mysql",
417 DatabaseBackend::Sqlite => "sqlite",
418 }
419 );
420
421 with_settings!({ snapshot_suffix => suffix }, {
422 assert_snapshot!(sql);
423 });
424 }
425
426 #[rstest]
428 #[case::postgres_integer_default(DatabaseBackend::Postgres)]
429 #[case::mysql_integer_default(DatabaseBackend::MySql)]
430 #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
431 fn test_integer_default(#[case] backend: DatabaseBackend) {
432 let schema = vec![table_def(
433 "products",
434 vec![
435 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
436 col(
437 "quantity",
438 ColumnType::Simple(SimpleColumnType::Integer),
439 false,
440 ),
441 ],
442 vec![],
443 )];
444
445 let result =
446 build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
447 assert!(result.is_ok());
448 let queries = result.unwrap();
449 let sql = queries
450 .iter()
451 .map(|q| q.build(backend))
452 .collect::<Vec<String>>()
453 .join("\n");
454
455 let suffix = format!(
456 "{}_integer_default",
457 match backend {
458 DatabaseBackend::Postgres => "postgres",
459 DatabaseBackend::MySql => "mysql",
460 DatabaseBackend::Sqlite => "sqlite",
461 }
462 );
463
464 with_settings!({ snapshot_suffix => suffix }, {
465 assert_snapshot!(sql);
466 });
467 }
468
469 #[rstest]
471 #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
472 #[case::mysql_boolean_default(DatabaseBackend::MySql)]
473 #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
474 fn test_boolean_default(#[case] backend: DatabaseBackend) {
475 let schema = vec![table_def(
476 "users",
477 vec![
478 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
479 col(
480 "is_active",
481 ColumnType::Simple(SimpleColumnType::Boolean),
482 false,
483 ),
484 ],
485 vec![],
486 )];
487
488 let result =
489 build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
490 assert!(result.is_ok());
491 let queries = result.unwrap();
492 let sql = queries
493 .iter()
494 .map(|q| q.build(backend))
495 .collect::<Vec<String>>()
496 .join("\n");
497
498 let suffix = format!(
499 "{}_boolean_default",
500 match backend {
501 DatabaseBackend::Postgres => "postgres",
502 DatabaseBackend::MySql => "mysql",
503 DatabaseBackend::Sqlite => "sqlite",
504 }
505 );
506
507 with_settings!({ snapshot_suffix => suffix }, {
508 assert_snapshot!(sql);
509 });
510 }
511
512 #[rstest]
514 #[case::postgres_function_default(DatabaseBackend::Postgres)]
515 #[case::mysql_function_default(DatabaseBackend::MySql)]
516 #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
517 fn test_function_default(#[case] backend: DatabaseBackend) {
518 let schema = vec![table_def(
519 "events",
520 vec![
521 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
522 col(
523 "created_at",
524 ColumnType::Simple(SimpleColumnType::Timestamp),
525 false,
526 ),
527 ],
528 vec![],
529 )];
530
531 let default_value = match backend {
532 DatabaseBackend::Postgres => "NOW()",
533 DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
534 DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
535 };
536
537 let result = build_modify_column_default(
538 &backend,
539 "events",
540 "created_at",
541 Some(default_value),
542 &schema,
543 );
544 assert!(result.is_ok());
545 let queries = result.unwrap();
546 let sql = queries
547 .iter()
548 .map(|q| q.build(backend))
549 .collect::<Vec<String>>()
550 .join("\n");
551
552 let suffix = format!(
553 "{}_function_default",
554 match backend {
555 DatabaseBackend::Postgres => "postgres",
556 DatabaseBackend::MySql => "mysql",
557 DatabaseBackend::Sqlite => "sqlite",
558 }
559 );
560
561 with_settings!({ snapshot_suffix => suffix }, {
562 assert_snapshot!(sql);
563 });
564 }
565
566 #[rstest]
568 #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
569 #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
570 #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
571 fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
572 let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
573 status_col.default = Some("'pending'".into());
574
575 let schema = vec![table_def(
576 "orders",
577 vec![
578 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
579 status_col,
580 ],
581 vec![],
582 )];
583
584 let result = build_modify_column_default(
585 &backend, "orders", "status", None, &schema,
587 );
588 assert!(result.is_ok());
589 let queries = result.unwrap();
590 let sql = queries
591 .iter()
592 .map(|q| q.build(backend))
593 .collect::<Vec<String>>()
594 .join("\n");
595
596 let suffix = format!(
597 "{}_drop_existing_default",
598 match backend {
599 DatabaseBackend::Postgres => "postgres",
600 DatabaseBackend::MySql => "mysql",
601 DatabaseBackend::Sqlite => "sqlite",
602 }
603 );
604
605 with_settings!({ snapshot_suffix => suffix }, {
606 assert_snapshot!(sql);
607 });
608 }
609}