sov_zk_cycle_macros/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!("../README.md")]
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{parse_macro_input, ItemFn};
8
9/// This macro is used to annotate functions that we want to track the number of riscV cycles being
10/// generated inside the VM. The purpose of the this macro is to measure how many cycles a rust
11/// function takes because prover time is directly proportional to the number of riscv cycles
12/// generated. It does this by making use of a risc0 provided function
13/// ```rust,ignore
14/// risc0_zkvm::guest::env::get_cycle_count
15/// ```
16/// The macro essentially generates new function with the same name by wrapping the body with a get_cycle_count
17/// at the beginning and end of the function, subtracting it and then emitting it out using the
18/// a custom syscall that is generated when the prover is run with the `bench` feature.
19/// `send_recv_slice` is used to communicate and pass a slice to the syscall that we defined.
20/// The handler for the syscall can be seen in adapters/risc0/src/host.rs and adapters/risc0/src/metrics.rs
21#[proc_macro_attribute]
22pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
23    let input = parse_macro_input!(item as ItemFn);
24
25    match wrap_function(input) {
26        Ok(ok) => ok,
27        Err(err) => err.to_compile_error().into(),
28    }
29}
30
31fn wrap_function(input: ItemFn) -> Result<TokenStream, syn::Error> {
32    let visibility = &input.vis;
33    let name = &input.sig.ident;
34    let inputs = &input.sig.inputs;
35    let output = &input.sig.output;
36    let block = &input.block;
37    let generics = &input.sig.generics;
38    let where_clause = &input.sig.generics.where_clause;
39    let risc0_zkvm = syn::Ident::new("risc0_zkvm", proc_macro2::Span::call_site());
40    let risc0_zkvm_platform =
41        syn::Ident::new("risc0_zkvm_platform", proc_macro2::Span::call_site());
42
43    let result = quote! {
44        #visibility fn #name #generics (#inputs) #output #where_clause {
45            let before = #risc0_zkvm::guest::env::get_cycle_count();
46            let result = (|| #block)();
47            let after = #risc0_zkvm::guest::env::get_cycle_count();
48
49            // simple serialization to avoid pulling in bincode or other libs
50            let tuple = (stringify!(#name).to_string(), (after - before) as u64);
51            let mut serialized = Vec::new();
52            serialized.extend(tuple.0.as_bytes());
53            serialized.push(0);
54            let size_bytes = tuple.1.to_ne_bytes();
55            serialized.extend(&size_bytes);
56
57            // calculate the syscall name.
58            let cycle_string = String::from("cycle_metrics\0");
59            let metrics_syscall_name = unsafe {
60                #risc0_zkvm_platform::syscall::SyscallName::from_bytes_with_nul(cycle_string.as_ptr())
61            };
62
63            #risc0_zkvm::guest::env::send_recv_slice::<u8,u8>(metrics_syscall_name, &serialized);
64            result
65        }
66    };
67    Ok(result.into())
68}