1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20
21 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 imports.insert("use serde::{Serialize, Deserialize};".to_string());
25 imports.insert("use sqlx_gen::SqlxGen;".to_string());
26 let mut derive_tokens = vec![
27 quote! { Debug },
28 quote! { Clone },
29 quote! { PartialEq },
30 quote! { Eq },
31 quote! { Serialize },
32 quote! { Deserialize },
33 quote! { sqlx::Type },
34 quote! { SqlxGen },
35 ];
36 for d in extra_derives {
37 let ident = format_ident!("{}", d);
38 derive_tokens.push(quote! { #ident });
39 }
40
41 let type_attr = if db_kind == DatabaseKind::Postgres {
44 let pg_name = if enum_info.schema_name != "public" {
45 format!("{}.{}", enum_info.schema_name, enum_info.name)
46 } else {
47 enum_info.name.clone()
48 };
49 quote! { #[sqlx(type_name = #pg_name)] }
50 } else {
51 quote! {}
52 };
53
54 let variants: Vec<TokenStream> = enum_info
55 .variants
56 .iter()
57 .map(|v| {
58 let variant_pascal = v.to_upper_camel_case();
59 let variant_ident = format_ident!("{}", variant_pascal);
60
61 let rename = if variant_pascal != *v {
62 quote! { #[sqlx(rename = #v)] }
63 } else {
64 quote! {}
65 };
66
67 quote! {
68 #rename
69 #variant_ident,
70 }
71 })
72 .collect();
73
74 let default_impl = if let Some(ref default_variant) = enum_info.default_variant {
75 let variant_pascal = default_variant.to_upper_camel_case();
76 let variant_ident = format_ident!("{}", variant_pascal);
77 quote! {
78 impl Default for #enum_name {
79 fn default() -> Self {
80 Self::#variant_ident
81 }
82 }
83 }
84 } else {
85 quote! {}
86 };
87
88 let schema_name_str = &enum_info.schema_name;
89 let enum_name_str = &enum_info.name;
90
91 let tokens = quote! {
92 #[doc = #doc]
93 #[derive(#(#derive_tokens),*)]
94 #[sqlx_gen(kind = "enum", schema = #schema_name_str, name = #enum_name_str)]
95 #type_attr
96 pub enum #enum_name {
97 #(#variants)*
98 }
99
100 #default_impl
101 };
102
103 (tokens, imports)
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::codegen::parse_and_format;
110
111 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
112 EnumInfo {
113 schema_name: "public".to_string(),
114 name: name.to_string(),
115 variants: variants.into_iter().map(|s| s.to_string()).collect(),
116 default_variant: None,
117 }
118 }
119
120 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
121 let (tokens, _) = generate_enum(info, db, &[]);
122 parse_and_format(&tokens)
123 }
124
125 fn gen_with_derives(
126 info: &EnumInfo,
127 db: DatabaseKind,
128 derives: &[String],
129 ) -> (String, BTreeSet<String>) {
130 let (tokens, imports) = generate_enum(info, db, derives);
131 (parse_and_format(&tokens), imports)
132 }
133
134 #[test]
137 fn test_enum_variants() {
138 let e = make_enum("status", vec!["active", "inactive"]);
139 let code = gen(&e, DatabaseKind::Postgres);
140 assert!(code.contains("Active"));
141 assert!(code.contains("Inactive"));
142 }
143
144 #[test]
145 fn test_enum_name_pascal_case() {
146 let e = make_enum("user_status", vec!["a"]);
147 let code = gen(&e, DatabaseKind::Postgres);
148 assert!(code.contains("pub enum UserStatus"));
149 }
150
151 #[test]
152 fn test_doc_comment() {
153 let e = make_enum("status", vec!["a"]);
154 let code = gen(&e, DatabaseKind::Postgres);
155 assert!(code.contains("Enum: public.status"));
156 }
157
158 #[test]
159 fn test_sqlx_gen_attr_has_schema_and_name() {
160 let e = make_enum("status", vec!["a"]);
161 let code = gen(&e, DatabaseKind::Postgres);
162 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"public\", name = \"status\")"));
163 }
164
165 #[test]
166 fn test_sqlx_gen_attr_non_public_schema() {
167 let e = EnumInfo {
168 schema_name: "auth".to_string(),
169 name: "role".to_string(),
170 variants: vec!["admin".to_string(), "user".to_string()],
171 default_variant: None,
172 };
173 let code = gen(&e, DatabaseKind::Postgres);
174 assert!(code.contains("sqlx_gen(kind = \"enum\", schema = \"auth\", name = \"role\")"));
175 }
176
177 #[test]
180 fn test_postgres_has_type_name() {
181 let e = make_enum("user_status", vec!["a"]);
182 let code = gen(&e, DatabaseKind::Postgres);
183 assert!(code.contains("sqlx(type_name = \"user_status\")"));
184 }
185
186 #[test]
187 fn test_postgres_non_public_schema_qualified_type_name() {
188 let e = EnumInfo {
189 schema_name: "auth".to_string(),
190 name: "role".to_string(),
191 variants: vec!["admin".to_string(), "user".to_string()],
192 default_variant: None,
193 };
194 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
195 let code = parse_and_format(&tokens);
196 assert!(code.contains("sqlx(type_name = \"auth.role\")"));
197 }
198
199 #[test]
200 fn test_postgres_public_schema_not_qualified() {
201 let e = make_enum("status", vec!["a"]);
202 let code = gen(&e, DatabaseKind::Postgres);
203 assert!(code.contains("sqlx(type_name = \"status\")"));
204 assert!(!code.contains("type_name = \"public.status\""));
206 }
207
208 #[test]
209 fn test_mysql_no_type_name() {
210 let e = make_enum("status", vec!["a"]);
211 let code = gen(&e, DatabaseKind::Mysql);
212 assert!(!code.contains("type_name"));
213 }
214
215 #[test]
216 fn test_sqlite_no_type_name() {
217 let e = make_enum("status", vec!["a"]);
218 let code = gen(&e, DatabaseKind::Sqlite);
219 assert!(!code.contains("type_name"));
220 }
221
222 #[test]
225 fn test_snake_case_variant_renamed() {
226 let e = make_enum("status", vec!["in_progress"]);
227 let code = gen(&e, DatabaseKind::Postgres);
228 assert!(code.contains("InProgress"));
229 assert!(code.contains("sqlx(rename = \"in_progress\")"));
230 }
231
232 #[test]
233 fn test_lowercase_variant_renamed() {
234 let e = make_enum("status", vec!["active"]);
235 let code = gen(&e, DatabaseKind::Postgres);
236 assert!(code.contains("Active"));
237 assert!(code.contains("sqlx(rename = \"active\")"));
238 }
239
240 #[test]
241 fn test_already_pascal_no_rename() {
242 let e = make_enum("status", vec!["Active"]);
243 let code = gen(&e, DatabaseKind::Postgres);
244 assert!(code.contains("Active"));
245 assert!(!code.contains("sqlx(rename"));
246 }
247
248 #[test]
249 fn test_upper_case_variant_renamed() {
250 let e = make_enum("status", vec!["UPPER_CASE"]);
251 let code = gen(&e, DatabaseKind::Postgres);
252 assert!(code.contains("UpperCase"));
253 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
254 }
255
256 #[test]
259 fn test_default_derives() {
260 let e = make_enum("status", vec!["a"]);
261 let code = gen(&e, DatabaseKind::Postgres);
262 assert!(code.contains("Debug"));
263 assert!(code.contains("Clone"));
264 assert!(code.contains("PartialEq"));
265 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
266 }
267
268 #[test]
269 fn test_extra_derive_serialize() {
270 let e = make_enum("status", vec!["a"]);
271 let derives = vec!["Serialize".to_string()];
272 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
273 assert!(code.contains("Serialize"));
274 }
275
276 #[test]
277 fn test_extra_derives_serde_imports() {
278 let e = make_enum("status", vec!["a"]);
279 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
280 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
281 assert!(imports.iter().any(|i| i.contains("serde")));
282 }
283
284 #[test]
287 fn test_no_extra_derives_has_serde_import() {
288 let e = make_enum("status", vec!["a"]);
289 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
290 assert!(imports.iter().any(|i| i.contains("serde")));
291 }
292
293 #[test]
294 fn test_serde_import_present() {
295 let e = make_enum("status", vec!["a"]);
296 let derives = vec!["Serialize".to_string()];
297 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
298 assert!(!imports.is_empty());
299 }
300
301 #[test]
304 fn test_single_variant() {
305 let e = make_enum("status", vec!["only"]);
306 let code = gen(&e, DatabaseKind::Postgres);
307 assert!(code.contains("Only"));
308 }
309
310 #[test]
311 fn test_many_variants() {
312 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
313 let e = make_enum("status", variants);
314 let code = gen(&e, DatabaseKind::Postgres);
315 assert!(code.contains("A,"));
316 assert!(code.contains("J,"));
317 }
318
319 #[test]
320 fn test_variant_with_digits() {
321 let e = make_enum("version", vec!["v2"]);
322 let code = gen(&e, DatabaseKind::Postgres);
323 assert!(code.contains("V2"));
324 }
325
326 #[test]
327 fn test_enum_name_with_double_underscores() {
328 let e = make_enum("my__enum", vec!["a"]);
329 let code = gen(&e, DatabaseKind::Postgres);
330 assert!(code.contains("pub enum MyEnum"));
331 }
332
333 #[test]
336 fn test_default_impl_generated() {
337 let e = EnumInfo {
338 schema_name: "public".to_string(),
339 name: "task_status".to_string(),
340 variants: vec!["idle".to_string(), "running".to_string(), "done".to_string()],
341 default_variant: Some("idle".to_string()),
342 };
343 let code = gen(&e, DatabaseKind::Postgres);
344 assert!(code.contains("impl Default for TaskStatus"));
345 assert!(code.contains("Self::Idle"));
346 }
347
348 #[test]
349 fn test_no_default_impl_when_none() {
350 let e = make_enum("status", vec!["active", "inactive"]);
351 let code = gen(&e, DatabaseKind::Postgres);
352 assert!(!code.contains("impl Default"));
353 }
354
355 #[test]
356 fn test_default_impl_snake_case_variant() {
357 let e = EnumInfo {
358 schema_name: "public".to_string(),
359 name: "status".to_string(),
360 variants: vec!["in_progress".to_string(), "done".to_string()],
361 default_variant: Some("in_progress".to_string()),
362 };
363 let code = gen(&e, DatabaseKind::Postgres);
364 assert!(code.contains("impl Default for Status"));
365 assert!(code.contains("Self::InProgress"));
366 }
367}