Skip to main content

wasefire_wire_derive/
lib.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use proc_macro2::{Span, TokenStream};
18use quote::quote;
19use syn::parse_quote;
20use syn::spanned::Spanned;
21
22#[cfg(test)]
23mod tests;
24
25// TODO: Use {parse_}quote_spanned.
26
27#[proc_macro_derive(Wire, attributes(wire))]
28pub fn derive_wire(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    let item = syn::parse_macro_input!(item as syn::DeriveInput);
30    print_item(derive_item(&item)).into()
31}
32
33#[proc_macro_attribute]
34pub fn internal_wire(
35    attr: proc_macro::TokenStream, item: proc_macro::TokenStream,
36) -> proc_macro::TokenStream {
37    let item = syn::parse_macro_input!(item as syn::DeriveInput);
38    if !attr.is_empty() {
39        return syn::Error::new(item.span(), "unexpected attribute").into_compile_error().into();
40    }
41    print_item(derive_item(&item)).into()
42}
43
44fn print_item(item: syn::Result<syn::ItemImpl>) -> TokenStream {
45    match item {
46        #[cfg(feature = "_dev")]
47        Ok(_) => TokenStream::new(),
48        #[cfg(not(feature = "_dev"))]
49        Ok(x) => quote!(#x),
50        Err(e) => e.into_compile_error(),
51    }
52}
53
54fn derive_item(item: &syn::DeriveInput) -> syn::Result<syn::ItemImpl> {
55    let mut attrs = Attrs::parse(&item.attrs, AttrsKind::Item)?;
56    let wire = match attrs.crate_.take() {
57        Attr::Absent => syn::Ident::new("wasefire_wire", Span::call_site()).into(),
58        Attr::Present(_, x) => x,
59    };
60    let param = match attrs.param.take() {
61        Attr::Absent => {
62            let first = item.generics.lifetimes().next();
63            first.map_or_else(
64                || syn::Lifetime::new("'wire", Span::call_site()),
65                |x| x.lifetime.clone(),
66            )
67        }
68        Attr::Present(_, x) => x,
69    };
70    let type_param = syn::Lifetime::new(&format!("'_{}", param.ident), param.span());
71    let (_, parameters, _) = item.generics.split_for_impl();
72    let mut generics = item.generics.clone();
73    if generics.lifetimes().all(|x| x.lifetime != param) {
74        generics.params.insert(0, parse_quote!(#param));
75    }
76    if let Attr::Present(_, where_) = attrs.where_.take() {
77        generics.make_where_clause().predicates.extend(where_);
78    }
79    let (generics, _, where_clause) = generics.split_for_impl();
80    let (schema, encode, decode) = match &item.data {
81        syn::Data::Struct(x) => derive_struct(&mut attrs, &wire, &param, &item.ident, x)?,
82        syn::Data::Enum(x) => derive_enum(&mut attrs, &wire, &item.ident, x)?,
83        syn::Data::Union(_) => {
84            return Err(syn::Error::new(item.span(), "unions are not supported"));
85        }
86    };
87    if let Attr::Present(span, _) = attrs.range.take() {
88        return Err(syn::Error::new(span, "range is only supported for enums"));
89    }
90    if let Attr::Present(span, _) = attrs.refine.take() {
91        return Err(syn::Error::new(span, "refine is only supported for structs with one field"));
92    }
93    let ident = &item.ident;
94    let statics: Vec<syn::Type> = match attrs.static_.take() {
95        Attr::Absent => Vec::new(),
96        Attr::Present(_, xs) => xs.into_iter().map(|x| parse_quote!(#x)).collect(),
97    };
98    let mut visitor = MakeType { src: &param, dst: &type_param, statics: &statics };
99    let mut type_: syn::Type = parse_quote!(#ident #parameters);
100    syn::visit_mut::visit_type_mut(&mut visitor, &mut type_);
101    let schema = match cfg!(feature = "schema") {
102        false => None,
103        true => Some(quote! {
104            fn schema(rules: &mut #wire::internal::Rules) {
105                #(#schema)*
106            }
107        }),
108    };
109    let impl_wire: syn::ItemImpl = parse_quote! {
110        #[automatically_derived]
111        impl #generics #wire::internal::Wire<#param> for #ident #parameters #where_clause {
112            type Type<#type_param> = #type_;
113            #schema
114            fn encode(
115                &self, writer: &mut #wire::internal::Writer<#param>
116            ) -> #wire::internal::Result<()> {
117                #(#encode)*
118            }
119            fn decode(
120                reader: &mut #wire::internal::Reader<#param>
121            ) -> #wire::internal::Result<Self> {
122                #(#decode)*
123            }
124        }
125    };
126    Ok(impl_wire)
127}
128
129fn derive_struct(
130    attrs: &mut Attrs, wire: &syn::Path, param: &syn::Lifetime, name: &syn::Ident,
131    item: &syn::DataStruct,
132) -> syn::Result<(Vec<syn::Stmt>, Vec<syn::Stmt>, Vec<syn::Stmt>)> {
133    let mut types = Types::default();
134    if item.fields.len() == 1
135        && let Attr::Present(_, refine) = attrs.refine.take()
136    {
137        let inner = item.fields.iter().next().unwrap();
138        let ty = &inner.ty;
139        let schema = parse_quote! {
140            rules.alias::<
141                Self::Type<'static>, <#ty as #wire::internal::Wire<#param>>::Type<'static>>();
142        };
143        let encode: syn::Expr = match &inner.ident {
144            None => parse_quote!(self.0.encode(writer)),
145            Some(name) => parse_quote!(self.#name.encode(writer)),
146        };
147        let decode: syn::Expr = parse_quote!(#refine(<#ty>::decode(reader)?));
148        return Ok((
149            vec![schema],
150            vec![syn::Stmt::Expr(encode, None)],
151            vec![syn::Stmt::Expr(decode, None)],
152        ));
153    }
154    let (mut schema, mut encode, decode) =
155        derive_fields(wire, &parse_quote!(#name), &item.fields, &mut types)?;
156    if cfg!(feature = "schema") {
157        types.stmts(&mut schema, wire, "struct_", "fields");
158    }
159    if !encode.is_empty() {
160        let path = parse_quote!(#name);
161        let encode_pat = fields_pat(&path, &item.fields, false);
162        encode.insert(0, parse_quote!(let #encode_pat = self;));
163    }
164    Ok((schema, encode, decode))
165}
166
167fn derive_enum(
168    attrs: &mut Attrs, wire: &syn::Path, name: &syn::Ident, item: &syn::DataEnum,
169) -> syn::Result<(Vec<syn::Stmt>, Vec<syn::Stmt>, Vec<syn::Stmt>)> {
170    let mut schema = Vec::new();
171    let mut types = Types::default();
172    let mut encode = Vec::<syn::Arm>::new();
173    let mut decode = Vec::<syn::Arm>::new();
174    let mut tags = Tags::default();
175    match item.variants.len() {
176        _ if !cfg!(feature = "schema") => (),
177        0 => schema.push(parse_quote!(let variants = #wire::internal::Vec::new();)),
178        n => schema.push(parse_quote!(let mut variants = #wire::internal::Vec::with_capacity(#n);)),
179    }
180    for variant in &item.variants {
181        let mut attrs = Attrs::parse(&variant.attrs, AttrsKind::Variant)?;
182        let tag = match attrs.tag.take() {
183            Attr::Present(span, tag) => tags.use_(span, tag)?,
184            Attr::Absent => tags.next(),
185        };
186        let ident = &variant.ident;
187        let path = parse_quote!(#name::#ident);
188        let (variant_schema, variant_encode, variant_decode) =
189            derive_fields(wire, &path, &variant.fields, &mut types)?;
190        let pat_encode = fields_pat(&path, &variant.fields, true);
191        let ident_schema = format!("{ident}");
192        if cfg!(feature = "schema") {
193            schema.push(parse_quote!({
194                #(#variant_schema)*
195                variants.push((#ident_schema, #tag, fields));
196            }));
197        }
198        encode.push(parse_quote!(#pat_encode => {
199            #wire::internal::encode_tag(#tag, writer)?;
200            #(#variant_encode)*
201        }));
202        decode.push(parse_quote!(#tag => { #(#variant_decode)* }));
203    }
204    if let Attr::Present(span, range) = attrs.range.take()
205        && (tags.used.len() as u32 != range || tags.used.iter().any(|x| range <= *x))
206    {
207        return Err(syn::Error::new(span, "tags don't form a range"));
208    }
209    if cfg!(feature = "schema") {
210        types.stmts(&mut schema, wire, "enum_", "variants");
211    }
212    let encode = parse_quote!(match *self { #(#encode)* });
213    let decode = parse_quote! {
214        let tag = #wire::internal::decode_tag(reader)?;
215        match tag {
216            #(#decode)*
217            _ => Err(#wire::internal::INVALID_TAG),
218        }
219    };
220    Ok((schema, encode, decode))
221}
222
223fn derive_fields<'a>(
224    wire: &syn::Path, name: &syn::Path, fields: &'a syn::Fields, types: &mut Types<'a>,
225) -> syn::Result<(Vec<syn::Stmt>, Vec<syn::Stmt>, Vec<syn::Stmt>)> {
226    let mut schema = Vec::new();
227    let mut encode = Vec::new();
228    let mut decode = Vec::new();
229    match fields.len() {
230        _ if !cfg!(feature = "schema") => (),
231        0 => schema.push(parse_quote!(let fields = #wire::internal::Vec::new();)),
232        n => schema.push(parse_quote!(let mut fields = #wire::internal::Vec::with_capacity(#n);)),
233    }
234    for (i, field) in fields.iter().enumerate() {
235        let _ = Attrs::parse(&field.attrs, AttrsKind::Invalid)?;
236        let name = field_name(i, field);
237        let ty = &field.ty;
238        if cfg!(feature = "schema") {
239            let name_str: syn::Expr = match &field.ident {
240                Some(x) => {
241                    let x = format!("{x}");
242                    parse_quote!(Some(#x))
243                }
244                None => parse_quote!(None),
245            };
246            schema.push(parse_quote!(fields.push((#name_str, #wire::internal::type_id::<#ty>()));));
247            types.insert(ty);
248        }
249        encode.push(parse_quote!(<#ty as #wire::internal::Wire>::encode(#name, writer)?;));
250        decode.push(parse_quote!(let #name = <#ty as #wire::internal::Wire>::decode(reader)?;));
251    }
252    encode.push(syn::Stmt::Expr(parse_quote!(Ok(())), None));
253    let fields_pat = fields_pat(name, fields, false);
254    decode.push(syn::Stmt::Expr(parse_quote!(Ok(#fields_pat)), None));
255    Ok((schema, encode, decode))
256}
257
258fn fields_pat(name: &syn::Path, fields: &syn::Fields, ref_: bool) -> syn::Pat {
259    let ref_: Option<syn::Token![ref]> = ref_.then_some(parse_quote!(ref));
260    let names = fields.iter().enumerate().map(|(i, field)| field_name(i, field));
261    match fields {
262        syn::Fields::Named(_) => parse_quote!(#name { #(#ref_ #names),* }),
263        syn::Fields::Unnamed(_) => parse_quote!(#name(#(#ref_ #names),*)),
264        syn::Fields::Unit => parse_quote!(#name),
265    }
266}
267
268fn field_name(i: usize, field: &syn::Field) -> syn::Ident {
269    match &field.ident {
270        Some(x) => x.clone(),
271        None => syn::Ident::new(&format!("x{i}"), field.span()),
272    }
273}
274
275struct MakeType<'a> {
276    src: &'a syn::Lifetime,
277    dst: &'a syn::Lifetime,
278    statics: &'a [syn::Type],
279}
280
281impl syn::visit_mut::VisitMut for MakeType<'_> {
282    fn visit_generic_argument_mut(&mut self, x: &mut syn::GenericArgument) {
283        match x {
284            syn::GenericArgument::Lifetime(a) if a == self.src => *a = self.dst.clone(),
285            syn::GenericArgument::Type(t) if !self.statics.contains(t) => {
286                let a = &self.dst;
287                *x = parse_quote!(#t::Type<#a>)
288            }
289            _ => (),
290        }
291    }
292}
293
294#[derive(Default)]
295struct Tags {
296    used: HashSet<u32>,
297    next: u32,
298}
299
300impl Tags {
301    fn use_(&mut self, span: Span, tag: u32) -> syn::Result<u32> {
302        if !self.used.insert(tag) {
303            return Err(syn::Error::new(span, "duplicate tag"));
304        }
305        Ok(self.update_next(tag))
306    }
307
308    fn next(&mut self) -> u32 {
309        while !self.used.insert(self.next) {
310            self.next = self.next.wrapping_add(1);
311        }
312        self.update_next(self.next)
313    }
314
315    fn update_next(&mut self, tag: u32) -> u32 {
316        self.next = tag.wrapping_add(1);
317        tag
318    }
319}
320
321#[derive(Default)]
322struct Types<'a>(Vec<&'a syn::Type>);
323
324impl<'a> Types<'a> {
325    fn insert(&mut self, x: &'a syn::Type) {
326        if !self.0.contains(&x) {
327            self.0.push(x);
328        }
329    }
330
331    fn stmts(&self, schema: &mut Vec<syn::Stmt>, wire: &syn::Path, fun: &str, var: &str) {
332        let types: Vec<syn::Stmt> =
333            self.0.iter().map(|ty| parse_quote!(#wire::internal::schema::<#ty>(rules);)).collect();
334        let fun = syn::Ident::new(fun, Span::call_site());
335        let var = syn::Ident::new(var, Span::call_site());
336        schema.push(parse_quote! {
337            if rules.#fun::<Self::Type<'static>>(#var) {
338                #(#types)*
339            }
340        });
341    }
342}
343
344enum AttrsKind {
345    Item,
346    Variant,
347    Invalid,
348}
349
350#[derive(PartialEq, Eq)]
351enum AttrKind {
352    Crate,
353    Param,
354    Where,
355    Tag,
356    Static,
357    Range,
358}
359
360#[derive(Default)]
361enum Attr<T> {
362    #[default]
363    Absent,
364    Present(Span, T),
365}
366
367impl<T> Attr<T> {
368    fn span(&self) -> Option<Span> {
369        match self {
370            Attr::Absent => None,
371            Attr::Present(x, _) => Some(*x),
372        }
373    }
374
375    fn set(&mut self, span: Span, value: T) -> syn::Result<()> {
376        match self {
377            Attr::Absent => Ok(*self = Attr::Present(span, value)),
378            Attr::Present(other, _) => {
379                let mut error = syn::Error::new(span, "attribute already defined");
380                error.combine(syn::Error::new(*other, "first attribute definition"));
381                Err(error)
382            }
383        }
384    }
385
386    fn take(&mut self) -> Self {
387        std::mem::take(self)
388    }
389}
390
391impl<T> Attr<Vec<T>> {
392    fn push(&mut self, span: Span, value: T) {
393        match self {
394            Attr::Absent => *self = Attr::Present(span, vec![value]),
395            Attr::Present(_, values) => values.push(value),
396        }
397    }
398}
399
400#[derive(Default)]
401struct Attrs {
402    crate_: Attr<syn::Path>,
403    param: Attr<syn::Lifetime>,
404    where_: Attr<Vec<syn::WherePredicate>>,
405    tag: Attr<u32>,
406    static_: Attr<Vec<syn::Ident>>,
407    range: Attr<u32>,
408    refine: Attr<syn::Path>,
409}
410
411impl Attrs {
412    fn parse(attrs: &[syn::Attribute], kind: AttrsKind) -> syn::Result<Self> {
413        let mut result = Attrs::default();
414        for attr in attrs {
415            result.parse_attr(attr)?;
416        }
417        result.check_kind(kind)?;
418        Ok(result)
419    }
420
421    fn parse_attr(&mut self, attr: &syn::Attribute) -> syn::Result<()> {
422        if !attr.path().is_ident("wire") {
423            return Ok(());
424        }
425        attr.parse_nested_meta(|meta| {
426            if meta.path.is_ident("crate") {
427                // #[wire(crate = <path>)]
428                self.crate_.set(attr.span(), meta.value()?.parse()?)?;
429            }
430            if meta.path.is_ident("param") {
431                // #[wire(param = <lifetime>)]
432                self.param.set(attr.span(), meta.value()?.parse()?)?;
433            }
434            if meta.path.is_ident("where") {
435                // #[wire(where = <where_predicate>)]
436                self.where_.push(attr.span(), meta.value()?.parse()?);
437            }
438            if meta.path.is_ident("tag") {
439                // #[wire(tag = <u32>)]
440                let tag: syn::LitInt = meta.value()?.parse()?;
441                self.tag.set(attr.span(), tag.base10_parse()?)?;
442            }
443            if meta.path.is_ident("static") {
444                // #[wire(static = <ident>)]
445                self.static_.push(attr.span(), meta.value()?.parse()?);
446            }
447            if meta.path.is_ident("range") {
448                // #[wire(range = <u32>)]
449                let range: syn::LitInt = meta.value()?.parse()?;
450                self.range.set(attr.span(), range.base10_parse()?)?;
451            }
452            if meta.path.is_ident("refine") {
453                // #[wire(refine = <path>)]
454                self.refine.set(attr.span(), meta.value()?.parse()?)?;
455            }
456            Ok(())
457        })
458    }
459
460    fn check_kind(&self, kind: AttrsKind) -> syn::Result<()> {
461        let expected: &[AttrKind] = match kind {
462            AttrsKind::Item => &[
463                AttrKind::Crate,
464                AttrKind::Param,
465                AttrKind::Where,
466                AttrKind::Static,
467                AttrKind::Range,
468            ],
469            AttrsKind::Variant => &[AttrKind::Tag],
470            AttrsKind::Invalid => &[],
471        };
472        let check = |name, actual, expected: &[AttrKind]| {
473            if let Some(actual) = actual
474                && !expected.contains(&name)
475            {
476                return Err(syn::Error::new(actual, "unexpected attribute"));
477            }
478            Ok(())
479        };
480        check(AttrKind::Crate, self.crate_.span(), expected)?;
481        check(AttrKind::Param, self.param.span(), expected)?;
482        check(AttrKind::Where, self.where_.span(), expected)?;
483        check(AttrKind::Tag, self.tag.span(), expected)?;
484        check(AttrKind::Static, self.static_.span(), expected)?;
485        check(AttrKind::Range, self.range.span(), expected)?;
486        Ok(())
487    }
488}