structopt_toml_derive/
lib.rs

1extern crate proc_macro;
2extern crate syn;
3#[macro_use]
4extern crate quote;
5
6use heck::ToKebabCase;
7use proc_macro::TokenStream;
8use proc_macro2::TokenTree;
9use syn::parse::{Parse, ParseStream};
10use syn::punctuated::Punctuated;
11use syn::token::Comma;
12use syn::{buffer::Cursor, DataStruct, DeriveInput, Field, Ident, LitStr};
13
14#[proc_macro_derive(StructOptToml, attributes(structopt))]
15pub fn structopt_toml(input: TokenStream) -> TokenStream {
16    let input: DeriveInput = syn::parse(input).unwrap();
17    let gen = impl_structopt_toml(&input);
18    gen.into()
19}
20
21fn impl_structopt_toml(input: &DeriveInput) -> proc_macro2::TokenStream {
22    use syn::Data::*;
23
24    let struct_name = &input.ident;
25    let inner_impl = match input.data {
26        Struct(DataStruct {
27            fields: syn::Fields::Named(ref fields),
28            ..
29        }) => impl_structopt_for_struct(struct_name, &fields.named),
30        _ => panic!("structopt_toml only supports non-tuple struct"),
31    };
32
33    quote!(#inner_impl)
34}
35
36fn impl_structopt_for_struct(
37    name: &Ident,
38    fields: &Punctuated<Field, Comma>,
39) -> proc_macro2::TokenStream {
40    let merged_fields = gen_merged_fields(fields);
41
42    quote! {
43        impl ::structopt_toml::StructOptToml for #name {
44            fn merge<'a>(from_toml: Self, from_args: Self, args: &::structopt_toml::clap::ArgMatches) -> Self where
45                Self: Sized,
46                Self: ::structopt_toml::structopt::StructOpt,
47                Self: ::structopt_toml::serde::de::Deserialize<'a>
48            {
49                Self {
50                    #merged_fields
51                }
52            }
53        }
54
55        impl Default for #name {
56            fn default() -> Self {
57                #name::from_args()
58            }
59        }
60    }
61}
62
63fn gen_merged_fields(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream {
64    let fields = fields.iter().map(|field| {
65        let explicit_name = load_explicit_name(field);
66
67        // If the field is decorated with `#[structopt(flatten)]` we have to treat it differently.
68        // We can't check its existence with `args.is_present` and `args.occurrences_of`
69        // and instead we delegate and call its own `StructOptToml` implementation of `merge`
70        let is_flatten = is_flatten(field);
71
72        // by default the clap arg name is the field name in kebab-case, unless overwritten with `name=<value>`
73        let field_name = field.ident.as_ref().unwrap();
74        let field_type = field.ty.clone();
75        let name_str = explicit_name.unwrap_or_else(|| format!("{}", field_name).to_kebab_case());
76        let structopt_name = LitStr::new(&name_str, field_name.span());
77        if is_flatten {
78            quote!(
79                #field_name: {
80                    <#field_type as ::structopt_toml::StructOptToml>::merge(
81                        from_toml.#field_name,
82                        from_args.#field_name,
83                        args
84                    )
85                }
86            )
87        } else {
88            quote!(
89                #field_name: {
90                    if args.is_present(#structopt_name) && args.occurrences_of(#structopt_name) > 0 {
91                        from_args.#field_name
92                    } else {
93                        from_toml.#field_name
94                    }
95                }
96            )
97        }
98    });
99    quote! (
100        #( #fields ),*
101    )
102}
103
104/// Loads the structopt name from the strcutopt attribute.
105/// i.e. from an attribute of the form `#[structopt(..., name = "some-name", ...)]`
106fn load_explicit_name(field: &Field) -> Option<String> {
107    field
108        .attrs
109        .iter()
110        .filter(|&attr| attr.path.is_ident("structopt"))
111        .filter_map(|attr| {
112            // extract parentheses
113            let ts = attr.parse_args().ok()?;
114            // find name = `value` in attribute
115            syn::parse2::<NameVal>(ts).map(|nv| nv.0).ok()
116        })
117        .next()
118}
119
120/// Checks whether the attribute is marked as flattened
121/// i.e. `#[structopt(flatten)]`
122fn is_flatten(field: &Field) -> bool {
123    field
124        .attrs
125        .iter()
126        .filter(|&attr| attr.path.is_ident("structopt"))
127        .filter_map(|attr| attr.parse_meta().ok())
128        .map(|meta| {
129            let list = match meta {
130                syn::Meta::List(list) => list,
131                _ => return false,
132            };
133            let nested = match list.nested.first() {
134                Some(nested) => nested,
135                _ => return false,
136            };
137            let inner_meta = match nested {
138                syn::NestedMeta::Meta(inner_meta) => inner_meta,
139                _ => return false,
140            };
141            let path = match inner_meta {
142                syn::Meta::Path(path) => path,
143                _ => return false,
144            };
145            path.is_ident("flatten")
146        })
147        .next()
148        .unwrap_or(false)
149}
150
151#[derive(Debug)]
152struct NameVal(String);
153
154impl Parse for NameVal {
155    fn parse(input: ParseStream) -> syn::Result<Self> {
156        #[derive(PartialEq, Eq, Debug)]
157        enum Match {
158            NameToken,
159            PunctEq,
160            LitVal,
161        }
162        let mut state = Match::NameToken;
163        let result = input.step(|cursor| {
164            let mut rest = *cursor;
165            while let Some((tt, next)) = rest.token_tree() {
166                match tt {
167                    TokenTree::Ident(ident) if ident == "name" && state == Match::NameToken => {
168                        state = Match::PunctEq;
169                    }
170                    TokenTree::Punct(punct)
171                        if punct.as_char() == '=' && state == Match::PunctEq =>
172                    {
173                        state = Match::LitVal;
174                    }
175                    TokenTree::Literal(lit) if state == Match::LitVal => {
176                        return Ok((lit.to_string().replace("\"", ""), Cursor::empty()));
177                    }
178                    _ => {
179                        // on first incorrect token reset
180                        state = Match::NameToken;
181                    }
182                }
183                rest = next;
184            }
185            Err(cursor.error("End reached"))
186        });
187        result.map(Self).map_err(|_| input.error("Not found"))
188    }
189}