sp1_recursion_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields};
6
7#[proc_macro_derive(DslVariable)]
8pub fn derive_variable(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    let name = input.ident; // Struct name
11
12    let gen = match input.data {
13        Data::Struct(data) => match data.fields {
14            Fields::Named(fields) => {
15                let fields_init = fields.named.iter().map(|f| {
16                    let fname = &f.ident;
17                    let ftype = &f.ty;
18                    let ftype_str = quote! { #ftype }.to_string();
19                    if ftype_str.contains("Array") {
20                        quote! {
21                            #fname: Array::Dyn(builder.uninit(), builder.uninit()),
22                        }
23                    } else {
24                        quote! {
25                            #fname: <#ftype as Variable<C>>::uninit(builder),
26                        }
27                    }
28                });
29
30                let fields_assign = fields.named.iter().map(|f| {
31                    let fname = &f.ident;
32                    quote! {
33                        self.#fname.assign(src.#fname.into(), builder);
34                    }
35                });
36
37                let fields_assert_eq = fields.named.iter().map(|f| {
38                    let fname = &f.ident;
39                    let ftype = &f.ty;
40                    quote! {
41                        <#ftype as Variable<C>>::assert_eq(lhs.#fname, rhs.#fname, builder);
42                    }
43                });
44
45                let fields_assert_ne = fields.named.iter().map(|f| {
46                    let fname = &f.ident;
47                    let ftype = &f.ty;
48                    quote! {
49                        <#ftype as Variable<C>>::assert_ne(lhs.#fname, rhs.#fname, builder);
50                    }
51                });
52
53                let field_sizes = fields.named.iter().map(|f| {
54                    let ftype = &f.ty;
55                    quote! {
56                        <#ftype as MemVariable<C>>::size_of()
57                    }
58                });
59
60                let field_loads = fields.named.iter().map(|f| {
61                    let fname = &f.ident;
62                    let ftype = &f.ty;
63                    quote! {
64                        {
65                            // let address = builder.eval(ptr + Usize::Const(offset));
66                            self.#fname.load(ptr, index, builder);
67                            index.offset += <#ftype as MemVariable<C>>::size_of();
68                        }
69                    }
70                });
71
72                let field_stores = fields.named.iter().map(|f| {
73                    let fname = &f.ident;
74                    let ftype = &f.ty;
75                    quote! {
76                        {
77                            // let address = builder.eval(ptr + Usize::Const(offset));
78                            self.#fname.store(ptr, index, builder);
79                            index.offset += <#ftype as MemVariable<C>>::size_of();
80                        }
81                    }
82                });
83
84                quote! {
85                    impl<C: Config> Variable<C> for #name<C> {
86                        type Expression = Self;
87
88                        fn uninit(builder: &mut Builder<C>) -> Self {
89                            Self {
90                                #(#fields_init)*
91                            }
92                        }
93
94                        fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
95                            #(#fields_assign)*
96                        }
97
98                        fn assert_eq(
99                            lhs: impl Into<Self::Expression>,
100                            rhs: impl Into<Self::Expression>,
101                            builder: &mut Builder<C>,
102                        ) {
103                            let lhs = lhs.into();
104                            let rhs = rhs.into();
105                            #(#fields_assert_eq)*
106                        }
107
108                        fn assert_ne(
109                            lhs: impl Into<Self::Expression>,
110                            rhs: impl Into<Self::Expression>,
111                            builder: &mut Builder<C>,
112                        ) {
113                            let lhs = lhs.into();
114                            let rhs = rhs.into();
115                            #(#fields_assert_ne)*
116                        }
117                    }
118
119                    impl<C: Config> MemVariable<C> for #name<C> {
120                        fn size_of() -> usize {
121                            let mut size = 0;
122                            #(size += #field_sizes;)*
123                            size
124                        }
125
126                        fn load(&self, ptr: Ptr<<C as Config>::N>,
127                            index: MemIndex<<C as Config>::N>,
128                            builder: &mut Builder<C>) {
129                            let mut index = index;
130                            #(#field_loads)*
131                        }
132
133                        fn store(&self, ptr: Ptr<<C as Config>::N>,
134                                 index: MemIndex<<C as Config>::N>,
135                                builder: &mut Builder<C>) {
136                            let mut index = index;
137                            #(#field_stores)*
138                        }
139                    }
140                }
141            }
142            _ => unimplemented!(),
143        },
144        _ => unimplemented!(),
145    };
146
147    gen.into()
148}