1extern crate proc_macro;
26
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{
30 parse_macro_input, parse_quote, Data, DeriveInput, GenericParam, ItemFn, WherePredicate,
31};
32
33mod input_expr;
34mod input_params;
35mod into_shape;
36mod sp1_operation_builder;
37
38#[proc_macro_derive(AlignedBorrow)]
39pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
40 let ast = parse_macro_input!(input as DeriveInput);
41 let name = &ast.ident;
42
43 let type_generic = ast
45 .generics
46 .params
47 .iter()
48 .map(|param| match param {
49 GenericParam::Type(type_param) => &type_param.ident,
50 _ => panic!("Expected first generic to be a type"),
51 })
52 .next()
53 .expect("Expected at least one generic");
54
55 let non_first_generics = ast
58 .generics
59 .params
60 .iter()
61 .skip(1)
62 .filter_map(|param| match param {
63 GenericParam::Type(type_param) => Some(&type_param.ident),
64 GenericParam::Const(const_param) => Some(&const_param.ident),
65 _ => None,
66 })
67 .collect::<Vec<_>>();
68
69 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
72
73 let methods = quote! {
74 impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
75 fn borrow(&self) -> &#name #type_generics {
76 debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
77 let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
78 debug_assert!(prefix.is_empty(), "Alignment should match");
79 debug_assert_eq!(shorts.len(), 1);
80 &shorts[0]
81 }
82 }
83
84 impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
85 fn borrow_mut(&mut self) -> &mut #name #type_generics {
86 debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
87 let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
88 debug_assert!(prefix.is_empty(), "Alignment should match");
89 debug_assert_eq!(shorts.len(), 1);
90 &mut shorts[0]
91 }
92 }
93 };
94
95 TokenStream::from(methods)
96}
97
98#[proc_macro_derive(
99 MachineAir,
100 attributes(execution_record_path, program_path, builder_path, eval_trait_bound)
101)]
102pub fn machine_air_derive(input: TokenStream) -> TokenStream {
103 let ast: syn::DeriveInput = syn::parse(input).unwrap();
104
105 let name = &ast.ident;
106 let generics = &ast.generics;
107 let execution_record_path = find_execution_record_path(&ast.attrs);
108 let program_path = find_program_path(&ast.attrs);
109 let builder_path = find_builder_path(&ast.attrs);
110 let eval_trait_bound = find_eval_trait_bound(&ast.attrs);
111 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
112
113 match &ast.data {
114 Data::Struct(_) => unimplemented!("Structs are not supported yet"),
115 Data::Enum(e) => {
116 let variants = e
117 .variants
118 .iter()
119 .map(|variant| {
120 let variant_name = &variant.ident;
121
122 let mut fields = variant.fields.iter();
123 let field = fields.next().unwrap();
124 assert!(fields.next().is_none(), "Only one field is supported");
125 (variant_name, field)
126 })
127 .collect::<Vec<_>>();
128
129 let width_arms = variants.iter().map(|(variant_name, field)| {
130 let field_ty = &field.ty;
131 quote! {
132 #name::#variant_name(x) => <#field_ty as slop_air::BaseAir<F>>::width(x)
133 }
134 });
135
136 let base_air = quote! {
137 impl #impl_generics slop_air::BaseAir<F> for #name #ty_generics #where_clause {
138 fn width(&self) -> usize {
139 match self {
140 #(#width_arms,)*
141 }
142 }
143
144 fn preprocessed_trace(&self) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
145 unreachable!("A machine air should use the preprocessed trace from the `MachineAir` trait")
146 }
147 }
148 };
149
150 let name_arms = variants.iter().map(|(variant_name, field)| {
151 let field_ty = &field.ty;
152 quote! {
153 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::name(x)
154 }
155 });
156
157 let preprocessed_num_rows_arms = variants.iter().map(|(variant_name, field)| {
158 let field_ty = &field.ty;
159 quote! {
160 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_num_rows(x, program)
161 }
162 });
163
164 let preprocessed_width_arms = variants.iter().map(|(variant_name, field)| {
165 let field_ty = &field.ty;
166 quote! {
167 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_width(x)
168 }
169 });
170
171 let generate_preprocessed_trace_arms = variants.iter().map(|(variant_name, field)| {
172 let field_ty = &field.ty;
173 quote! {
174 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace(x, program)
175 }
176 });
177
178 let generate_preprocessed_trace_into_arms = variants.iter().map(|(variant_name, field)| {
179 let field_ty = &field.ty;
180 quote! {
181 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace_into(x, program, buffer)
182 }
183 });
184
185 let generate_trace_arms = variants.iter().map(|(variant_name, field)| {
186 let field_ty = &field.ty;
187 quote! {
188 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace(x, input, output)
189 }
190 });
191
192 let generate_trace_into_arms = variants.iter().map(|(variant_name, field)| {
193 let field_ty = &field.ty;
194 quote! {
195 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace_into(x, input, output, buffer)
196 }
197 });
198
199 let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| {
200 let field_ty = &field.ty;
201 quote! {
202 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_dependencies(x, input, output)
203 }
204 });
205
206 let included_arms = variants.iter().map(|(variant_name, field)| {
207 let field_ty = &field.ty;
208 quote! {
209 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::included(x, shard)
210 }
211 });
212
213 let num_rows_arms = variants.iter().map(|(variant_name, field)| {
214 let field_ty = &field.ty;
215 quote! {
216 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::num_rows(x, input)
217 }
218 });
219
220 let machine_air = quote! {
221 impl #impl_generics sp1_hypercube::air::MachineAir<F> for #name #ty_generics #where_clause {
222 type Record = #execution_record_path;
223
224 type Program = #program_path;
225
226 fn name(&self) -> &'static str {
227 match self {
228 #(#name_arms,)*
229 }
230 }
231
232 fn preprocessed_width(&self) -> usize {
233 match self {
234 #(#preprocessed_width_arms,)*
235 }
236 }
237
238 fn preprocessed_num_rows(&self, program: &#program_path,) -> Option<usize> {
239 match self {
240 #(#preprocessed_num_rows_arms,)*
241 }
242 }
243
244 fn generate_preprocessed_trace(
245 &self,
246 program: &#program_path,
247 ) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
248 match self {
249 #(#generate_preprocessed_trace_arms,)*
250 }
251 }
252
253 fn generate_preprocessed_trace_into(
254 &self,
255 program: &#program_path,
256 buffer: &mut [MaybeUninit<F>],
257 ) {
258 match self {
259 #(#generate_preprocessed_trace_into_arms,)*
260 }
261 }
262
263 fn generate_trace(
264 &self,
265 input: &#execution_record_path,
266 output: &mut #execution_record_path,
267 ) -> slop_matrix::dense::RowMajorMatrix<F> {
268 match self {
269 #(#generate_trace_arms,)*
270 }
271 }
272
273 fn generate_trace_into(
274 &self,
275 input: &#execution_record_path,
276 output: &mut #execution_record_path,
277 buffer: &mut [MaybeUninit<F>],
278 ){
279 match self {
280 #(#generate_trace_into_arms,)*
281 }
282 }
283
284 fn generate_dependencies(
285 &self,
286 input: &#execution_record_path,
287 output: &mut #execution_record_path,
288 ) {
289 match self {
290 #(#generate_dependencies_arms,)*
291 }
292 }
293
294 fn included(&self, shard: &Self::Record) -> bool {
295 match self {
296 #(#included_arms,)*
297 }
298 }
299
300 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
301 match self {
302 #(#num_rows_arms,)*
303 }
304 }
305 }
306 };
307
308 let eval_arms = variants.iter().map(|(variant_name, field)| {
309 let field_ty = &field.ty;
310 quote! {
311 #name::#variant_name(x) => <#field_ty as slop_air::Air<AB>>::eval(x, builder)
312 }
313 });
314
315 let generics = &ast.generics;
317 let mut new_generics = generics.clone();
318 new_generics
319 .params
320 .push(syn::parse_quote! { AB: slop_air::PairBuilder + #builder_path });
321
322 let (air_impl_generics, _, _) = new_generics.split_for_impl();
323
324 let mut new_generics = generics.clone();
325 let where_clause = new_generics.make_where_clause();
326 if let Some(eval_trait_bound) = eval_trait_bound {
327 let predicate: WherePredicate = syn::parse_str(&eval_trait_bound).unwrap();
328 where_clause.predicates.push(predicate);
329 }
330
331 let air = quote! {
332 impl #air_impl_generics slop_air::Air<AB> for #name #ty_generics #where_clause {
333 fn eval(&self, builder: &mut AB) {
334 match self {
335 #(#eval_arms,)*
336 }
337 }
338 }
339 };
340
341 quote! {
342 #base_air
343
344 #machine_air
345
346 #air
347 }
348 .into()
349 }
350 Data::Union(_) => unimplemented!("Unions are not supported"),
351 }
352}
353
354#[proc_macro_attribute]
355pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
356 let input = parse_macro_input!(item as ItemFn);
357 let visibility = &input.vis;
358 let name = &input.sig.ident;
359 let inputs = &input.sig.inputs;
360 let output = &input.sig.output;
361 let block = &input.block;
362 let generics = &input.sig.generics;
363 let where_clause = &input.sig.generics.where_clause;
364
365 let result = quote! {
366 #visibility fn #name #generics (#inputs) #output #where_clause {
367 eprintln!("cycle-tracker-start: {}", stringify!(#name));
368 let result = (|| #block)();
369 eprintln!("cycle-tracker-end: {}", stringify!(#name));
370 result
371 }
372 };
373
374 result.into()
375}
376
377#[proc_macro_attribute]
378pub fn cycle_tracker_recursion(_attr: TokenStream, item: TokenStream) -> TokenStream {
379 let input = parse_macro_input!(item as ItemFn);
380 let visibility = &input.vis;
381 let name = &input.sig.ident;
382 let inputs = &input.sig.inputs;
383 let output = &input.sig.output;
384 let block = &input.block;
385 let generics = &input.sig.generics;
386 let where_clause = &input.sig.generics.where_clause;
387
388 let result = quote! {
389 #visibility fn #name #generics (#inputs) #output #where_clause {
390 sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_enter(builder, stringify!(#name));
391 let result = #block;
392 sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_exit(builder);
393 result
394 }
395 };
396
397 result.into()
398}
399
400fn find_execution_record_path(attrs: &[syn::Attribute]) -> syn::Path {
401 for attr in attrs {
402 if attr.path.is_ident("execution_record_path") {
403 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
404 if let syn::Lit::Str(lit_str) = &meta.lit {
405 if let Ok(path) = lit_str.parse::<syn::Path>() {
406 return path;
407 }
408 }
409 }
410 }
411 }
412 parse_quote!(sp1_core_executor::ExecutionRecord)
413}
414
415fn find_program_path(attrs: &[syn::Attribute]) -> syn::Path {
416 for attr in attrs {
417 if attr.path.is_ident("program_path") {
418 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
419 if let syn::Lit::Str(lit_str) = &meta.lit {
420 if let Ok(path) = lit_str.parse::<syn::Path>() {
421 return path;
422 }
423 }
424 }
425 }
426 }
427 parse_quote!(sp1_core_executor::Program)
428}
429
430fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path {
431 for attr in attrs {
432 if attr.path.is_ident("builder_path") {
433 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
434 if let syn::Lit::Str(lit_str) = &meta.lit {
435 if let Ok(path) = lit_str.parse::<syn::Path>() {
436 return path;
437 }
438 }
439 }
440 }
441 }
442 parse_quote!(crate::air::SP1CoreAirBuilder<F = F>)
443}
444
445fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option<String> {
446 for attr in attrs {
447 if attr.path.is_ident("eval_trait_bound") {
448 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
449 if let syn::Lit::Str(lit_str) = &meta.lit {
450 return Some(lit_str.value());
451 }
452 }
453 }
454 }
455
456 None
457}
458
459#[proc_macro_derive(IntoShape)]
460pub fn into_shape_derive(input: TokenStream) -> TokenStream {
461 into_shape::into_shape_derive(input)
462}
463
464#[proc_macro_derive(InputExpr)]
465pub fn input_expr_derive(input: TokenStream) -> TokenStream {
466 input_expr::input_expr_derive(input)
467}
468
469#[proc_macro_derive(InputParams, attributes(picus))]
470pub fn input_params_derive(input: TokenStream) -> TokenStream {
471 input_params::input_params_derive(input)
472}
473
474#[proc_macro_derive(SP1OperationBuilder)]
475pub fn sp1_operation_builder_derive(input: TokenStream) -> TokenStream {
476 sp1_operation_builder::sp1_operation_builder_derive(input)
477}