1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20
21 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 imports.insert("use serde::{Serialize, Deserialize};".to_string());
25 imports.insert("use sqlx_gen::SqlxGen;".to_string());
26 let mut derive_tokens = vec![
27 quote! { Debug },
28 quote! { Clone },
29 quote! { PartialEq },
30 quote! { Eq },
31 quote! { Serialize },
32 quote! { Deserialize },
33 quote! { sqlx::Type },
34 quote! { SqlxGen },
35 ];
36 for d in extra_derives {
37 let ident = format_ident!("{}", d);
38 derive_tokens.push(quote! { #ident });
39 }
40
41 let type_attr = if db_kind == DatabaseKind::Postgres {
43 let pg_name = if enum_info.schema_name != "public" {
44 format!("{}.{}", enum_info.schema_name, enum_info.name)
45 } else {
46 enum_info.name.clone()
47 };
48 quote! { #[sqlx(type_name = #pg_name)] }
49 } else {
50 quote! {}
51 };
52
53 let variants: Vec<TokenStream> = enum_info
54 .variants
55 .iter()
56 .map(|v| {
57 let variant_pascal = v.to_upper_camel_case();
58 let variant_ident = format_ident!("{}", variant_pascal);
59
60 let rename = if variant_pascal != *v {
61 quote! { #[sqlx(rename = #v)] }
62 } else {
63 quote! {}
64 };
65
66 quote! {
67 #rename
68 #variant_ident,
69 }
70 })
71 .collect();
72
73 let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
74 let variant_pascal = default_variant.to_upper_camel_case();
75 let variant_ident = format_ident!("{}", variant_pascal);
76 quote! {
77 impl Default for #enum_name {
78 fn default() -> Self {
79 Self::#variant_ident
80 }
81 }
82 }
83 } else {
84 quote! {}
85 };
86
87 let schema_name_str = &enum_info.schema_name;
88 let enum_name_str = &enum_info.name;
89
90 let tokens = quote! {
91 #[doc = #doc]
92 #[derive(#(#derive_tokens),*)]
93 #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
94 #type_attr
95 pub enum #enum_name {
96 #(#variants)*
97 }
98
99 #default_impl
100 };
101
102 (tokens, imports)
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::codegen::parse_and_format;
109
110 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
111 EnumInfo {
112 schema_name: "public".to_string(),
113 name: name.to_string(),
114 variants: variants.into_iter().map(|s| s.to_string()).collect(),
115 default_variant: None,
116 }
117 }
118
119 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
120 let (tokens, _) = generate_enum(info, db, &[]);
121 parse_and_format(&tokens)
122 }
123
124 fn gen_with_derives(
125 info: &EnumInfo,
126 db: DatabaseKind,
127 derives: &[String],
128 ) -> (String, BTreeSet<String>) {
129 let (tokens, imports) = generate_enum(info, db, derives);
130 (parse_and_format(&tokens), imports)
131 }
132
133 #[test]
136 fn test_enum_variants() {
137 let e = make_enum("status", vec!["active", "inactive"]);
138 let code = gen(&e, DatabaseKind::Postgres);
139 assert!(code.contains("Active"));
140 assert!(code.contains("Inactive"));
141 }
142
143 #[test]
144 fn test_enum_name_pascal_case() {
145 let e = make_enum("user_status", vec!["a"]);
146 let code = gen(&e, DatabaseKind::Postgres);
147 assert!(code.contains("pub enum UserStatus"));
148 }
149
150 #[test]
151 fn test_doc_comment() {
152 let e = make_enum("status", vec!["a"]);
153 let code = gen(&e, DatabaseKind::Postgres);
154 assert!(code.contains("Enum: public.status"));
155 }
156
157 #[test]
158 fn test_sqlx_gen_attr_has_schema_and_name() {
159 let e = make_enum("status", vec!["a"]);
160 let code = gen(&e, DatabaseKind::Postgres);
161 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
162 }
163
164 #[test]
165 fn test_sqlx_gen_attr_non_public_schema() {
166 let e = EnumInfo {
167 schema_name: "auth".to_string(),
168 name: "role".to_string(),
169 variants: vec!["admin".to_string(), "user".to_string()],
170 default_variant: None,
171 };
172 let code = gen(&e, DatabaseKind::Postgres);
173 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
174 }
175
176 #[test]
179 fn test_postgres_has_type_name() {
180 let e = make_enum("user_status", vec!["a"]);
181 let code = gen(&e, DatabaseKind::Postgres);
182 assert!(code.contains("sqlx(type_name = \"user_status\")"));
183 }
184
185 #[test]
186 fn test_postgres_non_public_schema_qualified_type_name() {
187 let e = EnumInfo {
188 schema_name: "auth".to_string(),
189 name: "role".to_string(),
190 variants: vec!["admin".to_string(), "user".to_string()],
191 default_variant: None,
192 };
193 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
194 let code = parse_and_format(&tokens);
195 assert!(code.contains("sqlx(type_name = \"auth.role\")"));
196 }
197
198 #[test]
199 fn test_postgres_public_schema_not_qualified() {
200 let e = make_enum("status", vec!["a"]);
201 let code = gen(&e, DatabaseKind::Postgres);
202 assert!(code.contains("sqlx(type_name = \"status\")"));
203 assert!(!code.contains("type_name = \"public.status\""));
205 }
206
207 #[test]
208 fn test_mysql_no_type_name() {
209 let e = make_enum("status", vec!["a"]);
210 let code = gen(&e, DatabaseKind::Mysql);
211 assert!(!code.contains("type_name"));
212 }
213
214 #[test]
215 fn test_sqlite_no_type_name() {
216 let e = make_enum("status", vec!["a"]);
217 let code = gen(&e, DatabaseKind::Sqlite);
218 assert!(!code.contains("type_name"));
219 }
220
221 #[test]
224 fn test_snake_case_variant_renamed() {
225 let e = make_enum("status", vec!["in_progress"]);
226 let code = gen(&e, DatabaseKind::Postgres);
227 assert!(code.contains("InProgress"));
228 assert!(code.contains("sqlx(rename = \"in_progress\")"));
229 }
230
231 #[test]
232 fn test_lowercase_variant_renamed() {
233 let e = make_enum("status", vec!["active"]);
234 let code = gen(&e, DatabaseKind::Postgres);
235 assert!(code.contains("Active"));
236 assert!(code.contains("sqlx(rename = \"active\")"));
237 }
238
239 #[test]
240 fn test_already_pascal_no_rename() {
241 let e = make_enum("status", vec!["Active"]);
242 let code = gen(&e, DatabaseKind::Postgres);
243 assert!(code.contains("Active"));
244 assert!(!code.contains("sqlx(rename"));
245 }
246
247 #[test]
248 fn test_upper_case_variant_renamed() {
249 let e = make_enum("status", vec!["UPPER_CASE"]);
250 let code = gen(&e, DatabaseKind::Postgres);
251 assert!(code.contains("UpperCase"));
252 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
253 }
254
255 #[test]
258 fn test_default_derives() {
259 let e = make_enum("status", vec!["a"]);
260 let code = gen(&e, DatabaseKind::Postgres);
261 assert!(code.contains("Debug"));
262 assert!(code.contains("Clone"));
263 assert!(code.contains("PartialEq"));
264 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
265 }
266
267 #[test]
268 fn test_extra_derive_serialize() {
269 let e = make_enum("status", vec!["a"]);
270 let derives = vec!["Serialize".to_string()];
271 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
272 assert!(code.contains("Serialize"));
273 }
274
275 #[test]
276 fn test_extra_derives_serde_imports() {
277 let e = make_enum("status", vec!["a"]);
278 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
279 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
280 assert!(imports.iter().any(|i| i.contains("serde")));
281 }
282
283 #[test]
286 fn test_no_extra_derives_has_serde_import() {
287 let e = make_enum("status", vec!["a"]);
288 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
289 assert!(imports.iter().any(|i| i.contains("serde")));
290 }
291
292 #[test]
293 fn test_serde_import_present() {
294 let e = make_enum("status", vec!["a"]);
295 let derives = vec!["Serialize".to_string()];
296 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
297 assert!(!imports.is_empty());
298 }
299
300 #[test]
303 fn test_single_variant() {
304 let e = make_enum("status", vec!["only"]);
305 let code = gen(&e, DatabaseKind::Postgres);
306 assert!(code.contains("Only"));
307 }
308
309 #[test]
310 fn test_many_variants() {
311 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
312 let e = make_enum("status", variants);
313 let code = gen(&e, DatabaseKind::Postgres);
314 assert!(code.contains("A,"));
315 assert!(code.contains("J,"));
316 }
317
318 #[test]
319 fn test_variant_with_digits() {
320 let e = make_enum("version", vec!["v2"]);
321 let code = gen(&e, DatabaseKind::Postgres);
322 assert!(code.contains("V2"));
323 }
324
325 #[test]
326 fn test_enum_name_with_double_underscores() {
327 let e = make_enum("my__enum", vec!["a"]);
328 let code = gen(&e, DatabaseKind::Postgres);
329 assert!(code.contains("pub enum MyEnum"));
330 }
331
332 #[test]
335 fn test_default_impl_generated() {
336 let e = EnumInfo {
337 schema_name: "public".to_string(),
338 name: "task_status".to_string(),
339 variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
340 default_variant: Some("idle".to_string()),
341 };
342 let code = gen(&e, DatabaseKind::Postgres);
343 assert!(code.contains("impl Default for TaskStatus"));
344 assert!(code.contains("Self::Idle"));
345 }
346
347 #[test]
348 fn test_no_default_impl_when_none() {
349 let e = make_enum("status", vec!["active", "inactive"]);
350 let code = gen(&e, DatabaseKind::Postgres);
351 assert!(!code.contains("impl Default"));
352 }
353
354 #[test]
355 fn test_default_impl_snake_case_variant() {
356 let e = EnumInfo {
357 schema_name: "public".to_string(),
358 name: "status".to_string(),
359 variants: vec!["in_progress".to_string(), "done".to_string()],
360 default_variant: Some("in_progress".to_string()),
361 };
362 let code = gen(&e, DatabaseKind::Postgres);
363 assert!(code.contains("impl Default for Status"));
364 assert!(code.contains("Self::InProgress"));
365 }
366
367 fn make_enum_in_schema(schema: &str, name: &str, variants: Vec<&str>) -> EnumInfo {
370 EnumInfo {
371 schema_name: schema.to_string(),
372 name: name.to_string(),
373 variants: variants.into_iter().map(|s| s.to_string()).collect(),
374 default_variant: None,
375 }
376 }
377
378 #[test]
379 fn test_public_schema_full_output() {
380 let e = make_enum_in_schema("public", "order_status", vec!["pending", "shipped", "delivered"]);
381 let code = gen(&e, DatabaseKind::Postgres);
382
383 assert!(code.contains("Enum: public.order_status"));
384 assert!(code.contains("pub enum OrderStatus"));
385 assert!(code.contains("sqlx(type_name = \"order_status\")"));
386 assert!(!code.contains("sqlx(type_name = \"public.order_status\")"));
387 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")"));
388 assert!(code.contains("Pending"));
389 assert!(code.contains("Shipped"));
390 assert!(code.contains("Delivered"));
391 }
392
393 #[test]
394 fn test_named_schema_full_output() {
395 let e = make_enum_in_schema("analysis", "toolcall_status", vec!["PENDING", "RUNNING", "DONE"]);
396 let code = gen(&e, DatabaseKind::Postgres);
397
398 assert!(code.contains("Enum: analysis.toolcall_status"));
399 assert!(code.contains("pub enum ToolcallStatus"));
400 assert!(code.contains("sqlx(type_name = \"analysis.toolcall_status\")"));
401 assert!(!code.contains("sqlx(type_name = \"toolcall_status\")"));
402 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")"));
403 assert!(code.contains("Pending"));
404 assert!(code.contains("Running"));
405 assert!(code.contains("Done"));
406 }
407
408 #[test]
409 fn test_named_schema_with_default_variant() {
410 let e = EnumInfo {
411 schema_name: "billing".to_string(),
412 name: "payment_status".to_string(),
413 variants: vec!["pending".to_string(), "paid".to_string(), "refunded".to_string()],
414 default_variant: Some("pending".to_string()),
415 };
416 let code = gen(&e, DatabaseKind::Postgres);
417
418 assert!(code.contains("sqlx(type_name = \"billing.payment_status\")"));
419 assert!(code.contains("impl Default for PaymentStatus"));
420 assert!(code.contains("Self::Pending"));
421 }
422
423 #[test]
424 fn test_named_schema_variant_rename() {
425 let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]);
426 let code = gen(&e, DatabaseKind::Postgres);
427
428 assert!(code.contains("sqlx(type_name = \"audit.log_level\")"));
429 assert!(code.contains("sqlx(rename = \"info\")"));
430 assert!(code.contains("sqlx(rename = \"warn_high\")"));
431 assert!(code.contains("WarnHigh"));
432 assert!(code.contains("sqlx(rename = \"CRITICAL\")"));
433 assert!(code.contains("Critical"));
434 }
435
436 #[test]
437 fn test_named_schema_mysql_no_type_name() {
438 let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
439 let code = gen(&e, DatabaseKind::Mysql);
440
441 assert!(!code.contains("type_name"));
442 }
443
444 #[test]
445 fn test_named_schema_sqlite_no_type_name() {
446 let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
447 let code = gen(&e, DatabaseKind::Sqlite);
448
449 assert!(!code.contains("type_name"));
450 }
451}