1extern crate proc_macro;
26
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{
30 parse_macro_input, parse_quote, Data, DeriveInput, GenericParam, ItemFn, WherePredicate,
31};
32
33#[proc_macro_derive(AlignedBorrow)]
34pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
35 let ast = parse_macro_input!(input as DeriveInput);
36 let name = &ast.ident;
37
38 let type_generic = ast
40 .generics
41 .params
42 .iter()
43 .map(|param| match param {
44 GenericParam::Type(type_param) => &type_param.ident,
45 _ => panic!("Expected first generic to be a type"),
46 })
47 .next()
48 .expect("Expected at least one generic");
49
50 let non_first_generics = ast
53 .generics
54 .params
55 .iter()
56 .skip(1)
57 .filter_map(|param| match param {
58 GenericParam::Type(type_param) => Some(&type_param.ident),
59 GenericParam::Const(const_param) => Some(&const_param.ident),
60 _ => None,
61 })
62 .collect::<Vec<_>>();
63
64 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
67
68 let methods = quote! {
69 impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
70 fn borrow(&self) -> &#name #type_generics {
71 debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
72 let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
73 debug_assert!(prefix.is_empty(), "Alignment should match");
74 debug_assert_eq!(shorts.len(), 1);
75 &shorts[0]
76 }
77 }
78
79 impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
80 fn borrow_mut(&mut self) -> &mut #name #type_generics {
81 debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
82 let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
83 debug_assert!(prefix.is_empty(), "Alignment should match");
84 debug_assert_eq!(shorts.len(), 1);
85 &mut shorts[0]
86 }
87 }
88 };
89
90 TokenStream::from(methods)
91}
92
93#[proc_macro_derive(
94 MachineAir,
95 attributes(sp1_core_path, execution_record_path, program_path, builder_path, eval_trait_bound)
96)]
97pub fn machine_air_derive(input: TokenStream) -> TokenStream {
98 let ast: syn::DeriveInput = syn::parse(input).unwrap();
99
100 let name = &ast.ident;
101 let generics = &ast.generics;
102 let execution_record_path = find_execution_record_path(&ast.attrs);
103 let program_path = find_program_path(&ast.attrs);
104 let builder_path = find_builder_path(&ast.attrs);
105 let eval_trait_bound = find_eval_trait_bound(&ast.attrs);
106 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
107
108 match &ast.data {
109 Data::Struct(_) => unimplemented!("Structs are not supported yet"),
110 Data::Enum(e) => {
111 let variants = e
112 .variants
113 .iter()
114 .map(|variant| {
115 let variant_name = &variant.ident;
116
117 let mut fields = variant.fields.iter();
118 let field = fields.next().unwrap();
119 assert!(fields.next().is_none(), "Only one field is supported");
120 (variant_name, field)
121 })
122 .collect::<Vec<_>>();
123
124 let width_arms = variants.iter().map(|(variant_name, field)| {
125 let field_ty = &field.ty;
126 quote! {
127 #name::#variant_name(x) => <#field_ty as p3_air::BaseAir<F>>::width(x)
128 }
129 });
130
131 let base_air = quote! {
132 impl #impl_generics p3_air::BaseAir<F> for #name #ty_generics #where_clause {
133 fn width(&self) -> usize {
134 match self {
135 #(#width_arms,)*
136 }
137 }
138
139 fn preprocessed_trace(&self) -> Option<p3_matrix::dense::RowMajorMatrix<F>> {
140 unreachable!("A machine air should use the preprocessed trace from the `MachineAir` trait")
141 }
142 }
143 };
144
145 let name_arms = variants.iter().map(|(variant_name, field)| {
146 let field_ty = &field.ty;
147 quote! {
148 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::name(x)
149 }
150 });
151
152 let preprocessed_width_arms = variants.iter().map(|(variant_name, field)| {
153 let field_ty = &field.ty;
154 quote! {
155 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::preprocessed_width(x)
156 }
157 });
158
159 let generate_preprocessed_trace_arms = variants.iter().map(|(variant_name, field)| {
160 let field_ty = &field.ty;
161 quote! {
162 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::generate_preprocessed_trace(x, program)
163 }
164 });
165
166 let generate_trace_arms = variants.iter().map(|(variant_name, field)| {
167 let field_ty = &field.ty;
168 quote! {
169 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::generate_trace(x, input, output)
170 }
171 });
172
173 let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| {
174 let field_ty = &field.ty;
175 quote! {
176 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::generate_dependencies(x, input, output)
177 }
178 });
179
180 let included_arms = variants.iter().map(|(variant_name, field)| {
181 let field_ty = &field.ty;
182 quote! {
183 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::included(x, shard)
184 }
185 });
186
187 let commit_scope_arms = variants.iter().map(|(variant_name, field)| {
188 let field_ty = &field.ty;
189 quote! {
190 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::commit_scope(x)
191 }
192 });
193
194 let local_only_arms = variants.iter().map(|(variant_name, field)| {
195 let field_ty = &field.ty;
196 quote! {
197 #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::local_only(x)
198 }
199 });
200
201 let machine_air = quote! {
202 impl #impl_generics sp1_stark::air::MachineAir<F> for #name #ty_generics #where_clause {
203 type Record = #execution_record_path;
204
205 type Program = #program_path;
206
207 fn name(&self) -> String {
208 match self {
209 #(#name_arms,)*
210 }
211 }
212
213 fn preprocessed_width(&self) -> usize {
214 match self {
215 #(#preprocessed_width_arms,)*
216 }
217 }
218
219 fn generate_preprocessed_trace(
220 &self,
221 program: &#program_path,
222 ) -> Option<p3_matrix::dense::RowMajorMatrix<F>> {
223 match self {
224 #(#generate_preprocessed_trace_arms,)*
225 }
226 }
227
228 fn generate_trace(
229 &self,
230 input: &#execution_record_path,
231 output: &mut #execution_record_path,
232 ) -> p3_matrix::dense::RowMajorMatrix<F> {
233 match self {
234 #(#generate_trace_arms,)*
235 }
236 }
237
238 fn generate_dependencies(
239 &self,
240 input: &#execution_record_path,
241 output: &mut #execution_record_path,
242 ) {
243 match self {
244 #(#generate_dependencies_arms,)*
245 }
246 }
247
248 fn included(&self, shard: &Self::Record) -> bool {
249 match self {
250 #(#included_arms,)*
251 }
252 }
253
254 fn commit_scope(&self) -> InteractionScope {
255 match self {
256 #(#commit_scope_arms,)*
257 }
258 }
259
260 fn local_only(&self) -> bool {
261 match self {
262 #(#local_only_arms,)*
263 }
264 }
265 }
266 };
267
268 let eval_arms = variants.iter().map(|(variant_name, field)| {
269 let field_ty = &field.ty;
270 quote! {
271 #name::#variant_name(x) => <#field_ty as p3_air::Air<AB>>::eval(x, builder)
272 }
273 });
274
275 let generics = &ast.generics;
277 let mut new_generics = generics.clone();
278 new_generics.params.push(syn::parse_quote! { AB: p3_air::PairBuilder + #builder_path });
279
280 let (air_impl_generics, _, _) = new_generics.split_for_impl();
281
282 let mut new_generics = generics.clone();
283 let where_clause = new_generics.make_where_clause();
284 if eval_trait_bound.is_some() {
285 let predicate: WherePredicate = syn::parse_str(&eval_trait_bound.unwrap()).unwrap();
286 where_clause.predicates.push(predicate);
287 }
288
289 let air = quote! {
290 impl #air_impl_generics p3_air::Air<AB> for #name #ty_generics #where_clause {
291 fn eval(&self, builder: &mut AB) {
292 match self {
293 #(#eval_arms,)*
294 }
295 }
296 }
297 };
298
299 quote! {
300 #base_air
301
302 #machine_air
303
304 #air
305 }
306 .into()
307 }
308 Data::Union(_) => unimplemented!("Unions are not supported"),
309 }
310}
311
312#[proc_macro_attribute]
313pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
314 let input = parse_macro_input!(item as ItemFn);
315 let visibility = &input.vis;
316 let name = &input.sig.ident;
317 let inputs = &input.sig.inputs;
318 let output = &input.sig.output;
319 let block = &input.block;
320 let generics = &input.sig.generics;
321 let where_clause = &input.sig.generics.where_clause;
322
323 let result = quote! {
324 #visibility fn #name #generics (#inputs) #output #where_clause {
325 eprintln!("cycle-tracker-start: {}", stringify!(#name));
326 let result = (|| #block)();
327 eprintln!("cycle-tracker-end: {}", stringify!(#name));
328 result
329 }
330 };
331
332 result.into()
333}
334
335#[proc_macro_attribute]
336pub fn cycle_tracker_recursion(_attr: TokenStream, item: TokenStream) -> TokenStream {
337 let input = parse_macro_input!(item as ItemFn);
338 let visibility = &input.vis;
339 let name = &input.sig.ident;
340 let inputs = &input.sig.inputs;
341 let output = &input.sig.output;
342 let block = &input.block;
343 let generics = &input.sig.generics;
344 let where_clause = &input.sig.generics.where_clause;
345
346 let result = quote! {
347 #visibility fn #name #generics (#inputs) #output #where_clause {
348 sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_enter(builder, stringify!(#name));
349 let result = #block;
350 sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_exit(builder);
351 result
352 }
353 };
354
355 result.into()
356}
357
358fn find_execution_record_path(attrs: &[syn::Attribute]) -> syn::Path {
359 for attr in attrs {
360 if attr.path.is_ident("execution_record_path") {
361 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
362 if let syn::Lit::Str(lit_str) = &meta.lit {
363 if let Ok(path) = lit_str.parse::<syn::Path>() {
364 return path;
365 }
366 }
367 }
368 }
369 }
370 parse_quote!(sp1_core_executor::ExecutionRecord)
371}
372
373fn find_program_path(attrs: &[syn::Attribute]) -> syn::Path {
374 for attr in attrs {
375 if attr.path.is_ident("program_path") {
376 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
377 if let syn::Lit::Str(lit_str) = &meta.lit {
378 if let Ok(path) = lit_str.parse::<syn::Path>() {
379 return path;
380 }
381 }
382 }
383 }
384 }
385 parse_quote!(sp1_core_executor::Program)
386}
387
388fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path {
389 for attr in attrs {
390 if attr.path.is_ident("builder_path") {
391 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
392 if let syn::Lit::Str(lit_str) = &meta.lit {
393 if let Ok(path) = lit_str.parse::<syn::Path>() {
394 return path;
395 }
396 }
397 }
398 }
399 }
400 parse_quote!(crate::air::SP1CoreAirBuilder<F = F>)
401}
402
403fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option<String> {
404 for attr in attrs {
405 if attr.path.is_ident("eval_trait_bound") {
406 if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
407 if let syn::Lit::Str(lit_str) = &meta.lit {
408 return Some(lit_str.value());
409 }
410 }
411 }
412 }
413
414 None
415}