Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20
21    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    imports.insert("use serde::{Serialize, Deserialize};".to_string());
25    imports.insert("use sqlx_gen::SqlxGen;".to_string());
26    let mut derive_tokens = vec![
27        quote! { Debug },
28        quote! { Clone },
29        quote! { PartialEq },
30        quote! { Eq },
31        quote! { Serialize },
32        quote! { Deserialize },
33        quote! { sqlx::Type },
34        quote! { SqlxGen },
35    ];
36    for d in extra_derives {
37        let ident = format_ident!("{}", d);
38        derive_tokens.push(quote! { #ident });
39    }
40
41    // For PG, add #[sqlx(type_name = "...")]
42    let type_attr = if db_kind == DatabaseKind::Postgres {
43        let pg_name = &enum_info.name;
44        quote! { #[sqlx(type_name = #pg_name)] }
45    } else {
46        quote! {}
47    };
48
49    let variants: Vec<TokenStream> = enum_info
50        .variants
51        .iter()
52        .map(|v| {
53            let variant_pascal = v.to_upper_camel_case();
54            let variant_ident = format_ident!("{}", variant_pascal);
55
56            let rename = if variant_pascal != *v {
57                quote! { #[sqlx(rename = #v)] }
58            } else {
59                quote! {}
60            };
61
62            quote! {
63                #rename
64                #variant_ident,
65            }
66        })
67        .collect();
68
69    let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
70        let variant_pascal = default_variant.to_upper_camel_case();
71        let variant_ident = format_ident!("{}", variant_pascal);
72        quote! {
73            impl Default for #enum_name {
74                fn default() -> Self {
75                    Self::#variant_ident
76                }
77            }
78        }
79    } else {
80        quote! {}
81    };
82
83    let schema_name_str = &enum_info.schema_name;
84    let enum_name_str = &enum_info.name;
85
86    let tokens = quote! {
87        #[doc = #doc]
88        #[derive(#(#derive_tokens),*)]
89        #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
90        #type_attr
91        pub enum #enum_name {
92            #(#variants)*
93        }
94
95        #default_impl
96    };
97
98    (tokens, imports)
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::codegen::parse_and_format;
105
106    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
107        EnumInfo {
108            schema_name: "public".to_string(),
109            name: name.to_string(),
110            variants: variants.into_iter().map(|s| s.to_string()).collect(),
111            default_variant: None,
112        }
113    }
114
115    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
116        let (tokens, _) = generate_enum(info, db, &[]);
117        parse_and_format(&tokens)
118    }
119
120    fn gen_with_derives(
121        info: &EnumInfo,
122        db: DatabaseKind,
123        derives: &[String],
124    ) -> (String, BTreeSet<String>) {
125        let (tokens, imports) = generate_enum(info, db, derives);
126        (parse_and_format(&tokens), imports)
127    }
128
129    // --- basic structure ---
130
131    #[test]
132    fn test_enum_variants() {
133        let e = make_enum("status", vec!["active", "inactive"]);
134        let code = gen(&e, DatabaseKind::Postgres);
135        assert!(code.contains("Active"));
136        assert!(code.contains("Inactive"));
137    }
138
139    #[test]
140    fn test_enum_name_pascal_case() {
141        let e = make_enum("user_status", vec!["a"]);
142        let code = gen(&e, DatabaseKind::Postgres);
143        assert!(code.contains("pub enum UserStatus"));
144    }
145
146    #[test]
147    fn test_doc_comment() {
148        let e = make_enum("status", vec!["a"]);
149        let code = gen(&e, DatabaseKind::Postgres);
150        assert!(code.contains("Enum: public.status"));
151    }
152
153    #[test]
154    fn test_sqlx_gen_attr_has_schema_and_name() {
155        let e = make_enum("status", vec!["a"]);
156        let code = gen(&e, DatabaseKind::Postgres);
157        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
158    }
159
160    #[test]
161    fn test_sqlx_gen_attr_non_public_schema() {
162        let e = EnumInfo {
163            schema_name: "auth".to_string(),
164            name: "role".to_string(),
165            variants: vec!["admin".to_string(), "user".to_string()],
166            default_variant: None,
167        };
168        let code = gen(&e, DatabaseKind::Postgres);
169        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
170    }
171
172    // --- sqlx attributes ---
173
174    #[test]
175    fn test_postgres_has_type_name() {
176        let e = make_enum("user_status", vec!["a"]);
177        let code = gen(&e, DatabaseKind::Postgres);
178        assert!(code.contains("sqlx(type_name = \"user_status\")"));
179    }
180
181    #[test]
182    fn test_postgres_non_public_schema_qualified_type_name() {
183        let e = EnumInfo {
184            schema_name: "auth".to_string(),
185            name: "role".to_string(),
186            variants: vec!["admin".to_string(), "user".to_string()],
187            default_variant: None,
188        };
189        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
190        let code = parse_and_format(&tokens);
191        assert!(code.contains("sqlx(type_name = \"role\")"));
192    }
193
194    #[test]
195    fn test_postgres_public_schema_not_qualified() {
196        let e = make_enum("status", vec!["a"]);
197        let code = gen(&e, DatabaseKind::Postgres);
198        assert!(code.contains("sqlx(type_name = \"status\")"));
199        // type_name should NOT be schema-qualified for public schema
200        assert!(!code.contains("type_name = \"public.status\""));
201    }
202
203    #[test]
204    fn test_mysql_no_type_name() {
205        let e = make_enum("status", vec!["a"]);
206        let code = gen(&e, DatabaseKind::Mysql);
207        assert!(!code.contains("type_name"));
208    }
209
210    #[test]
211    fn test_sqlite_no_type_name() {
212        let e = make_enum("status", vec!["a"]);
213        let code = gen(&e, DatabaseKind::Sqlite);
214        assert!(!code.contains("type_name"));
215    }
216
217    // --- rename variants ---
218
219    #[test]
220    fn test_snake_case_variant_renamed() {
221        let e = make_enum("status", vec!["in_progress"]);
222        let code = gen(&e, DatabaseKind::Postgres);
223        assert!(code.contains("InProgress"));
224        assert!(code.contains("sqlx(rename = \"in_progress\")"));
225    }
226
227    #[test]
228    fn test_lowercase_variant_renamed() {
229        let e = make_enum("status", vec!["active"]);
230        let code = gen(&e, DatabaseKind::Postgres);
231        assert!(code.contains("Active"));
232        assert!(code.contains("sqlx(rename = \"active\")"));
233    }
234
235    #[test]
236    fn test_already_pascal_no_rename() {
237        let e = make_enum("status", vec!["Active"]);
238        let code = gen(&e, DatabaseKind::Postgres);
239        assert!(code.contains("Active"));
240        assert!(!code.contains("sqlx(rename"));
241    }
242
243    #[test]
244    fn test_upper_case_variant_renamed() {
245        let e = make_enum("status", vec!["UPPER_CASE"]);
246        let code = gen(&e, DatabaseKind::Postgres);
247        assert!(code.contains("UpperCase"));
248        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
249    }
250
251    // --- derives ---
252
253    #[test]
254    fn test_default_derives() {
255        let e = make_enum("status", vec!["a"]);
256        let code = gen(&e, DatabaseKind::Postgres);
257        assert!(code.contains("Debug"));
258        assert!(code.contains("Clone"));
259        assert!(code.contains("PartialEq"));
260        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
261    }
262
263    #[test]
264    fn test_extra_derive_serialize() {
265        let e = make_enum("status", vec!["a"]);
266        let derives = vec!["Serialize".to_string()];
267        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
268        assert!(code.contains("Serialize"));
269    }
270
271    #[test]
272    fn test_extra_derives_serde_imports() {
273        let e = make_enum("status", vec!["a"]);
274        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
275        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
276        assert!(imports.iter().any(|i| i.contains("serde")));
277    }
278
279    // --- imports ---
280
281    #[test]
282    fn test_no_extra_derives_has_serde_import() {
283        let e = make_enum("status", vec!["a"]);
284        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
285        assert!(imports.iter().any(|i| i.contains("serde")));
286    }
287
288    #[test]
289    fn test_serde_import_present() {
290        let e = make_enum("status", vec!["a"]);
291        let derives = vec!["Serialize".to_string()];
292        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
293        assert!(!imports.is_empty());
294    }
295
296    // --- edge cases ---
297
298    #[test]
299    fn test_single_variant() {
300        let e = make_enum("status", vec!["only"]);
301        let code = gen(&e, DatabaseKind::Postgres);
302        assert!(code.contains("Only"));
303    }
304
305    #[test]
306    fn test_many_variants() {
307        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
308        let e = make_enum("status", variants);
309        let code = gen(&e, DatabaseKind::Postgres);
310        assert!(code.contains("A,"));
311        assert!(code.contains("J,"));
312    }
313
314    #[test]
315    fn test_variant_with_digits() {
316        let e = make_enum("version", vec!["v2"]);
317        let code = gen(&e, DatabaseKind::Postgres);
318        assert!(code.contains("V2"));
319    }
320
321    #[test]
322    fn test_enum_name_with_double_underscores() {
323        let e = make_enum("my__enum", vec!["a"]);
324        let code = gen(&e, DatabaseKind::Postgres);
325        assert!(code.contains("pub enum MyEnum"));
326    }
327
328    // --- impl Default ---
329
330    #[test]
331    fn test_default_impl_generated() {
332        let e = EnumInfo {
333            schema_name: "public".to_string(),
334            name: "task_status".to_string(),
335            variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
336            default_variant: Some("idle".to_string()),
337        };
338        let code = gen(&e, DatabaseKind::Postgres);
339        assert!(code.contains("impl Default for TaskStatus"));
340        assert!(code.contains("Self::Idle"));
341    }
342
343    #[test]
344    fn test_no_default_impl_when_none() {
345        let e = make_enum("status", vec!["active", "inactive"]);
346        let code = gen(&e, DatabaseKind::Postgres);
347        assert!(!code.contains("impl Default"));
348    }
349
350    #[test]
351    fn test_default_impl_snake_case_variant() {
352        let e = EnumInfo {
353            schema_name: "public".to_string(),
354            name: "status".to_string(),
355            variants: vec!["in_progress".to_string(), "done".to_string()],
356            default_variant: Some("in_progress".to_string()),
357        };
358        let code = gen(&e, DatabaseKind::Postgres);
359        assert!(code.contains("impl Default for Status"));
360        assert!(code.contains("Self::InProgress"));
361    }
362}