trans_derive/
lib.rs

1#![recursion_limit = "256"]
2extern crate proc_macro;
3
4use quote::quote;
5
6use proc_macro2::TokenStream;
7
8fn field_schema_name(field: &syn::Field) -> syn::Ident {
9    let mut name = field.ident.clone().unwrap();
10    for attr in &field.attrs {
11        if let Ok(syn::Meta::List(syn::MetaList {
12            path: ref meta_path,
13            ref nested,
14            ..
15        })) = attr.parse_meta()
16        {
17            if meta_path.is_ident("trans") {
18                for inner in nested {
19                    match *inner {
20                        syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
21                            path: ref meta_path,
22                            lit: syn::Lit::Str(ref lit),
23                            ..
24                        })) => {
25                            if meta_path.is_ident("rename") {
26                                name = syn::Ident::new(&lit.value(), lit.span());
27                            }
28                        }
29                        _ => panic!("Unexpected meta"),
30                    }
31                }
32            }
33        }
34    }
35    name
36}
37
38fn get_documentation(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
39    use std::collections::HashMap;
40    let mut language_docs: HashMap<String, String> = HashMap::new();
41    for attr in attrs {
42        if let Ok(syn::Meta::NameValue(syn::MetaNameValue {
43            path,
44            lit: syn::Lit::Str(lit),
45            ..
46        })) = attr.parse_meta()
47        {
48            let doc = if path.is_ident("doc") {
49                Some(("en".to_owned(), lit.value()))
50            } else if path.is_ident("trans_doc") {
51                let text = lit.value();
52                let colon_pos = text
53                    .find(':')
54                    .expect("trans_doc should be in format \"LANG:TEXT\"");
55                Some((
56                    text[..colon_pos].to_owned(),
57                    text[colon_pos + 1..].to_owned(),
58                ))
59            } else {
60                None
61            };
62            if let Some((lang, text)) = doc {
63                let lang = lang.trim();
64                let text = text.trim();
65                let current_text = language_docs.entry(lang.to_owned()).or_default();
66                if !current_text.is_empty() {
67                    current_text.push(' ');
68                }
69                current_text.push_str(text);
70            }
71        }
72    }
73    let language_docs = language_docs.into_iter().map(|(language, text)| {
74        // let language = syn::LitStr::new(&language, proc_macro2::Span::call_site());
75        // let text = syn::LitStr::new(&text, proc_macro2::Span::call_site());
76        quote! {
77            trans::LanguageDocumentation {
78                language: #language.to_owned(),
79                text: #text.to_owned(),
80            }
81        }
82    });
83    quote! {
84        trans::Documentation {
85            languages: {
86                let mut result = Vec::new();
87                #(result.push(#language_docs);)*
88                result
89            },
90        }
91    }
92}
93
94#[proc_macro_derive(Trans, attributes(trans, trans_doc))]
95pub fn derive_trans(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
96    let input: TokenStream = input.into();
97    let result: TokenStream = {
98        let ast: syn::DeriveInput = syn::parse_str(&input.to_string()).unwrap();
99        let input_type = &ast.ident;
100        let generic_params: Vec<_> = ast
101            .generics
102            .type_params()
103            .map(|param| &param.ident)
104            .collect();
105        let generic_params = &generic_params;
106        let mut base_name =
107            syn::LitStr::new(&ast.ident.to_string(), proc_macro2::Span::call_site());
108        let mut magic: Option<syn::Expr> = None;
109        let mut generics_in_name = true;
110        for attr in &ast.attrs {
111            if let Ok(syn::Meta::List(syn::MetaList {
112                path: ref meta_path,
113                ref nested,
114                ..
115            })) = attr.parse_meta()
116            {
117                if meta_path.is_ident("trans") {
118                    for inner in nested {
119                        match *inner {
120                            syn::NestedMeta::Meta(syn::Meta::NameValue(syn::MetaNameValue {
121                                path: ref meta_path,
122                                lit: syn::Lit::Str(ref lit),
123                                ..
124                            })) => {
125                                if meta_path.is_ident("rename") {
126                                    base_name = lit.clone();
127                                } else if meta_path.is_ident("magic") {
128                                    magic = Some(syn::parse_str(&lit.value()).unwrap());
129                                }
130                            }
131                            syn::NestedMeta::Meta(syn::Meta::Path(ref meta_path)) => {
132                                if meta_path.is_ident("no_generics_in_name") {
133                                    generics_in_name = false;
134                                }
135                            }
136                            _ => panic!("Unexpected meta"),
137                        }
138                    }
139                }
140            }
141        }
142        let (magic_write, magic_read) = match &magic {
143            Some(magic) => (
144                quote! {
145                    <i32 as trans::Trans>::write_to(&#magic, writer)?;
146                },
147                quote! {
148                    assert_eq!(<i32 as trans::Trans>::read_from(reader)?, #magic);
149                },
150            ),
151            None => (quote! {}, quote! {}),
152        };
153        let magic_value = match &magic {
154            Some(expr) => quote! { Some(#expr) },
155            None => quote! { None },
156        };
157        let final_name = quote! {{
158            let mut name = #base_name.to_owned();
159            if #generics_in_name {
160                #(
161                    name += &trans::Schema::of::<#generic_params>().full_name().raw();
162                )*
163            }
164            name
165        }};
166        match ast.data {
167            syn::Data::Struct(syn::DataStruct { ref fields, .. }) => match fields {
168                syn::Fields::Named(_) => {
169                    let field_tys: Vec<_> = fields.iter().map(|field| &field.ty).collect();
170                    let field_tys = &field_tys;
171                    let field_names: Vec<_> = fields
172                        .iter()
173                        .map(|field| field.ident.as_ref().unwrap())
174                        .collect();
175                    let field_names = &field_names;
176                    let mut generics = ast.generics.clone();
177                    let extra_where_clauses = quote! {
178                        where
179                            #(#field_tys: trans::Trans,)*
180                            #(#generic_params: trans::Trans,)*
181                    };
182                    let extra_where_clauses: syn::WhereClause =
183                        syn::parse_str(&extra_where_clauses.to_string()).unwrap();
184                    generics
185                        .make_where_clause()
186                        .predicates
187                        .extend(extra_where_clauses.predicates);
188                    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
189                    let schema_fields = fields.iter().map(|field| {
190                        let documentation = get_documentation(&field.attrs);
191                        let schema_name = field_schema_name(field);
192                        let ty = &field.ty;
193                        quote! {
194                            trans::Field {
195                                documentation: #documentation,
196                                name: trans::Name::new(stringify!(#schema_name).to_owned()),
197                                schema: trans::Schema::of::<#ty>(),
198                            }
199                        }
200                    });
201                    let documentation = get_documentation(&ast.attrs);
202                    let expanded = quote! {
203                        impl #impl_generics trans::Trans for #input_type #ty_generics #where_clause {
204                            fn create_schema() -> trans::Schema {
205                                let name = #final_name;
206                                trans::Schema::Struct(trans::Struct {
207                                    documentation: #documentation,
208                                    name: trans::Name::new(name),
209                                    magic: #magic_value,
210                                    fields: vec![#(#schema_fields),*],
211                                })
212                            }
213                            fn write_to(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> {
214                                #magic_write
215                                #(trans::Trans::write_to(&self.#field_names, writer)?;)*
216                                Ok(())
217                            }
218                            fn read_from(reader: &mut dyn std::io::Read) -> std::io::Result<Self> {
219                                #magic_read
220                                Ok(Self {
221                                    #(#field_names: trans::Trans::read_from(reader)?),*
222                                })
223                            }
224                        }
225                    };
226                    expanded.into()
227                }
228                syn::Fields::Unnamed(_) => {
229                    if fields.iter().len() != 1 {
230                        panic!("Tuple structs other than newtype not supported");
231                    }
232                    if magic.is_some() {
233                        panic!("Magic with newtypes not supported");
234                    }
235                    let inner_ty = fields.iter().next().unwrap();
236                    let mut generics = ast.generics.clone();
237                    let extra_where_clauses = quote! {
238                        where #inner_ty: trans::Trans + 'static
239                    };
240                    let extra_where_clauses: syn::WhereClause =
241                        syn::parse_str(&extra_where_clauses.to_string()).unwrap();
242                    generics
243                        .make_where_clause()
244                        .predicates
245                        .extend(extra_where_clauses.predicates);
246                    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
247                    let expanded = quote! {
248                        impl #impl_generics trans::Trans for #input_type #ty_generics #where_clause {
249                            fn create_schema() -> trans::Schema {
250                                <#inner_ty as trans::Trans>::create_schema()
251                            }
252                            fn write_to(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> {
253                                trans::Trans::write_to(&self.0, writer)?;
254                                Ok(())
255                            }
256                            fn read_from(reader: &mut dyn std::io::Read) -> std::io::Result<Self> {
257                                Ok(Self(trans::Trans::read_from(reader)?))
258                            }
259                        }
260                    };
261                    expanded.into()
262                }
263                syn::Fields::Unit => panic!("Unit structs not supported"),
264            },
265            syn::Data::Enum(syn::DataEnum { ref variants, .. }) => {
266                let generics = ast.generics.clone();
267                let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268                let read_write_impl = {
269                    let variant_writes = variants.iter().enumerate().map(|(tag, variant)| {
270                        let tag = tag as i32;
271                        let variant_name = &variant.ident;
272                        let field_names: Vec<_> = variant
273                            .fields
274                            .iter()
275                            .map(|field| {
276                                field
277                                    .ident
278                                    .as_ref()
279                                    .expect("Only named fields are supported")
280                            })
281                            .collect();
282                        let field_names = &field_names;
283                        let field_names_copy = field_names;
284                        quote! {
285                            #input_type::#variant_name { #(#field_names,)* } => {
286                                trans::Trans::write_to(&#tag, writer)?;
287                                #(trans::Trans::write_to(#field_names_copy, writer)?;)*
288                            }
289                        }
290                    });
291                    let variant_reads = variants.iter().enumerate().map(|(tag, variant)| {
292                        let tag = tag as i32;
293                        let variant_name = &variant.ident;
294                        let field_names = variant
295                            .fields
296                            .iter()
297                            .map(|field| field.ident.as_ref().unwrap());
298                        quote! {
299                            #tag => #input_type::#variant_name {
300                                #(#field_names: trans::Trans::read_from(reader)?,)*
301                            },
302                        }
303                    });
304                    quote! {
305                        fn write_to(&self, writer: &mut dyn std::io::Write) -> std::io::Result<()> {
306                            match self {
307                                #(#variant_writes)*
308                            }
309                            Ok(())
310                        }
311                        fn read_from(reader: &mut dyn std::io::Read) -> std::io::Result<Self> {
312                            Ok(match <i32 as trans::Trans>::read_from(reader)? {
313                                #(#variant_reads)*
314                                tag => {
315                                    return Err(std::io::Error::new(
316                                        std::io::ErrorKind::Other,
317                                        format!("Unexpected tag {:?}", tag)));
318                                }
319                            })
320                        }
321                    }
322                };
323                if variants.iter().all(|variant| {
324                    if let syn::Fields::Unit = variant.fields {
325                        true
326                    } else {
327                        false
328                    }
329                }) {
330                    let variants = variants.iter().map(|variant| {
331                        let name = &variant.ident;
332                        let documentation = get_documentation(&variant.attrs);
333                        quote! {
334                            trans::EnumVariant {
335                                name: trans::Name::new(stringify!(#name).to_owned()),
336                                documentation: #documentation,
337                            }
338                        }
339                    });
340                    let documentation = get_documentation(&ast.attrs);
341                    let expanded = quote! {
342                        impl #impl_generics trans::Trans for #input_type #ty_generics #where_clause {
343                            fn create_schema() -> trans::Schema {
344                                let base_name = #final_name;
345                                trans::Schema::Enum {
346                                    documentation: #documentation,
347                                    base_name: trans::Name::new(base_name),
348                                    variants: vec![#(#variants),*],
349                                }
350                            }
351                            #read_write_impl
352                        }
353                    };
354                    expanded.into()
355                } else {
356                    let variants = variants.iter().map(|variant| {
357                        let documentation = get_documentation(&variant.attrs);
358                        let variant_name = &variant.ident;
359                        let schema_fields = variant.fields.iter().map(|field| {
360                            let documentation = get_documentation(&field.attrs);
361                            let schema_name = field_schema_name(field);
362                            let ty = &field.ty;
363                            quote! {
364                                trans::Field {
365                                    documentation: #documentation,
366                                    name: trans::Name::new(stringify!(#schema_name).to_owned()),
367                                    schema: trans::Schema::of::<#ty>(),
368                                }
369                            }
370                        });
371                        quote! {
372                            trans::Struct {
373                                documentation: #documentation,
374                                name: trans::Name::new(stringify!(#variant_name).to_owned()),
375                                magic: None,
376                                fields: vec![
377                                    #(#schema_fields),*
378                                ],
379                            }
380                        }
381                    });
382                    let documentation = get_documentation(&ast.attrs);
383                    let expanded = quote! {
384                        impl #impl_generics trans::Trans for #input_type #ty_generics #where_clause {
385                            fn create_schema() -> trans::Schema {
386                                let base_name = #final_name;
387                                trans::Schema::OneOf {
388                                    documentation: #documentation,
389                                    base_name: trans::Name::new(base_name),
390                                    variants: vec![#(#variants),*],
391                                }
392                            }
393                            #read_write_impl
394                        }
395                    };
396                    expanded.into()
397                }
398            }
399            syn::Data::Union(_) => panic!("Unions not supported"),
400        }
401    };
402    result.into()
403}