1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::*;
4
5use crate::utils::*;
6
7macro_rules! trace {
8 ($($arg:expr),*) => {{
9 #[cfg(feature = "trace")]
10 println!($($arg),*);
11 }};
12}
13
14pub fn handle_encode(
15 input: TokenStream,
16 context_custom_value_kind: Option<&'static str>,
17) -> Result<TokenStream> {
18 trace!("handle_encode() starts");
19
20 let parsed: DeriveInput = parse2(input)?;
21
22 let output = match get_derive_strategy(&parsed.attrs)? {
23 DeriveStrategy::Normal => handle_normal_encode(parsed, context_custom_value_kind)?,
24 DeriveStrategy::Transparent => {
25 handle_transparent_encode(parsed, context_custom_value_kind)?
26 }
27 DeriveStrategy::DeriveAs {
28 as_type, as_ref, ..
29 } => handle_encode_as(parsed, context_custom_value_kind, &as_type, &as_ref)?,
30 };
31
32 #[cfg(feature = "trace")]
33 crate::utils::print_generated_code("Encode", &output);
34
35 trace!("handle_encode() finishes");
36 Ok(output)
37}
38
39pub fn handle_transparent_encode(
40 parsed: DeriveInput,
41 context_custom_value_kind: Option<&'static str>,
42) -> Result<TokenStream> {
43 let output = match &parsed.data {
44 Data::Struct(s) => {
45 let single_field = process_fields(&s.fields)?
46 .unique_unskipped_field()
47 .ok_or_else(|| Error::new(
48 Span::call_site(),
49 "The transparent attribute is only supported for structs with a single unskipped field.",
50 ))?;
51 handle_encode_as(
52 parsed,
53 context_custom_value_kind,
54 single_field.field_type(),
55 &single_field.self_field_reference(),
56 )?
57 }
58 Data::Enum(_) => {
59 return Err(Error::new(Span::call_site(), "The transparent attribute is only supported for structs with a single unskipped field."));
60 }
61 Data::Union(_) => {
62 return Err(Error::new(Span::call_site(), "Union is not supported!"));
63 }
64 };
65
66 Ok(output)
67}
68
69pub fn handle_encode_as(
70 parsed: DeriveInput,
71 context_custom_value_kind: Option<&'static str>,
72 as_type: &Type,
73 as_ref_code: &TokenStream,
74) -> Result<TokenStream> {
75 let DeriveInput {
76 attrs,
77 ident,
78 generics,
79 ..
80 } = parsed;
81 let (impl_generics, ty_generics, where_clause, custom_value_kind_generic, encoder_generic) =
82 build_encode_generics(&generics, &attrs, context_custom_value_kind)?;
83
84 let output = quote! {
88 impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause {
89 #[inline]
90 fn encode_value_kind(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
91 use sbor::{self, Encode};
92 let as_ref: &#as_type = #as_ref_code;
93 as_ref.encode_value_kind(encoder)
94 }
95
96 #[inline]
97 fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
98 use sbor::{self, Encode};
99 let as_ref: &#as_type = #as_ref_code;
100 as_ref.encode_body(encoder)
101 }
102 }
103 };
104
105 Ok(output)
106}
107
108pub fn handle_normal_encode(
109 parsed: DeriveInput,
110 context_custom_value_kind: Option<&'static str>,
111) -> Result<TokenStream> {
112 let DeriveInput {
113 attrs,
114 ident,
115 data,
116 generics,
117 ..
118 } = parsed;
119 let (impl_generics, ty_generics, where_clause, custom_value_kind_generic, encoder_generic) =
120 build_encode_generics(&generics, &attrs, context_custom_value_kind)?;
121
122 let output = match data {
123 Data::Struct(s) => {
124 let fields_data = process_fields(&s.fields)?;
125 let unskipped_field_count = fields_data.unskipped_field_count();
126 let unskipped_self_field_references = fields_data.unskipped_self_field_references();
127 quote! {
128 impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause {
129 #[inline]
130 fn encode_value_kind(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
131 encoder.write_value_kind(sbor::ValueKind::Tuple)
132 }
133
134 #[inline]
135 fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
136 use sbor::{self, Encode};
137 encoder.write_size(#unskipped_field_count)?;
138 #(encoder.encode(#unskipped_self_field_references)?;)*
139 Ok(())
140 }
141 }
142 }
143 }
144 Data::Enum(DataEnum { variants, .. }) => {
145 let EnumVariantsData {
146 source_variants, ..
147 } = process_enum_variants(&attrs, &variants)?;
148 let match_arms = source_variants
149 .iter()
150 .map(|source_variant| {
151 Ok(match source_variant {
152 SourceVariantData::Reachable(VariantData {
153 variant_name,
154 discriminator,
155 fields_handling: FieldsHandling::Standard(fields_data),
156 ..
157 }) => {
158 let unskipped_field_count = fields_data.unskipped_field_count();
159 let fields_unpacking = fields_data.fields_unpacking();
160 let unskipped_unpacking_variable_names = fields_data.unskipped_unpacking_variable_names();
161 quote! {
162 Self::#variant_name #fields_unpacking => {
163 encoder.write_discriminator(#discriminator)?;
164 encoder.write_size(#unskipped_field_count)?;
165 #(encoder.encode(#unskipped_unpacking_variable_names)?;)*
166 }
167 }
168 }
169 SourceVariantData::Reachable(VariantData {
170 variant_name,
171 discriminator,
172 fields_handling: FieldsHandling::Flatten { unique_field, fields_data, },
173 ..
174 }) => {
175 let fields_unpacking = fields_data.fields_unpacking();
176 let field_type = unique_field.field_type();
177 let unpacking_field_name = unique_field.variable_name_from_unpacking();
178 let tuple_assertion = output_flatten_type_is_sbor_tuple_assertion(
179 &custom_value_kind_generic,
180 field_type,
181 );
182 quote! {
183 Self::#variant_name #fields_unpacking => {
184 #tuple_assertion
188 encoder.write_discriminator(#discriminator)?;
192 <#field_type as sbor::Encode <#custom_value_kind_generic, #encoder_generic>>::encode_body(
193 #unpacking_field_name,
194 encoder
195 )?;
196 }
197 }
198 }
199 SourceVariantData::Unreachable(UnreachableVariantData {
200 variant_name,
201 fields_data,
202 ..
203 }) => {
204 let empty_fields_unpacking = fields_data.empty_fields_unpacking();
205 let panic_message =
206 format!("Variant {} ignored as unreachable", variant_name.to_string());
207 quote! {
208 Self::#variant_name #empty_fields_unpacking => panic!(#panic_message),
209 }
210 }
211 })
212 })
213 .collect::<Result<Vec<_>>>()?;
214
215 let encode_content = if match_arms.len() == 0 {
216 quote! {}
217 } else {
218 quote! {
219 use sbor::{self, Encode};
220
221 match self {
222 #(#match_arms)*
223 }
224 }
225 };
226 quote! {
227 impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause {
228 #[inline]
229 fn encode_value_kind(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
230 encoder.write_value_kind(sbor::ValueKind::Enum)
231 }
232
233 #[inline]
234 fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
235 #encode_content
236 Ok(())
237 }
238 }
239 }
240 }
241 Data::Union(_) => {
242 return Err(Error::new(Span::call_site(), "Union is not supported!"));
243 }
244 };
245
246 #[cfg(feature = "trace")]
247 crate::utils::print_generated_code("Encode", &output);
248
249 trace!("handle_encode() finishes");
250 Ok(output)
251}
252
253#[cfg(test)]
254mod tests {
255 use proc_macro2::TokenStream;
256 use std::str::FromStr;
257
258 use super::*;
259
260 fn assert_code_eq(a: TokenStream, b: TokenStream) {
261 assert_eq!(a.to_string(), b.to_string());
262 }
263
264 #[test]
265 fn test_encode_struct() {
266 let input = TokenStream::from_str("struct Test {a: u32}").unwrap();
267 let output = handle_encode(input, None).unwrap();
268
269 assert_code_eq(
270 output,
271 quote! {
272 impl <E: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E> for Test {
273 #[inline]
274 fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
275 encoder.write_value_kind(sbor::ValueKind::Tuple)
276 }
277
278 #[inline]
279 fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
280 use sbor::{self, Encode};
281 encoder.write_size(1usize)?;
282 encoder.encode(&self.a)?;
283 Ok(())
284 }
285 }
286 },
287 );
288 }
289
290 #[test]
291 fn test_encode_enum() {
292 let input = TokenStream::from_str("enum Test {A, B (u32), C {x: u8}}").unwrap();
293 let output = handle_encode(input, None).unwrap();
294
295 assert_code_eq(
296 output,
297 quote! {
298 impl <E: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E> for Test {
299 #[inline]
300 fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
301 encoder.write_value_kind(sbor::ValueKind::Enum)
302 }
303
304 #[inline]
305 fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
306 use sbor::{self, Encode};
307 match self {
308 Self::A => {
309 encoder.write_discriminator(0u8)?;
310 encoder.write_size(0usize)?;
311 }
312 Self::B(a0) => {
313 encoder.write_discriminator(1u8)?;
314 encoder.write_size(1usize)?;
315 encoder.encode(a0)?;
316 }
317 Self::C { x, .. } => {
318 encoder.write_discriminator(2u8)?;
319 encoder.write_size(1usize)?;
320 encoder.encode(x)?;
321 }
322 }
323 Ok(())
324 }
325 }
326 },
327 );
328 }
329
330 #[test]
331 fn test_skip() {
332 let input = TokenStream::from_str("struct Test {#[sbor(skip)] a: u32}").unwrap();
333 let output = handle_encode(input, None).unwrap();
334
335 assert_code_eq(
336 output,
337 quote! {
338 impl <E: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E> for Test {
339 #[inline]
340 fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
341 encoder.write_value_kind(sbor::ValueKind::Tuple)
342 }
343
344 #[inline]
345 fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
346 use sbor::{self, Encode};
347 encoder.write_size(0usize)?;
348 Ok(())
349 }
350 }
351 },
352 );
353 }
354
355 #[test]
356 fn test_encode_generic() {
357 let input = TokenStream::from_str("struct Test<T, E: Clashing> { a: T, b: E, }").unwrap();
358 let output = handle_encode(input, None).unwrap();
359
360 assert_code_eq(
361 output,
362 quote! {
363 impl <T, E: Clashing, E0: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E0> for Test<T, E >
364 where
365 T: sbor::Encode<X, E0>,
366 E: sbor::Encode<X, E0>
367 {
368 #[inline]
369 fn encode_value_kind(&self, encoder: &mut E0) -> Result<(), sbor::EncodeError> {
370 encoder.write_value_kind(sbor::ValueKind::Tuple)
371 }
372
373 #[inline]
374 fn encode_body(&self, encoder: &mut E0) -> Result<(), sbor::EncodeError> {
375 use sbor::{self, Encode};
376 encoder.write_size(2usize)?;
377 encoder.encode(&self.a)?;
378 encoder.encode(&self.b)?;
379 Ok(())
380 }
381 }
382 },
383 );
384 }
385
386 #[test]
387 fn test_encode_struct_with_custom_value_kind() {
388 let input = TokenStream::from_str(
389 "#[sbor(custom_value_kind = \"NoCustomValueKind\")] struct Test {#[sbor(skip)] a: u32}",
390 )
391 .unwrap();
392 let output = handle_encode(input, None).unwrap();
393
394 assert_code_eq(
395 output,
396 quote! {
397 impl <E: sbor::Encoder<NoCustomValueKind> > sbor::Encode<NoCustomValueKind, E> for Test {
398 #[inline]
399 fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
400 encoder.write_value_kind(sbor::ValueKind::Tuple)
401 }
402
403 #[inline]
404 fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
405 use sbor::{self, Encode};
406 encoder.write_size(0usize)?;
407 Ok(())
408 }
409 }
410 },
411 );
412 }
413
414 #[test]
415 fn test_custom_value_kind_canonical_path() {
416 let input = TokenStream::from_str(
417 "#[sbor(custom_value_kind = \"sbor::basic::NoCustomValueKind\")] struct Test {#[sbor(skip)] a: u32}",
418 )
419 .unwrap();
420 let output = handle_encode(input, None).unwrap();
421
422 assert_code_eq(
423 output,
424 quote! {
425 impl <E: sbor::Encoder<sbor::basic::NoCustomValueKind> > sbor::Encode<sbor::basic::NoCustomValueKind, E> for Test {
426 #[inline]
427 fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
428 encoder.write_value_kind(sbor::ValueKind::Tuple)
429 }
430
431 #[inline]
432 fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
433 use sbor::{self, Encode};
434 encoder.write_size(0usize)?;
435 Ok(())
436 }
437 }
438 },
439 );
440 }
441}