tasm_object_derive/
lib.rs

1//! This crate provides a derive macro for the `TasmObject` and `TasmStruct`
2//! traits.
3//!
4//! Example usage:
5//! ```no_compile
6//! #[derive(BFieldCodec, TasmObject)]
7//! struct Foo<T: BFieldCodec> {
8//!     t_list: Vec<T>,
9//! }
10//! ```
11//!
12//! Note: An implementation of `BFieldCodec` is required, else compilation will
13//! fail. It is recommended to derive `BFieldCodec`.
14//!
15//! ### Known limitations
16//!
17//!  - Ignoring fields in tuple structs is currently not supported. Consider
18//!    using a struct with named fields instead.
19//!
20//!    ```no_compile
21//!    #[derive(BFieldCodec, TasmObject)]
22//!    struct Foo(#[tasm_object(ignore)] u32);
23//!    //         ~~~~~~~~~~~~~~~~~~~~~~
24//!    //         currently unsupported in tuple structs
25//!    ```
26
27extern crate proc_macro;
28
29use proc_macro2::TokenStream;
30use quote::quote;
31use syn::DeriveInput;
32
33/// Derives both, `TasmObject` and `TasmStruct` for structs.
34#[proc_macro_derive(TasmObject, attributes(tasm_object))]
35pub fn tasm_object_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
36    let ast = syn::parse(input).unwrap();
37    impl_tasm_object_derive_macro(ast).into()
38}
39
40// To follow along with the more involved functions, consider struct `Foo`:
41//
42// #[derive(BFieldCodec, TasmObject)]
43// struct Foo {
44//     a: Vec<u32>,
45//     b: XFieldElement,
46//     c: Vec<Digest>,
47// }
48//
49// The encoding of
50// Foo {
51//     a: vec![40, 41, 42],
52//     b: xfe!([43, 44, 45]),
53//     c: vec![Digest::new(bfe_array![46, 47, 48, 49, 50])]
54// }
55// loos like this:
56//
57//     ╭──────── c ────────╮  ╭─── b ──╮     ╭──── a ────╮
58//  6, 1, 46, 47, 48, 49, 50, 43, 44, 45, 4, 3, 40, 41, 42
59//  ↑  ↑                       ↑          ↑  ↑
60//  | *c                      *b          | *a
61// *c_si                                 *a_si
62//
63// The abbreviation `si` means “size indicator”. Any dynamically-sized field
64// is prepended by a size indicator.
65//
66// A pointer `*foo` is equal to pointer `*c_si`.
67//
68#[derive(Clone)]
69struct ParsedStruct {
70    /// The names of all relevant fields. Notably, ignored fields are _not_ included.
71    ///
72    /// Reversed compared to the field declaration order. That is, for struct `Foo`:
73    ///
74    /// ```no_compile
75    /// #[derive(BFieldCodec, TasmObject)]
76    /// struct Foo {
77    ///     a: Vec<u32>,
78    ///     b: XFieldElement,
79    /// }
80    /// ```
81    ///
82    /// element `self.field_names[0]` is `b`, element `self.field_names[1]` is `a`.
83    field_names: Vec<syn::Ident>,
84
85    /// The types of all relevant fields. Notably, ignored fields are _not_ included.
86    ///
87    /// The order of the entries mimics the order of field
88    /// [`field_names`](Self::field_names).
89    field_types: Vec<syn::Type>,
90
91    ignored_fields: Vec<syn::Field>,
92
93    /// The rust code to assemble the struct that's annotated with the derive macro.
94    /// Variables with identifiers equal to the struct field identifiers and of the
95    /// appropriate type must be in scope.
96    struct_builder: TokenStream,
97}
98
99impl ParsedStruct {
100    fn new(ast: &DeriveInput) -> Self {
101        let syn::Data::Struct(syn::DataStruct { fields, .. }) = &ast.data else {
102            panic!("expected a struct")
103        };
104
105        match fields {
106            syn::Fields::Named(fields) => Self::parse_struct_with_named_fields(fields),
107            syn::Fields::Unnamed(fields) => Self::parse_tuple_struct(fields),
108            syn::Fields::Unit => Self::unit(),
109        }
110    }
111
112    fn parse_struct_with_named_fields(fields: &syn::FieldsNamed) -> Self {
113        let (ignored_fields, fields): (Vec<_>, Vec<_>) = fields
114            .named
115            .iter()
116            .cloned()
117            .rev()
118            .partition(Self::field_is_ignored);
119
120        let fields = fields.into_iter();
121        let field_names = fields.clone().map(|f| f.ident.unwrap()).collect::<Vec<_>>();
122        let field_types = fields.map(|f| f.ty).collect::<Vec<_>>();
123
124        let fields = field_names.iter();
125        let ignored = ignored_fields.iter().map(|f| f.ident.clone());
126        let struct_builder =
127            quote! { Self { #( #fields ,)* #( #ignored : Default::default(), )* } };
128
129        Self {
130            field_names,
131            field_types,
132            ignored_fields,
133            struct_builder,
134        }
135    }
136
137    fn parse_tuple_struct(fields: &syn::FieldsUnnamed) -> Self {
138        // for now, ignoring fields in tuple structs is unsupported
139        let (field_names, field_types): (Vec<_>, Vec<_>) = fields
140            .unnamed
141            .iter()
142            .cloned()
143            .enumerate()
144            .map(|(i, f)| (quote::format_ident!("field_{i}"), f.ty))
145            .rev()
146            .unzip();
147
148        let fields_in_declared_order = field_names.iter().rev();
149        let struct_builder = quote! { Self( #( #fields_in_declared_order ),* ) };
150
151        Self {
152            field_names,
153            field_types,
154            ignored_fields: vec![],
155            struct_builder,
156        }
157    }
158
159    fn unit() -> Self {
160        Self {
161            field_names: vec![],
162            field_types: vec![],
163            ignored_fields: vec![],
164            struct_builder: quote! { Self },
165        }
166    }
167
168    fn field_is_ignored(field: &syn::Field) -> bool {
169        let field_name = field.ident.as_ref().unwrap();
170        let mut relevant_attributes = field
171            .attrs
172            .iter()
173            .filter(|attr| attr.path().is_ident("tasm_object"));
174
175        let Some(attribute) = relevant_attributes.next() else {
176            return false;
177        };
178        if relevant_attributes.next().is_some() {
179            panic!("field `{field_name}` must have at most 1 `tasm_object` attribute");
180        }
181
182        let parse_ignore = attribute.parse_nested_meta(|meta| match meta.path.get_ident() {
183            Some(ident) if ident == "ignore" => Ok(()),
184            Some(ident) => panic!("unknown identifier `{ident}` for field `{field_name}`"),
185            _ => unreachable!(),
186        });
187        parse_ignore.is_ok()
188    }
189
190    /// Allows writing shorter paths while staying hygienic.
191    fn type_aliases() -> TokenStream {
192        quote! {
193            type Instruction = crate::triton_vm::isa::instruction::LabelledInstruction;
194            type AssertionContext = crate::triton_vm::isa::instruction::AssertionContext;
195            type TypeHint = crate::triton_vm::isa::instruction::TypeHint;
196
197            type AnInstruction =
198                crate::triton_vm::isa::instruction::AnInstruction<::std::string::String>;
199
200            type N = crate::triton_vm::isa::op_stack::NumberOfWords;
201            type ST = crate::triton_vm::isa::op_stack::OpStackElement;
202            type BFE = crate::triton_vm::prelude::BFieldElement;
203
204        }
205    }
206
207    fn generate_code_for_fn_compute_size_and_assert_valid_size_indicator(&self) -> TokenStream {
208        debug_assert_eq!(self.field_types.len(), self.field_names.len());
209
210        let mut fields = self
211            .field_names
212            .iter()
213            .map(|n| n.to_string())
214            .zip(&self.field_types);
215
216        let type_aliases = Self::type_aliases();
217        let Some((_, first_field_ty)) = fields.next_back() else {
218            return quote! {
219                #type_aliases
220
221                [
222                    Instruction::Instruction(AnInstruction::Pop(N::N1)),
223                    Instruction::Instruction(AnInstruction::Push(BFE::new(0))),
224                ]
225                .to_vec()
226            };
227        };
228
229        let accumulator_type_hint = quote! {
230            Instruction::TypeHint(
231                TypeHint {
232                    starting_index: 0,
233                    length: 1,
234                    type_name: ::std::option::Option::None,
235                    variable_name: ::std::string::String::from("size_acc"),
236                }
237            )
238        };
239        let mut rust = quote! {
240            #type_aliases
241
242            if let Some(size) =
243                <Self as crate::twenty_first::math::bfield_codec::BFieldCodec>::static_length()
244            {
245                return [
246                    Instruction::Instruction(AnInstruction::Pop(N::N1)),
247                    Instruction::Instruction(AnInstruction::Push(BFE::from(size))),
248                ]
249                .to_vec();
250            }
251
252            // accumulates successive static lengths; minimize number of static-length jumps
253            let mut static_jump_accumulator = BFE::new(0);
254            let mut instructions = [
255                Instruction::Instruction(AnInstruction::Push(BFE::new(0))),
256                #accumulator_type_hint,
257                Instruction::Instruction(AnInstruction::Place(ST::ST1)),
258            ].to_vec();
259        };
260
261        for (field_name, field_ty) in fields {
262            let field_ptr_hint = Self::top_of_stack_pointer_type_hint(&field_name);
263            rust.extend(quote! {
264                if let Some(size) =
265                    <#field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
266                        ::static_length()
267                {
268                    static_jump_accumulator += BFE::from(size);
269                } else {
270                    if static_jump_accumulator != BFE::new(0) {
271                        instructions.extend([
272                            // _ acc_up_to_some_field *some_field
273                            Instruction::Instruction(AnInstruction::AddI(static_jump_accumulator)),
274                            // _ acc_up_to_some_field *field_si
275                            Instruction::Instruction(AnInstruction::Pick(ST::ST1)),
276                            // _ *field_si acc_up_to_some_field
277                            Instruction::Instruction(AnInstruction::AddI(static_jump_accumulator)),
278                            // _ *field_si acc
279                            Instruction::Instruction(AnInstruction::Place(ST::ST1)),
280                            // _ acc *field_si
281                        ]);
282                        static_jump_accumulator = BFE::new(0);
283                    }
284
285                    instructions.extend([
286                        // _ acc *field_si
287                        Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
288                        // _ acc field_size (*field_si - 1)
289                        Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
290                        #field_ptr_hint,
291                        // _ acc field_size *field
292                        Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
293                        Instruction::Instruction(AnInstruction::Dup(ST::ST2)),
294                        Instruction::Instruction(AnInstruction::Lt),
295                        Instruction::Instruction(AnInstruction::Assert),
296                        Instruction::AssertionContext(AssertionContext::ID(180)),
297                        // _ acc field_size *field
298                        Instruction::Instruction(AnInstruction::Dup(ST::ST0)),
299                        // _ acc field_size *field *field
300                    ]);
301                    instructions.extend(
302                        <#field_ty as crate::tasm_lib::structure::tasm_object::TasmObject>
303                            ::compute_size_and_assert_valid_size_indicator(library)
304                    );
305                    instructions.extend([
306                        // _ acc field_size *field computed_field_size
307                        Instruction::Instruction(AnInstruction::Dup(ST::ST2)),
308                        // _ acc field_size *field computed_field_size field_size
309                        Instruction::Instruction(AnInstruction::Eq),
310                        // _ acc field_size *field (computed_field_size == field_size)
311                        Instruction::Instruction(AnInstruction::Assert),
312                        Instruction::AssertionContext(AssertionContext::ID(181)),
313                        // _ acc field_size *field
314                        Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
315                        // _ acc field_size *field field_size
316                        Instruction::Instruction(AnInstruction::Add),
317                        // _ acc field_size *next_field_or_next_field_si
318                        Instruction::Instruction(AnInstruction::Place(ST::ST2)),
319                        // _ *next_field_or_next_field_si acc field_size
320                        Instruction::Instruction(AnInstruction::Add),
321                        // _ *next_field_or_next_field_si (acc + field_size)
322                        Instruction::Instruction(AnInstruction::AddI(BFE::new(1))),
323                        // _ *next_field_or_next_field_si (acc + field_size + 1)
324                        Instruction::Instruction(AnInstruction::Pick(ST::ST1)),
325                        // _ (acc + field_size + 1) *next_field_or_next_field_si
326                    ]);
327                }
328            });
329        }
330
331        rust.extend(quote!(
332            if let Some(size) =
333                <#first_field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
334                    ::static_length()
335            {
336                static_jump_accumulator += BFE::from(size);
337                instructions.extend([
338                    // _ acc *some_field
339                    Instruction::Instruction(AnInstruction::Pop(N::N1)),
340                    // _ acc
341                    Instruction::Instruction(AnInstruction::AddI(static_jump_accumulator)),
342                    // _ final_size
343                ]);
344            } else {
345                if static_jump_accumulator != BFE::new(0) {
346                    instructions.push(Instruction::Instruction(AnInstruction::AddI(
347                        static_jump_accumulator
348                    )));
349                }
350
351                instructions.extend([
352                    // _ acc *field_si
353                    Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
354                    // _ acc field_size (*field_si - 1)
355                    Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
356                    // _ acc field_size *field
357                    Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
358                    Instruction::Instruction(AnInstruction::Dup(ST::ST2)),
359                    Instruction::Instruction(AnInstruction::Lt),
360                    Instruction::Instruction(AnInstruction::Assert),
361                    Instruction::AssertionContext(AssertionContext::ID(180)),
362                    // _ acc field_size *field
363                ]);
364                instructions.extend(
365                    <#first_field_ty as crate::tasm_lib::structure::tasm_object::TasmObject>
366                        ::compute_size_and_assert_valid_size_indicator(library)
367                );
368                instructions.extend([
369                    // _ acc field_size computed_field_size
370                    Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
371                    // _ acc field_size computed_field_size field_size
372                    Instruction::Instruction(AnInstruction::Eq),
373                    // _ acc field_size (computed_field_size == field_size)
374                    Instruction::Instruction(AnInstruction::Assert),
375                    Instruction::AssertionContext(AssertionContext::ID(181)),
376                    // _ acc field_size
377                    Instruction::Instruction(AnInstruction::Add),
378                    Instruction::Instruction(AnInstruction::AddI(BFE::new(1))),
379                    // _ acc
380                ]);
381
382                if static_jump_accumulator != BFE::new(0) {
383                    instructions.push(Instruction::Instruction(AnInstruction::AddI(
384                        static_jump_accumulator
385                    )));
386                }
387            }
388
389            instructions
390        ));
391
392        rust
393    }
394
395    /// Generate the rust code for `TasmStruct::get_field(…)`, which will then
396    /// generate tasm code to get a field.
397    ///
398    /// For example, calling `TasmStruct::get_field("field_i")` will generate
399    /// tasm code that
400    /// - assumes the stack is in the state `_ *struct`
401    /// - leaves the stack in the state `_ *field_i`
402    fn generate_code_for_fn_get_field(&self, struct_name: &syn::Ident) -> TokenStream {
403        debug_assert_eq!(self.field_types.len(), self.field_names.len());
404
405        let mut fields = self
406            .field_names
407            .iter()
408            .map(|n| n.to_string())
409            .zip(&self.field_types);
410
411        let Some(first_field) = fields.next_back() else {
412            let struct_name = struct_name.to_string();
413            return quote!(panic!("type `{}` has no fields", #struct_name););
414        };
415
416        let type_aliases = Self::type_aliases();
417        let mut rust = quote! {
418            #type_aliases
419            // accumulates successive static lengths; minimize number of static-length jumps
420            let mut static_jump_accumulator = BFE::new(0);
421            let mut instructions = ::std::vec::Vec::new();
422        };
423
424        for (field_name, field_ty) in fields {
425            let field_ptr_hint = Self::top_of_stack_pointer_type_hint(&field_name);
426            rust.extend(quote!(
427                if field_name == #field_name {
428                    if <#field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
429                        ::static_length().is_none()
430                    {
431                        // shift pointer from size indicator to actual field
432                        static_jump_accumulator += BFE::new(1);
433                    }
434                    if static_jump_accumulator != BFE::new(0) {
435                        instructions.push(Instruction::Instruction(AnInstruction::AddI(
436                            static_jump_accumulator
437                        )));
438                    }
439                    instructions.push(#field_ptr_hint);
440                    return instructions;
441                }
442
443                if let Some(size) =
444                    <#field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
445                        ::static_length()
446                {
447                    static_jump_accumulator += BFE::from(size);
448                } else {
449                    if static_jump_accumulator != BFE::new(0) {
450                        instructions.push(Instruction::Instruction(AnInstruction::AddI(
451                            static_jump_accumulator
452                        )));
453                        static_jump_accumulator = BFE::new(0);
454                    }
455
456                    instructions.extend([
457                        // _ *field_si
458                        Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
459                        // _ field_size (*field_si - 1)
460                        Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
461                        #field_ptr_hint,
462                        // _ field_size *field
463                        Instruction::Instruction(AnInstruction::Pick(ST::ST1)),
464                        // _ *field field_size
465                        Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
466                        Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
467                        Instruction::Instruction(AnInstruction::Lt),
468                        Instruction::Instruction(AnInstruction::Assert),
469                        Instruction::AssertionContext(AssertionContext::ID(184)),
470                        // _ *field field_size
471                        Instruction::Instruction(AnInstruction::Add),
472                        // _ *next_field_or_next_field_si
473                    ]);
474                }
475            ));
476        }
477
478        let (first_field_name, first_field_ty) = first_field;
479        let struct_name = struct_name.to_string();
480        let first_field_type_hint = Self::top_of_stack_pointer_type_hint(&first_field_name);
481        rust.extend(quote!(
482            if field_name != #first_field_name {
483                panic!("unknown field name `{field_name}` for type `{}`", #struct_name);
484            }
485            if <#first_field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
486                ::static_length().is_none()
487            {
488                // shift pointer from size indicator to actual field
489                static_jump_accumulator += BFE::new(1);
490            }
491            if static_jump_accumulator != BFE::new(0) {
492                instructions.push(Instruction::Instruction(AnInstruction::AddI(
493                    static_jump_accumulator
494                )));
495            }
496            instructions.push(#first_field_type_hint);
497            instructions
498        ));
499
500        rust
501    }
502
503    /// Generate the rust code for `TasmStruct::get_field_with_size(…)`, which will
504    /// then generate tasm code to get a field and its size.
505    ///
506    /// For example, calling `TasmStruct::get_field_with_size("field_i")` will
507    /// generate tasm code that
508    /// - assumes the stack is in the state `_ *struct`
509    /// - leaves the stack in the state `_ *field_i size_of_field_i`
510    fn generate_code_for_fn_get_field_with_size(&self, struct_name: &syn::Ident) -> TokenStream {
511        debug_assert_eq!(self.field_types.len(), self.field_names.len());
512
513        let mut fields = self
514            .field_names
515            .iter()
516            .map(|n| n.to_string())
517            .zip(&self.field_types);
518        let Some(first_field) = fields.next_back() else {
519            let struct_name = struct_name.to_string();
520            return quote!(panic!("type `{}` has no fields", #struct_name););
521        };
522
523        let type_aliases = Self::type_aliases();
524        let mut rust = quote! {
525            #type_aliases
526            // accumulates successive static lengths; minimize number of static-length jumps
527            let mut static_jump_accumulator = BFE::new(0);
528            let mut instructions = ::std::vec::Vec::new();
529        };
530
531        for (field_name, field_ty) in fields {
532            let field_ptr_hint = Self::top_of_stack_pointer_type_hint(&field_name);
533            rust.extend(quote!(
534                if field_name == #field_name {
535                    if static_jump_accumulator != BFE::new(0) {
536                        instructions.push(Instruction::Instruction(AnInstruction::AddI(
537                            static_jump_accumulator
538                        )));
539                    }
540                    if let Some(size) =
541                        <#field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
542                            ::static_length()
543                    {
544                        instructions.extend([
545                            #field_ptr_hint,
546                            Instruction::Instruction(AnInstruction::Push(BFE::from(size))),
547                        ]);
548                    } else {
549                        let max_offset = BFE::from(Self::MAX_OFFSET);
550                        instructions.extend([
551                            // _ *field_si
552                            Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
553                            // _ field_size (*field_si - 1)
554                            Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
555                            #field_ptr_hint,
556                            // _ field_size *field
557                            Instruction::Instruction(AnInstruction::Pick(ST::ST1)),
558                            // _ *field field_size
559                            Instruction::Instruction(AnInstruction::Push(max_offset)),
560                            Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
561                            Instruction::Instruction(AnInstruction::Lt),
562                            Instruction::Instruction(AnInstruction::Assert),
563                            Instruction::AssertionContext(AssertionContext::ID(185)),
564                            // _ *field field_size
565                        ]);
566                    }
567                    return instructions;
568                }
569
570                if let Some(size) =
571                    <#field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
572                        ::static_length()
573                {
574                    static_jump_accumulator += BFE::from(size);
575                } else {
576                    if static_jump_accumulator != BFE::new(0) {
577                        instructions.push(Instruction::Instruction(AnInstruction::AddI(
578                            static_jump_accumulator
579                        )));
580                        static_jump_accumulator = BFE::new(0);
581                    }
582
583                    instructions.extend([
584                        // _ *field_si
585                        Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
586                        // _ field_size (*field_si - 1)
587                        Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
588                        #field_ptr_hint,
589                        // _ field_size *field
590                        Instruction::Instruction(AnInstruction::Pick(ST::ST1)),
591                        // _ *field field_size
592                        Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
593                        Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
594                        Instruction::Instruction(AnInstruction::Lt),
595                        Instruction::Instruction(AnInstruction::Assert),
596                        Instruction::AssertionContext(AssertionContext::ID(185)),
597                        // _ *field field_size
598                        Instruction::Instruction(AnInstruction::Add),
599                        // _ *next_field_or_next_field_si
600                    ]);
601                }
602            ));
603        }
604
605        let (first_field_name, first_field_ty) = first_field;
606        let struct_name = struct_name.to_string();
607        let first_field_type_hint = Self::top_of_stack_pointer_type_hint(&first_field_name);
608        rust.extend(quote!(
609            if field_name != #first_field_name {
610                panic!("unknown field name `{field_name}` for type `{}`", #struct_name);
611            }
612            if static_jump_accumulator != BFE::new(0) {
613                instructions.push(Instruction::Instruction(AnInstruction::AddI(
614                    static_jump_accumulator
615                )));
616            }
617            if let Some(size) =
618                <#first_field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
619                    ::static_length()
620            {
621                instructions.extend([
622                    #first_field_type_hint,
623                    Instruction::Instruction(AnInstruction::Push(BFE::from(size))),
624                ]);
625            } else {
626                instructions.extend([
627                    // _ *field_si
628                    Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
629                    // _ field_size (*field_si - 1)
630                    Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
631                    #first_field_type_hint,
632                    // _ field_size *field
633                    Instruction::Instruction(AnInstruction::Pick(ST::ST1)),
634                    // _ *field field_size
635                    Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
636                    Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
637                    Instruction::Instruction(AnInstruction::Lt),
638                    Instruction::Instruction(AnInstruction::Assert),
639                    Instruction::AssertionContext(AssertionContext::ID(185)),
640                    // _ *field field_size
641                ]);
642            }
643
644            instructions
645        ));
646
647        rust
648    }
649
650    /// Generate the rust code for `TasmStruct::destructure()`, which will then
651    /// generate tasm code that
652    /// - assumes the stack is in the state `_ *struct`
653    /// - leaves the stack in the state `_ *field_n *field_(n-1) … *field_0`
654    fn generate_code_for_fn_destructure(&self) -> TokenStream {
655        debug_assert_eq!(self.field_types.len(), self.field_names.len());
656
657        let mut fields = self
658            .field_names
659            .iter()
660            .map(|n| n.to_string())
661            .zip(&self.field_types);
662
663        let type_aliases = Self::type_aliases();
664        let Some(first_field) = fields.next_back() else {
665            return quote! {
666                #type_aliases
667                [Instruction::Instruction(AnInstruction::Pop(N::N1))].to_vec()
668            };
669        };
670
671        let mut rust = quote! {
672            #type_aliases
673            let mut instructions = ::std::vec::Vec::new();
674        };
675
676        for (field_name, field_ty) in fields {
677            let field_ptr_hint = Self::top_of_stack_pointer_type_hint(&field_name);
678            rust.extend(quote!(
679            if let Some(size) =
680                <#field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>::static_length()
681            {
682                instructions.extend([
683                    #field_ptr_hint,
684                    // _ *field
685                    Instruction::Instruction(AnInstruction::Dup(ST::ST0)),
686                    // _ *field *field
687                    Instruction::Instruction(AnInstruction::AddI(BFE::from(size))),
688                    // _ *field *next_field_or_next_field_si
689                ]);
690            } else {
691                instructions.extend([
692                    // _ *field_si
693                    Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
694                    // _ field_size (*field_si - 1)
695                    Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
696                    #field_ptr_hint,
697                    // _ field_size *field
698                    Instruction::Instruction(AnInstruction::Dup(ST::ST0)),
699                    // _ field_size *field *field
700                    Instruction::Instruction(AnInstruction::Pick(ST::ST2)),
701                    // _ *field *field field_size
702                    Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
703                    Instruction::Instruction(AnInstruction::Dup(ST::ST1)),
704                    Instruction::Instruction(AnInstruction::Lt),
705                    Instruction::Instruction(AnInstruction::Assert),
706                    Instruction::AssertionContext(AssertionContext::ID(183)),
707                    // _ *field *field field_size
708                    Instruction::Instruction(AnInstruction::Add),
709                    // _ *field *next_field_or_next_field_si
710                ]);
711            }
712        ));
713        }
714
715        let (first_field_name, first_field_ty) = first_field;
716        let first_field_type_hint = Self::top_of_stack_pointer_type_hint(&first_field_name);
717        rust.extend(quote!(
718            if <#first_field_ty as crate::twenty_first::math::bfield_codec::BFieldCodec>
719                ::static_length().is_some()
720            {
721                instructions.push(#first_field_type_hint);
722            } else {
723                instructions.extend([
724                    // _ *field_si
725                    Instruction::Instruction(AnInstruction::ReadMem(N::N1)),
726                    // _ field_size (*field_si - 1)
727                    Instruction::Instruction(AnInstruction::AddI(BFE::new(2))),
728                    #first_field_type_hint,
729                    // _ field_size *field
730                    Instruction::Instruction(AnInstruction::Push(BFE::from(Self::MAX_OFFSET))),
731                    // _ field_size *field max_offset
732                    Instruction::Instruction(AnInstruction::Pick(ST::ST2)),
733                    // _ *field max_offset field_size
734                    Instruction::Instruction(AnInstruction::Lt),
735                    // _ *field (field_size < max_offset)
736                    Instruction::Instruction(AnInstruction::Assert),
737                    Instruction::AssertionContext(AssertionContext::ID(183)),
738                    // _ *field
739                ]);
740            }
741            instructions
742        ));
743
744        rust
745    }
746
747    fn top_of_stack_pointer_type_hint(field_name: &str) -> TokenStream {
748        quote!(
749            Instruction::TypeHint(
750                TypeHint {
751                    starting_index: 0,
752                    length: 1,
753                    type_name: ::std::option::Option::Some("Pointer".to_string()),
754                    variable_name: ::std::string::String::from(#field_name),
755                }
756            )
757        )
758    }
759}
760
761fn impl_tasm_object_derive_macro(ast: DeriveInput) -> TokenStream {
762    let parsed_struct = ParsedStruct::new(&ast);
763    let name = &ast.ident;
764    let name_as_string = ast.ident.to_string();
765
766    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
767    let ignored_field_types = parsed_struct.ignored_fields.iter().map(|f| f.ty.clone());
768    let new_where_clause = if let Some(old_where_clause) = where_clause {
769        quote! { #old_where_clause #(#ignored_field_types : Default ),* }
770    } else {
771        quote! { where #(#ignored_field_types : Default ),* }
772    };
773
774    let field_decoders = parsed_struct
775        .field_names
776        .iter()
777        .cloned()
778        .zip(parsed_struct.field_types.iter().cloned())
779        .map(|(field_name, field_ty)| field_decoder(field_name, field_ty));
780    let struct_builder = parsed_struct.struct_builder.clone();
781
782    let code_for_fn_compute_size_and_assert_valid_size_indicator =
783        parsed_struct.generate_code_for_fn_compute_size_and_assert_valid_size_indicator();
784    let code_for_fn_get_field = parsed_struct.generate_code_for_fn_get_field(&ast.ident);
785    let code_for_fn_get_field_with_size =
786        parsed_struct.generate_code_for_fn_get_field_with_size(&ast.ident);
787    let code_for_fn_destructure = parsed_struct.generate_code_for_fn_destructure();
788
789    quote! {
790        impl #impl_generics crate::tasm_lib::structure::tasm_object::TasmObject
791        for #name #ty_generics #new_where_clause {
792            fn label_friendly_name() -> String {
793                #name_as_string.to_owned()
794            }
795
796            fn compute_size_and_assert_valid_size_indicator(
797                library: &mut crate::tasm_lib::library::Library
798            ) -> ::std::vec::Vec<crate::triton_vm::isa::instruction::LabelledInstruction> {
799                #code_for_fn_compute_size_and_assert_valid_size_indicator
800            }
801
802            fn decode_iter<Itr: Iterator<Item = crate::triton_vm::prelude::BFieldElement>>(
803                iterator: &mut Itr
804            ) -> ::std::result::Result<
805                    ::std::boxed::Box<Self>,
806                    ::std::boxed::Box<
807                        dyn ::std::error::Error
808                        + ::core::marker::Send
809                        + ::core::marker::Sync
810                    >
811                >
812            {
813                #( #field_decoders )*
814                ::std::result::Result::Ok(::std::boxed::Box::new(#struct_builder))
815            }
816        }
817
818        impl #impl_generics crate::tasm_lib::structure::tasm_object::TasmStruct
819        for #name #ty_generics #new_where_clause {
820            fn get_field(
821                field_name: &str
822            ) -> ::std::vec::Vec<crate::triton_vm::isa::instruction::LabelledInstruction> {
823                #code_for_fn_get_field
824            }
825
826            fn get_field_with_size(
827                field_name: &str
828            ) -> ::std::vec::Vec<crate::triton_vm::isa::instruction::LabelledInstruction> {
829                #code_for_fn_get_field_with_size
830            }
831
832            fn destructure(
833            ) -> ::std::vec::Vec<crate::triton_vm::isa::instruction::LabelledInstruction> {
834                #code_for_fn_destructure
835            }
836        }
837    }
838}
839
840fn field_decoder(field_name: syn::Ident, field_type: syn::Type) -> TokenStream {
841    quote! {
842        let length = if let Some(static_length) =
843            <#field_type as crate::twenty_first::math::bfield_codec::BFieldCodec>::static_length()
844        {
845            static_length
846        } else {
847            iterator.next().ok_or("iterator exhausted")?.try_into()?
848        };
849        let sequence = (0..length)
850            .map(|_| iterator.next())
851            .collect::<::std::option::Option<::std::vec::Vec<_>>>()
852            .ok_or("iterator exhausted")?;
853        let #field_name : #field_type =
854            *crate::twenty_first::math::bfield_codec::BFieldCodec::decode(&sequence)?;
855    }
856}
857
858#[cfg(test)]
859mod tests {
860    use syn::parse_quote;
861
862    use super::*;
863
864    #[test]
865    fn unit_struct() {
866        let ast = parse_quote! {
867            #[derive(TasmObject)]
868            struct UnitStruct;
869        };
870        let _rust_code = impl_tasm_object_derive_macro(ast);
871    }
872
873    #[test]
874    fn tuple_struct() {
875        let ast = parse_quote! {
876            #[derive(TasmObject)]
877            struct TupleStruct(u64, u32);
878        };
879        let _rust_code = impl_tasm_object_derive_macro(ast);
880    }
881
882    #[test]
883    fn struct_with_named_fields() {
884        let ast = parse_quote! {
885            #[derive(TasmObject)]
886            struct StructWithNamedFields {
887                field1: u64,
888                field2: u32,
889                #[bfield_codec(ignore)]
890                ignored_field: bool,
891            }
892        };
893        let _rust_code = impl_tasm_object_derive_macro(ast);
894    }
895
896    #[test]
897    #[should_panic(expected = "expected a struct")] // enums are not supported (yet?)
898    fn enum_with_tuple_variants() {
899        let ast = parse_quote! {
900            #[derive(TasmObject)]
901            enum Enum {
902                Variant1,
903                Variant2(u64),
904                Variant3(u64, u32),
905                #[bfield_codec(ignore)]
906                IgnoredVariant,
907            }
908        };
909        let _rust_code = impl_tasm_object_derive_macro(ast);
910    }
911
912    #[test]
913    fn generic_tuple_struct() {
914        let ast = parse_quote! {
915            #[derive(TasmObject)]
916            struct TupleStruct<T>(T, (T, T));
917        };
918        let _rust_code = impl_tasm_object_derive_macro(ast);
919    }
920
921    #[test]
922    fn generic_struct_with_named_fields() {
923        let ast = parse_quote! {
924            #[derive(TasmObject)]
925            struct StructWithNamedFields<T> {
926                field1: T,
927                field2: (T, T),
928                #[bfield_codec(ignore)]
929                ignored_field: bool,
930            }
931        };
932        let _rust_code = impl_tasm_object_derive_macro(ast);
933    }
934
935    #[test]
936    #[should_panic(expected = "expected a struct")] // enums are not supported (yet?)
937    fn generic_enum() {
938        let ast = parse_quote! {
939            #[derive(TasmObject)]
940            enum Enum<T> {
941                Variant1,
942                Variant2(T),
943                Variant3(T, T),
944                #[bfield_codec(ignore)]
945                IgnoredVariant,
946            }
947        };
948        let _rust_code = impl_tasm_object_derive_macro(ast);
949    }
950
951    #[test]
952    fn struct_with_types_from_twenty_first() {
953        let ast = parse_quote! {
954            #[derive(TasmObject)]
955            struct WithComplexFields {
956                pub digest: Digest,
957                pub my_vec: Vec<BFieldElement>,
958            }
959        };
960        let _rust_code = impl_tasm_object_derive_macro(ast);
961    }
962
963    #[test]
964    fn where_clause_with_trailing_comma() {
965        let ast = parse_quote! {
966            #[derive(BFieldCodec, TasmObject)]
967            struct Foo<T>
968            where
969                T: BFieldCodec, { }
970            //                ^
971            //               this trailing comma
972        };
973        let rust_code = impl_tasm_object_derive_macro(ast);
974        println!("{}", prettyplease::unparse(&parse_quote!(#rust_code)));
975    }
976
977    #[test]
978    fn where_clause_with_trailing_comma_and_ignored_field() {
979        let ast = parse_quote! {
980            #[derive(BFieldCodec, TasmObject)]
981            struct Foo<S, T>
982            where
983                T: BFieldCodec, // <- this trailing comma
984            {
985                #[tasm_object(ignore)]
986                s: S,
987            }
988        };
989        let rust_code = impl_tasm_object_derive_macro(ast);
990        println!("{}", prettyplease::unparse(&parse_quote!(#rust_code)));
991    }
992}