Skip to main content

vespertide_query/sql/
modify_column_default.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::helpers::{
6    build_sea_column_def_with_table, build_sqlite_temp_table_create, normalize_enum_default,
7    recreate_indexes_after_rebuild,
8};
9use super::rename_table::build_rename_table;
10use super::types::{BuiltQuery, DatabaseBackend, RawSql};
11use crate::error::QueryError;
12
13/// Build SQL for changing column default value.
14pub fn build_modify_column_default(
15    backend: &DatabaseBackend,
16    table: &str,
17    column: &str,
18    new_default: Option<&str>,
19    current_schema: &[TableDef],
20) -> Result<Vec<BuiltQuery>, QueryError> {
21    let mut queries = Vec::new();
22
23    match backend {
24        DatabaseBackend::Postgres => {
25            let alter_sql = if let Some(default_value) = new_default {
26                // Look up column type to properly quote enum defaults
27                let column_type = current_schema
28                    .iter()
29                    .find(|t| t.name == table)
30                    .and_then(|t| t.columns.iter().find(|c| c.name == column))
31                    .map(|c| &c.r#type);
32
33                let normalized_default = if let Some(col_type) = column_type {
34                    normalize_enum_default(col_type, default_value)
35                } else {
36                    default_value.to_string()
37                };
38
39                format!(
40                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
41                    table, column, normalized_default
42                )
43            } else {
44                format!(
45                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
46                    table, column
47                )
48            };
49            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
50        }
51        DatabaseBackend::MySql => {
52            // MySQL requires the full column definition in ALTER COLUMN
53            let table_def = current_schema
54                .iter()
55                .find(|t| t.name == table)
56                .ok_or_else(|| {
57                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
58                })?;
59
60            let column_def = table_def
61                .columns
62                .iter()
63                .find(|c| c.name == column)
64                .ok_or_else(|| {
65                    QueryError::Other(format!(
66                        "Column '{}' not found in table '{}'.",
67                        column, table
68                    ))
69                })?;
70
71            // Create a modified column def with the new default
72            let modified_col_def = ColumnDef {
73                default: new_default.map(|s| s.into()),
74                ..column_def.clone()
75            };
76
77            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
78
79            let stmt = Table::alter()
80                .table(Alias::new(table))
81                .modify_column(sea_col)
82                .to_owned();
83            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
84        }
85        DatabaseBackend::Sqlite => {
86            // SQLite doesn't support ALTER COLUMN for default changes
87            // Use temporary table approach
88            let table_def = current_schema
89                .iter()
90                .find(|t| t.name == table)
91                .ok_or_else(|| {
92                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
93                })?;
94
95            // Create modified columns with the new default
96            let mut new_columns = table_def.columns.clone();
97            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
98                col.default = new_default.map(|s| s.into());
99            }
100
101            // Generate temporary table name
102            let temp_table = format!("{}_temp", table);
103
104            // 1. Create temporary table with modified column + CHECK constraints
105            let create_query = build_sqlite_temp_table_create(
106                backend,
107                &temp_table,
108                table,
109                &new_columns,
110                &table_def.constraints,
111            );
112            queries.push(create_query);
113
114            // 2. Copy data (all columns)
115            let column_aliases: Vec<Alias> = table_def
116                .columns
117                .iter()
118                .map(|c| Alias::new(&c.name))
119                .collect();
120            let mut select_query = Query::select();
121            for col_alias in &column_aliases {
122                select_query = select_query.column(col_alias.clone()).to_owned();
123            }
124            select_query = select_query.from(Alias::new(table)).to_owned();
125
126            let insert_stmt = Query::insert()
127                .into_table(Alias::new(&temp_table))
128                .columns(column_aliases.clone())
129                .select_from(select_query)
130                .unwrap()
131                .to_owned();
132            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
133
134            // 3. Drop original table
135            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
136            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
137
138            // 4. Rename temporary table to original name
139            queries.push(build_rename_table(&temp_table, table));
140
141            // 5. Recreate indexes (both regular and UNIQUE)
142            queries.extend(recreate_indexes_after_rebuild(
143                table,
144                &table_def.constraints,
145                &[],
146            ));
147        }
148    }
149
150    Ok(queries)
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use insta::{assert_snapshot, with_settings};
157    use rstest::rstest;
158    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
159
160    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
161        ColumnDef {
162            name: name.to_string(),
163            r#type: ty,
164            nullable,
165            default: None,
166            comment: None,
167            primary_key: None,
168            unique: None,
169            index: None,
170            foreign_key: None,
171        }
172    }
173
174    fn table_def(
175        name: &str,
176        columns: Vec<ColumnDef>,
177        constraints: Vec<TableConstraint>,
178    ) -> TableDef {
179        TableDef {
180            name: name.to_string(),
181            description: None,
182            columns,
183            constraints,
184        }
185    }
186
187    #[rstest]
188    #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
189    #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
190    #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
191    #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
192    #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
193    #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
194    fn test_build_modify_column_default(
195        #[case] backend: DatabaseBackend,
196        #[case] new_default: Option<&str>,
197    ) {
198        let schema = vec![table_def(
199            "users",
200            vec![
201                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
202                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
203            ],
204            vec![],
205        )];
206
207        let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
208        assert!(result.is_ok());
209        let queries = result.unwrap();
210        let sql = queries
211            .iter()
212            .map(|q| q.build(backend))
213            .collect::<Vec<String>>()
214            .join("\n");
215
216        let suffix = format!(
217            "{}_{}_users",
218            match backend {
219                DatabaseBackend::Postgres => "postgres",
220                DatabaseBackend::MySql => "mysql",
221                DatabaseBackend::Sqlite => "sqlite",
222            },
223            if new_default.is_some() {
224                "set_default"
225            } else {
226                "drop_default"
227            }
228        );
229
230        with_settings!({ snapshot_suffix => suffix }, {
231            assert_snapshot!(sql);
232        });
233    }
234
235    /// Test table not found error
236    #[rstest]
237    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
238    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
239    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
240    fn test_table_not_found(#[case] backend: DatabaseBackend) {
241        // Postgres doesn't need schema lookup for default changes
242        if backend == DatabaseBackend::Postgres {
243            return;
244        }
245
246        let result =
247            build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
248        assert!(result.is_err());
249        let err_msg = result.unwrap_err().to_string();
250        assert!(err_msg.contains("Table 'users' not found"));
251    }
252
253    /// Test column not found error
254    #[rstest]
255    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
256    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
257    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
258    fn test_column_not_found(#[case] backend: DatabaseBackend) {
259        // Postgres doesn't need schema lookup for default changes
260        // SQLite doesn't validate column existence in modify_column_default
261        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
262            return;
263        }
264
265        let schema = vec![table_def(
266            "users",
267            vec![col(
268                "id",
269                ColumnType::Simple(SimpleColumnType::Integer),
270                false,
271            )],
272            vec![],
273        )];
274
275        let result =
276            build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
277        assert!(result.is_err());
278        let err_msg = result.unwrap_err().to_string();
279        assert!(err_msg.contains("Column 'email' not found"));
280    }
281
282    /// Test Postgres default change when column is not in schema
283    /// This covers the fallback path where column_type is None
284    #[test]
285    fn test_postgres_column_not_in_schema_uses_default_as_is() {
286        let schema = vec![table_def(
287            "users",
288            vec![col(
289                "id",
290                ColumnType::Simple(SimpleColumnType::Integer),
291                false,
292            )],
293            // Note: "status" column is NOT in the schema
294            vec![],
295        )];
296
297        // Postgres doesn't error when column isn't found - it just uses the default as-is
298        let result = build_modify_column_default(
299            &DatabaseBackend::Postgres,
300            "users",
301            "status", // column not in schema
302            Some("'active'"),
303            &schema,
304        );
305        assert!(result.is_ok());
306        let queries = result.unwrap();
307        let sql = queries
308            .iter()
309            .map(|q| q.build(DatabaseBackend::Postgres))
310            .collect::<Vec<String>>()
311            .join("\n");
312
313        // Should still generate valid SQL, using the default value as-is
314        assert!(sql.contains("ALTER TABLE \"users\" ALTER COLUMN \"status\" SET DEFAULT 'active'"));
315    }
316
317    /// Test with index - should recreate index after table rebuild (SQLite)
318    #[rstest]
319    #[case::postgres_with_index(DatabaseBackend::Postgres)]
320    #[case::mysql_with_index(DatabaseBackend::MySql)]
321    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
322    fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
323        let schema = vec![table_def(
324            "users",
325            vec![
326                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
327                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
328            ],
329            vec![TableConstraint::Index {
330                name: Some("idx_users_email".into()),
331                columns: vec!["email".into()],
332            }],
333        )];
334
335        let result = build_modify_column_default(
336            &backend,
337            "users",
338            "email",
339            Some("'default@example.com'"),
340            &schema,
341        );
342        assert!(result.is_ok());
343        let queries = result.unwrap();
344        let sql = queries
345            .iter()
346            .map(|q| q.build(backend))
347            .collect::<Vec<String>>()
348            .join("\n");
349
350        // SQLite should recreate the index after table rebuild
351        if backend == DatabaseBackend::Sqlite {
352            assert!(sql.contains("CREATE INDEX"));
353            assert!(sql.contains("idx_users_email"));
354        }
355
356        let suffix = format!(
357            "{}_with_index",
358            match backend {
359                DatabaseBackend::Postgres => "postgres",
360                DatabaseBackend::MySql => "mysql",
361                DatabaseBackend::Sqlite => "sqlite",
362            }
363        );
364
365        with_settings!({ snapshot_suffix => suffix }, {
366            assert_snapshot!(sql);
367        });
368    }
369
370    /// Test changing default value from one to another
371    #[rstest]
372    #[case::postgres_change_default(DatabaseBackend::Postgres)]
373    #[case::mysql_change_default(DatabaseBackend::MySql)]
374    #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
375    fn test_change_default_value(#[case] backend: DatabaseBackend) {
376        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
377        email_col.default = Some("'old@example.com'".into());
378
379        let schema = vec![table_def(
380            "users",
381            vec![
382                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
383                email_col,
384            ],
385            vec![],
386        )];
387
388        let result = build_modify_column_default(
389            &backend,
390            "users",
391            "email",
392            Some("'new@example.com'"),
393            &schema,
394        );
395        assert!(result.is_ok());
396        let queries = result.unwrap();
397        let sql = queries
398            .iter()
399            .map(|q| q.build(backend))
400            .collect::<Vec<String>>()
401            .join("\n");
402
403        let suffix = format!(
404            "{}_change_default",
405            match backend {
406                DatabaseBackend::Postgres => "postgres",
407                DatabaseBackend::MySql => "mysql",
408                DatabaseBackend::Sqlite => "sqlite",
409            }
410        );
411
412        with_settings!({ snapshot_suffix => suffix }, {
413            assert_snapshot!(sql);
414        });
415    }
416
417    /// Test with integer default value
418    #[rstest]
419    #[case::postgres_integer_default(DatabaseBackend::Postgres)]
420    #[case::mysql_integer_default(DatabaseBackend::MySql)]
421    #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
422    fn test_integer_default(#[case] backend: DatabaseBackend) {
423        let schema = vec![table_def(
424            "products",
425            vec![
426                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
427                col(
428                    "quantity",
429                    ColumnType::Simple(SimpleColumnType::Integer),
430                    false,
431                ),
432            ],
433            vec![],
434        )];
435
436        let result =
437            build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
438        assert!(result.is_ok());
439        let queries = result.unwrap();
440        let sql = queries
441            .iter()
442            .map(|q| q.build(backend))
443            .collect::<Vec<String>>()
444            .join("\n");
445
446        let suffix = format!(
447            "{}_integer_default",
448            match backend {
449                DatabaseBackend::Postgres => "postgres",
450                DatabaseBackend::MySql => "mysql",
451                DatabaseBackend::Sqlite => "sqlite",
452            }
453        );
454
455        with_settings!({ snapshot_suffix => suffix }, {
456            assert_snapshot!(sql);
457        });
458    }
459
460    /// Test with boolean default value
461    #[rstest]
462    #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
463    #[case::mysql_boolean_default(DatabaseBackend::MySql)]
464    #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
465    fn test_boolean_default(#[case] backend: DatabaseBackend) {
466        let schema = vec![table_def(
467            "users",
468            vec![
469                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
470                col(
471                    "is_active",
472                    ColumnType::Simple(SimpleColumnType::Boolean),
473                    false,
474                ),
475            ],
476            vec![],
477        )];
478
479        let result =
480            build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
481        assert!(result.is_ok());
482        let queries = result.unwrap();
483        let sql = queries
484            .iter()
485            .map(|q| q.build(backend))
486            .collect::<Vec<String>>()
487            .join("\n");
488
489        let suffix = format!(
490            "{}_boolean_default",
491            match backend {
492                DatabaseBackend::Postgres => "postgres",
493                DatabaseBackend::MySql => "mysql",
494                DatabaseBackend::Sqlite => "sqlite",
495            }
496        );
497
498        with_settings!({ snapshot_suffix => suffix }, {
499            assert_snapshot!(sql);
500        });
501    }
502
503    /// Test with function default (e.g., NOW(), CURRENT_TIMESTAMP)
504    #[rstest]
505    #[case::postgres_function_default(DatabaseBackend::Postgres)]
506    #[case::mysql_function_default(DatabaseBackend::MySql)]
507    #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
508    fn test_function_default(#[case] backend: DatabaseBackend) {
509        let schema = vec![table_def(
510            "events",
511            vec![
512                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
513                col(
514                    "created_at",
515                    ColumnType::Simple(SimpleColumnType::Timestamp),
516                    false,
517                ),
518            ],
519            vec![],
520        )];
521
522        let default_value = match backend {
523            DatabaseBackend::Postgres => "NOW()",
524            DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
525            DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
526        };
527
528        let result = build_modify_column_default(
529            &backend,
530            "events",
531            "created_at",
532            Some(default_value),
533            &schema,
534        );
535        assert!(result.is_ok());
536        let queries = result.unwrap();
537        let sql = queries
538            .iter()
539            .map(|q| q.build(backend))
540            .collect::<Vec<String>>()
541            .join("\n");
542
543        let suffix = format!(
544            "{}_function_default",
545            match backend {
546                DatabaseBackend::Postgres => "postgres",
547                DatabaseBackend::MySql => "mysql",
548                DatabaseBackend::Sqlite => "sqlite",
549            }
550        );
551
552        with_settings!({ snapshot_suffix => suffix }, {
553            assert_snapshot!(sql);
554        });
555    }
556
557    /// Test dropping default from column that had one
558    #[rstest]
559    #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
560    #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
561    #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
562    fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
563        let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
564        status_col.default = Some("'pending'".into());
565
566        let schema = vec![table_def(
567            "orders",
568            vec![
569                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
570                status_col,
571            ],
572            vec![],
573        )];
574
575        let result = build_modify_column_default(
576            &backend, "orders", "status", None, // Drop default
577            &schema,
578        );
579        assert!(result.is_ok());
580        let queries = result.unwrap();
581        let sql = queries
582            .iter()
583            .map(|q| q.build(backend))
584            .collect::<Vec<String>>()
585            .join("\n");
586
587        let suffix = format!(
588            "{}_drop_existing_default",
589            match backend {
590                DatabaseBackend::Postgres => "postgres",
591                DatabaseBackend::MySql => "mysql",
592                DatabaseBackend::Sqlite => "sqlite",
593            }
594        );
595
596        with_settings!({ snapshot_suffix => suffix }, {
597            assert_snapshot!(sql);
598        });
599    }
600}