Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, rust_type_name_for};
9use crate::introspect::{EnumInfo, SchemaInfo};
10
11/// Detect two SQL enum variants that collapse to the same Rust identifier after
12/// `to_upper_camel_case` (e.g. `"foo bar"` and `"foo_bar"` both become `FooBar`).
13/// Returns an error pointing at the offending pair so the user can rename one
14/// side in the database before regenerating.
15pub fn check_variant_collisions(enum_info: &EnumInfo) -> crate::error::Result<()> {
16    use std::collections::BTreeMap;
17    let mut seen: BTreeMap<String, &str> = BTreeMap::new();
18    for v in &enum_info.variants {
19        let pascal = v.to_upper_camel_case();
20        if let Some(prev) = seen.get(pascal.as_str()).copied() {
21            return Err(crate::error::Error::Config(format!(
22                "Enum '{}.{}': SQL variants '{}' and '{}' both map to Rust identifier '{}'. \
23                 Rename one of them in the database or use a custom mapping.",
24                enum_info.schema_name, enum_info.name, prev, v, pascal
25            )));
26        }
27        seen.insert(pascal, v.as_str());
28    }
29    Ok(())
30}
31
32pub fn generate_enum(
33    enum_info: &EnumInfo,
34    db_kind: DatabaseKind,
35    extra_derives: &[String],
36) -> (TokenStream, BTreeSet<String>) {
37    // Backwards-compatible entry point — uses an empty SchemaInfo so the
38    // enum keeps its bare PascalCase name (no schema prefix).
39    generate_enum_with_schema(enum_info, db_kind, extra_derives, &SchemaInfo::default())
40}
41
42pub fn generate_enum_with_schema(
43    enum_info: &EnumInfo,
44    db_kind: DatabaseKind,
45    extra_derives: &[String],
46    schema_info: &SchemaInfo,
47) -> (TokenStream, BTreeSet<String>) {
48    let mut imports = BTreeSet::new();
49    for imp in imports_for_derives(extra_derives) {
50        imports.insert(imp);
51    }
52
53    let rust_name = rust_type_name_for(schema_info, &enum_info.schema_name, &enum_info.name);
54    let enum_name = format_ident!("{}", rust_name);
55    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
56    // For non-default schemas, remind the user that sqlx 0.8 can only resolve
57    // unqualified type_name attributes — the connection must have the schema
58    // in its search_path. Emitted as a /// doc-comment so it shows up both in
59    // generated source and in rustdoc.
60    let search_path_doc = if db_kind == DatabaseKind::Postgres
61        && !crate::codegen::is_default_schema(&enum_info.schema_name)
62    {
63        let msg = format!(
64            "Lives in PostgreSQL schema `{schema}`. The sqlx connection \
65             must include `{schema}` in its search_path so PG resolves the \
66             unqualified `type_name = \"{name}\"` to this enum. Example:\n\
67             \n\
68             ```ignore\n\
69             sqlx::query(\"SET search_path TO public, {schema}\")\n\
70             ```",
71            schema = enum_info.schema_name,
72            name = enum_info.name,
73        );
74        Some(msg)
75    } else {
76        None
77    };
78
79    imports.insert("use serde::{Serialize, Deserialize};".to_string());
80    imports.insert("use sqlx_gen::SqlxGen;".to_string());
81    let mut derive_tokens = vec![
82        quote! { Debug },
83        quote! { Clone },
84        quote! { PartialEq },
85        quote! { Eq },
86        quote! { Serialize },
87        quote! { Deserialize },
88        quote! { sqlx::Type },
89        quote! { SqlxGen },
90    ];
91    // Use the standard `#[derive(Default)] + #[default]` pattern (stable since
92    // 1.62) when a default variant exists, instead of hand-rolling an impl.
93    if enum_info.default_variant.is_some() {
94        derive_tokens.push(quote! { Default });
95    }
96    for d in extra_derives {
97        let ident = format_ident!("{}", d);
98        derive_tokens.push(quote! { #ident });
99    }
100    let default_variant_pascal = enum_info
101        .default_variant
102        .as_ref()
103        .map(|v| v.to_upper_camel_case());
104
105    // For PG, add #[sqlx(type_name = "...")] — always unqualified.
106    // sqlx 0.8's PgTypeInfo::with_name does NOT accept schema-qualified names; emitting
107    // "schema.type" causes runtime decode failures. The user is expected to set
108    // `search_path` on the connection so that PG resolves the unqualified type name.
109    let type_attr = if db_kind == DatabaseKind::Postgres {
110        let pg_name = &enum_info.name;
111        quote! { #[sqlx(type_name = #pg_name)] }
112    } else {
113        quote! {}
114    };
115
116    let variants: Vec<TokenStream> = enum_info
117        .variants
118        .iter()
119        .map(|v| {
120            let variant_pascal = v.to_upper_camel_case();
121            let variant_ident = format_ident!("{}", variant_pascal);
122
123            let rename = if variant_pascal != *v {
124                quote! { #[sqlx(rename = #v)] }
125            } else {
126                quote! {}
127            };
128
129            let default_attr = if default_variant_pascal.as_deref() == Some(variant_pascal.as_str())
130            {
131                quote! { #[default] }
132            } else {
133                quote! {}
134            };
135
136            quote! {
137                #rename
138                #default_attr
139                #variant_ident,
140            }
141        })
142        .collect();
143
144    // Postgres arrays: `#[derive(sqlx::Type)]` with `#[sqlx(type_name = "x")]`
145    // already auto-generates `impl PgHasArrayType` returning `_x` in sqlx 0.8+.
146    // Emitting a second impl here triggers E0119 (conflicting implementations)
147    // in the user's crate. Leave the derive in charge.
148    let _ = db_kind;
149
150    let schema_name_str = &enum_info.schema_name;
151    let enum_name_str = &enum_info.name;
152    let search_path_doc_tokens = match &search_path_doc {
153        Some(m) => quote! { #[doc = #m] },
154        None => quote! {},
155    };
156
157    let tokens = quote! {
158        #[doc = #doc]
159        #search_path_doc_tokens
160        #[derive(#(#derive_tokens),*)]
161        #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
162        #type_attr
163        pub enum #enum_name {
164            #(#variants)*
165        }
166    };
167
168    (tokens, imports)
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::codegen::parse_and_format;
175
176    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
177        EnumInfo {
178            schema_name: "public".to_string(),
179            name: name.to_string(),
180            variants: variants.into_iter().map(|s| s.to_string()).collect(),
181            default_variant: None,
182        }
183    }
184
185    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
186        let (tokens, _) = generate_enum(info, db, &[]);
187        parse_and_format(&tokens).unwrap()
188    }
189
190    fn gen_with_derives(
191        info: &EnumInfo,
192        db: DatabaseKind,
193        derives: &[String],
194    ) -> (String, BTreeSet<String>) {
195        let (tokens, imports) = generate_enum(info, db, derives);
196        (parse_and_format(&tokens).unwrap(), imports)
197    }
198
199    // --- basic structure ---
200
201    #[test]
202    fn test_enum_variants() {
203        let e = make_enum("status", vec!["active", "inactive"]);
204        let code = gen(&e, DatabaseKind::Postgres);
205        assert!(code.contains("Active"));
206        assert!(code.contains("Inactive"));
207    }
208
209    #[test]
210    fn test_enum_name_pascal_case() {
211        let e = make_enum("user_status", vec!["a"]);
212        let code = gen(&e, DatabaseKind::Postgres);
213        assert!(code.contains("pub enum UserStatus"));
214    }
215
216    #[test]
217    fn test_doc_comment() {
218        let e = make_enum("status", vec!["a"]);
219        let code = gen(&e, DatabaseKind::Postgres);
220        assert!(code.contains("Enum: public.status"));
221    }
222
223    #[test]
224    fn test_sqlx_gen_attr_has_schema_and_name() {
225        let e = make_enum("status", vec!["a"]);
226        let code = gen(&e, DatabaseKind::Postgres);
227        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
228    }
229
230    #[test]
231    fn test_sqlx_gen_attr_non_public_schema() {
232        let e = EnumInfo {
233            schema_name: "auth".to_string(),
234            name: "role".to_string(),
235            variants: vec!["admin".to_string(), "user".to_string()],
236            default_variant: None,
237        };
238        let code = gen(&e, DatabaseKind::Postgres);
239        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
240    }
241
242    // --- sqlx attributes ---
243
244    #[test]
245    fn test_postgres_has_type_name() {
246        let e = make_enum("user_status", vec!["a"]);
247        let code = gen(&e, DatabaseKind::Postgres);
248        assert!(code.contains("sqlx(type_name = \"user_status\")"));
249    }
250
251    #[test]
252    fn test_check_variant_collisions_detects_after_camel_case() {
253        let e = EnumInfo {
254            schema_name: "public".into(),
255            name: "weird".into(),
256            variants: vec!["foo bar".into(), "foo_bar".into()],
257            default_variant: None,
258        };
259        let result = check_variant_collisions(&e);
260        assert!(result.is_err(), "must detect collision");
261        let msg = result.unwrap_err().to_string();
262        assert!(
263            msg.contains("FooBar"),
264            "error must mention conflicting Rust ident, got: {}",
265            msg
266        );
267        assert!(msg.contains("foo bar") || msg.contains("foo_bar"));
268    }
269
270    #[test]
271    fn test_check_variant_collisions_accepts_distinct_variants() {
272        let e = make_enum("status", vec!["active", "inactive"]);
273        assert!(check_variant_collisions(&e).is_ok());
274    }
275
276    #[test]
277    fn test_check_variant_collisions_accepts_single_variant() {
278        let e = make_enum("status", vec!["only"]);
279        assert!(check_variant_collisions(&e).is_ok());
280    }
281
282    #[test]
283    fn test_does_not_emit_manual_pg_has_array_type_impl() {
284        // Regression for E0119 — `#[derive(sqlx::Type)]` already provides this
285        // impl when `type_name` is set, so emitting our own conflicted.
286        for db in [
287            DatabaseKind::Postgres,
288            DatabaseKind::Mysql,
289            DatabaseKind::Sqlite,
290        ] {
291            let e = make_enum("status", vec!["a", "b"]);
292            let code = gen(&e, db);
293            assert!(
294                !code.contains("PgHasArrayType"),
295                "{:?}: must not emit a manual PgHasArrayType impl, got:\n{}",
296                db,
297                code
298            );
299        }
300    }
301
302    #[test]
303    fn test_postgres_non_public_schema_type_name_is_unqualified() {
304        // Regression: previously emitted "auth.role" which crashes sqlx 0.8 at runtime
305        // (PgTypeInfo::with_name does not accept schema-qualified names).
306        let e = EnumInfo {
307            schema_name: "auth".to_string(),
308            name: "role".to_string(),
309            variants: vec!["admin".to_string(), "user".to_string()],
310            default_variant: None,
311        };
312        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
313        let code = parse_and_format(&tokens).unwrap();
314        assert!(
315            code.contains("sqlx(type_name = \"role\")"),
316            "type_name must be unqualified for sqlx 0.8 compatibility, got:\n{}",
317            code
318        );
319        assert!(
320            !code.contains("\"auth.role\""),
321            "type_name must NOT include schema; got:\n{}",
322            code
323        );
324    }
325
326    #[test]
327    fn test_postgres_public_schema_not_qualified() {
328        let e = make_enum("status", vec!["a"]);
329        let code = gen(&e, DatabaseKind::Postgres);
330        assert!(code.contains("sqlx(type_name = \"status\")"));
331        // type_name should NOT be schema-qualified for public schema
332        assert!(!code.contains("type_name = \"public.status\""));
333    }
334
335    #[test]
336    fn test_mysql_inline_enum_emits_rename_for_lowercase_variants() {
337        // Inline MySQL ENUM('active', 'inactive') → Rust variants are PascalCase
338        // and need #[sqlx(rename)] so encode/decode hits the SQL text values.
339        let e = make_enum("status", vec!["active", "inactive"]);
340        let code = gen(&e, DatabaseKind::Mysql);
341        assert!(
342            code.contains("sqlx(rename = \"active\")"),
343            "MySQL inline ENUM variant must carry rename for round-trip:\n{}",
344            code
345        );
346        assert!(code.contains("sqlx(rename = \"inactive\")"));
347        // type_name does NOT exist for MySQL — only PG-native enums need it.
348        assert!(!code.contains("type_name"));
349    }
350
351    #[test]
352    fn test_mysql_inline_enum_preserves_case_sensitive_variants() {
353        let e = make_enum("priority", vec!["LOW", "HIGH"]);
354        let code = gen(&e, DatabaseKind::Mysql);
355        // PascalCase("LOW") = "Low" → rename required so SQL sees "LOW"
356        assert!(code.contains("sqlx(rename = \"LOW\")"));
357        assert!(code.contains("sqlx(rename = \"HIGH\")"));
358    }
359
360    #[test]
361    fn test_mysql_no_type_name() {
362        let e = make_enum("status", vec!["a"]);
363        let code = gen(&e, DatabaseKind::Mysql);
364        assert!(!code.contains("type_name"));
365    }
366
367    #[test]
368    fn test_sqlite_no_type_name() {
369        let e = make_enum("status", vec!["a"]);
370        let code = gen(&e, DatabaseKind::Sqlite);
371        assert!(!code.contains("type_name"));
372    }
373
374    // --- rename variants ---
375
376    #[test]
377    fn test_snake_case_variant_renamed() {
378        let e = make_enum("status", vec!["in_progress"]);
379        let code = gen(&e, DatabaseKind::Postgres);
380        assert!(code.contains("InProgress"));
381        assert!(code.contains("sqlx(rename = \"in_progress\")"));
382    }
383
384    #[test]
385    fn test_lowercase_variant_renamed() {
386        let e = make_enum("status", vec!["active"]);
387        let code = gen(&e, DatabaseKind::Postgres);
388        assert!(code.contains("Active"));
389        assert!(code.contains("sqlx(rename = \"active\")"));
390    }
391
392    #[test]
393    fn test_already_pascal_no_rename() {
394        let e = make_enum("status", vec!["Active"]);
395        let code = gen(&e, DatabaseKind::Postgres);
396        assert!(code.contains("Active"));
397        assert!(!code.contains("sqlx(rename"));
398    }
399
400    #[test]
401    fn test_upper_case_variant_renamed() {
402        let e = make_enum("status", vec!["UPPER_CASE"]);
403        let code = gen(&e, DatabaseKind::Postgres);
404        assert!(code.contains("UpperCase"));
405        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
406    }
407
408    // --- derives ---
409
410    #[test]
411    fn test_default_derives() {
412        let e = make_enum("status", vec!["a"]);
413        let code = gen(&e, DatabaseKind::Postgres);
414        assert!(code.contains("Debug"));
415        assert!(code.contains("Clone"));
416        assert!(code.contains("PartialEq"));
417        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
418    }
419
420    #[test]
421    fn test_extra_derive_serialize() {
422        let e = make_enum("status", vec!["a"]);
423        let derives = vec!["Serialize".to_string()];
424        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
425        assert!(code.contains("Serialize"));
426    }
427
428    #[test]
429    fn test_extra_derives_serde_imports() {
430        let e = make_enum("status", vec!["a"]);
431        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
432        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
433        assert!(imports.iter().any(|i| i.contains("serde")));
434    }
435
436    // --- imports ---
437
438    #[test]
439    fn test_no_extra_derives_has_serde_import() {
440        let e = make_enum("status", vec!["a"]);
441        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
442        assert!(imports.iter().any(|i| i.contains("serde")));
443    }
444
445    #[test]
446    fn test_serde_import_present() {
447        let e = make_enum("status", vec!["a"]);
448        let derives = vec!["Serialize".to_string()];
449        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
450        assert!(!imports.is_empty());
451    }
452
453    // --- edge cases ---
454
455    #[test]
456    fn test_single_variant() {
457        let e = make_enum("status", vec!["only"]);
458        let code = gen(&e, DatabaseKind::Postgres);
459        assert!(code.contains("Only"));
460    }
461
462    #[test]
463    fn test_many_variants() {
464        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
465        let e = make_enum("status", variants);
466        let code = gen(&e, DatabaseKind::Postgres);
467        assert!(code.contains("A,"));
468        assert!(code.contains("J,"));
469    }
470
471    #[test]
472    fn test_variant_with_digits() {
473        let e = make_enum("version", vec!["v2"]);
474        let code = gen(&e, DatabaseKind::Postgres);
475        assert!(code.contains("V2"));
476    }
477
478    #[test]
479    fn test_enum_name_with_double_underscores() {
480        let e = make_enum("my__enum", vec!["a"]);
481        let code = gen(&e, DatabaseKind::Postgres);
482        assert!(code.contains("pub enum MyEnum"));
483    }
484
485    // --- impl Default ---
486
487    #[test]
488    fn test_default_uses_derive_and_attribute() {
489        let e = EnumInfo {
490            schema_name: "public".to_string(),
491            name: "task_status".to_string(),
492            variants: vec![
493                "idle".to_string(),
494                "running".to_string(),
495                "done".to_string(),
496            ],
497            default_variant: Some("idle".to_string()),
498        };
499        let code = gen(&e, DatabaseKind::Postgres);
500        assert!(
501            code.contains("Default"),
502            "expected `Default` in derive list, got:\n{}",
503            code
504        );
505        assert!(
506            code.contains("#[default]"),
507            "expected #[default] attribute on the variant, got:\n{}",
508            code
509        );
510        // No hand-rolled impl Default block.
511        assert!(!code.contains("impl Default for TaskStatus"));
512    }
513
514    #[test]
515    fn test_no_default_derive_when_no_default_variant() {
516        let e = make_enum("status", vec!["active", "inactive"]);
517        let code = gen(&e, DatabaseKind::Postgres);
518        assert!(!code.contains("impl Default"));
519        assert!(!code.contains("#[default]"));
520        // The derive line must NOT contain a free-standing Default token.
521        let derive_line = code
522            .lines()
523            .find(|l| l.contains("#[derive"))
524            .expect("derive line");
525        assert!(
526            !derive_line.contains(", Default"),
527            "derive list should not include Default, got: {}",
528            derive_line
529        );
530    }
531
532    #[test]
533    fn test_default_attribute_on_correct_variant_snake_case() {
534        let e = EnumInfo {
535            schema_name: "public".to_string(),
536            name: "status".to_string(),
537            variants: vec!["in_progress".to_string(), "done".to_string()],
538            default_variant: Some("in_progress".to_string()),
539        };
540        let code = gen(&e, DatabaseKind::Postgres);
541        // The `#[default]` attribute must sit directly above the `InProgress`
542        // variant — not on `Done`.
543        let in_progress_idx = code.find("InProgress").expect("InProgress");
544        let default_attr_idx = code.find("#[default]").expect("#[default]");
545        assert!(
546            default_attr_idx < in_progress_idx,
547            "#[default] must precede InProgress"
548        );
549        let between = &code[default_attr_idx..in_progress_idx];
550        assert!(
551            !between.contains("Done"),
552            "#[default] landed on the wrong variant:\n{}",
553            code
554        );
555    }
556
557    // --- public vs named schema integration ---
558
559    fn make_enum_in_schema(schema: &str, name: &str, variants: Vec<&str>) -> EnumInfo {
560        EnumInfo {
561            schema_name: schema.to_string(),
562            name: name.to_string(),
563            variants: variants.into_iter().map(|s| s.to_string()).collect(),
564            default_variant: None,
565        }
566    }
567
568    #[test]
569    fn test_public_schema_full_output() {
570        let e = make_enum_in_schema(
571            "public",
572            "order_status",
573            vec!["pending", "shipped", "delivered"],
574        );
575        let code = gen(&e, DatabaseKind::Postgres);
576
577        assert!(code.contains("Enum: public.order_status"));
578        assert!(code.contains("pub enum OrderStatus"));
579        assert!(code.contains("sqlx(type_name = \"order_status\")"));
580        assert!(!code.contains("sqlx(type_name = \"public.order_status\")"));
581        assert!(code
582            .contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")"));
583        assert!(code.contains("Pending"));
584        assert!(code.contains("Shipped"));
585        assert!(code.contains("Delivered"));
586    }
587
588    #[test]
589    fn test_named_schema_full_output() {
590        let e = make_enum_in_schema(
591            "analysis",
592            "toolcall_status",
593            vec!["PENDING", "RUNNING", "DONE"],
594        );
595        let code = gen(&e, DatabaseKind::Postgres);
596
597        assert!(code.contains("Enum: analysis.toolcall_status"));
598        assert!(code.contains("pub enum ToolcallStatus"));
599        assert!(code.contains("sqlx(type_name = \"toolcall_status\")"));
600        assert!(!code.contains("\"analysis.toolcall_status\""));
601        assert!(code.contains(
602            "sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")"
603        ));
604        assert!(code.contains("Pending"));
605        assert!(code.contains("Running"));
606        assert!(code.contains("Done"));
607    }
608
609    #[test]
610    fn test_named_schema_with_default_variant() {
611        let e = EnumInfo {
612            schema_name: "billing".to_string(),
613            name: "payment_status".to_string(),
614            variants: vec![
615                "pending".to_string(),
616                "paid".to_string(),
617                "refunded".to_string(),
618            ],
619            default_variant: Some("pending".to_string()),
620        };
621        let code = gen(&e, DatabaseKind::Postgres);
622
623        assert!(code.contains("sqlx(type_name = \"payment_status\")"));
624        assert!(!code.contains("\"billing.payment_status\""));
625        // Uses #[derive(Default)] + #[default] instead of a hand-rolled impl.
626        assert!(code.contains("Default"));
627        assert!(code.contains("#[default]"));
628        assert!(!code.contains("impl Default for PaymentStatus"));
629    }
630
631    #[test]
632    fn test_named_schema_variant_rename() {
633        let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]);
634        let code = gen(&e, DatabaseKind::Postgres);
635
636        assert!(code.contains("sqlx(type_name = \"log_level\")"));
637        assert!(!code.contains("\"audit.log_level\""));
638        assert!(code.contains("sqlx(rename = \"info\")"));
639        assert!(code.contains("sqlx(rename = \"warn_high\")"));
640        assert!(code.contains("WarnHigh"));
641        assert!(code.contains("sqlx(rename = \"CRITICAL\")"));
642        assert!(code.contains("Critical"));
643    }
644
645    #[test]
646    fn test_named_schema_mysql_no_type_name() {
647        let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
648        let code = gen(&e, DatabaseKind::Mysql);
649
650        assert!(!code.contains("type_name"));
651    }
652
653    #[test]
654    fn test_named_schema_sqlite_no_type_name() {
655        let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
656        let code = gen(&e, DatabaseKind::Sqlite);
657
658        assert!(!code.contains("type_name"));
659    }
660}