Skip to main content

vespertide_query/sql/
helpers.rs

1use std::borrow::Cow;
2
3use sea_query::{
4    Alias, ColumnDef as SeaColumnDef, ForeignKeyAction, MysqlQueryBuilder, PostgresQueryBuilder,
5    QueryStatementWriter, SchemaStatementBuilder, SimpleExpr, SqliteQueryBuilder,
6};
7
8use vespertide_core::{
9    ColumnDef, ColumnType, ComplexColumnType, ReferenceAction, SimpleColumnType, TableConstraint,
10};
11
12use super::create_table::build_create_table_for_backend;
13use super::types::{BuiltQuery, DatabaseBackend, RawSql};
14
15/// Normalize `fill_with` value - empty string becomes '' (SQL empty string literal)
16/// Returns a Cow to avoid allocations when possible.
17#[must_use]
18pub fn normalize_fill_with(fill_with: Option<&str>) -> Option<Cow<'_, str>> {
19    fill_with.map(|s| {
20        if s.is_empty() {
21            Cow::Borrowed("''")
22        } else {
23            Cow::Borrowed(s)
24        }
25    })
26}
27
28/// Helper function to convert a schema statement to SQL for a specific backend
29pub fn build_schema_statement<T: SchemaStatementBuilder>(
30    stmt: &T,
31    backend: DatabaseBackend,
32) -> String {
33    match backend {
34        DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
35        DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
36        DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
37    }
38}
39
40/// Helper function to convert a query statement (INSERT, SELECT, etc.) to SQL for a specific backend
41pub fn build_query_statement<T: QueryStatementWriter>(
42    stmt: &T,
43    backend: DatabaseBackend,
44) -> String {
45    match backend {
46        DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
47        DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
48        DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
49    }
50}
51
52/// Apply vespertide `ColumnType` to `sea_query` `ColumnDef` with table-aware enum type naming
53pub fn apply_column_type_with_table(
54    col: &mut SeaColumnDef,
55    ty: &ColumnType,
56    table: &str,
57    backend: DatabaseBackend,
58) {
59    match ty {
60        ColumnType::Simple(simple) => apply_simple_column_type(col, *simple, backend),
61        ColumnType::Complex(complex) => apply_complex_column_type(col, complex, table, backend),
62    }
63}
64
65fn apply_simple_column_type(
66    col: &mut SeaColumnDef,
67    simple: SimpleColumnType,
68    backend: DatabaseBackend,
69) {
70    match simple {
71        SimpleColumnType::SmallInt => {
72            col.small_integer();
73        }
74        SimpleColumnType::Integer => {
75            col.integer();
76        }
77        SimpleColumnType::BigInt => {
78            col.big_integer();
79        }
80        SimpleColumnType::Real => {
81            col.float();
82        }
83        SimpleColumnType::DoublePrecision => {
84            col.double();
85        }
86        SimpleColumnType::Text => {
87            col.text();
88        }
89        SimpleColumnType::Boolean => {
90            col.boolean();
91        }
92        SimpleColumnType::Date => {
93            col.date();
94        }
95        SimpleColumnType::Time => {
96            col.time();
97        }
98        SimpleColumnType::Timestamp => {
99            col.timestamp();
100        }
101        SimpleColumnType::Timestamptz => apply_timestamptz_type(col, backend),
102        SimpleColumnType::Interval => apply_interval_type(col, backend),
103        SimpleColumnType::Bytea => {
104            col.binary();
105        }
106        SimpleColumnType::Uuid => {
107            col.uuid();
108        }
109        SimpleColumnType::Json => {
110            col.json();
111        }
112        SimpleColumnType::Inet => apply_postgres_text_fallback_type(col, backend, "INET"),
113        SimpleColumnType::Cidr => apply_postgres_text_fallback_type(col, backend, "CIDR"),
114        SimpleColumnType::Macaddr => apply_postgres_text_fallback_type(col, backend, "MACADDR"),
115        SimpleColumnType::Xml => apply_postgres_text_fallback_type(col, backend, "XML"),
116        _ => unreachable!("SimpleColumnType is #[non_exhaustive]; all variants are matched above"),
117    }
118}
119
120fn apply_timestamptz_type(col: &mut SeaColumnDef, backend: DatabaseBackend) {
121    match backend {
122        DatabaseBackend::Postgres => {
123            col.timestamp_with_time_zone();
124        }
125        DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
126            col.timestamp();
127        }
128    }
129}
130
131fn apply_interval_type(col: &mut SeaColumnDef, backend: DatabaseBackend) {
132    match backend {
133        DatabaseBackend::Postgres => {
134            col.interval(None, None);
135        }
136        DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
137            col.text();
138        }
139    }
140}
141
142fn apply_postgres_text_fallback_type(
143    col: &mut SeaColumnDef,
144    backend: DatabaseBackend,
145    postgres_type: &str,
146) {
147    match backend {
148        DatabaseBackend::Postgres => {
149            col.custom(Alias::new(postgres_type));
150        }
151        DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
152            col.text();
153        }
154    }
155}
156
157fn apply_complex_column_type(
158    col: &mut SeaColumnDef,
159    complex: &ComplexColumnType,
160    table: &str,
161    backend: DatabaseBackend,
162) {
163    match complex {
164        ComplexColumnType::Varchar { length } => {
165            col.string_len(*length);
166        }
167        ComplexColumnType::Numeric { precision, scale } => {
168            apply_numeric_type(col, *precision, *scale, backend);
169        }
170        ComplexColumnType::Char { length } => {
171            col.char_len(*length);
172        }
173        ComplexColumnType::Custom { custom_type } => {
174            col.custom(Alias::new(custom_type));
175        }
176        ComplexColumnType::Enum { name, values } => {
177            // For integer enums, use INTEGER type instead of ENUM
178            if values.is_integer() {
179                col.integer();
180            } else {
181                // Use table-prefixed enum type name to avoid conflicts
182                let type_name = build_enum_type_name(table, name);
183                let variants = values
184                    .variant_names()
185                    .into_iter()
186                    .map(Alias::new)
187                    .collect::<Vec<Alias>>();
188                col.enumeration(Alias::new(&type_name), variants);
189            }
190        }
191        _ => unreachable!("ComplexColumnType is #[non_exhaustive]; all variants are matched above"),
192    }
193}
194
195fn apply_numeric_type(
196    col: &mut SeaColumnDef,
197    precision: u32,
198    scale: u32,
199    backend: DatabaseBackend,
200) {
201    debug_assert!(
202        scale <= precision,
203        "numeric scale ({scale}) must be <= precision ({precision}); schema validation should reject this before SQL generation"
204    );
205    let safe_precision = precision.min(28);
206    let safe_scale = scale.min(safe_precision);
207    match backend {
208        DatabaseBackend::Postgres | DatabaseBackend::MySql => {
209            col.decimal_len(safe_precision, safe_scale);
210        }
211        DatabaseBackend::Sqlite => {
212            col.double();
213        }
214    }
215}
216
217/// Convert vespertide `ReferenceAction` to `sea_query` `ForeignKeyAction`
218pub fn to_sea_fk_action(action: &ReferenceAction) -> ForeignKeyAction {
219    match action {
220        ReferenceAction::Cascade => ForeignKeyAction::Cascade,
221        ReferenceAction::Restrict => ForeignKeyAction::Restrict,
222        ReferenceAction::SetNull => ForeignKeyAction::SetNull,
223        ReferenceAction::SetDefault => ForeignKeyAction::SetDefault,
224        ReferenceAction::NoAction => ForeignKeyAction::NoAction,
225        _ => unreachable!("ReferenceAction is #[non_exhaustive]; all variants are matched above"),
226    }
227}
228
229/// Convert vespertide `ReferenceAction` to SQL string
230pub fn reference_action_sql(action: &ReferenceAction) -> &'static str {
231    match action {
232        ReferenceAction::Cascade => "CASCADE",
233        ReferenceAction::Restrict => "RESTRICT",
234        ReferenceAction::SetNull => "SET NULL",
235        ReferenceAction::SetDefault => "SET DEFAULT",
236        ReferenceAction::NoAction => "NO ACTION",
237        _ => unreachable!("ReferenceAction is #[non_exhaustive]; all variants are matched above"),
238    }
239}
240
241/// Convert a default value string to the appropriate backend-specific expression
242pub fn convert_default_for_backend(default: &str, backend: DatabaseBackend) -> String {
243    let lower = default.to_lowercase();
244
245    // UUID generation functions
246    if lower == "gen_random_uuid()" || lower == "uuid()" || lower == "lower(hex(randomblob(16)))" {
247        return match backend {
248            DatabaseBackend::Postgres => "gen_random_uuid()".to_string(),
249            DatabaseBackend::MySql => "(UUID())".to_string(),
250            DatabaseBackend::Sqlite => "lower(hex(randomblob(16)))".to_string(),
251        };
252    }
253
254    // Timestamp functions (case-insensitive)
255    if lower == "current_timestamp()"
256        || lower == "now()"
257        || lower == "current_timestamp"
258        || lower == "getdate()"
259    {
260        return "CURRENT_TIMESTAMP".to_string();
261    }
262
263    // PostgreSQL-style type casts: 'value'::type or expr::type
264    if let Some((value, cast_type)) = parse_pg_type_cast(default) {
265        return convert_type_cast(&value, &cast_type, backend);
266    }
267
268    default.to_string()
269}
270
271/// Parse a PostgreSQL-style type cast expression (e.g., `'[]'::json`, `0::boolean`)
272/// Returns `(value, type)` if parsed, or None if not a type cast.
273pub(super) fn parse_pg_type_cast(expr: &str) -> Option<(String, String)> {
274    let trimmed = expr.trim();
275
276    // Handle quoted values: 'value'::type
277    if let Some(after_open) = trimmed.strip_prefix('\'') {
278        // Find the closing quote (handle escaped quotes '')
279        let mut chars = after_open.char_indices().peekable();
280        while let Some((i, ch)) = chars.next() {
281            if ch == '\'' {
282                // Check for escaped quote ''
283                if chars.next_if(|(_, next)| *next == '\'').is_some() {
284                    continue;
285                }
286                // Found closing quote
287                let value_end = i + ch.len_utf8(); // index in `after_open`
288                let rest = after_open.get(value_end..)?;
289                if let Some(stripped) = rest.strip_prefix("::") {
290                    let cast_type = stripped.trim().to_lowercase();
291                    if !cast_type.is_empty() {
292                        let value = format!("'{}'", after_open.get(..i)?);
293                        return Some((value, cast_type));
294                    }
295                }
296                return None;
297            }
298        }
299        return None;
300    }
301
302    // Handle unquoted values: expr::type (e.g., 0::boolean, NULL::json)
303    if let Some((value, cast_type)) = trimmed.split_once("::") {
304        let value = value.trim().to_string();
305        let cast_type = cast_type.trim().to_lowercase();
306        if !value.is_empty() && !cast_type.is_empty() {
307            return Some((value, cast_type));
308        }
309    }
310
311    None
312}
313
314/// Map `PostgreSQL` type name to `MySQL` CAST target type
315fn pg_type_to_mysql_cast(pg_type: &str) -> &'static str {
316    match pg_type {
317        "json" | "jsonb" => "JSON",
318        "integer" | "int" | "int4" | "smallint" | "int2" | "bigint" | "int8" => "SIGNED",
319        "real" | "float4" | "double precision" | "float8" | "numeric" | "decimal" => "DECIMAL",
320        "boolean" | "bool" => "UNSIGNED",
321        "date" => "DATE",
322        "time" => "TIME",
323        "timestamp"
324        | "timestamptz"
325        | "timestamp with time zone"
326        | "timestamp without time zone" => "DATETIME",
327        "bytea" => "BINARY",
328        _ => "CHAR",
329    }
330}
331
332/// Convert a type cast expression to the appropriate backend syntax
333fn convert_type_cast(value: &str, cast_type: &str, backend: DatabaseBackend) -> String {
334    match backend {
335        // PostgreSQL: keep native :: syntax
336        DatabaseBackend::Postgres => format!("{value}::{cast_type}"),
337        // MySQL: CAST(value AS type)
338        DatabaseBackend::MySql => {
339            let mysql_type = pg_type_to_mysql_cast(cast_type);
340            format!("CAST({value} AS {mysql_type})")
341        }
342        // SQLite: strip the cast, use raw value (SQLite is dynamically typed)
343        DatabaseBackend::Sqlite => value.to_string(),
344    }
345}
346
347/// Check if the column type is an enum type
348pub(super) fn is_enum_type(column_type: &ColumnType) -> bool {
349    matches!(
350        column_type,
351        ColumnType::Complex(ComplexColumnType::Enum { .. })
352    )
353}
354
355/// Normalize a default value for enum columns - add quotes if needed
356/// This is used for SQL expressions (INSERT, UPDATE) where enum values need quoting
357pub fn normalize_enum_default(column_type: &ColumnType, value: &str) -> String {
358    if is_enum_type(column_type) && needs_quoting(value) {
359        format!("'{value}'")
360    } else {
361        value.to_string()
362    }
363}
364
365/// Check if a string default value needs quoting (is a plain string literal without quotes/parens)
366pub(super) fn needs_quoting(default_str: &str) -> bool {
367    let trimmed = default_str.trim();
368    // Empty string always needs quoting to become ''
369    if trimmed.is_empty() {
370        return true;
371    }
372    // Don't quote if already quoted
373    if trimmed.starts_with('\'') || trimmed.starts_with('"') {
374        return false;
375    }
376    // Don't quote if it's a function call
377    if trimmed.contains('(') || trimmed.contains(')') {
378        return false;
379    }
380    // Don't quote NULL
381    if trimmed.eq_ignore_ascii_case("null") {
382        return false;
383    }
384    // Don't quote special SQL keywords
385    if trimmed.eq_ignore_ascii_case("current_timestamp")
386        || trimmed.eq_ignore_ascii_case("current_date")
387        || trimmed.eq_ignore_ascii_case("current_time")
388    {
389        return false;
390    }
391    true
392}
393
394/// Build `sea_query` `ColumnDef` from vespertide `ColumnDef` for a specific backend with table-aware enum naming
395pub fn build_sea_column_def_with_table(
396    backend: DatabaseBackend,
397    table: &str,
398    column: &ColumnDef,
399) -> SeaColumnDef {
400    let mut col = SeaColumnDef::new(Alias::new(&column.name));
401    apply_column_type_with_table(&mut col, &column.r#type, table, backend);
402
403    if !column.nullable {
404        col.not_null();
405    }
406
407    if let Some(default) = &column.default {
408        let default_str = default.to_sql();
409        let converted = convert_default_for_backend(&default_str, backend);
410
411        // Auto-quote enum default values if the value is a string and needs quoting
412        let final_default =
413            if is_enum_type(&column.r#type) && default.is_string() && needs_quoting(&converted) {
414                format!("'{converted}'")
415            } else {
416                converted
417            };
418
419        // SQLite requires DEFAULT (expr) for expressions containing function calls.
420        // Wrapping in parentheses is always safe for all backends.
421        let final_default = if backend == DatabaseBackend::Sqlite
422            && final_default.contains('(')
423            && !final_default.starts_with('(')
424        {
425            format!("({final_default})")
426        } else {
427            final_default
428        };
429
430        col.default(Into::<SimpleExpr>::into(sea_query::Expr::cust(
431            final_default,
432        )));
433    }
434
435    col
436}
437
438/// Generate CREATE TYPE SQL for an enum type (`PostgreSQL` only)
439/// Returns None for non-PostgreSQL backends or non-enum types
440///
441/// The enum type name will be prefixed with the table name to avoid conflicts
442/// across tables using the same enum name (e.g., "status", "gender").
443pub fn build_create_enum_type_sql(
444    table: &str,
445    column_type: &ColumnType,
446) -> Option<super::types::RawSql> {
447    if let ColumnType::Complex(ComplexColumnType::Enum { name, values }) = column_type {
448        // Integer enums don't need CREATE TYPE - they use INTEGER column
449        if values.is_integer() {
450            return None;
451        }
452
453        let values_sql = values.to_sql_values().join(", ");
454
455        // Generate unique type name with table prefix
456        let type_name = build_enum_type_name(table, name);
457
458        // PostgreSQL: CREATE TYPE {table}_{name} AS ENUM (...)
459        let type_name = quote_ident(&type_name, DatabaseBackend::Postgres);
460        let pg_sql = format!("CREATE TYPE {type_name} AS ENUM ({values_sql})");
461
462        // MySQL: ENUMs are inline, no CREATE TYPE needed
463        // SQLite: Uses TEXT, no CREATE TYPE needed
464        Some(super::types::RawSql::per_backend(
465            pg_sql,
466            String::new(),
467            String::new(),
468        ))
469    } else {
470        None
471    }
472}
473
474/// Generate DROP TYPE SQL for an enum type (`PostgreSQL` only)
475/// Returns None for non-PostgreSQL backends or non-enum types
476///
477/// The enum type name will be prefixed with the table name to match the CREATE TYPE.
478pub fn build_drop_enum_type_sql(
479    table: &str,
480    column_type: &ColumnType,
481) -> Option<super::types::RawSql> {
482    if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
483        // Generate the same unique type name used in CREATE TYPE
484        let type_name = build_enum_type_name(table, name);
485
486        // PostgreSQL: DROP TYPE {table}_{name}
487        let type_name = quote_ident(&type_name, DatabaseBackend::Postgres);
488        let pg_sql = format!("DROP TYPE {type_name}");
489
490        // MySQL/SQLite: No action needed
491        Some(super::types::RawSql::per_backend(
492            pg_sql,
493            String::new(),
494            String::new(),
495        ))
496    } else {
497        None
498    }
499}
500
501// Re-export naming functions from vespertide-naming
502pub use vespertide_naming::{
503    build_check_constraint_name, build_enum_type_name, build_foreign_key_name, build_index_name,
504    build_unique_constraint_name,
505};
506
507/// Generate CHECK constraint expression for `SQLite` enum column
508/// Returns the constraint clause like: CONSTRAINT "`chk_table_col`" CHECK (col IN ('val1', 'val2'))
509pub fn build_sqlite_enum_check_clause(
510    table: &str,
511    column: &str,
512    column_type: &ColumnType,
513) -> Option<String> {
514    if let ColumnType::Complex(ComplexColumnType::Enum { values, .. }) = column_type {
515        let name = build_check_constraint_name(table, column);
516        let values_sql = values.to_sql_values().join(", ");
517        let name = quote_ident(&name, DatabaseBackend::Sqlite);
518        let column = quote_ident(column, DatabaseBackend::Sqlite);
519        Some(format!(
520            "CONSTRAINT {name} CHECK ({column} IN ({values_sql}))"
521        ))
522    } else {
523        None
524    }
525}
526
527/// Collect all CHECK constraints for enum columns in a table (for `SQLite`)
528pub fn collect_sqlite_enum_check_clauses(table: &str, columns: &[ColumnDef]) -> Vec<String> {
529    columns
530        .iter()
531        .filter_map(|col| build_sqlite_enum_check_clause(table, &col.name, &col.r#type))
532        .collect()
533}
534
535/// Extract CHECK constraint clauses from a list of table constraints.
536/// Returns SQL fragments like: `CONSTRAINT "chk_name" CHECK (expr)`
537pub fn extract_check_clauses(constraints: &[TableConstraint]) -> Vec<String> {
538    constraints
539        .iter()
540        .filter_map(|c| {
541            if let TableConstraint::Check { name, expr, .. } = c {
542                let name = quote_ident(name, DatabaseBackend::Sqlite);
543                Some(format!("CONSTRAINT {name} CHECK ({expr})"))
544            } else {
545                None
546            }
547        })
548        .collect()
549}
550
551/// Collect ALL CHECK constraint clauses for a `SQLite` temp table.
552/// Combines both:
553/// - Enum-based CHECK constraints (from column types)
554/// - Explicit CHECK constraints (from `TableConstraint::Check`)
555///
556/// Returns deduplicated union of both.
557pub fn collect_all_check_clauses(
558    table: &str,
559    columns: &[ColumnDef],
560    constraints: &[TableConstraint],
561) -> Vec<String> {
562    let mut clauses = collect_sqlite_enum_check_clauses(table, columns);
563    let explicit = extract_check_clauses(constraints);
564    for clause in explicit {
565        if !clauses.contains(&clause) {
566            clauses.push(clause);
567        }
568    }
569    clauses
570}
571
572/// Build CREATE TABLE query with CHECK constraints properly embedded.
573/// sea-query doesn't support CHECK constraints natively, so we inject them
574/// by modifying the generated SQL string.
575pub fn build_create_with_checks(
576    backend: DatabaseBackend,
577    create_stmt: &sea_query::TableCreateStatement,
578    check_clauses: &[String],
579) -> BuiltQuery {
580    if check_clauses.is_empty() {
581        BuiltQuery::CreateTable(Box::new(create_stmt.clone()))
582    } else {
583        let base_sql = build_schema_statement(create_stmt, backend);
584        let mut modified_sql = base_sql;
585        if let Some(pos) = modified_sql.rfind(')') {
586            let check_sql = check_clauses.join(", ");
587            modified_sql.insert_str(pos, &format!(", {check_sql}"));
588        }
589        BuiltQuery::Raw(RawSql::per_backend(
590            modified_sql.clone(),
591            modified_sql.clone(),
592            modified_sql,
593        ))
594    }
595}
596
597/// Build the CREATE TABLE statement for a `SQLite` temp table, including all CHECK constraints.
598/// This combines `build_create_table_for_backend` with CHECK constraint injection.
599///
600/// `table` is the ORIGINAL table name (used for constraint naming).
601/// `temp_table` is the temporary table name.
602pub fn build_sqlite_temp_table_create(
603    backend: DatabaseBackend,
604    temp_table: &str,
605    table: &str,
606    columns: &[ColumnDef],
607    constraints: &[TableConstraint],
608) -> BuiltQuery {
609    let create_stmt = build_create_table_for_backend(backend, temp_table, columns, constraints);
610    let check_clauses = collect_all_check_clauses(table, columns, constraints);
611    build_create_with_checks(backend, &create_stmt, &check_clauses)
612}
613
614/// Recreate all indexes (both regular and UNIQUE) after a `SQLite` temp table rebuild.
615/// After DROP TABLE + RENAME, all original indexes are gone, so plain CREATE INDEX is correct.
616///
617/// `pending_constraints` are constraints that exist in the logical schema but haven't been
618/// physically created yet (e.g., promoted from inline column definitions by `AddColumn` normalization).
619/// These will be created by separate `AddConstraint` actions later, so we must NOT recreate them here.
620pub fn recreate_indexes_after_rebuild(
621    table: &str,
622    constraints: &[TableConstraint],
623    pending_constraints: &[TableConstraint],
624) -> Vec<BuiltQuery> {
625    // perf: capacity follows the upper bound of emitted index queries, avoiding reallocations.
626    let mut queries = Vec::with_capacity(constraints.len());
627    // perf: BTreeSet membership avoids nested Vec::contains scans during SQLite rebuilds.
628    let pending_constraints: std::collections::BTreeSet<_> = pending_constraints.iter().collect();
629    for constraint in constraints {
630        // Skip constraints that will be created by future AddConstraint actions
631        if pending_constraints.contains(constraint) {
632            continue;
633        }
634        match constraint {
635            TableConstraint::Index { name, columns } => {
636                let index_name = build_index_name(table, columns, name.as_deref());
637                let cols_sql = quote_idents(columns, DatabaseBackend::Sqlite);
638                let index_name = quote_ident(&index_name, DatabaseBackend::Sqlite);
639                let table = quote_ident(table, DatabaseBackend::Sqlite);
640                let sql = format!("CREATE INDEX {index_name} ON {table} ({cols_sql})");
641                queries.push(BuiltQuery::Raw(RawSql::per_backend(
642                    sql.clone(),
643                    sql.clone(),
644                    sql,
645                )));
646            }
647            TableConstraint::Unique { name, columns, .. } => {
648                let index_name = build_unique_constraint_name(table, columns, name.as_deref());
649                let cols_sql = quote_idents(columns, DatabaseBackend::Sqlite);
650                let index_name = quote_ident(&index_name, DatabaseBackend::Sqlite);
651                let table = quote_ident(table, DatabaseBackend::Sqlite);
652                let sql = format!("CREATE UNIQUE INDEX {index_name} ON {table} ({cols_sql})");
653                queries.push(BuiltQuery::Raw(RawSql::per_backend(
654                    sql.clone(),
655                    sql.clone(),
656                    sql,
657                )));
658            }
659            _ => {}
660        }
661    }
662    queries
663}
664
665/// Extract enum name from column type if it's an enum
666pub fn get_enum_name(column_type: &ColumnType) -> Option<&str> {
667    if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
668        Some(name.as_str())
669    } else {
670        None
671    }
672}
673
674/// Quote an identifier (table name, column name, constraint name) for the given backend.
675///
676/// Escapes any quote characters within the identifier to prevent SQL injection
677/// via malicious model names (defense-in-depth; identifier validation upstream
678/// is the primary defense).
679///
680/// - `PostgreSQL` / `SQLite`: `"identifier"` (double quotes; embedded `"` escaped as `""`)
681/// - `MySQL`: `` `identifier` `` (backticks; embedded `` ` `` escaped as ` `` `)
682#[must_use]
683pub fn quote_ident(name: &str, backend: DatabaseBackend) -> String {
684    match backend {
685        DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
686            let escaped = name.replace('"', "\"\"");
687            format!("\"{escaped}\"")
688        }
689        DatabaseBackend::MySql => {
690            let escaped = name.replace('`', "``");
691            format!("`{escaped}`")
692        }
693    }
694}
695
696/// Quote a list of identifiers and join them with comma.
697#[must_use]
698pub fn quote_idents<T: AsRef<str>>(names: &[T], backend: DatabaseBackend) -> String {
699    names
700        .iter()
701        .map(|n| quote_ident(n.as_ref(), backend))
702        .collect::<Vec<_>>()
703        .join(", ")
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709    use sea_query::{Alias, ColumnDef as SeaColDef, Table};
710
711    /// `build_create_with_checks` early-returns a plain `CreateTable` query
712    /// when `check_clauses` is empty (no string-injection round-trip).
713    /// Covers the `if check_clauses.is_empty() { ... }` true-branch.
714    #[test]
715    fn build_create_with_checks_empty_clauses_returns_plain_create_table() {
716        let mut stmt = Table::create();
717        stmt.table(Alias::new("users"))
718            .col(SeaColDef::new(Alias::new("id")).integer().not_null());
719        let query = build_create_with_checks(DatabaseBackend::Postgres, &stmt, &[]);
720        let sql = query.build(DatabaseBackend::Postgres);
721        assert!(
722            sql.contains("CREATE TABLE"),
723            "expected CREATE TABLE in: {sql}"
724        );
725        // No CHECK clauses appended.
726        assert!(
727            !sql.contains("CHECK ("),
728            "no CHECK should be injected: {sql}"
729        );
730        // The empty-branch path returns a `CreateTable` variant (not `Raw`).
731        assert!(
732            matches!(query, BuiltQuery::CreateTable(_)),
733            "empty-checks branch must return BuiltQuery::CreateTable"
734        );
735    }
736}