Skip to main content

vespertide_query/sql/
modify_column_default.rs

1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{build_sea_column_def_with_table, normalize_enum_default};
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11/// Build SQL for changing column default value.
12pub fn build_modify_column_default(
13    backend: &DatabaseBackend,
14    table: &str,
15    column: &str,
16    new_default: Option<&str>,
17    current_schema: &[TableDef],
18) -> Result<Vec<BuiltQuery>, QueryError> {
19    let mut queries = Vec::new();
20
21    match backend {
22        DatabaseBackend::Postgres => {
23            let alter_sql = if let Some(default_value) = new_default {
24                // Look up column type to properly quote enum defaults
25                let column_type = current_schema
26                    .iter()
27                    .find(|t| t.name == table)
28                    .and_then(|t| t.columns.iter().find(|c| c.name == column))
29                    .map(|c| &c.r#type);
30
31                let normalized_default = if let Some(col_type) = column_type {
32                    normalize_enum_default(col_type, default_value)
33                } else {
34                    default_value.to_string()
35                };
36
37                format!(
38                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET DEFAULT {}",
39                    table, column, normalized_default
40                )
41            } else {
42                format!(
43                    "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP DEFAULT",
44                    table, column
45                )
46            };
47            queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
48        }
49        DatabaseBackend::MySql => {
50            // MySQL requires the full column definition in ALTER COLUMN
51            let table_def = current_schema
52                .iter()
53                .find(|t| t.name == table)
54                .ok_or_else(|| {
55                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
56                })?;
57
58            let column_def = table_def
59                .columns
60                .iter()
61                .find(|c| c.name == column)
62                .ok_or_else(|| {
63                    QueryError::Other(format!(
64                        "Column '{}' not found in table '{}'.",
65                        column, table
66                    ))
67                })?;
68
69            // Create a modified column def with the new default
70            let modified_col_def = ColumnDef {
71                default: new_default.map(|s| s.into()),
72                ..column_def.clone()
73            };
74
75            let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
76
77            let stmt = Table::alter()
78                .table(Alias::new(table))
79                .modify_column(sea_col)
80                .to_owned();
81            queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
82        }
83        DatabaseBackend::Sqlite => {
84            // SQLite doesn't support ALTER COLUMN for default changes
85            // Use temporary table approach
86            let table_def = current_schema
87                .iter()
88                .find(|t| t.name == table)
89                .ok_or_else(|| {
90                    QueryError::Other(format!("Table '{}' not found in current schema.", table))
91                })?;
92
93            // Create modified columns with the new default
94            let mut new_columns = table_def.columns.clone();
95            if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
96                col.default = new_default.map(|s| s.into());
97            }
98
99            // Generate temporary table name
100            let temp_table = format!("{}_temp", table);
101
102            // 1. Create temporary table with modified column
103            let create_temp_table = build_create_table_for_backend(
104                backend,
105                &temp_table,
106                &new_columns,
107                &table_def.constraints,
108            );
109            queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
110
111            // 2. Copy data (all columns)
112            let column_aliases: Vec<Alias> = table_def
113                .columns
114                .iter()
115                .map(|c| Alias::new(&c.name))
116                .collect();
117            let mut select_query = Query::select();
118            for col_alias in &column_aliases {
119                select_query = select_query.column(col_alias.clone()).to_owned();
120            }
121            select_query = select_query.from(Alias::new(table)).to_owned();
122
123            let insert_stmt = Query::insert()
124                .into_table(Alias::new(&temp_table))
125                .columns(column_aliases.clone())
126                .select_from(select_query)
127                .unwrap()
128                .to_owned();
129            queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
130
131            // 3. Drop original table
132            let drop_table = Table::drop().table(Alias::new(table)).to_owned();
133            queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
134
135            // 4. Rename temporary table to original name
136            queries.push(build_rename_table(&temp_table, table));
137
138            // 5. Recreate indexes from Index constraints
139            for constraint in &table_def.constraints {
140                if let vespertide_core::TableConstraint::Index {
141                    name: idx_name,
142                    columns: idx_cols,
143                } = constraint
144                {
145                    let index_name =
146                        vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
147                    let mut idx_stmt = sea_query::Index::create();
148                    idx_stmt = idx_stmt.name(&index_name).to_owned();
149                    for col_name in idx_cols {
150                        idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
151                    }
152                    idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
153                    queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
154                }
155            }
156        }
157    }
158
159    Ok(queries)
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use insta::{assert_snapshot, with_settings};
166    use rstest::rstest;
167    use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
168
169    fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
170        ColumnDef {
171            name: name.to_string(),
172            r#type: ty,
173            nullable,
174            default: None,
175            comment: None,
176            primary_key: None,
177            unique: None,
178            index: None,
179            foreign_key: None,
180        }
181    }
182
183    fn table_def(
184        name: &str,
185        columns: Vec<ColumnDef>,
186        constraints: Vec<TableConstraint>,
187    ) -> TableDef {
188        TableDef {
189            name: name.to_string(),
190            description: None,
191            columns,
192            constraints,
193        }
194    }
195
196    #[rstest]
197    #[case::postgres_set_default(DatabaseBackend::Postgres, Some("'unknown'"))]
198    #[case::postgres_drop_default(DatabaseBackend::Postgres, None)]
199    #[case::mysql_set_default(DatabaseBackend::MySql, Some("'unknown'"))]
200    #[case::mysql_drop_default(DatabaseBackend::MySql, None)]
201    #[case::sqlite_set_default(DatabaseBackend::Sqlite, Some("'unknown'"))]
202    #[case::sqlite_drop_default(DatabaseBackend::Sqlite, None)]
203    fn test_build_modify_column_default(
204        #[case] backend: DatabaseBackend,
205        #[case] new_default: Option<&str>,
206    ) {
207        let schema = vec![table_def(
208            "users",
209            vec![
210                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
211                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
212            ],
213            vec![],
214        )];
215
216        let result = build_modify_column_default(&backend, "users", "email", new_default, &schema);
217        assert!(result.is_ok());
218        let queries = result.unwrap();
219        let sql = queries
220            .iter()
221            .map(|q| q.build(backend))
222            .collect::<Vec<String>>()
223            .join("\n");
224
225        let suffix = format!(
226            "{}_{}_users",
227            match backend {
228                DatabaseBackend::Postgres => "postgres",
229                DatabaseBackend::MySql => "mysql",
230                DatabaseBackend::Sqlite => "sqlite",
231            },
232            if new_default.is_some() {
233                "set_default"
234            } else {
235                "drop_default"
236            }
237        );
238
239        with_settings!({ snapshot_suffix => suffix }, {
240            assert_snapshot!(sql);
241        });
242    }
243
244    /// Test table not found error
245    #[rstest]
246    #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
247    #[case::mysql_table_not_found(DatabaseBackend::MySql)]
248    #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
249    fn test_table_not_found(#[case] backend: DatabaseBackend) {
250        // Postgres doesn't need schema lookup for default changes
251        if backend == DatabaseBackend::Postgres {
252            return;
253        }
254
255        let result =
256            build_modify_column_default(&backend, "users", "email", Some("'default'"), &[]);
257        assert!(result.is_err());
258        let err_msg = result.unwrap_err().to_string();
259        assert!(err_msg.contains("Table 'users' not found"));
260    }
261
262    /// Test column not found error
263    #[rstest]
264    #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
265    #[case::mysql_column_not_found(DatabaseBackend::MySql)]
266    #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
267    fn test_column_not_found(#[case] backend: DatabaseBackend) {
268        // Postgres doesn't need schema lookup for default changes
269        // SQLite doesn't validate column existence in modify_column_default
270        if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
271            return;
272        }
273
274        let schema = vec![table_def(
275            "users",
276            vec![col(
277                "id",
278                ColumnType::Simple(SimpleColumnType::Integer),
279                false,
280            )],
281            vec![],
282        )];
283
284        let result =
285            build_modify_column_default(&backend, "users", "email", Some("'default'"), &schema);
286        assert!(result.is_err());
287        let err_msg = result.unwrap_err().to_string();
288        assert!(err_msg.contains("Column 'email' not found"));
289    }
290
291    /// Test with index - should recreate index after table rebuild (SQLite)
292    #[rstest]
293    #[case::postgres_with_index(DatabaseBackend::Postgres)]
294    #[case::mysql_with_index(DatabaseBackend::MySql)]
295    #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
296    fn test_modify_default_with_index(#[case] backend: DatabaseBackend) {
297        let schema = vec![table_def(
298            "users",
299            vec![
300                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
301                col("email", ColumnType::Simple(SimpleColumnType::Text), true),
302            ],
303            vec![TableConstraint::Index {
304                name: Some("idx_users_email".into()),
305                columns: vec!["email".into()],
306            }],
307        )];
308
309        let result = build_modify_column_default(
310            &backend,
311            "users",
312            "email",
313            Some("'default@example.com'"),
314            &schema,
315        );
316        assert!(result.is_ok());
317        let queries = result.unwrap();
318        let sql = queries
319            .iter()
320            .map(|q| q.build(backend))
321            .collect::<Vec<String>>()
322            .join("\n");
323
324        // SQLite should recreate the index after table rebuild
325        if backend == DatabaseBackend::Sqlite {
326            assert!(sql.contains("CREATE INDEX"));
327            assert!(sql.contains("idx_users_email"));
328        }
329
330        let suffix = format!(
331            "{}_with_index",
332            match backend {
333                DatabaseBackend::Postgres => "postgres",
334                DatabaseBackend::MySql => "mysql",
335                DatabaseBackend::Sqlite => "sqlite",
336            }
337        );
338
339        with_settings!({ snapshot_suffix => suffix }, {
340            assert_snapshot!(sql);
341        });
342    }
343
344    /// Test changing default value from one to another
345    #[rstest]
346    #[case::postgres_change_default(DatabaseBackend::Postgres)]
347    #[case::mysql_change_default(DatabaseBackend::MySql)]
348    #[case::sqlite_change_default(DatabaseBackend::Sqlite)]
349    fn test_change_default_value(#[case] backend: DatabaseBackend) {
350        let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
351        email_col.default = Some("'old@example.com'".into());
352
353        let schema = vec![table_def(
354            "users",
355            vec![
356                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
357                email_col,
358            ],
359            vec![],
360        )];
361
362        let result = build_modify_column_default(
363            &backend,
364            "users",
365            "email",
366            Some("'new@example.com'"),
367            &schema,
368        );
369        assert!(result.is_ok());
370        let queries = result.unwrap();
371        let sql = queries
372            .iter()
373            .map(|q| q.build(backend))
374            .collect::<Vec<String>>()
375            .join("\n");
376
377        let suffix = format!(
378            "{}_change_default",
379            match backend {
380                DatabaseBackend::Postgres => "postgres",
381                DatabaseBackend::MySql => "mysql",
382                DatabaseBackend::Sqlite => "sqlite",
383            }
384        );
385
386        with_settings!({ snapshot_suffix => suffix }, {
387            assert_snapshot!(sql);
388        });
389    }
390
391    /// Test with integer default value
392    #[rstest]
393    #[case::postgres_integer_default(DatabaseBackend::Postgres)]
394    #[case::mysql_integer_default(DatabaseBackend::MySql)]
395    #[case::sqlite_integer_default(DatabaseBackend::Sqlite)]
396    fn test_integer_default(#[case] backend: DatabaseBackend) {
397        let schema = vec![table_def(
398            "products",
399            vec![
400                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
401                col(
402                    "quantity",
403                    ColumnType::Simple(SimpleColumnType::Integer),
404                    false,
405                ),
406            ],
407            vec![],
408        )];
409
410        let result =
411            build_modify_column_default(&backend, "products", "quantity", Some("0"), &schema);
412        assert!(result.is_ok());
413        let queries = result.unwrap();
414        let sql = queries
415            .iter()
416            .map(|q| q.build(backend))
417            .collect::<Vec<String>>()
418            .join("\n");
419
420        let suffix = format!(
421            "{}_integer_default",
422            match backend {
423                DatabaseBackend::Postgres => "postgres",
424                DatabaseBackend::MySql => "mysql",
425                DatabaseBackend::Sqlite => "sqlite",
426            }
427        );
428
429        with_settings!({ snapshot_suffix => suffix }, {
430            assert_snapshot!(sql);
431        });
432    }
433
434    /// Test with boolean default value
435    #[rstest]
436    #[case::postgres_boolean_default(DatabaseBackend::Postgres)]
437    #[case::mysql_boolean_default(DatabaseBackend::MySql)]
438    #[case::sqlite_boolean_default(DatabaseBackend::Sqlite)]
439    fn test_boolean_default(#[case] backend: DatabaseBackend) {
440        let schema = vec![table_def(
441            "users",
442            vec![
443                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
444                col(
445                    "is_active",
446                    ColumnType::Simple(SimpleColumnType::Boolean),
447                    false,
448                ),
449            ],
450            vec![],
451        )];
452
453        let result =
454            build_modify_column_default(&backend, "users", "is_active", Some("true"), &schema);
455        assert!(result.is_ok());
456        let queries = result.unwrap();
457        let sql = queries
458            .iter()
459            .map(|q| q.build(backend))
460            .collect::<Vec<String>>()
461            .join("\n");
462
463        let suffix = format!(
464            "{}_boolean_default",
465            match backend {
466                DatabaseBackend::Postgres => "postgres",
467                DatabaseBackend::MySql => "mysql",
468                DatabaseBackend::Sqlite => "sqlite",
469            }
470        );
471
472        with_settings!({ snapshot_suffix => suffix }, {
473            assert_snapshot!(sql);
474        });
475    }
476
477    /// Test with function default (e.g., NOW(), CURRENT_TIMESTAMP)
478    #[rstest]
479    #[case::postgres_function_default(DatabaseBackend::Postgres)]
480    #[case::mysql_function_default(DatabaseBackend::MySql)]
481    #[case::sqlite_function_default(DatabaseBackend::Sqlite)]
482    fn test_function_default(#[case] backend: DatabaseBackend) {
483        let schema = vec![table_def(
484            "events",
485            vec![
486                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
487                col(
488                    "created_at",
489                    ColumnType::Simple(SimpleColumnType::Timestamp),
490                    false,
491                ),
492            ],
493            vec![],
494        )];
495
496        let default_value = match backend {
497            DatabaseBackend::Postgres => "NOW()",
498            DatabaseBackend::MySql => "CURRENT_TIMESTAMP",
499            DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP",
500        };
501
502        let result = build_modify_column_default(
503            &backend,
504            "events",
505            "created_at",
506            Some(default_value),
507            &schema,
508        );
509        assert!(result.is_ok());
510        let queries = result.unwrap();
511        let sql = queries
512            .iter()
513            .map(|q| q.build(backend))
514            .collect::<Vec<String>>()
515            .join("\n");
516
517        let suffix = format!(
518            "{}_function_default",
519            match backend {
520                DatabaseBackend::Postgres => "postgres",
521                DatabaseBackend::MySql => "mysql",
522                DatabaseBackend::Sqlite => "sqlite",
523            }
524        );
525
526        with_settings!({ snapshot_suffix => suffix }, {
527            assert_snapshot!(sql);
528        });
529    }
530
531    /// Test dropping default from column that had one
532    #[rstest]
533    #[case::postgres_drop_existing_default(DatabaseBackend::Postgres)]
534    #[case::mysql_drop_existing_default(DatabaseBackend::MySql)]
535    #[case::sqlite_drop_existing_default(DatabaseBackend::Sqlite)]
536    fn test_drop_existing_default(#[case] backend: DatabaseBackend) {
537        let mut status_col = col("status", ColumnType::Simple(SimpleColumnType::Text), false);
538        status_col.default = Some("'pending'".into());
539
540        let schema = vec![table_def(
541            "orders",
542            vec![
543                col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
544                status_col,
545            ],
546            vec![],
547        )];
548
549        let result = build_modify_column_default(
550            &backend, "orders", "status", None, // Drop default
551            &schema,
552        );
553        assert!(result.is_ok());
554        let queries = result.unwrap();
555        let sql = queries
556            .iter()
557            .map(|q| q.build(backend))
558            .collect::<Vec<String>>()
559            .join("\n");
560
561        let suffix = format!(
562            "{}_drop_existing_default",
563            match backend {
564                DatabaseBackend::Postgres => "postgres",
565                DatabaseBackend::MySql => "mysql",
566                DatabaseBackend::Sqlite => "sqlite",
567            }
568        );
569
570        with_settings!({ snapshot_suffix => suffix }, {
571            assert_snapshot!(sql);
572        });
573    }
574}