Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20
21    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    imports.insert("use serde::{Serialize, Deserialize};".to_string());
25    imports.insert("use sqlx_gen::SqlxGen;".to_string());
26    let mut derive_tokens = vec![
27        quote! { Debug },
28        quote! { Clone },
29        quote! { PartialEq },
30        quote! { Eq },
31        quote! { Serialize },
32        quote! { Deserialize },
33        quote! { sqlx::Type },
34        quote! { SqlxGen },
35    ];
36    for d in extra_derives {
37        let ident = format_ident!("{}", d);
38        derive_tokens.push(quote! { #ident });
39    }
40
41    // For PG, add #[sqlx(type_name = "...")]
42    // Schema-qualify the type name for non-public schemas so sqlx can find the type
43    let type_attr = if db_kind == DatabaseKind::Postgres {
44        let pg_name = if enum_info.schema_name != "public" {
45            format!("{}.{}", enum_info.schema_name, enum_info.name)
46        } else {
47            enum_info.name.clone()
48        };
49        quote! { #[sqlx(type_name = #pg_name)] }
50    } else {
51        quote! {}
52    };
53
54    let variants: Vec<TokenStream> = enum_info
55        .variants
56        .iter()
57        .map(|v| {
58            let variant_pascal = v.to_upper_camel_case();
59            let variant_ident = format_ident!("{}", variant_pascal);
60
61            let rename = if variant_pascal != *v {
62                quote! { #[sqlx(rename = #v)] }
63            } else {
64                quote! {}
65            };
66
67            quote! {
68                #rename
69                #variant_ident,
70            }
71        })
72        .collect();
73
74    let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
75        let variant_pascal = default_variant.to_upper_camel_case();
76        let variant_ident = format_ident!("{}", variant_pascal);
77        quote! {
78            impl Default for #enum_name {
79                fn default() -> Self {
80                    Self::#variant_ident
81                }
82            }
83        }
84    } else {
85        quote! {}
86    };
87
88    let schema_name_str = &enum_info.schema_name;
89    let enum_name_str = &enum_info.name;
90
91    let tokens = quote! {
92        #[doc = #doc]
93        #[derive(#(#derive_tokens),*)]
94        #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
95        #type_attr
96        pub enum #enum_name {
97            #(#variants)*
98        }
99
100        #default_impl
101    };
102
103    (tokens, imports)
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::codegen::parse_and_format;
110
111    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
112        EnumInfo {
113            schema_name: "public".to_string(),
114            name: name.to_string(),
115            variants: variants.into_iter().map(|s| s.to_string()).collect(),
116            default_variant: None,
117        }
118    }
119
120    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
121        let (tokens, _) = generate_enum(info, db, &[]);
122        parse_and_format(&tokens)
123    }
124
125    fn gen_with_derives(
126        info: &EnumInfo,
127        db: DatabaseKind,
128        derives: &[String],
129    ) -> (String, BTreeSet<String>) {
130        let (tokens, imports) = generate_enum(info, db, derives);
131        (parse_and_format(&tokens), imports)
132    }
133
134    // --- basic structure ---
135
136    #[test]
137    fn test_enum_variants() {
138        let e = make_enum("status", vec!["active", "inactive"]);
139        let code = gen(&e, DatabaseKind::Postgres);
140        assert!(code.contains("Active"));
141        assert!(code.contains("Inactive"));
142    }
143
144    #[test]
145    fn test_enum_name_pascal_case() {
146        let e = make_enum("user_status", vec!["a"]);
147        let code = gen(&e, DatabaseKind::Postgres);
148        assert!(code.contains("pub enum UserStatus"));
149    }
150
151    #[test]
152    fn test_doc_comment() {
153        let e = make_enum("status", vec!["a"]);
154        let code = gen(&e, DatabaseKind::Postgres);
155        assert!(code.contains("Enum: public.status"));
156    }
157
158    #[test]
159    fn test_sqlx_gen_attr_has_schema_and_name() {
160        let e = make_enum("status", vec!["a"]);
161        let code = gen(&e, DatabaseKind::Postgres);
162        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
163    }
164
165    #[test]
166    fn test_sqlx_gen_attr_non_public_schema() {
167        let e = EnumInfo {
168            schema_name: "auth".to_string(),
169            name: "role".to_string(),
170            variants: vec!["admin".to_string(), "user".to_string()],
171            default_variant: None,
172        };
173        let code = gen(&e, DatabaseKind::Postgres);
174        assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
175    }
176
177    // --- sqlx attributes ---
178
179    #[test]
180    fn test_postgres_has_type_name() {
181        let e = make_enum("user_status", vec!["a"]);
182        let code = gen(&e, DatabaseKind::Postgres);
183        assert!(code.contains("sqlx(type_name = \"user_status\")"));
184    }
185
186    #[test]
187    fn test_postgres_non_public_schema_qualified_type_name() {
188        let e = EnumInfo {
189            schema_name: "auth".to_string(),
190            name: "role".to_string(),
191            variants: vec!["admin".to_string(), "user".to_string()],
192            default_variant: None,
193        };
194        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
195        let code = parse_and_format(&tokens);
196        assert!(code.contains("sqlx(type_name = \"auth.role\")"));
197    }
198
199    #[test]
200    fn test_postgres_public_schema_not_qualified() {
201        let e = make_enum("status", vec!["a"]);
202        let code = gen(&e, DatabaseKind::Postgres);
203        assert!(code.contains("sqlx(type_name = \"status\")"));
204        // type_name should NOT be schema-qualified for public schema
205        assert!(!code.contains("type_name = \"public.status\""));
206    }
207
208    #[test]
209    fn test_mysql_no_type_name() {
210        let e = make_enum("status", vec!["a"]);
211        let code = gen(&e, DatabaseKind::Mysql);
212        assert!(!code.contains("type_name"));
213    }
214
215    #[test]
216    fn test_sqlite_no_type_name() {
217        let e = make_enum("status", vec!["a"]);
218        let code = gen(&e, DatabaseKind::Sqlite);
219        assert!(!code.contains("type_name"));
220    }
221
222    // --- rename variants ---
223
224    #[test]
225    fn test_snake_case_variant_renamed() {
226        let e = make_enum("status", vec!["in_progress"]);
227        let code = gen(&e, DatabaseKind::Postgres);
228        assert!(code.contains("InProgress"));
229        assert!(code.contains("sqlx(rename = \"in_progress\")"));
230    }
231
232    #[test]
233    fn test_lowercase_variant_renamed() {
234        let e = make_enum("status", vec!["active"]);
235        let code = gen(&e, DatabaseKind::Postgres);
236        assert!(code.contains("Active"));
237        assert!(code.contains("sqlx(rename = \"active\")"));
238    }
239
240    #[test]
241    fn test_already_pascal_no_rename() {
242        let e = make_enum("status", vec!["Active"]);
243        let code = gen(&e, DatabaseKind::Postgres);
244        assert!(code.contains("Active"));
245        assert!(!code.contains("sqlx(rename"));
246    }
247
248    #[test]
249    fn test_upper_case_variant_renamed() {
250        let e = make_enum("status", vec!["UPPER_CASE"]);
251        let code = gen(&e, DatabaseKind::Postgres);
252        assert!(code.contains("UpperCase"));
253        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
254    }
255
256    // --- derives ---
257
258    #[test]
259    fn test_default_derives() {
260        let e = make_enum("status", vec!["a"]);
261        let code = gen(&e, DatabaseKind::Postgres);
262        assert!(code.contains("Debug"));
263        assert!(code.contains("Clone"));
264        assert!(code.contains("PartialEq"));
265        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
266    }
267
268    #[test]
269    fn test_extra_derive_serialize() {
270        let e = make_enum("status", vec!["a"]);
271        let derives = vec!["Serialize".to_string()];
272        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
273        assert!(code.contains("Serialize"));
274    }
275
276    #[test]
277    fn test_extra_derives_serde_imports() {
278        let e = make_enum("status", vec!["a"]);
279        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
280        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
281        assert!(imports.iter().any(|i| i.contains("serde")));
282    }
283
284    // --- imports ---
285
286    #[test]
287    fn test_no_extra_derives_has_serde_import() {
288        let e = make_enum("status", vec!["a"]);
289        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
290        assert!(imports.iter().any(|i| i.contains("serde")));
291    }
292
293    #[test]
294    fn test_serde_import_present() {
295        let e = make_enum("status", vec!["a"]);
296        let derives = vec!["Serialize".to_string()];
297        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
298        assert!(!imports.is_empty());
299    }
300
301    // --- edge cases ---
302
303    #[test]
304    fn test_single_variant() {
305        let e = make_enum("status", vec!["only"]);
306        let code = gen(&e, DatabaseKind::Postgres);
307        assert!(code.contains("Only"));
308    }
309
310    #[test]
311    fn test_many_variants() {
312        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
313        let e = make_enum("status", variants);
314        let code = gen(&e, DatabaseKind::Postgres);
315        assert!(code.contains("A,"));
316        assert!(code.contains("J,"));
317    }
318
319    #[test]
320    fn test_variant_with_digits() {
321        let e = make_enum("version", vec!["v2"]);
322        let code = gen(&e, DatabaseKind::Postgres);
323        assert!(code.contains("V2"));
324    }
325
326    #[test]
327    fn test_enum_name_with_double_underscores() {
328        let e = make_enum("my__enum", vec!["a"]);
329        let code = gen(&e, DatabaseKind::Postgres);
330        assert!(code.contains("pub enum MyEnum"));
331    }
332
333    // --- impl Default ---
334
335    #[test]
336    fn test_default_impl_generated() {
337        let e = EnumInfo {
338            schema_name: "public".to_string(),
339            name: "task_status".to_string(),
340            variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
341            default_variant: Some("idle".to_string()),
342        };
343        let code = gen(&e, DatabaseKind::Postgres);
344        assert!(code.contains("impl Default for TaskStatus"));
345        assert!(code.contains("Self::Idle"));
346    }
347
348    #[test]
349    fn test_no_default_impl_when_none() {
350        let e = make_enum("status", vec!["active", "inactive"]);
351        let code = gen(&e, DatabaseKind::Postgres);
352        assert!(!code.contains("impl Default"));
353    }
354
355    #[test]
356    fn test_default_impl_snake_case_variant() {
357        let e = EnumInfo {
358            schema_name: "public".to_string(),
359            name: "status".to_string(),
360            variants: vec!["in_progress".to_string(), "done".to_string()],
361            default_variant: Some("in_progress".to_string()),
362        };
363        let code = gen(&e, DatabaseKind::Postgres);
364        assert!(code.contains("impl Default for Status"));
365        assert!(code.contains("Self::InProgress"));
366    }
367}