1#![no_std]
2
3extern crate alloc;
4extern crate proc_macro;
5extern crate proc_macro2;
6
7use alloc::vec::Vec;
8use proc_macro::TokenStream;
9use proc_macro2::{Span, TokenStream as TokenStream2};
10use quote::quote;
11use syn::{
12 punctuated::Punctuated, token::Add, Attribute, Data, Fields, Generics, Ident, Lifetime, LifetimeDef, Lit, Meta,
13 Type,
14};
15
16mod parsed {
17 use alloc::vec::Vec;
18
19 pub enum Type<'a> {
20 Struct(Struct<'a>),
21 FieldlessEnum(FieldlessEnum<'a>),
22 MetaEnum(MetaEnum<'a>),
23 }
24
25 pub struct Struct<'a> {
26 pub name: &'a syn::Ident,
27 pub generics: &'a syn::Generics,
28 pub fields: Vec<Field<'a>>,
29 }
30
31 pub struct Field<'a> {
32 pub decode_ignore: bool,
33 pub encode_ignore: bool,
34 pub name: &'a syn::Ident,
35 pub ty: &'a syn::Type,
36 }
37
38 pub struct FieldlessEnum<'a> {
39 pub name: &'a syn::Ident,
40 pub underlying_repr: syn::Ident,
41 }
42
43 pub struct MetaEnum<'a> {
44 pub name: &'a syn::Ident,
45 pub generics: &'a syn::Generics,
46 pub subtype_enum_ty: syn::Ident,
47 pub meta_variants: Vec<MetaVariant<'a>>,
48 }
49
50 pub struct MetaVariant<'a> {
51 pub decode_ignore: bool,
52 pub encode_ignore: bool,
53 pub name: &'a syn::Ident,
54 pub field_type: &'a syn::Type,
55 }
56}
57
58#[proc_macro_derive(Encode, attributes(meta_enum, encode_ignore))]
59pub fn encode_macro_derive(input: TokenStream) -> TokenStream {
60 let ast = syn::parse(input).expect("failed to parse input");
61 impl_trait(&ast, impl_encode)
62}
63
64fn impl_encode(ty: parsed::Type<'_>) -> TokenStream {
65 match ty {
66 parsed::Type::Struct(data) => {
67 let ty = data.name;
68 let (impl_generics, ty_generics, where_clause) = data.generics.split_for_impl();
69 let fields = data
70 .fields
71 .iter()
72 .filter(|field| !field.encode_ignore)
73 .map(|field| field.name)
74 .collect::<Vec<&Ident>>();
75
76 let expanded = quote! {
77 impl #impl_generics ::wayk_proto::serialization::Encode for #ty #ty_generics #where_clause {
78 fn encoded_len(&self) -> usize {
79 #(
80 self.#fields.encoded_len()
81 )+*
82 }
83
84 fn encode_into<W: ::std::io::Write>(&self, writer: &mut W) -> ::core::result::Result<(), ::wayk_proto::error::ProtoError> {
85 use ::wayk_proto::error::{ProtoErrorKind, ProtoErrorResultExt};
86 #(
87 self.#fields.encode_into(writer)
88 .chain(ProtoErrorKind::Encoding(stringify!(#ty)))
89 .or_else_desc(|| format!("couldn't encode {}::{}", stringify!(#ty), stringify!(#fields)))?;
90 )*
91 Ok(())
92 }
93 }
94 };
95
96 expanded.into()
97 }
98 parsed::Type::MetaEnum(data) => {
99 let ty = data.name;
100 let (impl_generics, ty_generics, where_clause) = data.generics.split_for_impl();
101
102 let variants: Vec<&Ident> = data
103 .meta_variants
104 .iter()
105 .filter(|variant| !variant.encode_ignore)
106 .map(|variant| variant.name)
107 .collect();
108
109 let expanded = quote! {
110 impl #impl_generics ::wayk_proto::serialization::Encode for #ty #ty_generics #where_clause {
111 fn encoded_len(&self) -> usize {
112 match self {
113 #(
114 Self::#variants(msg) => msg.encoded_len(),
115 )*
116 }
117 }
118
119 fn encode_into<W: ::std::io::Write>(&self, writer: &mut W) -> ::core::result::Result<(), ::wayk_proto::error::ProtoError> {
120 use ::wayk_proto::error::{ProtoErrorKind, ProtoErrorResultExt};
121 match self {
122 #(
123 Self::#variants(msg) => msg
124 .encode_into(writer)
125 .chain(ProtoErrorKind::Encoding(stringify!(#ty)))
126 .or_desc(concat!("couldn't encode ", stringify!(#variants)," message")),
127 )*
128 }
129 }
130 }
131 };
132
133 expanded.into()
134 }
135 parsed::Type::FieldlessEnum(data) => {
136 let ty = data.name;
137 let underlying_repr = data.underlying_repr;
138
139 let expanded = quote! {
140 impl ::wayk_proto::serialization::Encode for #ty {
141 fn encoded_len(&self) -> usize {
142 ::core::mem::size_of::<#underlying_repr>()
143 }
144
145 fn encode_into<W: ::std::io::Write>(
146 &self,
147 writer: &mut W,
148 ) -> ::core::result::Result<(), ::wayk_proto::error::ProtoError> {
149 <#underlying_repr>::encode_into(&(*self as #underlying_repr), writer)
150 }
151 }
152
153 impl #ty {
154 fn to_primitive(&self) -> #underlying_repr {
155 *self as #underlying_repr
156 }
157 }
158 };
159
160 expanded.into()
161 }
162 }
163}
164
165#[proc_macro_derive(Decode, attributes(meta_enum, decode_ignore))]
166pub fn decode_macro_derive(input: TokenStream) -> TokenStream {
167 let ast = syn::parse(input).expect("failed to parse input");
168 impl_trait(&ast, impl_decode)
169}
170
171fn build_decode_impl_generics(generics: &Generics) -> TokenStream2 {
172 let decode_lt = {
173 let lt = Lifetime::new("'dec", Span::call_site());
174
175 let mut bounds = Punctuated::<Lifetime, Add>::new();
176 for bounded_lt in generics.lifetimes() {
177 bounds.push(bounded_lt.lifetime.clone());
178 }
179
180 let mut lt_def = LifetimeDef::new(lt);
181 lt_def.bounds = bounds;
182
183 lt_def
184 };
185
186 let lifetimes = generics.lifetimes();
187 let type_params = generics.type_params();
188
189 quote! {
190 <#decode_lt, #(#lifetimes),* #(#type_params)+*>
191 }
192}
193
194fn impl_decode(enc_dec_ty: parsed::Type<'_>) -> TokenStream {
195 match enc_dec_ty {
196 parsed::Type::Struct(data) => {
197 let ty = data.name;
198
199 let impl_generics = build_decode_impl_generics(data.generics);
200 let (_, ty_generics, where_clause) = data.generics.split_for_impl();
201
202 let fields_ty = data
203 .fields
204 .iter()
205 .filter(|field| !field.decode_ignore)
206 .map(|field| field.ty)
207 .collect::<Vec<&Type>>();
208 let fields = data
209 .fields
210 .iter()
211 .filter(|field| !field.decode_ignore)
212 .map(|field| field.name)
213 .collect::<Vec<&Ident>>();
214 let ignored_fields = data
215 .fields
216 .iter()
217 .filter(|field| field.decode_ignore)
218 .map(|field| field.name)
219 .collect::<Vec<&Ident>>();
220
221 let expanded = quote! {
222 impl #impl_generics ::wayk_proto::serialization::Decode<'dec> for #ty #ty_generics #where_clause {
223 fn decode_from(cursor: &mut ::std::io::Cursor<&'dec [u8]>) -> ::core::result::Result<Self, ::wayk_proto::error::ProtoError> {
224 use ::wayk_proto::error::{ProtoErrorResultExt, ProtoErrorKind};
225 Ok(Self {
226 #(
227 #fields: <#fields_ty as ::wayk_proto::serialization::Decode>::decode_from(cursor)
228 .chain(ProtoErrorKind::Decoding(stringify!(#ty)))
229 .or_desc(concat!(
230 "couldn't decode ",
231 stringify!(#fields_ty),
232 " into ",
233 stringify!(#ty), "::", stringify!(#fields)
234 ))?,
235 )*
236 #(
237 #ignored_fields: ::core::default::Default::default(),
238 )*
239 })
240 }
241 }
242 };
243
244 expanded.into()
245 }
246 parsed::Type::MetaEnum(data) => {
247 let ty = data.name;
248 let generics = data.generics;
249 let subtype_enum_ty = &data.subtype_enum_ty;
250
251 let variants: Vec<&Ident> = data
252 .meta_variants
253 .iter()
254 .filter(|variant| !variant.decode_ignore)
255 .map(|variant| variant.name)
256 .collect();
257 let variants_field_ty: Vec<&Type> = data
258 .meta_variants
259 .iter()
260 .filter(|variant| !variant.decode_ignore)
261 .map(|variant| variant.field_type)
262 .collect();
263
264 let impl_generics = build_decode_impl_generics(generics);
265 let (_, ty_generics, where_clause) = generics.split_for_impl();
266
267 let expanded = quote! {
268 impl #impl_generics ::wayk_proto::serialization::Decode<'dec> for #ty #ty_generics #where_clause {
269 fn decode_from(cursor: &mut ::std::io::Cursor<&'dec [u8]>) -> ::core::result::Result<Self, ::wayk_proto::error::ProtoError> {
270 use ::wayk_proto::error::{ProtoErrorResultExt, ProtoErrorKind};
271 use ::wayk_proto::serialization::Encode;
272 use ::std::io::{Seek, SeekFrom};
273
274 let subtype = <#subtype_enum_ty as ::wayk_proto::serialization::Decode>::decode_from(cursor)
275 .chain(ProtoErrorKind::Decoding(stringify!(#ty)))
276 .or_desc("couldn't decode subtype")?;
277 cursor.seek(SeekFrom::Current(-(subtype.encoded_len() as i64)))
278 .expect("seek back after subtype decoding failed"); match subtype {
281 #(
282 #subtype_enum_ty::#variants => <#variants_field_ty as ::wayk_proto::serialization::Decode>::decode_from(cursor)
283 .map(Self::#variants)
284 .chain(ProtoErrorKind::Decoding(stringify!(#ty)))
285 .or_desc(concat!(
286 "couldn't decode ",
287 stringify!(#ty),
288 " for subtype ",
289 stringify!(#variants)
290 )),
291 )*
292 }
293 }
294 }
295 };
296
297 expanded.into()
298 }
299 parsed::Type::FieldlessEnum(data) => {
300 let ty = data.name;
301 let underlying_repr = data.underlying_repr;
302
303 let from_primitive = Ident::new(&alloc::format!("from_{}", underlying_repr), Span::call_site());
304
305 let expanded = quote! {
306 impl ::wayk_proto::serialization::Decode<'_> for #ty {
307 fn decode_from(
308 cursor: &mut ::std::io::Cursor<&[u8]>,
309 ) -> ::core::result::Result<Self, ::wayk_proto::error::ProtoError> {
310 use ::wayk_proto::error::{ProtoErrorKind, ProtoErrorResultExt};
311 let v = #underlying_repr::decode_from(cursor)?;
312 ::num::FromPrimitive::#from_primitive(v)
313 .chain(ProtoErrorKind::Decoding(stringify!($ty)))
314 .or_else_desc(||
315 format!(concat!("no variant in ", stringify!(#ty), " for value {}"), v)
316 )
317 }
318 }
319 };
320
321 expanded.into()
322 }
323 }
324}
325
326fn find_attr<'a>(attrs: &'a [Attribute], name: &str) -> Option<&'a Attribute> {
327 attrs
328 .iter()
329 .find(|attr| attr.path.segments.iter().any(|seg| seg.ident == name))
330}
331
332fn impl_trait<F>(ast: &syn::DeriveInput, implementor: F) -> TokenStream
333where
334 F: FnOnce(parsed::Type<'_>) -> TokenStream,
335{
336 let ty = &ast.ident;
337 let generics = &ast.generics;
338 let enc_dec_type = match &ast.data {
339 Data::Struct(data) => {
340 if let Fields::Named(fields) = &data.fields {
341 let fields = fields
342 .named
343 .iter()
344 .map(|field| parsed::Field {
345 decode_ignore: find_attr(&field.attrs, "decode_ignore").is_some(),
346 encode_ignore: find_attr(&field.attrs, "encode_ignore").is_some(),
347 name: field.ident.as_ref().unwrap(),
348 ty: &field.ty,
349 })
350 .collect();
351
352 parsed::Type::Struct(parsed::Struct {
353 name: ty,
354 generics,
355 fields,
356 })
357 } else {
358 unimplemented!("currently only named fields are supported");
359 }
360 }
361 Data::Enum(data) => {
362 let meta_enum_attr = find_attr(&ast.attrs, "meta_enum");
363 let repr_attr = find_attr(&ast.attrs, "repr");
364 if let Some(meta_enum_attr) = meta_enum_attr {
365 let meta = meta_enum_attr
366 .parse_meta()
367 .expect("failed to parse `meta_enum` argument");
368 let subtype_enum_ty = if let Meta::NameValue(name) = meta {
369 if let Lit::Str(s) = name.lit {
370 Ident::new(&s.value(), Span::call_site())
371 } else {
372 panic!("wrong literal in `meta_enum` attribute parameter. Expected a string literal for the subtype enum.");
373 }
374 } else {
375 panic!(r#"wrong meta for `meta_enum`. Expected a name value (eg: meta_enum = "...")."#);
376 };
377
378 let mut meta_variants = Vec::new();
379 for variant in &data.variants {
380 let variant = parsed::MetaVariant {
381 decode_ignore: find_attr(&variant.attrs, "decode_ignore").is_some(),
382 encode_ignore: find_attr(&variant.attrs, "encode_ignore").is_some(),
383 name: &variant.ident,
384 field_type: match &variant.fields {
385 Fields::Unnamed(field) => &field.unnamed.first().unwrap().ty,
386 Fields::Named(_) => panic!("named fields unsupported"),
387 Fields::Unit => panic!("unexpected unit field"),
388 },
389 };
390
391 meta_variants.push(variant);
392 }
393
394 parsed::Type::MetaEnum(parsed::MetaEnum {
395 name: ty,
396 generics,
397 subtype_enum_ty,
398 meta_variants,
399 })
400 } else if let Some(repr_attr) = repr_attr {
401 parsed::Type::FieldlessEnum(parsed::FieldlessEnum {
402 name: ty,
403 underlying_repr: repr_attr.parse_args().expect("couldn't parse repr type"),
404 })
405 } else {
406 panic!("meta_enum or repr attribute missing")
407 }
408 }
409 Data::Union(_) => unimplemented!("union"),
410 };
411
412 implementor(enc_dec_type)
413}