sea_query_derive/
lib.rs

1use std::convert::{TryFrom, TryInto};
2
3use darling::FromMeta;
4use heck::{ToPascalCase, ToSnakeCase};
5use proc_macro::{self, TokenStream};
6use quote::{quote, quote_spanned};
7use syn::{
8    parse_macro_input, spanned::Spanned, Attribute, Data, DataEnum, DataStruct, DeriveInput,
9    Fields, Ident, Variant,
10};
11
12mod iden;
13
14use self::iden::{
15    attr::IdenAttr, error::ErrorMsg, path::IdenPath, write_arm::IdenVariant, DeriveIden,
16    DeriveIdenStatic,
17};
18
19#[proc_macro_derive(Iden, attributes(iden, method))]
20pub fn derive_iden(input: TokenStream) -> TokenStream {
21    let DeriveInput {
22        ident, data, attrs, ..
23    } = parse_macro_input!(input);
24    let table_name = match get_table_name(&ident, attrs) {
25        Ok(v) => v,
26        Err(e) => return e.to_compile_error().into(),
27    };
28
29    // Currently we only support enums and unit structs
30    let variants =
31        match data {
32            syn::Data::Enum(DataEnum { variants, .. }) => variants,
33            syn::Data::Struct(DataStruct {
34                fields: Fields::Unit,
35                ..
36            }) => return impl_iden_for_unit_struct(&ident, &table_name).into(),
37            _ => return quote_spanned! {
38                ident.span() => compile_error!("you can only derive Iden on enums or unit structs");
39            }
40            .into(),
41        };
42
43    if variants.is_empty() {
44        return TokenStream::new();
45    }
46
47    let output = impl_iden_for_enum(&ident, &table_name, variants.iter());
48
49    output.into()
50}
51
52#[proc_macro_derive(IdenStatic, attributes(iden, method))]
53pub fn derive_iden_static(input: TokenStream) -> TokenStream {
54    let sea_query_path = sea_query_path();
55
56    let DeriveInput {
57        ident, data, attrs, ..
58    } = parse_macro_input!(input);
59
60    let table_name = match get_table_name(&ident, attrs) {
61        Ok(v) => v,
62        Err(e) => return e.to_compile_error().into(),
63    };
64
65    // Currently we only support enums and unit structs
66    let variants =
67        match data {
68            syn::Data::Enum(DataEnum { variants, .. }) => variants,
69            syn::Data::Struct(DataStruct {
70                fields: Fields::Unit,
71                ..
72            }) => {
73                let impl_iden = impl_iden_for_unit_struct(&ident, &table_name);
74
75                return quote! {
76                    #impl_iden
77
78                    impl #sea_query_path::IdenStatic for #ident {
79                        fn as_str(&self) -> &'static str {
80                            #table_name
81                        }
82                    }
83
84                    impl std::convert::AsRef<str> for #ident {
85                        fn as_ref(&self) -> &str {
86                            self.as_str()
87                        }
88                    }
89                }
90                .into();
91            }
92            _ => return quote_spanned! {
93                ident.span() => compile_error!("you can only derive Iden on enums or unit structs");
94            }
95            .into(),
96        };
97
98    if variants.is_empty() {
99        return TokenStream::new();
100    }
101
102    let impl_iden = impl_iden_for_enum(&ident, &table_name, variants.iter());
103
104    let match_arms = match variants
105        .iter()
106        .map(|v| (table_name.as_str(), v))
107        .map(IdenVariant::<DeriveIdenStatic>::try_from)
108        .collect::<syn::Result<Vec<_>>>()
109    {
110        Ok(v) => quote! { #(#v),* },
111        Err(e) => return e.to_compile_error().into(),
112    };
113
114    let output = quote! {
115        #impl_iden
116
117        impl #sea_query_path::IdenStatic for #ident {
118            fn as_str(&self) -> &'static str {
119                match self {
120                    #match_arms
121                }
122            }
123        }
124
125        impl std::convert::AsRef<str> for #ident {
126            fn as_ref(&self) -> &'static str {
127                self.as_str()
128            }
129        }
130    };
131
132    output.into()
133}
134
135fn find_attr(attrs: &[Attribute]) -> Option<&Attribute> {
136    attrs.iter().find(|attr| {
137        attr.path().is_ident(&IdenPath::Iden) || attr.path().is_ident(&IdenPath::Method)
138    })
139}
140
141fn get_table_name(ident: &proc_macro2::Ident, attrs: Vec<Attribute>) -> Result<String, syn::Error> {
142    let table_name = match find_attr(&attrs) {
143        Some(att) => match att.try_into()? {
144            IdenAttr::Rename(lit) => lit,
145            _ => return Err(syn::Error::new_spanned(att, ErrorMsg::ContainerAttr)),
146        },
147        None => ident.to_string().to_snake_case(),
148    };
149    Ok(table_name)
150}
151
152fn must_be_valid_iden(name: &str) -> bool {
153    // can only begin with [a-z_]
154    name.chars()
155        .take(1)
156        .all(|c| c == '_' || c.is_ascii_alphabetic())
157        && name.chars().all(|c| c == '_' || c.is_ascii_alphanumeric())
158}
159
160fn impl_iden_for_unit_struct(
161    ident: &proc_macro2::Ident,
162    table_name: &str,
163) -> proc_macro2::TokenStream {
164    let sea_query_path = sea_query_path();
165
166    let prepare = if must_be_valid_iden(table_name) {
167        quote! {
168            fn prepare(&self, s: &mut dyn ::std::fmt::Write, q: #sea_query_path::Quote) {
169                write!(s, "{}", q.left()).unwrap();
170                self.unquoted(s);
171                write!(s, "{}", q.right()).unwrap();
172            }
173        }
174    } else {
175        quote! {}
176    };
177
178    quote! {
179        impl #sea_query_path::Iden for #ident {
180            #prepare
181
182            fn unquoted(&self, s: &mut dyn ::std::fmt::Write) {
183                write!(s, #table_name).unwrap();
184            }
185        }
186    }
187}
188
189fn impl_iden_for_enum<'a, T>(
190    ident: &proc_macro2::Ident,
191    table_name: &str,
192    variants: T,
193) -> proc_macro2::TokenStream
194where
195    T: Iterator<Item = &'a Variant>,
196{
197    let sea_query_path = sea_query_path();
198
199    let mut is_all_valid = true;
200
201    let match_arms = match variants
202        .map(|v| (table_name, v))
203        .map(|v| {
204            let v = IdenVariant::<DeriveIden>::try_from(v)?;
205            is_all_valid &= v.must_be_valid_iden();
206            Ok(v)
207        })
208        .collect::<syn::Result<Vec<_>>>()
209    {
210        Ok(v) => quote! { #(#v),* },
211        Err(e) => return e.to_compile_error(),
212    };
213
214    let prepare = if is_all_valid {
215        quote! {
216            fn prepare(&self, s: &mut dyn ::std::fmt::Write, q: #sea_query_path::Quote) {
217                write!(s, "{}", q.left()).unwrap();
218                self.unquoted(s);
219                write!(s, "{}", q.right()).unwrap();
220            }
221        }
222    } else {
223        quote! {}
224    };
225
226    quote! {
227        impl #sea_query_path::Iden for #ident {
228            #prepare
229
230            fn unquoted(&self, s: &mut dyn ::std::fmt::Write) {
231                match self {
232                    #match_arms
233                };
234            }
235        }
236    }
237}
238
239fn sea_query_path() -> proc_macro2::TokenStream {
240    if cfg!(feature = "sea-orm") {
241        quote!(sea_orm::sea_query)
242    } else {
243        quote!(sea_query)
244    }
245}
246
247struct NamingHolder {
248    pub default: Ident,
249    pub pascal: Ident,
250}
251
252#[derive(Debug, FromMeta)]
253struct GenEnumArgs {
254    #[darling(default)]
255    pub prefix: Option<String>,
256    #[darling(default)]
257    pub suffix: Option<String>,
258    #[darling(default)]
259    pub crate_name: Option<String>,
260    #[darling(default)]
261    pub table_name: Option<String>,
262}
263
264const DEFAULT_PREFIX: &str = "";
265const DEFAULT_SUFFIX: &str = "Iden";
266const DEFAULT_CRATE_NAME: &str = "sea_query";
267
268impl Default for GenEnumArgs {
269    fn default() -> Self {
270        Self {
271            prefix: Some(DEFAULT_PREFIX.to_string()),
272            suffix: Some(DEFAULT_SUFFIX.to_string()),
273            crate_name: Some(DEFAULT_CRATE_NAME.to_string()),
274            table_name: None,
275        }
276    }
277}
278
279#[proc_macro_attribute]
280pub fn enum_def(args: TokenStream, input: TokenStream) -> TokenStream {
281    let attr_args = match darling::ast::NestedMeta::parse_meta_list(args.into()) {
282        Ok(v) => v,
283        Err(e) => {
284            return TokenStream::from(darling::Error::from(e).write_errors());
285        }
286    };
287    let input = parse_macro_input!(input as DeriveInput);
288
289    let args = match GenEnumArgs::from_list(&attr_args) {
290        Ok(v) => v,
291        Err(e) => {
292            return TokenStream::from(e.write_errors());
293        }
294    };
295
296    let fields =
297        match &input.data {
298            Data::Struct(DataStruct {
299                fields: Fields::Named(fields),
300                ..
301            }) => &fields.named,
302            _ => return quote_spanned! {
303                input.span() => compile_error!("#[enum_def] can only be used on non-tuple structs");
304            }
305            .into(),
306        };
307
308    let field_names: Vec<NamingHolder> = fields
309        .iter()
310        .map(|field| {
311            let ident = field.ident.as_ref().unwrap();
312            NamingHolder {
313                default: ident.clone(),
314                pascal: Ident::new(ident.to_string().to_pascal_case().as_str(), ident.span()),
315            }
316        })
317        .collect();
318
319    let table_name = Ident::new(
320        args.table_name
321            .unwrap_or_else(|| input.ident.to_string().to_snake_case())
322            .as_str(),
323        input.ident.span(),
324    );
325
326    let enum_name = quote::format_ident!(
327        "{}{}{}",
328        args.prefix.unwrap_or_else(|| DEFAULT_PREFIX.to_string()),
329        &input.ident,
330        args.suffix.unwrap_or_else(|| DEFAULT_SUFFIX.to_string())
331    );
332    let pascal_def_names = field_names.iter().map(|field| &field.pascal);
333    let pascal_def_names2 = pascal_def_names.clone();
334    let default_names = field_names.iter().map(|field| &field.default);
335    let default_names2 = default_names.clone();
336    let import_name = Ident::new(
337        args.crate_name
338            .unwrap_or_else(|| DEFAULT_CRATE_NAME.to_string())
339            .as_str(),
340        input.span(),
341    );
342
343    TokenStream::from(quote::quote! {
344        #input
345
346        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
347        pub enum #enum_name {
348            Table,
349            #(#pascal_def_names,)*
350        }
351
352        impl #import_name::IdenStatic for #enum_name {
353            fn as_str(&self) -> &'static str {
354                match self {
355                    #enum_name::Table => stringify!(#table_name),
356                    #(#enum_name::#pascal_def_names2 => stringify!(#default_names2)),*
357                }
358            }
359        }
360
361        impl #import_name::Iden for #enum_name {
362            fn unquoted(&self, s: &mut dyn sea_query::Write) {
363                write!(s, "{}", <Self as #import_name::IdenStatic>::as_str(&self)).unwrap();
364            }
365        }
366
367        impl ::std::convert::AsRef<str> for #enum_name {
368            fn as_ref(&self) -> &str {
369                <Self as #import_name::IdenStatic>::as_str(&self)
370            }
371        }
372    })
373}