1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20
21 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 imports.insert("use serde::{Serialize, Deserialize};".to_string());
25 imports.insert("use sqlx_gen::SqlxGen;".to_string());
26 let mut derive_tokens = vec![
27 quote! { Debug },
28 quote! { Clone },
29 quote! { PartialEq },
30 quote! { Eq },
31 quote! { Serialize },
32 quote! { Deserialize },
33 quote! { sqlx::Type },
34 quote! { SqlxGen },
35 ];
36 for d in extra_derives {
37 let ident = format_ident!("{}", d);
38 derive_tokens.push(quote! { #ident });
39 }
40
41 let type_attr = if db_kind == DatabaseKind::Postgres {
44 let pg_name = if enum_info.schema_name != "public" {
45 format!("{}.{}", enum_info.schema_name, enum_info.name)
46 } else {
47 enum_info.name.clone()
48 };
49 quote! { #[sqlx(type_name = #pg_name)] }
50 } else {
51 quote! {}
52 };
53
54 let variants: Vec<TokenStream> = enum_info
55 .variants
56 .iter()
57 .map(|v| {
58 let variant_pascal = v.to_upper_camel_case();
59 let variant_ident = format_ident!("{}", variant_pascal);
60
61 let rename = if variant_pascal != *v {
62 quote! { #[sqlx(rename = #v)] }
63 } else {
64 quote! {}
65 };
66
67 quote! {
68 #rename
69 #variant_ident,
70 }
71 })
72 .collect();
73
74 let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
75 let variant_pascal = default_variant.to_upper_camel_case();
76 let variant_ident = format_ident!("{}", variant_pascal);
77 quote! {
78 impl Default for #enum_name {
79 fn default() -> Self {
80 Self::#variant_ident
81 }
82 }
83 }
84 } else {
85 quote! {}
86 };
87
88 let tokens = quote! {
89 #[doc = #doc]
90 #[derive(#(#derive_tokens),*)]
91 #[sqlx_gen(kind = "enum")]
92 #type_attr
93 pub enum #enum_name {
94 #(#variants)*
95 }
96
97 #default_impl
98 };
99
100 (tokens, imports)
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::codegen::parse_and_format;
107
108 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
109 EnumInfo {
110 schema_name: "public".to_string(),
111 name: name.to_string(),
112 variants: variants.into_iter().map(|s| s.to_string()).collect(),
113 default_variant: None,
114 }
115 }
116
117 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
118 let (tokens, _) = generate_enum(info, db, &[]);
119 parse_and_format(&tokens)
120 }
121
122 fn gen_with_derives(
123 info: &EnumInfo,
124 db: DatabaseKind,
125 derives: &[String],
126 ) -> (String, BTreeSet<String>) {
127 let (tokens, imports) = generate_enum(info, db, derives);
128 (parse_and_format(&tokens), imports)
129 }
130
131 #[test]
134 fn test_enum_variants() {
135 let e = make_enum("status", vec!["active", "inactive"]);
136 let code = gen(&e, DatabaseKind::Postgres);
137 assert!(code.contains("Active"));
138 assert!(code.contains("Inactive"));
139 }
140
141 #[test]
142 fn test_enum_name_pascal_case() {
143 let e = make_enum("user_status", vec!["a"]);
144 let code = gen(&e, DatabaseKind::Postgres);
145 assert!(code.contains("pub enum UserStatus"));
146 }
147
148 #[test]
149 fn test_doc_comment() {
150 let e = make_enum("status", vec!["a"]);
151 let code = gen(&e, DatabaseKind::Postgres);
152 assert!(code.contains("Enum: public.status"));
153 }
154
155 #[test]
158 fn test_postgres_has_type_name() {
159 let e = make_enum("user_status", vec!["a"]);
160 let code = gen(&e, DatabaseKind::Postgres);
161 assert!(code.contains("sqlx(type_name = \"user_status\")"));
162 }
163
164 #[test]
165 fn test_postgres_non_public_schema_qualified_type_name() {
166 let e = EnumInfo {
167 schema_name: "auth".to_string(),
168 name: "role".to_string(),
169 variants: vec!["admin".to_string(), "user".to_string()],
170 default_variant: None,
171 };
172 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
173 let code = parse_and_format(&tokens);
174 assert!(code.contains("sqlx(type_name = \"auth.role\")"));
175 }
176
177 #[test]
178 fn test_postgres_public_schema_not_qualified() {
179 let e = make_enum("status", vec!["a"]);
180 let code = gen(&e, DatabaseKind::Postgres);
181 assert!(code.contains("sqlx(type_name = \"status\")"));
182 assert!(!code.contains("type_name = \"public.status\""));
184 }
185
186 #[test]
187 fn test_mysql_no_type_name() {
188 let e = make_enum("status", vec!["a"]);
189 let code = gen(&e, DatabaseKind::Mysql);
190 assert!(!code.contains("type_name"));
191 }
192
193 #[test]
194 fn test_sqlite_no_type_name() {
195 let e = make_enum("status", vec!["a"]);
196 let code = gen(&e, DatabaseKind::Sqlite);
197 assert!(!code.contains("type_name"));
198 }
199
200 #[test]
203 fn test_snake_case_variant_renamed() {
204 let e = make_enum("status", vec!["in_progress"]);
205 let code = gen(&e, DatabaseKind::Postgres);
206 assert!(code.contains("InProgress"));
207 assert!(code.contains("sqlx(rename = \"in_progress\")"));
208 }
209
210 #[test]
211 fn test_lowercase_variant_renamed() {
212 let e = make_enum("status", vec!["active"]);
213 let code = gen(&e, DatabaseKind::Postgres);
214 assert!(code.contains("Active"));
215 assert!(code.contains("sqlx(rename = \"active\")"));
216 }
217
218 #[test]
219 fn test_already_pascal_no_rename() {
220 let e = make_enum("status", vec!["Active"]);
221 let code = gen(&e, DatabaseKind::Postgres);
222 assert!(code.contains("Active"));
223 assert!(!code.contains("sqlx(rename"));
224 }
225
226 #[test]
227 fn test_upper_case_variant_renamed() {
228 let e = make_enum("status", vec!["UPPER_CASE"]);
229 let code = gen(&e, DatabaseKind::Postgres);
230 assert!(code.contains("UpperCase"));
231 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
232 }
233
234 #[test]
237 fn test_default_derives() {
238 let e = make_enum("status", vec!["a"]);
239 let code = gen(&e, DatabaseKind::Postgres);
240 assert!(code.contains("Debug"));
241 assert!(code.contains("Clone"));
242 assert!(code.contains("PartialEq"));
243 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
244 }
245
246 #[test]
247 fn test_extra_derive_serialize() {
248 let e = make_enum("status", vec!["a"]);
249 let derives = vec!["Serialize".to_string()];
250 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
251 assert!(code.contains("Serialize"));
252 }
253
254 #[test]
255 fn test_extra_derives_serde_imports() {
256 let e = make_enum("status", vec!["a"]);
257 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
258 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
259 assert!(imports.iter().any(|i| i.contains("serde")));
260 }
261
262 #[test]
265 fn test_no_extra_derives_has_serde_import() {
266 let e = make_enum("status", vec!["a"]);
267 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
268 assert!(imports.iter().any(|i| i.contains("serde")));
269 }
270
271 #[test]
272 fn test_serde_import_present() {
273 let e = make_enum("status", vec!["a"]);
274 let derives = vec!["Serialize".to_string()];
275 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
276 assert!(!imports.is_empty());
277 }
278
279 #[test]
282 fn test_single_variant() {
283 let e = make_enum("status", vec!["only"]);
284 let code = gen(&e, DatabaseKind::Postgres);
285 assert!(code.contains("Only"));
286 }
287
288 #[test]
289 fn test_many_variants() {
290 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
291 let e = make_enum("status", variants);
292 let code = gen(&e, DatabaseKind::Postgres);
293 assert!(code.contains("A,"));
294 assert!(code.contains("J,"));
295 }
296
297 #[test]
298 fn test_variant_with_digits() {
299 let e = make_enum("version", vec!["v2"]);
300 let code = gen(&e, DatabaseKind::Postgres);
301 assert!(code.contains("V2"));
302 }
303
304 #[test]
305 fn test_enum_name_with_double_underscores() {
306 let e = make_enum("my__enum", vec!["a"]);
307 let code = gen(&e, DatabaseKind::Postgres);
308 assert!(code.contains("pub enum MyEnum"));
309 }
310
311 #[test]
314 fn test_default_impl_generated() {
315 let e = EnumInfo {
316 schema_name: "public".to_string(),
317 name: "task_status".to_string(),
318 variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
319 default_variant: Some("idle".to_string()),
320 };
321 let code = gen(&e, DatabaseKind::Postgres);
322 assert!(code.contains("impl Default for TaskStatus"));
323 assert!(code.contains("Self::Idle"));
324 }
325
326 #[test]
327 fn test_no_default_impl_when_none() {
328 let e = make_enum("status", vec!["active", "inactive"]);
329 let code = gen(&e, DatabaseKind::Postgres);
330 assert!(!code.contains("impl Default"));
331 }
332
333 #[test]
334 fn test_default_impl_snake_case_variant() {
335 let e = EnumInfo {
336 schema_name: "public".to_string(),
337 name: "status".to_string(),
338 variants: vec!["in_progress".to_string(), "done".to_string()],
339 default_variant: Some("in_progress".to_string()),
340 };
341 let code = gen(&e, DatabaseKind::Postgres);
342 assert!(code.contains("impl Default for Status"));
343 assert!(code.contains("Self::InProgress"));
344 }
345}