1use sea_query::{
2 Alias, ColumnDef as SeaColumnDef, ForeignKeyAction, MysqlQueryBuilder, PostgresQueryBuilder,
3 QueryStatementWriter, SchemaStatementBuilder, SimpleExpr, SqliteQueryBuilder,
4};
5
6use vespertide_core::{
7 ColumnDef, ColumnType, ComplexColumnType, ReferenceAction, SimpleColumnType,
8};
9
10use super::types::DatabaseBackend;
11
12pub fn normalize_fill_with(fill_with: Option<&str>) -> Option<String> {
14 fill_with.map(|s| {
15 if s.is_empty() {
16 "''".to_string()
17 } else {
18 s.to_string()
19 }
20 })
21}
22
23pub fn build_schema_statement<T: SchemaStatementBuilder>(
25 stmt: &T,
26 backend: DatabaseBackend,
27) -> String {
28 match backend {
29 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
30 DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
31 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
32 }
33}
34
35pub fn build_query_statement<T: QueryStatementWriter>(
37 stmt: &T,
38 backend: DatabaseBackend,
39) -> String {
40 match backend {
41 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
42 DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
43 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
44 }
45}
46
47pub fn apply_column_type_with_table(col: &mut SeaColumnDef, ty: &ColumnType, table: &str) {
49 match ty {
50 ColumnType::Simple(simple) => match simple {
51 SimpleColumnType::SmallInt => {
52 col.small_integer();
53 }
54 SimpleColumnType::Integer => {
55 col.integer();
56 }
57 SimpleColumnType::BigInt => {
58 col.big_integer();
59 }
60 SimpleColumnType::Real => {
61 col.float();
62 }
63 SimpleColumnType::DoublePrecision => {
64 col.double();
65 }
66 SimpleColumnType::Text => {
67 col.text();
68 }
69 SimpleColumnType::Boolean => {
70 col.boolean();
71 }
72 SimpleColumnType::Date => {
73 col.date();
74 }
75 SimpleColumnType::Time => {
76 col.time();
77 }
78 SimpleColumnType::Timestamp => {
79 col.timestamp();
80 }
81 SimpleColumnType::Timestamptz => {
82 col.timestamp_with_time_zone();
83 }
84 SimpleColumnType::Interval => {
85 col.interval(None, None);
86 }
87 SimpleColumnType::Bytea => {
88 col.binary();
89 }
90 SimpleColumnType::Uuid => {
91 col.uuid();
92 }
93 SimpleColumnType::Json => {
94 col.json();
95 }
96 SimpleColumnType::Inet => {
97 col.custom(Alias::new("INET"));
98 }
99 SimpleColumnType::Cidr => {
100 col.custom(Alias::new("CIDR"));
101 }
102 SimpleColumnType::Macaddr => {
103 col.custom(Alias::new("MACADDR"));
104 }
105 SimpleColumnType::Xml => {
106 col.custom(Alias::new("XML"));
107 }
108 },
109 ColumnType::Complex(complex) => match complex {
110 ComplexColumnType::Varchar { length } => {
111 col.string_len(*length);
112 }
113 ComplexColumnType::Numeric { precision, scale } => {
114 col.decimal_len(*precision, *scale);
115 }
116 ComplexColumnType::Char { length } => {
117 col.char_len(*length);
118 }
119 ComplexColumnType::Custom { custom_type } => {
120 col.custom(Alias::new(custom_type));
121 }
122 ComplexColumnType::Enum { name, values } => {
123 if values.is_integer() {
125 col.integer();
126 } else {
127 let type_name = build_enum_type_name(table, name);
129 col.enumeration(
130 Alias::new(&type_name),
131 values
132 .variant_names()
133 .into_iter()
134 .map(Alias::new)
135 .collect::<Vec<Alias>>(),
136 );
137 }
138 }
139 },
140 }
141}
142
143pub fn to_sea_fk_action(action: &ReferenceAction) -> ForeignKeyAction {
145 match action {
146 ReferenceAction::Cascade => ForeignKeyAction::Cascade,
147 ReferenceAction::Restrict => ForeignKeyAction::Restrict,
148 ReferenceAction::SetNull => ForeignKeyAction::SetNull,
149 ReferenceAction::SetDefault => ForeignKeyAction::SetDefault,
150 ReferenceAction::NoAction => ForeignKeyAction::NoAction,
151 }
152}
153
154pub fn reference_action_sql(action: &ReferenceAction) -> &'static str {
156 match action {
157 ReferenceAction::Cascade => "CASCADE",
158 ReferenceAction::Restrict => "RESTRICT",
159 ReferenceAction::SetNull => "SET NULL",
160 ReferenceAction::SetDefault => "SET DEFAULT",
161 ReferenceAction::NoAction => "NO ACTION",
162 }
163}
164
165pub fn convert_default_for_backend(default: &str, backend: &DatabaseBackend) -> String {
167 let lower = default.to_lowercase();
168
169 if lower == "gen_random_uuid()" || lower == "uuid()" || lower == "lower(hex(randomblob(16)))" {
171 return match backend {
172 DatabaseBackend::Postgres => "gen_random_uuid()".to_string(),
173 DatabaseBackend::MySql => "(UUID())".to_string(),
174 DatabaseBackend::Sqlite => "lower(hex(randomblob(16)))".to_string(),
175 };
176 }
177
178 if lower == "current_timestamp()"
180 || lower == "now()"
181 || lower == "current_timestamp"
182 || lower == "getdate()"
183 {
184 return match backend {
185 DatabaseBackend::Postgres => "CURRENT_TIMESTAMP".to_string(),
186 DatabaseBackend::MySql => "CURRENT_TIMESTAMP".to_string(),
187 DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP".to_string(),
188 };
189 }
190
191 default.to_string()
192}
193
194fn is_enum_type(column_type: &ColumnType) -> bool {
196 matches!(
197 column_type,
198 ColumnType::Complex(ComplexColumnType::Enum { .. })
199 )
200}
201
202pub fn normalize_enum_default(column_type: &ColumnType, value: &str) -> String {
205 if is_enum_type(column_type) && needs_quoting(value) {
206 format!("'{}'", value)
207 } else {
208 value.to_string()
209 }
210}
211
212fn needs_quoting(default_str: &str) -> bool {
214 let trimmed = default_str.trim();
215 if trimmed.is_empty() {
217 return true;
218 }
219 if trimmed.starts_with('\'') || trimmed.starts_with('"') {
221 return false;
222 }
223 if trimmed.contains('(') || trimmed.contains(')') {
225 return false;
226 }
227 if trimmed.eq_ignore_ascii_case("null") {
229 return false;
230 }
231 if trimmed.eq_ignore_ascii_case("current_timestamp")
233 || trimmed.eq_ignore_ascii_case("current_date")
234 || trimmed.eq_ignore_ascii_case("current_time")
235 {
236 return false;
237 }
238 true
239}
240
241pub fn build_sea_column_def_with_table(
243 backend: &DatabaseBackend,
244 table: &str,
245 column: &ColumnDef,
246) -> SeaColumnDef {
247 let mut col = SeaColumnDef::new(Alias::new(&column.name));
248 apply_column_type_with_table(&mut col, &column.r#type, table);
249
250 if !column.nullable {
251 col.not_null();
252 }
253
254 if let Some(default) = &column.default {
255 let default_str = default.to_sql();
256 let converted = convert_default_for_backend(&default_str, backend);
257
258 let final_default =
260 if is_enum_type(&column.r#type) && default.is_string() && needs_quoting(&converted) {
261 format!("'{}'", converted)
262 } else {
263 converted
264 };
265
266 col.default(Into::<SimpleExpr>::into(sea_query::Expr::cust(
267 final_default,
268 )));
269 }
270
271 col
272}
273
274pub fn build_create_enum_type_sql(
280 table: &str,
281 column_type: &ColumnType,
282) -> Option<super::types::RawSql> {
283 if let ColumnType::Complex(ComplexColumnType::Enum { name, values }) = column_type {
284 if values.is_integer() {
286 return None;
287 }
288
289 let values_sql = values.to_sql_values().join(", ");
290
291 let type_name = build_enum_type_name(table, name);
293
294 let pg_sql = format!("CREATE TYPE \"{}\" AS ENUM ({})", type_name, values_sql);
296
297 Some(super::types::RawSql::per_backend(
300 pg_sql,
301 String::new(),
302 String::new(),
303 ))
304 } else {
305 None
306 }
307}
308
309pub fn build_drop_enum_type_sql(
314 table: &str,
315 column_type: &ColumnType,
316) -> Option<super::types::RawSql> {
317 if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
318 let type_name = build_enum_type_name(table, name);
320
321 let pg_sql = format!("DROP TYPE IF EXISTS \"{}\"", type_name);
323
324 Some(super::types::RawSql::per_backend(
326 pg_sql,
327 String::new(),
328 String::new(),
329 ))
330 } else {
331 None
332 }
333}
334
335pub use vespertide_naming::{
337 build_check_constraint_name, build_enum_type_name, build_foreign_key_name, build_index_name,
338 build_unique_constraint_name,
339};
340
341pub fn build_sqlite_enum_check_name(table: &str, column: &str) -> String {
343 build_check_constraint_name(table, column)
344}
345
346pub fn build_sqlite_enum_check_clause(
349 table: &str,
350 column: &str,
351 column_type: &ColumnType,
352) -> Option<String> {
353 if let ColumnType::Complex(ComplexColumnType::Enum { values, .. }) = column_type {
354 let name = build_sqlite_enum_check_name(table, column);
355 let values_sql = values.to_sql_values().join(", ");
356 Some(format!(
357 "CONSTRAINT \"{}\" CHECK (\"{}\" IN ({}))",
358 name, column, values_sql
359 ))
360 } else {
361 None
362 }
363}
364
365pub fn collect_sqlite_enum_check_clauses(table: &str, columns: &[ColumnDef]) -> Vec<String> {
367 columns
368 .iter()
369 .filter_map(|col| build_sqlite_enum_check_clause(table, &col.name, &col.r#type))
370 .collect()
371}
372
373pub fn get_enum_name(column_type: &ColumnType) -> Option<&str> {
375 if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
376 Some(name.as_str())
377 } else {
378 None
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use rstest::rstest;
386 use sea_query::{Alias, ColumnDef as SeaColumnDef, ForeignKeyAction};
387 use vespertide_core::EnumValues;
388
389 #[rstest]
390 #[case(ColumnType::Simple(SimpleColumnType::Integer))]
391 #[case(ColumnType::Simple(SimpleColumnType::BigInt))]
392 #[case(ColumnType::Simple(SimpleColumnType::Text))]
393 #[case(ColumnType::Simple(SimpleColumnType::Boolean))]
394 #[case(ColumnType::Simple(SimpleColumnType::Timestamp))]
395 #[case(ColumnType::Simple(SimpleColumnType::Uuid))]
396 #[case(ColumnType::Complex(ComplexColumnType::Varchar { length: 255 }))]
397 #[case(ColumnType::Complex(ComplexColumnType::Numeric { precision: 10, scale: 2 }))]
398 fn test_column_type_conversion(#[case] ty: ColumnType) {
399 let mut col = SeaColumnDef::new(Alias::new("test"));
401 apply_column_type_with_table(&mut col, &ty, "test_table");
402 }
403
404 #[rstest]
405 #[case(SimpleColumnType::SmallInt)]
406 #[case(SimpleColumnType::Integer)]
407 #[case(SimpleColumnType::BigInt)]
408 #[case(SimpleColumnType::Real)]
409 #[case(SimpleColumnType::DoublePrecision)]
410 #[case(SimpleColumnType::Text)]
411 #[case(SimpleColumnType::Boolean)]
412 #[case(SimpleColumnType::Date)]
413 #[case(SimpleColumnType::Time)]
414 #[case(SimpleColumnType::Timestamp)]
415 #[case(SimpleColumnType::Timestamptz)]
416 #[case(SimpleColumnType::Interval)]
417 #[case(SimpleColumnType::Bytea)]
418 #[case(SimpleColumnType::Uuid)]
419 #[case(SimpleColumnType::Json)]
420 #[case(SimpleColumnType::Inet)]
421 #[case(SimpleColumnType::Cidr)]
422 #[case(SimpleColumnType::Macaddr)]
423 #[case(SimpleColumnType::Xml)]
424 fn test_all_simple_types_cover_branches(#[case] ty: SimpleColumnType) {
425 let mut col = SeaColumnDef::new(Alias::new("t"));
426 apply_column_type_with_table(&mut col, &ColumnType::Simple(ty), "test_table");
427 }
428
429 #[rstest]
430 #[case(ComplexColumnType::Varchar { length: 42 })]
431 #[case(ComplexColumnType::Numeric { precision: 8, scale: 3 })]
432 #[case(ComplexColumnType::Char { length: 3 })]
433 #[case(ComplexColumnType::Custom { custom_type: "GEOGRAPHY".into() })]
434 #[case(ComplexColumnType::Enum { name: "status".into(), values: EnumValues::String(vec!["active".into(), "inactive".into()]) })]
435 fn test_all_complex_types_cover_branches(#[case] ty: ComplexColumnType) {
436 let mut col = SeaColumnDef::new(Alias::new("t"));
437 apply_column_type_with_table(&mut col, &ColumnType::Complex(ty), "test_table");
438 }
439
440 #[rstest]
441 #[case::cascade(ReferenceAction::Cascade, ForeignKeyAction::Cascade)]
442 #[case::restrict(ReferenceAction::Restrict, ForeignKeyAction::Restrict)]
443 #[case::set_null(ReferenceAction::SetNull, ForeignKeyAction::SetNull)]
444 #[case::set_default(ReferenceAction::SetDefault, ForeignKeyAction::SetDefault)]
445 #[case::no_action(ReferenceAction::NoAction, ForeignKeyAction::NoAction)]
446 fn test_reference_action_conversion(
447 #[case] action: ReferenceAction,
448 #[case] expected: ForeignKeyAction,
449 ) {
450 let result = to_sea_fk_action(&action);
452 assert!(
453 matches!(result, _expected),
454 "Expected {:?}, got {:?}",
455 expected,
456 result
457 );
458 }
459
460 #[rstest]
461 #[case(ReferenceAction::Cascade, "CASCADE")]
462 #[case(ReferenceAction::Restrict, "RESTRICT")]
463 #[case(ReferenceAction::SetNull, "SET NULL")]
464 #[case(ReferenceAction::SetDefault, "SET DEFAULT")]
465 #[case(ReferenceAction::NoAction, "NO ACTION")]
466 fn test_reference_action_sql_all_variants(
467 #[case] action: ReferenceAction,
468 #[case] expected: &str,
469 ) {
470 assert_eq!(reference_action_sql(&action), expected);
471 }
472
473 #[rstest]
474 #[case::gen_random_uuid_postgres(
475 "gen_random_uuid()",
476 DatabaseBackend::Postgres,
477 "gen_random_uuid()"
478 )]
479 #[case::gen_random_uuid_mysql("gen_random_uuid()", DatabaseBackend::MySql, "(UUID())")]
480 #[case::gen_random_uuid_sqlite(
481 "gen_random_uuid()",
482 DatabaseBackend::Sqlite,
483 "lower(hex(randomblob(16)))"
484 )]
485 #[case::current_timestamp_postgres(
486 "current_timestamp()",
487 DatabaseBackend::Postgres,
488 "CURRENT_TIMESTAMP"
489 )]
490 #[case::current_timestamp_mysql(
491 "current_timestamp()",
492 DatabaseBackend::MySql,
493 "CURRENT_TIMESTAMP"
494 )]
495 #[case::current_timestamp_sqlite(
496 "current_timestamp()",
497 DatabaseBackend::Sqlite,
498 "CURRENT_TIMESTAMP"
499 )]
500 #[case::now_postgres("now()", DatabaseBackend::Postgres, "CURRENT_TIMESTAMP")]
501 #[case::now_mysql("now()", DatabaseBackend::MySql, "CURRENT_TIMESTAMP")]
502 #[case::now_sqlite("now()", DatabaseBackend::Sqlite, "CURRENT_TIMESTAMP")]
503 #[case::now_upper_postgres("NOW()", DatabaseBackend::Postgres, "CURRENT_TIMESTAMP")]
504 #[case::now_upper_mysql("NOW()", DatabaseBackend::MySql, "CURRENT_TIMESTAMP")]
505 #[case::now_upper_sqlite("NOW()", DatabaseBackend::Sqlite, "CURRENT_TIMESTAMP")]
506 #[case::current_timestamp_upper_postgres(
507 "CURRENT_TIMESTAMP",
508 DatabaseBackend::Postgres,
509 "CURRENT_TIMESTAMP"
510 )]
511 #[case::current_timestamp_upper_mysql(
512 "CURRENT_TIMESTAMP",
513 DatabaseBackend::MySql,
514 "CURRENT_TIMESTAMP"
515 )]
516 #[case::current_timestamp_upper_sqlite(
517 "CURRENT_TIMESTAMP",
518 DatabaseBackend::Sqlite,
519 "CURRENT_TIMESTAMP"
520 )]
521 fn test_convert_default_for_backend(
522 #[case] default: &str,
523 #[case] backend: DatabaseBackend,
524 #[case] expected: &str,
525 ) {
526 let result = convert_default_for_backend(default, &backend);
527 assert_eq!(result, expected);
528 }
529
530 #[test]
531 fn test_is_enum_type_true() {
532 use vespertide_core::EnumValues;
533
534 let enum_type = ColumnType::Complex(ComplexColumnType::Enum {
535 name: "status".into(),
536 values: EnumValues::String(vec!["active".into(), "inactive".into()]),
537 });
538 assert!(is_enum_type(&enum_type));
539 }
540
541 #[test]
542 fn test_is_enum_type_false() {
543 let text_type = ColumnType::Simple(SimpleColumnType::Text);
544 assert!(!is_enum_type(&text_type));
545 }
546
547 #[test]
548 fn test_get_enum_name_some() {
549 use vespertide_core::EnumValues;
550
551 let enum_type = ColumnType::Complex(ComplexColumnType::Enum {
552 name: "user_status".into(),
553 values: EnumValues::String(vec!["active".into(), "inactive".into()]),
554 });
555 assert_eq!(get_enum_name(&enum_type), Some("user_status"));
556 }
557
558 #[test]
559 fn test_get_enum_name_none() {
560 let text_type = ColumnType::Simple(SimpleColumnType::Text);
561 assert_eq!(get_enum_name(&text_type), None);
562 }
563
564 #[test]
565 fn test_apply_column_type_integer_enum() {
566 use vespertide_core::{EnumValues, NumValue};
567 let integer_enum = ColumnType::Complex(ComplexColumnType::Enum {
568 name: "color".into(),
569 values: EnumValues::Integer(vec![
570 NumValue {
571 name: "Black".into(),
572 value: 0,
573 },
574 NumValue {
575 name: "White".into(),
576 value: 1,
577 },
578 ]),
579 });
580 let mut col = SeaColumnDef::new(Alias::new("color"));
581 apply_column_type_with_table(&mut col, &integer_enum, "test_table");
582 }
584
585 #[test]
586 fn test_build_create_enum_type_sql_integer_enum_returns_none() {
587 use vespertide_core::{EnumValues, NumValue};
588 let integer_enum = ColumnType::Complex(ComplexColumnType::Enum {
589 name: "priority".into(),
590 values: EnumValues::Integer(vec![
591 NumValue {
592 name: "Low".into(),
593 value: 0,
594 },
595 NumValue {
596 name: "High".into(),
597 value: 10,
598 },
599 ]),
600 });
601 assert!(build_create_enum_type_sql("test_table", &integer_enum).is_none());
603 }
604
605 #[rstest]
606 #[case::empty("", true)]
608 #[case::whitespace_only(" ", true)]
609 #[case::now_func("now()", false)]
611 #[case::coalesce_func("COALESCE(old_value, 'default')", false)]
612 #[case::uuid_func("gen_random_uuid()", false)]
613 #[case::null_upper("NULL", false)]
615 #[case::null_lower("null", false)]
616 #[case::null_mixed("Null", false)]
617 #[case::current_timestamp_upper("CURRENT_TIMESTAMP", false)]
619 #[case::current_timestamp_lower("current_timestamp", false)]
620 #[case::current_date_upper("CURRENT_DATE", false)]
621 #[case::current_date_lower("current_date", false)]
622 #[case::current_time_upper("CURRENT_TIME", false)]
623 #[case::current_time_lower("current_time", false)]
624 #[case::single_quoted("'active'", false)]
626 #[case::double_quoted("\"active\"", false)]
627 #[case::plain_active("active", true)]
629 #[case::plain_pending("pending", true)]
630 #[case::plain_underscore("some_value", true)]
631 fn test_needs_quoting(#[case] input: &str, #[case] expected: bool) {
632 assert_eq!(needs_quoting(input), expected);
633 }
634}