1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20
21 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 imports.insert("use serde::{Serialize, Deserialize};".to_string());
25 let mut derive_tokens = vec![
26 quote! { Debug },
27 quote! { Clone },
28 quote! { PartialEq },
29 quote! { Eq },
30 quote! { Serialize },
31 quote! { Deserialize },
32 quote! { sqlx::Type },
33 ];
34 for d in extra_derives {
35 let ident = format_ident!("{}", d);
36 derive_tokens.push(quote! { #ident });
37 }
38
39 let type_attr = if db_kind == DatabaseKind::Postgres {
42 let pg_name = if enum_info.schema_name != "public" {
43 format!("{}.{}", enum_info.schema_name, enum_info.name)
44 } else {
45 enum_info.name.clone()
46 };
47 quote! { #[sqlx(type_name = #pg_name)] }
48 } else {
49 quote! {}
50 };
51
52 let variants: Vec<TokenStream> = enum_info
53 .variants
54 .iter()
55 .map(|v| {
56 let variant_pascal = v.to_upper_camel_case();
57 let variant_ident = format_ident!("{}", variant_pascal);
58
59 let rename = if variant_pascal != *v {
60 quote! { #[sqlx(rename = #v)] }
61 } else {
62 quote! {}
63 };
64
65 quote! {
66 #rename
67 #variant_ident,
68 }
69 })
70 .collect();
71
72 let tokens = quote! {
73 #[doc = #doc]
74 #[derive(#(#derive_tokens),*)]
75 #type_attr
76 pub enum #enum_name {
77 #(#variants)*
78 }
79 };
80
81 (tokens, imports)
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::codegen::parse_and_format;
88
89 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
90 EnumInfo {
91 schema_name: "public".to_string(),
92 name: name.to_string(),
93 variants: variants.into_iter().map(|s| s.to_string()).collect(),
94 }
95 }
96
97 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
98 let (tokens, _) = generate_enum(info, db, &[]);
99 parse_and_format(&tokens)
100 }
101
102 fn gen_with_derives(
103 info: &EnumInfo,
104 db: DatabaseKind,
105 derives: &[String],
106 ) -> (String, BTreeSet<String>) {
107 let (tokens, imports) = generate_enum(info, db, derives);
108 (parse_and_format(&tokens), imports)
109 }
110
111 #[test]
114 fn test_enum_variants() {
115 let e = make_enum("status", vec!["active", "inactive"]);
116 let code = gen(&e, DatabaseKind::Postgres);
117 assert!(code.contains("Active"));
118 assert!(code.contains("Inactive"));
119 }
120
121 #[test]
122 fn test_enum_name_pascal_case() {
123 let e = make_enum("user_status", vec!["a"]);
124 let code = gen(&e, DatabaseKind::Postgres);
125 assert!(code.contains("pub enum UserStatus"));
126 }
127
128 #[test]
129 fn test_doc_comment() {
130 let e = make_enum("status", vec!["a"]);
131 let code = gen(&e, DatabaseKind::Postgres);
132 assert!(code.contains("Enum: public.status"));
133 }
134
135 #[test]
138 fn test_postgres_has_type_name() {
139 let e = make_enum("user_status", vec!["a"]);
140 let code = gen(&e, DatabaseKind::Postgres);
141 assert!(code.contains("sqlx(type_name = \"user_status\")"));
142 }
143
144 #[test]
145 fn test_postgres_non_public_schema_qualified_type_name() {
146 let e = EnumInfo {
147 schema_name: "auth".to_string(),
148 name: "role".to_string(),
149 variants: vec!["admin".to_string(), "user".to_string()],
150 };
151 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
152 let code = parse_and_format(&tokens);
153 assert!(code.contains("sqlx(type_name = \"auth.role\")"));
154 }
155
156 #[test]
157 fn test_postgres_public_schema_not_qualified() {
158 let e = make_enum("status", vec!["a"]);
159 let code = gen(&e, DatabaseKind::Postgres);
160 assert!(code.contains("sqlx(type_name = \"status\")"));
161 assert!(!code.contains("type_name = \"public.status\""));
163 }
164
165 #[test]
166 fn test_mysql_no_type_name() {
167 let e = make_enum("status", vec!["a"]);
168 let code = gen(&e, DatabaseKind::Mysql);
169 assert!(!code.contains("type_name"));
170 }
171
172 #[test]
173 fn test_sqlite_no_type_name() {
174 let e = make_enum("status", vec!["a"]);
175 let code = gen(&e, DatabaseKind::Sqlite);
176 assert!(!code.contains("type_name"));
177 }
178
179 #[test]
182 fn test_snake_case_variant_renamed() {
183 let e = make_enum("status", vec!["in_progress"]);
184 let code = gen(&e, DatabaseKind::Postgres);
185 assert!(code.contains("InProgress"));
186 assert!(code.contains("sqlx(rename = \"in_progress\")"));
187 }
188
189 #[test]
190 fn test_lowercase_variant_renamed() {
191 let e = make_enum("status", vec!["active"]);
192 let code = gen(&e, DatabaseKind::Postgres);
193 assert!(code.contains("Active"));
194 assert!(code.contains("sqlx(rename = \"active\")"));
195 }
196
197 #[test]
198 fn test_already_pascal_no_rename() {
199 let e = make_enum("status", vec!["Active"]);
200 let code = gen(&e, DatabaseKind::Postgres);
201 assert!(code.contains("Active"));
202 assert!(!code.contains("sqlx(rename"));
203 }
204
205 #[test]
206 fn test_upper_case_variant_renamed() {
207 let e = make_enum("status", vec!["UPPER_CASE"]);
208 let code = gen(&e, DatabaseKind::Postgres);
209 assert!(code.contains("UpperCase"));
210 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
211 }
212
213 #[test]
216 fn test_default_derives() {
217 let e = make_enum("status", vec!["a"]);
218 let code = gen(&e, DatabaseKind::Postgres);
219 assert!(code.contains("Debug"));
220 assert!(code.contains("Clone"));
221 assert!(code.contains("PartialEq"));
222 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
223 }
224
225 #[test]
226 fn test_extra_derive_serialize() {
227 let e = make_enum("status", vec!["a"]);
228 let derives = vec!["Serialize".to_string()];
229 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
230 assert!(code.contains("Serialize"));
231 }
232
233 #[test]
234 fn test_extra_derives_serde_imports() {
235 let e = make_enum("status", vec!["a"]);
236 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
237 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
238 assert!(imports.iter().any(|i| i.contains("serde")));
239 }
240
241 #[test]
244 fn test_no_extra_derives_has_serde_import() {
245 let e = make_enum("status", vec!["a"]);
246 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
247 assert!(imports.iter().any(|i| i.contains("serde")));
248 }
249
250 #[test]
251 fn test_serde_import_present() {
252 let e = make_enum("status", vec!["a"]);
253 let derives = vec!["Serialize".to_string()];
254 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
255 assert!(!imports.is_empty());
256 }
257
258 #[test]
261 fn test_single_variant() {
262 let e = make_enum("status", vec!["only"]);
263 let code = gen(&e, DatabaseKind::Postgres);
264 assert!(code.contains("Only"));
265 }
266
267 #[test]
268 fn test_many_variants() {
269 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
270 let e = make_enum("status", variants);
271 let code = gen(&e, DatabaseKind::Postgres);
272 assert!(code.contains("A,"));
273 assert!(code.contains("J,"));
274 }
275
276 #[test]
277 fn test_variant_with_digits() {
278 let e = make_enum("version", vec!["v2"]);
279 let code = gen(&e, DatabaseKind::Postgres);
280 assert!(code.contains("V2"));
281 }
282
283 #[test]
284 fn test_enum_name_with_double_underscores() {
285 let e = make_enum("my__enum", vec!["a"]);
286 let code = gen(&e, DatabaseKind::Postgres);
287 assert!(code.contains("pub enum MyEnum"));
288 }
289}