Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20
21    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    imports.insert("use serde::{Serialize, Deserialize};".to_string());
25    imports.insert("use sqlx_gen::SqlxGen;".to_string());
26    let mut derive_tokens = vec![
27        quote! { Debug },
28        quote! { Clone },
29        quote! { PartialEq },
30        quote! { Eq },
31        quote! { Serialize },
32        quote! { Deserialize },
33        quote! { sqlx::Type },
34        quote! { SqlxGen },
35    ];
36    for d in extra_derives {
37        let ident = format_ident!("{}", d);
38        derive_tokens.push(quote! { #ident });
39    }
40
41    // For PG, add #[sqlx(type_name = "...")] — schema-qualify for non-public schemas
42    let type_attr = if db_kind == DatabaseKind::Postgres {
43        let pg_name = if enum_info.schema_name != "public" {
44            format!("{}.{}", enum_info.schema_name, enum_info.name)
45        } else {
46            enum_info.name.clone()
47        };
48        quote! { #[sqlx(type_name = #pg_name)] }
49    } else {
50        quote! {}
51    };
52
53    let variants: Vec<TokenStream> = enum_info
54        .variants
55        .iter()
56        .map(|v| {
57            let variant_pascal = v.to_upper_camel_case();
58            let variant_ident = format_ident!("{}", variant_pascal);
59
60            let rename = if variant_pascal != *v {
61                quote! { #[sqlx(rename = #v)] }
62            } else {
63                quote! {}
64            };
65
66            quote! {
67                #rename
68                #variant_ident,
69            }
70        })
71        .collect();
72
73    let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
74        let variant_pascal = default_variant.to_upper_camel_case();
75        let variant_ident = format_ident!("{}", variant_pascal);
76        quote! {
77            impl Default for #enum_name {
78                fn default() -> Self {
79                    Self::#variant_ident
80                }
81            }
82        }
83    } else {
84        quote! {}
85    };
86
87    let schema_name_str = &enum_info.schema_name;
88    let enum_name_str = &enum_info.name;
89
90    let tokens = quote! {
91        #[doc = #doc]
92        #[derive(#(#derive_tokens),*)]
93        #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
94        #type_attr
95        pub enum #enum_name {
96            #(#variants)*
97        }
98
99        #default_impl
100    };
101
102    (tokens, imports)
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::codegen::parse_and_format;
109
110    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
111        EnumInfo {
112            schema_name: "public".to_string(),
113            name: name.to_string(),
114            variants: variants.into_iter().map(|s| s.to_string()).collect(),
115            default_variant: None,
116        }
117    }
118
119    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
120        let (tokens, _) = generate_enum(info, db, &[]);
121        parse_and_format(&tokens)
122    }
123
124    fn gen_with_derives(
125        info: &EnumInfo,
126        db: DatabaseKind,
127        derives: &[String],
128    ) -> (String, BTreeSet<String>) {
129        let (tokens, imports) = generate_enum(info, db, derives);
130        (parse_and_format(&tokens), imports)
131    }
132
133    // --- basic structure ---
134
135    #[test]
136    fn test_enum_variants() {
137        let e = make_enum("status", vec!["active", "inactive"]);
138        let code = gen(&e, DatabaseKind::Postgres);
139        assert!(code.contains("Active"));
140        assert!(code.contains("Inactive"));
141    }
142
143    #[test]
144    fn test_enum_name_pascal_case() {
145        let e = make_enum("user_status", vec!["a"]);
146        let code = gen(&e, DatabaseKind::Postgres);
147        assert!(code.contains("pub enum UserStatus"));
148    }
149
150    #[test]
151    fn test_doc_comment() {
152        let e = make_enum("status", vec!["a"]);
153        let code = gen(&e, DatabaseKind::Postgres);
154        assert!(code.contains("Enum: public.status"));
155    }
156
157    #[test]
158    fn test_sqlx_gen_attr_has_schema_and_name() {
159        let e = make_enum("status", vec!["a"]);
160        let code = gen(&e, DatabaseKind::Postgres);
161        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
162    }
163
164    #[test]
165    fn test_sqlx_gen_attr_non_public_schema() {
166        let e = EnumInfo {
167            schema_name: "auth".to_string(),
168            name: "role".to_string(),
169            variants: vec!["admin".to_string(), "user".to_string()],
170            default_variant: None,
171        };
172        let code = gen(&e, DatabaseKind::Postgres);
173        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
174    }
175
176    // --- sqlx attributes ---
177
178    #[test]
179    fn test_postgres_has_type_name() {
180        let e = make_enum("user_status", vec!["a"]);
181        let code = gen(&e, DatabaseKind::Postgres);
182        assert!(code.contains("sqlx(type_name = \"user_status\")"));
183    }
184
185    #[test]
186    fn test_postgres_non_public_schema_qualified_type_name() {
187        let e = EnumInfo {
188            schema_name: "auth".to_string(),
189            name: "role".to_string(),
190            variants: vec!["admin".to_string(), "user".to_string()],
191            default_variant: None,
192        };
193        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
194        let code = parse_and_format(&tokens);
195        assert!(code.contains("sqlx(type_name = \"auth.role\")"));
196    }
197
198    #[test]
199    fn test_postgres_public_schema_not_qualified() {
200        let e = make_enum("status", vec!["a"]);
201        let code = gen(&e, DatabaseKind::Postgres);
202        assert!(code.contains("sqlx(type_name = \"status\")"));
203        // type_name should NOT be schema-qualified for public schema
204        assert!(!code.contains("type_name = \"public.status\""));
205    }
206
207    #[test]
208    fn test_mysql_no_type_name() {
209        let e = make_enum("status", vec!["a"]);
210        let code = gen(&e, DatabaseKind::Mysql);
211        assert!(!code.contains("type_name"));
212    }
213
214    #[test]
215    fn test_sqlite_no_type_name() {
216        let e = make_enum("status", vec!["a"]);
217        let code = gen(&e, DatabaseKind::Sqlite);
218        assert!(!code.contains("type_name"));
219    }
220
221    // --- rename variants ---
222
223    #[test]
224    fn test_snake_case_variant_renamed() {
225        let e = make_enum("status", vec!["in_progress"]);
226        let code = gen(&e, DatabaseKind::Postgres);
227        assert!(code.contains("InProgress"));
228        assert!(code.contains("sqlx(rename = \"in_progress\")"));
229    }
230
231    #[test]
232    fn test_lowercase_variant_renamed() {
233        let e = make_enum("status", vec!["active"]);
234        let code = gen(&e, DatabaseKind::Postgres);
235        assert!(code.contains("Active"));
236        assert!(code.contains("sqlx(rename = \"active\")"));
237    }
238
239    #[test]
240    fn test_already_pascal_no_rename() {
241        let e = make_enum("status", vec!["Active"]);
242        let code = gen(&e, DatabaseKind::Postgres);
243        assert!(code.contains("Active"));
244        assert!(!code.contains("sqlx(rename"));
245    }
246
247    #[test]
248    fn test_upper_case_variant_renamed() {
249        let e = make_enum("status", vec!["UPPER_CASE"]);
250        let code = gen(&e, DatabaseKind::Postgres);
251        assert!(code.contains("UpperCase"));
252        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
253    }
254
255    // --- derives ---
256
257    #[test]
258    fn test_default_derives() {
259        let e = make_enum("status", vec!["a"]);
260        let code = gen(&e, DatabaseKind::Postgres);
261        assert!(code.contains("Debug"));
262        assert!(code.contains("Clone"));
263        assert!(code.contains("PartialEq"));
264        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
265    }
266
267    #[test]
268    fn test_extra_derive_serialize() {
269        let e = make_enum("status", vec!["a"]);
270        let derives = vec!["Serialize".to_string()];
271        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
272        assert!(code.contains("Serialize"));
273    }
274
275    #[test]
276    fn test_extra_derives_serde_imports() {
277        let e = make_enum("status", vec!["a"]);
278        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
279        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
280        assert!(imports.iter().any(|i| i.contains("serde")));
281    }
282
283    // --- imports ---
284
285    #[test]
286    fn test_no_extra_derives_has_serde_import() {
287        let e = make_enum("status", vec!["a"]);
288        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
289        assert!(imports.iter().any(|i| i.contains("serde")));
290    }
291
292    #[test]
293    fn test_serde_import_present() {
294        let e = make_enum("status", vec!["a"]);
295        let derives = vec!["Serialize".to_string()];
296        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
297        assert!(!imports.is_empty());
298    }
299
300    // --- edge cases ---
301
302    #[test]
303    fn test_single_variant() {
304        let e = make_enum("status", vec!["only"]);
305        let code = gen(&e, DatabaseKind::Postgres);
306        assert!(code.contains("Only"));
307    }
308
309    #[test]
310    fn test_many_variants() {
311        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
312        let e = make_enum("status", variants);
313        let code = gen(&e, DatabaseKind::Postgres);
314        assert!(code.contains("A,"));
315        assert!(code.contains("J,"));
316    }
317
318    #[test]
319    fn test_variant_with_digits() {
320        let e = make_enum("version", vec!["v2"]);
321        let code = gen(&e, DatabaseKind::Postgres);
322        assert!(code.contains("V2"));
323    }
324
325    #[test]
326    fn test_enum_name_with_double_underscores() {
327        let e = make_enum("my__enum", vec!["a"]);
328        let code = gen(&e, DatabaseKind::Postgres);
329        assert!(code.contains("pub enum MyEnum"));
330    }
331
332    // --- impl Default ---
333
334    #[test]
335    fn test_default_impl_generated() {
336        let e = EnumInfo {
337            schema_name: "public".to_string(),
338            name: "task_status".to_string(),
339            variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
340            default_variant: Some("idle".to_string()),
341        };
342        let code = gen(&e, DatabaseKind::Postgres);
343        assert!(code.contains("impl Default for TaskStatus"));
344        assert!(code.contains("Self::Idle"));
345    }
346
347    #[test]
348    fn test_no_default_impl_when_none() {
349        let e = make_enum("status", vec!["active", "inactive"]);
350        let code = gen(&e, DatabaseKind::Postgres);
351        assert!(!code.contains("impl Default"));
352    }
353
354    #[test]
355    fn test_default_impl_snake_case_variant() {
356        let e = EnumInfo {
357            schema_name: "public".to_string(),
358            name: "status".to_string(),
359            variants: vec!["in_progress".to_string(), "done".to_string()],
360            default_variant: Some("in_progress".to_string()),
361        };
362        let code = gen(&e, DatabaseKind::Postgres);
363        assert!(code.contains("impl Default for Status"));
364        assert!(code.contains("Self::InProgress"));
365    }
366
367    // --- public vs named schema integration ---
368
369    fn make_enum_in_schema(schema: &str, name: &str, variants: Vec<&str>) -> EnumInfo {
370        EnumInfo {
371            schema_name: schema.to_string(),
372            name: name.to_string(),
373            variants: variants.into_iter().map(|s| s.to_string()).collect(),
374            default_variant: None,
375        }
376    }
377
378    #[test]
379    fn test_public_schema_full_output() {
380        let e = make_enum_in_schema("public", "order_status", vec!["pending", "shipped", "delivered"]);
381        let code = gen(&e, DatabaseKind::Postgres);
382
383        assert!(code.contains("Enum: public.order_status"));
384        assert!(code.contains("pub enum OrderStatus"));
385        assert!(code.contains("sqlx(type_name = \"order_status\")"));
386        assert!(!code.contains("sqlx(type_name = \"public.order_status\")"));
387        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")"));
388        assert!(code.contains("Pending"));
389        assert!(code.contains("Shipped"));
390        assert!(code.contains("Delivered"));
391    }
392
393    #[test]
394    fn test_named_schema_full_output() {
395        let e = make_enum_in_schema("analysis", "toolcall_status", vec!["PENDING", "RUNNING", "DONE"]);
396        let code = gen(&e, DatabaseKind::Postgres);
397
398        assert!(code.contains("Enum: analysis.toolcall_status"));
399        assert!(code.contains("pub enum ToolcallStatus"));
400        assert!(code.contains("sqlx(type_name = \"analysis.toolcall_status\")"));
401        assert!(!code.contains("sqlx(type_name = \"toolcall_status\")"));
402        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")"));
403        assert!(code.contains("Pending"));
404        assert!(code.contains("Running"));
405        assert!(code.contains("Done"));
406    }
407
408    #[test]
409    fn test_named_schema_with_default_variant() {
410        let e = EnumInfo {
411            schema_name: "billing".to_string(),
412            name: "payment_status".to_string(),
413            variants: vec!["pending".to_string(), "paid".to_string(), "refunded".to_string()],
414            default_variant: Some("pending".to_string()),
415        };
416        let code = gen(&e, DatabaseKind::Postgres);
417
418        assert!(code.contains("sqlx(type_name = \"billing.payment_status\")"));
419        assert!(code.contains("impl Default for PaymentStatus"));
420        assert!(code.contains("Self::Pending"));
421    }
422
423    #[test]
424    fn test_named_schema_variant_rename() {
425        let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]);
426        let code = gen(&e, DatabaseKind::Postgres);
427
428        assert!(code.contains("sqlx(type_name = \"audit.log_level\")"));
429        assert!(code.contains("sqlx(rename = \"info\")"));
430        assert!(code.contains("sqlx(rename = \"warn_high\")"));
431        assert!(code.contains("WarnHigh"));
432        assert!(code.contains("sqlx(rename = \"CRITICAL\")"));
433        assert!(code.contains("Critical"));
434    }
435
436    #[test]
437    fn test_named_schema_mysql_no_type_name() {
438        let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
439        let code = gen(&e, DatabaseKind::Mysql);
440
441        assert!(!code.contains("type_name"));
442    }
443
444    #[test]
445    fn test_named_schema_sqlite_no_type_name() {
446        let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
447        let code = gen(&e, DatabaseKind::Sqlite);
448
449        assert!(!code.contains("type_name"));
450    }
451}