1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20
21 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 imports.insert("use serde::{Serialize, Deserialize};".to_string());
25 imports.insert("use sqlx_gen::SqlxGen;".to_string());
26 let mut derive_tokens = vec![
27 quote! { Debug },
28 quote! { Clone },
29 quote! { PartialEq },
30 quote! { Eq },
31 quote! { Serialize },
32 quote! { Deserialize },
33 quote! { sqlx::Type },
34 quote! { SqlxGen },
35 ];
36 for d in extra_derives {
37 let ident = format_ident!("{}", d);
38 derive_tokens.push(quote! { #ident });
39 }
40
41 let type_attr = if db_kind == DatabaseKind::Postgres {
43 let pg_name = &enum_info.name;
44 quote! { #[sqlx(type_name = #pg_name)] }
45 } else {
46 quote! {}
47 };
48
49 let variants: Vec<TokenStream> = enum_info
50 .variants
51 .iter()
52 .map(|v| {
53 let variant_pascal = v.to_upper_camel_case();
54 let variant_ident = format_ident!("{}", variant_pascal);
55
56 let rename = if variant_pascal != *v {
57 quote! { #[sqlx(rename = #v)] }
58 } else {
59 quote! {}
60 };
61
62 quote! {
63 #rename
64 #variant_ident,
65 }
66 })
67 .collect();
68
69 let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
70 let variant_pascal = default_variant.to_upper_camel_case();
71 let variant_ident = format_ident!("{}", variant_pascal);
72 quote! {
73 impl Default for #enum_name {
74 fn default() -> Self {
75 Self::#variant_ident
76 }
77 }
78 }
79 } else {
80 quote! {}
81 };
82
83 let schema_name_str = &enum_info.schema_name;
84 let enum_name_str = &enum_info.name;
85
86 let tokens = quote! {
87 #[doc = #doc]
88 #[derive(#(#derive_tokens),*)]
89 #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
90 #type_attr
91 pub enum #enum_name {
92 #(#variants)*
93 }
94
95 #default_impl
96 };
97
98 (tokens, imports)
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::codegen::parse_and_format;
105
106 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
107 EnumInfo {
108 schema_name: "public".to_string(),
109 name: name.to_string(),
110 variants: variants.into_iter().map(|s| s.to_string()).collect(),
111 default_variant: None,
112 }
113 }
114
115 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
116 let (tokens, _) = generate_enum(info, db, &[]);
117 parse_and_format(&tokens)
118 }
119
120 fn gen_with_derives(
121 info: &EnumInfo,
122 db: DatabaseKind,
123 derives: &[String],
124 ) -> (String, BTreeSet<String>) {
125 let (tokens, imports) = generate_enum(info, db, derives);
126 (parse_and_format(&tokens), imports)
127 }
128
129 #[test]
132 fn test_enum_variants() {
133 let e = make_enum("status", vec!["active", "inactive"]);
134 let code = gen(&e, DatabaseKind::Postgres);
135 assert!(code.contains("Active"));
136 assert!(code.contains("Inactive"));
137 }
138
139 #[test]
140 fn test_enum_name_pascal_case() {
141 let e = make_enum("user_status", vec!["a"]);
142 let code = gen(&e, DatabaseKind::Postgres);
143 assert!(code.contains("pub enum UserStatus"));
144 }
145
146 #[test]
147 fn test_doc_comment() {
148 let e = make_enum("status", vec!["a"]);
149 let code = gen(&e, DatabaseKind::Postgres);
150 assert!(code.contains("Enum: public.status"));
151 }
152
153 #[test]
154 fn test_sqlx_gen_attr_has_schema_and_name() {
155 let e = make_enum("status", vec!["a"]);
156 let code = gen(&e, DatabaseKind::Postgres);
157 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
158 }
159
160 #[test]
161 fn test_sqlx_gen_attr_non_public_schema() {
162 let e = EnumInfo {
163 schema_name: "auth".to_string(),
164 name: "role".to_string(),
165 variants: vec!["admin".to_string(), "user".to_string()],
166 default_variant: None,
167 };
168 let code = gen(&e, DatabaseKind::Postgres);
169 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
170 }
171
172 #[test]
175 fn test_postgres_has_type_name() {
176 let e = make_enum("user_status", vec!["a"]);
177 let code = gen(&e, DatabaseKind::Postgres);
178 assert!(code.contains("sqlx(type_name = \"user_status\")"));
179 }
180
181 #[test]
182 fn test_postgres_non_public_schema_qualified_type_name() {
183 let e = EnumInfo {
184 schema_name: "auth".to_string(),
185 name: "role".to_string(),
186 variants: vec!["admin".to_string(), "user".to_string()],
187 default_variant: None,
188 };
189 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
190 let code = parse_and_format(&tokens);
191 assert!(code.contains("sqlx(type_name = \"role\")"));
192 }
193
194 #[test]
195 fn test_postgres_public_schema_not_qualified() {
196 let e = make_enum("status", vec!["a"]);
197 let code = gen(&e, DatabaseKind::Postgres);
198 assert!(code.contains("sqlx(type_name = \"status\")"));
199 assert!(!code.contains("type_name = \"public.status\""));
201 }
202
203 #[test]
204 fn test_mysql_no_type_name() {
205 let e = make_enum("status", vec!["a"]);
206 let code = gen(&e, DatabaseKind::Mysql);
207 assert!(!code.contains("type_name"));
208 }
209
210 #[test]
211 fn test_sqlite_no_type_name() {
212 let e = make_enum("status", vec!["a"]);
213 let code = gen(&e, DatabaseKind::Sqlite);
214 assert!(!code.contains("type_name"));
215 }
216
217 #[test]
220 fn test_snake_case_variant_renamed() {
221 let e = make_enum("status", vec!["in_progress"]);
222 let code = gen(&e, DatabaseKind::Postgres);
223 assert!(code.contains("InProgress"));
224 assert!(code.contains("sqlx(rename = \"in_progress\")"));
225 }
226
227 #[test]
228 fn test_lowercase_variant_renamed() {
229 let e = make_enum("status", vec!["active"]);
230 let code = gen(&e, DatabaseKind::Postgres);
231 assert!(code.contains("Active"));
232 assert!(code.contains("sqlx(rename = \"active\")"));
233 }
234
235 #[test]
236 fn test_already_pascal_no_rename() {
237 let e = make_enum("status", vec!["Active"]);
238 let code = gen(&e, DatabaseKind::Postgres);
239 assert!(code.contains("Active"));
240 assert!(!code.contains("sqlx(rename"));
241 }
242
243 #[test]
244 fn test_upper_case_variant_renamed() {
245 let e = make_enum("status", vec!["UPPER_CASE"]);
246 let code = gen(&e, DatabaseKind::Postgres);
247 assert!(code.contains("UpperCase"));
248 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
249 }
250
251 #[test]
254 fn test_default_derives() {
255 let e = make_enum("status", vec!["a"]);
256 let code = gen(&e, DatabaseKind::Postgres);
257 assert!(code.contains("Debug"));
258 assert!(code.contains("Clone"));
259 assert!(code.contains("PartialEq"));
260 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
261 }
262
263 #[test]
264 fn test_extra_derive_serialize() {
265 let e = make_enum("status", vec!["a"]);
266 let derives = vec!["Serialize".to_string()];
267 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
268 assert!(code.contains("Serialize"));
269 }
270
271 #[test]
272 fn test_extra_derives_serde_imports() {
273 let e = make_enum("status", vec!["a"]);
274 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
275 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
276 assert!(imports.iter().any(|i| i.contains("serde")));
277 }
278
279 #[test]
282 fn test_no_extra_derives_has_serde_import() {
283 let e = make_enum("status", vec!["a"]);
284 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
285 assert!(imports.iter().any(|i| i.contains("serde")));
286 }
287
288 #[test]
289 fn test_serde_import_present() {
290 let e = make_enum("status", vec!["a"]);
291 let derives = vec!["Serialize".to_string()];
292 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
293 assert!(!imports.is_empty());
294 }
295
296 #[test]
299 fn test_single_variant() {
300 let e = make_enum("status", vec!["only"]);
301 let code = gen(&e, DatabaseKind::Postgres);
302 assert!(code.contains("Only"));
303 }
304
305 #[test]
306 fn test_many_variants() {
307 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
308 let e = make_enum("status", variants);
309 let code = gen(&e, DatabaseKind::Postgres);
310 assert!(code.contains("A,"));
311 assert!(code.contains("J,"));
312 }
313
314 #[test]
315 fn test_variant_with_digits() {
316 let e = make_enum("version", vec!["v2"]);
317 let code = gen(&e, DatabaseKind::Postgres);
318 assert!(code.contains("V2"));
319 }
320
321 #[test]
322 fn test_enum_name_with_double_underscores() {
323 let e = make_enum("my__enum", vec!["a"]);
324 let code = gen(&e, DatabaseKind::Postgres);
325 assert!(code.contains("pub enum MyEnum"));
326 }
327
328 #[test]
331 fn test_default_impl_generated() {
332 let e = EnumInfo {
333 schema_name: "public".to_string(),
334 name: "task_status".to_string(),
335 variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
336 default_variant: Some("idle".to_string()),
337 };
338 let code = gen(&e, DatabaseKind::Postgres);
339 assert!(code.contains("impl Default for TaskStatus"));
340 assert!(code.contains("Self::Idle"));
341 }
342
343 #[test]
344 fn test_no_default_impl_when_none() {
345 let e = make_enum("status", vec!["active", "inactive"]);
346 let code = gen(&e, DatabaseKind::Postgres);
347 assert!(!code.contains("impl Default"));
348 }
349
350 #[test]
351 fn test_default_impl_snake_case_variant() {
352 let e = EnumInfo {
353 schema_name: "public".to_string(),
354 name: "status".to_string(),
355 variants: vec!["in_progress".to_string(), "done".to_string()],
356 default_variant: Some("in_progress".to_string()),
357 };
358 let code = gen(&e, DatabaseKind::Postgres);
359 assert!(code.contains("impl Default for Status"));
360 assert!(code.contains("Self::InProgress"));
361 }
362}