structopt_yaml_derive/
lib.rs

1extern crate proc_macro;
2extern crate syn;
3#[macro_use]
4extern crate quote;
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenTree;
8use syn::parse::{Parse, ParseStream};
9use syn::punctuated::Punctuated;
10use syn::token::Comma;
11use syn::{buffer::Cursor, DataStruct, DeriveInput, Field, Ident, LitStr};
12
13#[proc_macro_derive(StructOptYaml, attributes(structopt))]
14pub fn structopt_yaml(input: TokenStream) -> TokenStream {
15    let input: DeriveInput = syn::parse(input).unwrap();
16    let gen = impl_structopt_yaml(&input);
17    gen.into()
18}
19
20fn impl_structopt_yaml(input: &DeriveInput) -> proc_macro2::TokenStream {
21    use syn::Data::*;
22
23    let struct_name = &input.ident;
24    let inner_impl = match input.data {
25        Struct(DataStruct {
26            fields: syn::Fields::Named(ref fields),
27            ..
28        }) => impl_structopt_for_struct(struct_name, &fields.named),
29        _ => panic!("structopt_yaml only supports non-tuple struct"),
30    };
31
32    quote!(#inner_impl)
33}
34
35fn impl_structopt_for_struct(
36    name: &Ident,
37    fields: &Punctuated<Field, Comma>,
38) -> proc_macro2::TokenStream {
39    let merged_fields = gen_merged_fields(fields);
40
41    quote! {
42        impl ::structopt_yaml::StructOptYaml for #name {
43            fn merge<'a>(from_yaml: Self, from_args: Self, args: &::structopt_yaml::clap::ArgMatches) -> Self where
44                Self: Sized,
45                Self: ::structopt_yaml::structopt::StructOpt,
46                Self: ::structopt_yaml::serde::de::Deserialize<'a>
47            {
48                Self {
49                    #merged_fields
50                }
51            }
52        }
53
54        impl Default for #name {
55            fn default() -> Self {
56                #name::from_args()
57            }
58        }
59    }
60}
61
62fn gen_merged_fields(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream {
63    let fields = fields.iter().map(|field| {
64        let explicit_name = load_explicit_name(field);
65
66        // If the field is decorated with `#[structopt(flatten)]` we have to treat it differently.
67        // We can't check its existence with `args.is_present` and `args.occurrences_of`
68        // and instead we delegate and call its own `StructOptYaml` implementation of `merge`
69        let is_flatten = is_flatten(field);
70
71        // by default the clap arg name is the field name, unless overwritten with `name=<value>`
72        let field_name = field.ident.as_ref().unwrap();
73        let field_type = field.ty.clone();
74        let name_str = explicit_name.unwrap_or(format!("{}", field_name));
75        let structopt_name = LitStr::new(&name_str, field_name.span());
76        if is_flatten {
77            quote!(
78                #field_name: {
79                    <#field_type as ::structopt_yaml::StructOptYaml>::merge(
80                        from_yaml.#field_name,
81                        from_args.#field_name,
82                        args
83                    )
84                }
85            )
86        } else {
87            quote!(
88                #field_name: {
89                    if args.is_present(#structopt_name) && args.occurrences_of(#structopt_name) > 0 {
90                        from_args.#field_name
91                    } else {
92                        from_yaml.#field_name
93                    }
94                }
95            )
96        }
97    });
98    quote! (
99        #( #fields ),*
100    )
101}
102
103/// Loads the structopt name from the strcutopt attribute.
104/// i.e. from an attribute of the form `#[structopt(..., name = "some-name", ...)]`
105fn load_explicit_name(field: &Field) -> Option<String> {
106    field
107        .attrs
108        .iter()
109        .filter(|&attr| attr.path.is_ident("structopt"))
110        .filter_map(|attr| {
111            // extract parentheses
112            let ts = attr.parse_args().ok()?;
113            // find name = `value` in attribute
114            syn::parse2::<NameVal>(ts).map(|nv| nv.0).ok()
115        })
116        .nth(0)
117}
118
119/// Checks whether the attribute is marked as flattened
120/// i.e. `#[structopt(flatten)]`
121fn is_flatten(field: &Field) -> bool {
122    field
123        .attrs
124        .iter()
125        .filter(|&attr| attr.path.is_ident("structopt"))
126        .filter_map(|attr| attr.parse_meta().ok())
127        .map(|meta| {
128            let list = match meta {
129                syn::Meta::List(list) => list,
130                _ => return false,
131            };
132            let nested = match list.nested.first() {
133                Some(nested) => nested,
134                _ => return false,
135            };
136            let inner_meta = match nested {
137                syn::NestedMeta::Meta(inner_meta) => inner_meta,
138                _ => return false,
139            };
140            let path = match inner_meta {
141                syn::Meta::Path(path) => path,
142                _ => return false,
143            };
144            path.is_ident("flatten")
145        })
146        .nth(0)
147        .unwrap_or(false)
148}
149
150#[derive(Debug)]
151struct NameVal(String);
152
153impl Parse for NameVal {
154    fn parse(input: ParseStream) -> syn::Result<Self> {
155        #[derive(PartialEq, Eq, Debug)]
156        enum Match {
157            NameToken,
158            PunctEq,
159            LitVal,
160        }
161        let mut state = Match::NameToken;
162        let result = input.step(|cursor| {
163            let mut rest = *cursor;
164            while let Some((tt, next)) = rest.token_tree() {
165                match tt {
166                    TokenTree::Ident(ident) if ident == "name" && state == Match::NameToken => {
167                        state = Match::PunctEq;
168                    }
169                    TokenTree::Punct(punct)
170                        if punct.as_char() == '=' && state == Match::PunctEq =>
171                    {
172                        state = Match::LitVal;
173                    }
174                    TokenTree::Literal(lit) if state == Match::LitVal => {
175                        return Ok((lit.to_string().replace("\"", ""), Cursor::empty()));
176                    }
177                    _ => {
178                        // on first incorrect token reset
179                        state = Match::NameToken;
180                    }
181                }
182                rest = next;
183            }
184            Err(cursor.error("End reached"))
185        });
186        result
187            .map(|lit| Self(lit))
188            .map_err(|_| input.error("Not found"))
189    }
190}