Skip to main content

vespertide_query/sql/
modify_column_default.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_sea_column_def_with_table, build_sqlite_temp_table_create, normalize_enum_default,
7    recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13/// Build SQL for changing column default value.
14pub fn build_modify_column_default(
15    backend: &DatabaseBackend,
16    table: &str,
17    column: &str,
18    new_default: Option<&str>,
19    current_schema: &[TableDef],
20    pending_constraints: &[vespertide_core::TableConstraint],
21) -> Result<Vec<BuiltQuery>, QueryError> {
22    let mut queries = Vec::new();
23
24    match backend {
25        DatabaseBackend::Postgres => {
26            let alter_sql = if let Some(default_value) = new_default {
27                // Look up column type to properly quote enum defaults
28                let column_type = current_schema
29                    .iter()
30                    .find(|t| t.name == table)
31                    .and_then(|t| t.columns.iter().find(|c| c.name == column))
32                    .map(|c| &c.r#type);
33
34                let normalized_default = if let Some(col_type) = column_type {
35                    normalize_enum_default(col_type, default_value)
36                } else {
37                    default_value.to_string()
38                };
39
40                format!(
41                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
42                    table, column, normalized_default
43                )
44            } else {
45                format!(
46                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
47                    table, column
48                )
49            };
50            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
51        }
52        DatabaseBackend::MySql => {
53            // MySQL requires the full column definition in ALTER COLUMN
54            let table_def = current_schema
55                .iter()
56                .find(|t| t.name == table)
57                .ok_or_else(|| {
58                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
59                })?;
60
61            let column_def = table_def
62                .columns
63                .iter()
64                .find(|c| c.name == column)
65                .ok_or_else(|| {
66                    QueryError::Other(format!(
67                        "Column '{}' not found in table '{}'.",
68                        column, table
69                    ))
70                })?;
71
72            // Create a modified column def with the new default
73            let modified_col_def = ColumnDef {
74                default: new_default.map(|s| s.into()),
75                ..column_def.clone()
76            };
77
78            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
79
80            let stmt = Table::alter()
81                .table(Alias::new(table))
82                .modify_column(sea_col)
83                .to_owned();
84            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
85        }
86        DatabaseBackend::Sqlite => {
87            // SQLite doesn't support ALTER COLUMN for default changes
88            // Use temporary table approach
89            let table_def = current_schema
90                .iter()
91                .find(|t| t.name == table)
92                .ok_or_else(|| {
93                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
94                })?;
95
96            // Create modified columns with the new default
97            let mut new_columns = table_def.columns.clone();
98            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
99                col.default = new_default.map(|s| s.into());
100            }
101
102            // Generate temporary table name
103            let temp_table = format!("{}_temp", table);
104
105            // 1. Create temporary table with modified column + CHECK constraints
106            let create_query = build_sqlite_temp_table_create(
107                backend,
108                &temp_table,
109                table,
110                &new_columns,
111                &table_def.constraints,
112            );
113            queries.push(create_query);
114
115            // 2. Copy data (all columns)
116            let column_aliases: Vec<Alias> = table_def
117                .columns
118                .iter()
119                .map(|c| Alias::new(&c.name))
120                .collect();
121            let mut select_query = Query::select();
122            for col_alias in &column_aliases {
123                select_query = select_query.column(col_alias.clone()).to_owned();
124            }
125            select_query = select_query.from(Alias::new(table)).to_owned();
126
127            let insert_stmt = Query::insert()
128                .into_table(Alias::new(&temp_table))
129                .columns(column_aliases.clone())
130                .select_from(select_query)
131                .unwrap()
132                .to_owned();
133            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
134
135            // 3. Drop original table
136            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
137            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
138
139            // 4. Rename temporary table to original name
140            queries.push(build_rename_table(&temp_table, table));
141
142            // 5. Recreate indexes (both regular and UNIQUE)
143            queries.extend(recreate_indexes_after_rebuild(
144                table,
145                &table_def.constraints,
146                pending_constraints,
147            ));
148        }
149    }
150
151    Ok(queries)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use insta::{assert_snapshot, with_settings};
158    use rstest::rstest;
159    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
160
161    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
162        ColumnDef {
163            name: name.to_string(),
164            r#type: ty,
165            nullable,
166            default: None,
167            comment: None,
168            primary_key: None,
169            unique: None,
170            index: None,
171            foreign_key: None,
172        }
173    }
174
175    fn table_def(
176        name: &str,
177        columns: Vec<ColumnDef>,
178        constraints: Vec<TableConstraint>,
179    ) -> TableDef {
180        TableDef {
181            name: name.to_string(),
182            description: None,
183            columns,
184            constraints,
185        }
186    }
187
188    #[rstest]
189    #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
190    #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
191    #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
192    #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
193    #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
194    #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
195    fn test_build_modify_column_default(
196        #[case] backend: DatabaseBackend,
197        #[case] new_default: Option<&str>,
198    ) {
199        let schema = vec![table_def(
200            "users",
201            vec![
202                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
203                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
204            ],
205            vec![],
206        )];
207
208        let result =
209            build_modify_column_default(&backend, "users", "email", new_default, &schema, &[]);
210        assert!(result.is_ok());
211        let queries = result.unwrap();
212        let sql = queries
213            .iter()
214            .map(|q| q.build(backend))
215            .collect::<Vec<String>>()
216            .join("\n");
217
218        let suffix = format!(
219            "{}_{}_users",
220            match backend {
221                DatabaseBackend::Postgres => "postgres",
222                DatabaseBackend::MySql => "mysql",
223                DatabaseBackend::Sqlite => "sqlite",
224            },
225            if new_default.is_some() {
226                "set_default"
227            } else {
228                "drop_default"
229            }
230        );
231
232        with_settings!({ snapshot_suffix => suffix }, {
233            assert_snapshot!(sql);
234        });
235    }
236
237    /// Test table not found error
238    #[rstest]
239    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
240    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
241    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
242    fn test_table_not_found(#[case] backend: DatabaseBackend) {
243        // Postgres doesn't need schema lookup for default changes
244        if backend == DatabaseBackend::Postgres {
245            return;
246        }
247
248        let result =
249            build_modify_column_default(&backend, "users", "email", Some("'default'"), &[], &[]);
250        assert!(result.is_err());
251        let err_msg = result.unwrap_err().to_string();
252        assert!(err_msg.contains("Table 'users' not found"));
253    }
254
255    /// Test column not found error
256    #[rstest]
257    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
258    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
259    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
260    fn test_column_not_found(#[case] backend: DatabaseBackend) {
261        // Postgres doesn't need schema lookup for default changes
262        // SQLite doesn't validate column existence in modify_column_default
263        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
264            return;
265        }
266
267        let schema = vec![table_def(
268            "users",
269            vec![col(
270                "id",
271                ColumnType::Simple(SimpleColumnType::Integer),
272                false,
273            )],
274            vec![],
275        )];
276
277        let result = build_modify_column_default(
278            &backend,
279            "users",
280            "email",
281            Some("'default'"),
282            &schema,
283            &[],
284        );
285        assert!(result.is_err());
286        let err_msg = result.unwrap_err().to_string();
287        assert!(err_msg.contains("Column 'email' not found"));
288    }
289
290    /// Test Postgres default change when column is not in schema
291    /// This covers the fallback path where column_type is None
292    #[test]
293    fn test_postgres_column_not_in_schema_uses_default_as_is() {
294        let schema = vec![table_def(
295            "users",
296            vec![col(
297                "id",
298                ColumnType::Simple(SimpleColumnType::Integer),
299                false,
300            )],
301            // Note: "status" column is NOT in the schema
302            vec![],
303        )];
304
305        // Postgres doesn't error when column isn't found - it just uses the default as-is
306        let result = build_modify_column_default(
307            &DatabaseBackend::Postgres,
308            "users",
309            "status", // column not in schema
310            Some("'active'"),
311            &schema,
312            &[],
313        );
314        assert!(result.is_ok());
315        let queries = result.unwrap();
316        let sql = queries
317            .iter()
318            .map(|q| q.build(DatabaseBackend::Postgres))
319            .collect::<Vec<String>>()
320            .join("\n");
321
322        // Should still generate valid SQL, using the default value as-is
323        assert!(sql.contains("ALTER TABLE \"users\" ALTER COLUMN \"status\" SET DEFAULT 'active'"));
324    }
325
326    /// Test with index - should recreate index after table rebuild (SQLite)
327    #[rstest]
328    #[case::postgres_with_index(DatabaseBackend::Postgres)]
329    #[case::mysql_with_index(DatabaseBackend::MySql)]
330    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
331    fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
332        let schema = vec![table_def(
333            "users",
334            vec![
335                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
336                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
337            ],
338            vec![TableConstraint::Index {
339                name: Some("idx_users_email".into()),
340                columns: vec!["email".into()],
341            }],
342        )];
343
344        let result = build_modify_column_default(
345            &backend,
346            "users",
347            "email",
348            Some("'default@example.com'"),
349            &schema,
350            &[],
351        );
352        assert!(result.is_ok());
353        let queries = result.unwrap();
354        let sql = queries
355            .iter()
356            .map(|q| q.build(backend))
357            .collect::<Vec<String>>()
358            .join("\n");
359
360        // SQLite should recreate the index after table rebuild
361        if backend == DatabaseBackend::Sqlite {
362            assert!(sql.contains("CREATE INDEX"));
363            assert!(sql.contains("idx_users_email"));
364        }
365
366        let suffix = format!(
367            "{}_with_index",
368            match backend {
369                DatabaseBackend::Postgres => "postgres",
370                DatabaseBackend::MySql => "mysql",
371                DatabaseBackend::Sqlite => "sqlite",
372            }
373        );
374
375        with_settings!({ snapshot_suffix => suffix }, {
376            assert_snapshot!(sql);
377        });
378    }
379
380    /// Test changing default value from one to another
381    #[rstest]
382    #[case::postgres_change_default(DatabaseBackend::Postgres)]
383    #[case::mysql_change_default(DatabaseBackend::MySql)]
384    #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
385    fn test_change_default_value(#[case] backend: DatabaseBackend) {
386        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
387        email_col.default = Some("'old@example.com'".into());
388
389        let schema = vec![table_def(
390            "users",
391            vec![
392                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
393                email_col,
394            ],
395            vec![],
396        )];
397
398        let result = build_modify_column_default(
399            &backend,
400            "users",
401            "email",
402            Some("'new@example.com'"),
403            &schema,
404            &[],
405        );
406        assert!(result.is_ok());
407        let queries = result.unwrap();
408        let sql = queries
409            .iter()
410            .map(|q| q.build(backend))
411            .collect::<Vec<String>>()
412            .join("\n");
413
414        let suffix = format!(
415            "{}_change_default",
416            match backend {
417                DatabaseBackend::Postgres => "postgres",
418                DatabaseBackend::MySql => "mysql",
419                DatabaseBackend::Sqlite => "sqlite",
420            }
421        );
422
423        with_settings!({ snapshot_suffix => suffix }, {
424            assert_snapshot!(sql);
425        });
426    }
427
428    /// Test with integer default value
429    #[rstest]
430    #[case::postgres_integer_default(DatabaseBackend::Postgres)]
431    #[case::mysql_integer_default(DatabaseBackend::MySql)]
432    #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
433    fn test_integer_default(#[case] backend: DatabaseBackend) {
434        let schema = vec![table_def(
435            "products",
436            vec![
437                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
438                col(
439                    "quantity",
440                    ColumnType::Simple(SimpleColumnType::Integer),
441                    false,
442                ),
443            ],
444            vec![],
445        )];
446
447        let result =
448            build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema, &[]);
449        assert!(result.is_ok());
450        let queries = result.unwrap();
451        let sql = queries
452            .iter()
453            .map(|q| q.build(backend))
454            .collect::<Vec<String>>()
455            .join("\n");
456
457        let suffix = format!(
458            "{}_integer_default",
459            match backend {
460                DatabaseBackend::Postgres => "postgres",
461                DatabaseBackend::MySql => "mysql",
462                DatabaseBackend::Sqlite => "sqlite",
463            }
464        );
465
466        with_settings!({ snapshot_suffix => suffix }, {
467            assert_snapshot!(sql);
468        });
469    }
470
471    /// Test with boolean default value
472    #[rstest]
473    #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
474    #[case::mysql_boolean_default(DatabaseBackend::MySql)]
475    #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
476    fn test_boolean_default(#[case] backend: DatabaseBackend) {
477        let schema = vec![table_def(
478            "users",
479            vec![
480                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
481                col(
482                    "is_active",
483                    ColumnType::Simple(SimpleColumnType::Boolean),
484                    false,
485                ),
486            ],
487            vec![],
488        )];
489
490        let result =
491            build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema, &[]);
492        assert!(result.is_ok());
493        let queries = result.unwrap();
494        let sql = queries
495            .iter()
496            .map(|q| q.build(backend))
497            .collect::<Vec<String>>()
498            .join("\n");
499
500        let suffix = format!(
501            "{}_boolean_default",
502            match backend {
503                DatabaseBackend::Postgres => "postgres",
504                DatabaseBackend::MySql => "mysql",
505                DatabaseBackend::Sqlite => "sqlite",
506            }
507        );
508
509        with_settings!({ snapshot_suffix => suffix }, {
510            assert_snapshot!(sql);
511        });
512    }
513
514    /// Test with function default (e.g., NOW(), CURRENT_TIMESTAMP)
515    #[rstest]
516    #[case::postgres_function_default(DatabaseBackend::Postgres)]
517    #[case::mysql_function_default(DatabaseBackend::MySql)]
518    #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
519    fn test_function_default(#[case] backend: DatabaseBackend) {
520        let schema = vec![table_def(
521            "events",
522            vec![
523                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
524                col(
525                    "created_at",
526                    ColumnType::Simple(SimpleColumnType::Timestamp),
527                    false,
528                ),
529            ],
530            vec![],
531        )];
532
533        let default_value = match backend {
534            DatabaseBackend::Postgres => "NOW()",
535            DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
536            DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
537        };
538
539        let result = build_modify_column_default(
540            &backend,
541            "events",
542            "created_at",
543            Some(default_value),
544            &schema,
545            &[],
546        );
547        assert!(result.is_ok());
548        let queries = result.unwrap();
549        let sql = queries
550            .iter()
551            .map(|q| q.build(backend))
552            .collect::<Vec<String>>()
553            .join("\n");
554
555        let suffix = format!(
556            "{}_function_default",
557            match backend {
558                DatabaseBackend::Postgres => "postgres",
559                DatabaseBackend::MySql => "mysql",
560                DatabaseBackend::Sqlite => "sqlite",
561            }
562        );
563
564        with_settings!({ snapshot_suffix => suffix }, {
565            assert_snapshot!(sql);
566        });
567    }
568
569    /// Test dropping default from column that had one
570    #[rstest]
571    #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
572    #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
573    #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
574    fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
575        let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
576        status_col.default = Some("'pending'".into());
577
578        let schema = vec![table_def(
579            "orders",
580            vec![
581                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
582                status_col,
583            ],
584            vec![],
585        )];
586
587        let result = build_modify_column_default(
588            &backend,
589            "orders",
590            "status",
591            None, // Drop default
592            &schema,
593            &[],
594        );
595        assert!(result.is_ok());
596        let queries = result.unwrap();
597        let sql = queries
598            .iter()
599            .map(|q| q.build(backend))
600            .collect::<Vec<String>>()
601            .join("\n");
602
603        let suffix = format!(
604            "{}_drop_existing_default",
605            match backend {
606                DatabaseBackend::Postgres => "postgres",
607                DatabaseBackend::MySql => "mysql",
608                DatabaseBackend::Sqlite => "sqlite",
609            }
610        );
611
612        with_settings!({ snapshot_suffix => suffix }, {
613            assert_snapshot!(sql);
614        });
615    }
616}