1use std::collections::BTreeSet;
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::imports_for_derives;
9use crate::introspect::EnumInfo;
10
11pub fn generate_enum(
12 enum_info: &EnumInfo,
13 db_kind: DatabaseKind,
14 extra_derives: &[String],
15) -> (TokenStream, BTreeSet<String>) {
16 let mut imports = BTreeSet::new();
17 for imp in imports_for_derives(extra_derives) {
18 imports.insert(imp);
19 }
20
21 let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
22 let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
23
24 imports.insert("use serde::{Serialize, Deserialize};".to_string());
25 imports.insert("use sqlx_gen::SqlxGen;".to_string());
26 let mut derive_tokens = vec![
27 quote! { Debug },
28 quote! { Clone },
29 quote! { PartialEq },
30 quote! { Eq },
31 quote! { Serialize },
32 quote! { Deserialize },
33 quote! { sqlx::Type },
34 quote! { SqlxGen },
35 ];
36 for d in extra_derives {
37 let ident = format_ident!("{}", d);
38 derive_tokens.push(quote! { #ident });
39 }
40
41 let type_attr = if db_kind == DatabaseKind::Postgres {
44 let pg_name = if enum_info.schema_name != "public" {
45 format!("{}.{}", enum_info.schema_name, enum_info.name)
46 } else {
47 enum_info.name.clone()
48 };
49 quote! { #[sqlx(type_name = #pg_name)] }
50 } else {
51 quote! {}
52 };
53
54 let variants: Vec<TokenStream> = enum_info
55 .variants
56 .iter()
57 .map(|v| {
58 let variant_pascal = v.to_upper_camel_case();
59 let variant_ident = format_ident!("{}", variant_pascal);
60
61 let rename = if variant_pascal != *v {
62 quote! { #[sqlx(rename = #v)] }
63 } else {
64 quote! {}
65 };
66
67 quote! {
68 #rename
69 #variant_ident,
70 }
71 })
72 .collect();
73
74 let tokens = quote! {
75 #[doc = #doc]
76 #[derive(#(#derive_tokens),*)]
77 #[sqlx_gen(kind = "enum")]
78 #type_attr
79 pub enum #enum_name {
80 #(#variants)*
81 }
82 };
83
84 (tokens, imports)
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::codegen::parse_and_format;
91
92 fn make_enum(name: &str, variants: Vec<&str>) -> EnumInfo {
93 EnumInfo {
94 schema_name: "public".to_string(),
95 name: name.to_string(),
96 variants: variants.into_iter().map(|s| s.to_string()).collect(),
97 }
98 }
99
100 fn gen(info: &EnumInfo, db: DatabaseKind) -> String {
101 let (tokens, _) = generate_enum(info, db, &[]);
102 parse_and_format(&tokens)
103 }
104
105 fn gen_with_derives(
106 info: &EnumInfo,
107 db: DatabaseKind,
108 derives: &[String],
109 ) -> (String, BTreeSet<String>) {
110 let (tokens, imports) = generate_enum(info, db, derives);
111 (parse_and_format(&tokens), imports)
112 }
113
114 #[test]
117 fn test_enum_variants() {
118 let e = make_enum("status", vec!["active", "inactive"]);
119 let code = gen(&e, DatabaseKind::Postgres);
120 assert!(code.contains("Active"));
121 assert!(code.contains("Inactive"));
122 }
123
124 #[test]
125 fn test_enum_name_pascal_case() {
126 let e = make_enum("user_status", vec!["a"]);
127 let code = gen(&e, DatabaseKind::Postgres);
128 assert!(code.contains("pub enum UserStatus"));
129 }
130
131 #[test]
132 fn test_doc_comment() {
133 let e = make_enum("status", vec!["a"]);
134 let code = gen(&e, DatabaseKind::Postgres);
135 assert!(code.contains("Enum: public.status"));
136 }
137
138 #[test]
141 fn test_postgres_has_type_name() {
142 let e = make_enum("user_status", vec!["a"]);
143 let code = gen(&e, DatabaseKind::Postgres);
144 assert!(code.contains("sqlx(type_name = \"user_status\")"));
145 }
146
147 #[test]
148 fn test_postgres_non_public_schema_qualified_type_name() {
149 let e = EnumInfo {
150 schema_name: "auth".to_string(),
151 name: "role".to_string(),
152 variants: vec!["admin".to_string(), "user".to_string()],
153 };
154 let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
155 let code = parse_and_format(&tokens);
156 assert!(code.contains("sqlx(type_name = \"auth.role\")"));
157 }
158
159 #[test]
160 fn test_postgres_public_schema_not_qualified() {
161 let e = make_enum("status", vec!["a"]);
162 let code = gen(&e, DatabaseKind::Postgres);
163 assert!(code.contains("sqlx(type_name = \"status\")"));
164 assert!(!code.contains("type_name = \"public.status\""));
166 }
167
168 #[test]
169 fn test_mysql_no_type_name() {
170 let e = make_enum("status", vec!["a"]);
171 let code = gen(&e, DatabaseKind::Mysql);
172 assert!(!code.contains("type_name"));
173 }
174
175 #[test]
176 fn test_sqlite_no_type_name() {
177 let e = make_enum("status", vec!["a"]);
178 let code = gen(&e, DatabaseKind::Sqlite);
179 assert!(!code.contains("type_name"));
180 }
181
182 #[test]
185 fn test_snake_case_variant_renamed() {
186 let e = make_enum("status", vec!["in_progress"]);
187 let code = gen(&e, DatabaseKind::Postgres);
188 assert!(code.contains("InProgress"));
189 assert!(code.contains("sqlx(rename = \"in_progress\")"));
190 }
191
192 #[test]
193 fn test_lowercase_variant_renamed() {
194 let e = make_enum("status", vec!["active"]);
195 let code = gen(&e, DatabaseKind::Postgres);
196 assert!(code.contains("Active"));
197 assert!(code.contains("sqlx(rename = \"active\")"));
198 }
199
200 #[test]
201 fn test_already_pascal_no_rename() {
202 let e = make_enum("status", vec!["Active"]);
203 let code = gen(&e, DatabaseKind::Postgres);
204 assert!(code.contains("Active"));
205 assert!(!code.contains("sqlx(rename"));
206 }
207
208 #[test]
209 fn test_upper_case_variant_renamed() {
210 let e = make_enum("status", vec!["UPPER_CASE"]);
211 let code = gen(&e, DatabaseKind::Postgres);
212 assert!(code.contains("UpperCase"));
213 assert!(code.contains("sqlx(rename = \"UPPER_CASE\")"));
214 }
215
216 #[test]
219 fn test_default_derives() {
220 let e = make_enum("status", vec!["a"]);
221 let code = gen(&e, DatabaseKind::Postgres);
222 assert!(code.contains("Debug"));
223 assert!(code.contains("Clone"));
224 assert!(code.contains("PartialEq"));
225 assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
226 }
227
228 #[test]
229 fn test_extra_derive_serialize() {
230 let e = make_enum("status", vec!["a"]);
231 let derives = vec!["Serialize".to_string()];
232 let (code, _) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
233 assert!(code.contains("Serialize"));
234 }
235
236 #[test]
237 fn test_extra_derives_serde_imports() {
238 let e = make_enum("status", vec!["a"]);
239 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
240 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
241 assert!(imports.iter().any(|i| i.contains("serde")));
242 }
243
244 #[test]
247 fn test_no_extra_derives_has_serde_import() {
248 let e = make_enum("status", vec!["a"]);
249 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
250 assert!(imports.iter().any(|i| i.contains("serde")));
251 }
252
253 #[test]
254 fn test_serde_import_present() {
255 let e = make_enum("status", vec!["a"]);
256 let derives = vec!["Serialize".to_string()];
257 let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &derives);
258 assert!(!imports.is_empty());
259 }
260
261 #[test]
264 fn test_single_variant() {
265 let e = make_enum("status", vec!["only"]);
266 let code = gen(&e, DatabaseKind::Postgres);
267 assert!(code.contains("Only"));
268 }
269
270 #[test]
271 fn test_many_variants() {
272 let variants: Vec<&str> = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"];
273 let e = make_enum("status", variants);
274 let code = gen(&e, DatabaseKind::Postgres);
275 assert!(code.contains("A,"));
276 assert!(code.contains("J,"));
277 }
278
279 #[test]
280 fn test_variant_with_digits() {
281 let e = make_enum("version", vec!["v2"]);
282 let code = gen(&e, DatabaseKind::Postgres);
283 assert!(code.contains("V2"));
284 }
285
286 #[test]
287 fn test_enum_name_with_double_underscores() {
288 let e = make_enum("my__enum", vec!["a"]);
289 let code = gen(&e, DatabaseKind::Postgres);
290 assert!(code.contains("pub enum MyEnum"));
291 }
292}