Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20
21    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    imports.insert("use serde::{Serialize, Deserialize};".to_string());
25    let mut derive_tokens = vec![
26        quote! { Debug },
27        quote! { Clone },
28        quote! { PartialEq },
29        quote! { Eq },
30        quote! { Serialize },
31        quote! { Deserialize },
32        quote! { sqlx::Type },
33    ];
34    for d in extra_derives {
35        let ident = format_ident!("{}", d);
36        derive_tokens.push(quote! { #ident });
37    }
38
39    // For PG, add #[sqlx(type_name = "...")]
40    // Schema-qualify the type name for non-public schemas so sqlx can find the type
41    let type_attr = if db_kind == DatabaseKind::Postgres {
42        let pg_name = if enum_info.schema_name != "public" {
43            format!("{}.{}", enum_info.schema_name, enum_info.name)
44        } else {
45            enum_info.name.clone()
46        };
47        quote! { #[sqlx(type_name = #pg_name)] }
48    } else {
49        quote! {}
50    };
51
52    let variants: Vec<TokenStream> = enum_info
53        .variants
54        .iter()
55        .map(|v| {
56            let variant_pascal = v.to_upper_camel_case();
57            let variant_ident = format_ident!("{}", variant_pascal);
58
59            let rename = if variant_pascal != *v {
60                quote! { #[sqlx(rename = #v)] }
61            } else {
62                quote! {}
63            };
64
65            quote! {
66                #rename
67                #variant_ident,
68            }
69        })
70        .collect();
71
72    let tokens = quote! {
73        #[doc = #doc]
74        #[derive(#(#derive_tokens),*)]
75        #type_attr
76        pub enum #enum_name {
77            #(#variants)*
78        }
79    };
80
81    (tokens, imports)
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::codegen::parse_and_format;
88
89    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
90        EnumInfo {
91            schema_name: "public".to_string(),
92            name: name.to_string(),
93            variants: variants.into_iter().map(|s| s.to_string()).collect(),
94        }
95    }
96
97    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
98        let (tokens, _) = generate_enum(info, db, &[]);
99        parse_and_format(&tokens)
100    }
101
102    fn gen_with_derives(
103        info: &EnumInfo,
104        db: DatabaseKind,
105        derives: &[String],
106    ) -> (String, BTreeSet<String>) {
107        let (tokens, imports) = generate_enum(info, db, derives);
108        (parse_and_format(&tokens), imports)
109    }
110
111    // --- basic structure ---
112
113    #[test]
114    fn test_enum_variants() {
115        let e = make_enum("status", vec!["active", "inactive"]);
116        let code = gen(&e, DatabaseKind::Postgres);
117        assert!(code.contains("Active"));
118        assert!(code.contains("Inactive"));
119    }
120
121    #[test]
122    fn test_enum_name_pascal_case() {
123        let e = make_enum("user_status", vec!["a"]);
124        let code = gen(&e, DatabaseKind::Postgres);
125        assert!(code.contains("pub enum UserStatus"));
126    }
127
128    #[test]
129    fn test_doc_comment() {
130        let e = make_enum("status", vec!["a"]);
131        let code = gen(&e, DatabaseKind::Postgres);
132        assert!(code.contains("Enum: public.status"));
133    }
134
135    // --- sqlx attributes ---
136
137    #[test]
138    fn test_postgres_has_type_name() {
139        let e = make_enum("user_status", vec!["a"]);
140        let code = gen(&e, DatabaseKind::Postgres);
141        assert!(code.contains("sqlx(type_name = \"user_status\")"));
142    }
143
144    #[test]
145    fn test_postgres_non_public_schema_qualified_type_name() {
146        let e = EnumInfo {
147            schema_name: "auth".to_string(),
148            name: "role".to_string(),
149            variants: vec!["admin".to_string(), "user".to_string()],
150        };
151        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
152        let code = parse_and_format(&tokens);
153        assert!(code.contains("sqlx(type_name = \"auth.role\")"));
154    }
155
156    #[test]
157    fn test_postgres_public_schema_not_qualified() {
158        let e = make_enum("status", vec!["a"]);
159        let code = gen(&e, DatabaseKind::Postgres);
160        assert!(code.contains("sqlx(type_name = \"status\")"));
161        // type_name should NOT be schema-qualified for public schema
162        assert!(!code.contains("type_name = \"public.status\""));
163    }
164
165    #[test]
166    fn test_mysql_no_type_name() {
167        let e = make_enum("status", vec!["a"]);
168        let code = gen(&e, DatabaseKind::Mysql);
169        assert!(!code.contains("type_name"));
170    }
171
172    #[test]
173    fn test_sqlite_no_type_name() {
174        let e = make_enum("status", vec!["a"]);
175        let code = gen(&e, DatabaseKind::Sqlite);
176        assert!(!code.contains("type_name"));
177    }
178
179    // --- rename variants ---
180
181    #[test]
182    fn test_snake_case_variant_renamed() {
183        let e = make_enum("status", vec!["in_progress"]);
184        let code = gen(&e, DatabaseKind::Postgres);
185        assert!(code.contains("InProgress"));
186        assert!(code.contains("sqlx(rename = \"in_progress\")"));
187    }
188
189    #[test]
190    fn test_lowercase_variant_renamed() {
191        let e = make_enum("status", vec!["active"]);
192        let code = gen(&e, DatabaseKind::Postgres);
193        assert!(code.contains("Active"));
194        assert!(code.contains("sqlx(rename = \"active\")"));
195    }
196
197    #[test]
198    fn test_already_pascal_no_rename() {
199        let e = make_enum("status", vec!["Active"]);
200        let code = gen(&e, DatabaseKind::Postgres);
201        assert!(code.contains("Active"));
202        assert!(!code.contains("sqlx(rename"));
203    }
204
205    #[test]
206    fn test_upper_case_variant_renamed() {
207        let e = make_enum("status", vec!["UPPER_CASE"]);
208        let code = gen(&e, DatabaseKind::Postgres);
209        assert!(code.contains("UpperCase"));
210        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
211    }
212
213    // --- derives ---
214
215    #[test]
216    fn test_default_derives() {
217        let e = make_enum("status", vec!["a"]);
218        let code = gen(&e, DatabaseKind::Postgres);
219        assert!(code.contains("Debug"));
220        assert!(code.contains("Clone"));
221        assert!(code.contains("PartialEq"));
222        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
223    }
224
225    #[test]
226    fn test_extra_derive_serialize() {
227        let e = make_enum("status", vec!["a"]);
228        let derives = vec!["Serialize".to_string()];
229        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
230        assert!(code.contains("Serialize"));
231    }
232
233    #[test]
234    fn test_extra_derives_serde_imports() {
235        let e = make_enum("status", vec!["a"]);
236        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
237        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
238        assert!(imports.iter().any(|i| i.contains("serde")));
239    }
240
241    // --- imports ---
242
243    #[test]
244    fn test_no_extra_derives_has_serde_import() {
245        let e = make_enum("status", vec!["a"]);
246        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
247        assert!(imports.iter().any(|i| i.contains("serde")));
248    }
249
250    #[test]
251    fn test_serde_import_present() {
252        let e = make_enum("status", vec!["a"]);
253        let derives = vec!["Serialize".to_string()];
254        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
255        assert!(!imports.is_empty());
256    }
257
258    // --- edge cases ---
259
260    #[test]
261    fn test_single_variant() {
262        let e = make_enum("status", vec!["only"]);
263        let code = gen(&e, DatabaseKind::Postgres);
264        assert!(code.contains("Only"));
265    }
266
267    #[test]
268    fn test_many_variants() {
269        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
270        let e = make_enum("status", variants);
271        let code = gen(&e, DatabaseKind::Postgres);
272        assert!(code.contains("A,"));
273        assert!(code.contains("J,"));
274    }
275
276    #[test]
277    fn test_variant_with_digits() {
278        let e = make_enum("version", vec!["v2"]);
279        let code = gen(&e, DatabaseKind::Postgres);
280        assert!(code.contains("V2"));
281    }
282
283    #[test]
284    fn test_enum_name_with_double_underscores() {
285        let e = make_enum("my__enum", vec!["a"]);
286        let code = gen(&e, DatabaseKind::Postgres);
287        assert!(code.contains("pub enum MyEnum"));
288    }
289}