vespertide_query/sql/
modify_column_default.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::build_sea_column_def_with_table;
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11/// Build SQL for changing column default value.
12pub fn build_modify_column_default(
13    backend: &DatabaseBackend,
14    table: &str,
15    column: &str,
16    new_default: Option<&str>,
17    current_schema: &[TableDef],
18) -> Result<Vec<BuiltQuery>, QueryError> {
19    let mut queries = Vec::new();
20
21    match backend {
22        DatabaseBackend::Postgres => {
23            let alter_sql = if let Some(default_value) = new_default {
24                format!(
25                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
26                    table, column, default_value
27                )
28            } else {
29                format!(
30                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
31                    table, column
32                )
33            };
34            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
35        }
36        DatabaseBackend::MySql => {
37            // MySQL requires the full column definition in ALTER COLUMN
38            let table_def = current_schema
39                .iter()
40                .find(|t| t.name == table)
41                .ok_or_else(|| {
42                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
43                })?;
44
45            let column_def = table_def
46                .columns
47                .iter()
48                .find(|c| c.name == column)
49                .ok_or_else(|| {
50                    QueryError::Other(format!(
51                        "Column '{}' not found in table '{}'.",
52                        column, table
53                    ))
54                })?;
55
56            // Create a modified column def with the new default
57            let modified_col_def = ColumnDef {
58                default: new_default.map(|s| s.into()),
59                ..column_def.clone()
60            };
61
62            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
63
64            let stmt = Table::alter()
65                .table(Alias::new(table))
66                .modify_column(sea_col)
67                .to_owned();
68            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
69        }
70        DatabaseBackend::Sqlite => {
71            // SQLite doesn't support ALTER COLUMN for default changes
72            // Use temporary table approach
73            let table_def = current_schema
74                .iter()
75                .find(|t| t.name == table)
76                .ok_or_else(|| {
77                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
78                })?;
79
80            // Create modified columns with the new default
81            let mut new_columns = table_def.columns.clone();
82            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
83                col.default = new_default.map(|s| s.into());
84            }
85
86            // Generate temporary table name
87            let temp_table = format!("{}_temp", table);
88
89            // 1. Create temporary table with modified column
90            let create_temp_table = build_create_table_for_backend(
91                backend,
92                &temp_table,
93                &new_columns,
94                &table_def.constraints,
95            );
96            queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
97
98            // 2. Copy data (all columns)
99            let column_aliases: Vec<Alias> = table_def
100                .columns
101                .iter()
102                .map(|c| Alias::new(&c.name))
103                .collect();
104            let mut select_query = Query::select();
105            for col_alias in &column_aliases {
106                select_query = select_query.column(col_alias.clone()).to_owned();
107            }
108            select_query = select_query.from(Alias::new(table)).to_owned();
109
110            let insert_stmt = Query::insert()
111                .into_table(Alias::new(&temp_table))
112                .columns(column_aliases.clone())
113                .select_from(select_query)
114                .unwrap()
115                .to_owned();
116            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
117
118            // 3. Drop original table
119            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
120            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
121
122            // 4. Rename temporary table to original name
123            queries.push(build_rename_table(&temp_table, table));
124
125            // 5. Recreate indexes from Index constraints
126            for constraint in &table_def.constraints {
127                if let vespertide_core::TableConstraint::Index {
128                    name: idx_name,
129                    columns: idx_cols,
130                } = constraint
131                {
132                    let index_name =
133                        vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
134                    let mut idx_stmt = sea_query::Index::create();
135                    idx_stmt = idx_stmt.name(&index_name).to_owned();
136                    for col_name in idx_cols {
137                        idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
138                    }
139                    idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
140                    queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
141                }
142            }
143        }
144    }
145
146    Ok(queries)
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use insta::{assert_snapshot, with_settings};
153    use rstest::rstest;
154    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
155
156    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
157        ColumnDef {
158            name: name.to_string(),
159            r#type: ty,
160            nullable,
161            default: None,
162            comment: None,
163            primary_key: None,
164            unique: None,
165            index: None,
166            foreign_key: None,
167        }
168    }
169
170    fn table_def(
171        name: &str,
172        columns: Vec<ColumnDef>,
173        constraints: Vec<TableConstraint>,
174    ) -> TableDef {
175        TableDef {
176            name: name.to_string(),
177            description: None,
178            columns,
179            constraints,
180        }
181    }
182
183    #[rstest]
184    #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
185    #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
186    #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
187    #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
188    #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
189    #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
190    fn test_build_modify_column_default(
191        #[case] backend: DatabaseBackend,
192        #[case] new_default: Option<&str>,
193    ) {
194        let schema = vec![table_def(
195            "users",
196            vec![
197                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
198                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
199            ],
200            vec![],
201        )];
202
203        let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
204        assert!(result.is_ok());
205        let queries = result.unwrap();
206        let sql = queries
207            .iter()
208            .map(|q| q.build(backend))
209            .collect::<Vec<String>>()
210            .join("\n");
211
212        let suffix = format!(
213            "{}_{}_users",
214            match backend {
215                DatabaseBackend::Postgres => "postgres",
216                DatabaseBackend::MySql => "mysql",
217                DatabaseBackend::Sqlite => "sqlite",
218            },
219            if new_default.is_some() {
220                "set_default"
221            } else {
222                "drop_default"
223            }
224        );
225
226        with_settings!({ snapshot_suffix => suffix }, {
227            assert_snapshot!(sql);
228        });
229    }
230
231    /// Test table not found error
232    #[rstest]
233    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
234    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
235    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
236    fn test_table_not_found(#[case] backend: DatabaseBackend) {
237        // Postgres doesn't need schema lookup for default changes
238        if backend == DatabaseBackend::Postgres {
239            return;
240        }
241
242        let result =
243            build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
244        assert!(result.is_err());
245        let err_msg = result.unwrap_err().to_string();
246        assert!(err_msg.contains("Table 'users' not found"));
247    }
248
249    /// Test column not found error
250    #[rstest]
251    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
252    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
253    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
254    fn test_column_not_found(#[case] backend: DatabaseBackend) {
255        // Postgres doesn't need schema lookup for default changes
256        // SQLite doesn't validate column existence in modify_column_default
257        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
258            return;
259        }
260
261        let schema = vec![table_def(
262            "users",
263            vec![col(
264                "id",
265                ColumnType::Simple(SimpleColumnType::Integer),
266                false,
267            )],
268            vec![],
269        )];
270
271        let result =
272            build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
273        assert!(result.is_err());
274        let err_msg = result.unwrap_err().to_string();
275        assert!(err_msg.contains("Column 'email' not found"));
276    }
277
278    /// Test with index - should recreate index after table rebuild (SQLite)
279    #[rstest]
280    #[case::postgres_with_index(DatabaseBackend::Postgres)]
281    #[case::mysql_with_index(DatabaseBackend::MySql)]
282    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
283    fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
284        let schema = vec![table_def(
285            "users",
286            vec![
287                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
288                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
289            ],
290            vec![TableConstraint::Index {
291                name: Some("idx_users_email".into()),
292                columns: vec!["email".into()],
293            }],
294        )];
295
296        let result = build_modify_column_default(
297            &backend,
298            "users",
299            "email",
300            Some("'default@example.com'"),
301            &schema,
302        );
303        assert!(result.is_ok());
304        let queries = result.unwrap();
305        let sql = queries
306            .iter()
307            .map(|q| q.build(backend))
308            .collect::<Vec<String>>()
309            .join("\n");
310
311        // SQLite should recreate the index after table rebuild
312        if backend == DatabaseBackend::Sqlite {
313            assert!(sql.contains("CREATE INDEX"));
314            assert!(sql.contains("idx_users_email"));
315        }
316
317        let suffix = format!(
318            "{}_with_index",
319            match backend {
320                DatabaseBackend::Postgres => "postgres",
321                DatabaseBackend::MySql => "mysql",
322                DatabaseBackend::Sqlite => "sqlite",
323            }
324        );
325
326        with_settings!({ snapshot_suffix => suffix }, {
327            assert_snapshot!(sql);
328        });
329    }
330
331    /// Test changing default value from one to another
332    #[rstest]
333    #[case::postgres_change_default(DatabaseBackend::Postgres)]
334    #[case::mysql_change_default(DatabaseBackend::MySql)]
335    #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
336    fn test_change_default_value(#[case] backend: DatabaseBackend) {
337        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
338        email_col.default = Some("'old@example.com'".into());
339
340        let schema = vec![table_def(
341            "users",
342            vec![
343                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
344                email_col,
345            ],
346            vec![],
347        )];
348
349        let result = build_modify_column_default(
350            &backend,
351            "users",
352            "email",
353            Some("'new@example.com'"),
354            &schema,
355        );
356        assert!(result.is_ok());
357        let queries = result.unwrap();
358        let sql = queries
359            .iter()
360            .map(|q| q.build(backend))
361            .collect::<Vec<String>>()
362            .join("\n");
363
364        let suffix = format!(
365            "{}_change_default",
366            match backend {
367                DatabaseBackend::Postgres => "postgres",
368                DatabaseBackend::MySql => "mysql",
369                DatabaseBackend::Sqlite => "sqlite",
370            }
371        );
372
373        with_settings!({ snapshot_suffix => suffix }, {
374            assert_snapshot!(sql);
375        });
376    }
377
378    /// Test with integer default value
379    #[rstest]
380    #[case::postgres_integer_default(DatabaseBackend::Postgres)]
381    #[case::mysql_integer_default(DatabaseBackend::MySql)]
382    #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
383    fn test_integer_default(#[case] backend: DatabaseBackend) {
384        let schema = vec![table_def(
385            "products",
386            vec![
387                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
388                col(
389                    "quantity",
390                    ColumnType::Simple(SimpleColumnType::Integer),
391                    false,
392                ),
393            ],
394            vec![],
395        )];
396
397        let result =
398            build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
399        assert!(result.is_ok());
400        let queries = result.unwrap();
401        let sql = queries
402            .iter()
403            .map(|q| q.build(backend))
404            .collect::<Vec<String>>()
405            .join("\n");
406
407        let suffix = format!(
408            "{}_integer_default",
409            match backend {
410                DatabaseBackend::Postgres => "postgres",
411                DatabaseBackend::MySql => "mysql",
412                DatabaseBackend::Sqlite => "sqlite",
413            }
414        );
415
416        with_settings!({ snapshot_suffix => suffix }, {
417            assert_snapshot!(sql);
418        });
419    }
420
421    /// Test with boolean default value
422    #[rstest]
423    #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
424    #[case::mysql_boolean_default(DatabaseBackend::MySql)]
425    #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
426    fn test_boolean_default(#[case] backend: DatabaseBackend) {
427        let schema = vec![table_def(
428            "users",
429            vec![
430                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
431                col(
432                    "is_active",
433                    ColumnType::Simple(SimpleColumnType::Boolean),
434                    false,
435                ),
436            ],
437            vec![],
438        )];
439
440        let result =
441            build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
442        assert!(result.is_ok());
443        let queries = result.unwrap();
444        let sql = queries
445            .iter()
446            .map(|q| q.build(backend))
447            .collect::<Vec<String>>()
448            .join("\n");
449
450        let suffix = format!(
451            "{}_boolean_default",
452            match backend {
453                DatabaseBackend::Postgres => "postgres",
454                DatabaseBackend::MySql => "mysql",
455                DatabaseBackend::Sqlite => "sqlite",
456            }
457        );
458
459        with_settings!({ snapshot_suffix => suffix }, {
460            assert_snapshot!(sql);
461        });
462    }
463
464    /// Test with function default (e.g., NOW(), CURRENT_TIMESTAMP)
465    #[rstest]
466    #[case::postgres_function_default(DatabaseBackend::Postgres)]
467    #[case::mysql_function_default(DatabaseBackend::MySql)]
468    #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
469    fn test_function_default(#[case] backend: DatabaseBackend) {
470        let schema = vec![table_def(
471            "events",
472            vec![
473                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
474                col(
475                    "created_at",
476                    ColumnType::Simple(SimpleColumnType::Timestamp),
477                    false,
478                ),
479            ],
480            vec![],
481        )];
482
483        let default_value = match backend {
484            DatabaseBackend::Postgres => "NOW()",
485            DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
486            DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
487        };
488
489        let result = build_modify_column_default(
490            &backend,
491            "events",
492            "created_at",
493            Some(default_value),
494            &schema,
495        );
496        assert!(result.is_ok());
497        let queries = result.unwrap();
498        let sql = queries
499            .iter()
500            .map(|q| q.build(backend))
501            .collect::<Vec<String>>()
502            .join("\n");
503
504        let suffix = format!(
505            "{}_function_default",
506            match backend {
507                DatabaseBackend::Postgres => "postgres",
508                DatabaseBackend::MySql => "mysql",
509                DatabaseBackend::Sqlite => "sqlite",
510            }
511        );
512
513        with_settings!({ snapshot_suffix => suffix }, {
514            assert_snapshot!(sql);
515        });
516    }
517
518    /// Test dropping default from column that had one
519    #[rstest]
520    #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
521    #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
522    #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
523    fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
524        let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
525        status_col.default = Some("'pending'".into());
526
527        let schema = vec![table_def(
528            "orders",
529            vec![
530                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
531                status_col,
532            ],
533            vec![],
534        )];
535
536        let result = build_modify_column_default(
537            &backend, "orders", "status", None, // Drop default
538            &schema,
539        );
540        assert!(result.is_ok());
541        let queries = result.unwrap();
542        let sql = queries
543            .iter()
544            .map(|q| q.build(backend))
545            .collect::<Vec<String>>()
546            .join("\n");
547
548        let suffix = format!(
549            "{}_drop_existing_default",
550            match backend {
551                DatabaseBackend::Postgres => "postgres",
552                DatabaseBackend::MySql => "mysql",
553                DatabaseBackend::Sqlite => "sqlite",
554            }
555        );
556
557        with_settings!({ snapshot_suffix => suffix }, {
558            assert_snapshot!(sql);
559        });
560    }
561}