Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, rust_type_name_for};
9use crate::introspect::{EnumInfo, SchemaInfo};
10
11/// Detect two SQL enum variants that collapse to the same Rust identifier after
12/// `to_upper_camel_case` (e.g. `"foo bar"` and `"foo_bar"` both become `FooBar`).
13/// Returns an error pointing at the offending pair so the user can rename one
14/// side in the database before regenerating.
15pub fn check_variant_collisions(enum_info: &EnumInfo) -> crate::error::Result<()> {
16    use std::collections::BTreeMap;
17    let mut seen: BTreeMap<String, &str> = BTreeMap::new();
18    for v in &enum_info.variants {
19        let pascal = v.to_upper_camel_case();
20        if let Some(prev) = seen.get(pascal.as_str()).copied() {
21            return Err(crate::error::Error::Config(format!(
22                "Enum '{}.{}': SQL variants '{}' and '{}' both map to Rust identifier '{}'. \
23                 Rename one of them in the database or use a custom mapping.",
24                enum_info.schema_name, enum_info.name, prev, v, pascal
25            )));
26        }
27        seen.insert(pascal, v.as_str());
28    }
29    Ok(())
30}
31
32pub fn generate_enum(
33    enum_info: &EnumInfo,
34    db_kind: DatabaseKind,
35    extra_derives: &[String],
36) -> (TokenStream, BTreeSet<String>) {
37    // Backwards-compatible entry point — uses an empty SchemaInfo so the
38    // enum keeps its bare PascalCase name (no schema prefix).
39    generate_enum_with_schema(enum_info, db_kind, extra_derives, &SchemaInfo::default())
40}
41
42pub fn generate_enum_with_schema(
43    enum_info: &EnumInfo,
44    db_kind: DatabaseKind,
45    extra_derives: &[String],
46    schema_info: &SchemaInfo,
47) -> (TokenStream, BTreeSet<String>) {
48    let mut imports = BTreeSet::new();
49    for imp in imports_for_derives(extra_derives) {
50        imports.insert(imp);
51    }
52
53    let rust_name = rust_type_name_for(schema_info, &enum_info.schema_name, &enum_info.name);
54    let enum_name = format_ident!("{}", rust_name);
55    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
56    // For non-default schemas, remind the user that sqlx 0.8 can only resolve
57    // unqualified type_name attributes — the connection must have the schema
58    // in its search_path. Emitted as a /// doc-comment so it shows up both in
59    // generated source and in rustdoc.
60    let search_path_doc = if db_kind == DatabaseKind::Postgres
61        && !crate::codegen::is_default_schema(&enum_info.schema_name)
62    {
63        let msg = format!(
64            "Lives in PostgreSQL schema `{schema}`. The sqlx connection \
65             must include `{schema}` in its search_path so PG resolves the \
66             unqualified `type_name = \"{name}\"` to this enum. Example:\n\
67             \n\
68             ```ignore\n\
69             sqlx::query(\"SET search_path TO public, {schema}\")\n\
70             ```",
71            schema = enum_info.schema_name,
72            name = enum_info.name,
73        );
74        Some(msg)
75    } else {
76        None
77    };
78
79    imports.insert("use serde::{Serialize, Deserialize};".to_string());
80    imports.insert("use sqlx_gen::SqlxGen;".to_string());
81    let mut derive_tokens = vec![
82        quote! { Debug },
83        quote! { Clone },
84        quote! { PartialEq },
85        quote! { Eq },
86        quote! { Serialize },
87        quote! { Deserialize },
88        quote! { sqlx::Type },
89        quote! { SqlxGen },
90    ];
91    for d in extra_derives {
92        let ident = format_ident!("{}", d);
93        derive_tokens.push(quote! { #ident });
94    }
95
96    // For PG, add #[sqlx(type_name = "...")] — always unqualified.
97    // sqlx 0.8's PgTypeInfo::with_name does NOT accept schema-qualified names; emitting
98    // "schema.type" causes runtime decode failures. The user is expected to set
99    // `search_path` on the connection so that PG resolves the unqualified type name.
100    let type_attr = if db_kind == DatabaseKind::Postgres {
101        let pg_name = &enum_info.name;
102        quote! { #[sqlx(type_name = #pg_name)] }
103    } else {
104        quote! {}
105    };
106
107    let variants: Vec<TokenStream> = enum_info
108        .variants
109        .iter()
110        .map(|v| {
111            let variant_pascal = v.to_upper_camel_case();
112            let variant_ident = format_ident!("{}", variant_pascal);
113
114            let rename = if variant_pascal != *v {
115                quote! { #[sqlx(rename = #v)] }
116            } else {
117                quote! {}
118            };
119
120            quote! {
121                #rename
122                #variant_ident,
123            }
124        })
125        .collect();
126
127    let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
128        let variant_pascal = default_variant.to_upper_camel_case();
129        let variant_ident = format_ident!("{}", variant_pascal);
130        quote! {
131            impl Default for #enum_name {
132                fn default() -> Self {
133                    Self::#variant_ident
134                }
135            }
136        }
137    } else {
138        quote! {}
139    };
140
141    // Postgres arrays: `#[derive(sqlx::Type)]` with `#[sqlx(type_name = "x")]`
142    // already auto-generates `impl PgHasArrayType` returning `_x` in sqlx 0.8+.
143    // Emitting a second impl here triggers E0119 (conflicting implementations)
144    // in the user's crate. Leave the derive in charge.
145    let _ = db_kind;
146
147    let schema_name_str = &enum_info.schema_name;
148    let enum_name_str = &enum_info.name;
149    let search_path_doc_tokens = match &search_path_doc {
150        Some(m) => quote! { #[doc = #m] },
151        None => quote! {},
152    };
153
154    let tokens = quote! {
155        #[doc = #doc]
156        #search_path_doc_tokens
157        #[derive(#(#derive_tokens),*)]
158        #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
159        #type_attr
160        pub enum #enum_name {
161            #(#variants)*
162        }
163
164        #default_impl
165    };
166
167    (tokens, imports)
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::codegen::parse_and_format;
174
175    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
176        EnumInfo {
177            schema_name: "public".to_string(),
178            name: name.to_string(),
179            variants: variants.into_iter().map(|s| s.to_string()).collect(),
180            default_variant: None,
181        }
182    }
183
184    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
185        let (tokens, _) = generate_enum(info, db, &[]);
186        parse_and_format(&tokens).unwrap()
187    }
188
189    fn gen_with_derives(
190        info: &EnumInfo,
191        db: DatabaseKind,
192        derives: &[String],
193    ) -> (String, BTreeSet<String>) {
194        let (tokens, imports) = generate_enum(info, db, derives);
195        (parse_and_format(&tokens).unwrap(), imports)
196    }
197
198    // --- basic structure ---
199
200    #[test]
201    fn test_enum_variants() {
202        let e = make_enum("status", vec!["active", "inactive"]);
203        let code = gen(&e, DatabaseKind::Postgres);
204        assert!(code.contains("Active"));
205        assert!(code.contains("Inactive"));
206    }
207
208    #[test]
209    fn test_enum_name_pascal_case() {
210        let e = make_enum("user_status", vec!["a"]);
211        let code = gen(&e, DatabaseKind::Postgres);
212        assert!(code.contains("pub enum UserStatus"));
213    }
214
215    #[test]
216    fn test_doc_comment() {
217        let e = make_enum("status", vec!["a"]);
218        let code = gen(&e, DatabaseKind::Postgres);
219        assert!(code.contains("Enum: public.status"));
220    }
221
222    #[test]
223    fn test_sqlx_gen_attr_has_schema_and_name() {
224        let e = make_enum("status", vec!["a"]);
225        let code = gen(&e, DatabaseKind::Postgres);
226        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
227    }
228
229    #[test]
230    fn test_sqlx_gen_attr_non_public_schema() {
231        let e = EnumInfo {
232            schema_name: "auth".to_string(),
233            name: "role".to_string(),
234            variants: vec!["admin".to_string(), "user".to_string()],
235            default_variant: None,
236        };
237        let code = gen(&e, DatabaseKind::Postgres);
238        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
239    }
240
241    // --- sqlx attributes ---
242
243    #[test]
244    fn test_postgres_has_type_name() {
245        let e = make_enum("user_status", vec!["a"]);
246        let code = gen(&e, DatabaseKind::Postgres);
247        assert!(code.contains("sqlx(type_name = \"user_status\")"));
248    }
249
250    #[test]
251    fn test_check_variant_collisions_detects_after_camel_case() {
252        let e = EnumInfo {
253            schema_name: "public".into(),
254            name: "weird".into(),
255            variants: vec!["foo bar".into(), "foo_bar".into()],
256            default_variant: None,
257        };
258        let result = check_variant_collisions(&e);
259        assert!(result.is_err(), "must detect collision");
260        let msg = result.unwrap_err().to_string();
261        assert!(
262            msg.contains("FooBar"),
263            "error must mention conflicting Rust ident, got: {}",
264            msg
265        );
266        assert!(msg.contains("foo bar") || msg.contains("foo_bar"));
267    }
268
269    #[test]
270    fn test_check_variant_collisions_accepts_distinct_variants() {
271        let e = make_enum("status", vec!["active", "inactive"]);
272        assert!(check_variant_collisions(&e).is_ok());
273    }
274
275    #[test]
276    fn test_check_variant_collisions_accepts_single_variant() {
277        let e = make_enum("status", vec!["only"]);
278        assert!(check_variant_collisions(&e).is_ok());
279    }
280
281    #[test]
282    fn test_does_not_emit_manual_pg_has_array_type_impl() {
283        // Regression for E0119 — `#[derive(sqlx::Type)]` already provides this
284        // impl when `type_name` is set, so emitting our own conflicted.
285        for db in [
286            DatabaseKind::Postgres,
287            DatabaseKind::Mysql,
288            DatabaseKind::Sqlite,
289        ] {
290            let e = make_enum("status", vec!["a", "b"]);
291            let code = gen(&e, db);
292            assert!(
293                !code.contains("PgHasArrayType"),
294                "{:?}: must not emit a manual PgHasArrayType impl, got:\n{}",
295                db,
296                code
297            );
298        }
299    }
300
301    #[test]
302    fn test_postgres_non_public_schema_type_name_is_unqualified() {
303        // Regression: previously emitted "auth.role" which crashes sqlx 0.8 at runtime
304        // (PgTypeInfo::with_name does not accept schema-qualified names).
305        let e = EnumInfo {
306            schema_name: "auth".to_string(),
307            name: "role".to_string(),
308            variants: vec!["admin".to_string(), "user".to_string()],
309            default_variant: None,
310        };
311        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
312        let code = parse_and_format(&tokens).unwrap();
313        assert!(
314            code.contains("sqlx(type_name = \"role\")"),
315            "type_name must be unqualified for sqlx 0.8 compatibility, got:\n{}",
316            code
317        );
318        assert!(
319            !code.contains("\"auth.role\""),
320            "type_name must NOT include schema; got:\n{}",
321            code
322        );
323    }
324
325    #[test]
326    fn test_postgres_public_schema_not_qualified() {
327        let e = make_enum("status", vec!["a"]);
328        let code = gen(&e, DatabaseKind::Postgres);
329        assert!(code.contains("sqlx(type_name = \"status\")"));
330        // type_name should NOT be schema-qualified for public schema
331        assert!(!code.contains("type_name = \"public.status\""));
332    }
333
334    #[test]
335    fn test_mysql_inline_enum_emits_rename_for_lowercase_variants() {
336        // Inline MySQL ENUM('active', 'inactive') → Rust variants are PascalCase
337        // and need #[sqlx(rename)] so encode/decode hits the SQL text values.
338        let e = make_enum("status", vec!["active", "inactive"]);
339        let code = gen(&e, DatabaseKind::Mysql);
340        assert!(
341            code.contains("sqlx(rename = \"active\")"),
342            "MySQL inline ENUM variant must carry rename for round-trip:\n{}",
343            code
344        );
345        assert!(code.contains("sqlx(rename = \"inactive\")"));
346        // type_name does NOT exist for MySQL — only PG-native enums need it.
347        assert!(!code.contains("type_name"));
348    }
349
350    #[test]
351    fn test_mysql_inline_enum_preserves_case_sensitive_variants() {
352        let e = make_enum("priority", vec!["LOW", "HIGH"]);
353        let code = gen(&e, DatabaseKind::Mysql);
354        // PascalCase("LOW") = "Low" → rename required so SQL sees "LOW"
355        assert!(code.contains("sqlx(rename = \"LOW\")"));
356        assert!(code.contains("sqlx(rename = \"HIGH\")"));
357    }
358
359    #[test]
360    fn test_mysql_no_type_name() {
361        let e = make_enum("status", vec!["a"]);
362        let code = gen(&e, DatabaseKind::Mysql);
363        assert!(!code.contains("type_name"));
364    }
365
366    #[test]
367    fn test_sqlite_no_type_name() {
368        let e = make_enum("status", vec!["a"]);
369        let code = gen(&e, DatabaseKind::Sqlite);
370        assert!(!code.contains("type_name"));
371    }
372
373    // --- rename variants ---
374
375    #[test]
376    fn test_snake_case_variant_renamed() {
377        let e = make_enum("status", vec!["in_progress"]);
378        let code = gen(&e, DatabaseKind::Postgres);
379        assert!(code.contains("InProgress"));
380        assert!(code.contains("sqlx(rename = \"in_progress\")"));
381    }
382
383    #[test]
384    fn test_lowercase_variant_renamed() {
385        let e = make_enum("status", vec!["active"]);
386        let code = gen(&e, DatabaseKind::Postgres);
387        assert!(code.contains("Active"));
388        assert!(code.contains("sqlx(rename = \"active\")"));
389    }
390
391    #[test]
392    fn test_already_pascal_no_rename() {
393        let e = make_enum("status", vec!["Active"]);
394        let code = gen(&e, DatabaseKind::Postgres);
395        assert!(code.contains("Active"));
396        assert!(!code.contains("sqlx(rename"));
397    }
398
399    #[test]
400    fn test_upper_case_variant_renamed() {
401        let e = make_enum("status", vec!["UPPER_CASE"]);
402        let code = gen(&e, DatabaseKind::Postgres);
403        assert!(code.contains("UpperCase"));
404        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
405    }
406
407    // --- derives ---
408
409    #[test]
410    fn test_default_derives() {
411        let e = make_enum("status", vec!["a"]);
412        let code = gen(&e, DatabaseKind::Postgres);
413        assert!(code.contains("Debug"));
414        assert!(code.contains("Clone"));
415        assert!(code.contains("PartialEq"));
416        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
417    }
418
419    #[test]
420    fn test_extra_derive_serialize() {
421        let e = make_enum("status", vec!["a"]);
422        let derives = vec!["Serialize".to_string()];
423        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
424        assert!(code.contains("Serialize"));
425    }
426
427    #[test]
428    fn test_extra_derives_serde_imports() {
429        let e = make_enum("status", vec!["a"]);
430        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
431        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
432        assert!(imports.iter().any(|i| i.contains("serde")));
433    }
434
435    // --- imports ---
436
437    #[test]
438    fn test_no_extra_derives_has_serde_import() {
439        let e = make_enum("status", vec!["a"]);
440        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
441        assert!(imports.iter().any(|i| i.contains("serde")));
442    }
443
444    #[test]
445    fn test_serde_import_present() {
446        let e = make_enum("status", vec!["a"]);
447        let derives = vec!["Serialize".to_string()];
448        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
449        assert!(!imports.is_empty());
450    }
451
452    // --- edge cases ---
453
454    #[test]
455    fn test_single_variant() {
456        let e = make_enum("status", vec!["only"]);
457        let code = gen(&e, DatabaseKind::Postgres);
458        assert!(code.contains("Only"));
459    }
460
461    #[test]
462    fn test_many_variants() {
463        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
464        let e = make_enum("status", variants);
465        let code = gen(&e, DatabaseKind::Postgres);
466        assert!(code.contains("A,"));
467        assert!(code.contains("J,"));
468    }
469
470    #[test]
471    fn test_variant_with_digits() {
472        let e = make_enum("version", vec!["v2"]);
473        let code = gen(&e, DatabaseKind::Postgres);
474        assert!(code.contains("V2"));
475    }
476
477    #[test]
478    fn test_enum_name_with_double_underscores() {
479        let e = make_enum("my__enum", vec!["a"]);
480        let code = gen(&e, DatabaseKind::Postgres);
481        assert!(code.contains("pub enum MyEnum"));
482    }
483
484    // --- impl Default ---
485
486    #[test]
487    fn test_default_impl_generated() {
488        let e = EnumInfo {
489            schema_name: "public".to_string(),
490            name: "task_status".to_string(),
491            variants: vec![
492                "idle".to_string(),
493                "running".to_string(),
494                "done".to_string(),
495            ],
496            default_variant: Some("idle".to_string()),
497        };
498        let code = gen(&e, DatabaseKind::Postgres);
499        assert!(code.contains("impl Default for TaskStatus"));
500        assert!(code.contains("Self::Idle"));
501    }
502
503    #[test]
504    fn test_no_default_impl_when_none() {
505        let e = make_enum("status", vec!["active", "inactive"]);
506        let code = gen(&e, DatabaseKind::Postgres);
507        assert!(!code.contains("impl Default"));
508    }
509
510    #[test]
511    fn test_default_impl_snake_case_variant() {
512        let e = EnumInfo {
513            schema_name: "public".to_string(),
514            name: "status".to_string(),
515            variants: vec!["in_progress".to_string(), "done".to_string()],
516            default_variant: Some("in_progress".to_string()),
517        };
518        let code = gen(&e, DatabaseKind::Postgres);
519        assert!(code.contains("impl Default for Status"));
520        assert!(code.contains("Self::InProgress"));
521    }
522
523    // --- public vs named schema integration ---
524
525    fn make_enum_in_schema(schema: &str, name: &str, variants: Vec<&str>) -> EnumInfo {
526        EnumInfo {
527            schema_name: schema.to_string(),
528            name: name.to_string(),
529            variants: variants.into_iter().map(|s| s.to_string()).collect(),
530            default_variant: None,
531        }
532    }
533
534    #[test]
535    fn test_public_schema_full_output() {
536        let e = make_enum_in_schema(
537            "public",
538            "order_status",
539            vec!["pending", "shipped", "delivered"],
540        );
541        let code = gen(&e, DatabaseKind::Postgres);
542
543        assert!(code.contains("Enum: public.order_status"));
544        assert!(code.contains("pub enum OrderStatus"));
545        assert!(code.contains("sqlx(type_name = \"order_status\")"));
546        assert!(!code.contains("sqlx(type_name = \"public.order_status\")"));
547        assert!(code
548            .contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")"));
549        assert!(code.contains("Pending"));
550        assert!(code.contains("Shipped"));
551        assert!(code.contains("Delivered"));
552    }
553
554    #[test]
555    fn test_named_schema_full_output() {
556        let e = make_enum_in_schema(
557            "analysis",
558            "toolcall_status",
559            vec!["PENDING", "RUNNING", "DONE"],
560        );
561        let code = gen(&e, DatabaseKind::Postgres);
562
563        assert!(code.contains("Enum: analysis.toolcall_status"));
564        assert!(code.contains("pub enum ToolcallStatus"));
565        assert!(code.contains("sqlx(type_name = \"toolcall_status\")"));
566        assert!(!code.contains("\"analysis.toolcall_status\""));
567        assert!(code.contains(
568            "sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")"
569        ));
570        assert!(code.contains("Pending"));
571        assert!(code.contains("Running"));
572        assert!(code.contains("Done"));
573    }
574
575    #[test]
576    fn test_named_schema_with_default_variant() {
577        let e = EnumInfo {
578            schema_name: "billing".to_string(),
579            name: "payment_status".to_string(),
580            variants: vec![
581                "pending".to_string(),
582                "paid".to_string(),
583                "refunded".to_string(),
584            ],
585            default_variant: Some("pending".to_string()),
586        };
587        let code = gen(&e, DatabaseKind::Postgres);
588
589        assert!(code.contains("sqlx(type_name = \"payment_status\")"));
590        assert!(!code.contains("\"billing.payment_status\""));
591        assert!(code.contains("impl Default for PaymentStatus"));
592        assert!(code.contains("Self::Pending"));
593    }
594
595    #[test]
596    fn test_named_schema_variant_rename() {
597        let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]);
598        let code = gen(&e, DatabaseKind::Postgres);
599
600        assert!(code.contains("sqlx(type_name = \"log_level\")"));
601        assert!(!code.contains("\"audit.log_level\""));
602        assert!(code.contains("sqlx(rename = \"info\")"));
603        assert!(code.contains("sqlx(rename = \"warn_high\")"));
604        assert!(code.contains("WarnHigh"));
605        assert!(code.contains("sqlx(rename = \"CRITICAL\")"));
606        assert!(code.contains("Critical"));
607    }
608
609    #[test]
610    fn test_named_schema_mysql_no_type_name() {
611        let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
612        let code = gen(&e, DatabaseKind::Mysql);
613
614        assert!(!code.contains("type_name"));
615    }
616
617    #[test]
618    fn test_named_schema_sqlite_no_type_name() {
619        let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
620        let code = gen(&e, DatabaseKind::Sqlite);
621
622        assert!(!code.contains("type_name"));
623    }
624}