sbor_derive_common/
encode.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::*;
4
5use crate::utils::*;
6
7macro_rules! trace {
8    ($($arg:expr),*) => {{
9        #[cfg(feature = "trace")]
10        println!($($arg),*);
11    }};
12}
13
14pub fn handle_encode(
15    input: TokenStream,
16    context_custom_value_kind: Option<&'static str>,
17) -> Result<TokenStream> {
18    trace!("handle_encode() starts");
19
20    let parsed: DeriveInput = parse2(input)?;
21
22    let output = match get_derive_strategy(&parsed.attrs)? {
23        DeriveStrategy::Normal => handle_normal_encode(parsed, context_custom_value_kind)?,
24        DeriveStrategy::Transparent => {
25            handle_transparent_encode(parsed, context_custom_value_kind)?
26        }
27        DeriveStrategy::DeriveAs {
28            as_type, as_ref, ..
29        } => handle_encode_as(parsed, context_custom_value_kind, &as_type, &as_ref)?,
30    };
31
32    #[cfg(feature = "trace")]
33    crate::utils::print_generated_code("Encode", &output);
34
35    trace!("handle_encode() finishes");
36    Ok(output)
37}
38
39pub fn handle_transparent_encode(
40    parsed: DeriveInput,
41    context_custom_value_kind: Option<&'static str>,
42) -> Result<TokenStream> {
43    let output = match &parsed.data {
44        Data::Struct(s) => {
45            let single_field = process_fields(&s.fields)?
46                .unique_unskipped_field()
47                .ok_or_else(|| Error::new(
48                    Span::call_site(),
49                    "The transparent attribute is only supported for structs with a single unskipped field.",
50                ))?;
51            handle_encode_as(
52                parsed,
53                context_custom_value_kind,
54                single_field.field_type(),
55                &single_field.self_field_reference(),
56            )?
57        }
58        Data::Enum(_) => {
59            return Err(Error::new(Span::call_site(), "The transparent attribute is only supported for structs with a single unskipped field."));
60        }
61        Data::Union(_) => {
62            return Err(Error::new(Span::call_site(), "Union is not supported!"));
63        }
64    };
65
66    Ok(output)
67}
68
69pub fn handle_encode_as(
70    parsed: DeriveInput,
71    context_custom_value_kind: Option<&'static str>,
72    as_type: &Type,
73    as_ref_code: &TokenStream,
74) -> Result<TokenStream> {
75    let DeriveInput {
76        attrs,
77        ident,
78        generics,
79        ..
80    } = parsed;
81    let (impl_generics, ty_generics, where_clause, custom_value_kind_generic, encoder_generic) =
82        build_encode_generics(&generics, &attrs, context_custom_value_kind)?;
83
84    // NOTE: The `: &#as_type` is not strictly needed for the code to compile,
85    // but it is useful to sanity check that the user has provided the correct implementation.
86    // If they have not, they should get a nice and clear error message.
87    let output = quote! {
88        impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause {
89            #[inline]
90            fn encode_value_kind(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
91                use sbor::{self, Encode};
92                let as_ref: &#as_type = #as_ref_code;
93                as_ref.encode_value_kind(encoder)
94            }
95
96            #[inline]
97            fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
98                use sbor::{self, Encode};
99                let as_ref: &#as_type = #as_ref_code;
100                as_ref.encode_body(encoder)
101            }
102        }
103    };
104
105    Ok(output)
106}
107
108pub fn handle_normal_encode(
109    parsed: DeriveInput,
110    context_custom_value_kind: Option<&'static str>,
111) -> Result<TokenStream> {
112    let DeriveInput {
113        attrs,
114        ident,
115        data,
116        generics,
117        ..
118    } = parsed;
119    let (impl_generics, ty_generics, where_clause, custom_value_kind_generic, encoder_generic) =
120        build_encode_generics(&generics, &attrs, context_custom_value_kind)?;
121
122    let output = match data {
123        Data::Struct(s) => {
124            let fields_data = process_fields(&s.fields)?;
125            let unskipped_field_count = fields_data.unskipped_field_count();
126            let unskipped_self_field_references = fields_data.unskipped_self_field_references();
127            quote! {
128                impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause {
129                    #[inline]
130                    fn encode_value_kind(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
131                        encoder.write_value_kind(sbor::ValueKind::Tuple)
132                    }
133
134                    #[inline]
135                    fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
136                        use sbor::{self, Encode};
137                        encoder.write_size(#unskipped_field_count)?;
138                        #(encoder.encode(#unskipped_self_field_references)?;)*
139                        Ok(())
140                    }
141                }
142            }
143        }
144        Data::Enum(DataEnum { variants, .. }) => {
145            let EnumVariantsData {
146                source_variants, ..
147            } = process_enum_variants(&attrs, &variants)?;
148            let match_arms = source_variants
149                .iter()
150                .map(|source_variant| {
151                    Ok(match source_variant {
152                        SourceVariantData::Reachable(VariantData {
153                            variant_name,
154                            discriminator,
155                            fields_handling: FieldsHandling::Standard(fields_data),
156                            ..
157                        }) => {
158                            let unskipped_field_count = fields_data.unskipped_field_count();
159                            let fields_unpacking = fields_data.fields_unpacking();
160                            let unskipped_unpacking_variable_names = fields_data.unskipped_unpacking_variable_names();
161                            quote! {
162                                Self::#variant_name #fields_unpacking => {
163                                    encoder.write_discriminator(#discriminator)?;
164                                    encoder.write_size(#unskipped_field_count)?;
165                                    #(encoder.encode(#unskipped_unpacking_variable_names)?;)*
166                                }
167                            }
168                        }
169                        SourceVariantData::Reachable(VariantData {
170                            variant_name,
171                            discriminator,
172                            fields_handling: FieldsHandling::Flatten { unique_field, fields_data, },
173                            ..
174                        }) => {
175                            let fields_unpacking = fields_data.fields_unpacking();
176                            let field_type = unique_field.field_type();
177                            let unpacking_field_name = unique_field.variable_name_from_unpacking();
178                            let tuple_assertion = output_flatten_type_is_sbor_tuple_assertion(
179                                &custom_value_kind_generic,
180                                field_type,
181                            );
182                            quote! {
183                                Self::#variant_name #fields_unpacking => {
184                                    // Flatten is only valid if the single child type is an SBOR tuple, so do a
185                                    // zero-cost assertion on this so the user gets a good error message if they
186                                    // misuse this.
187                                    #tuple_assertion
188                                    // We make use of the fact that an enum body encodes as (discriminator, fields_count, ..fields)
189                                    // And a tuple body encodes as (fields_count, ..fields)
190                                    // So we can flatten by encoding the discriminator and then running `encode_body` on the child tuple
191                                    encoder.write_discriminator(#discriminator)?;
192                                    <#field_type as sbor::Encode <#custom_value_kind_generic, #encoder_generic>>::encode_body(
193                                        #unpacking_field_name,
194                                        encoder
195                                    )?;
196                                }
197                            }
198                        }
199                        SourceVariantData::Unreachable(UnreachableVariantData {
200                            variant_name,
201                            fields_data,
202                            ..
203                        }) => {
204                            let empty_fields_unpacking = fields_data.empty_fields_unpacking();
205                            let panic_message =
206                                format!("Variant {} ignored as unreachable", variant_name.to_string());
207                            quote! {
208                                Self::#variant_name #empty_fields_unpacking => panic!(#panic_message),
209                            }
210                        }
211                    })
212                })
213                .collect::<Result<Vec<_>>>()?;
214
215            let encode_content = if match_arms.len() == 0 {
216                quote! {}
217            } else {
218                quote! {
219                    use sbor::{self, Encode};
220
221                    match self {
222                        #(#match_arms)*
223                    }
224                }
225            };
226            quote! {
227                impl #impl_generics sbor::Encode <#custom_value_kind_generic, #encoder_generic> for #ident #ty_generics #where_clause {
228                    #[inline]
229                    fn encode_value_kind(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
230                        encoder.write_value_kind(sbor::ValueKind::Enum)
231                    }
232
233                    #[inline]
234                    fn encode_body(&self, encoder: &mut #encoder_generic) -> Result<(), sbor::EncodeError> {
235                        #encode_content
236                        Ok(())
237                    }
238                }
239            }
240        }
241        Data::Union(_) => {
242            return Err(Error::new(Span::call_site(), "Union is not supported!"));
243        }
244    };
245
246    #[cfg(feature = "trace")]
247    crate::utils::print_generated_code("Encode", &output);
248
249    trace!("handle_encode() finishes");
250    Ok(output)
251}
252
253#[cfg(test)]
254mod tests {
255    use proc_macro2::TokenStream;
256    use std::str::FromStr;
257
258    use super::*;
259
260    fn assert_code_eq(a: TokenStream, b: TokenStream) {
261        assert_eq!(a.to_string(), b.to_string());
262    }
263
264    #[test]
265    fn test_encode_struct() {
266        let input = TokenStream::from_str("struct Test {a: u32}").unwrap();
267        let output = handle_encode(input, None).unwrap();
268
269        assert_code_eq(
270            output,
271            quote! {
272                impl <E: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E> for Test {
273                    #[inline]
274                    fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
275                        encoder.write_value_kind(sbor::ValueKind::Tuple)
276                    }
277
278                    #[inline]
279                    fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
280                        use sbor::{self, Encode};
281                        encoder.write_size(1usize)?;
282                        encoder.encode(&self.a)?;
283                        Ok(())
284                    }
285                }
286            },
287        );
288    }
289
290    #[test]
291    fn test_encode_enum() {
292        let input = TokenStream::from_str("enum Test {A, B (u32), C {x: u8}}").unwrap();
293        let output = handle_encode(input, None).unwrap();
294
295        assert_code_eq(
296            output,
297            quote! {
298                impl <E: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E> for Test {
299                    #[inline]
300                    fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
301                        encoder.write_value_kind(sbor::ValueKind::Enum)
302                    }
303
304                    #[inline]
305                    fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
306                        use sbor::{self, Encode};
307                        match self {
308                            Self::A => {
309                                encoder.write_discriminator(0u8)?;
310                                encoder.write_size(0usize)?;
311                            }
312                            Self::B(a0) => {
313                                encoder.write_discriminator(1u8)?;
314                                encoder.write_size(1usize)?;
315                                encoder.encode(a0)?;
316                            }
317                            Self::C { x, .. } => {
318                                encoder.write_discriminator(2u8)?;
319                                encoder.write_size(1usize)?;
320                                encoder.encode(x)?;
321                            }
322                        }
323                        Ok(())
324                    }
325                }
326            },
327        );
328    }
329
330    #[test]
331    fn test_skip() {
332        let input = TokenStream::from_str("struct Test {#[sbor(skip)] a: u32}").unwrap();
333        let output = handle_encode(input, None).unwrap();
334
335        assert_code_eq(
336            output,
337            quote! {
338                impl <E: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E> for Test {
339                    #[inline]
340                    fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
341                        encoder.write_value_kind(sbor::ValueKind::Tuple)
342                    }
343
344                    #[inline]
345                    fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
346                        use sbor::{self, Encode};
347                        encoder.write_size(0usize)?;
348                        Ok(())
349                    }
350                }
351            },
352        );
353    }
354
355    #[test]
356    fn test_encode_generic() {
357        let input = TokenStream::from_str("struct Test<T, E: Clashing> { a: T, b: E, }").unwrap();
358        let output = handle_encode(input, None).unwrap();
359
360        assert_code_eq(
361            output,
362            quote! {
363                impl <T, E: Clashing, E0: sbor::Encoder<X>, X: sbor::CustomValueKind > sbor::Encode<X, E0> for Test<T, E >
364                where
365                    T: sbor::Encode<X, E0>,
366                    E: sbor::Encode<X, E0>
367                {
368                    #[inline]
369                    fn encode_value_kind(&self, encoder: &mut E0) -> Result<(), sbor::EncodeError> {
370                        encoder.write_value_kind(sbor::ValueKind::Tuple)
371                    }
372
373                    #[inline]
374                    fn encode_body(&self, encoder: &mut E0) -> Result<(), sbor::EncodeError> {
375                        use sbor::{self, Encode};
376                        encoder.write_size(2usize)?;
377                        encoder.encode(&self.a)?;
378                        encoder.encode(&self.b)?;
379                        Ok(())
380                    }
381                }
382            },
383        );
384    }
385
386    #[test]
387    fn test_encode_struct_with_custom_value_kind() {
388        let input = TokenStream::from_str(
389            "#[sbor(custom_value_kind = \"NoCustomValueKind\")] struct Test {#[sbor(skip)] a: u32}",
390        )
391        .unwrap();
392        let output = handle_encode(input, None).unwrap();
393
394        assert_code_eq(
395            output,
396            quote! {
397                impl <E: sbor::Encoder<NoCustomValueKind> > sbor::Encode<NoCustomValueKind, E> for Test {
398                    #[inline]
399                    fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
400                        encoder.write_value_kind(sbor::ValueKind::Tuple)
401                    }
402
403                    #[inline]
404                    fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
405                        use sbor::{self, Encode};
406                        encoder.write_size(0usize)?;
407                        Ok(())
408                    }
409                }
410            },
411        );
412    }
413
414    #[test]
415    fn test_custom_value_kind_canonical_path() {
416        let input = TokenStream::from_str(
417            "#[sbor(custom_value_kind = \"sbor::basic::NoCustomValueKind\")] struct Test {#[sbor(skip)] a: u32}",
418        )
419        .unwrap();
420        let output = handle_encode(input, None).unwrap();
421
422        assert_code_eq(
423            output,
424            quote! {
425                impl <E: sbor::Encoder<sbor::basic::NoCustomValueKind> > sbor::Encode<sbor::basic::NoCustomValueKind, E> for Test {
426                    #[inline]
427                    fn encode_value_kind(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
428                        encoder.write_value_kind(sbor::ValueKind::Tuple)
429                    }
430
431                    #[inline]
432                    fn encode_body(&self, encoder: &mut E) -> Result<(), sbor::EncodeError> {
433                        use sbor::{self, Encode};
434                        encoder.write_size(0usize)?;
435                        Ok(())
436                    }
437                }
438            },
439        );
440    }
441}