Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20
21    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    imports.insert("use serde::{Serialize, Deserialize};".to_string());
25    imports.insert("use sqlx_gen::SqlxGen;".to_string());
26    let mut derive_tokens = vec![
27        quote! { Debug },
28        quote! { Clone },
29        quote! { PartialEq },
30        quote! { Eq },
31        quote! { Serialize },
32        quote! { Deserialize },
33        quote! { sqlx::Type },
34        quote! { SqlxGen },
35    ];
36    for d in extra_derives {
37        let ident = format_ident!("{}", d);
38        derive_tokens.push(quote! { #ident });
39    }
40
41    // For PG, add #[sqlx(type_name = "...")]
42    // Schema-qualify the type name for non-public schemas so sqlx can find the type
43    let type_attr = if db_kind == DatabaseKind::Postgres {
44        let pg_name = if enum_info.schema_name != "public" {
45            format!("{}.{}", enum_info.schema_name, enum_info.name)
46        } else {
47            enum_info.name.clone()
48        };
49        quote! { #[sqlx(type_name = #pg_name)] }
50    } else {
51        quote! {}
52    };
53
54    let variants: Vec<TokenStream> = enum_info
55        .variants
56        .iter()
57        .map(|v| {
58            let variant_pascal = v.to_upper_camel_case();
59            let variant_ident = format_ident!("{}", variant_pascal);
60
61            let rename = if variant_pascal != *v {
62                quote! { #[sqlx(rename = #v)] }
63            } else {
64                quote! {}
65            };
66
67            quote! {
68                #rename
69                #variant_ident,
70            }
71        })
72        .collect();
73
74    let tokens = quote! {
75        #[doc = #doc]
76        #[derive(#(#derive_tokens),*)]
77        #[sqlx_gen(kind = "enum")]
78        #type_attr
79        pub enum #enum_name {
80            #(#variants)*
81        }
82    };
83
84    (tokens, imports)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use crate::codegen::parse_and_format;
91
92    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
93        EnumInfo {
94            schema_name: "public".to_string(),
95            name: name.to_string(),
96            variants: variants.into_iter().map(|s| s.to_string()).collect(),
97        }
98    }
99
100    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
101        let (tokens, _) = generate_enum(info, db, &[]);
102        parse_and_format(&tokens)
103    }
104
105    fn gen_with_derives(
106        info: &EnumInfo,
107        db: DatabaseKind,
108        derives: &[String],
109    ) -> (String, BTreeSet<String>) {
110        let (tokens, imports) = generate_enum(info, db, derives);
111        (parse_and_format(&tokens), imports)
112    }
113
114    // --- basic structure ---
115
116    #[test]
117    fn test_enum_variants() {
118        let e = make_enum("status", vec!["active", "inactive"]);
119        let code = gen(&e, DatabaseKind::Postgres);
120        assert!(code.contains("Active"));
121        assert!(code.contains("Inactive"));
122    }
123
124    #[test]
125    fn test_enum_name_pascal_case() {
126        let e = make_enum("user_status", vec!["a"]);
127        let code = gen(&e, DatabaseKind::Postgres);
128        assert!(code.contains("pub enum UserStatus"));
129    }
130
131    #[test]
132    fn test_doc_comment() {
133        let e = make_enum("status", vec!["a"]);
134        let code = gen(&e, DatabaseKind::Postgres);
135        assert!(code.contains("Enum: public.status"));
136    }
137
138    // --- sqlx attributes ---
139
140    #[test]
141    fn test_postgres_has_type_name() {
142        let e = make_enum("user_status", vec!["a"]);
143        let code = gen(&e, DatabaseKind::Postgres);
144        assert!(code.contains("sqlx(type_name = \"user_status\")"));
145    }
146
147    #[test]
148    fn test_postgres_non_public_schema_qualified_type_name() {
149        let e = EnumInfo {
150            schema_name: "auth".to_string(),
151            name: "role".to_string(),
152            variants: vec!["admin".to_string(), "user".to_string()],
153        };
154        let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
155        let code = parse_and_format(&tokens);
156        assert!(code.contains("sqlx(type_name = \"auth.role\")"));
157    }
158
159    #[test]
160    fn test_postgres_public_schema_not_qualified() {
161        let e = make_enum("status", vec!["a"]);
162        let code = gen(&e, DatabaseKind::Postgres);
163        assert!(code.contains("sqlx(type_name = \"status\")"));
164        // type_name should NOT be schema-qualified for public schema
165        assert!(!code.contains("type_name = \"public.status\""));
166    }
167
168    #[test]
169    fn test_mysql_no_type_name() {
170        let e = make_enum("status", vec!["a"]);
171        let code = gen(&e, DatabaseKind::Mysql);
172        assert!(!code.contains("type_name"));
173    }
174
175    #[test]
176    fn test_sqlite_no_type_name() {
177        let e = make_enum("status", vec!["a"]);
178        let code = gen(&e, DatabaseKind::Sqlite);
179        assert!(!code.contains("type_name"));
180    }
181
182    // --- rename variants ---
183
184    #[test]
185    fn test_snake_case_variant_renamed() {
186        let e = make_enum("status", vec!["in_progress"]);
187        let code = gen(&e, DatabaseKind::Postgres);
188        assert!(code.contains("InProgress"));
189        assert!(code.contains("sqlx(rename = \"in_progress\")"));
190    }
191
192    #[test]
193    fn test_lowercase_variant_renamed() {
194        let e = make_enum("status", vec!["active"]);
195        let code = gen(&e, DatabaseKind::Postgres);
196        assert!(code.contains("Active"));
197        assert!(code.contains("sqlx(rename = \"active\")"));
198    }
199
200    #[test]
201    fn test_already_pascal_no_rename() {
202        let e = make_enum("status", vec!["Active"]);
203        let code = gen(&e, DatabaseKind::Postgres);
204        assert!(code.contains("Active"));
205        assert!(!code.contains("sqlx(rename"));
206    }
207
208    #[test]
209    fn test_upper_case_variant_renamed() {
210        let e = make_enum("status", vec!["UPPER_CASE"]);
211        let code = gen(&e, DatabaseKind::Postgres);
212        assert!(code.contains("UpperCase"));
213        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
214    }
215
216    // --- derives ---
217
218    #[test]
219    fn test_default_derives() {
220        let e = make_enum("status", vec!["a"]);
221        let code = gen(&e, DatabaseKind::Postgres);
222        assert!(code.contains("Debug"));
223        assert!(code.contains("Clone"));
224        assert!(code.contains("PartialEq"));
225        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
226    }
227
228    #[test]
229    fn test_extra_derive_serialize() {
230        let e = make_enum("status", vec!["a"]);
231        let derives = vec!["Serialize".to_string()];
232        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
233        assert!(code.contains("Serialize"));
234    }
235
236    #[test]
237    fn test_extra_derives_serde_imports() {
238        let e = make_enum("status", vec!["a"]);
239        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
240        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
241        assert!(imports.iter().any(|i| i.contains("serde")));
242    }
243
244    // --- imports ---
245
246    #[test]
247    fn test_no_extra_derives_has_serde_import() {
248        let e = make_enum("status", vec!["a"]);
249        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
250        assert!(imports.iter().any(|i| i.contains("serde")));
251    }
252
253    #[test]
254    fn test_serde_import_present() {
255        let e = make_enum("status", vec!["a"]);
256        let derives = vec!["Serialize".to_string()];
257        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
258        assert!(!imports.is_empty());
259    }
260
261    // --- edge cases ---
262
263    #[test]
264    fn test_single_variant() {
265        let e = make_enum("status", vec!["only"]);
266        let code = gen(&e, DatabaseKind::Postgres);
267        assert!(code.contains("Only"));
268    }
269
270    #[test]
271    fn test_many_variants() {
272        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
273        let e = make_enum("status", variants);
274        let code = gen(&e, DatabaseKind::Postgres);
275        assert!(code.contains("A,"));
276        assert!(code.contains("J,"));
277    }
278
279    #[test]
280    fn test_variant_with_digits() {
281        let e = make_enum("version", vec!["v2"]);
282        let code = gen(&e, DatabaseKind::Postgres);
283        assert!(code.contains("V2"));
284    }
285
286    #[test]
287    fn test_enum_name_with_double_underscores() {
288        let e = make_enum("my__enum", vec!["a"]);
289        let code = gen(&e, DatabaseKind::Postgres);
290        assert!(code.contains("pub enum MyEnum"));
291    }
292}