wayk_proto_derive/
lib.rs

1#![no_std]
2
3extern crate alloc;
4extern crate proc_macro;
5extern crate proc_macro2;
6
7use alloc::vec::Vec;
8use proc_macro::TokenStream;
9use proc_macro2::{Span, TokenStream as TokenStream2};
10use quote::quote;
11use syn::{
12    punctuated::Punctuated, token::Add, Attribute, Data, Fields, Generics, Ident, Lifetime, LifetimeDef, Lit, Meta,
13    Type,
14};
15
16mod parsed {
17    use alloc::vec::Vec;
18
19    pub enum Type<'a> {
20        Struct(Struct<'a>),
21        FieldlessEnum(FieldlessEnum<'a>),
22        MetaEnum(MetaEnum<'a>),
23    }
24
25    pub struct Struct<'a> {
26        pub name: &'a syn::Ident,
27        pub generics: &'a syn::Generics,
28        pub fields: Vec<Field<'a>>,
29    }
30
31    pub struct Field<'a> {
32        pub decode_ignore: bool,
33        pub encode_ignore: bool,
34        pub name: &'a syn::Ident,
35        pub ty: &'a syn::Type,
36    }
37
38    pub struct FieldlessEnum<'a> {
39        pub name: &'a syn::Ident,
40        pub underlying_repr: syn::Ident,
41    }
42
43    pub struct MetaEnum<'a> {
44        pub name: &'a syn::Ident,
45        pub generics: &'a syn::Generics,
46        pub subtype_enum_ty: syn::Ident,
47        pub meta_variants: Vec<MetaVariant<'a>>,
48    }
49
50    pub struct MetaVariant<'a> {
51        pub decode_ignore: bool,
52        pub encode_ignore: bool,
53        pub name: &'a syn::Ident,
54        pub field_type: &'a syn::Type,
55    }
56}
57
58#[proc_macro_derive(Encode, attributes(meta_enum, encode_ignore))]
59pub fn encode_macro_derive(input: TokenStream) -> TokenStream {
60    let ast = syn::parse(input).expect("failed to parse input");
61    impl_trait(&ast, impl_encode)
62}
63
64fn impl_encode(ty: parsed::Type<'_>) -> TokenStream {
65    match ty {
66        parsed::Type::Struct(data) => {
67            let ty = data.name;
68            let (impl_generics, ty_generics, where_clause) = data.generics.split_for_impl();
69            let fields = data
70                .fields
71                .iter()
72                .filter(|field| !field.encode_ignore)
73                .map(|field| field.name)
74                .collect::<Vec<&Ident>>();
75
76            let expanded = quote! {
77                impl #impl_generics ::wayk_proto::serialization::Encode for #ty #ty_generics #where_clause {
78                    fn encoded_len(&self) -> usize {
79                        #(
80                            self.#fields.encoded_len()
81                        )+*
82                    }
83
84                    fn encode_into<W: ::std::io::Write>(&self, writer: &mut W) -> ::core::result::Result<(), ::wayk_proto::error::ProtoError> {
85                        use ::wayk_proto::error::{ProtoErrorKind, ProtoErrorResultExt};
86                        #(
87                            self.#fields.encode_into(writer)
88                                .chain(ProtoErrorKind::Encoding(stringify!(#ty)))
89                                .or_else_desc(|| format!("couldn't encode {}::{}", stringify!(#ty), stringify!(#fields)))?;
90                        )*
91                        Ok(())
92                    }
93                }
94            };
95
96            expanded.into()
97        }
98        parsed::Type::MetaEnum(data) => {
99            let ty = data.name;
100            let (impl_generics, ty_generics, where_clause) = data.generics.split_for_impl();
101
102            let variants: Vec<&Ident> = data
103                .meta_variants
104                .iter()
105                .filter(|variant| !variant.encode_ignore)
106                .map(|variant| variant.name)
107                .collect();
108
109            let expanded = quote! {
110                impl #impl_generics ::wayk_proto::serialization::Encode for #ty #ty_generics #where_clause {
111                    fn encoded_len(&self) -> usize {
112                        match self {
113                            #(
114                                Self::#variants(msg) => msg.encoded_len(),
115                            )*
116                        }
117                    }
118
119                    fn encode_into<W: ::std::io::Write>(&self, writer: &mut W) -> ::core::result::Result<(), ::wayk_proto::error::ProtoError> {
120                        use ::wayk_proto::error::{ProtoErrorKind, ProtoErrorResultExt};
121                        match self {
122                            #(
123                                Self::#variants(msg) => msg
124                                    .encode_into(writer)
125                                    .chain(ProtoErrorKind::Encoding(stringify!(#ty)))
126                                    .or_desc(concat!("couldn't encode ", stringify!(#variants)," message")),
127                            )*
128                        }
129                    }
130                }
131            };
132
133            expanded.into()
134        }
135        parsed::Type::FieldlessEnum(data) => {
136            let ty = data.name;
137            let underlying_repr = data.underlying_repr;
138
139            let expanded = quote! {
140                impl ::wayk_proto::serialization::Encode for #ty {
141                    fn encoded_len(&self) -> usize {
142                        ::core::mem::size_of::<#underlying_repr>()
143                    }
144
145                    fn encode_into<W: ::std::io::Write>(
146                        &self,
147                        writer: &mut W,
148                    ) -> ::core::result::Result<(), ::wayk_proto::error::ProtoError> {
149                        <#underlying_repr>::encode_into(&(*self as #underlying_repr), writer)
150                    }
151                }
152
153                impl #ty {
154                    fn to_primitive(&self) -> #underlying_repr {
155                        *self as #underlying_repr
156                    }
157                }
158            };
159
160            expanded.into()
161        }
162    }
163}
164
165#[proc_macro_derive(Decode, attributes(meta_enum, decode_ignore))]
166pub fn decode_macro_derive(input: TokenStream) -> TokenStream {
167    let ast = syn::parse(input).expect("failed to parse input");
168    impl_trait(&ast, impl_decode)
169}
170
171fn build_decode_impl_generics(generics: &Generics) -> TokenStream2 {
172    let decode_lt = {
173        let lt = Lifetime::new("'dec", Span::call_site());
174
175        let mut bounds = Punctuated::<Lifetime, Add>::new();
176        for bounded_lt in generics.lifetimes() {
177            bounds.push(bounded_lt.lifetime.clone());
178        }
179
180        let mut lt_def = LifetimeDef::new(lt);
181        lt_def.bounds = bounds;
182
183        lt_def
184    };
185
186    let lifetimes = generics.lifetimes();
187    let type_params = generics.type_params();
188
189    quote! {
190        <#decode_lt, #(#lifetimes),* #(#type_params)+*>
191    }
192}
193
194fn impl_decode(enc_dec_ty: parsed::Type<'_>) -> TokenStream {
195    match enc_dec_ty {
196        parsed::Type::Struct(data) => {
197            let ty = data.name;
198
199            let impl_generics = build_decode_impl_generics(data.generics);
200            let (_, ty_generics, where_clause) = data.generics.split_for_impl();
201
202            let fields_ty = data
203                .fields
204                .iter()
205                .filter(|field| !field.decode_ignore)
206                .map(|field| field.ty)
207                .collect::<Vec<&Type>>();
208            let fields = data
209                .fields
210                .iter()
211                .filter(|field| !field.decode_ignore)
212                .map(|field| field.name)
213                .collect::<Vec<&Ident>>();
214            let ignored_fields = data
215                .fields
216                .iter()
217                .filter(|field| field.decode_ignore)
218                .map(|field| field.name)
219                .collect::<Vec<&Ident>>();
220
221            let expanded = quote! {
222                impl #impl_generics ::wayk_proto::serialization::Decode<'dec> for #ty #ty_generics #where_clause {
223                    fn decode_from(cursor: &mut ::std::io::Cursor<&'dec [u8]>) -> ::core::result::Result<Self, ::wayk_proto::error::ProtoError> {
224                        use ::wayk_proto::error::{ProtoErrorResultExt, ProtoErrorKind};
225                        Ok(Self {
226                            #(
227                                #fields: <#fields_ty as ::wayk_proto::serialization::Decode>::decode_from(cursor)
228                                    .chain(ProtoErrorKind::Decoding(stringify!(#ty)))
229                                    .or_desc(concat!(
230                                        "couldn't decode ",
231                                        stringify!(#fields_ty),
232                                        " into ",
233                                        stringify!(#ty), "::", stringify!(#fields)
234                                    ))?,
235                            )*
236                            #(
237                                #ignored_fields: ::core::default::Default::default(),
238                            )*
239                        })
240                    }
241                }
242            };
243
244            expanded.into()
245        }
246        parsed::Type::MetaEnum(data) => {
247            let ty = data.name;
248            let generics = data.generics;
249            let subtype_enum_ty = &data.subtype_enum_ty;
250
251            let variants: Vec<&Ident> = data
252                .meta_variants
253                .iter()
254                .filter(|variant| !variant.decode_ignore)
255                .map(|variant| variant.name)
256                .collect();
257            let variants_field_ty: Vec<&Type> = data
258                .meta_variants
259                .iter()
260                .filter(|variant| !variant.decode_ignore)
261                .map(|variant| variant.field_type)
262                .collect();
263
264            let impl_generics = build_decode_impl_generics(generics);
265            let (_, ty_generics, where_clause) = generics.split_for_impl();
266
267            let expanded = quote! {
268                impl #impl_generics ::wayk_proto::serialization::Decode<'dec> for #ty #ty_generics #where_clause {
269                    fn decode_from(cursor: &mut ::std::io::Cursor<&'dec [u8]>) -> ::core::result::Result<Self, ::wayk_proto::error::ProtoError> {
270                        use ::wayk_proto::error::{ProtoErrorResultExt, ProtoErrorKind};
271                        use ::wayk_proto::serialization::Encode;
272                        use ::std::io::{Seek, SeekFrom};
273
274                        let subtype = <#subtype_enum_ty as ::wayk_proto::serialization::Decode>::decode_from(cursor)
275                            .chain(ProtoErrorKind::Decoding(stringify!(#ty)))
276                            .or_desc("couldn't decode subtype")?;
277                        cursor.seek(SeekFrom::Current(-(subtype.encoded_len() as i64)))
278                            .expect("seek back after subtype decoding failed"); // cannot fail
279
280                        match subtype {
281                            #(
282                                #subtype_enum_ty::#variants => <#variants_field_ty as ::wayk_proto::serialization::Decode>::decode_from(cursor)
283                                    .map(Self::#variants)
284                                    .chain(ProtoErrorKind::Decoding(stringify!(#ty)))
285                                    .or_desc(concat!(
286                                        "couldn't decode ",
287                                        stringify!(#ty),
288                                        " for subtype ",
289                                        stringify!(#variants)
290                                    )),
291                            )*
292                        }
293                    }
294                }
295            };
296
297            expanded.into()
298        }
299        parsed::Type::FieldlessEnum(data) => {
300            let ty = data.name;
301            let underlying_repr = data.underlying_repr;
302
303            let from_primitive = Ident::new(&alloc::format!("from_{}", underlying_repr), Span::call_site());
304
305            let expanded = quote! {
306                impl ::wayk_proto::serialization::Decode<'_> for #ty {
307                    fn decode_from(
308                        cursor: &mut ::std::io::Cursor<&[u8]>,
309                    ) -> ::core::result::Result<Self, ::wayk_proto::error::ProtoError> {
310                        use ::wayk_proto::error::{ProtoErrorKind, ProtoErrorResultExt};
311                        let v = #underlying_repr::decode_from(cursor)?;
312                        ::num::FromPrimitive::#from_primitive(v)
313                            .chain(ProtoErrorKind::Decoding(stringify!($ty)))
314                            .or_else_desc(||
315                                format!(concat!("no variant in ", stringify!(#ty), " for value {}"), v)
316                            )
317                    }
318                }
319            };
320
321            expanded.into()
322        }
323    }
324}
325
326fn find_attr<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attribute> {
327    attrs
328        .iter()
329        .find(|attr| attr.path.segments.iter().any(|seg| seg.ident == name))
330}
331
332fn impl_trait<F>(ast: &syn::DeriveInput, implementor: F) -> TokenStream
333where
334    F: FnOnce(parsed::Type<'_>) -> TokenStream,
335{
336    let ty = &ast.ident;
337    let generics = &ast.generics;
338    let enc_dec_type = match &ast.data {
339        Data::Struct(data) => {
340            if let Fields::Named(fields) = &data.fields {
341                let fields = fields
342                    .named
343                    .iter()
344                    .map(|field| parsed::Field {
345                        decode_ignore: find_attr(&field.attrs, "decode_ignore").is_some(),
346                        encode_ignore: find_attr(&field.attrs, "encode_ignore").is_some(),
347                        name: field.ident.as_ref().unwrap(),
348                        ty: &field.ty,
349                    })
350                    .collect();
351
352                parsed::Type::Struct(parsed::Struct {
353                    name: ty,
354                    generics,
355                    fields,
356                })
357            } else {
358                unimplemented!("currently only named fields are supported");
359            }
360        }
361        Data::Enum(data) => {
362            let meta_enum_attr = find_attr(&ast.attrs, "meta_enum");
363            let repr_attr = find_attr(&ast.attrs, "repr");
364            if let Some(meta_enum_attr) = meta_enum_attr {
365                let meta = meta_enum_attr
366                    .parse_meta()
367                    .expect("failed to parse `meta_enum` argument");
368                let subtype_enum_ty = if let Meta::NameValue(name) = meta {
369                    if let Lit::Str(s) = name.lit {
370                        Ident::new(&s.value(), Span::call_site())
371                    } else {
372                        panic!("wrong literal in `meta_enum` attribute parameter. Expected a string literal for the subtype enum.");
373                    }
374                } else {
375                    panic!(r#"wrong meta for `meta_enum`. Expected a name value (eg: meta_enum = "...")."#);
376                };
377
378                let mut meta_variants = Vec::new();
379                for variant in &data.variants {
380                    let variant = parsed::MetaVariant {
381                        decode_ignore: find_attr(&variant.attrs, "decode_ignore").is_some(),
382                        encode_ignore: find_attr(&variant.attrs, "encode_ignore").is_some(),
383                        name: &variant.ident,
384                        field_type: match &variant.fields {
385                            Fields::Unnamed(field) => &field.unnamed.first().unwrap().ty,
386                            Fields::Named(_) => panic!("named fields unsupported"),
387                            Fields::Unit => panic!("unexpected unit field"),
388                        },
389                    };
390
391                    meta_variants.push(variant);
392                }
393
394                parsed::Type::MetaEnum(parsed::MetaEnum {
395                    name: ty,
396                    generics,
397                    subtype_enum_ty,
398                    meta_variants,
399                })
400            } else if let Some(repr_attr) = repr_attr {
401                parsed::Type::FieldlessEnum(parsed::FieldlessEnum {
402                    name: ty,
403                    underlying_repr: repr_attr.parse_args().expect("couldn't parse repr type"),
404                })
405            } else {
406                panic!("meta_enum or repr attribute missing")
407            }
408        }
409        Data::Union(_) => unimplemented!("union"),
410    };
411
412    implementor(enc_dec_type)
413}