1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, rust_type_name_for};
9use crate::introspect::{EnumInfo, SchemaInfo};
10
11pub fn check_variant_collisions(enum_info: &EnumInfo) -> crate::error::Result<()> {
16 use std::collections::BTreeMap;
17 let mut seen: BTreeMap<String, &str> = BTreeMap::new();
18 for v in &enum_info.variants {
19 let pascal = v.to_upper_camel_case();
20 if let Some(prev) = seen.get(pascal.as_str()).copied() {
21 return Err(crate::error::Error::Config(format!(
22 "Enum '{}.{}': SQL variants '{}' and '{}' both map to Rust identifier '{}'. \
23 Rename one of them in the database or use a custom mapping.",
24 enum_info.schema_name, enum_info.name, prev, v, pascal
25 )));
26 }
27 seen.insert(pascal, v.as_str());
28 }
29 Ok(())
30}
31
32pub fn generate_enum(
33 enum_info: &EnumInfo,
34 db_kind: DatabaseKind,
35 extra_derives: &[String],
36) -> (TokenStream, BTreeSet<String>) {
37 generate_enum_with_schema(enum_info, db_kind, extra_derives, &SchemaInfo::default())
40}
41
42pub fn generate_enum_with_schema(
43 enum_info: &EnumInfo,
44 db_kind: DatabaseKind,
45 extra_derives: &[String],
46 schema_info: &SchemaInfo,
47) -> (TokenStream, BTreeSet<String>) {
48 let mut imports = BTreeSet::new();
49 for imp in imports_for_derives(extra_derives) {
50 imports.insert(imp);
51 }
52
53 let rust_name = rust_type_name_for(schema_info, &enum_info.schema_name, &enum_info.name);
54 let enum_name = format_ident!("{}", rust_name);
55 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
56 let search_path_doc = if db_kind == DatabaseKind::Postgres
61 && !crate::codegen::is_default_schema(&enum_info.schema_name)
62 {
63 let msg = format!(
64 "Lives in PostgreSQL schema `{schema}`. The sqlx connection \
65 must include `{schema}` in its search_path so PG resolves the \
66 unqualified `type_name = \"{name}\"` to this enum. Example:\n\
67 \n\
68 ```ignore\n\
69 sqlx::query(\"SET search_path TO public, {schema}\")\n\
70 ```",
71 schema = enum_info.schema_name,
72 name = enum_info.name,
73 );
74 Some(msg)
75 } else {
76 None
77 };
78
79 imports.insert("use serde::{Serialize, Deserialize};".to_string());
80 imports.insert("use sqlx_gen::SqlxGen;".to_string());
81 let mut derive_tokens = vec![
82 quote! { Debug },
83 quote! { Clone },
84 quote! { PartialEq },
85 quote! { Eq },
86 quote! { Serialize },
87 quote! { Deserialize },
88 quote! { sqlx::Type },
89 quote! { SqlxGen },
90 ];
91 if enum_info.default_variant.is_some() {
94 derive_tokens.push(quote! { Default });
95 }
96 for d in extra_derives {
97 let ident = format_ident!("{}", d);
98 derive_tokens.push(quote! { #ident });
99 }
100 let default_variant_pascal = enum_info
101 .default_variant
102 .as_ref()
103 .map(|v| v.to_upper_camel_case());
104
105 let type_attr = if db_kind == DatabaseKind::Postgres {
110 let pg_name = &enum_info.name;
111 quote! { #[sqlx(type_name = #pg_name)] }
112 } else {
113 quote! {}
114 };
115
116 let variants: Vec<TokenStream> = enum_info
117 .variants
118 .iter()
119 .map(|v| {
120 let variant_pascal = v.to_upper_camel_case();
121 let variant_ident = format_ident!("{}", variant_pascal);
122
123 let rename = if variant_pascal != *v {
124 quote! { #[sqlx(rename = #v)] }
125 } else {
126 quote! {}
127 };
128
129 let default_attr = if default_variant_pascal.as_deref() == Some(variant_pascal.as_str())
130 {
131 quote! { #[default] }
132 } else {
133 quote! {}
134 };
135
136 quote! {
137 #rename
138 #default_attr
139 #variant_ident,
140 }
141 })
142 .collect();
143
144 let _ = db_kind;
149
150 let schema_name_str = &enum_info.schema_name;
151 let enum_name_str = &enum_info.name;
152 let search_path_doc_tokens = match &search_path_doc {
153 Some(m) => quote! { #[doc = #m] },
154 None => quote! {},
155 };
156
157 let tokens = quote! {
158 #[doc = #doc]
159 #search_path_doc_tokens
160 #[derive(#(#derive_tokens),*)]
161 #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
162 #type_attr
163 pub enum #enum_name {
164 #(#variants)*
165 }
166 };
167
168 (tokens, imports)
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::codegen::parse_and_format;
175
176 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
177 EnumInfo {
178 schema_name: "public".to_string(),
179 name: name.to_string(),
180 variants: variants.into_iter().map(|s| s.to_string()).collect(),
181 default_variant: None,
182 }
183 }
184
185 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
186 let (tokens, _) = generate_enum(info, db, &[]);
187 parse_and_format(&tokens).unwrap()
188 }
189
190 fn gen_with_derives(
191 info: &EnumInfo,
192 db: DatabaseKind,
193 derives: &[String],
194 ) -> (String, BTreeSet<String>) {
195 let (tokens, imports) = generate_enum(info, db, derives);
196 (parse_and_format(&tokens).unwrap(), imports)
197 }
198
199 #[test]
202 fn test_enum_variants() {
203 let e = make_enum("status", vec!["active", "inactive"]);
204 let code = gen(&e, DatabaseKind::Postgres);
205 assert!(code.contains("Active"));
206 assert!(code.contains("Inactive"));
207 }
208
209 #[test]
210 fn test_enum_name_pascal_case() {
211 let e = make_enum("user_status", vec!["a"]);
212 let code = gen(&e, DatabaseKind::Postgres);
213 assert!(code.contains("pub enum UserStatus"));
214 }
215
216 #[test]
217 fn test_doc_comment() {
218 let e = make_enum("status", vec!["a"]);
219 let code = gen(&e, DatabaseKind::Postgres);
220 assert!(code.contains("Enum: public.status"));
221 }
222
223 #[test]
224 fn test_sqlx_gen_attr_has_schema_and_name() {
225 let e = make_enum("status", vec!["a"]);
226 let code = gen(&e, DatabaseKind::Postgres);
227 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
228 }
229
230 #[test]
231 fn test_sqlx_gen_attr_non_public_schema() {
232 let e = EnumInfo {
233 schema_name: "auth".to_string(),
234 name: "role".to_string(),
235 variants: vec!["admin".to_string(), "user".to_string()],
236 default_variant: None,
237 };
238 let code = gen(&e, DatabaseKind::Postgres);
239 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
240 }
241
242 #[test]
245 fn test_postgres_has_type_name() {
246 let e = make_enum("user_status", vec!["a"]);
247 let code = gen(&e, DatabaseKind::Postgres);
248 assert!(code.contains("sqlx(type_name = \"user_status\")"));
249 }
250
251 #[test]
252 fn test_check_variant_collisions_detects_after_camel_case() {
253 let e = EnumInfo {
254 schema_name: "public".into(),
255 name: "weird".into(),
256 variants: vec!["foo bar".into(), "foo_bar".into()],
257 default_variant: None,
258 };
259 let result = check_variant_collisions(&e);
260 assert!(result.is_err(), "must detect collision");
261 let msg = result.unwrap_err().to_string();
262 assert!(
263 msg.contains("FooBar"),
264 "error must mention conflicting Rust ident, got: {}",
265 msg
266 );
267 assert!(msg.contains("foo bar") || msg.contains("foo_bar"));
268 }
269
270 #[test]
271 fn test_check_variant_collisions_accepts_distinct_variants() {
272 let e = make_enum("status", vec!["active", "inactive"]);
273 assert!(check_variant_collisions(&e).is_ok());
274 }
275
276 #[test]
277 fn test_check_variant_collisions_accepts_single_variant() {
278 let e = make_enum("status", vec!["only"]);
279 assert!(check_variant_collisions(&e).is_ok());
280 }
281
282 #[test]
283 fn test_does_not_emit_manual_pg_has_array_type_impl() {
284 for db in [
287 DatabaseKind::Postgres,
288 DatabaseKind::Mysql,
289 DatabaseKind::Sqlite,
290 ] {
291 let e = make_enum("status", vec!["a", "b"]);
292 let code = gen(&e, db);
293 assert!(
294 !code.contains("PgHasArrayType"),
295 "{:?}: must not emit a manual PgHasArrayType impl, got:\n{}",
296 db,
297 code
298 );
299 }
300 }
301
302 #[test]
303 fn test_postgres_non_public_schema_type_name_is_unqualified() {
304 let e = EnumInfo {
307 schema_name: "auth".to_string(),
308 name: "role".to_string(),
309 variants: vec!["admin".to_string(), "user".to_string()],
310 default_variant: None,
311 };
312 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
313 let code = parse_and_format(&tokens).unwrap();
314 assert!(
315 code.contains("sqlx(type_name = \"role\")"),
316 "type_name must be unqualified for sqlx 0.8 compatibility, got:\n{}",
317 code
318 );
319 assert!(
320 !code.contains("\"auth.role\""),
321 "type_name must NOT include schema; got:\n{}",
322 code
323 );
324 }
325
326 #[test]
327 fn test_postgres_public_schema_not_qualified() {
328 let e = make_enum("status", vec!["a"]);
329 let code = gen(&e, DatabaseKind::Postgres);
330 assert!(code.contains("sqlx(type_name = \"status\")"));
331 assert!(!code.contains("type_name = \"public.status\""));
333 }
334
335 #[test]
336 fn test_mysql_inline_enum_emits_rename_for_lowercase_variants() {
337 let e = make_enum("status", vec!["active", "inactive"]);
340 let code = gen(&e, DatabaseKind::Mysql);
341 assert!(
342 code.contains("sqlx(rename = \"active\")"),
343 "MySQL inline ENUM variant must carry rename for round-trip:\n{}",
344 code
345 );
346 assert!(code.contains("sqlx(rename = \"inactive\")"));
347 assert!(!code.contains("type_name"));
349 }
350
351 #[test]
352 fn test_mysql_inline_enum_preserves_case_sensitive_variants() {
353 let e = make_enum("priority", vec!["LOW", "HIGH"]);
354 let code = gen(&e, DatabaseKind::Mysql);
355 assert!(code.contains("sqlx(rename = \"LOW\")"));
357 assert!(code.contains("sqlx(rename = \"HIGH\")"));
358 }
359
360 #[test]
361 fn test_mysql_no_type_name() {
362 let e = make_enum("status", vec!["a"]);
363 let code = gen(&e, DatabaseKind::Mysql);
364 assert!(!code.contains("type_name"));
365 }
366
367 #[test]
368 fn test_sqlite_no_type_name() {
369 let e = make_enum("status", vec!["a"]);
370 let code = gen(&e, DatabaseKind::Sqlite);
371 assert!(!code.contains("type_name"));
372 }
373
374 #[test]
377 fn test_snake_case_variant_renamed() {
378 let e = make_enum("status", vec!["in_progress"]);
379 let code = gen(&e, DatabaseKind::Postgres);
380 assert!(code.contains("InProgress"));
381 assert!(code.contains("sqlx(rename = \"in_progress\")"));
382 }
383
384 #[test]
385 fn test_lowercase_variant_renamed() {
386 let e = make_enum("status", vec!["active"]);
387 let code = gen(&e, DatabaseKind::Postgres);
388 assert!(code.contains("Active"));
389 assert!(code.contains("sqlx(rename = \"active\")"));
390 }
391
392 #[test]
393 fn test_already_pascal_no_rename() {
394 let e = make_enum("status", vec!["Active"]);
395 let code = gen(&e, DatabaseKind::Postgres);
396 assert!(code.contains("Active"));
397 assert!(!code.contains("sqlx(rename"));
398 }
399
400 #[test]
401 fn test_upper_case_variant_renamed() {
402 let e = make_enum("status", vec!["UPPER_CASE"]);
403 let code = gen(&e, DatabaseKind::Postgres);
404 assert!(code.contains("UpperCase"));
405 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
406 }
407
408 #[test]
411 fn test_default_derives() {
412 let e = make_enum("status", vec!["a"]);
413 let code = gen(&e, DatabaseKind::Postgres);
414 assert!(code.contains("Debug"));
415 assert!(code.contains("Clone"));
416 assert!(code.contains("PartialEq"));
417 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
418 }
419
420 #[test]
421 fn test_extra_derive_serialize() {
422 let e = make_enum("status", vec!["a"]);
423 let derives = vec!["Serialize".to_string()];
424 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
425 assert!(code.contains("Serialize"));
426 }
427
428 #[test]
429 fn test_extra_derives_serde_imports() {
430 let e = make_enum("status", vec!["a"]);
431 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
432 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
433 assert!(imports.iter().any(|i| i.contains("serde")));
434 }
435
436 #[test]
439 fn test_no_extra_derives_has_serde_import() {
440 let e = make_enum("status", vec!["a"]);
441 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
442 assert!(imports.iter().any(|i| i.contains("serde")));
443 }
444
445 #[test]
446 fn test_serde_import_present() {
447 let e = make_enum("status", vec!["a"]);
448 let derives = vec!["Serialize".to_string()];
449 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
450 assert!(!imports.is_empty());
451 }
452
453 #[test]
456 fn test_single_variant() {
457 let e = make_enum("status", vec!["only"]);
458 let code = gen(&e, DatabaseKind::Postgres);
459 assert!(code.contains("Only"));
460 }
461
462 #[test]
463 fn test_many_variants() {
464 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
465 let e = make_enum("status", variants);
466 let code = gen(&e, DatabaseKind::Postgres);
467 assert!(code.contains("A,"));
468 assert!(code.contains("J,"));
469 }
470
471 #[test]
472 fn test_variant_with_digits() {
473 let e = make_enum("version", vec!["v2"]);
474 let code = gen(&e, DatabaseKind::Postgres);
475 assert!(code.contains("V2"));
476 }
477
478 #[test]
479 fn test_enum_name_with_double_underscores() {
480 let e = make_enum("my__enum", vec!["a"]);
481 let code = gen(&e, DatabaseKind::Postgres);
482 assert!(code.contains("pub enum MyEnum"));
483 }
484
485 #[test]
488 fn test_default_uses_derive_and_attribute() {
489 let e = EnumInfo {
490 schema_name: "public".to_string(),
491 name: "task_status".to_string(),
492 variants: vec![
493 "idle".to_string(),
494 "running".to_string(),
495 "done".to_string(),
496 ],
497 default_variant: Some("idle".to_string()),
498 };
499 let code = gen(&e, DatabaseKind::Postgres);
500 assert!(
501 code.contains("Default"),
502 "expected `Default` in derive list, got:\n{}",
503 code
504 );
505 assert!(
506 code.contains("#[default]"),
507 "expected #[default] attribute on the variant, got:\n{}",
508 code
509 );
510 assert!(!code.contains("impl Default for TaskStatus"));
512 }
513
514 #[test]
515 fn test_no_default_derive_when_no_default_variant() {
516 let e = make_enum("status", vec!["active", "inactive"]);
517 let code = gen(&e, DatabaseKind::Postgres);
518 assert!(!code.contains("impl Default"));
519 assert!(!code.contains("#[default]"));
520 let derive_line = code
522 .lines()
523 .find(|l| l.contains("#[derive"))
524 .expect("derive line");
525 assert!(
526 !derive_line.contains(", Default"),
527 "derive list should not include Default, got: {}",
528 derive_line
529 );
530 }
531
532 #[test]
533 fn test_default_attribute_on_correct_variant_snake_case() {
534 let e = EnumInfo {
535 schema_name: "public".to_string(),
536 name: "status".to_string(),
537 variants: vec!["in_progress".to_string(), "done".to_string()],
538 default_variant: Some("in_progress".to_string()),
539 };
540 let code = gen(&e, DatabaseKind::Postgres);
541 let in_progress_idx = code.find("InProgress").expect("InProgress");
544 let default_attr_idx = code.find("#[default]").expect("#[default]");
545 assert!(
546 default_attr_idx < in_progress_idx,
547 "#[default] must precede InProgress"
548 );
549 let between = &code[default_attr_idx..in_progress_idx];
550 assert!(
551 !between.contains("Done"),
552 "#[default] landed on the wrong variant:\n{}",
553 code
554 );
555 }
556
557 fn make_enum_in_schema(schema: &str, name: &str, variants: Vec<&str>) -> EnumInfo {
560 EnumInfo {
561 schema_name: schema.to_string(),
562 name: name.to_string(),
563 variants: variants.into_iter().map(|s| s.to_string()).collect(),
564 default_variant: None,
565 }
566 }
567
568 #[test]
569 fn test_public_schema_full_output() {
570 let e = make_enum_in_schema(
571 "public",
572 "order_status",
573 vec!["pending", "shipped", "delivered"],
574 );
575 let code = gen(&e, DatabaseKind::Postgres);
576
577 assert!(code.contains("Enum: public.order_status"));
578 assert!(code.contains("pub enum OrderStatus"));
579 assert!(code.contains("sqlx(type_name = \"order_status\")"));
580 assert!(!code.contains("sqlx(type_name = \"public.order_status\")"));
581 assert!(code
582 .contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"order_status\")"));
583 assert!(code.contains("Pending"));
584 assert!(code.contains("Shipped"));
585 assert!(code.contains("Delivered"));
586 }
587
588 #[test]
589 fn test_named_schema_full_output() {
590 let e = make_enum_in_schema(
591 "analysis",
592 "toolcall_status",
593 vec!["PENDING", "RUNNING", "DONE"],
594 );
595 let code = gen(&e, DatabaseKind::Postgres);
596
597 assert!(code.contains("Enum: analysis.toolcall_status"));
598 assert!(code.contains("pub enum ToolcallStatus"));
599 assert!(code.contains("sqlx(type_name = \"toolcall_status\")"));
600 assert!(!code.contains("\"analysis.toolcall_status\""));
601 assert!(code.contains(
602 "sqlx_gen(kind = \"enum\", schema = \"analysis\", name = \"toolcall_status\")"
603 ));
604 assert!(code.contains("Pending"));
605 assert!(code.contains("Running"));
606 assert!(code.contains("Done"));
607 }
608
609 #[test]
610 fn test_named_schema_with_default_variant() {
611 let e = EnumInfo {
612 schema_name: "billing".to_string(),
613 name: "payment_status".to_string(),
614 variants: vec![
615 "pending".to_string(),
616 "paid".to_string(),
617 "refunded".to_string(),
618 ],
619 default_variant: Some("pending".to_string()),
620 };
621 let code = gen(&e, DatabaseKind::Postgres);
622
623 assert!(code.contains("sqlx(type_name = \"payment_status\")"));
624 assert!(!code.contains("\"billing.payment_status\""));
625 assert!(code.contains("Default"));
627 assert!(code.contains("#[default]"));
628 assert!(!code.contains("impl Default for PaymentStatus"));
629 }
630
631 #[test]
632 fn test_named_schema_variant_rename() {
633 let e = make_enum_in_schema("audit", "log_level", vec!["info", "warn_high", "CRITICAL"]);
634 let code = gen(&e, DatabaseKind::Postgres);
635
636 assert!(code.contains("sqlx(type_name = \"log_level\")"));
637 assert!(!code.contains("\"audit.log_level\""));
638 assert!(code.contains("sqlx(rename = \"info\")"));
639 assert!(code.contains("sqlx(rename = \"warn_high\")"));
640 assert!(code.contains("WarnHigh"));
641 assert!(code.contains("sqlx(rename = \"CRITICAL\")"));
642 assert!(code.contains("Critical"));
643 }
644
645 #[test]
646 fn test_named_schema_mysql_no_type_name() {
647 let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
648 let code = gen(&e, DatabaseKind::Mysql);
649
650 assert!(!code.contains("type_name"));
651 }
652
653 #[test]
654 fn test_named_schema_sqlite_no_type_name() {
655 let e = make_enum_in_schema("analytics", "event_type", vec!["click", "view"]);
656 let code = gen(&e, DatabaseKind::Sqlite);
657
658 assert!(!code.contains("type_name"));
659 }
660}