Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20
21    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    imports.insert("use serde::{Serialize, Deserialize};".to_string());
25    imports.insert("use sqlx_gen::SqlxGen;".to_string());
26    let mut derive_tokens = vec![
27        quote! { Debug },
28        quote! { Clone },
29        quote! { PartialEq },
30        quote! { Eq },
31        quote! { Serialize },
32        quote! { Deserialize },
33        quote! { sqlx::Type },
34        quote! { SqlxGen },
35    ];
36    for d in extra_derives {
37        let ident = format_ident!("{}", d);
38        derive_tokens.push(quote! { #ident });
39    }
40
41    // For PG, add #[sqlx(type_name = "...")]
42    // Schema-qualify the type name for non-public schemas so sqlx can find the type
43    let type_attr = if db_kind == DatabaseKind::Postgres {
44        let pg_name = if enum_info.schema_name != "public" {
45            format!("{}.{}", enum_info.schema_name, enum_info.name)
46        } else {
47            enum_info.name.clone()
48        };
49        quote! { #[sqlx(type_name = #pg_name)] }
50    } else {
51        quote! {}
52    };
53
54    let variants: Vec<TokenStream> = enum_info
55        .variants
56        .iter()
57        .map(|v| {
58            let variant_pascal = v.to_upper_camel_case();
59            let variant_ident = format_ident!("{}", variant_pascal);
60
61            let rename = if variant_pascal != *v {
62                quote! { #[sqlx(rename = #v)] }
63            } else {
64                quote! {}
65            };
66
67            quote! {
68                #rename
69                #variant_ident,
70            }
71        })
72        .collect();
73
74    let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
75        let variant_pascal = default_variant.to_upper_camel_case();
76        let variant_ident = format_ident!("{}", variant_pascal);
77        quote! {
78            impl Default for #enum_name {
79                fn default() -> Self {
80                    Self::#variant_ident
81                }
82            }
83        }
84    } else {
85        quote! {}
86    };
87
88    let tokens = quote! {
89        #[doc = #doc]
90        #[derive(#(#derive_tokens),*)]
91        #[sqlx_gen(kind = "enum")]
92        #type_attr
93        pub enum #enum_name {
94            #(#variants)*
95        }
96
97        #default_impl
98    };
99
100    (tokens, imports)
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::codegen::parse_and_format;
107
108    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
109        EnumInfo {
110            schema_name: "public".to_string(),
111            name: name.to_string(),
112            variants: variants.into_iter().map(|s| s.to_string()).collect(),
113            default_variant: None,
114        }
115    }
116
117    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
118        let (tokens, _) = generate_enum(info, db, &[]);
119        parse_and_format(&tokens)
120    }
121
122    fn gen_with_derives(
123        info: &EnumInfo,
124        db: DatabaseKind,
125        derives: &[String],
126    ) -> (String, BTreeSet<String>) {
127        let (tokens, imports) = generate_enum(info, db, derives);
128        (parse_and_format(&tokens), imports)
129    }
130
131    // --- basic structure ---
132
133    #[test]
134    fn test_enum_variants() {
135        let e = make_enum("status", vec!["active", "inactive"]);
136        let code = gen(&e, DatabaseKind::Postgres);
137        assert!(code.contains("Active"));
138        assert!(code.contains("Inactive"));
139    }
140
141    #[test]
142    fn test_enum_name_pascal_case() {
143        let e = make_enum("user_status", vec!["a"]);
144        let code = gen(&e, DatabaseKind::Postgres);
145        assert!(code.contains("pub enum UserStatus"));
146    }
147
148    #[test]
149    fn test_doc_comment() {
150        let e = make_enum("status", vec!["a"]);
151        let code = gen(&e, DatabaseKind::Postgres);
152        assert!(code.contains("Enum: public.status"));
153    }
154
155    // --- sqlx attributes ---
156
157    #[test]
158    fn test_postgres_has_type_name() {
159        let e = make_enum("user_status", vec!["a"]);
160        let code = gen(&e, DatabaseKind::Postgres);
161        assert!(code.contains("sqlx(type_name = \"user_status\")"));
162    }
163
164    #[test]
165    fn test_postgres_non_public_schema_qualified_type_name() {
166        let e = EnumInfo {
167            schema_name: "auth".to_string(),
168            name: "role".to_string(),
169            variants: vec!["admin".to_string(), "user".to_string()],
170            default_variant: None,
171        };
172        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
173        let code = parse_and_format(&tokens);
174        assert!(code.contains("sqlx(type_name = \"auth.role\")"));
175    }
176
177    #[test]
178    fn test_postgres_public_schema_not_qualified() {
179        let e = make_enum("status", vec!["a"]);
180        let code = gen(&e, DatabaseKind::Postgres);
181        assert!(code.contains("sqlx(type_name = \"status\")"));
182        // type_name should NOT be schema-qualified for public schema
183        assert!(!code.contains("type_name = \"public.status\""));
184    }
185
186    #[test]
187    fn test_mysql_no_type_name() {
188        let e = make_enum("status", vec!["a"]);
189        let code = gen(&e, DatabaseKind::Mysql);
190        assert!(!code.contains("type_name"));
191    }
192
193    #[test]
194    fn test_sqlite_no_type_name() {
195        let e = make_enum("status", vec!["a"]);
196        let code = gen(&e, DatabaseKind::Sqlite);
197        assert!(!code.contains("type_name"));
198    }
199
200    // --- rename variants ---
201
202    #[test]
203    fn test_snake_case_variant_renamed() {
204        let e = make_enum("status", vec!["in_progress"]);
205        let code = gen(&e, DatabaseKind::Postgres);
206        assert!(code.contains("InProgress"));
207        assert!(code.contains("sqlx(rename = \"in_progress\")"));
208    }
209
210    #[test]
211    fn test_lowercase_variant_renamed() {
212        let e = make_enum("status", vec!["active"]);
213        let code = gen(&e, DatabaseKind::Postgres);
214        assert!(code.contains("Active"));
215        assert!(code.contains("sqlx(rename = \"active\")"));
216    }
217
218    #[test]
219    fn test_already_pascal_no_rename() {
220        let e = make_enum("status", vec!["Active"]);
221        let code = gen(&e, DatabaseKind::Postgres);
222        assert!(code.contains("Active"));
223        assert!(!code.contains("sqlx(rename"));
224    }
225
226    #[test]
227    fn test_upper_case_variant_renamed() {
228        let e = make_enum("status", vec!["UPPER_CASE"]);
229        let code = gen(&e, DatabaseKind::Postgres);
230        assert!(code.contains("UpperCase"));
231        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
232    }
233
234    // --- derives ---
235
236    #[test]
237    fn test_default_derives() {
238        let e = make_enum("status", vec!["a"]);
239        let code = gen(&e, DatabaseKind::Postgres);
240        assert!(code.contains("Debug"));
241        assert!(code.contains("Clone"));
242        assert!(code.contains("PartialEq"));
243        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
244    }
245
246    #[test]
247    fn test_extra_derive_serialize() {
248        let e = make_enum("status", vec!["a"]);
249        let derives = vec!["Serialize".to_string()];
250        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
251        assert!(code.contains("Serialize"));
252    }
253
254    #[test]
255    fn test_extra_derives_serde_imports() {
256        let e = make_enum("status", vec!["a"]);
257        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
258        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
259        assert!(imports.iter().any(|i| i.contains("serde")));
260    }
261
262    // --- imports ---
263
264    #[test]
265    fn test_no_extra_derives_has_serde_import() {
266        let e = make_enum("status", vec!["a"]);
267        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
268        assert!(imports.iter().any(|i| i.contains("serde")));
269    }
270
271    #[test]
272    fn test_serde_import_present() {
273        let e = make_enum("status", vec!["a"]);
274        let derives = vec!["Serialize".to_string()];
275        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
276        assert!(!imports.is_empty());
277    }
278
279    // --- edge cases ---
280
281    #[test]
282    fn test_single_variant() {
283        let e = make_enum("status", vec!["only"]);
284        let code = gen(&e, DatabaseKind::Postgres);
285        assert!(code.contains("Only"));
286    }
287
288    #[test]
289    fn test_many_variants() {
290        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
291        let e = make_enum("status", variants);
292        let code = gen(&e, DatabaseKind::Postgres);
293        assert!(code.contains("A,"));
294        assert!(code.contains("J,"));
295    }
296
297    #[test]
298    fn test_variant_with_digits() {
299        let e = make_enum("version", vec!["v2"]);
300        let code = gen(&e, DatabaseKind::Postgres);
301        assert!(code.contains("V2"));
302    }
303
304    #[test]
305    fn test_enum_name_with_double_underscores() {
306        let e = make_enum("my__enum", vec!["a"]);
307        let code = gen(&e, DatabaseKind::Postgres);
308        assert!(code.contains("pub enum MyEnum"));
309    }
310
311    // --- impl Default ---
312
313    #[test]
314    fn test_default_impl_generated() {
315        let e = EnumInfo {
316            schema_name: "public".to_string(),
317            name: "task_status".to_string(),
318            variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
319            default_variant: Some("idle".to_string()),
320        };
321        let code = gen(&e, DatabaseKind::Postgres);
322        assert!(code.contains("impl Default for TaskStatus"));
323        assert!(code.contains("Self::Idle"));
324    }
325
326    #[test]
327    fn test_no_default_impl_when_none() {
328        let e = make_enum("status", vec!["active", "inactive"]);
329        let code = gen(&e, DatabaseKind::Postgres);
330        assert!(!code.contains("impl Default"));
331    }
332
333    #[test]
334    fn test_default_impl_snake_case_variant() {
335        let e = EnumInfo {
336            schema_name: "public".to_string(),
337            name: "status".to_string(),
338            variants: vec!["in_progress".to_string(), "done".to_string()],
339            default_variant: Some("in_progress".to_string()),
340        };
341        let code = gen(&e, DatabaseKind::Postgres);
342        assert!(code.contains("impl Default for Status"));
343        assert!(code.contains("Self::InProgress"));
344    }
345}