1use std::borrow::Cow;
2
3use sea_query::{
4 Alias, ColumnDef as SeaColumnDef, ForeignKeyAction, MysqlQueryBuilder, PostgresQueryBuilder,
5 QueryStatementWriter, SchemaStatementBuilder, SimpleExpr, SqliteQueryBuilder,
6};
7
8use vespertide_core::{
9 ColumnDef, ColumnType, ComplexColumnType, ReferenceAction, SimpleColumnType, TableConstraint,
10};
11
12use super::create_table::build_create_table_for_backend;
13use super::types::{BuiltQuery, DatabaseBackend, RawSql};
14
15#[must_use]
18pub fn normalize_fill_with(fill_with: Option<&str>) -> Option<Cow<'_, str>> {
19 fill_with.map(|s| {
20 if s.is_empty() {
21 Cow::Borrowed("''")
22 } else {
23 Cow::Borrowed(s)
24 }
25 })
26}
27
28pub fn build_schema_statement<T: SchemaStatementBuilder>(
30 stmt: &T,
31 backend: DatabaseBackend,
32) -> String {
33 match backend {
34 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
35 DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
36 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
37 }
38}
39
40pub fn build_query_statement<T: QueryStatementWriter>(
42 stmt: &T,
43 backend: DatabaseBackend,
44) -> String {
45 match backend {
46 DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
47 DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
48 DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
49 }
50}
51
52pub fn apply_column_type_with_table(
54 col: &mut SeaColumnDef,
55 ty: &ColumnType,
56 table: &str,
57 backend: DatabaseBackend,
58) {
59 match ty {
60 ColumnType::Simple(simple) => apply_simple_column_type(col, *simple, backend),
61 ColumnType::Complex(complex) => apply_complex_column_type(col, complex, table, backend),
62 }
63}
64
65fn apply_simple_column_type(
66 col: &mut SeaColumnDef,
67 simple: SimpleColumnType,
68 backend: DatabaseBackend,
69) {
70 match simple {
71 SimpleColumnType::SmallInt => {
72 col.small_integer();
73 }
74 SimpleColumnType::Integer => {
75 col.integer();
76 }
77 SimpleColumnType::BigInt => {
78 col.big_integer();
79 }
80 SimpleColumnType::Real => {
81 col.float();
82 }
83 SimpleColumnType::DoublePrecision => {
84 col.double();
85 }
86 SimpleColumnType::Text => {
87 col.text();
88 }
89 SimpleColumnType::Boolean => {
90 col.boolean();
91 }
92 SimpleColumnType::Date => {
93 col.date();
94 }
95 SimpleColumnType::Time => {
96 col.time();
97 }
98 SimpleColumnType::Timestamp => {
99 col.timestamp();
100 }
101 SimpleColumnType::Timestamptz => apply_timestamptz_type(col, backend),
102 SimpleColumnType::Interval => apply_interval_type(col, backend),
103 SimpleColumnType::Bytea => {
104 col.binary();
105 }
106 SimpleColumnType::Uuid => {
107 col.uuid();
108 }
109 SimpleColumnType::Json => {
110 col.json();
111 }
112 SimpleColumnType::Inet => apply_postgres_text_fallback_type(col, backend, "INET"),
113 SimpleColumnType::Cidr => apply_postgres_text_fallback_type(col, backend, "CIDR"),
114 SimpleColumnType::Macaddr => apply_postgres_text_fallback_type(col, backend, "MACADDR"),
115 SimpleColumnType::Xml => apply_postgres_text_fallback_type(col, backend, "XML"),
116 _ => unreachable!("SimpleColumnType is #[non_exhaustive]; all variants are matched above"),
117 }
118}
119
120fn apply_timestamptz_type(col: &mut SeaColumnDef, backend: DatabaseBackend) {
121 match backend {
122 DatabaseBackend::Postgres => {
123 col.timestamp_with_time_zone();
124 }
125 DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
126 col.timestamp();
127 }
128 }
129}
130
131fn apply_interval_type(col: &mut SeaColumnDef, backend: DatabaseBackend) {
132 match backend {
133 DatabaseBackend::Postgres => {
134 col.interval(None, None);
135 }
136 DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
137 col.text();
138 }
139 }
140}
141
142fn apply_postgres_text_fallback_type(
143 col: &mut SeaColumnDef,
144 backend: DatabaseBackend,
145 postgres_type: &str,
146) {
147 match backend {
148 DatabaseBackend::Postgres => {
149 col.custom(Alias::new(postgres_type));
150 }
151 DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
152 col.text();
153 }
154 }
155}
156
157fn apply_complex_column_type(
158 col: &mut SeaColumnDef,
159 complex: &ComplexColumnType,
160 table: &str,
161 backend: DatabaseBackend,
162) {
163 match complex {
164 ComplexColumnType::Varchar { length } => {
165 col.string_len(*length);
166 }
167 ComplexColumnType::Numeric { precision, scale } => {
168 apply_numeric_type(col, *precision, *scale, backend);
169 }
170 ComplexColumnType::Char { length } => {
171 col.char_len(*length);
172 }
173 ComplexColumnType::Custom { custom_type } => {
174 col.custom(Alias::new(custom_type));
175 }
176 ComplexColumnType::Enum { name, values } => {
177 if values.is_integer() {
179 col.integer();
180 } else {
181 let type_name = build_enum_type_name(table, name);
183 let variants = values
184 .variant_names()
185 .into_iter()
186 .map(Alias::new)
187 .collect::<Vec<Alias>>();
188 col.enumeration(Alias::new(&type_name), variants);
189 }
190 }
191 _ => unreachable!("ComplexColumnType is #[non_exhaustive]; all variants are matched above"),
192 }
193}
194
195fn apply_numeric_type(
196 col: &mut SeaColumnDef,
197 precision: u32,
198 scale: u32,
199 backend: DatabaseBackend,
200) {
201 debug_assert!(
202 scale <= precision,
203 "numeric scale ({scale}) must be <= precision ({precision}); schema validation should reject this before SQL generation"
204 );
205 let safe_precision = precision.min(28);
206 let safe_scale = scale.min(safe_precision);
207 match backend {
208 DatabaseBackend::Postgres | DatabaseBackend::MySql => {
209 col.decimal_len(safe_precision, safe_scale);
210 }
211 DatabaseBackend::Sqlite => {
212 col.double();
213 }
214 }
215}
216
217pub fn to_sea_fk_action(action: &ReferenceAction) -> ForeignKeyAction {
219 match action {
220 ReferenceAction::Cascade => ForeignKeyAction::Cascade,
221 ReferenceAction::Restrict => ForeignKeyAction::Restrict,
222 ReferenceAction::SetNull => ForeignKeyAction::SetNull,
223 ReferenceAction::SetDefault => ForeignKeyAction::SetDefault,
224 ReferenceAction::NoAction => ForeignKeyAction::NoAction,
225 _ => unreachable!("ReferenceAction is #[non_exhaustive]; all variants are matched above"),
226 }
227}
228
229pub fn reference_action_sql(action: &ReferenceAction) -> &'static str {
231 match action {
232 ReferenceAction::Cascade => "CASCADE",
233 ReferenceAction::Restrict => "RESTRICT",
234 ReferenceAction::SetNull => "SET NULL",
235 ReferenceAction::SetDefault => "SET DEFAULT",
236 ReferenceAction::NoAction => "NO ACTION",
237 _ => unreachable!("ReferenceAction is #[non_exhaustive]; all variants are matched above"),
238 }
239}
240
241pub fn convert_default_for_backend(default: &str, backend: DatabaseBackend) -> String {
243 let lower = default.to_lowercase();
244
245 if lower == "gen_random_uuid()" || lower == "uuid()" || lower == "lower(hex(randomblob(16)))" {
247 return match backend {
248 DatabaseBackend::Postgres => "gen_random_uuid()".to_string(),
249 DatabaseBackend::MySql => "(UUID())".to_string(),
250 DatabaseBackend::Sqlite => "lower(hex(randomblob(16)))".to_string(),
251 };
252 }
253
254 if lower == "current_timestamp()"
256 || lower == "now()"
257 || lower == "current_timestamp"
258 || lower == "getdate()"
259 {
260 return "CURRENT_TIMESTAMP".to_string();
261 }
262
263 if let Some((value, cast_type)) = parse_pg_type_cast(default) {
265 return convert_type_cast(&value, &cast_type, backend);
266 }
267
268 default.to_string()
269}
270
271pub(super) fn parse_pg_type_cast(expr: &str) -> Option<(String, String)> {
274 let trimmed = expr.trim();
275
276 if let Some(after_open) = trimmed.strip_prefix('\'') {
278 let mut chars = after_open.char_indices().peekable();
280 while let Some((i, ch)) = chars.next() {
281 if ch == '\'' {
282 if chars.next_if(|(_, next)| *next == '\'').is_some() {
284 continue;
285 }
286 let value_end = i + ch.len_utf8(); let rest = after_open.get(value_end..)?;
289 if let Some(stripped) = rest.strip_prefix("::") {
290 let cast_type = stripped.trim().to_lowercase();
291 if !cast_type.is_empty() {
292 let value = format!("'{}'", after_open.get(..i)?);
293 return Some((value, cast_type));
294 }
295 }
296 return None;
297 }
298 }
299 return None;
300 }
301
302 if let Some((value, cast_type)) = trimmed.split_once("::") {
304 let value = value.trim().to_string();
305 let cast_type = cast_type.trim().to_lowercase();
306 if !value.is_empty() && !cast_type.is_empty() {
307 return Some((value, cast_type));
308 }
309 }
310
311 None
312}
313
314fn pg_type_to_mysql_cast(pg_type: &str) -> &'static str {
316 match pg_type {
317 "json" | "jsonb" => "JSON",
318 "integer" | "int" | "int4" | "smallint" | "int2" | "bigint" | "int8" => "SIGNED",
319 "real" | "float4" | "double precision" | "float8" | "numeric" | "decimal" => "DECIMAL",
320 "boolean" | "bool" => "UNSIGNED",
321 "date" => "DATE",
322 "time" => "TIME",
323 "timestamp"
324 | "timestamptz"
325 | "timestamp with time zone"
326 | "timestamp without time zone" => "DATETIME",
327 "bytea" => "BINARY",
328 _ => "CHAR",
329 }
330}
331
332fn convert_type_cast(value: &str, cast_type: &str, backend: DatabaseBackend) -> String {
334 match backend {
335 DatabaseBackend::Postgres => format!("{value}::{cast_type}"),
337 DatabaseBackend::MySql => {
339 let mysql_type = pg_type_to_mysql_cast(cast_type);
340 format!("CAST({value} AS {mysql_type})")
341 }
342 DatabaseBackend::Sqlite => value.to_string(),
344 }
345}
346
347pub(super) fn is_enum_type(column_type: &ColumnType) -> bool {
349 matches!(
350 column_type,
351 ColumnType::Complex(ComplexColumnType::Enum { .. })
352 )
353}
354
355pub fn normalize_enum_default(column_type: &ColumnType, value: &str) -> String {
358 if is_enum_type(column_type) && needs_quoting(value) {
359 format!("'{value}'")
360 } else {
361 value.to_string()
362 }
363}
364
365pub(super) fn needs_quoting(default_str: &str) -> bool {
367 let trimmed = default_str.trim();
368 if trimmed.is_empty() {
370 return true;
371 }
372 if trimmed.starts_with('\'') || trimmed.starts_with('"') {
374 return false;
375 }
376 if trimmed.contains('(') || trimmed.contains(')') {
378 return false;
379 }
380 if trimmed.eq_ignore_ascii_case("null") {
382 return false;
383 }
384 if trimmed.eq_ignore_ascii_case("current_timestamp")
386 || trimmed.eq_ignore_ascii_case("current_date")
387 || trimmed.eq_ignore_ascii_case("current_time")
388 {
389 return false;
390 }
391 true
392}
393
394pub fn build_sea_column_def_with_table(
396 backend: DatabaseBackend,
397 table: &str,
398 column: &ColumnDef,
399) -> SeaColumnDef {
400 let mut col = SeaColumnDef::new(Alias::new(&column.name));
401 apply_column_type_with_table(&mut col, &column.r#type, table, backend);
402
403 if !column.nullable {
404 col.not_null();
405 }
406
407 if let Some(default) = &column.default {
408 let default_str = default.to_sql();
409 let converted = convert_default_for_backend(&default_str, backend);
410
411 let final_default =
413 if is_enum_type(&column.r#type) && default.is_string() && needs_quoting(&converted) {
414 format!("'{converted}'")
415 } else {
416 converted
417 };
418
419 let final_default = if backend == DatabaseBackend::Sqlite
422 && final_default.contains('(')
423 && !final_default.starts_with('(')
424 {
425 format!("({final_default})")
426 } else {
427 final_default
428 };
429
430 col.default(Into::<SimpleExpr>::into(sea_query::Expr::cust(
431 final_default,
432 )));
433 }
434
435 col
436}
437
438pub fn build_create_enum_type_sql(
444 table: &str,
445 column_type: &ColumnType,
446) -> Option<super::types::RawSql> {
447 if let ColumnType::Complex(ComplexColumnType::Enum { name, values }) = column_type {
448 if values.is_integer() {
450 return None;
451 }
452
453 let values_sql = values.to_sql_values().join(", ");
454
455 let type_name = build_enum_type_name(table, name);
457
458 let type_name = quote_ident(&type_name, DatabaseBackend::Postgres);
460 let pg_sql = format!("CREATE TYPE {type_name} AS ENUM ({values_sql})");
461
462 Some(super::types::RawSql::per_backend(
465 pg_sql,
466 String::new(),
467 String::new(),
468 ))
469 } else {
470 None
471 }
472}
473
474pub fn build_drop_enum_type_sql(
479 table: &str,
480 column_type: &ColumnType,
481) -> Option<super::types::RawSql> {
482 if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
483 let type_name = build_enum_type_name(table, name);
485
486 let type_name = quote_ident(&type_name, DatabaseBackend::Postgres);
488 let pg_sql = format!("DROP TYPE {type_name}");
489
490 Some(super::types::RawSql::per_backend(
492 pg_sql,
493 String::new(),
494 String::new(),
495 ))
496 } else {
497 None
498 }
499}
500
501pub use vespertide_naming::{
503 build_check_constraint_name, build_enum_type_name, build_foreign_key_name, build_index_name,
504 build_unique_constraint_name,
505};
506
507pub fn build_sqlite_enum_check_clause(
510 table: &str,
511 column: &str,
512 column_type: &ColumnType,
513) -> Option<String> {
514 if let ColumnType::Complex(ComplexColumnType::Enum { values, .. }) = column_type {
515 let name = build_check_constraint_name(table, column);
516 let values_sql = values.to_sql_values().join(", ");
517 let name = quote_ident(&name, DatabaseBackend::Sqlite);
518 let column = quote_ident(column, DatabaseBackend::Sqlite);
519 Some(format!(
520 "CONSTRAINT {name} CHECK ({column} IN ({values_sql}))"
521 ))
522 } else {
523 None
524 }
525}
526
527pub fn collect_sqlite_enum_check_clauses(table: &str, columns: &[ColumnDef]) -> Vec<String> {
529 columns
530 .iter()
531 .filter_map(|col| build_sqlite_enum_check_clause(table, &col.name, &col.r#type))
532 .collect()
533}
534
535pub fn extract_check_clauses(constraints: &[TableConstraint]) -> Vec<String> {
538 constraints
539 .iter()
540 .filter_map(|c| {
541 if let TableConstraint::Check { name, expr, .. } = c {
542 let name = quote_ident(name, DatabaseBackend::Sqlite);
543 Some(format!("CONSTRAINT {name} CHECK ({expr})"))
544 } else {
545 None
546 }
547 })
548 .collect()
549}
550
551pub fn collect_all_check_clauses(
558 table: &str,
559 columns: &[ColumnDef],
560 constraints: &[TableConstraint],
561) -> Vec<String> {
562 let mut clauses = collect_sqlite_enum_check_clauses(table, columns);
563 let explicit = extract_check_clauses(constraints);
564 for clause in explicit {
565 if !clauses.contains(&clause) {
566 clauses.push(clause);
567 }
568 }
569 clauses
570}
571
572pub fn build_create_with_checks(
576 backend: DatabaseBackend,
577 create_stmt: &sea_query::TableCreateStatement,
578 check_clauses: &[String],
579) -> BuiltQuery {
580 if check_clauses.is_empty() {
581 BuiltQuery::CreateTable(Box::new(create_stmt.clone()))
582 } else {
583 let base_sql = build_schema_statement(create_stmt, backend);
584 let mut modified_sql = base_sql;
585 if let Some(pos) = modified_sql.rfind(')') {
586 let check_sql = check_clauses.join(", ");
587 modified_sql.insert_str(pos, &format!(", {check_sql}"));
588 }
589 BuiltQuery::Raw(RawSql::per_backend(
590 modified_sql.clone(),
591 modified_sql.clone(),
592 modified_sql,
593 ))
594 }
595}
596
597pub fn build_sqlite_temp_table_create(
603 backend: DatabaseBackend,
604 temp_table: &str,
605 table: &str,
606 columns: &[ColumnDef],
607 constraints: &[TableConstraint],
608) -> BuiltQuery {
609 let create_stmt = build_create_table_for_backend(backend, temp_table, columns, constraints);
610 let check_clauses = collect_all_check_clauses(table, columns, constraints);
611 build_create_with_checks(backend, &create_stmt, &check_clauses)
612}
613
614pub fn recreate_indexes_after_rebuild(
621 table: &str,
622 constraints: &[TableConstraint],
623 pending_constraints: &[TableConstraint],
624) -> Vec<BuiltQuery> {
625 let mut queries = Vec::with_capacity(constraints.len());
627 let pending_constraints: std::collections::BTreeSet<_> = pending_constraints.iter().collect();
629 for constraint in constraints {
630 if pending_constraints.contains(constraint) {
632 continue;
633 }
634 match constraint {
635 TableConstraint::Index { name, columns } => {
636 let index_name = build_index_name(table, columns, name.as_deref());
637 let cols_sql = quote_idents(columns, DatabaseBackend::Sqlite);
638 let index_name = quote_ident(&index_name, DatabaseBackend::Sqlite);
639 let table = quote_ident(table, DatabaseBackend::Sqlite);
640 let sql = format!("CREATE INDEX {index_name} ON {table} ({cols_sql})");
641 queries.push(BuiltQuery::Raw(RawSql::per_backend(
642 sql.clone(),
643 sql.clone(),
644 sql,
645 )));
646 }
647 TableConstraint::Unique { name, columns, .. } => {
648 let index_name = build_unique_constraint_name(table, columns, name.as_deref());
649 let cols_sql = quote_idents(columns, DatabaseBackend::Sqlite);
650 let index_name = quote_ident(&index_name, DatabaseBackend::Sqlite);
651 let table = quote_ident(table, DatabaseBackend::Sqlite);
652 let sql = format!("CREATE UNIQUE INDEX {index_name} ON {table} ({cols_sql})");
653 queries.push(BuiltQuery::Raw(RawSql::per_backend(
654 sql.clone(),
655 sql.clone(),
656 sql,
657 )));
658 }
659 _ => {}
660 }
661 }
662 queries
663}
664
665pub fn get_enum_name(column_type: &ColumnType) -> Option<&str> {
667 if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
668 Some(name.as_str())
669 } else {
670 None
671 }
672}
673
674#[must_use]
683pub fn quote_ident(name: &str, backend: DatabaseBackend) -> String {
684 match backend {
685 DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
686 let escaped = name.replace('"', "\"\"");
687 format!("\"{escaped}\"")
688 }
689 DatabaseBackend::MySql => {
690 let escaped = name.replace('`', "``");
691 format!("`{escaped}`")
692 }
693 }
694}
695
696#[must_use]
698pub fn quote_idents<T: AsRef<str>>(names: &[T], backend: DatabaseBackend) -> String {
699 names
700 .iter()
701 .map(|n| quote_ident(n.as_ref(), backend))
702 .collect::<Vec<_>>()
703 .join(", ")
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709 use sea_query::{Alias, ColumnDef as SeaColDef, Table};
710
711 #[test]
715 fn build_create_with_checks_empty_clauses_returns_plain_create_table() {
716 let mut stmt = Table::create();
717 stmt.table(Alias::new("users"))
718 .col(SeaColDef::new(Alias::new("id")).integer().not_null());
719 let query = build_create_with_checks(DatabaseBackend::Postgres, &stmt, &[]);
720 let sql = query.build(DatabaseBackend::Postgres);
721 assert!(
722 sql.contains("CREATE TABLE"),
723 "expected CREATE TABLE in: {sql}"
724 );
725 assert!(
727 !sql.contains("CHECK ("),
728 "no CHECK should be injected: {sql}"
729 );
730 assert!(
732 matches!(query, BuiltQuery::CreateTable(_)),
733 "empty-checks branch must return BuiltQuery::CreateTable"
734 );
735 }
736}