Skip to main content

sqlx_gen/codegen/
enum_gen.rs

1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12    enum_info: &EnumInfo,
13    db_kind: DatabaseKind,
14    extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16    let mut imports = BTreeSet::new();
17    for imp in imports_for_derives(extra_derives) {
18        imports.insert(imp);
19    }
20    let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
21
22    let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24    let mut derive_tokens = vec![
25        quote! { Debug },
26        quote! { Clone },
27        quote! { PartialEq },
28        quote! { sqlx::Type },
29    ];
30    for d in extra_derives {
31        let ident = format_ident!("{}", d);
32        derive_tokens.push(quote! { #ident });
33    }
34
35    // For PG, add #[sqlx(type_name = "...")]
36    let type_attr = if db_kind == DatabaseKind::Postgres {
37        let pg_name = &enum_info.name;
38        quote! { #[sqlx(type_name = #pg_name)] }
39    } else {
40        quote! {}
41    };
42
43    let variants: Vec<TokenStream> = enum_info
44        .variants
45        .iter()
46        .map(|v| {
47            let variant_pascal = v.to_upper_camel_case();
48            let variant_ident = format_ident!("{}", variant_pascal);
49
50            let rename = if variant_pascal != *v {
51                quote! { #[sqlx(rename = #v)] }
52            } else {
53                quote! {}
54            };
55
56            quote! {
57                #rename
58                #variant_ident,
59            }
60        })
61        .collect();
62
63    let tokens = quote! {
64        #[doc = #doc]
65        #[derive(#(#derive_tokens),*)]
66        #type_attr
67        pub enum #enum_name {
68            #(#variants)*
69        }
70    };
71
72    (tokens, imports)
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::codegen::parse_and_format;
79
80    fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
81        EnumInfo {
82            schema_name: "public".to_string(),
83            name: name.to_string(),
84            variants: variants.into_iter().map(|s| s.to_string()).collect(),
85        }
86    }
87
88    fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
89        let (tokens, _) = generate_enum(info, db, &[]);
90        parse_and_format(&tokens)
91    }
92
93    fn gen_with_derives(info: &EnumInfo, db: DatabaseKind, derives: &[String]) -> (String, BTreeSet<String>) {
94        let (tokens, imports) = generate_enum(info, db, derives);
95        (parse_and_format(&tokens), imports)
96    }
97
98    // --- basic structure ---
99
100    #[test]
101    fn test_enum_variants() {
102        let e = make_enum("status", vec!["active", "inactive"]);
103        let code = gen(&e, DatabaseKind::Postgres);
104        assert!(code.contains("Active"));
105        assert!(code.contains("Inactive"));
106    }
107
108    #[test]
109    fn test_enum_name_pascal_case() {
110        let e = make_enum("user_status", vec!["a"]);
111        let code = gen(&e, DatabaseKind::Postgres);
112        assert!(code.contains("pub enum UserStatus"));
113    }
114
115    #[test]
116    fn test_doc_comment() {
117        let e = make_enum("status", vec!["a"]);
118        let code = gen(&e, DatabaseKind::Postgres);
119        assert!(code.contains("Enum: public.status"));
120    }
121
122    // --- sqlx attributes ---
123
124    #[test]
125    fn test_postgres_has_type_name() {
126        let e = make_enum("user_status", vec!["a"]);
127        let code = gen(&e, DatabaseKind::Postgres);
128        assert!(code.contains("sqlx(type_name = \"user_status\")"));
129    }
130
131    #[test]
132    fn test_mysql_no_type_name() {
133        let e = make_enum("status", vec!["a"]);
134        let code = gen(&e, DatabaseKind::Mysql);
135        assert!(!code.contains("type_name"));
136    }
137
138    #[test]
139    fn test_sqlite_no_type_name() {
140        let e = make_enum("status", vec!["a"]);
141        let code = gen(&e, DatabaseKind::Sqlite);
142        assert!(!code.contains("type_name"));
143    }
144
145    // --- rename variants ---
146
147    #[test]
148    fn test_snake_case_variant_renamed() {
149        let e = make_enum("status", vec!["in_progress"]);
150        let code = gen(&e, DatabaseKind::Postgres);
151        assert!(code.contains("InProgress"));
152        assert!(code.contains("sqlx(rename = \"in_progress\")"));
153    }
154
155    #[test]
156    fn test_lowercase_variant_renamed() {
157        let e = make_enum("status", vec!["active"]);
158        let code = gen(&e, DatabaseKind::Postgres);
159        assert!(code.contains("Active"));
160        assert!(code.contains("sqlx(rename = \"active\")"));
161    }
162
163    #[test]
164    fn test_already_pascal_no_rename() {
165        let e = make_enum("status", vec!["Active"]);
166        let code = gen(&e, DatabaseKind::Postgres);
167        assert!(code.contains("Active"));
168        assert!(!code.contains("sqlx(rename"));
169    }
170
171    #[test]
172    fn test_upper_case_variant_renamed() {
173        let e = make_enum("status", vec!["UPPER_CASE"]);
174        let code = gen(&e, DatabaseKind::Postgres);
175        assert!(code.contains("UpperCase"));
176        assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
177    }
178
179    // --- derives ---
180
181    #[test]
182    fn test_default_derives() {
183        let e = make_enum("status", vec!["a"]);
184        let code = gen(&e, DatabaseKind::Postgres);
185        assert!(code.contains("Debug"));
186        assert!(code.contains("Clone"));
187        assert!(code.contains("PartialEq"));
188        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
189    }
190
191    #[test]
192    fn test_extra_derive_serialize() {
193        let e = make_enum("status", vec!["a"]);
194        let derives = vec!["Serialize".to_string()];
195        let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
196        assert!(code.contains("Serialize"));
197    }
198
199    #[test]
200    fn test_extra_derives_serde_imports() {
201        let e = make_enum("status", vec!["a"]);
202        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
203        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
204        assert!(imports.iter().any(|i| i.contains("serde")));
205    }
206
207    // --- imports ---
208
209    #[test]
210    fn test_no_derives_empty_imports() {
211        let e = make_enum("status", vec!["a"]);
212        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
213        assert!(imports.is_empty());
214    }
215
216    #[test]
217    fn test_serde_import_present() {
218        let e = make_enum("status", vec!["a"]);
219        let derives = vec!["Serialize".to_string()];
220        let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
221        assert!(!imports.is_empty());
222    }
223
224    // --- edge cases ---
225
226    #[test]
227    fn test_single_variant() {
228        let e = make_enum("status", vec!["only"]);
229        let code = gen(&e, DatabaseKind::Postgres);
230        assert!(code.contains("Only"));
231    }
232
233    #[test]
234    fn test_many_variants() {
235        let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
236        let e = make_enum("status", variants);
237        let code = gen(&e, DatabaseKind::Postgres);
238        assert!(code.contains("A,"));
239        assert!(code.contains("J,"));
240    }
241
242    #[test]
243    fn test_variant_with_digits() {
244        let e = make_enum("version", vec!["v2"]);
245        let code = gen(&e, DatabaseKind::Postgres);
246        assert!(code.contains("V2"));
247    }
248
249    #[test]
250    fn test_enum_name_with_double_underscores() {
251        let e = make_enum("my__enum", vec!["a"]);
252        let code = gen(&e, DatabaseKind::Postgres);
253        assert!(code.contains("pub enum MyEnum"));
254    }
255}