1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, rust_type_name_for};
9use crate::introspect::{EnumInfo, SchemaInfo};
10
11pub fn check_variant_collisions(enum_info: &EnumInfo) -> crate::error::Result<()> {
16 use std::collections::BTreeMap;
17 let mut seen: BTreeMap<String, &str> = BTreeMap::new();
18 for v in &enum_info.variants {
19 let pascal = v.to_upper_camel_case();
20 if let Some(prev) = seen.get(pascal.as_str()).copied() {
21 return Err(crate::error::Error::Config(format!(
22 "Enum '{}.{}': SQL variants '{}' and '{}' both map to Rust identifier '{}'. \
23 Rename one of them in the database or use a custom mapping.",
24 enum_info.schema_name, enum_info.name, prev, v, pascal
25 )));
26 }
27 seen.insert(pascal, v.as_str());
28 }
29 Ok(())
30}
31
32pub fn generate_enum(
33 enum_info: &EnumInfo,
34 db_kind: DatabaseKind,
35 extra_derives: &[String],
36) -> (TokenStream, BTreeSet<String>) {
37 generate_enum_with_schema(enum_info, db_kind, extra_derives, &SchemaInfo::default())
40}
41
42pub fn generate_enum_with_schema(
43 enum_info: &EnumInfo,
44 db_kind: DatabaseKind,
45 extra_derives: &[String],
46 schema_info: &SchemaInfo,
47) -> (TokenStream, BTreeSet<String>) {
48 let mut imports = BTreeSet::new();
49 for imp in imports_for_derives(extra_derives) {
50 imports.insert(imp);
51 }
52
53 let rust_name = rust_type_name_for(schema_info, &enum_info.schema_name, &enum_info.name);
54 let enum_name = format_ident!("{}", rust_name);
55 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
56 let search_path_doc = if db_kind == DatabaseKind::Postgres
61 && !crate::codegen::is_default_schema(&enum_info.schema_name)
62 {
63 let msg = format!(
64 "Lives in PostgreSQL schema `{schema}`. The sqlx connection \
65 must include `{schema}` in its search_path so PG resolves the \
66 unqualified `type_name = \"{name}\"` to this enum. Example:\n\
67 \n\
68 ```ignore\n\
69 sqlx::query(\"SET search_path TO public, {schema}\")\n\
70 ```",
71 schema = enum_info.schema_name,
72 name = enum_info.name,
73 );
74 Some(msg)
75 } else {
76 None
77 };
78
79 imports.insert("use serde::{Serialize, Deserialize};".to_string());
80 imports.insert("use sqlx_gen::SqlxGen;".to_string());
81 let mut derive_tokens = vec![
82 quote! { Debug },
83 quote! { Clone },
84 quote! { PartialEq },
85 quote! { Eq },
86 quote! { Serialize },
87 quote! { Deserialize },
88 quote! { sqlx::Type },
89 quote! { SqlxGen },
90 ];
91 for d in extra_derives {
92 let ident = format_ident!("{}", d);
93 derive_tokens.push(quote! { #ident });
94 }
95
96 let type_attr = if db_kind == DatabaseKind::Postgres {
101 let pg_name = &enum_info.name;
102 quote! { #[sqlx(type_name = #pg_name)] }
103 } else {
104 quote! {}
105 };
106
107 let variants: Vec<TokenStream> = enum_info
108 .variants
109 .iter()
110 .map(|v| {
111 let variant_pascal = v.to_upper_camel_case();
112 let variant_ident = format_ident!("{}", variant_pascal);
113
114 let rename = if variant_pascal != *v {
115 quote! { #[sqlx(rename = #v)] }
116 } else {
117 quote! {}
118 };
119
120 quote! {
121 #rename
122 #variant_ident,
123 }
124 })
125 .collect();
126
127 let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
128 let variant_pascal = default_variant.to_upper_camel_case();
129 let variant_ident = format_ident!("{}", variant_pascal);
130 quote! {
131 impl Default for #enum_name {
132 fn default() -> Self {
133 Self::#variant_ident
134 }
135 }
136 }
137 } else {
138 quote! {}
139 };
140
141 let _ = db_kind;
146
147 let schema_name_str = &enum_info.schema_name;
148 let enum_name_str = &enum_info.name;
149 let search_path_doc_tokens = match &search_path_doc {
150 Some(m) => quote! { #[doc = #m] },
151 None => quote! {},
152 };
153
154 let tokens = quote! {
155 #[doc = #doc]
156 #search_path_doc_tokens
157 #[derive(#(#derive_tokens),*)]
158 #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
159 #type_attr
160 pub enum #enum_name {
161 #(#variants)*
162 }
163
164 #default_impl
165 };
166
167 (tokens, imports)
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::codegen::parse_and_format;
174
175 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
176 EnumInfo {
177 schema_name: "public".to_string(),
178 name: name.to_string(),
179 variants: variants.into_iter().map(|s| s.to_string()).collect(),
180 default_variant: None,
181 }
182 }
183
184 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
185 let (tokens, _) = generate_enum(info, db, &[]);
186 parse_and_format(&tokens).unwrap()
187 }
188
189 fn gen_with_derives(
190 info: &EnumInfo,
191 db: DatabaseKind,
192 derives: &[String],
193 ) -> (String, BTreeSet<String>) {
194 let (tokens, imports) = generate_enum(info, db, derives);
195 (parse_and_format(&tokens).unwrap(), imports)
196 }
197
198 #[test]
201 fn test_enum_variants() {
202 let e = make_enum("status", vec!["active", "inactive"]);
203 let code = gen(&e, DatabaseKind::Postgres);
204 assert!(code.contains("Active"));
205 assert!(code.contains("Inactive"));
206 }
207
208 #[test]
209 fn test_enum_name_pascal_case() {
210 let e = make_enum("user_status", vec!["a"]);
211 let code = gen(&e, DatabaseKind::Postgres);
212 assert!(code.contains("pub enum UserStatus"));
213 }
214
215 #[test]
216 fn test_doc_comment() {
217 let e = make_enum("status", vec!["a"]);
218 let code = gen(&e, DatabaseKind::Postgres);
219 assert!(code.contains("Enum: public.status"));
220 }
221
222 #[test]
223 fn test_sqlx_gen_attr_has_schema_and_name() {
224 let e = make_enum("status", vec!["a"]);
225 let code = gen(&e, DatabaseKind::Postgres);
226 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
227 }
228
229 #[test]
230 fn test_sqlx_gen_attr_non_public_schema() {
231 let e = EnumInfo {
232 schema_name: "auth".to_string(),
233 name: "role".to_string(),
234 variants: vec!["admin".to_string(), "user".to_string()],
235 default_variant: None,
236 };
237 let code = gen(&e, DatabaseKind::Postgres);
238 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
239 }
240
241 #[test]
244 fn test_postgres_has_type_name() {
245 let e = make_enum("user_status", vec!["a"]);
246 let code = gen(&e, DatabaseKind::Postgres);
247 assert!(code.contains("sqlx(type_name = \"user_status\")"));
248 }
249
250 #[test]
251 fn test_check_variant_collisions_detects_after_camel_case() {
252 let e = EnumInfo {
253 schema_name: "public".into(),
254 name: "weird".into(),
255 variants: vec!["foo bar".into(), "foo_bar".into()],
256 default_variant: None,
257 };
258 let result = check_variant_collisions(&e);
259 assert!(result.is_err(), "must detect collision");
260 let msg = result.unwrap_err().to_string();
261 assert!(
262 msg.contains("FooBar"),
263 "error must mention conflicting Rust ident, got: {}",
264 msg
265 );
266 assert!(msg.contains("foo bar") || msg.contains("foo_bar"));
267 }
268
269 #[test]
270 fn test_check_variant_collisions_accepts_distinct_variants() {
271 let e = make_enum("status", vec!["active", "inactive"]);
272 assert!(check_variant_collisions(&e).is_ok());
273 }
274
275 #[test]
276 fn test_check_variant_collisions_accepts_single_variant() {
277 let e = make_enum("status", vec!["only"]);
278 assert!(check_variant_collisions(&e).is_ok());
279 }
280
281 #[test]
282 fn test_does_not_emit_manual_pg_has_array_type_impl() {
283 for db in [
286 DatabaseKind::Postgres,
287 DatabaseKind::Mysql,
288 DatabaseKind::Sqlite,
289 ] {
290 let e = make_enum("status", vec!["a", "b"]);
291 let code = gen(&e, db);
292 assert!(
293 !code.contains("PgHasArrayType"),
294 "{:?}: must not emit a manual PgHasArrayType impl, got:\n{}",
295 db,
296 code
297 );
298 }
299 }
300
301 #[test]
302 fn test_postgres_non_public_schema_type_name_is_unqualified() {
303 let e = EnumInfo {
306 schema_name: "auth".to_string(),
307 name: "role".to_string(),
308 variants: vec!["admin".to_string(), "user".to_string()],
309 default_variant: None,
310 };
311 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
312 let code = parse_and_format(&tokens).unwrap();
313 assert!(
314 code.contains("sqlx(type_name = \"role\")"),
315 "type_name must be unqualified for sqlx 0.8 compatibility, got:\n{}",
316 code
317 );
318 assert!(
319 !code.contains("\"auth.role\""),
320 "type_name must NOT include schema; got:\n{}",
321 code
322 );
323 }
324
325 #[test]
326 fn test_postgres_public_schema_not_qualified() {
327 let e = make_enum("status", vec!["a"]);
328 let code = gen(&e, DatabaseKind::Postgres);
329 assert!(code.contains("sqlx(type_name = \"status\")"));
330 assert!(!code.contains("type_name = \"public.status\""));
332 }
333
334 #[test]
335 fn test_mysql_inline_enum_emits_rename_for_lowercase_variants() {
336 let e = make_enum("status", vec!["active", "inactive"]);
339 let code = gen(&e, DatabaseKind::Mysql);
340 assert!(
341 code.contains("sqlx(rename = \"active\")"),
342 "MySQL inline ENUM variant must carry rename for round-trip:\n{}",
343 code
344 );
345 assert!(code.contains("sqlx(rename = \"inactive\")"));
346 assert!(!code.contains("type_name"));
348 }
349
350 #[test]
351 fn test_mysql_inline_enum_preserves_case_sensitive_variants() {
352 let e = make_enum("priority", vec!["LOW", "HIGH"]);
353 let code = gen(&e, DatabaseKind::Mysql);
354 assert!(code.contains("sqlx(rename = \"LOW\")"));
356 assert!(code.contains("sqlx(rename = \"HIGH\")"));
357 }
358
359 #[test]
360 fn test_mysql_no_type_name() {
361 let e = make_enum("status", vec!["a"]);
362 let code = gen(&e, DatabaseKind::Mysql);
363 assert!(!code.contains("type_name"));
364 }
365
366 #[test]
367 fn test_sqlite_no_type_name() {
368 let e = make_enum("status", vec!["a"]);
369 let code = gen(&e, DatabaseKind::Sqlite);
370 assert!(!code.contains("type_name"));
371 }
372
373 #[test]
376 fn test_snake_case_variant_renamed() {
377 let e = make_enum("status", vec!["in_progress"]);
378 let code = gen(&e, DatabaseKind::Postgres);
379 assert!(code.contains("InProgress"));
380 assert!(code.contains("sqlx(rename = \"in_progress\")"));
381 }
382
383 #[test]
384 fn test_lowercase_variant_renamed() {
385 let e = make_enum("status", vec!["active"]);
386 let code = gen(&e, DatabaseKind::Postgres);
387 assert!(code.contains("Active"));
388 assert!(code.contains("sqlx(rename = \"active\")"));
389 }
390
391 #[test]
392 fn test_already_pascal_no_rename() {
393 let e = make_enum("status", vec!["Active"]);
394 let code = gen(&e, DatabaseKind::Postgres);
395 assert!(code.contains("Active"));
396 assert!(!code.contains("sqlx(rename"));
397 }
398
399 #[test]
400 fn test_upper_case_variant_renamed() {
401 let e = make_enum("status", vec!["UPPER_CASE"]);
402 let code = gen(&e, DatabaseKind::Postgres);
403 assert!(code.contains("UpperCase"));
404 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
405 }
406
407 #[test]
410 fn test_default_derives() {
411 let e = make_enum("status", vec!["a"]);
412 let code = gen(&e, DatabaseKind::Postgres);
413 assert!(code.contains("Debug"));
414 assert!(code.contains("Clone"));
415 assert!(code.contains("PartialEq"));
416 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
417 }
418
419 #[test]
420 fn test_extra_derive_serialize() {
421 let e = make_enum("status", vec!["a"]);
422 let derives = vec!["Serialize".to_string()];
423 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
424 assert!(code.contains("Serialize"));
425 }
426
427 #[test]
428 fn test_extra_derives_serde_imports() {
429 let e = make_enum("status", vec!["a"]);
430 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
431 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
432 assert!(imports.iter().any(|i| i.contains("serde")));
433 }
434
435 #[test]
438 fn test_no_extra_derives_has_serde_import() {
439 let e = make_enum("status", vec!["a"]);
440 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
441 assert!(imports.iter().any(|i| i.contains("serde")));
442 }
443
444 #[test]
445 fn test_serde_import_present() {
446 let e = make_enum("status", vec!["a"]);
447 let derives = vec!["Serialize".to_string()];
448 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
449 assert!(!imports.is_empty());
450 }
451
452 #[test]
455 fn test_single_variant() {
456 let e = make_enum("status", vec!["only"]);
457 let code = gen(&e, DatabaseKind::Postgres);
458 assert!(code.contains("Only"));
459 }
460
461 #[test]
462 fn test_many_variants() {
463 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
464 let e = make_enum("status", variants);
465 let code = gen(&e, DatabaseKind::Postgres);
466 assert!(code.contains("A,"));
467 assert!(code.contains("J,"));
468 }
469
470 #[test]
471 fn test_variant_with_digits() {
472 let e = make_enum("version", vec!["v2"]);
473 let code = gen(&e, DatabaseKind::Postgres);
474 assert!(code.contains("V2"));
475 }
476
477 #[test]
478 fn test_enum_name_with_double_underscores() {
479 let e = make_enum("my__enum", vec!["a"]);
480 let code = gen(&e, DatabaseKind::Postgres);
481 assert!(code.contains("pub enum MyEnum"));
482 }
483
484 #[test]
487 fn test_default_impl_generated() {
488 let e = EnumInfo {
489 schema_name: "public".to_string(),
490 name: "task_status".to_string(),
491 variants: vec![
492 "idle".to_string(),
493 "running".to_string(),
494 "done".to_string(),
495 ],
496 default_variant: Some("idle".to_string()),
497 };
498 let code = gen(&e, DatabaseKind::Postgres);
499 assert!(code.contains("impl Default for TaskStatus"));
500 assert!(code.contains("Self::Idle"));
501 }
502
503 #[test]
504 fn test_no_default_impl_when_none() {
505 let e = make_enum("status", vec!["active", "inactive"]);
506 let code = gen(&e, DatabaseKind::Postgres);
507 assert!(!code.contains("impl Default"));
508 }
509
510 #[test]
511 fn test_default_impl_snake_case_variant() {
512 let e = EnumInfo {
513 schema_name: "public".to_string(),
514 name: "status".to_string(),
515 variants: vec!["in_progress".to_string(), "done".to_string()],
516 default_variant: Some("in_progress".to_string()),
517 };
518 let code = gen(&e, DatabaseKind::Postgres);
519 assert!(code.contains("impl Default for Status"));
520 assert!(code.contains("Self::InProgress"));
521 }
522
523 fn make_enum_in_schema(schema: &str, name: &str, variants: Vec<&str>) -> EnumInfo {
526 EnumInfo {
527 schema_name: schema.to_string(),
528 name: name.to_string(),
529 variants: variants.into_iter().map(|s| s.to_string()).collect(),
530 default_variant: None,
531 }
532 }
533
534 #[test]
535 fn test_public_schema_full_output() {
536 let e = make_enum_in_schema(
537 "public",
538 "order_status",
539 vec!["pending", "shipped", "delivered"],
540 );
541 let code = gen(&e, DatabaseKind::Postgres);
542
543 assert!(code.contains("Enum: public.order_status"));
544 assert!(code.contains("pub enum OrderStatus"));
545 assert!(code.contains("sqlx(type_name = \"order_status\")"));
546 assert!(!code.contains("sqlx(type_name = \"public.order_status\")"));
547 assert!(code
548 .contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")"));
549 assert!(code.contains("Pending"));
550 assert!(code.contains("Shipped"));
551 assert!(code.contains("Delivered"));
552 }
553
554 #[test]
555 fn test_named_schema_full_output() {
556 let e = make_enum_in_schema(
557 "analysis",
558 "toolcall_status",
559 vec!["PENDING", "RUNNING", "DONE"],
560 );
561 let code = gen(&e, DatabaseKind::Postgres);
562
563 assert!(code.contains("Enum: analysis.toolcall_status"));
564 assert!(code.contains("pub enum ToolcallStatus"));
565 assert!(code.contains("sqlx(type_name = \"toolcall_status\")"));
566 assert!(!code.contains("\"analysis.toolcall_status\""));
567 assert!(code.contains(
568 "sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")"
569 ));
570 assert!(code.contains("Pending"));
571 assert!(code.contains("Running"));
572 assert!(code.contains("Done"));
573 }
574
575 #[test]
576 fn test_named_schema_with_default_variant() {
577 let e = EnumInfo {
578 schema_name: "billing".to_string(),
579 name: "payment_status".to_string(),
580 variants: vec![
581 "pending".to_string(),
582 "paid".to_string(),
583 "refunded".to_string(),
584 ],
585 default_variant: Some("pending".to_string()),
586 };
587 let code = gen(&e, DatabaseKind::Postgres);
588
589 assert!(code.contains("sqlx(type_name = \"payment_status\")"));
590 assert!(!code.contains("\"billing.payment_status\""));
591 assert!(code.contains("impl Default for PaymentStatus"));
592 assert!(code.contains("Self::Pending"));
593 }
594
595 #[test]
596 fn test_named_schema_variant_rename() {
597 let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]);
598 let code = gen(&e, DatabaseKind::Postgres);
599
600 assert!(code.contains("sqlx(type_name = \"log_level\")"));
601 assert!(!code.contains("\"audit.log_level\""));
602 assert!(code.contains("sqlx(rename = \"info\")"));
603 assert!(code.contains("sqlx(rename = \"warn_high\")"));
604 assert!(code.contains("WarnHigh"));
605 assert!(code.contains("sqlx(rename = \"CRITICAL\")"));
606 assert!(code.contains("Critical"));
607 }
608
609 #[test]
610 fn test_named_schema_mysql_no_type_name() {
611 let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
612 let code = gen(&e, DatabaseKind::Mysql);
613
614 assert!(!code.contains("type_name"));
615 }
616
617 #[test]
618 fn test_named_schema_sqlite_no_type_name() {
619 let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
620 let code = gen(&e, DatabaseKind::Sqlite);
621
622 assert!(!code.contains("type_name"));
623 }
624}