1use std::convert::{TryFrom, TryInto};
2
3use darling::FromMeta;
4use heck::{ToPascalCase, ToSnakeCase};
5use proc_macro::{self, TokenStream};
6use quote::{quote, quote_spanned};
7use syn::{
8 parse_macro_input, spanned::Spanned, Attribute, Data, DataEnum, DataStruct, DeriveInput,
9 Fields, Ident, Variant,
10};
11
12mod iden;
13
14use self::iden::{
15 attr::IdenAttr, error::ErrorMsg, path::IdenPath, write_arm::IdenVariant, DeriveIden,
16 DeriveIdenStatic,
17};
18
19#[proc_macro_derive(Iden, attributes(iden, method))]
20pub fn derive_iden(input: TokenStream) -> TokenStream {
21 let DeriveInput {
22 ident, data, attrs, ..
23 } = parse_macro_input!(input);
24 let table_name = match get_table_name(&ident, attrs) {
25 Ok(v) => v,
26 Err(e) => return e.to_compile_error().into(),
27 };
28
29 let variants =
31 match data {
32 syn::Data::Enum(DataEnum { variants, .. }) => variants,
33 syn::Data::Struct(DataStruct {
34 fields: Fields::Unit,
35 ..
36 }) => return impl_iden_for_unit_struct(&ident, &table_name).into(),
37 _ => return quote_spanned! {
38 ident.span() => compile_error!("you can only derive Iden on enums or unit structs");
39 }
40 .into(),
41 };
42
43 if variants.is_empty() {
44 return TokenStream::new();
45 }
46
47 let output = impl_iden_for_enum(&ident, &table_name, variants.iter());
48
49 output.into()
50}
51
52#[proc_macro_derive(IdenStatic, attributes(iden, method))]
53pub fn derive_iden_static(input: TokenStream) -> TokenStream {
54 let sea_query_path = sea_query_path();
55
56 let DeriveInput {
57 ident, data, attrs, ..
58 } = parse_macro_input!(input);
59
60 let table_name = match get_table_name(&ident, attrs) {
61 Ok(v) => v,
62 Err(e) => return e.to_compile_error().into(),
63 };
64
65 let variants =
67 match data {
68 syn::Data::Enum(DataEnum { variants, .. }) => variants,
69 syn::Data::Struct(DataStruct {
70 fields: Fields::Unit,
71 ..
72 }) => {
73 let impl_iden = impl_iden_for_unit_struct(&ident, &table_name);
74
75 return quote! {
76 #impl_iden
77
78 impl #sea_query_path::IdenStatic for #ident {
79 fn as_str(&self) -> &'static str {
80 #table_name
81 }
82 }
83
84 impl std::convert::AsRef<str> for #ident {
85 fn as_ref(&self) -> &str {
86 self.as_str()
87 }
88 }
89 }
90 .into();
91 }
92 _ => return quote_spanned! {
93 ident.span() => compile_error!("you can only derive Iden on enums or unit structs");
94 }
95 .into(),
96 };
97
98 if variants.is_empty() {
99 return TokenStream::new();
100 }
101
102 let impl_iden = impl_iden_for_enum(&ident, &table_name, variants.iter());
103
104 let match_arms = match variants
105 .iter()
106 .map(|v| (table_name.as_str(), v))
107 .map(IdenVariant::<DeriveIdenStatic>::try_from)
108 .collect::<syn::Result<Vec<_>>>()
109 {
110 Ok(v) => quote! { #(#v),* },
111 Err(e) => return e.to_compile_error().into(),
112 };
113
114 let output = quote! {
115 #impl_iden
116
117 impl #sea_query_path::IdenStatic for #ident {
118 fn as_str(&self) -> &'static str {
119 match self {
120 #match_arms
121 }
122 }
123 }
124
125 impl std::convert::AsRef<str> for #ident {
126 fn as_ref(&self) -> &'static str {
127 self.as_str()
128 }
129 }
130 };
131
132 output.into()
133}
134
135fn find_attr(attrs: &[Attribute]) -> Option<&Attribute> {
136 attrs.iter().find(|attr| {
137 attr.path().is_ident(&IdenPath::Iden) || attr.path().is_ident(&IdenPath::Method)
138 })
139}
140
141fn get_table_name(ident: &proc_macro2::Ident, attrs: Vec<Attribute>) -> Result<String, syn::Error> {
142 let table_name = match find_attr(&attrs) {
143 Some(att) => match att.try_into()? {
144 IdenAttr::Rename(lit) => lit,
145 _ => return Err(syn::Error::new_spanned(att, ErrorMsg::ContainerAttr)),
146 },
147 None => ident.to_string().to_snake_case(),
148 };
149 Ok(table_name)
150}
151
152fn must_be_valid_iden(name: &str) -> bool {
153 name.chars()
155 .take(1)
156 .all(|c| c == '_' || c.is_ascii_alphabetic())
157 && name.chars().all(|c| c == '_' || c.is_ascii_alphanumeric())
158}
159
160fn impl_iden_for_unit_struct(
161 ident: &proc_macro2::Ident,
162 table_name: &str,
163) -> proc_macro2::TokenStream {
164 let sea_query_path = sea_query_path();
165
166 let prepare = if must_be_valid_iden(table_name) {
167 quote! {
168 fn prepare(&self, s: &mut dyn ::std::fmt::Write, q: #sea_query_path::Quote) {
169 write!(s, "{}", q.left()).unwrap();
170 self.unquoted(s);
171 write!(s, "{}", q.right()).unwrap();
172 }
173 }
174 } else {
175 quote! {}
176 };
177
178 quote! {
179 impl #sea_query_path::Iden for #ident {
180 #prepare
181
182 fn unquoted(&self, s: &mut dyn ::std::fmt::Write) {
183 write!(s, #table_name).unwrap();
184 }
185 }
186 }
187}
188
189fn impl_iden_for_enum<'a, T>(
190 ident: &proc_macro2::Ident,
191 table_name: &str,
192 variants: T,
193) -> proc_macro2::TokenStream
194where
195 T: Iterator<Item = &'a Variant>,
196{
197 let sea_query_path = sea_query_path();
198
199 let mut is_all_valid = true;
200
201 let match_arms = match variants
202 .map(|v| (table_name, v))
203 .map(|v| {
204 let v = IdenVariant::<DeriveIden>::try_from(v)?;
205 is_all_valid &= v.must_be_valid_iden();
206 Ok(v)
207 })
208 .collect::<syn::Result<Vec<_>>>()
209 {
210 Ok(v) => quote! { #(#v),* },
211 Err(e) => return e.to_compile_error(),
212 };
213
214 let prepare = if is_all_valid {
215 quote! {
216 fn prepare(&self, s: &mut dyn ::std::fmt::Write, q: #sea_query_path::Quote) {
217 write!(s, "{}", q.left()).unwrap();
218 self.unquoted(s);
219 write!(s, "{}", q.right()).unwrap();
220 }
221 }
222 } else {
223 quote! {}
224 };
225
226 quote! {
227 impl #sea_query_path::Iden for #ident {
228 #prepare
229
230 fn unquoted(&self, s: &mut dyn ::std::fmt::Write) {
231 match self {
232 #match_arms
233 };
234 }
235 }
236 }
237}
238
239fn sea_query_path() -> proc_macro2::TokenStream {
240 if cfg!(feature = "sea-orm") {
241 quote!(sea_orm::sea_query)
242 } else {
243 quote!(sea_query)
244 }
245}
246
247struct NamingHolder {
248 pub default: Ident,
249 pub pascal: Ident,
250}
251
252#[derive(Debug, FromMeta)]
253struct GenEnumArgs {
254 #[darling(default)]
255 pub prefix: Option<String>,
256 #[darling(default)]
257 pub suffix: Option<String>,
258 #[darling(default)]
259 pub crate_name: Option<String>,
260 #[darling(default)]
261 pub table_name: Option<String>,
262}
263
264const DEFAULT_PREFIX: &str = "";
265const DEFAULT_SUFFIX: &str = "Iden";
266const DEFAULT_CRATE_NAME: &str = "sea_query";
267
268impl Default for GenEnumArgs {
269 fn default() -> Self {
270 Self {
271 prefix: Some(DEFAULT_PREFIX.to_string()),
272 suffix: Some(DEFAULT_SUFFIX.to_string()),
273 crate_name: Some(DEFAULT_CRATE_NAME.to_string()),
274 table_name: None,
275 }
276 }
277}
278
279#[proc_macro_attribute]
280pub fn enum_def(args: TokenStream, input: TokenStream) -> TokenStream {
281 let attr_args = match darling::ast::NestedMeta::parse_meta_list(args.into()) {
282 Ok(v) => v,
283 Err(e) => {
284 return TokenStream::from(darling::Error::from(e).write_errors());
285 }
286 };
287 let input = parse_macro_input!(input as DeriveInput);
288
289 let args = match GenEnumArgs::from_list(&attr_args) {
290 Ok(v) => v,
291 Err(e) => {
292 return TokenStream::from(e.write_errors());
293 }
294 };
295
296 let fields =
297 match &input.data {
298 Data::Struct(DataStruct {
299 fields: Fields::Named(fields),
300 ..
301 }) => &fields.named,
302 _ => return quote_spanned! {
303 input.span() => compile_error!("#[enum_def] can only be used on non-tuple structs");
304 }
305 .into(),
306 };
307
308 let field_names: Vec<NamingHolder> = fields
309 .iter()
310 .map(|field| {
311 let ident = field.ident.as_ref().unwrap();
312 NamingHolder {
313 default: ident.clone(),
314 pascal: Ident::new(ident.to_string().to_pascal_case().as_str(), ident.span()),
315 }
316 })
317 .collect();
318
319 let table_name = Ident::new(
320 args.table_name
321 .unwrap_or_else(|| input.ident.to_string().to_snake_case())
322 .as_str(),
323 input.ident.span(),
324 );
325
326 let enum_name = quote::format_ident!(
327 "{}{}{}",
328 args.prefix.unwrap_or_else(|| DEFAULT_PREFIX.to_string()),
329 &input.ident,
330 args.suffix.unwrap_or_else(|| DEFAULT_SUFFIX.to_string())
331 );
332 let pascal_def_names = field_names.iter().map(|field| &field.pascal);
333 let pascal_def_names2 = pascal_def_names.clone();
334 let default_names = field_names.iter().map(|field| &field.default);
335 let default_names2 = default_names.clone();
336 let import_name = Ident::new(
337 args.crate_name
338 .unwrap_or_else(|| DEFAULT_CRATE_NAME.to_string())
339 .as_str(),
340 input.span(),
341 );
342
343 TokenStream::from(quote::quote! {
344 #input
345
346 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
347 pub enum #enum_name {
348 Table,
349 #(#pascal_def_names,)*
350 }
351
352 impl #import_name::IdenStatic for #enum_name {
353 fn as_str(&self) -> &'static str {
354 match self {
355 #enum_name::Table => stringify!(#table_name),
356 #(#enum_name::#pascal_def_names2 => stringify!(#default_names2)),*
357 }
358 }
359 }
360
361 impl #import_name::Iden for #enum_name {
362 fn unquoted(&self, s: &mut dyn sea_query::Write) {
363 write!(s, "{}", <Self as #import_name::IdenStatic>::as_str(&self)).unwrap();
364 }
365 }
366
367 impl ::std::convert::AsRef<str> for #enum_name {
368 fn as_ref(&self) -> &str {
369 <Self as #import_name::IdenStatic>::as_str(&self)
370 }
371 }
372 })
373}