vespertide_query/sql/
helpers.rs

1use sea_query::{
2    Alias, ColumnDef as SeaColumnDef, ForeignKeyAction, MysqlQueryBuilder, PostgresQueryBuilder,
3    QueryStatementWriter, SchemaStatementBuilder, SimpleExpr, SqliteQueryBuilder,
4};
5
6use vespertide_core::{
7    ColumnDef, ColumnType, ComplexColumnType, ReferenceAction, SimpleColumnType,
8};
9
10use super::types::DatabaseBackend;
11
12/// Normalize fill_with value - empty string becomes '' (SQL empty string literal)
13pub fn normalize_fill_with(fill_with: Option<&str>) -> Option<String> {
14    fill_with.map(|s| {
15        if s.is_empty() {
16            "''".to_string()
17        } else {
18            s.to_string()
19        }
20    })
21}
22
23/// Helper function to convert a schema statement to SQL for a specific backend
24pub fn build_schema_statement<T: SchemaStatementBuilder>(
25    stmt: &T,
26    backend: DatabaseBackend,
27) -> String {
28    match backend {
29        DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
30        DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
31        DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
32    }
33}
34
35/// Helper function to convert a query statement (INSERT, SELECT, etc.) to SQL for a specific backend
36pub fn build_query_statement<T: QueryStatementWriter>(
37    stmt: &T,
38    backend: DatabaseBackend,
39) -> String {
40    match backend {
41        DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
42        DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
43        DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
44    }
45}
46
47/// Apply vespertide ColumnType to sea_query ColumnDef with table-aware enum type naming
48pub fn apply_column_type_with_table(col: &mut SeaColumnDef, ty: &ColumnType, table: &str) {
49    match ty {
50        ColumnType::Simple(simple) => match simple {
51            SimpleColumnType::SmallInt => {
52                col.small_integer();
53            }
54            SimpleColumnType::Integer => {
55                col.integer();
56            }
57            SimpleColumnType::BigInt => {
58                col.big_integer();
59            }
60            SimpleColumnType::Real => {
61                col.float();
62            }
63            SimpleColumnType::DoublePrecision => {
64                col.double();
65            }
66            SimpleColumnType::Text => {
67                col.text();
68            }
69            SimpleColumnType::Boolean => {
70                col.boolean();
71            }
72            SimpleColumnType::Date => {
73                col.date();
74            }
75            SimpleColumnType::Time => {
76                col.time();
77            }
78            SimpleColumnType::Timestamp => {
79                col.timestamp();
80            }
81            SimpleColumnType::Timestamptz => {
82                col.timestamp_with_time_zone();
83            }
84            SimpleColumnType::Interval => {
85                col.interval(None, None);
86            }
87            SimpleColumnType::Bytea => {
88                col.binary();
89            }
90            SimpleColumnType::Uuid => {
91                col.uuid();
92            }
93            SimpleColumnType::Json => {
94                col.json();
95            }
96            SimpleColumnType::Inet => {
97                col.custom(Alias::new("INET"));
98            }
99            SimpleColumnType::Cidr => {
100                col.custom(Alias::new("CIDR"));
101            }
102            SimpleColumnType::Macaddr => {
103                col.custom(Alias::new("MACADDR"));
104            }
105            SimpleColumnType::Xml => {
106                col.custom(Alias::new("XML"));
107            }
108        },
109        ColumnType::Complex(complex) => match complex {
110            ComplexColumnType::Varchar { length } => {
111                col.string_len(*length);
112            }
113            ComplexColumnType::Numeric { precision, scale } => {
114                col.decimal_len(*precision, *scale);
115            }
116            ComplexColumnType::Char { length } => {
117                col.char_len(*length);
118            }
119            ComplexColumnType::Custom { custom_type } => {
120                col.custom(Alias::new(custom_type));
121            }
122            ComplexColumnType::Enum { name, values } => {
123                // For integer enums, use INTEGER type instead of ENUM
124                if values.is_integer() {
125                    col.integer();
126                } else {
127                    // Use table-prefixed enum type name to avoid conflicts
128                    let type_name = build_enum_type_name(table, name);
129                    col.enumeration(
130                        Alias::new(&type_name),
131                        values
132                            .variant_names()
133                            .into_iter()
134                            .map(Alias::new)
135                            .collect::<Vec<Alias>>(),
136                    );
137                }
138            }
139        },
140    }
141}
142
143/// Convert vespertide ReferenceAction to sea_query ForeignKeyAction
144pub fn to_sea_fk_action(action: &ReferenceAction) -> ForeignKeyAction {
145    match action {
146        ReferenceAction::Cascade => ForeignKeyAction::Cascade,
147        ReferenceAction::Restrict => ForeignKeyAction::Restrict,
148        ReferenceAction::SetNull => ForeignKeyAction::SetNull,
149        ReferenceAction::SetDefault => ForeignKeyAction::SetDefault,
150        ReferenceAction::NoAction => ForeignKeyAction::NoAction,
151    }
152}
153
154/// Convert vespertide ReferenceAction to SQL string
155pub fn reference_action_sql(action: &ReferenceAction) -> &'static str {
156    match action {
157        ReferenceAction::Cascade => "CASCADE",
158        ReferenceAction::Restrict => "RESTRICT",
159        ReferenceAction::SetNull => "SET NULL",
160        ReferenceAction::SetDefault => "SET DEFAULT",
161        ReferenceAction::NoAction => "NO ACTION",
162    }
163}
164
165/// Convert a default value string to the appropriate backend-specific expression
166pub fn convert_default_for_backend(default: &str, backend: &DatabaseBackend) -> String {
167    let lower = default.to_lowercase();
168
169    // UUID generation functions
170    if lower == "gen_random_uuid()" || lower == "uuid()" || lower == "lower(hex(randomblob(16)))" {
171        return match backend {
172            DatabaseBackend::Postgres => "gen_random_uuid()".to_string(),
173            DatabaseBackend::MySql => "(UUID())".to_string(),
174            DatabaseBackend::Sqlite => "lower(hex(randomblob(16)))".to_string(),
175        };
176    }
177
178    // Timestamp functions (case-insensitive)
179    if lower == "current_timestamp()"
180        || lower == "now()"
181        || lower == "current_timestamp"
182        || lower == "getdate()"
183    {
184        return match backend {
185            DatabaseBackend::Postgres => "CURRENT_TIMESTAMP".to_string(),
186            DatabaseBackend::MySql => "CURRENT_TIMESTAMP".to_string(),
187            DatabaseBackend::Sqlite => "CURRENT_TIMESTAMP".to_string(),
188        };
189    }
190
191    default.to_string()
192}
193
194/// Check if the column type is an enum type
195fn is_enum_type(column_type: &ColumnType) -> bool {
196    matches!(
197        column_type,
198        ColumnType::Complex(ComplexColumnType::Enum { .. })
199    )
200}
201
202/// Normalize a default value for enum columns - add quotes if needed
203/// This is used for SQL expressions (INSERT, UPDATE) where enum values need quoting
204pub fn normalize_enum_default(column_type: &ColumnType, value: &str) -> String {
205    if is_enum_type(column_type) && needs_quoting(value) {
206        format!("'{}'", value)
207    } else {
208        value.to_string()
209    }
210}
211
212/// Check if a string default value needs quoting (is a plain string literal without quotes/parens)
213fn needs_quoting(default_str: &str) -> bool {
214    let trimmed = default_str.trim();
215    // Empty string always needs quoting to become ''
216    if trimmed.is_empty() {
217        return true;
218    }
219    // Don't quote if already quoted
220    if trimmed.starts_with('\'') || trimmed.starts_with('"') {
221        return false;
222    }
223    // Don't quote if it's a function call
224    if trimmed.contains('(') || trimmed.contains(')') {
225        return false;
226    }
227    // Don't quote NULL
228    if trimmed.eq_ignore_ascii_case("null") {
229        return false;
230    }
231    // Don't quote special SQL keywords
232    if trimmed.eq_ignore_ascii_case("current_timestamp")
233        || trimmed.eq_ignore_ascii_case("current_date")
234        || trimmed.eq_ignore_ascii_case("current_time")
235    {
236        return false;
237    }
238    true
239}
240
241/// Build sea_query ColumnDef from vespertide ColumnDef for a specific backend with table-aware enum naming
242pub fn build_sea_column_def_with_table(
243    backend: &DatabaseBackend,
244    table: &str,
245    column: &ColumnDef,
246) -> SeaColumnDef {
247    let mut col = SeaColumnDef::new(Alias::new(&column.name));
248    apply_column_type_with_table(&mut col, &column.r#type, table);
249
250    if !column.nullable {
251        col.not_null();
252    }
253
254    if let Some(default) = &column.default {
255        let default_str = default.to_sql();
256        let converted = convert_default_for_backend(&default_str, backend);
257
258        // Auto-quote enum default values if the value is a string and needs quoting
259        let final_default =
260            if is_enum_type(&column.r#type) && default.is_string() && needs_quoting(&converted) {
261                format!("'{}'", converted)
262            } else {
263                converted
264            };
265
266        col.default(Into::<SimpleExpr>::into(sea_query::Expr::cust(
267            final_default,
268        )));
269    }
270
271    col
272}
273
274/// Generate CREATE TYPE SQL for an enum type (PostgreSQL only)
275/// Returns None for non-PostgreSQL backends or non-enum types
276///
277/// The enum type name will be prefixed with the table name to avoid conflicts
278/// across tables using the same enum name (e.g., "status", "gender").
279pub fn build_create_enum_type_sql(
280    table: &str,
281    column_type: &ColumnType,
282) -> Option<super::types::RawSql> {
283    if let ColumnType::Complex(ComplexColumnType::Enum { name, values }) = column_type {
284        // Integer enums don't need CREATE TYPE - they use INTEGER column
285        if values.is_integer() {
286            return None;
287        }
288
289        let values_sql = values.to_sql_values().join(", ");
290
291        // Generate unique type name with table prefix
292        let type_name = build_enum_type_name(table, name);
293
294        // PostgreSQL: CREATE TYPE {table}_{name} AS ENUM (...)
295        let pg_sql = format!("CREATE TYPE \"{}\" AS ENUM ({})", type_name, values_sql);
296
297        // MySQL: ENUMs are inline, no CREATE TYPE needed
298        // SQLite: Uses TEXT, no CREATE TYPE needed
299        Some(super::types::RawSql::per_backend(
300            pg_sql,
301            String::new(),
302            String::new(),
303        ))
304    } else {
305        None
306    }
307}
308
309/// Generate DROP TYPE SQL for an enum type (PostgreSQL only)
310/// Returns None for non-PostgreSQL backends or non-enum types
311///
312/// The enum type name will be prefixed with the table name to match the CREATE TYPE.
313pub fn build_drop_enum_type_sql(
314    table: &str,
315    column_type: &ColumnType,
316) -> Option<super::types::RawSql> {
317    if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
318        // Generate the same unique type name used in CREATE TYPE
319        let type_name = build_enum_type_name(table, name);
320
321        // PostgreSQL: DROP TYPE IF EXISTS {table}_{name}
322        let pg_sql = format!("DROP TYPE IF EXISTS \"{}\"", type_name);
323
324        // MySQL/SQLite: No action needed
325        Some(super::types::RawSql::per_backend(
326            pg_sql,
327            String::new(),
328            String::new(),
329        ))
330    } else {
331        None
332    }
333}
334
335// Re-export naming functions from vespertide-naming
336pub use vespertide_naming::{
337    build_check_constraint_name, build_enum_type_name, build_foreign_key_name, build_index_name,
338    build_unique_constraint_name,
339};
340
341/// Alias for build_check_constraint_name for SQLite enum columns
342pub fn build_sqlite_enum_check_name(table: &str, column: &str) -> String {
343    build_check_constraint_name(table, column)
344}
345
346/// Generate CHECK constraint expression for SQLite enum column
347/// Returns the constraint clause like: CONSTRAINT "chk_table_col" CHECK (col IN ('val1', 'val2'))
348pub fn build_sqlite_enum_check_clause(
349    table: &str,
350    column: &str,
351    column_type: &ColumnType,
352) -> Option<String> {
353    if let ColumnType::Complex(ComplexColumnType::Enum { values, .. }) = column_type {
354        let name = build_sqlite_enum_check_name(table, column);
355        let values_sql = values.to_sql_values().join(", ");
356        Some(format!(
357            "CONSTRAINT \"{}\" CHECK (\"{}\" IN ({}))",
358            name, column, values_sql
359        ))
360    } else {
361        None
362    }
363}
364
365/// Collect all CHECK constraints for enum columns in a table (for SQLite)
366pub fn collect_sqlite_enum_check_clauses(table: &str, columns: &[ColumnDef]) -> Vec<String> {
367    columns
368        .iter()
369        .filter_map(|col| build_sqlite_enum_check_clause(table, &col.name, &col.r#type))
370        .collect()
371}
372
373/// Extract enum name from column type if it's an enum
374pub fn get_enum_name(column_type: &ColumnType) -> Option<&str> {
375    if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
376        Some(name.as_str())
377    } else {
378        None
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use rstest::rstest;
386    use sea_query::{Alias, ColumnDef as SeaColumnDef, ForeignKeyAction};
387    use vespertide_core::EnumValues;
388
389    #[rstest]
390    #[case(ColumnType::Simple(SimpleColumnType::Integer))]
391    #[case(ColumnType::Simple(SimpleColumnType::BigInt))]
392    #[case(ColumnType::Simple(SimpleColumnType::Text))]
393    #[case(ColumnType::Simple(SimpleColumnType::Boolean))]
394    #[case(ColumnType::Simple(SimpleColumnType::Timestamp))]
395    #[case(ColumnType::Simple(SimpleColumnType::Uuid))]
396    #[case(ColumnType::Complex(ComplexColumnType::Varchar { length: 255 }))]
397    #[case(ColumnType::Complex(ComplexColumnType::Numeric { precision: 10, scale: 2 }))]
398    fn test_column_type_conversion(#[case] ty: ColumnType) {
399        // Just ensure no panic - test by creating a column with this type
400        let mut col = SeaColumnDef::new(Alias::new("test"));
401        apply_column_type_with_table(&mut col, &ty, "test_table");
402    }
403
404    #[rstest]
405    #[case(SimpleColumnType::SmallInt)]
406    #[case(SimpleColumnType::Integer)]
407    #[case(SimpleColumnType::BigInt)]
408    #[case(SimpleColumnType::Real)]
409    #[case(SimpleColumnType::DoublePrecision)]
410    #[case(SimpleColumnType::Text)]
411    #[case(SimpleColumnType::Boolean)]
412    #[case(SimpleColumnType::Date)]
413    #[case(SimpleColumnType::Time)]
414    #[case(SimpleColumnType::Timestamp)]
415    #[case(SimpleColumnType::Timestamptz)]
416    #[case(SimpleColumnType::Interval)]
417    #[case(SimpleColumnType::Bytea)]
418    #[case(SimpleColumnType::Uuid)]
419    #[case(SimpleColumnType::Json)]
420    #[case(SimpleColumnType::Inet)]
421    #[case(SimpleColumnType::Cidr)]
422    #[case(SimpleColumnType::Macaddr)]
423    #[case(SimpleColumnType::Xml)]
424    fn test_all_simple_types_cover_branches(#[case] ty: SimpleColumnType) {
425        let mut col = SeaColumnDef::new(Alias::new("t"));
426        apply_column_type_with_table(&mut col, &ColumnType::Simple(ty), "test_table");
427    }
428
429    #[rstest]
430    #[case(ComplexColumnType::Varchar { length: 42 })]
431    #[case(ComplexColumnType::Numeric { precision: 8, scale: 3 })]
432    #[case(ComplexColumnType::Char { length: 3 })]
433    #[case(ComplexColumnType::Custom { custom_type: "GEOGRAPHY".into() })]
434    #[case(ComplexColumnType::Enum { name: "status".into(), values: EnumValues::String(vec!["active".into(), "inactive".into()]) })]
435    fn test_all_complex_types_cover_branches(#[case] ty: ComplexColumnType) {
436        let mut col = SeaColumnDef::new(Alias::new("t"));
437        apply_column_type_with_table(&mut col, &ColumnType::Complex(ty), "test_table");
438    }
439
440    #[rstest]
441    #[case::cascade(ReferenceAction::Cascade, ForeignKeyAction::Cascade)]
442    #[case::restrict(ReferenceAction::Restrict, ForeignKeyAction::Restrict)]
443    #[case::set_null(ReferenceAction::SetNull, ForeignKeyAction::SetNull)]
444    #[case::set_default(ReferenceAction::SetDefault, ForeignKeyAction::SetDefault)]
445    #[case::no_action(ReferenceAction::NoAction, ForeignKeyAction::NoAction)]
446    fn test_reference_action_conversion(
447        #[case] action: ReferenceAction,
448        #[case] expected: ForeignKeyAction,
449    ) {
450        // Just ensure the function doesn't panic and returns valid ForeignKeyAction
451        let result = to_sea_fk_action(&action);
452        assert!(
453            matches!(result, _expected),
454            "Expected {:?}, got {:?}",
455            expected,
456            result
457        );
458    }
459
460    #[rstest]
461    #[case(ReferenceAction::Cascade, "CASCADE")]
462    #[case(ReferenceAction::Restrict, "RESTRICT")]
463    #[case(ReferenceAction::SetNull, "SET NULL")]
464    #[case(ReferenceAction::SetDefault, "SET DEFAULT")]
465    #[case(ReferenceAction::NoAction, "NO ACTION")]
466    fn test_reference_action_sql_all_variants(
467        #[case] action: ReferenceAction,
468        #[case] expected: &str,
469    ) {
470        assert_eq!(reference_action_sql(&action), expected);
471    }
472
473    #[rstest]
474    #[case::gen_random_uuid_postgres(
475        "gen_random_uuid()",
476        DatabaseBackend::Postgres,
477        "gen_random_uuid()"
478    )]
479    #[case::gen_random_uuid_mysql("gen_random_uuid()", DatabaseBackend::MySql, "(UUID())")]
480    #[case::gen_random_uuid_sqlite(
481        "gen_random_uuid()",
482        DatabaseBackend::Sqlite,
483        "lower(hex(randomblob(16)))"
484    )]
485    #[case::current_timestamp_postgres(
486        "current_timestamp()",
487        DatabaseBackend::Postgres,
488        "CURRENT_TIMESTAMP"
489    )]
490    #[case::current_timestamp_mysql(
491        "current_timestamp()",
492        DatabaseBackend::MySql,
493        "CURRENT_TIMESTAMP"
494    )]
495    #[case::current_timestamp_sqlite(
496        "current_timestamp()",
497        DatabaseBackend::Sqlite,
498        "CURRENT_TIMESTAMP"
499    )]
500    #[case::now_postgres("now()", DatabaseBackend::Postgres, "CURRENT_TIMESTAMP")]
501    #[case::now_mysql("now()", DatabaseBackend::MySql, "CURRENT_TIMESTAMP")]
502    #[case::now_sqlite("now()", DatabaseBackend::Sqlite, "CURRENT_TIMESTAMP")]
503    #[case::now_upper_postgres("NOW()", DatabaseBackend::Postgres, "CURRENT_TIMESTAMP")]
504    #[case::now_upper_mysql("NOW()", DatabaseBackend::MySql, "CURRENT_TIMESTAMP")]
505    #[case::now_upper_sqlite("NOW()", DatabaseBackend::Sqlite, "CURRENT_TIMESTAMP")]
506    #[case::current_timestamp_upper_postgres(
507        "CURRENT_TIMESTAMP",
508        DatabaseBackend::Postgres,
509        "CURRENT_TIMESTAMP"
510    )]
511    #[case::current_timestamp_upper_mysql(
512        "CURRENT_TIMESTAMP",
513        DatabaseBackend::MySql,
514        "CURRENT_TIMESTAMP"
515    )]
516    #[case::current_timestamp_upper_sqlite(
517        "CURRENT_TIMESTAMP",
518        DatabaseBackend::Sqlite,
519        "CURRENT_TIMESTAMP"
520    )]
521    fn test_convert_default_for_backend(
522        #[case] default: &str,
523        #[case] backend: DatabaseBackend,
524        #[case] expected: &str,
525    ) {
526        let result = convert_default_for_backend(default, &backend);
527        assert_eq!(result, expected);
528    }
529
530    #[test]
531    fn test_is_enum_type_true() {
532        use vespertide_core::EnumValues;
533
534        let enum_type = ColumnType::Complex(ComplexColumnType::Enum {
535            name: "status".into(),
536            values: EnumValues::String(vec!["active".into(), "inactive".into()]),
537        });
538        assert!(is_enum_type(&enum_type));
539    }
540
541    #[test]
542    fn test_is_enum_type_false() {
543        let text_type = ColumnType::Simple(SimpleColumnType::Text);
544        assert!(!is_enum_type(&text_type));
545    }
546
547    #[test]
548    fn test_get_enum_name_some() {
549        use vespertide_core::EnumValues;
550
551        let enum_type = ColumnType::Complex(ComplexColumnType::Enum {
552            name: "user_status".into(),
553            values: EnumValues::String(vec!["active".into(), "inactive".into()]),
554        });
555        assert_eq!(get_enum_name(&enum_type), Some("user_status"));
556    }
557
558    #[test]
559    fn test_get_enum_name_none() {
560        let text_type = ColumnType::Simple(SimpleColumnType::Text);
561        assert_eq!(get_enum_name(&text_type), None);
562    }
563
564    #[test]
565    fn test_apply_column_type_integer_enum() {
566        use vespertide_core::{EnumValues, NumValue};
567        let integer_enum = ColumnType::Complex(ComplexColumnType::Enum {
568            name: "color".into(),
569            values: EnumValues::Integer(vec![
570                NumValue {
571                    name: "Black".into(),
572                    value: 0,
573                },
574                NumValue {
575                    name: "White".into(),
576                    value: 1,
577                },
578            ]),
579        });
580        let mut col = SeaColumnDef::new(Alias::new("color"));
581        apply_column_type_with_table(&mut col, &integer_enum, "test_table");
582        // Integer enums should use INTEGER type, not ENUM
583    }
584
585    #[test]
586    fn test_build_create_enum_type_sql_integer_enum_returns_none() {
587        use vespertide_core::{EnumValues, NumValue};
588        let integer_enum = ColumnType::Complex(ComplexColumnType::Enum {
589            name: "priority".into(),
590            values: EnumValues::Integer(vec![
591                NumValue {
592                    name: "Low".into(),
593                    value: 0,
594                },
595                NumValue {
596                    name: "High".into(),
597                    value: 10,
598                },
599            ]),
600        });
601        // Integer enums should return None (no CREATE TYPE needed)
602        assert!(build_create_enum_type_sql("test_table", &integer_enum).is_none());
603    }
604
605    #[rstest]
606    // Empty strings need quoting
607    #[case::empty("", true)]
608    #[case::whitespace_only("   ", true)]
609    // Function calls should not be quoted
610    #[case::now_func("now()", false)]
611    #[case::coalesce_func("COALESCE(old_value, 'default')", false)]
612    #[case::uuid_func("gen_random_uuid()", false)]
613    // NULL keyword should not be quoted
614    #[case::null_upper("NULL", false)]
615    #[case::null_lower("null", false)]
616    #[case::null_mixed("Null", false)]
617    // SQL date/time keywords should not be quoted
618    #[case::current_timestamp_upper("CURRENT_TIMESTAMP", false)]
619    #[case::current_timestamp_lower("current_timestamp", false)]
620    #[case::current_date_upper("CURRENT_DATE", false)]
621    #[case::current_date_lower("current_date", false)]
622    #[case::current_time_upper("CURRENT_TIME", false)]
623    #[case::current_time_lower("current_time", false)]
624    // Already quoted strings should not be re-quoted
625    #[case::single_quoted("'active'", false)]
626    #[case::double_quoted("\"active\"", false)]
627    // Plain strings need quoting
628    #[case::plain_active("active", true)]
629    #[case::plain_pending("pending", true)]
630    #[case::plain_underscore("some_value", true)]
631    fn test_needs_quoting(#[case] input: &str, #[case] expected: bool) {
632        assert_eq!(needs_quoting(input), expected);
633    }
634}