1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

#[proc_macro_derive(EncodeAsType)]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);

    // TODO: make this configurable:
    let path_to_scale_encode: syn::Path = syn::parse_quote!(scale_encode);

    // what type is the derive macro declared on?
    match input.data {
        syn::Data::Enum(details) => {
            derive_enum(&path_to_scale_encode, input.ident, details).into()
        },
        syn::Data::Struct(details) => {
            derive_struct(&path_to_scale_encode, input.ident, details).into()
        },
        syn::Data::Union(_) => {
            syn::Error::new(input.ident.span(), "Unions are not supported by the EncodeAsType derive macro")
                .into_compile_error()
                .into()
        }
    }
}

fn derive_enum(path_to_scale_encode: &syn::Path, name: syn::Ident, details: syn::DataEnum) -> TokenStream2 {
    // For each variant we want to spit out a match arm.
    let match_arms = details.variants.into_iter().map(|variant| {
        let variant_name = variant.ident;
        let variant_name_str = variant_name.to_string();

        let (matcher, composite) = fields_to_matcher_and_composite(path_to_scale_encode, &variant.fields);
        quote!(#variant_name #matcher => #path_to_scale_encode::__internal::Variant(#variant_name_str, #composite))
    });

    quote!(
        impl #path_to_scale_encode::EncodeAsType for #name {
            fn encode_as_type(
                &self,
                type_id: u32,
                types: &#path_to_scale_encode::__internal::PortableRegistry,
                context: #path_to_scale_encode::Context
            ) -> Result<Vec<u8>, Error> {
                match self {
                    #( #match_arms ),*
                }
            }
        }
    )
}

fn derive_struct(path_to_scale_encode: &syn::Path, name: syn::Ident, details: syn::DataStruct) -> TokenStream2 {
    let (matcher, composite) = fields_to_matcher_and_composite(path_to_scale_encode, &details.fields);

    quote!(
        impl #path_to_scale_encode::EncodeAsType for #name {
            fn encode_as_type(
                &self,
                type_id: u32,
                types: &#path_to_scale_encode::__internal::PortableRegistry,
                context: #path_to_scale_encode::Context
            ) -> Result<Vec<u8>, Error> {
                let #matcher = self;
                #composite
            }
        }
    )
}

fn fields_to_matcher_and_composite(path_to_scale_encode: &syn::Path, fields: &syn::Fields) -> (TokenStream2, TokenStream2) {
    match fields {
        syn::Fields::Named(fields) => {
            let match_body = fields.named
                .iter()
                .map(|f| {
                    let field_name = &f.ident;
                    quote!(#field_name)
                });
            let tuple_body = fields.named
                .iter()
                .map(|f| {
                    let field_name_str = f.ident.as_ref().unwrap().to_string();
                    let field_name = &f.ident;
                    quote!((Some(#field_name_str), &#field_name))
                });
            (
                quote!({#( #match_body ),*}),
                quote!(#path_to_scale_encode::__internal::Composite((#( #tuple_body ),*)))
            )
        },
        syn::Fields::Unnamed(fields) => {
            let match_body = fields.unnamed
                .iter()
                .map(|f| {
                    let field_name = &f.ident;
                    quote!(#field_name)
                });
            let tuple_body = fields.unnamed
                .iter()
                .map(|f| {
                    let field_name = &f.ident;
                    quote!((None as &'static str, &#field_name))
                });
            (
                quote!((#( #match_body ),*)),
                quote!(#path_to_scale_encode::__internal::Composite((#( #tuple_body ),*)))
            )
        },
        syn::Fields::Unit => {
            (
                quote!(),
                quote!(#path_to_scale_encode::__internal::Composite(()))
            )
        }
    }
}