1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::build_sea_column_def_with_table;
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11pub fn build_modify_column_default(
13 backend: &DatabaseBackend,
14 table: &str,
15 column: &str,
16 new_default: Option<&str>,
17 current_schema: &[TableDef],
18) -> Result<Vec<BuiltQuery>, QueryError> {
19 let mut queries = Vec::new();
20
21 match backend {
22 DatabaseBackend::Postgres => {
23 let alter_sql = if let Some(default_value) = new_default {
24 format!(
25 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
26 table, column, default_value
27 )
28 } else {
29 format!(
30 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
31 table, column
32 )
33 };
34 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
35 }
36 DatabaseBackend::MySql => {
37 let table_def = current_schema
39 .iter()
40 .find(|t| t.name == table)
41 .ok_or_else(|| {
42 QueryError::Other(format!("Table '{}' not found in current schema.", table))
43 })?;
44
45 let column_def = table_def
46 .columns
47 .iter()
48 .find(|c| c.name == column)
49 .ok_or_else(|| {
50 QueryError::Other(format!(
51 "Column '{}' not found in table '{}'.",
52 column, table
53 ))
54 })?;
55
56 let modified_col_def = ColumnDef {
58 default: new_default.map(|s| s.into()),
59 ..column_def.clone()
60 };
61
62 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
63
64 let stmt = Table::alter()
65 .table(Alias::new(table))
66 .modify_column(sea_col)
67 .to_owned();
68 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
69 }
70 DatabaseBackend::Sqlite => {
71 let table_def = current_schema
74 .iter()
75 .find(|t| t.name == table)
76 .ok_or_else(|| {
77 QueryError::Other(format!("Table '{}' not found in current schema.", table))
78 })?;
79
80 let mut new_columns = table_def.columns.clone();
82 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
83 col.default = new_default.map(|s| s.into());
84 }
85
86 let temp_table = format!("{}_temp", table);
88
89 let create_temp_table = build_create_table_for_backend(
91 backend,
92 &temp_table,
93 &new_columns,
94 &table_def.constraints,
95 );
96 queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
97
98 let column_aliases: Vec<Alias> = table_def
100 .columns
101 .iter()
102 .map(|c| Alias::new(&c.name))
103 .collect();
104 let mut select_query = Query::select();
105 for col_alias in &column_aliases {
106 select_query = select_query.column(col_alias.clone()).to_owned();
107 }
108 select_query = select_query.from(Alias::new(table)).to_owned();
109
110 let insert_stmt = Query::insert()
111 .into_table(Alias::new(&temp_table))
112 .columns(column_aliases.clone())
113 .select_from(select_query)
114 .unwrap()
115 .to_owned();
116 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
117
118 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
120 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
121
122 queries.push(build_rename_table(&temp_table, table));
124
125 for constraint in &table_def.constraints {
127 if let vespertide_core::TableConstraint::Index {
128 name: idx_name,
129 columns: idx_cols,
130 } = constraint
131 {
132 let index_name =
133 vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
134 let mut idx_stmt = sea_query::Index::create();
135 idx_stmt = idx_stmt.name(&index_name).to_owned();
136 for col_name in idx_cols {
137 idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
138 }
139 idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
140 queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
141 }
142 }
143 }
144 }
145
146 Ok(queries)
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use insta::{assert_snapshot, with_settings};
153 use rstest::rstest;
154 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
155
156 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
157 ColumnDef {
158 name: name.to_string(),
159 r#type: ty,
160 nullable,
161 default: None,
162 comment: None,
163 primary_key: None,
164 unique: None,
165 index: None,
166 foreign_key: None,
167 }
168 }
169
170 fn table_def(
171 name: &str,
172 columns: Vec<ColumnDef>,
173 constraints: Vec<TableConstraint>,
174 ) -> TableDef {
175 TableDef {
176 name: name.to_string(),
177 description: None,
178 columns,
179 constraints,
180 }
181 }
182
183 #[rstest]
184 #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
185 #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
186 #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
187 #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
188 #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
189 #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
190 fn test_build_modify_column_default(
191 #[case] backend: DatabaseBackend,
192 #[case] new_default: Option<&str>,
193 ) {
194 let schema = vec![table_def(
195 "users",
196 vec![
197 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
198 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
199 ],
200 vec![],
201 )];
202
203 let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
204 assert!(result.is_ok());
205 let queries = result.unwrap();
206 let sql = queries
207 .iter()
208 .map(|q| q.build(backend))
209 .collect::<Vec<String>>()
210 .join("\n");
211
212 let suffix = format!(
213 "{}_{}_users",
214 match backend {
215 DatabaseBackend::Postgres => "postgres",
216 DatabaseBackend::MySql => "mysql",
217 DatabaseBackend::Sqlite => "sqlite",
218 },
219 if new_default.is_some() {
220 "set_default"
221 } else {
222 "drop_default"
223 }
224 );
225
226 with_settings!({ snapshot_suffix => suffix }, {
227 assert_snapshot!(sql);
228 });
229 }
230
231 #[rstest]
233 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
234 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
235 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
236 fn test_table_not_found(#[case] backend: DatabaseBackend) {
237 if backend == DatabaseBackend::Postgres {
239 return;
240 }
241
242 let result =
243 build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
244 assert!(result.is_err());
245 let err_msg = result.unwrap_err().to_string();
246 assert!(err_msg.contains("Table 'users' not found"));
247 }
248
249 #[rstest]
251 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
252 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
253 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
254 fn test_column_not_found(#[case] backend: DatabaseBackend) {
255 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
258 return;
259 }
260
261 let schema = vec![table_def(
262 "users",
263 vec![col(
264 "id",
265 ColumnType::Simple(SimpleColumnType::Integer),
266 false,
267 )],
268 vec![],
269 )];
270
271 let result =
272 build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
273 assert!(result.is_err());
274 let err_msg = result.unwrap_err().to_string();
275 assert!(err_msg.contains("Column 'email' not found"));
276 }
277
278 #[rstest]
280 #[case::postgres_with_index(DatabaseBackend::Postgres)]
281 #[case::mysql_with_index(DatabaseBackend::MySql)]
282 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
283 fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
284 let schema = vec![table_def(
285 "users",
286 vec![
287 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
288 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
289 ],
290 vec![TableConstraint::Index {
291 name: Some("idx_users_email".into()),
292 columns: vec!["email".into()],
293 }],
294 )];
295
296 let result = build_modify_column_default(
297 &backend,
298 "users",
299 "email",
300 Some("'default@example.com'"),
301 &schema,
302 );
303 assert!(result.is_ok());
304 let queries = result.unwrap();
305 let sql = queries
306 .iter()
307 .map(|q| q.build(backend))
308 .collect::<Vec<String>>()
309 .join("\n");
310
311 if backend == DatabaseBackend::Sqlite {
313 assert!(sql.contains("CREATE INDEX"));
314 assert!(sql.contains("idx_users_email"));
315 }
316
317 let suffix = format!(
318 "{}_with_index",
319 match backend {
320 DatabaseBackend::Postgres => "postgres",
321 DatabaseBackend::MySql => "mysql",
322 DatabaseBackend::Sqlite => "sqlite",
323 }
324 );
325
326 with_settings!({ snapshot_suffix => suffix }, {
327 assert_snapshot!(sql);
328 });
329 }
330
331 #[rstest]
333 #[case::postgres_change_default(DatabaseBackend::Postgres)]
334 #[case::mysql_change_default(DatabaseBackend::MySql)]
335 #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
336 fn test_change_default_value(#[case] backend: DatabaseBackend) {
337 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
338 email_col.default = Some("'old@example.com'".into());
339
340 let schema = vec![table_def(
341 "users",
342 vec![
343 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
344 email_col,
345 ],
346 vec![],
347 )];
348
349 let result = build_modify_column_default(
350 &backend,
351 "users",
352 "email",
353 Some("'new@example.com'"),
354 &schema,
355 );
356 assert!(result.is_ok());
357 let queries = result.unwrap();
358 let sql = queries
359 .iter()
360 .map(|q| q.build(backend))
361 .collect::<Vec<String>>()
362 .join("\n");
363
364 let suffix = format!(
365 "{}_change_default",
366 match backend {
367 DatabaseBackend::Postgres => "postgres",
368 DatabaseBackend::MySql => "mysql",
369 DatabaseBackend::Sqlite => "sqlite",
370 }
371 );
372
373 with_settings!({ snapshot_suffix => suffix }, {
374 assert_snapshot!(sql);
375 });
376 }
377
378 #[rstest]
380 #[case::postgres_integer_default(DatabaseBackend::Postgres)]
381 #[case::mysql_integer_default(DatabaseBackend::MySql)]
382 #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
383 fn test_integer_default(#[case] backend: DatabaseBackend) {
384 let schema = vec![table_def(
385 "products",
386 vec![
387 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
388 col(
389 "quantity",
390 ColumnType::Simple(SimpleColumnType::Integer),
391 false,
392 ),
393 ],
394 vec![],
395 )];
396
397 let result =
398 build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
399 assert!(result.is_ok());
400 let queries = result.unwrap();
401 let sql = queries
402 .iter()
403 .map(|q| q.build(backend))
404 .collect::<Vec<String>>()
405 .join("\n");
406
407 let suffix = format!(
408 "{}_integer_default",
409 match backend {
410 DatabaseBackend::Postgres => "postgres",
411 DatabaseBackend::MySql => "mysql",
412 DatabaseBackend::Sqlite => "sqlite",
413 }
414 );
415
416 with_settings!({ snapshot_suffix => suffix }, {
417 assert_snapshot!(sql);
418 });
419 }
420
421 #[rstest]
423 #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
424 #[case::mysql_boolean_default(DatabaseBackend::MySql)]
425 #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
426 fn test_boolean_default(#[case] backend: DatabaseBackend) {
427 let schema = vec![table_def(
428 "users",
429 vec![
430 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
431 col(
432 "is_active",
433 ColumnType::Simple(SimpleColumnType::Boolean),
434 false,
435 ),
436 ],
437 vec![],
438 )];
439
440 let result =
441 build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
442 assert!(result.is_ok());
443 let queries = result.unwrap();
444 let sql = queries
445 .iter()
446 .map(|q| q.build(backend))
447 .collect::<Vec<String>>()
448 .join("\n");
449
450 let suffix = format!(
451 "{}_boolean_default",
452 match backend {
453 DatabaseBackend::Postgres => "postgres",
454 DatabaseBackend::MySql => "mysql",
455 DatabaseBackend::Sqlite => "sqlite",
456 }
457 );
458
459 with_settings!({ snapshot_suffix => suffix }, {
460 assert_snapshot!(sql);
461 });
462 }
463
464 #[rstest]
466 #[case::postgres_function_default(DatabaseBackend::Postgres)]
467 #[case::mysql_function_default(DatabaseBackend::MySql)]
468 #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
469 fn test_function_default(#[case] backend: DatabaseBackend) {
470 let schema = vec![table_def(
471 "events",
472 vec![
473 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
474 col(
475 "created_at",
476 ColumnType::Simple(SimpleColumnType::Timestamp),
477 false,
478 ),
479 ],
480 vec![],
481 )];
482
483 let default_value = match backend {
484 DatabaseBackend::Postgres => "NOW()",
485 DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
486 DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
487 };
488
489 let result = build_modify_column_default(
490 &backend,
491 "events",
492 "created_at",
493 Some(default_value),
494 &schema,
495 );
496 assert!(result.is_ok());
497 let queries = result.unwrap();
498 let sql = queries
499 .iter()
500 .map(|q| q.build(backend))
501 .collect::<Vec<String>>()
502 .join("\n");
503
504 let suffix = format!(
505 "{}_function_default",
506 match backend {
507 DatabaseBackend::Postgres => "postgres",
508 DatabaseBackend::MySql => "mysql",
509 DatabaseBackend::Sqlite => "sqlite",
510 }
511 );
512
513 with_settings!({ snapshot_suffix => suffix }, {
514 assert_snapshot!(sql);
515 });
516 }
517
518 #[rstest]
520 #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
521 #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
522 #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
523 fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
524 let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
525 status_col.default = Some("'pending'".into());
526
527 let schema = vec![table_def(
528 "orders",
529 vec![
530 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
531 status_col,
532 ],
533 vec![],
534 )];
535
536 let result = build_modify_column_default(
537 &backend, "orders", "status", None, &schema,
539 );
540 assert!(result.is_ok());
541 let queries = result.unwrap();
542 let sql = queries
543 .iter()
544 .map(|q| q.build(backend))
545 .collect::<Vec<String>>()
546 .join("\n");
547
548 let suffix = format!(
549 "{}_drop_existing_default",
550 match backend {
551 DatabaseBackend::Postgres => "postgres",
552 DatabaseBackend::MySql => "mysql",
553 DatabaseBackend::Sqlite => "sqlite",
554 }
555 );
556
557 with_settings!({ snapshot_suffix => suffix }, {
558 assert_snapshot!(sql);
559 });
560 }
561}