starknet_core_derive/
lib.rs1#![deny(missing_docs)]
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{
9 parse::{Error as ParseError, Parse, ParseStream},
10 parse_macro_input, DeriveInput, Fields, LitInt, LitStr, Meta, Token,
11};
12
13#[derive(Default)]
14struct Args {
15 core: Option<LitStr>,
16}
17
18impl Args {
19 fn merge(&mut self, other: Self) {
20 if let Some(core) = other.core {
21 if self.core.is_some() {
22 panic!("starknet attribute `core` defined more than once");
23 } else {
24 self.core = Some(core);
25 }
26 }
27 }
28}
29
30impl Parse for Args {
31 fn parse(input: ParseStream<'_>) -> Result<Self, ParseError> {
32 let mut core: Option<LitStr> = None;
33
34 while !input.is_empty() {
35 let lookahead = input.lookahead1();
36 if lookahead.peek(kw::core) {
37 let _ = input.parse::<kw::core>()?;
38 let _ = input.parse::<Token![=]>()?;
39 let value = input.parse::<LitStr>()?;
40
41 match core {
42 Some(_) => {
43 return Err(ParseError::new(
44 Span::call_site(),
45 "starknet attribute `core` defined more than once",
46 ))
47 }
48 None => {
49 core = Some(value);
50 }
51 }
52 } else {
53 return Err(lookahead.error());
54 }
55 }
56
57 Ok(Self { core })
58 }
59}
60
61mod kw {
62 syn::custom_keyword!(core);
63}
64
65#[proc_macro_derive(Encode, attributes(starknet))]
67pub fn derive_encode(input: TokenStream) -> TokenStream {
68 let input: DeriveInput = parse_macro_input!(input);
69 let ident = &input.ident;
70
71 let core = derive_core_path(&input);
72
73 let impl_block = match input.data {
74 syn::Data::Struct(data) => {
75 let field_impls = data.fields.iter().enumerate().map(|(ind_field, field)| {
76 let field_ident = match &field.ident {
77 Some(field_ident) => quote! { self.#field_ident },
78 None => {
79 let ind_field = syn::Index::from(ind_field);
80 quote! { self.#ind_field }
81 }
82 };
83 let field_type = &field.ty;
84
85 quote! {
86 <#field_type as #core::codec::Encode>::encode(&#field_ident, writer)?;
87 }
88 });
89
90 quote! {
91 #(#field_impls)*
92 }
93 }
94 syn::Data::Enum(data) => {
95 let variant_impls =
96 data.variants
97 .iter()
98 .enumerate()
99 .map(|(ind_variant, variant)| {
100 let variant_ident = &variant.ident;
101 let ind_variant = int_to_felt(ind_variant, &core);
102
103 match &variant.fields {
104 Fields::Named(fields_named) => {
105 let names = fields_named
106 .named
107 .iter()
108 .map(|field| field.ident.as_ref().unwrap());
109
110 let field_impls = fields_named.named.iter().map(|field| {
111 let field_ident = field.ident.as_ref().unwrap();
112 let field_type = &field.ty;
113
114 quote! {
115 <#field_type as #core::codec::Encode>
116 ::encode(#field_ident, writer)?;
117 }
118 });
119
120 quote! {
121 Self::#variant_ident { #(#names),* } => {
122 writer.write(#ind_variant);
123 #(#field_impls)*
124 },
125 }
126 }
127 Fields::Unnamed(fields_unnamed) => {
128 let names = fields_unnamed.unnamed.iter().enumerate().map(
129 |(ind_field, _)| {
130 syn::Ident::new(
131 &format!("field_{}", ind_field),
132 Span::call_site(),
133 )
134 },
135 );
136
137 let field_impls = fields_unnamed.unnamed.iter().enumerate().map(
138 |(ind_field, field)| {
139 let field_ident = syn::Ident::new(
140 &format!("field_{}", ind_field),
141 Span::call_site(),
142 );
143 let field_type = &field.ty;
144
145 quote! {
146 <#field_type as #core::codec::Encode>
147 ::encode(#field_ident, writer)?;
148 }
149 },
150 );
151
152 quote! {
153 Self::#variant_ident( #(#names),* ) => {
154 writer.write(#ind_variant);
155 #(#field_impls)*
156 },
157 }
158 }
159 Fields::Unit => {
160 quote! {
161 Self::#variant_ident => {
162 writer.write(#ind_variant);
163 },
164 }
165 }
166 }
167 });
168
169 quote! {
170 match self {
171 #(#variant_impls)*
172 }
173 }
174 }
175 syn::Data::Union(_) => panic!("union type not supported"),
176 };
177
178 quote! {
179 #[automatically_derived]
180 impl #core::codec::Encode for #ident {
181 fn encode<W: #core::codec::FeltWriter>(&self, writer: &mut W)
182 -> ::core::result::Result<(), #core::codec::Error> {
183 #impl_block
184
185 Ok(())
186 }
187 }
188 }
189 .into()
190}
191
192#[proc_macro_derive(Decode, attributes(starknet))]
194pub fn derive_decode(input: TokenStream) -> TokenStream {
195 let input: DeriveInput = parse_macro_input!(input);
196 let ident = &input.ident;
197
198 let core = derive_core_path(&input);
199
200 let impl_block = match input.data {
201 syn::Data::Struct(data) => match &data.fields {
202 Fields::Named(fields_named) => {
203 let field_impls = fields_named.named.iter().map(|field| {
204 let field_ident = &field.ident;
205 let field_type = &field.ty;
206
207 quote! {
208 #field_ident: <#field_type as #core::codec::Decode>
209 ::decode_iter(iter)?,
210 }
211 });
212
213 quote! {
214 Ok(Self {
215 #(#field_impls)*
216 })
217 }
218 }
219 Fields::Unnamed(fields_unnamed) => {
220 let field_impls = fields_unnamed.unnamed.iter().map(|field| {
221 let field_type = &field.ty;
222 quote! {
223 <#field_type as #core::codec::Decode>::decode_iter(iter)?
224 }
225 });
226
227 quote! {
228 Ok(Self (
229 #(#field_impls),*
230 ))
231 }
232 }
233 Fields::Unit => {
234 quote! {
235 Ok(Self)
236 }
237 }
238 },
239 syn::Data::Enum(data) => {
240 let variant_impls = data
241 .variants
242 .iter()
243 .enumerate()
244 .map(|(ind_variant, variant)| {
245 let variant_ident = &variant.ident;
246 let ind_variant = int_to_felt(ind_variant, &core);
247
248 let decode_impl = match &variant.fields {
249 Fields::Named(fields_named) => {
250 let field_impls = fields_named.named.iter().map(|field| {
251 let field_ident = field.ident.as_ref().unwrap();
252 let field_type = &field.ty;
253
254 quote! {
255 #field_ident: <#field_type as #core::codec::Decode>
256 ::decode_iter(iter)?,
257 }
258 });
259
260 quote! {
261 return Ok(Self::#variant_ident {
262 #(#field_impls)*
263 });
264 }
265 }
266 Fields::Unnamed(fields_unnamed) => {
267 let field_impls = fields_unnamed.unnamed.iter().map(|field| {
268 let field_type = &field.ty;
269
270 quote! {
271 <#field_type as #core::codec::Decode>::decode_iter(iter)?
272 }
273 });
274
275 quote! {
276 return Ok(Self::#variant_ident( #(#field_impls),* ));
277 }
278 }
279 Fields::Unit => {
280 quote! {
281 return Ok(Self::#variant_ident);
282 }
283 }
284 };
285
286 quote! {
287 if tag == &#ind_variant {
288 #decode_impl
289 }
290 }
291 });
292
293 let ident = ident.to_string();
294
295 quote! {
296 let tag = iter.next().ok_or_else(#core::codec::Error::input_exhausted)?;
297
298 #(#variant_impls)*
299
300 Err(#core::codec::Error::unknown_enum_tag(tag, #ident))
301 }
302 }
303 syn::Data::Union(_) => panic!("union type not supported"),
304 };
305
306 quote! {
307 #[automatically_derived]
308 impl<'a> #core::codec::Decode<'a> for #ident {
309 fn decode_iter<T>(iter: &mut T) -> ::core::result::Result<Self, #core::codec::Error>
310 where
311 T: core::iter::Iterator<Item = &'a #core::types::Felt>
312 {
313 #impl_block
314 }
315 }
316 }
317 .into()
318}
319
320fn derive_core_path(input: &DeriveInput) -> proc_macro2::TokenStream {
322 let mut attr_args = Args::default();
323
324 for attr in &input.attrs {
325 if !attr.meta.path().is_ident("starknet") {
326 continue;
327 }
328
329 match &attr.meta {
330 Meta::Path(_) => {}
331 Meta::List(meta_list) => {
332 let args: Args = meta_list
333 .parse_args()
334 .expect("unable to parse starknet attribute args");
335
336 attr_args.merge(args);
337 }
338 Meta::NameValue(_) => panic!("starknet attribute must not be name-value"),
339 }
340 }
341
342 attr_args.core.map_or_else(
343 || {
344 #[cfg(not(feature = "import_from_starknet"))]
345 quote! {
346 ::starknet_core
347 }
348
349 #[cfg(feature = "import_from_starknet")]
352 quote! {
353 ::starknet::core
354 }
355 },
356 |id| id.parse().expect("unable to parse core crate path"),
357 )
358}
359
360fn int_to_felt(int: usize, core: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
362 match int {
363 0 => quote! { #core::types::Felt::ZERO },
364 1 => quote! { #core::types::Felt::ONE },
365 2 => quote! { #core::types::Felt::TWO },
366 3 => quote! { #core::types::Felt::THREE },
367 _ => {
369 let literal = LitInt::new(&int.to_string(), Span::call_site());
370 quote! { #core::types::Felt::from(#literal) }
371 }
372 }
373}