1use sea_query::{Alias, Query, Table};
2
3use vespertide_core::{ColumnDef, TableDef};
4
5use super::create_table::build_create_table_for_backend;
6use super::helpers::{build_sea_column_def_with_table, normalize_fill_with};
7use super::rename_table::build_rename_table;
8use super::types::{BuiltQuery, DatabaseBackend, RawSql};
9use crate::error::QueryError;
10
11pub fn build_modify_column_nullable(
14 backend: &DatabaseBackend,
15 table: &str,
16 column: &str,
17 nullable: bool,
18 fill_with: Option<&str>,
19 current_schema: &[TableDef],
20) -> Result<Vec<BuiltQuery>, QueryError> {
21 let mut queries = Vec::new();
22
23 if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
25 let update_sql = match backend {
26 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
27 "UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
28 table, column, fill_value, column
29 ),
30 DatabaseBackend::MySql => format!(
31 "UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
32 table, column, fill_value, column
33 ),
34 };
35 queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
36 }
37
38 match backend {
40 DatabaseBackend::Postgres => {
41 let alter_sql = if nullable {
42 format!(
43 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
44 table, column
45 )
46 } else {
47 format!(
48 "ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
49 table, column
50 )
51 };
52 queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
53 }
54 DatabaseBackend::MySql => {
55 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
58
59 let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
60
61 let modified_col_def = ColumnDef {
63 nullable,
64 ..column_def.clone()
65 };
66
67 let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
69
70 let stmt = Table::alter()
71 .table(Alias::new(table))
72 .modify_column(sea_col)
73 .to_owned();
74 queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
75 }
76 DatabaseBackend::Sqlite => {
77 let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
80
81 let mut new_columns = table_def.columns.clone();
83 if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
84 col.nullable = nullable;
85 }
86
87 let temp_table = format!("{}_temp", table);
89
90 let create_temp_table = build_create_table_for_backend(
92 backend,
93 &temp_table,
94 &new_columns,
95 &table_def.constraints,
96 );
97 queries.push(BuiltQuery::CreateTable(Box::new(create_temp_table)));
98
99 let column_aliases: Vec<Alias> = table_def
101 .columns
102 .iter()
103 .map(|c| Alias::new(&c.name))
104 .collect();
105 let mut select_query = Query::select();
106 for col_alias in &column_aliases {
107 select_query = select_query.column(col_alias.clone()).to_owned();
108 }
109 select_query = select_query.from(Alias::new(table)).to_owned();
110
111 let insert_stmt = Query::insert()
112 .into_table(Alias::new(&temp_table))
113 .columns(column_aliases.clone())
114 .select_from(select_query)
115 .unwrap()
116 .to_owned();
117 queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
118
119 let drop_table = Table::drop().table(Alias::new(table)).to_owned();
121 queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
122
123 queries.push(build_rename_table(&temp_table, table));
125
126 for constraint in &table_def.constraints {
128 if let vespertide_core::TableConstraint::Index {
129 name: idx_name,
130 columns: idx_cols,
131 } = constraint
132 {
133 let index_name =
134 vespertide_naming::build_index_name(table, idx_cols, idx_name.as_deref());
135 let mut idx_stmt = sea_query::Index::create();
136 idx_stmt = idx_stmt.name(&index_name).to_owned();
137 for col_name in idx_cols {
138 idx_stmt = idx_stmt.col(Alias::new(col_name)).to_owned();
139 }
140 idx_stmt = idx_stmt.table(Alias::new(table)).to_owned();
141 queries.push(BuiltQuery::CreateIndex(Box::new(idx_stmt)));
142 }
143 }
144 }
145 }
146
147 Ok(queries)
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use insta::{assert_snapshot, with_settings};
154 use rstest::rstest;
155 use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
156
157 fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
158 ColumnDef {
159 name: name.to_string(),
160 r#type: ty,
161 nullable,
162 default: None,
163 comment: None,
164 primary_key: None,
165 unique: None,
166 index: None,
167 foreign_key: None,
168 }
169 }
170
171 fn table_def(
172 name: &str,
173 columns: Vec<ColumnDef>,
174 constraints: Vec<TableConstraint>,
175 ) -> TableDef {
176 TableDef {
177 name: name.to_string(),
178 description: None,
179 columns,
180 constraints,
181 }
182 }
183
184 #[rstest]
185 #[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
186 #[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
187 #[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
188 #[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
189 #[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
190 #[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
191 #[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
192 #[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
193 #[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
194 fn test_build_modify_column_nullable(
195 #[case] backend: DatabaseBackend,
196 #[case] nullable: bool,
197 #[case] fill_with: Option<&str>,
198 ) {
199 let schema = vec![table_def(
200 "users",
201 vec![
202 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
203 col(
204 "email",
205 ColumnType::Simple(SimpleColumnType::Text),
206 !nullable,
207 ),
208 ],
209 vec![],
210 )];
211
212 let result =
213 build_modify_column_nullable(&backend, "users", "email", nullable, fill_with, &schema);
214 assert!(result.is_ok());
215 let queries = result.unwrap();
216 let sql = queries
217 .iter()
218 .map(|q| q.build(backend))
219 .collect::<Vec<String>>()
220 .join("\n");
221
222 let suffix = format!(
223 "{}_{}_users{}",
224 match backend {
225 DatabaseBackend::Postgres => "postgres",
226 DatabaseBackend::MySql => "mysql",
227 DatabaseBackend::Sqlite => "sqlite",
228 },
229 if nullable { "nullable" } else { "not_null" },
230 if fill_with.is_some() {
231 "_with_fill"
232 } else {
233 ""
234 }
235 );
236
237 with_settings!({ snapshot_suffix => suffix }, {
238 assert_snapshot!(sql);
239 });
240 }
241
242 #[rstest]
244 #[case::postgres_table_not_found(DatabaseBackend::Postgres)]
245 #[case::mysql_table_not_found(DatabaseBackend::MySql)]
246 #[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
247 fn test_table_not_found(#[case] backend: DatabaseBackend) {
248 if backend == DatabaseBackend::Postgres {
250 return;
251 }
252
253 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &[]);
254 assert!(result.is_err());
255 let err_msg = result.unwrap_err().to_string();
256 assert!(err_msg.contains("Table 'users' not found"));
257 }
258
259 #[rstest]
261 #[case::postgres_column_not_found(DatabaseBackend::Postgres)]
262 #[case::mysql_column_not_found(DatabaseBackend::MySql)]
263 #[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
264 fn test_column_not_found(#[case] backend: DatabaseBackend) {
265 if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
268 return;
269 }
270
271 let schema = vec![table_def(
272 "users",
273 vec![col(
274 "id",
275 ColumnType::Simple(SimpleColumnType::Integer),
276 false,
277 )],
278 vec![],
279 )];
280
281 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
282 assert!(result.is_err());
283 let err_msg = result.unwrap_err().to_string();
284 assert!(err_msg.contains("Column 'email' not found"));
285 }
286
287 #[rstest]
289 #[case::postgres_with_index(DatabaseBackend::Postgres)]
290 #[case::mysql_with_index(DatabaseBackend::MySql)]
291 #[case::sqlite_with_index(DatabaseBackend::Sqlite)]
292 fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
293 let schema = vec![table_def(
294 "users",
295 vec![
296 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
297 col("email", ColumnType::Simple(SimpleColumnType::Text), true),
298 ],
299 vec![TableConstraint::Index {
300 name: Some("idx_email".into()),
301 columns: vec!["email".into()],
302 }],
303 )];
304
305 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
306 assert!(result.is_ok());
307 let queries = result.unwrap();
308 let sql = queries
309 .iter()
310 .map(|q| q.build(backend))
311 .collect::<Vec<String>>()
312 .join("\n");
313
314 if backend == DatabaseBackend::Sqlite {
316 assert!(sql.contains("CREATE INDEX"));
317 assert!(sql.contains("idx_email"));
318 }
319
320 let suffix = format!(
321 "{}_with_index",
322 match backend {
323 DatabaseBackend::Postgres => "postgres",
324 DatabaseBackend::MySql => "mysql",
325 DatabaseBackend::Sqlite => "sqlite",
326 }
327 );
328
329 with_settings!({ snapshot_suffix => suffix }, {
330 assert_snapshot!(sql);
331 });
332 }
333
334 #[rstest]
336 #[case::postgres_with_default(DatabaseBackend::Postgres)]
337 #[case::mysql_with_default(DatabaseBackend::MySql)]
338 #[case::sqlite_with_default(DatabaseBackend::Sqlite)]
339 fn test_with_default_value(#[case] backend: DatabaseBackend) {
340 let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
341 email_col.default = Some("'default@example.com'".into());
342
343 let schema = vec![table_def(
344 "users",
345 vec![
346 col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
347 email_col,
348 ],
349 vec![],
350 )];
351
352 let result = build_modify_column_nullable(&backend, "users", "email", false, None, &schema);
353 assert!(result.is_ok());
354 let queries = result.unwrap();
355 let sql = queries
356 .iter()
357 .map(|q| q.build(backend))
358 .collect::<Vec<String>>()
359 .join("\n");
360
361 if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
363 assert!(sql.contains("DEFAULT"));
364 }
365
366 let suffix = format!(
367 "{}_with_default",
368 match backend {
369 DatabaseBackend::Postgres => "postgres",
370 DatabaseBackend::MySql => "mysql",
371 DatabaseBackend::Sqlite => "sqlite",
372 }
373 );
374
375 with_settings!({ snapshot_suffix => suffix }, {
376 assert_snapshot!(sql);
377 });
378 }
379}