Skip to main content

vespertide_query/sql/
modify_column_default.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{build_sea_column_def_with_table, normalize_enum_default};
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11/// Build SQL for changing column default value.
12pub fn build_modify_column_default(
13    backend: &DatabaseBackend,
14    table: &str,
15    column: &str,
16    new_default: Option<&str>,
17    current_schema: &[TableDef],
18) -> Result<Vec<BuiltQuery>, QueryError> {
19    let mut queries = Vec::new();
20
21    match backend {
22        DatabaseBackend::Postgres => {
23            let alter_sql = if let Some(default_value) = new_default {
24                // Look up column type to properly quote enum defaults
25                let column_type = current_schema
26                    .iter()
27                    .find(|t| t.name == table)
28                    .and_then(|t| t.columns.iter().find(|c| c.name == column))
29                    .map(|c| &c.r#type);
30
31                let normalized_default = if let Some(col_type) = column_type {
32                    normalize_enum_default(col_type, default_value)
33                } else {
34                    default_value.to_string()
35                };
36
37                format!(
38                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
39                    table, column, normalized_default
40                )
41            } else {
42                format!(
43                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
44                    table, column
45                )
46            };
47            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
48        }
49        DatabaseBackend::MySql => {
50            // MySQL requires the full column definition in ALTER COLUMN
51            let table_def = current_schema
52                .iter()
53                .find(|t| t.name == table)
54                .ok_or_else(|| {
55                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
56                })?;
57
58            let column_def = table_def
59                .columns
60                .iter()
61                .find(|c| c.name == column)
62                .ok_or_else(|| {
63                    QueryError::Other(format!(
64                        "Column '{}' not found in table '{}'.",
65                        column, table
66                    ))
67                })?;
68
69            // Create a modified column def with the new default
70            let modified_col_def = ColumnDef {
71                default: new_default.map(|s| s.into()),
72                ..column_def.clone()
73            };
74
75            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
76
77            let stmt = Table::alter()
78                .table(Alias::new(table))
79                .modify_column(sea_col)
80                .to_owned();
81            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
82        }
83        DatabaseBackend::Sqlite => {
84            // SQLite doesn't support ALTER COLUMN for default changes
85            // Use temporary table approach
86            let table_def = current_schema
87                .iter()
88                .find(|t| t.name == table)
89                .ok_or_else(|| {
90                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
91                })?;
92
93            // Create modified columns with the new default
94            let mut new_columns = table_def.columns.clone();
95            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
96                col.default = new_default.map(|s| s.into());
97            }
98
99            // Generate temporary table name
100            let temp_table = format!("{}_temp", table);
101
102            // 1. Create temporary table with modified column
103            let create_temp_table = build_create_table_for_backend(
104                backend,
105                &temp_table,
106                &new_columns,
107                &table_def.constraints,
108            );
109            queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
110
111            // 2. Copy data (all columns)
112            let column_aliases: Vec<Alias> = table_def
113                .columns
114                .iter()
115                .map(|c| Alias::new(&c.name))
116                .collect();
117            let mut select_query = Query::select();
118            for col_alias in &column_aliases {
119                select_query = select_query.column(col_alias.clone()).to_owned();
120            }
121            select_query = select_query.from(Alias::new(table)).to_owned();
122
123            let insert_stmt = Query::insert()
124                .into_table(Alias::new(&temp_table))
125                .columns(column_aliases.clone())
126                .select_from(select_query)
127                .unwrap()
128                .to_owned();
129            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
130
131            // 3. Drop original table
132            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
133            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
134
135            // 4. Rename temporary table to original name
136            queries.push(build_rename_table(&temp_table, table));
137
138            // 5. Recreate indexes from Index constraints
139            for constraint in &table_def.constraints {
140                if let vespertide_core::TableConstraint::Index {
141                    name: idx_name,
142                    columns: idx_cols,
143                } = constraint
144                {
145                    let index_name =
146                        vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
147                    let mut idx_stmt = sea_query::Index::create();
148                    idx_stmt = idx_stmt.name(&index_name).to_owned();
149                    for col_name in idx_cols {
150                        idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
151                    }
152                    idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
153                    queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
154                }
155            }
156        }
157    }
158
159    Ok(queries)
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use insta::{assert_snapshot, with_settings};
166    use rstest::rstest;
167    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
168
169    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
170        ColumnDef {
171            name: name.to_string(),
172            r#type: ty,
173            nullable,
174            default: None,
175            comment: None,
176            primary_key: None,
177            unique: None,
178            index: None,
179            foreign_key: None,
180        }
181    }
182
183    fn table_def(
184        name: &str,
185        columns: Vec<ColumnDef>,
186        constraints: Vec<TableConstraint>,
187    ) -> TableDef {
188        TableDef {
189            name: name.to_string(),
190            description: None,
191            columns,
192            constraints,
193        }
194    }
195
196    #[rstest]
197    #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
198    #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
199    #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
200    #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
201    #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
202    #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
203    fn test_build_modify_column_default(
204        #[case] backend: DatabaseBackend,
205        #[case] new_default: Option<&str>,
206    ) {
207        let schema = vec![table_def(
208            "users",
209            vec![
210                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
211                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
212            ],
213            vec![],
214        )];
215
216        let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
217        assert!(result.is_ok());
218        let queries = result.unwrap();
219        let sql = queries
220            .iter()
221            .map(|q| q.build(backend))
222            .collect::<Vec<String>>()
223            .join("\n");
224
225        let suffix = format!(
226            "{}_{}_users",
227            match backend {
228                DatabaseBackend::Postgres => "postgres",
229                DatabaseBackend::MySql => "mysql",
230                DatabaseBackend::Sqlite => "sqlite",
231            },
232            if new_default.is_some() {
233                "set_default"
234            } else {
235                "drop_default"
236            }
237        );
238
239        with_settings!({ snapshot_suffix => suffix }, {
240            assert_snapshot!(sql);
241        });
242    }
243
244    /// Test table not found error
245    #[rstest]
246    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
247    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
248    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
249    fn test_table_not_found(#[case] backend: DatabaseBackend) {
250        // Postgres doesn't need schema lookup for default changes
251        if backend == DatabaseBackend::Postgres {
252            return;
253        }
254
255        let result =
256            build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
257        assert!(result.is_err());
258        let err_msg = result.unwrap_err().to_string();
259        assert!(err_msg.contains("Table 'users' not found"));
260    }
261
262    /// Test column not found error
263    #[rstest]
264    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
265    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
266    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
267    fn test_column_not_found(#[case] backend: DatabaseBackend) {
268        // Postgres doesn't need schema lookup for default changes
269        // SQLite doesn't validate column existence in modify_column_default
270        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
271            return;
272        }
273
274        let schema = vec![table_def(
275            "users",
276            vec![col(
277                "id",
278                ColumnType::Simple(SimpleColumnType::Integer),
279                false,
280            )],
281            vec![],
282        )];
283
284        let result =
285            build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
286        assert!(result.is_err());
287        let err_msg = result.unwrap_err().to_string();
288        assert!(err_msg.contains("Column 'email' not found"));
289    }
290
291    /// Test Postgres default change when column is not in schema
292    /// This covers the fallback path where column_type is None
293    #[test]
294    fn test_postgres_column_not_in_schema_uses_default_as_is() {
295        let schema = vec![table_def(
296            "users",
297            vec![col(
298                "id",
299                ColumnType::Simple(SimpleColumnType::Integer),
300                false,
301            )],
302            // Note: "status" column is NOT in the schema
303            vec![],
304        )];
305
306        // Postgres doesn't error when column isn't found - it just uses the default as-is
307        let result = build_modify_column_default(
308            &DatabaseBackend::Postgres,
309            "users",
310            "status", // column not in schema
311            Some("'active'"),
312            &schema,
313        );
314        assert!(result.is_ok());
315        let queries = result.unwrap();
316        let sql = queries
317            .iter()
318            .map(|q| q.build(DatabaseBackend::Postgres))
319            .collect::<Vec<String>>()
320            .join("\n");
321
322        // Should still generate valid SQL, using the default value as-is
323        assert!(sql.contains("ALTER TABLE \"users\" ALTER COLUMN \"status\" SET DEFAULT 'active'"));
324    }
325
326    /// Test with index - should recreate index after table rebuild (SQLite)
327    #[rstest]
328    #[case::postgres_with_index(DatabaseBackend::Postgres)]
329    #[case::mysql_with_index(DatabaseBackend::MySql)]
330    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
331    fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
332        let schema = vec![table_def(
333            "users",
334            vec![
335                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
336                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
337            ],
338            vec![TableConstraint::Index {
339                name: Some("idx_users_email".into()),
340                columns: vec!["email".into()],
341            }],
342        )];
343
344        let result = build_modify_column_default(
345            &backend,
346            "users",
347            "email",
348            Some("'default@example.com'"),
349            &schema,
350        );
351        assert!(result.is_ok());
352        let queries = result.unwrap();
353        let sql = queries
354            .iter()
355            .map(|q| q.build(backend))
356            .collect::<Vec<String>>()
357            .join("\n");
358
359        // SQLite should recreate the index after table rebuild
360        if backend == DatabaseBackend::Sqlite {
361            assert!(sql.contains("CREATE INDEX"));
362            assert!(sql.contains("idx_users_email"));
363        }
364
365        let suffix = format!(
366            "{}_with_index",
367            match backend {
368                DatabaseBackend::Postgres => "postgres",
369                DatabaseBackend::MySql => "mysql",
370                DatabaseBackend::Sqlite => "sqlite",
371            }
372        );
373
374        with_settings!({ snapshot_suffix => suffix }, {
375            assert_snapshot!(sql);
376        });
377    }
378
379    /// Test changing default value from one to another
380    #[rstest]
381    #[case::postgres_change_default(DatabaseBackend::Postgres)]
382    #[case::mysql_change_default(DatabaseBackend::MySql)]
383    #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
384    fn test_change_default_value(#[case] backend: DatabaseBackend) {
385        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
386        email_col.default = Some("'old@example.com'".into());
387
388        let schema = vec![table_def(
389            "users",
390            vec![
391                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
392                email_col,
393            ],
394            vec![],
395        )];
396
397        let result = build_modify_column_default(
398            &backend,
399            "users",
400            "email",
401            Some("'new@example.com'"),
402            &schema,
403        );
404        assert!(result.is_ok());
405        let queries = result.unwrap();
406        let sql = queries
407            .iter()
408            .map(|q| q.build(backend))
409            .collect::<Vec<String>>()
410            .join("\n");
411
412        let suffix = format!(
413            "{}_change_default",
414            match backend {
415                DatabaseBackend::Postgres => "postgres",
416                DatabaseBackend::MySql => "mysql",
417                DatabaseBackend::Sqlite => "sqlite",
418            }
419        );
420
421        with_settings!({ snapshot_suffix => suffix }, {
422            assert_snapshot!(sql);
423        });
424    }
425
426    /// Test with integer default value
427    #[rstest]
428    #[case::postgres_integer_default(DatabaseBackend::Postgres)]
429    #[case::mysql_integer_default(DatabaseBackend::MySql)]
430    #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
431    fn test_integer_default(#[case] backend: DatabaseBackend) {
432        let schema = vec![table_def(
433            "products",
434            vec![
435                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
436                col(
437                    "quantity",
438                    ColumnType::Simple(SimpleColumnType::Integer),
439                    false,
440                ),
441            ],
442            vec![],
443        )];
444
445        let result =
446            build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
447        assert!(result.is_ok());
448        let queries = result.unwrap();
449        let sql = queries
450            .iter()
451            .map(|q| q.build(backend))
452            .collect::<Vec<String>>()
453            .join("\n");
454
455        let suffix = format!(
456            "{}_integer_default",
457            match backend {
458                DatabaseBackend::Postgres => "postgres",
459                DatabaseBackend::MySql => "mysql",
460                DatabaseBackend::Sqlite => "sqlite",
461            }
462        );
463
464        with_settings!({ snapshot_suffix => suffix }, {
465            assert_snapshot!(sql);
466        });
467    }
468
469    /// Test with boolean default value
470    #[rstest]
471    #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
472    #[case::mysql_boolean_default(DatabaseBackend::MySql)]
473    #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
474    fn test_boolean_default(#[case] backend: DatabaseBackend) {
475        let schema = vec![table_def(
476            "users",
477            vec![
478                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
479                col(
480                    "is_active",
481                    ColumnType::Simple(SimpleColumnType::Boolean),
482                    false,
483                ),
484            ],
485            vec![],
486        )];
487
488        let result =
489            build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
490        assert!(result.is_ok());
491        let queries = result.unwrap();
492        let sql = queries
493            .iter()
494            .map(|q| q.build(backend))
495            .collect::<Vec<String>>()
496            .join("\n");
497
498        let suffix = format!(
499            "{}_boolean_default",
500            match backend {
501                DatabaseBackend::Postgres => "postgres",
502                DatabaseBackend::MySql => "mysql",
503                DatabaseBackend::Sqlite => "sqlite",
504            }
505        );
506
507        with_settings!({ snapshot_suffix => suffix }, {
508            assert_snapshot!(sql);
509        });
510    }
511
512    /// Test with function default (e.g., NOW(), CURRENT_TIMESTAMP)
513    #[rstest]
514    #[case::postgres_function_default(DatabaseBackend::Postgres)]
515    #[case::mysql_function_default(DatabaseBackend::MySql)]
516    #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
517    fn test_function_default(#[case] backend: DatabaseBackend) {
518        let schema = vec![table_def(
519            "events",
520            vec![
521                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
522                col(
523                    "created_at",
524                    ColumnType::Simple(SimpleColumnType::Timestamp),
525                    false,
526                ),
527            ],
528            vec![],
529        )];
530
531        let default_value = match backend {
532            DatabaseBackend::Postgres => "NOW()",
533            DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
534            DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
535        };
536
537        let result = build_modify_column_default(
538            &backend,
539            "events",
540            "created_at",
541            Some(default_value),
542            &schema,
543        );
544        assert!(result.is_ok());
545        let queries = result.unwrap();
546        let sql = queries
547            .iter()
548            .map(|q| q.build(backend))
549            .collect::<Vec<String>>()
550            .join("\n");
551
552        let suffix = format!(
553            "{}_function_default",
554            match backend {
555                DatabaseBackend::Postgres => "postgres",
556                DatabaseBackend::MySql => "mysql",
557                DatabaseBackend::Sqlite => "sqlite",
558            }
559        );
560
561        with_settings!({ snapshot_suffix => suffix }, {
562            assert_snapshot!(sql);
563        });
564    }
565
566    /// Test dropping default from column that had one
567    #[rstest]
568    #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
569    #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
570    #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
571    fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
572        let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
573        status_col.default = Some("'pending'".into());
574
575        let schema = vec![table_def(
576            "orders",
577            vec![
578                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
579                status_col,
580            ],
581            vec![],
582        )];
583
584        let result = build_modify_column_default(
585            &backend, "orders", "status", None, // Drop default
586            &schema,
587        );
588        assert!(result.is_ok());
589        let queries = result.unwrap();
590        let sql = queries
591            .iter()
592            .map(|q| q.build(backend))
593            .collect::<Vec<String>>()
594            .join("\n");
595
596        let suffix = format!(
597            "{}_drop_existing_default",
598            match backend {
599                DatabaseBackend::Postgres => "postgres",
600                DatabaseBackend::MySql => "mysql",
601                DatabaseBackend::Sqlite => "sqlite",
602            }
603        );
604
605        with_settings!({ snapshot_suffix => suffix }, {
606            assert_snapshot!(sql);
607        });
608    }
609}