1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
21
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 let mut derive_tokens = vec![
25 quote! { Debug },
26 quote! { Clone },
27 quote! { PartialEq },
28 quote! { sqlx::Type },
29 ];
30 for d in extra_derives {
31 let ident = format_ident!("{}", d);
32 derive_tokens.push(quote! { #ident });
33 }
34
35 let type_attr = if db_kind == DatabaseKind::Postgres {
37 let pg_name = &enum_info.name;
38 quote! { #[sqlx(type_name = #pg_name)] }
39 } else {
40 quote! {}
41 };
42
43 let variants: Vec<TokenStream> = enum_info
44 .variants
45 .iter()
46 .map(|v| {
47 let variant_pascal = v.to_upper_camel_case();
48 let variant_ident = format_ident!("{}", variant_pascal);
49
50 let rename = if variant_pascal != *v {
51 quote! { #[sqlx(rename = #v)] }
52 } else {
53 quote! {}
54 };
55
56 quote! {
57 #rename
58 #variant_ident,
59 }
60 })
61 .collect();
62
63 let tokens = quote! {
64 #[doc = #doc]
65 #[derive(#(#derive_tokens),*)]
66 #type_attr
67 pub enum #enum_name {
68 #(#variants)*
69 }
70 };
71
72 (tokens, imports)
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::codegen::parse_and_format;
79
80 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
81 EnumInfo {
82 schema_name: "public".to_string(),
83 name: name.to_string(),
84 variants: variants.into_iter().map(|s| s.to_string()).collect(),
85 }
86 }
87
88 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
89 let (tokens, _) = generate_enum(info, db, &[]);
90 parse_and_format(&tokens)
91 }
92
93 fn gen_with_derives(info: &EnumInfo, db: DatabaseKind, derives: &[String]) -> (String, BTreeSet<String>) {
94 let (tokens, imports) = generate_enum(info, db, derives);
95 (parse_and_format(&tokens), imports)
96 }
97
98 #[test]
101 fn test_enum_variants() {
102 let e = make_enum("status", vec!["active", "inactive"]);
103 let code = gen(&e, DatabaseKind::Postgres);
104 assert!(code.contains("Active"));
105 assert!(code.contains("Inactive"));
106 }
107
108 #[test]
109 fn test_enum_name_pascal_case() {
110 let e = make_enum("user_status", vec!["a"]);
111 let code = gen(&e, DatabaseKind::Postgres);
112 assert!(code.contains("pub enum UserStatus"));
113 }
114
115 #[test]
116 fn test_doc_comment() {
117 let e = make_enum("status", vec!["a"]);
118 let code = gen(&e, DatabaseKind::Postgres);
119 assert!(code.contains("Enum: public.status"));
120 }
121
122 #[test]
125 fn test_postgres_has_type_name() {
126 let e = make_enum("user_status", vec!["a"]);
127 let code = gen(&e, DatabaseKind::Postgres);
128 assert!(code.contains("sqlx(type_name = \"user_status\")"));
129 }
130
131 #[test]
132 fn test_mysql_no_type_name() {
133 let e = make_enum("status", vec!["a"]);
134 let code = gen(&e, DatabaseKind::Mysql);
135 assert!(!code.contains("type_name"));
136 }
137
138 #[test]
139 fn test_sqlite_no_type_name() {
140 let e = make_enum("status", vec!["a"]);
141 let code = gen(&e, DatabaseKind::Sqlite);
142 assert!(!code.contains("type_name"));
143 }
144
145 #[test]
148 fn test_snake_case_variant_renamed() {
149 let e = make_enum("status", vec!["in_progress"]);
150 let code = gen(&e, DatabaseKind::Postgres);
151 assert!(code.contains("InProgress"));
152 assert!(code.contains("sqlx(rename = \"in_progress\")"));
153 }
154
155 #[test]
156 fn test_lowercase_variant_renamed() {
157 let e = make_enum("status", vec!["active"]);
158 let code = gen(&e, DatabaseKind::Postgres);
159 assert!(code.contains("Active"));
160 assert!(code.contains("sqlx(rename = \"active\")"));
161 }
162
163 #[test]
164 fn test_already_pascal_no_rename() {
165 let e = make_enum("status", vec!["Active"]);
166 let code = gen(&e, DatabaseKind::Postgres);
167 assert!(code.contains("Active"));
168 assert!(!code.contains("sqlx(rename"));
169 }
170
171 #[test]
172 fn test_upper_case_variant_renamed() {
173 let e = make_enum("status", vec!["UPPER_CASE"]);
174 let code = gen(&e, DatabaseKind::Postgres);
175 assert!(code.contains("UpperCase"));
176 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
177 }
178
179 #[test]
182 fn test_default_derives() {
183 let e = make_enum("status", vec!["a"]);
184 let code = gen(&e, DatabaseKind::Postgres);
185 assert!(code.contains("Debug"));
186 assert!(code.contains("Clone"));
187 assert!(code.contains("PartialEq"));
188 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
189 }
190
191 #[test]
192 fn test_extra_derive_serialize() {
193 let e = make_enum("status", vec!["a"]);
194 let derives = vec!["Serialize".to_string()];
195 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
196 assert!(code.contains("Serialize"));
197 }
198
199 #[test]
200 fn test_extra_derives_serde_imports() {
201 let e = make_enum("status", vec!["a"]);
202 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
203 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
204 assert!(imports.iter().any(|i| i.contains("serde")));
205 }
206
207 #[test]
210 fn test_no_derives_empty_imports() {
211 let e = make_enum("status", vec!["a"]);
212 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
213 assert!(imports.is_empty());
214 }
215
216 #[test]
217 fn test_serde_import_present() {
218 let e = make_enum("status", vec!["a"]);
219 let derives = vec!["Serialize".to_string()];
220 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
221 assert!(!imports.is_empty());
222 }
223
224 #[test]
227 fn test_single_variant() {
228 let e = make_enum("status", vec!["only"]);
229 let code = gen(&e, DatabaseKind::Postgres);
230 assert!(code.contains("Only"));
231 }
232
233 #[test]
234 fn test_many_variants() {
235 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
236 let e = make_enum("status", variants);
237 let code = gen(&e, DatabaseKind::Postgres);
238 assert!(code.contains("A,"));
239 assert!(code.contains("J,"));
240 }
241
242 #[test]
243 fn test_variant_with_digits() {
244 let e = make_enum("version", vec!["v2"]);
245 let code = gen(&e, DatabaseKind::Postgres);
246 assert!(code.contains("V2"));
247 }
248
249 #[test]
250 fn test_enum_name_with_double_underscores() {
251 let e = make_enum("my__enum", vec!["a"]);
252 let code = gen(&e, DatabaseKind::Postgres);
253 assert!(code.contains("pub enum MyEnum"));
254 }
255}