Skip to main content

sp1_derive/
lib.rs

1// The `aligned_borrow_derive` macro is taken from valida-xyz/valida under MIT license
2//
3// The MIT License (MIT)
4//
5// Copyright (c) 2023 The Valida Authors
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25extern crate proc_macro;
26
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{
30    parse_macro_input, parse_quote, Data, DeriveInput, GenericParam, ItemFn, WherePredicate,
31};
32
33mod input_expr;
34mod input_params;
35mod into_shape;
36mod sp1_operation_builder;
37
38#[proc_macro_derive(AlignedBorrow)]
39pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
40    let ast = parse_macro_input!(input as DeriveInput);
41    let name = &ast.ident;
42
43    // Get first generic which must be type (ex. `T`) for input <T, N: NumLimbs, const M: usize>
44    let type_generic = ast
45        .generics
46        .params
47        .iter()
48        .map(|param| match param {
49            GenericParam::Type(type_param) => &type_param.ident,
50            _ => panic!("Expected first generic to be a type"),
51        })
52        .next()
53        .expect("Expected at least one generic");
54
55    // Get generics after the first (ex. `N: NumLimbs, const M: usize`)
56    // We need this because when we assert the size, we want to substitute u8 for T.
57    let non_first_generics = ast
58        .generics
59        .params
60        .iter()
61        .skip(1)
62        .filter_map(|param| match param {
63            GenericParam::Type(type_param) => Some(&type_param.ident),
64            GenericParam::Const(const_param) => Some(&const_param.ident),
65            _ => None,
66        })
67        .collect::<Vec<_>>();
68
69    // Get impl generics (`<T, N: NumLimbs, const M: usize>`), type generics (`<T, N>`), where
70    // clause (`where T: Clone`)
71    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
72
73    let methods = quote! {
74        impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
75            fn borrow(&self) -> &#name #type_generics {
76                debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
77                let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
78                debug_assert!(prefix.is_empty(), "Alignment should match");
79                debug_assert_eq!(shorts.len(), 1);
80                &shorts[0]
81            }
82        }
83
84        impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
85            fn borrow_mut(&mut self) -> &mut #name #type_generics {
86                debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
87                let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
88                debug_assert!(prefix.is_empty(), "Alignment should match");
89                debug_assert_eq!(shorts.len(), 1);
90                &mut shorts[0]
91            }
92        }
93    };
94
95    TokenStream::from(methods)
96}
97
98#[proc_macro_derive(
99    MachineAir,
100    attributes(execution_record_path, program_path, builder_path, eval_trait_bound)
101)]
102pub fn machine_air_derive(input: TokenStream) -> TokenStream {
103    let ast: syn::DeriveInput = syn::parse(input).unwrap();
104
105    let name = &ast.ident;
106    let generics = &ast.generics;
107    let execution_record_path = find_execution_record_path(&ast.attrs);
108    let program_path = find_program_path(&ast.attrs);
109    let builder_path = find_builder_path(&ast.attrs);
110    let eval_trait_bound = find_eval_trait_bound(&ast.attrs);
111    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
112
113    match &ast.data {
114        Data::Struct(_) => unimplemented!("Structs are not supported yet"),
115        Data::Enum(e) => {
116            let variants = e
117                .variants
118                .iter()
119                .map(|variant| {
120                    let variant_name = &variant.ident;
121
122                    let mut fields = variant.fields.iter();
123                    let field = fields.next().unwrap();
124                    assert!(fields.next().is_none(), "Only one field is supported");
125                    (variant_name, field)
126                })
127                .collect::<Vec<_>>();
128
129            let width_arms = variants.iter().map(|(variant_name, field)| {
130                let field_ty = &field.ty;
131                quote! {
132                    #name::#variant_name(x) => <#field_ty as slop_air::BaseAir<F>>::width(x)
133                }
134            });
135
136            let base_air = quote! {
137                impl #impl_generics slop_air::BaseAir<F> for #name #ty_generics #where_clause {
138                    fn width(&self) -> usize {
139                        match self {
140                            #(#width_arms,)*
141                        }
142                    }
143
144                    fn preprocessed_trace(&self) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
145                        unreachable!("A machine air should use the preprocessed trace from the `MachineAir` trait")
146                    }
147                }
148            };
149
150            let name_arms = variants.iter().map(|(variant_name, field)| {
151                let field_ty = &field.ty;
152                quote! {
153                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::name(x)
154                }
155            });
156
157            let column_names_arms = variants.iter().map(|(variant_name, field)| {
158                let field_ty = &field.ty;
159                quote! {
160                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::column_names(x)
161                }
162            });
163
164            let preprocessed_num_rows_arms = variants.iter().map(|(variant_name, field)| {
165                let field_ty = &field.ty;
166                quote! {
167                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_num_rows(x, program)
168                }
169            });
170
171            let preprocessed_width_arms = variants.iter().map(|(variant_name, field)| {
172                let field_ty = &field.ty;
173                quote! {
174                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_width(x)
175                }
176            });
177
178            let generate_preprocessed_trace_arms = variants.iter().map(|(variant_name, field)| {
179                let field_ty = &field.ty;
180                quote! {
181                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace(x, program)
182                }
183            });
184
185            let generate_preprocessed_trace_into_arms = variants.iter().map(|(variant_name, field)| {
186                let field_ty = &field.ty;
187                quote! {
188                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace_into(x, program, buffer)
189                }
190            });
191
192            let generate_trace_arms = variants.iter().map(|(variant_name, field)| {
193                let field_ty = &field.ty;
194                quote! {
195                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace(x, input, output)
196                }
197            });
198
199            let generate_trace_into_arms = variants.iter().map(|(variant_name, field)| {
200                let field_ty = &field.ty;
201                quote! {
202                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace_into(x, input, output, buffer)
203                }
204            });
205
206            let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| {
207                let field_ty = &field.ty;
208                quote! {
209                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_dependencies(x, input, output)
210                }
211            });
212
213            let included_arms = variants.iter().map(|(variant_name, field)| {
214                let field_ty = &field.ty;
215                quote! {
216                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::included(x, shard)
217                }
218            });
219
220            let num_rows_arms = variants.iter().map(|(variant_name, field)| {
221                let field_ty = &field.ty;
222                quote! {
223                    #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::num_rows(x, input)
224                }
225            });
226
227            let machine_air = quote! {
228                impl #impl_generics sp1_hypercube::air::MachineAir<F> for #name #ty_generics #where_clause {
229                    type Record = #execution_record_path;
230
231                    type Program = #program_path;
232
233                    fn name(&self) -> &'static str {
234                        match self {
235                            #(#name_arms,)*
236                        }
237                    }
238
239                    fn column_names(&self) -> Vec<String> {
240                        match self {
241                            #(#column_names_arms,)*
242                        }
243                    }
244
245                    fn preprocessed_width(&self) -> usize {
246                        match self {
247                            #(#preprocessed_width_arms,)*
248                        }
249                    }
250
251                    fn preprocessed_num_rows(&self, program: &#program_path,) -> Option<usize> {
252                        match self {
253                            #(#preprocessed_num_rows_arms,)*
254                        }
255                    }
256
257                    fn generate_preprocessed_trace(
258                        &self,
259                        program: &#program_path,
260                    ) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
261                        match self {
262                            #(#generate_preprocessed_trace_arms,)*
263                        }
264                    }
265
266                    fn generate_preprocessed_trace_into(
267                        &self,
268                        program: &#program_path,
269                        buffer: &mut [MaybeUninit<F>],
270                    ) {
271                        match self {
272                            #(#generate_preprocessed_trace_into_arms,)*
273                        }
274                    }
275
276                    fn generate_trace(
277                        &self,
278                        input: &#execution_record_path,
279                        output: &mut #execution_record_path,
280                    ) -> slop_matrix::dense::RowMajorMatrix<F> {
281                        match self {
282                            #(#generate_trace_arms,)*
283                        }
284                    }
285
286                    fn generate_trace_into(
287                        &self,
288                        input: &#execution_record_path,
289                        output: &mut #execution_record_path,
290                        buffer: &mut [MaybeUninit<F>],
291                    ){
292                        match self {
293                            #(#generate_trace_into_arms,)*
294                        }
295                    }
296
297                    fn generate_dependencies(
298                        &self,
299                        input: &#execution_record_path,
300                        output: &mut #execution_record_path,
301                    ) {
302                        match self {
303                            #(#generate_dependencies_arms,)*
304                        }
305                    }
306
307                    fn included(&self, shard: &Self::Record) -> bool {
308                        match self {
309                            #(#included_arms,)*
310                        }
311                    }
312
313                    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
314                        match self {
315                            #(#num_rows_arms,)*
316                        }
317                    }
318                }
319            };
320
321            let eval_arms = variants.iter().map(|(variant_name, field)| {
322                let field_ty = &field.ty;
323                quote! {
324                    #name::#variant_name(x) => <#field_ty as slop_air::Air<AB>>::eval(x, builder)
325                }
326            });
327
328            // Attach an extra generic AB : crate::air::SP1AirBuilder to the generics of the enum
329            let generics = &ast.generics;
330            let mut new_generics = generics.clone();
331            new_generics
332                .params
333                .push(syn::parse_quote! { AB: slop_air::PairBuilder + #builder_path });
334
335            let (air_impl_generics, _, _) = new_generics.split_for_impl();
336
337            let mut new_generics = generics.clone();
338            let where_clause = new_generics.make_where_clause();
339            if let Some(eval_trait_bound) = eval_trait_bound {
340                let predicate: WherePredicate = syn::parse_str(&eval_trait_bound).unwrap();
341                where_clause.predicates.push(predicate);
342            }
343
344            let air = quote! {
345                impl #air_impl_generics slop_air::Air<AB> for #name #ty_generics #where_clause {
346                    fn eval(&self, builder: &mut AB) {
347                        match self {
348                            #(#eval_arms,)*
349                        }
350                    }
351                }
352            };
353
354            quote! {
355                #base_air
356
357                #machine_air
358
359                #air
360            }
361            .into()
362        }
363        Data::Union(_) => unimplemented!("Unions are not supported"),
364    }
365}
366
367#[proc_macro_attribute]
368pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
369    let input = parse_macro_input!(item as ItemFn);
370    let visibility = &input.vis;
371    let name = &input.sig.ident;
372    let inputs = &input.sig.inputs;
373    let output = &input.sig.output;
374    let block = &input.block;
375    let generics = &input.sig.generics;
376    let where_clause = &input.sig.generics.where_clause;
377
378    let result = quote! {
379        #visibility fn #name #generics (#inputs) #output #where_clause {
380            eprintln!("cycle-tracker-start: {}", stringify!(#name));
381            let result = (|| #block)();
382            eprintln!("cycle-tracker-end: {}", stringify!(#name));
383            result
384        }
385    };
386
387    result.into()
388}
389
390#[proc_macro_attribute]
391pub fn cycle_tracker_recursion(_attr: TokenStream, item: TokenStream) -> TokenStream {
392    let input = parse_macro_input!(item as ItemFn);
393    let visibility = &input.vis;
394    let name = &input.sig.ident;
395    let inputs = &input.sig.inputs;
396    let output = &input.sig.output;
397    let block = &input.block;
398    let generics = &input.sig.generics;
399    let where_clause = &input.sig.generics.where_clause;
400
401    let result = quote! {
402        #visibility fn #name #generics (#inputs) #output #where_clause {
403            sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_enter(builder, stringify!(#name));
404            let result = #block;
405            sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_exit(builder);
406            result
407        }
408    };
409
410    result.into()
411}
412
413fn find_execution_record_path(attrs: &[syn::Attribute]) -> syn::Path {
414    for attr in attrs {
415        if attr.path.is_ident("execution_record_path") {
416            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
417                if let syn::Lit::Str(lit_str) = &meta.lit {
418                    if let Ok(path) = lit_str.parse::<syn::Path>() {
419                        return path;
420                    }
421                }
422            }
423        }
424    }
425    parse_quote!(sp1_core_executor::ExecutionRecord)
426}
427
428fn find_program_path(attrs: &[syn::Attribute]) -> syn::Path {
429    for attr in attrs {
430        if attr.path.is_ident("program_path") {
431            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
432                if let syn::Lit::Str(lit_str) = &meta.lit {
433                    if let Ok(path) = lit_str.parse::<syn::Path>() {
434                        return path;
435                    }
436                }
437            }
438        }
439    }
440    parse_quote!(sp1_core_executor::Program)
441}
442
443fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path {
444    for attr in attrs {
445        if attr.path.is_ident("builder_path") {
446            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
447                if let syn::Lit::Str(lit_str) = &meta.lit {
448                    if let Ok(path) = lit_str.parse::<syn::Path>() {
449                        return path;
450                    }
451                }
452            }
453        }
454    }
455    parse_quote!(crate::air::SP1CoreAirBuilder<F = F>)
456}
457
458fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option<String> {
459    for attr in attrs {
460        if attr.path.is_ident("eval_trait_bound") {
461            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
462                if let syn::Lit::Str(lit_str) = &meta.lit {
463                    return Some(lit_str.value());
464                }
465            }
466        }
467    }
468
469    None
470}
471
472#[proc_macro_derive(IntoShape)]
473pub fn into_shape_derive(input: TokenStream) -> TokenStream {
474    into_shape::into_shape_derive(input)
475}
476
477#[proc_macro_derive(InputExpr)]
478pub fn input_expr_derive(input: TokenStream) -> TokenStream {
479    input_expr::input_expr_derive(input)
480}
481
482#[proc_macro_derive(InputParams, attributes(picus))]
483pub fn input_params_derive(input: TokenStream) -> TokenStream {
484    input_params::input_params_derive(input)
485}
486
487#[proc_macro_derive(SP1OperationBuilder)]
488pub fn sp1_operation_builder_derive(input: TokenStream) -> TokenStream {
489    sp1_operation_builder::sp1_operation_builder_derive(input)
490}