starknet_core_derive/
lib.rs

1//! Procedural derive macros for the `starknet-core` crate.
2
3#![deny(missing_docs)]
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{
9    parse::{Error as ParseError, Parse, ParseStream},
10    parse_macro_input, DeriveInput, Fields, LitInt, LitStr, Meta, Token,
11};
12
13#[derive(Default)]
14struct Args {
15    core: Option<LitStr>,
16}
17
18impl Args {
19    fn merge(&mut self, other: Self) {
20        if let Some(core) = other.core {
21            if self.core.is_some() {
22                panic!("starknet attribute `core` defined more than once");
23            } else {
24                self.core = Some(core);
25            }
26        }
27    }
28}
29
30impl Parse for Args {
31    fn parse(input: ParseStream<'_>) -> Result<Self, ParseError> {
32        let mut core: Option<LitStr> = None;
33
34        while !input.is_empty() {
35            let lookahead = input.lookahead1();
36            if lookahead.peek(kw::core) {
37                let _ = input.parse::<kw::core>()?;
38                let _ = input.parse::<Token![=]>()?;
39                let value = input.parse::<LitStr>()?;
40
41                match core {
42                    Some(_) => {
43                        return Err(ParseError::new(
44                            Span::call_site(),
45                            "starknet attribute `core` defined more than once",
46                        ))
47                    }
48                    None => {
49                        core = Some(value);
50                    }
51                }
52            } else {
53                return Err(lookahead.error());
54            }
55        }
56
57        Ok(Self { core })
58    }
59}
60
61mod kw {
62    syn::custom_keyword!(core);
63}
64
65/// Derives the `Encode` trait.
66#[proc_macro_derive(Encode, attributes(starknet))]
67pub fn derive_encode(input: TokenStream) -> TokenStream {
68    let input: DeriveInput = parse_macro_input!(input);
69    let ident = &input.ident;
70
71    let core = derive_core_path(&input);
72
73    let impl_block = match input.data {
74        syn::Data::Struct(data) => {
75            let field_impls = data.fields.iter().enumerate().map(|(ind_field, field)| {
76                let field_ident = match &field.ident {
77                    Some(field_ident) => quote! { self.#field_ident },
78                    None => {
79                        let ind_field = syn::Index::from(ind_field);
80                        quote! { self.#ind_field }
81                    }
82                };
83                let field_type = &field.ty;
84
85                quote! {
86                    <#field_type as #core::codec::Encode>::encode(&#field_ident, writer)?;
87                }
88            });
89
90            quote! {
91                #(#field_impls)*
92            }
93        }
94        syn::Data::Enum(data) => {
95            let variant_impls =
96                data.variants
97                    .iter()
98                    .enumerate()
99                    .map(|(ind_variant, variant)| {
100                        let variant_ident = &variant.ident;
101                        let ind_variant = int_to_felt(ind_variant, &core);
102
103                        match &variant.fields {
104                            Fields::Named(fields_named) => {
105                                let names = fields_named
106                                    .named
107                                    .iter()
108                                    .map(|field| field.ident.as_ref().unwrap());
109
110                                let field_impls = fields_named.named.iter().map(|field| {
111                                    let field_ident = field.ident.as_ref().unwrap();
112                                    let field_type = &field.ty;
113
114                                    quote! {
115                                        <#field_type as #core::codec::Encode>
116                                            ::encode(#field_ident, writer)?;
117                                    }
118                                });
119
120                                quote! {
121                                    Self::#variant_ident { #(#names),* } => {
122                                        writer.write(#ind_variant);
123                                        #(#field_impls)*
124                                    },
125                                }
126                            }
127                            Fields::Unnamed(fields_unnamed) => {
128                                let names = fields_unnamed.unnamed.iter().enumerate().map(
129                                    |(ind_field, _)| {
130                                        syn::Ident::new(
131                                            &format!("field_{}", ind_field),
132                                            Span::call_site(),
133                                        )
134                                    },
135                                );
136
137                                let field_impls = fields_unnamed.unnamed.iter().enumerate().map(
138                                    |(ind_field, field)| {
139                                        let field_ident = syn::Ident::new(
140                                            &format!("field_{}", ind_field),
141                                            Span::call_site(),
142                                        );
143                                        let field_type = &field.ty;
144
145                                        quote! {
146                                            <#field_type as #core::codec::Encode>
147                                                ::encode(#field_ident, writer)?;
148                                        }
149                                    },
150                                );
151
152                                quote! {
153                                    Self::#variant_ident( #(#names),* ) => {
154                                        writer.write(#ind_variant);
155                                        #(#field_impls)*
156                                    },
157                                }
158                            }
159                            Fields::Unit => {
160                                quote! {
161                                    Self::#variant_ident => {
162                                        writer.write(#ind_variant);
163                                    },
164                                }
165                            }
166                        }
167                    });
168
169            quote! {
170                match self {
171                    #(#variant_impls)*
172                }
173            }
174        }
175        syn::Data::Union(_) => panic!("union type not supported"),
176    };
177
178    quote! {
179        #[automatically_derived]
180        impl #core::codec::Encode for #ident {
181            fn encode<W: #core::codec::FeltWriter>(&self, writer: &mut W)
182                -> ::core::result::Result<(), #core::codec::Error> {
183                #impl_block
184
185                Ok(())
186            }
187        }
188    }
189    .into()
190}
191
192/// Derives the `Decode` trait.
193#[proc_macro_derive(Decode, attributes(starknet))]
194pub fn derive_decode(input: TokenStream) -> TokenStream {
195    let input: DeriveInput = parse_macro_input!(input);
196    let ident = &input.ident;
197
198    let core = derive_core_path(&input);
199
200    let impl_block = match input.data {
201        syn::Data::Struct(data) => match &data.fields {
202            Fields::Named(fields_named) => {
203                let field_impls = fields_named.named.iter().map(|field| {
204                    let field_ident = &field.ident;
205                    let field_type = &field.ty;
206
207                    quote! {
208                        #field_ident: <#field_type as #core::codec::Decode>
209                            ::decode_iter(iter)?,
210                    }
211                });
212
213                quote! {
214                    Ok(Self {
215                        #(#field_impls)*
216                    })
217                }
218            }
219            Fields::Unnamed(fields_unnamed) => {
220                let field_impls = fields_unnamed.unnamed.iter().map(|field| {
221                    let field_type = &field.ty;
222                    quote! {
223                        <#field_type as #core::codec::Decode>::decode_iter(iter)?
224                    }
225                });
226
227                quote! {
228                    Ok(Self (
229                        #(#field_impls),*
230                    ))
231                }
232            }
233            Fields::Unit => {
234                quote! {
235                    Ok(Self)
236                }
237            }
238        },
239        syn::Data::Enum(data) => {
240            let variant_impls = data
241                .variants
242                .iter()
243                .enumerate()
244                .map(|(ind_variant, variant)| {
245                    let variant_ident = &variant.ident;
246                    let ind_variant = int_to_felt(ind_variant, &core);
247
248                    let decode_impl = match &variant.fields {
249                        Fields::Named(fields_named) => {
250                            let field_impls = fields_named.named.iter().map(|field| {
251                                let field_ident = field.ident.as_ref().unwrap();
252                                let field_type = &field.ty;
253
254                                quote! {
255                                    #field_ident: <#field_type as #core::codec::Decode>
256                                        ::decode_iter(iter)?,
257                                }
258                            });
259
260                            quote! {
261                                return Ok(Self::#variant_ident {
262                                    #(#field_impls)*
263                                });
264                            }
265                        }
266                        Fields::Unnamed(fields_unnamed) => {
267                            let field_impls = fields_unnamed.unnamed.iter().map(|field| {
268                                let field_type = &field.ty;
269
270                                quote! {
271                                    <#field_type as #core::codec::Decode>::decode_iter(iter)?
272                                }
273                            });
274
275                            quote! {
276                                return Ok(Self::#variant_ident( #(#field_impls),* ));
277                            }
278                        }
279                        Fields::Unit => {
280                            quote! {
281                                return Ok(Self::#variant_ident);
282                            }
283                        }
284                    };
285
286                    quote! {
287                        if tag == &#ind_variant {
288                            #decode_impl
289                        }
290                    }
291                });
292
293            let ident = ident.to_string();
294
295            quote! {
296                let tag = iter.next().ok_or_else(#core::codec::Error::input_exhausted)?;
297
298                #(#variant_impls)*
299
300                Err(#core::codec::Error::unknown_enum_tag(tag, #ident))
301            }
302        }
303        syn::Data::Union(_) => panic!("union type not supported"),
304    };
305
306    quote! {
307        #[automatically_derived]
308        impl<'a> #core::codec::Decode<'a> for #ident {
309            fn decode_iter<T>(iter: &mut T) -> ::core::result::Result<Self, #core::codec::Error>
310            where
311                T: core::iter::Iterator<Item = &'a #core::types::Felt>
312            {
313                #impl_block
314            }
315        }
316    }
317    .into()
318}
319
320/// Determines the path to the `starknet-core` crate root.
321fn derive_core_path(input: &DeriveInput) -> proc_macro2::TokenStream {
322    let mut attr_args = Args::default();
323
324    for attr in &input.attrs {
325        if !attr.meta.path().is_ident("starknet") {
326            continue;
327        }
328
329        match &attr.meta {
330            Meta::Path(_) => {}
331            Meta::List(meta_list) => {
332                let args: Args = meta_list
333                    .parse_args()
334                    .expect("unable to parse starknet attribute args");
335
336                attr_args.merge(args);
337            }
338            Meta::NameValue(_) => panic!("starknet attribute must not be name-value"),
339        }
340    }
341
342    attr_args.core.map_or_else(
343        || {
344            #[cfg(not(feature = "import_from_starknet"))]
345            quote! {
346                ::starknet_core
347            }
348
349            // This feature is enabled by the `starknet` crate. When using `starknet` it's assumed
350            // that users would not have imported `starknet-core` directly.
351            #[cfg(feature = "import_from_starknet")]
352            quote! {
353                ::starknet::core
354            }
355        },
356        |id| id.parse().expect("unable to parse core crate path"),
357    )
358}
359
360/// Turns an integer into an optimal `TokenStream` that constructs a `Felt` with the same value.
361fn int_to_felt(int: usize, core: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
362    match int {
363        0 => quote! { #core::types::Felt::ZERO },
364        1 => quote! { #core::types::Felt::ONE },
365        2 => quote! { #core::types::Felt::TWO },
366        3 => quote! { #core::types::Felt::THREE },
367        // TODO: turn the number into Montgomery repr and use const ctor instead.
368        _ => {
369            let literal = LitInt::new(&int.to_string(), Span::call_site());
370            quote! { #core::types::Felt::from(#literal) }
371        }
372    }
373}