static_map_macro/
lib.rs

1extern crate proc_macro;
2
3use std::iter::once;
4
5use proc_macro2::{Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn::{
8    parse, parse_quote, punctuated::Punctuated, token::Comma, Arm, Data, DeriveInput, Expr,
9    ExprMatch, Field, FieldValue, Fields, GenericParam, Generics, Ident, Item, ItemImpl, Lit,
10    LitStr, Pat, PatLit, Token, Type,
11};
12
13use self::util::ItemImplExt;
14
15mod util;
16
17enum Mode {
18    Value,
19    Ref,
20    MutRef,
21}
22
23#[proc_macro_derive(StaticMap)]
24pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
25    let input = parse::<DeriveInput>(input).expect("failed to parse input as DeriveInput");
26    let name = input.ident.clone();
27
28    let fields = match input.data {
29        Data::Struct(s) => {
30            if s.fields.is_empty() {
31                panic!("StaticMap: failed to detect type because there's no field")
32            }
33
34            match s.fields {
35                Fields::Named(named) => named.named,
36                _ => panic!("StaticMap: failed to detect type because there's no field"),
37            }
38        }
39        _ => panic!("StaticMap can only be applied to structs"),
40    };
41    let len = fields.len();
42    let data_type = fields.first().unwrap().ty.clone();
43
44    let (_impl_generics, ty_generics, _where_clause) = input.generics.split_for_impl();
45
46    let mut tts = TokenStream::new();
47
48    let type_name = parse_quote!(#name #ty_generics);
49
50    {
51        // IntoIterator
52
53        let make = |m: Mode| {
54            let arr: Punctuated<_, Token![;]> = fields
55                .iter()
56                .map(|f| -> Expr {
57                    //
58                    let name = f.ident.as_ref().unwrap();
59                    let mode = match m {
60                        Mode::Value => quote!(),
61                        Mode::Ref => quote!(&),
62                        Mode::MutRef => quote!(&mut),
63                    };
64                    let value = f.ident.as_ref().unwrap();
65
66                    parse_quote!(
67                        v.push((stringify!(#name), #mode self.#value))
68                    )
69                })
70                .collect();
71
72            arr
73        };
74
75        let body = make(Mode::Value);
76
77        let item: ItemImpl = parse_quote!(
78            impl IntoIterator for #name {
79                type IntoIter = st_map::arrayvec::IntoIter<(&'static str, #data_type), #len>;
80                type Item = (&'static str, #data_type);
81
82                fn into_iter(self) -> Self::IntoIter {
83                    let mut v: st_map::arrayvec::ArrayVec<_, #len> = Default::default();
84
85                    #body;
86
87                    v.into_iter()
88                }
89            }
90        );
91
92        item.with_generics(input.generics.clone())
93            .to_tokens(&mut tts);
94    }
95
96    {
97        // Iterators
98
99        let mut items = vec![];
100
101        items.extend(make_iterator(
102            &type_name,
103            &data_type,
104            &Ident::new(&format!("{name}RefIter"), Span::call_site()),
105            &fields,
106            &input.generics,
107            Mode::Ref,
108        ));
109        items.extend(make_iterator(
110            &type_name,
111            &data_type,
112            &Ident::new(&format!("{name}MutIter"), Span::call_site()),
113            &fields,
114            &input.generics,
115            Mode::MutRef,
116        ));
117
118        for item in items {
119            item.to_tokens(&mut tts);
120        }
121    }
122
123    {
124        // std::ops::Index
125        let body = ExprMatch {
126            attrs: Default::default(),
127            match_token: Default::default(),
128            expr: parse_quote!(v),
129            brace_token: Default::default(),
130            arms: fields
131                .iter()
132                .map(|f| {
133                    let variant = &f.ident;
134                    //
135                    Arm {
136                        attrs: Default::default(),
137                        pat: Pat::Lit(PatLit {
138                            attrs: Default::default(),
139                            lit: Lit::Str(LitStr::new(
140                                &f.ident.as_ref().unwrap().to_string(),
141                                Span::call_site(),
142                            )),
143                        }),
144                        guard: None,
145                        fat_arrow_token: Default::default(),
146                        body: parse_quote!(&self.#variant),
147                        comma: Some(Default::default()),
148                    }
149                })
150                .chain(once(parse_quote!(
151                    _ => panic!("Unknown key: {}", v),
152                )))
153                .collect(),
154        };
155
156        let item: ItemImpl = parse_quote!(
157            impl<'a, K: ?Sized + ::std::borrow::Borrow<str>> ::std::ops::Index<&'a K> for #name {
158                type Output = #data_type;
159                fn index(&self, v: &K) -> &Self::Output {
160                    use std::borrow::Borrow;
161                    let v: &str = v.borrow();
162                    #body
163                }
164            }
165        );
166        item.with_generics(input.generics.clone())
167            .to_tokens(&mut tts);
168    }
169
170    {
171        assert!(
172            input.generics.params.is_empty() || input.generics.params.len() == 1,
173            "StaticMap should have zero or one generic argument"
174        );
175
176        let map_fields: Punctuated<_, Token![,]> = fields
177            .iter()
178            .map(|f| -> FieldValue {
179                let f = f.ident.as_ref().unwrap();
180                let f_str = f.to_string();
181                parse_quote!(
182                    #f: op(#f_str, self.#f)
183                )
184            })
185            .collect();
186
187        // map(), map_value()
188        let item = if input.generics.params.is_empty() {
189            quote!(
190                impl #name {
191                    pub fn map(self, mut op: impl FnMut(&'static str, #data_type) -> #data_type) -> #name {
192                        #name { #map_fields }
193                    }
194
195                    #[inline]
196                    pub fn map_value(self, mut op: impl FnMut(#data_type) -> #data_type) -> #name {
197                        self.map(|_, v| op(v))
198                    }
199                }
200            )
201        } else if match input.generics.params.first().as_ref().unwrap() {
202            GenericParam::Type(ty) => ty.bounds.is_empty(),
203            _ => false,
204        } {
205            quote!(
206                impl<T> #name<T> {
207                    pub fn map<N>(self, mut op: impl FnMut(&'static str, #data_type) -> N) -> #name<N> {
208                        #name { #map_fields }
209                    }
210
211                    #[inline]
212                    pub fn map_value<N>(self, mut op: impl FnMut(#data_type) -> N) -> #name<N> {
213                        self.map(|_, v| op(v))
214                    }
215                }
216            )
217        } else {
218            let bound = match input.generics.params.first().as_ref().unwrap() {
219                GenericParam::Type(ty) => &ty.bounds,
220                _ => unimplemented!("Generic parameters other than type parameter"),
221            };
222
223            quote!(
224                impl<#data_type: #bound> #name<#data_type> {
225                    pub fn map<N: #bound>(
226                        self,
227                        mut op: impl FnMut(&'static str, #data_type) -> N,
228                    ) -> #name<N> {
229                        #name { #map_fields }
230                    }
231
232                    #[inline]
233                    pub fn map_value<N: #bound>(self, mut op: impl FnMut(#data_type) -> N) -> #name<N> {
234                        self.map(|_, v| op(v))
235                    }
236                }
237            )
238        };
239
240        item.to_tokens(&mut tts);
241    }
242
243    tts.into()
244}
245
246fn make_iterator(
247    type_name: &Type,
248    data_type: &Type,
249    iter_type_name: &Ident,
250    fields: &Punctuated<Field, Comma>,
251    generic: &Generics,
252    mode: Mode,
253) -> Vec<Item> {
254    let len = fields.len();
255
256    let (impl_generics, _, _) = generic.split_for_impl();
257
258    let where_clause = generic.where_clause.clone();
259
260    let type_generic = {
261        let type_generic = generic.params.last();
262        match type_generic {
263            Some(GenericParam::Type(t)) => {
264                let param_name = t.ident.clone();
265                let bounds = if t.bounds.is_empty() {
266                    quote!()
267                } else {
268                    let b = &t.bounds;
269                    quote!(: #b)
270                };
271
272                match mode {
273                    Mode::Value => quote!(<#param_name #bounds>),
274                    Mode::Ref => quote!(<'a, #param_name #bounds>),
275                    Mode::MutRef => quote!(<'a, #param_name #bounds>),
276                }
277            }
278            _ => match mode {
279                Mode::Value => quote!(),
280                Mode::Ref => quote!(<'a>),
281                Mode::MutRef => quote!(<'a>),
282            },
283        }
284    };
285
286    let generic_arg_for_method = {
287        let type_generic = generic.params.last();
288        match type_generic {
289            Some(GenericParam::Type(t)) => {
290                let param_name = t.ident.clone();
291
292                quote!(<#param_name>)
293            }
294            _ => quote!(),
295        }
296    };
297
298    let generic = {
299        let type_generic = generic.params.last();
300        match type_generic {
301            Some(GenericParam::Type(t)) => {
302                let param_name = t.ident.clone();
303
304                match mode {
305                    Mode::Value => quote!(<#param_name>),
306                    Mode::Ref => quote!(<'a, #param_name>),
307                    Mode::MutRef => quote!(<'a, #param_name>),
308                }
309            }
310            _ => match mode {
311                Mode::Value => quote!(),
312                Mode::Ref => quote!(<'a>),
313                Mode::MutRef => quote!(<'a>),
314            },
315        }
316    };
317
318    let lifetime = match mode {
319        Mode::Value => quote!(),
320        Mode::Ref => quote!(&'a),
321        Mode::MutRef => quote!(&'a mut),
322    };
323
324    let arms = fields
325        .iter()
326        .enumerate()
327        .map(|(idx, f)| {
328            let pat = idx + 1;
329
330            let name = f.ident.as_ref().unwrap();
331            let name_str = name.to_string();
332            match mode {
333                Mode::Value => quote!(#pat => Some((#name_str, self.data.#name))),
334                Mode::Ref => quote!(#pat => Some((#name_str, &self.data.#name))),
335                Mode::MutRef => quote!(#pat => Some((#name_str, unsafe {
336                    std::mem::transmute::<&mut _, &'a mut _>(&mut self.data.#name)
337                }))),
338            }
339        })
340        .collect::<Punctuated<_, Comma>>();
341
342    let iter_type = parse_quote!(
343        pub struct #iter_type_name #type_generic {
344            cur_index: usize,
345            data: #lifetime #type_name,
346        }
347    );
348    let mut iter_impl: ItemImpl = parse_quote!(
349        impl #type_generic Iterator for #iter_type_name #generic {
350            type Item = (&'static str, #lifetime #data_type);
351
352            fn next(&mut self) -> Option<Self::Item> {
353                self.cur_index += 1;
354                match self.cur_index {
355                    #arms,
356
357                    _ => None
358                }
359            }
360
361            fn size_hint(&self) -> (usize, Option<usize>) {
362                let len = #len - self.cur_index;
363                (len, Some(len))
364            }
365        }
366    );
367    iter_impl.generics.where_clause = where_clause;
368
369    let impl_for_method = {
370        let (recv, method_name) = match mode {
371            Mode::Value => (quote!(self), quote!(into_iter)),
372            Mode::Ref => (quote!(&self), quote!(iter)),
373            Mode::MutRef => (quote!(&mut self), quote!(iter_mut)),
374        };
375
376        parse_quote! {
377            impl #impl_generics #type_name {
378                pub fn #method_name(#recv) -> #iter_type_name #generic_arg_for_method {
379                    #iter_type_name {
380                        cur_index: 0,
381                        data: self,
382                    }
383                }
384            }
385        }
386    };
387
388    vec![iter_type, Item::Impl(iter_impl), impl_for_method]
389}