webb_proposal_derive/
lib.rs

1//! This crate provides a derive macro to implement the `Proposal` trait.
2
3use ethers_core::abi::HumanReadableParser;
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::parse::{Parse, ParseStream};
7use syn::{
8    parse_macro_input, DeriveInput, Fields, Generics, Ident, Lit, LitStr, Token,
9};
10
11#[cfg(not(feature = "std"))]
12#[doc(hidden)]
13extern crate alloc;
14
15/// Args for the proposal derive macro
16struct Args {
17    function_sig: LitStr,
18}
19
20impl Parse for Args {
21    fn parse(input: ParseStream) -> syn::Result<Self> {
22        let ident: Ident = input.parse()?;
23        let _: Token!(=) = input.parse()?;
24        let function_sig: Lit = input.parse()?;
25        if ident != "function_sig" {
26            return Err(syn::Error::new(
27                ident.span(),
28                format!("expected `function_sig` but got {ident} instead!"),
29            ));
30        }
31        let function_sig = match function_sig {
32            Lit::Str(v) => v,
33            unknown => {
34                return Err(syn::Error::new(
35                    ident.span(),
36                    format!("expected a string literal but got `{unknown:?}` instead!"),
37                ));
38            }
39        };
40        let args = Self { function_sig };
41        Ok(args)
42    }
43}
44
45fn derive_proposal(input: DeriveInput) -> syn::Result<TokenStream> {
46    // Used in the quasi-quotation below as `#name`.
47    let name = input.ident;
48    let generics = input.generics;
49    let struct_fields = match input.data {
50        syn::Data::Struct(s) => s.fields,
51        _ => {
52            return Err(syn::Error::new(
53                name.span(),
54                "expected the proposal to be a struct but got something else instead!",
55            ))
56        }
57    };
58    make_sure_it_has_header_field(&struct_fields)?;
59    let common = impl_common_methods(&name, &generics, &struct_fields)?;
60    let attr = input
61        .attrs
62        .iter()
63        .find(|a| a.path().is_ident("proposal"))
64        .ok_or_else( ||
65            syn::Error::new(name.span(), r#"missing function signature! please add #[proposal(function_sig = "...")] on the proposal struct"#),
66        )?;
67    let args: Args = attr.parse_args()?;
68    let trait_impl = derive_proposal_trait(name, generics, &args)?;
69    let mut expanded = common;
70    expanded.extend(trait_impl);
71    Ok(expanded)
72}
73
74fn derive_proposal_trait(
75    name: Ident,
76    generics: Generics,
77    args: &Args,
78) -> syn::Result<TokenStream> {
79    let function_sig = &args.function_sig;
80    let sig = match HumanReadableParser::parse_function(&function_sig.value()) {
81        Ok(f) => f.short_signature(),
82        Err(e) => return Err(syn::Error::new(function_sig.span(), e)),
83    };
84    // Had to do this hack.
85    let computed_function_sig = {
86        let (s0, s1, s2, s3) = (sig[0], sig[1], sig[2], sig[3]);
87        quote! { [#s0, #s1, #s2, #s3] }
88    };
89    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
90    // Build the output, possibly using quasi-quotation
91    #[cfg(not(feature = "std"))]
92    let to_vec = quote! {
93        fn to_vec(&self) -> alloc::vec::Vec<u8> {
94            crate::to_vec(self).expect("never fails to serialize")
95        }
96    };
97
98    #[cfg(feature = "std")]
99    let to_vec = quote! {
100        fn to_vec(&self) -> Vec<u8> {
101            crate::to_vec(self).expect("never fails to serialize")
102        }
103    };
104    let expanded = quote! {
105        impl #impl_generics crate::ProposalTrait for #name #ty_generics #where_clause {
106            fn header(&self) -> crate::ProposalHeader {
107                self.header
108            }
109
110            fn function_sig() -> crate::FunctionSignature {
111                crate::FunctionSignature(#computed_function_sig)
112            }
113
114            #to_vec
115        }
116    };
117
118    Ok(expanded.into())
119}
120
121/// Makes sure of the following:
122/// 1. that the proposal has the header field.
123/// 2. the header field has the correct type.
124/// 3. the header field is the first field in the struct.
125fn make_sure_it_has_header_field(fields: &Fields) -> syn::Result<()> {
126    let header_field = fields.iter().next().ok_or_else(|| {
127        syn::Error::new_spanned(fields, "expected at least one field")
128    })?;
129    let header_field_name = header_field.ident.as_ref().ok_or_else(|| {
130        syn::Error::new_spanned(
131            header_field,
132            "expected the first field to have a name",
133        )
134    })?;
135    if header_field_name != "header" {
136        return Err(syn::Error::new_spanned(
137            header_field_name,
138            "expected the first field to be named `header`",
139        ));
140    }
141    let header_field_type = &header_field.ty;
142    if !matches!(header_field_type, syn::Type::Path(p) if p.path.is_ident("ProposalHeader"))
143    {
144        return Err(syn::Error::new_spanned(
145            header_field_type,
146            "expected the first field to be of type `ProposalHeader`",
147        ));
148    }
149
150    Ok(())
151}
152
153/// Implements the common methods for any proposal.
154/// The common methods are:
155/// - `LENGTH` const
156/// - `new` function
157/// - `header` getter
158/// - `[field_name]` getter
159fn impl_common_methods(
160    name: &Ident,
161    generics: &Generics,
162    fields: &Fields,
163) -> syn::Result<TokenStream> {
164    // LENGTH const
165    // 40 is the length of the proposal header.
166    // + ANY additional field using `core::mem::size_of::<T>()`
167    // so the final length is 40 + sum of all fields.
168    let length = fields
169        .iter()
170        .skip(1)
171        .map(|f| {
172            let ty = &f.ty;
173            quote! { core::mem::size_of::<#ty>() }
174        })
175        .fold(quote! { 40 }, |acc, f| quote! { #acc + #f });
176    let length = quote! {
177        /// The length of the proposal in bytes.
178        pub const LENGTH: usize = #length;
179    };
180
181    // new function
182    // The new function takes the header and all the fields.
183    // It returns Self.
184    let new = {
185        let header_field = fields.iter().next().unwrap();
186        let header_field_name = header_field.ident.as_ref().unwrap();
187        let header_field_type = &header_field.ty;
188        let fields_names = fields.iter().skip(1).map(|f| {
189            let field_name = f.ident.as_ref().unwrap();
190            quote! { #field_name }
191        });
192        let fields_with_types = fields.iter().skip(1).map(|f| {
193            let field_name = f.ident.as_ref().unwrap();
194            let field_type = &f.ty;
195            quote! { #field_name: #field_type }
196        });
197        quote! {
198            /// Creates a new proposal.
199            #[must_use]
200            pub const fn new(#header_field_name: #header_field_type, #(#fields_with_types),*) -> Self {
201                Self { #header_field_name, #(#fields_names),* }
202            }
203        }
204    };
205
206    // header getter
207    // returns ProposalHeader.
208    let header_getter = {
209        let header_field = fields.iter().next().unwrap();
210        let header_field_name = header_field.ident.as_ref().unwrap();
211        quote! {
212            /// Returns the header of the proposal.
213            #[must_use]
214            pub const fn header(&self) -> ProposalHeader {
215                self.#header_field_name
216            }
217        }
218    };
219
220    // field getters
221    // returns a ref to the field type.
222    let field_getters = fields.iter().skip(1).map(|f| {
223        let field_name = f.ident.as_ref().unwrap();
224        let field_type = &f.ty;
225        quote! {
226            /// Returns a reference to that field of the proposal.
227            #[must_use]
228            pub const fn #field_name(&self) -> &#field_type {
229                &self.#field_name
230            }
231        }
232    });
233
234    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
235    let expanded = quote! {
236        impl #impl_generics #name #ty_generics #where_clause {
237            #length
238            #new
239            #header_getter
240            #(#field_getters)*
241        }
242    };
243    Ok(expanded.into())
244}
245
246#[proc_macro_derive(Proposal, attributes(proposal))]
247pub fn derive(input: TokenStream) -> TokenStream {
248    // Parse the input tokens into a syntax tree
249    let input = parse_macro_input!(input as DeriveInput);
250    derive_proposal(input).unwrap_or_else(|err| err.to_compile_error().into())
251}