scale_encode_derive/
lib.rs

1// Copyright (C) 2023 Parity Technologies (UK) Ltd. (admin@parity.io)
2// This file is a part of the scale-encode crate.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//         http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// https://github.com/rust-lang/rust-clippy/issues/12643.
17// related to `darling::default` attribute expansion
18#![allow(clippy::manual_unwrap_or_default)]
19
20use darling::FromAttributes;
21use proc_macro2::TokenStream as TokenStream2;
22use quote::{format_ident, quote};
23use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput};
24
25// The default attribute name for attrs
26const ATTR_NAME: &str = "encode_as_type";
27
28// Macro docs in main crate; don't add any docs here.
29#[proc_macro_derive(EncodeAsType, attributes(encode_as_type, codec))]
30pub fn derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32
33    // parse top level attrs.
34    let attrs = match TopLevelAttrs::parse(&input.attrs) {
35        Ok(attrs) => attrs,
36        Err(e) => return e.write_errors().into(),
37    };
38
39    derive_with_attrs(attrs, input).into()
40}
41
42fn derive_with_attrs(attrs: TopLevelAttrs, input: DeriveInput) -> TokenStream2 {
43    // what type is the derive macro declared on?
44    match &input.data {
45        syn::Data::Enum(details) => generate_enum_impl(attrs, &input, details),
46        syn::Data::Struct(details) => generate_struct_impl(attrs, &input, details),
47        syn::Data::Union(_) => syn::Error::new(
48            input.ident.span(),
49            "Unions are not supported by the EncodeAsType macro",
50        )
51        .into_compile_error(),
52    }
53}
54
55fn generate_enum_impl(
56    attrs: TopLevelAttrs,
57    input: &DeriveInput,
58    details: &syn::DataEnum,
59) -> TokenStream2 {
60    let path_to_scale_encode = &attrs.crate_path;
61    let path_to_type: syn::Path = input.ident.clone().into();
62    let (impl_generics, ty_generics, where_clause) = handle_generics(&attrs, &input.generics);
63
64    // For each variant we want to spit out a match arm.
65    let match_arms = details.variants.iter().map(|variant| {
66        let variant_name = &variant.ident;
67        let variant_name_str = variant_name.to_string();
68
69        let (matcher, composite) =
70            fields_to_matcher_and_composite(path_to_scale_encode, &variant.fields);
71        quote!(
72            Self::#variant_name #matcher => {
73                #path_to_scale_encode::Variant { name: #variant_name_str, fields: #composite }
74                    .encode_variant_as_type_to(
75                        __encode_as_type_type_id,
76                        __encode_as_type_types,
77                        __encode_as_type_out
78                    )
79            }
80        )
81    });
82
83    quote!(
84        impl #impl_generics #path_to_scale_encode::EncodeAsType for #path_to_type #ty_generics #where_clause {
85            #[allow(unused_variables)]
86            fn encode_as_type_to<ScaleEncodeResolver: #path_to_scale_encode::TypeResolver>(
87                &self,
88                // long variable names to prevent conflict with struct field names:
89                __encode_as_type_type_id: ScaleEncodeResolver::TypeId,
90                __encode_as_type_types: &ScaleEncodeResolver,
91                __encode_as_type_out: &mut #path_to_scale_encode::Vec<u8>
92            ) -> Result<(), #path_to_scale_encode::Error> {
93                match self {
94                    #( #match_arms, )*
95                    // This will never be encountered, but in case the enum has no variants
96                    // the compiler will still want something to be spat out here:
97                    _ => unreachable!()
98                }
99            }
100        }
101    )
102}
103
104fn generate_struct_impl(
105    attrs: TopLevelAttrs,
106    input: &DeriveInput,
107    details: &syn::DataStruct,
108) -> TokenStream2 {
109    let path_to_scale_encode = &attrs.crate_path;
110    let path_to_type: syn::Path = input.ident.clone().into();
111    let (impl_generics, ty_generics, where_clause) = handle_generics(&attrs, &input.generics);
112
113    let (matcher, composite) =
114        fields_to_matcher_and_composite(path_to_scale_encode, &details.fields);
115
116    quote!(
117        impl #impl_generics #path_to_scale_encode::EncodeAsType for #path_to_type #ty_generics #where_clause {
118            #[allow(unused_variables)]
119            fn encode_as_type_to<ScaleEncodeResolver: #path_to_scale_encode::TypeResolver>(
120                &self,
121                // long variable names to prevent conflict with struct field names:
122                __encode_as_type_type_id: ScaleEncodeResolver::TypeId,
123                __encode_as_type_types: &ScaleEncodeResolver,
124                __encode_as_type_out: &mut #path_to_scale_encode::Vec<u8>
125            ) -> Result<(), #path_to_scale_encode::Error> {
126                let #path_to_type #matcher = self;
127                #composite.encode_composite_as_type_to(
128                    __encode_as_type_type_id,
129                    __encode_as_type_types,
130                    __encode_as_type_out
131                )
132            }
133        }
134        impl #impl_generics #path_to_scale_encode::EncodeAsFields for #path_to_type #ty_generics #where_clause {
135            #[allow(unused_variables)]
136            fn encode_as_fields_to<ScaleEncodeResolver: #path_to_scale_encode::TypeResolver>(
137                &self,
138                // long variable names to prevent conflict with struct field names:
139                __encode_as_type_fields: &mut dyn #path_to_scale_encode::FieldIter<'_, ScaleEncodeResolver::TypeId>,
140                __encode_as_type_types: &ScaleEncodeResolver,
141                __encode_as_type_out: &mut #path_to_scale_encode::Vec<u8>
142            ) -> Result<(), #path_to_scale_encode::Error> {
143                let #path_to_type #matcher = self;
144                #composite.encode_composite_fields_to(
145                    __encode_as_type_fields,
146                    __encode_as_type_types,
147                    __encode_as_type_out
148                )
149            }
150        }
151    )
152}
153
154fn handle_generics<'a>(
155    attrs: &TopLevelAttrs,
156    generics: &'a syn::Generics,
157) -> (
158    syn::ImplGenerics<'a>,
159    syn::TypeGenerics<'a>,
160    syn::WhereClause,
161) {
162    let path_to_crate = &attrs.crate_path;
163    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
164
165    let mut where_clause = where_clause.cloned().unwrap_or(syn::parse_quote!(where));
166
167    if let Some(where_predicates) = &attrs.trait_bounds {
168        // if custom trait bounds are given, append those to the where clause.
169        where_clause.predicates.extend(where_predicates.clone());
170    } else {
171        // else, append our default EncodeAsType bounds to the where clause.
172        for param in generics.type_params() {
173            let ty = &param.ident;
174            where_clause
175                .predicates
176                .push(syn::parse_quote!(#ty: #path_to_crate::EncodeAsType))
177        }
178    }
179
180    (impl_generics, ty_generics, where_clause)
181}
182
183fn fields_to_matcher_and_composite(
184    path_to_scale_encode: &syn::Path,
185    fields: &syn::Fields,
186) -> (TokenStream2, TokenStream2) {
187    match fields {
188        syn::Fields::Named(fields) => {
189            let match_body = fields.named.iter().map(|f| {
190                let field_name = &f.ident;
191                quote!(#field_name)
192            });
193            let tuple_body = fields.named
194                .iter()
195                .filter(|f| !should_skip(&f.attrs))
196                .map(|f| {
197                    let field_name_str = f.ident.as_ref().unwrap().to_string();
198                    let field_name = &f.ident;
199                    quote!((Some(#field_name_str), #path_to_scale_encode::CompositeField::new(#field_name)))
200                });
201
202            (
203                quote!({#( #match_body ),*}),
204                quote!(#path_to_scale_encode::Composite::new([#( #tuple_body ),*].into_iter())),
205            )
206        }
207        syn::Fields::Unnamed(fields) => {
208            let field_idents = fields
209                .unnamed
210                .iter()
211                .enumerate()
212                .map(|(idx, f)| (format_ident!("_{idx}"), f));
213
214            let match_body = field_idents.clone().map(|(i, _)| quote!(#i));
215            let tuple_body = field_idents
216                .filter(|(_, f)| !should_skip(&f.attrs))
217                .map(|(i, _)| quote!((None as Option<&'static str>, #path_to_scale_encode::CompositeField::new(#i))));
218
219            (
220                quote!((#( #match_body ),*)),
221                quote!(#path_to_scale_encode::Composite::new([#( #tuple_body ),*].into_iter())),
222            )
223        }
224        syn::Fields::Unit => (
225            quote!(),
226            quote!(#path_to_scale_encode::Composite::new(([] as [(Option<&'static str>, #path_to_scale_encode::CompositeField<_>);0]).into_iter())),
227        ),
228    }
229}
230
231struct TopLevelAttrs {
232    // path to the scale_encode crate, in case it's not a top level dependency.
233    crate_path: syn::Path,
234    // allow custom trait bounds to be used instead of the defaults.
235    trait_bounds: Option<Punctuated<syn::WherePredicate, syn::Token!(,)>>,
236}
237
238impl TopLevelAttrs {
239    fn parse(attrs: &[syn::Attribute]) -> darling::Result<Self> {
240        use darling::FromMeta;
241
242        #[derive(FromMeta)]
243        struct TopLevelAttrsInner {
244            #[darling(default)]
245            crate_path: Option<syn::Path>,
246            #[darling(default)]
247            trait_bounds: Option<Punctuated<syn::WherePredicate, syn::Token!(,)>>,
248        }
249
250        let mut res = TopLevelAttrs {
251            crate_path: syn::parse_quote!(::scale_encode),
252            trait_bounds: None,
253        };
254
255        // look at each top level attr. parse any for encode_as_type.
256        for attr in attrs {
257            if !attr.path().is_ident(ATTR_NAME) {
258                continue;
259            }
260            let meta = &attr.meta;
261            let parsed_attrs = TopLevelAttrsInner::from_meta(meta)?;
262
263            res.trait_bounds = parsed_attrs.trait_bounds;
264            if let Some(crate_path) = parsed_attrs.crate_path {
265                res.crate_path = crate_path;
266            }
267        }
268
269        Ok(res)
270    }
271}
272
273// Checks if the attributes contain `skip`.
274//
275// NOTE: Since we only care about `skip` at the moment, we just expose this helper,
276// but if we add more attrs we can expose `FieldAttrs` properly:
277fn should_skip(attrs: &[syn::Attribute]) -> bool {
278    #[derive(FromAttributes, Default)]
279    #[darling(attributes(encode_as_type, codec))]
280    struct FieldAttrs {
281        #[darling(default)]
282        skip: bool,
283    }
284
285    FieldAttrs::from_attributes(attrs).unwrap_or_default().skip
286}