1extern crate proc_macro;
26
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{
30 parse_macro_input, parse_quote, Data, DeriveInput, GenericParam, ItemFn, WherePredicate,
31};
32
33mod input_expr;
34mod input_params;
35mod into_shape;
36mod sp1_operation_builder;
37
38#[proc_macro_derive(AlignedBorrow)]
39pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
40 let ast = parse_macro_input!(input as DeriveInput);
41 let name = &ast.ident;
42
43 let type_generic = ast
45 .generics
46 .params
47 .iter()
48 .map(|param| match param {
49 GenericParam::Type(type_param) => &type_param.ident,
50 _ => panic!("Expected first generic to be a type"),
51 })
52 .next()
53 .expect("Expected at least one generic");
54
55 let non_first_generics = ast
58 .generics
59 .params
60 .iter()
61 .skip(1)
62 .filter_map(|param| match param {
63 GenericParam::Type(type_param) => Some(&type_param.ident),
64 GenericParam::Const(const_param) => Some(&const_param.ident),
65 _ => None,
66 })
67 .collect::<Vec<_>>();
68
69 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
72
73 let methods = quote! {
74 impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
75 fn borrow(&self) -> &#name #type_generics {
76 debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
77 let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
78 debug_assert!(prefix.is_empty(), "Alignment should match");
79 debug_assert_eq!(shorts.len(), 1);
80 &shorts[0]
81 }
82 }
83
84 impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
85 fn borrow_mut(&mut self) -> &mut #name #type_generics {
86 debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
87 let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
88 debug_assert!(prefix.is_empty(), "Alignment should match");
89 debug_assert_eq!(shorts.len(), 1);
90 &mut shorts[0]
91 }
92 }
93 };
94
95 TokenStream::from(methods)
96}
97
98#[proc_macro_derive(
99 MachineAir,
100 attributes(execution_record_path, program_path, builder_path, eval_trait_bound)
101)]
102pub fn machine_air_derive(input: TokenStream) -> TokenStream {
103 let ast: syn::DeriveInput = syn::parse(input).unwrap();
104
105 let name = &ast.ident;
106 let generics = &ast.generics;
107 let execution_record_path = find_execution_record_path(&ast.attrs);
108 let program_path = find_program_path(&ast.attrs);
109 let builder_path = find_builder_path(&ast.attrs);
110 let eval_trait_bound = find_eval_trait_bound(&ast.attrs);
111 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
112
113 match &ast.data {
114 Data::Struct(_) => unimplemented!("Structs are not supported yet"),
115 Data::Enum(e) => {
116 let variants = e
117 .variants
118 .iter()
119 .map(|variant| {
120 let variant_name = &variant.ident;
121
122 let mut fields = variant.fields.iter();
123 let field = fields.next().unwrap();
124 assert!(fields.next().is_none(), "Only one field is supported");
125 (variant_name, field)
126 })
127 .collect::<Vec<_>>();
128
129 let width_arms = variants.iter().map(|(variant_name, field)| {
130 let field_ty = &field.ty;
131 quote! {
132 #name::#variant_name(x) => <#field_ty as slop_air::BaseAir<F>>::width(x)
133 }
134 });
135
136 let base_air = quote! {
137 impl #impl_generics slop_air::BaseAir<F> for #name #ty_generics #where_clause {
138 fn width(&self) -> usize {
139 match self {
140 #(#width_arms,)*
141 }
142 }
143
144 fn preprocessed_trace(&self) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
145 unreachable!("A machine air should use the preprocessed trace from the `MachineAir` trait")
146 }
147 }
148 };
149
150 let name_arms = variants.iter().map(|(variant_name, field)| {
151 let field_ty = &field.ty;
152 quote! {
153 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::name(x)
154 }
155 });
156
157 let column_names_arms = variants.iter().map(|(variant_name, field)| {
158 let field_ty = &field.ty;
159 quote! {
160 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::column_names(x)
161 }
162 });
163
164 let preprocessed_num_rows_arms = variants.iter().map(|(variant_name, field)| {
165 let field_ty = &field.ty;
166 quote! {
167 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_num_rows(x, program)
168 }
169 });
170
171 let preprocessed_width_arms = variants.iter().map(|(variant_name, field)| {
172 let field_ty = &field.ty;
173 quote! {
174 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::preprocessed_width(x)
175 }
176 });
177
178 let generate_preprocessed_trace_arms = variants.iter().map(|(variant_name, field)| {
179 let field_ty = &field.ty;
180 quote! {
181 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace(x, program)
182 }
183 });
184
185 let generate_preprocessed_trace_into_arms = variants.iter().map(|(variant_name, field)| {
186 let field_ty = &field.ty;
187 quote! {
188 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_preprocessed_trace_into(x, program, buffer)
189 }
190 });
191
192 let generate_trace_arms = variants.iter().map(|(variant_name, field)| {
193 let field_ty = &field.ty;
194 quote! {
195 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace(x, input, output)
196 }
197 });
198
199 let generate_trace_into_arms = variants.iter().map(|(variant_name, field)| {
200 let field_ty = &field.ty;
201 quote! {
202 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_trace_into(x, input, output, buffer)
203 }
204 });
205
206 let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| {
207 let field_ty = &field.ty;
208 quote! {
209 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::generate_dependencies(x, input, output)
210 }
211 });
212
213 let included_arms = variants.iter().map(|(variant_name, field)| {
214 let field_ty = &field.ty;
215 quote! {
216 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::included(x, shard)
217 }
218 });
219
220 let num_rows_arms = variants.iter().map(|(variant_name, field)| {
221 let field_ty = &field.ty;
222 quote! {
223 #name::#variant_name(x) => <#field_ty as sp1_hypercube::air::MachineAir<F>>::num_rows(x, input)
224 }
225 });
226
227 let machine_air = quote! {
228 impl #impl_generics sp1_hypercube::air::MachineAir<F> for #name #ty_generics #where_clause {
229 type Record = #execution_record_path;
230
231 type Program = #program_path;
232
233 fn name(&self) -> &'static str {
234 match self {
235 #(#name_arms,)*
236 }
237 }
238
239 fn column_names(&self) -> Vec<String> {
240 match self {
241 #(#column_names_arms,)*
242 }
243 }
244
245 fn preprocessed_width(&self) -> usize {
246 match self {
247 #(#preprocessed_width_arms,)*
248 }
249 }
250
251 fn preprocessed_num_rows(&self, program: &#program_path,) -> Option<usize> {
252 match self {
253 #(#preprocessed_num_rows_arms,)*
254 }
255 }
256
257 fn generate_preprocessed_trace(
258 &self,
259 program: &#program_path,
260 ) -> Option<slop_matrix::dense::RowMajorMatrix<F>> {
261 match self {
262 #(#generate_preprocessed_trace_arms,)*
263 }
264 }
265
266 fn generate_preprocessed_trace_into(
267 &self,
268 program: &#program_path,
269 buffer: &mut [MaybeUninit<F>],
270 ) {
271 match self {
272 #(#generate_preprocessed_trace_into_arms,)*
273 }
274 }
275
276 fn generate_trace(
277 &self,
278 input: &#execution_record_path,
279 output: &mut #execution_record_path,
280 ) -> slop_matrix::dense::RowMajorMatrix<F> {
281 match self {
282 #(#generate_trace_arms,)*
283 }
284 }
285
286 fn generate_trace_into(
287 &self,
288 input: &#execution_record_path,
289 output: &mut #execution_record_path,
290 buffer: &mut [MaybeUninit<F>],
291 ){
292 match self {
293 #(#generate_trace_into_arms,)*
294 }
295 }
296
297 fn generate_dependencies(
298 &self,
299 input: &#execution_record_path,
300 output: &mut #execution_record_path,
301 ) {
302 match self {
303 #(#generate_dependencies_arms,)*
304 }
305 }
306
307 fn included(&self, shard: &Self::Record) -> bool {
308 match self {
309 #(#included_arms,)*
310 }
311 }
312
313 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
314 match self {
315 #(#num_rows_arms,)*
316 }
317 }
318 }
319 };
320
321 let eval_arms = variants.iter().map(|(variant_name, field)| {
322 let field_ty = &field.ty;
323 quote! {
324 #name::#variant_name(x) => <#field_ty as slop_air::Air<AB>>::eval(x, builder)
325 }
326 });
327
328 let generics = &ast.generics;
330 let mut new_generics = generics.clone();
331 new_generics
332 .params
333 .push(syn::parse_quote! { AB: slop_air::PairBuilder + #builder_path });
334
335 let (air_impl_generics, _, _) = new_generics.split_for_impl();
336
337 let mut new_generics = generics.clone();
338 let where_clause = new_generics.make_where_clause();
339 if let Some(eval_trait_bound) = eval_trait_bound {
340 let predicate: WherePredicate = syn::parse_str(&eval_trait_bound).unwrap();
341 where_clause.predicates.push(predicate);
342 }
343
344 let air = quote! {
345 impl #air_impl_generics slop_air::Air<AB> for #name #ty_generics #where_clause {
346 fn eval(&self, builder: &mut AB) {
347 match self {
348 #(#eval_arms,)*
349 }
350 }
351 }
352 };
353
354 quote! {
355 #base_air
356
357 #machine_air
358
359 #air
360 }
361 .into()
362 }
363 Data::Union(_) => unimplemented!("Unions are not supported"),
364 }
365}
366
367#[proc_macro_attribute]
368pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
369 let input = parse_macro_input!(item as ItemFn);
370 let visibility = &input.vis;
371 let name = &input.sig.ident;
372 let inputs = &input.sig.inputs;
373 let output = &input.sig.output;
374 let block = &input.block;
375 let generics = &input.sig.generics;
376 let where_clause = &input.sig.generics.where_clause;
377
378 let result = quote! {
379 #visibility fn #name #generics (#inputs) #output #where_clause {
380 eprintln!("cycle-tracker-start: {}", stringify!(#name));
381 let result = (|| #block)();
382 eprintln!("cycle-tracker-end: {}", stringify!(#name));
383 result
384 }
385 };
386
387 result.into()
388}
389
390#[proc_macro_attribute]
391pub fn cycle_tracker_recursion(_attr: TokenStream, item: TokenStream) -> TokenStream {
392 let input = parse_macro_input!(item as ItemFn);
393 let visibility = &input.vis;
394 let name = &input.sig.ident;
395 let inputs = &input.sig.inputs;
396 let output = &input.sig.output;
397 let block = &input.block;
398 let generics = &input.sig.generics;
399 let where_clause = &input.sig.generics.where_clause;
400
401 let result = quote! {
402 #visibility fn #name #generics (#inputs) #output #where_clause {
403 sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_enter(builder, stringify!(#name));
404 let result = #block;
405 sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_exit(builder);
406 result
407 }
408 };
409
410 result.into()
411}
412
413fn find_execution_record_path(attrs: &[syn::Attribute]) -> syn::Path {
414 for attr in attrs {
415 if attr.path.is_ident("execution_record_path") {
416 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
417 if let syn::Lit::Str(lit_str) = &meta.lit {
418 if let Ok(path) = lit_str.parse::<syn::Path>() {
419 return path;
420 }
421 }
422 }
423 }
424 }
425 parse_quote!(sp1_core_executor::ExecutionRecord)
426}
427
428fn find_program_path(attrs: &[syn::Attribute]) -> syn::Path {
429 for attr in attrs {
430 if attr.path.is_ident("program_path") {
431 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
432 if let syn::Lit::Str(lit_str) = &meta.lit {
433 if let Ok(path) = lit_str.parse::<syn::Path>() {
434 return path;
435 }
436 }
437 }
438 }
439 }
440 parse_quote!(sp1_core_executor::Program)
441}
442
443fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path {
444 for attr in attrs {
445 if attr.path.is_ident("builder_path") {
446 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
447 if let syn::Lit::Str(lit_str) = &meta.lit {
448 if let Ok(path) = lit_str.parse::<syn::Path>() {
449 return path;
450 }
451 }
452 }
453 }
454 }
455 parse_quote!(crate::air::SP1CoreAirBuilder<F = F>)
456}
457
458fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option<String> {
459 for attr in attrs {
460 if attr.path.is_ident("eval_trait_bound") {
461 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
462 if let syn::Lit::Str(lit_str) = &meta.lit {
463 return Some(lit_str.value());
464 }
465 }
466 }
467 }
468
469 None
470}
471
472#[proc_macro_derive(IntoShape)]
473pub fn into_shape_derive(input: TokenStream) -> TokenStream {
474 into_shape::into_shape_derive(input)
475}
476
477#[proc_macro_derive(InputExpr)]
478pub fn input_expr_derive(input: TokenStream) -> TokenStream {
479 input_expr::input_expr_derive(input)
480}
481
482#[proc_macro_derive(InputParams, attributes(picus))]
483pub fn input_params_derive(input: TokenStream) -> TokenStream {
484 input_params::input_params_derive(input)
485}
486
487#[proc_macro_derive(SP1OperationBuilder)]
488pub fn sp1_operation_builder_derive(input: TokenStream) -> TokenStream {
489 sp1_operation_builder::sp1_operation_builder_derive(input)
490}