sea_bae/
lib.rs

1//! `sea-bae` is a crate for proc macro authors, which simplifies parsing of attributes. It is
2//! heavily inspired by [`darling`](https://crates.io/crates/darling) but has a significantly
3//! simpler API.
4//!
5//! ```rust
6//! use sea_bae::FromAttributes;
7//!
8//! #[derive(
9//!     Debug,
10//!     Eq,
11//!     PartialEq,
12//!
13//!     // This will add two functions:
14//!     // ```
15//!     // fn from_attributes(attrs: &[syn::Attribute]) -> Result<MyAttr, syn::Error>
16//!     // fn try_from_attributes(attrs: &[syn::Attribute]) -> Result<Option<MyAttr>, syn::Error>
17//!     // ```
18//!     //
19//!     // `try_from_attributes` returns `Ok(None)` if the attribute is missing, `Ok(Some(_))` if
20//!     // its there and is valid, `Err(_)` otherwise.
21//!     FromAttributes,
22//! )]
23//! pub struct MyAttr {
24//!     // Anything that implements `syn::parse::Parse` is supported.
25//!     mandatory_type: syn::Type,
26//!     mandatory_ident: syn::Ident,
27//!
28//!     // Fields wrapped in `Option` are optional and default to `None` if
29//!     // not specified in the attribute.
30//!     optional_missing: Option<syn::Type>,
31//!     optional_given: Option<syn::Type>,
32//!
33//!     // A "switch" is something that doesn't take arguments.
34//!     // All fields with type `Option<()>` are considered swiches.
35//!     // They default to `None`.
36//!     switch: Option<()>,
37//! }
38//!
39//! // `MyAttr` is now equipped to parse attributes named `my_attr`. For example:
40//! //
41//! //     #[my_attr(
42//! //         switch,
43//! //         mandatory_ident = foo,
44//! //         mandatory_type = SomeType,
45//! //         optional_given = OtherType,
46//! //     )]
47//! //     struct Foo {
48//! //         ...
49//! //     }
50//!
51//! // the input and output type would normally be `proc_macro::TokenStream` but those
52//! // types cannot be used outside the compiler itself.
53//! fn my_proc_macro(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
54//!     let item_struct = syn::parse2::<syn::ItemStruct>(input).unwrap();
55//!
56//!     let my_attr = MyAttr::from_attributes(&item_struct.attrs).unwrap();
57//!
58//!     assert_eq!(
59//!         my_attr.mandatory_type,
60//!         syn::parse_str::<syn::Type>("SomeType").unwrap()
61//!     );
62//!
63//!     assert_eq!(my_attr.optional_missing, None);
64//!
65//!     assert_eq!(
66//!         my_attr.optional_given,
67//!         Some(syn::parse_str::<syn::Type>("OtherType").unwrap())
68//!     );
69//!
70//!     assert_eq!(my_attr.mandatory_ident, syn::parse_str::<syn::Ident>("foo").unwrap());
71//!
72//!     assert_eq!(my_attr.switch.is_some(), true);
73//!
74//!     // ...
75//!     #
76//!     # quote::quote! {}
77//! }
78//! #
79//! # fn main() {
80//! #     let code = quote::quote! {
81//! #         #[other_random_attr]
82//! #         #[my_attr(
83//! #             switch,
84//! #             mandatory_ident = foo,
85//! #             mandatory_type = SomeType,
86//! #             optional_given = OtherType,
87//! #         )]
88//! #         struct Foo;
89//! #     };
90//! #     my_proc_macro(code);
91//! # }
92//! ```
93
94#![doc(html_root_url = "https://docs.rs/sea-bae/0.2.0")]
95#![allow(clippy::let_and_return)]
96#![deny(
97    unused_variables,
98    dead_code,
99    unused_must_use,
100    unused_imports,
101    missing_docs
102)]
103
104extern crate proc_macro;
105
106use heck::ToSnakeCase;
107use proc_macro2::TokenStream;
108use proc_macro_error2::{abort, proc_macro_error};
109use quote::*;
110use syn::{spanned::Spanned, *};
111
112/// See root module docs for more info.
113#[proc_macro_derive(FromAttributes, attributes())]
114#[proc_macro_error]
115pub fn from_attributes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116    let item = parse_macro_input!(input as ItemStruct);
117    FromAttributes::new(item).expand().into()
118}
119
120#[derive(Debug)]
121struct FromAttributes {
122    item: ItemStruct,
123    tokens: TokenStream,
124}
125
126impl FromAttributes {
127    fn new(item: ItemStruct) -> Self {
128        Self {
129            item,
130            tokens: TokenStream::new(),
131        }
132    }
133
134    fn expand(mut self) -> TokenStream {
135        self.expand_from_attributes_method();
136        self.expand_parse_impl();
137
138        if std::env::var("BAE_DEBUG").is_ok() {
139            eprintln!("{}", self.tokens);
140        }
141
142        self.tokens
143    }
144
145    fn struct_name(&self) -> &Ident {
146        &self.item.ident
147    }
148
149    fn attr_name(&self) -> LitStr {
150        let struct_name = self.struct_name();
151        let name = struct_name.to_string().to_snake_case();
152        LitStr::new(&name, struct_name.span())
153    }
154
155    fn expand_from_attributes_method(&mut self) {
156        let struct_name = self.struct_name();
157        let attr_name = self.attr_name();
158
159        let code = quote! {
160            impl #struct_name {
161                pub fn try_from_attributes(attrs: &[syn::Attribute]) -> syn::Result<Option<Self>> {
162                    use syn::spanned::Spanned;
163
164                    for attr in attrs {
165                        if attr.path().is_ident(#attr_name) {
166                            return Some(attr.parse_args::<Self>()).transpose()
167                        }
168                    }
169
170                    Ok(None)
171                }
172
173                pub fn from_attributes(attrs: &[syn::Attribute]) -> syn::Result<Self> {
174                    if let Some(attr) = Self::try_from_attributes(attrs)? {
175                        Ok(attr)
176                    } else {
177                        Err(syn::Error::new(
178                            proc_macro2::Span::call_site(),
179                            &format!("missing attribute `#[{}]`", #attr_name),
180                        ))
181                    }
182                }
183            }
184        };
185        self.tokens.extend(code);
186    }
187
188    fn expand_parse_impl(&mut self) {
189        let struct_name = self.struct_name();
190        let attr_name = self.attr_name();
191
192        let variable_declarations = self.item.fields.iter().map(|field| {
193            let name = &field.ident;
194            quote! { let mut #name = std::option::Option::None; }
195        });
196
197        let match_arms = self.item.fields.iter().map(|field| {
198            let field_name = get_field_name(field);
199            let pattern = LitStr::new(&field_name.to_string(), field.span());
200
201            if field_is_switch(field) {
202                quote! {
203                    #pattern => {
204                        #field_name = std::option::Option::Some(());
205                    }
206                }
207            } else {
208                quote! {
209                    #pattern => {
210                        input.parse::<syn::Token![=]>()?;
211                        #field_name = std::option::Option::Some(input.parse()?);
212                    }
213                }
214            }
215        });
216
217        let unwrap_mandatory_fields = self
218            .item
219            .fields
220            .iter()
221            .filter(|field| !field_is_optional(field))
222            .map(|field| {
223                let field_name = get_field_name(field);
224                let arg_name = LitStr::new(&field_name.to_string(), field.span());
225
226                quote! {
227                    let #field_name = if let std::option::Option::Some(#field_name) = #field_name {
228                        #field_name
229                    } else {
230                        return syn::Result::Err(
231                            input.error(
232                                &format!("`#[{}]` is missing `{}` argument", #attr_name, #arg_name),
233                            )
234                        );
235                    };
236                }
237            });
238
239        let set_fields = self.item.fields.iter().map(|field| {
240            let field_name = get_field_name(field);
241            quote! { #field_name, }
242        });
243
244        let mut supported_args = self
245            .item
246            .fields
247            .iter()
248            .map(|field| get_field_name(field))
249            .map(|field_name| format!("`{}`", field_name))
250            .collect::<Vec<_>>();
251        supported_args.sort_unstable();
252        let supported_args = supported_args.join(", ");
253
254        let code = quote! {
255            impl syn::parse::Parse for #struct_name {
256                #[allow(unreachable_code, unused_imports, unused_variables)]
257                fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
258                    #(#variable_declarations)*
259
260                    while !input.is_empty() {
261                        let bae_attr_ident = input.parse::<syn::Ident>()?;
262
263                        match &*bae_attr_ident.to_string() {
264                            #(#match_arms)*
265                            other => {
266                                return syn::Result::Err(
267                                    syn::Error::new(
268                                        bae_attr_ident.span(),
269                                        &format!(
270                                            "`#[{}]` got unknown `{}` argument. Supported arguments are {}",
271                                            #attr_name,
272                                            other,
273                                            #supported_args,
274                                        ),
275                                    )
276                                );
277                            }
278                        }
279
280                        input.parse::<syn::Token![,]>().ok();
281                    }
282
283                    #(#unwrap_mandatory_fields)*
284
285                    syn::Result::Ok(Self { #(#set_fields)* })
286                }
287            }
288        };
289        self.tokens.extend(code);
290    }
291}
292
293fn get_field_name(field: &Field) -> &Ident {
294    field
295        .ident
296        .as_ref()
297        .unwrap_or_else(|| abort!(field.span(), "Field without a name"))
298}
299
300fn field_is_optional(field: &Field) -> bool {
301    let type_path = if let Type::Path(type_path) = &field.ty {
302        type_path
303    } else {
304        return false;
305    };
306
307    let ident = &type_path
308        .path
309        .segments
310        .last()
311        .unwrap_or_else(|| abort!(field.span(), "Empty type path"))
312        .ident;
313
314    ident == "Option"
315}
316
317fn field_is_switch(field: &Field) -> bool {
318    let unit_type = syn::parse_str::<Type>("()").unwrap();
319    inner_type(&field.ty) == Some(&unit_type)
320}
321
322fn inner_type(ty: &Type) -> Option<&Type> {
323    let type_path = if let Type::Path(type_path) = ty {
324        type_path
325    } else {
326        return None;
327    };
328
329    let ty_args = &type_path
330        .path
331        .segments
332        .last()
333        .unwrap_or_else(|| abort!(ty.span(), "Empty type path"))
334        .arguments;
335
336    let ty_args = if let PathArguments::AngleBracketed(ty_args) = ty_args {
337        ty_args
338    } else {
339        return None;
340    };
341
342    let generic_arg = &ty_args
343        .args
344        .last()
345        .unwrap_or_else(|| abort!(ty_args.span(), "Empty generic argument"));
346
347    let ty = if let GenericArgument::Type(ty) = generic_arg {
348        ty
349    } else {
350        return None;
351    };
352
353    Some(ty)
354}
355
356#[cfg(test)]
357mod test {
358    #[allow(unused_imports)]
359    use super::*;
360
361    #[test]
362    fn test_ui() {
363        let t = trybuild::TestCases::new();
364        t.pass("tests/compile_pass/*.rs");
365        t.compile_fail("tests/compile_fail/*.rs");
366    }
367}