Skip to main content

sp1_derive/
lib.rs

1// The `aligned_borrow_derive` macro is taken from valida-xyz/valida under MIT license
2//
3// The MIT License (MIT)
4//
5// Copyright (c) 2023 The Valida Authors
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25extern crate proc_macro;
26
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{
30    parse_macro_input, parse_quote, Data, DeriveInput, GenericParam, ItemFn, WherePredicate,
31};
32
33mod input_expr;
34mod input_params;
35mod into_shape;
36mod sp1_operation_builder;
37
38#[proc_macro_derive(AlignedBorrow)]
39pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
40    let ast = parse_macro_input!(input as DeriveInput);
41    let name = &ast.ident;
42
43    // Get first generic which must be type (ex. `T`) for input <T, N: NumLimbs, const M: usize>
44    let type_generic = ast
45        .generics
46        .params
47        .iter()
48        .map(|param| match param {
49            GenericParam::Type(type_param) => &type_param.ident,
50            _ => panic!("Expected first generic to be a type"),
51        })
52        .next()
53        .expect("Expected at least one generic");
54
55    // Get generics after the first (ex. `N: NumLimbs, const M: usize`)
56    // We need this because when we assert the size, we want to substitute u8 for T.
57    let non_first_generics = ast
58        .generics
59        .params
60        .iter()
61        .skip(1)
62        .filter_map(|param| match param {
63            GenericParam::Type(type_param) => Some(&type_param.ident),
64            GenericParam::Const(const_param) => Some(&const_param.ident),
65            _ => None,
66        })
67        .collect::<Vec<_>>();
68
69    // Get impl generics (`<T, N: NumLimbs, const M: usize>`), type generics (`<T, N>`), where
70    // clause (`where T: Clone`)
71    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
72
73    let methods = quote! {
74        impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
75            fn borrow(&self) -> &#name #type_generics {
76                debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
77                let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
78                debug_assert!(prefix.is_empty(), "Alignment should match");
79                debug_assert_eq!(shorts.len(), 1);
80                &shorts[0]
81            }
82        }
83
84        impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
85            fn borrow_mut(&mut self) -> &mut #name #type_generics {
86                debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
87                let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
88                debug_assert!(prefix.is_empty(), "Alignment should match");
89                debug_assert_eq!(shorts.len(), 1);
90                &mut shorts[0]
91            }
92        }
93    };
94
95    TokenStream::from(methods)
96}
97
98#[proc_macro_derive(
99    MachineAir,
100    attributes(execution_record_path, program_path, builder_path, eval_trait_bound)
101)]
102pub fn machine_air_derive(input: TokenStream) -> TokenStream {
103    let ast: syn::DeriveInput = syn::parse(input).unwrap();
104
105    let name = &ast.ident;
106    let generics = &ast.generics;
107    let execution_record_path = find_execution_record_path(&ast.attrs);
108    let program_path = find_program_path(&ast.attrs);
109    let builder_path = find_builder_path(&ast.attrs);
110    let eval_trait_bound = find_eval_trait_bound(&ast.attrs);
111    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
112
113    match &ast.data {
114        Data::Struct(_) => unimplemented!("Structs are not supported yet"),
115        Data::Enum(e) => {
116            let variants = e
117                .variants
118                .iter()
119                .map(|variant| {
120                    let variant_name = &variant.ident;
121
122                    let mut fields = variant.fields.iter();
123                    let field = fields.next().unwrap();
124                    assert!(fields.next().is_none(), "Only one field is supported");
125                    (variant_name, field)
126                })
127                .collect::<Vec<_>>();
128
129            let width_arms = variants.iter().map(|(variant_name, field)| {
130                let field_ty = &field.ty;
131                quote! {
132                    #name::#variant_name(x) => <#field_ty as slop_air::BaseAir<F>>::width(x)
133                }
134            });
135
136            let base_air = quote! {
137                impl #impl_generics slop_air::BaseAir<F> for #name #ty_generics #where_clause {
138                    fn width(&self) -> usize {
139                        match self {
140                            #(#width_arms,)*
141                        }
142                    }
143
144                    fn preprocessed_trace(&self) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
145                        unreachable!("A machine air should use the preprocessed trace from the `MachineAir` trait")
146                    }
147                }
148            };
149
150            let name_arms = variants.iter().map(|(variant_name, field)| {
151                let field_ty = &field.ty;
152                quote! {
153                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::name(x)
154                }
155            });
156
157            let preprocessed_num_rows_arms = variants.iter().map(|(variant_name, field)| {
158                let field_ty = &field.ty;
159                quote! {
160                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_num_rows(x, program)
161                }
162            });
163
164            let preprocessed_width_arms = variants.iter().map(|(variant_name, field)| {
165                let field_ty = &field.ty;
166                quote! {
167                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_width(x)
168                }
169            });
170
171            let generate_preprocessed_trace_arms = variants.iter().map(|(variant_name, field)| {
172                let field_ty = &field.ty;
173                quote! {
174                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace(x, program)
175                }
176            });
177
178            let generate_preprocessed_trace_into_arms = variants.iter().map(|(variant_name, field)| {
179                let field_ty = &field.ty;
180                quote! {
181                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace_into(x, program, buffer)
182                }
183            });
184
185            let generate_trace_arms = variants.iter().map(|(variant_name, field)| {
186                let field_ty = &field.ty;
187                quote! {
188                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace(x, input, output)
189                }
190            });
191
192            let generate_trace_into_arms = variants.iter().map(|(variant_name, field)| {
193                let field_ty = &field.ty;
194                quote! {
195                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace_into(x, input, output, buffer)
196                }
197            });
198
199            let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| {
200                let field_ty = &field.ty;
201                quote! {
202                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_dependencies(x, input, output)
203                }
204            });
205
206            let included_arms = variants.iter().map(|(variant_name, field)| {
207                let field_ty = &field.ty;
208                quote! {
209                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::included(x, shard)
210                }
211            });
212
213            let num_rows_arms = variants.iter().map(|(variant_name, field)| {
214                let field_ty = &field.ty;
215                quote! {
216                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::num_rows(x, input)
217                }
218            });
219
220            let machine_air = quote! {
221                impl #impl_generics sp1_hypercube::air::MachineAir<F> for #name #ty_generics #where_clause {
222                    type Record = #execution_record_path;
223
224                    type Program = #program_path;
225
226                    fn name(&self) -> &'static str {
227                        match self {
228                            #(#name_arms,)*
229                        }
230                    }
231
232                    fn preprocessed_width(&self) -> usize {
233                        match self {
234                            #(#preprocessed_width_arms,)*
235                        }
236                    }
237
238                    fn preprocessed_num_rows(&self, program: &#program_path,) -> Option<usize> {
239                        match self {
240                            #(#preprocessed_num_rows_arms,)*
241                        }
242                    }
243
244                    fn generate_preprocessed_trace(
245                        &self,
246                        program: &#program_path,
247                    ) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
248                        match self {
249                            #(#generate_preprocessed_trace_arms,)*
250                        }
251                    }
252
253                    fn generate_preprocessed_trace_into(
254                        &self,
255                        program: &#program_path,
256                        buffer: &mut [MaybeUninit<F>],
257                    ) {
258                        match self {
259                            #(#generate_preprocessed_trace_into_arms,)*
260                        }
261                    }
262
263                    fn generate_trace(
264                        &self,
265                        input: &#execution_record_path,
266                        output: &mut #execution_record_path,
267                    ) -> slop_matrix::dense::RowMajorMatrix<F> {
268                        match self {
269                            #(#generate_trace_arms,)*
270                        }
271                    }
272
273                    fn generate_trace_into(
274                        &self,
275                        input: &#execution_record_path,
276                        output: &mut #execution_record_path,
277                        buffer: &mut [MaybeUninit<F>],
278                    ){
279                        match self {
280                            #(#generate_trace_into_arms,)*
281                        }
282                    }
283
284                    fn generate_dependencies(
285                        &self,
286                        input: &#execution_record_path,
287                        output: &mut #execution_record_path,
288                    ) {
289                        match self {
290                            #(#generate_dependencies_arms,)*
291                        }
292                    }
293
294                    fn included(&self, shard: &Self::Record) -> bool {
295                        match self {
296                            #(#included_arms,)*
297                        }
298                    }
299
300                    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
301                        match self {
302                            #(#num_rows_arms,)*
303                    }
304                }
305                }
306            };
307
308            let eval_arms = variants.iter().map(|(variant_name, field)| {
309                let field_ty = &field.ty;
310                quote! {
311                    #name::#variant_name(x) => <#field_ty as slop_air::Air<AB>>::eval(x, builder)
312                }
313            });
314
315            // Attach an extra generic AB : crate::air::SP1AirBuilder to the generics of the enum
316            let generics = &ast.generics;
317            let mut new_generics = generics.clone();
318            new_generics
319                .params
320                .push(syn::parse_quote! { AB: slop_air::PairBuilder + #builder_path });
321
322            let (air_impl_generics, _, _) = new_generics.split_for_impl();
323
324            let mut new_generics = generics.clone();
325            let where_clause = new_generics.make_where_clause();
326            if let Some(eval_trait_bound) = eval_trait_bound {
327                let predicate: WherePredicate = syn::parse_str(&eval_trait_bound).unwrap();
328                where_clause.predicates.push(predicate);
329            }
330
331            let air = quote! {
332                impl #air_impl_generics slop_air::Air<AB> for #name #ty_generics #where_clause {
333                    fn eval(&self, builder: &mut AB) {
334                        match self {
335                            #(#eval_arms,)*
336                        }
337                    }
338                }
339            };
340
341            quote! {
342                #base_air
343
344                #machine_air
345
346                #air
347            }
348            .into()
349        }
350        Data::Union(_) => unimplemented!("Unions are not supported"),
351    }
352}
353
354#[proc_macro_attribute]
355pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
356    let input = parse_macro_input!(item as ItemFn);
357    let visibility = &input.vis;
358    let name = &input.sig.ident;
359    let inputs = &input.sig.inputs;
360    let output = &input.sig.output;
361    let block = &input.block;
362    let generics = &input.sig.generics;
363    let where_clause = &input.sig.generics.where_clause;
364
365    let result = quote! {
366        #visibility fn #name #generics (#inputs) #output #where_clause {
367            eprintln!("cycle-tracker-start: {}", stringify!(#name));
368            let result = (|| #block)();
369            eprintln!("cycle-tracker-end: {}", stringify!(#name));
370            result
371        }
372    };
373
374    result.into()
375}
376
377#[proc_macro_attribute]
378pub fn cycle_tracker_recursion(_attr: TokenStream, item: TokenStream) -> TokenStream {
379    let input = parse_macro_input!(item as ItemFn);
380    let visibility = &input.vis;
381    let name = &input.sig.ident;
382    let inputs = &input.sig.inputs;
383    let output = &input.sig.output;
384    let block = &input.block;
385    let generics = &input.sig.generics;
386    let where_clause = &input.sig.generics.where_clause;
387
388    let result = quote! {
389        #visibility fn #name #generics (#inputs) #output #where_clause {
390            sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_enter(builder, stringify!(#name));
391            let result = #block;
392            sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_exit(builder);
393            result
394        }
395    };
396
397    result.into()
398}
399
400fn find_execution_record_path(attrs: &[syn::Attribute]) -> syn::Path {
401    for attr in attrs {
402        if attr.path.is_ident("execution_record_path") {
403            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
404                if let syn::Lit::Str(lit_str) = &meta.lit {
405                    if let Ok(path) = lit_str.parse::<syn::Path>() {
406                        return path;
407                    }
408                }
409            }
410        }
411    }
412    parse_quote!(sp1_core_executor::ExecutionRecord)
413}
414
415fn find_program_path(attrs: &[syn::Attribute]) -> syn::Path {
416    for attr in attrs {
417        if attr.path.is_ident("program_path") {
418            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
419                if let syn::Lit::Str(lit_str) = &meta.lit {
420                    if let Ok(path) = lit_str.parse::<syn::Path>() {
421                        return path;
422                    }
423                }
424            }
425        }
426    }
427    parse_quote!(sp1_core_executor::Program)
428}
429
430fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path {
431    for attr in attrs {
432        if attr.path.is_ident("builder_path") {
433            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
434                if let syn::Lit::Str(lit_str) = &meta.lit {
435                    if let Ok(path) = lit_str.parse::<syn::Path>() {
436                        return path;
437                    }
438                }
439            }
440        }
441    }
442    parse_quote!(crate::air::SP1CoreAirBuilder<F = F>)
443}
444
445fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option<String> {
446    for attr in attrs {
447        if attr.path.is_ident("eval_trait_bound") {
448            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
449                if let syn::Lit::Str(lit_str) = &meta.lit {
450                    return Some(lit_str.value());
451                }
452            }
453        }
454    }
455
456    None
457}
458
459#[proc_macro_derive(IntoShape)]
460pub fn into_shape_derive(input: TokenStream) -> TokenStream {
461    into_shape::into_shape_derive(input)
462}
463
464#[proc_macro_derive(InputExpr)]
465pub fn input_expr_derive(input: TokenStream) -> TokenStream {
466    input_expr::input_expr_derive(input)
467}
468
469#[proc_macro_derive(InputParams, attributes(picus))]
470pub fn input_params_derive(input: TokenStream) -> TokenStream {
471    input_params::input_params_derive(input)
472}
473
474#[proc_macro_derive(SP1OperationBuilder)]
475pub fn sp1_operation_builder_derive(input: TokenStream) -> TokenStream {
476    sp1_operation_builder::sp1_operation_builder_derive(input)
477}