switchboard_starknet_macros/
lib.rs

1extern crate proc_macro;
2
3mod params;
4mod utils;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{ FnArg, ItemFn, Result as SynResult, ReturnType, Type };
9
10#[proc_macro_attribute]
11pub fn switchboard_function(attr: TokenStream, item: TokenStream) -> TokenStream {
12    // Parse the macro parameters to set a timeout
13    let macro_params = match syn::parse::<params::SwitchboardStarknetFunctionArgs>(attr.clone()) {
14        Ok(args) => args,
15        Err(err) => {
16            let e = syn::Error::new_spanned(
17                err.to_compile_error(),
18                format!("Failed to parse macro parameters: {:?}", err)
19            );
20
21            return e.to_compile_error().into();
22        }
23    };
24
25    // Try to build the token stream, return errors if failed
26    match build_token_stream(macro_params, item) {
27        Ok(token_stream) => token_stream,
28        Err(err) => err.to_compile_error().into(),
29    }
30}
31
32/// Validates whether the first param is StarknetFunctionRunner
33fn validate_function_runner_param(input: &ItemFn) -> SynResult<()> {
34    // Extract the first parameter of the function.
35    let first_param_type = input.sig.inputs
36        .iter()
37        .next()
38        .ok_or_else(|| {
39            syn::Error::new_spanned(
40                &input.sig,
41                "The switchboard_function must take at least one parameter"
42            )
43        })?;
44
45    let typed_arg = match first_param_type {
46        FnArg::Typed(typed) => { typed }
47        _ => {
48            return Err(syn::Error::new_spanned(first_param_type, "Expected a typed parameter"));
49        }
50    };
51
52    let is_function_runner = if let Type::Path(type_path) = &*typed_arg.ty {
53        &type_path.path.segments.last().unwrap().ident == "StarknetFunctionRunner"
54    } else {
55        false
56    };
57
58    if !is_function_runner {
59        return Err(syn::Error::new_spanned(&typed_arg.ty, "Parameter must be StarknetFunctionRunner"));
60    }
61
62    Ok(())
63}
64
65/// Helper function to validate the return type is a Result with the correct arguments.
66fn validate_function_return_type(input: &ItemFn) -> SynResult<()> {
67    let ty = match &input.sig.output {
68        ReturnType::Type(_, ty) => ty,
69        ReturnType::Default => {
70            return Err(
71                syn::Error::new_spanned(&input.sig.output, "Function does not have a return type")
72            );
73        }
74    };
75
76    let (ok_type, err_type) = utils
77        ::extract_result_args(ty)
78        .ok_or_else(|| {
79            syn::Error::new_spanned(&input.sig.output, "Return type must be a Result")
80        })?;
81
82    // Validate the inner Vec type
83    let inner_vec_type = utils
84        ::extract_inner_type_from_vec(ok_type)
85        .ok_or_else(|| {
86            syn::Error::new_spanned(
87                &input.sig.output,
88                "Ok variant of Result must be a Vec<Call>"
89            )
90        })?;
91
92    if !matches!(inner_vec_type, Type::Path(t) if t.path.is_ident("Call")) {
93        return Err(
94            syn::Error::new_spanned(
95                &input.sig.output,
96                "Ok variant of Result must be a Vec<Call>"
97            )
98        );
99    }
100
101    // Validate the error type
102    let error_type_path_segments = match err_type {
103        Type::Path(type_path) => &type_path.path.segments,
104        _ => {
105            return Err(syn::Error::new_spanned(err_type, "Error type must be a path type"));
106        }
107    };
108
109    // Check if the error type is SbFunctionError or switchboard_common::SbFunctionError
110    let is_sb_function_error = match error_type_path_segments.last() {
111        Some(last_segment) if last_segment.ident == "SbFunctionError" => true,
112        Some(last_segment) if last_segment.ident == "Error" => {
113            // If the last segment is "Error", check the preceding segment for "switchboard_common"
114            error_type_path_segments.len() > 1 &&
115                error_type_path_segments[error_type_path_segments.len() - 2].ident ==
116                    "switchboard_common"
117        }
118        _ => false,
119    };
120
121    if !is_sb_function_error {
122        return Err(
123            syn::Error::new_spanned(
124                &input.sig.output,
125                "The error variant in the Result return type should be SbFunctionError"
126            )
127        );
128    }
129
130    Ok(())
131}
132
133fn validate_second_parameter(input: &ItemFn) -> SynResult<()> {
134    let second_param = input.sig.inputs
135        .iter()
136        .nth(1)
137        .ok_or_else(|| {
138            syn::Error::new_spanned(&input.sig, "The switchboard_function must take two parameters")
139        })?;
140
141    let typed_arg = match second_param {
142        FnArg::Typed(typed) => typed,
143        _ => {
144            return Err(syn::Error::new_spanned(second_param, "Expected a typed second parameter"));
145        }
146    };
147
148    // Use the utility function to extract the inner type from a Vec
149    let inner_type = utils
150        ::extract_inner_type_from_vec(&typed_arg.ty)
151        .ok_or_else(||
152            syn::Error::new_spanned(&typed_arg.ty, "The second parameter must be of type Vec<FieldElement>")
153        )?;
154
155    // Ensure the inner type of the Vec is u8
156    if let Type::Path(type_path) = inner_type {
157        if !type_path.path.is_ident("FieldElement") {
158            return Err(
159                syn::Error::new_spanned(
160                    &typed_arg.ty,
161                    "The second parameter must be of type Vec<FieldElement>"
162                )
163            );
164        }
165    } else {
166        return Err(
167            syn::Error::new_spanned(&typed_arg.ty, "The second parameter must be of type Vec<FieldElement>")
168        );
169    }
170
171    Ok(())
172}
173
174fn build_token_stream(
175    _params: params::SwitchboardStarknetFunctionArgs,
176    item: TokenStream
177) -> SynResult<TokenStream> {
178    let input: ItemFn = syn::parse(item.clone())?;
179    let function_name = &input.sig.ident;
180
181    // Validate that there's exactly one input of the correct type
182    if input.sig.inputs.len() != 2 {
183        return Err(
184            syn::Error::new_spanned(
185                &input.sig,
186                "The switchboard_function must take exactly one parameter of type 'Arc<StarknetFunctionRunner>' and 'Vec<FieldElement>'"
187            )
188        );
189    }
190
191    validate_function_return_type(&input)?;
192
193    // Validate input parameters
194    validate_function_runner_param(&input)?;
195    validate_second_parameter(&input)?;
196
197    let expanded =
198        quote! {
199
200            // Include the original function definition
201            #input
202
203            pub type SwitchboardFunctionResult<T> = std::result::Result<T, SbFunctionError>;
204
205            /// Run an async function and catch any panics
206            pub async fn run_switchboard_function<F, T>(
207                logic: F,
208            ) -> SwitchboardFunctionResult<()>
209            where
210                F: Fn(StarknetFunctionRunner, Vec<FieldElement>) -> T + Send + 'static,
211                T: futures::Future<Output = SwitchboardFunctionResult<Vec<Call>>>
212                    + Send,
213            {
214                // Initialize the runner
215                let mut runner = StarknetFunctionRunner::new().map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
216                match logic(runner.clone(), runner.params.clone()).await {
217                    Ok(calls) => {
218                        runner
219                            .emit(calls)
220                            .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
221                        Ok(())
222                    }
223                    Err(e) => {
224                        println!("Error: Switchboard function failed with error code: {:?}", e);
225                        let mut err_code = 199;
226                        if let SbFunctionError::FunctionError(code) = e {
227                            err_code = code;
228                        }
229                        runner
230                            .emit_error(err_code)
231                            .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
232                        Ok(())
233                    }
234                }
235            }
236
237            #[tokio::main(worker_threads = 12)]
238            async fn main() -> SwitchboardFunctionResult<()> {
239                run_switchboard_function(#function_name).await?;
240                Ok(())
241            }
242    };
243
244    Ok(TokenStream::from(expanded))
245}
246
247#[proc_macro_attribute]
248pub fn sb_error(_attr: TokenStream, item: TokenStream) -> TokenStream {
249    let input = syn::parse_macro_input!(item as syn::DeriveInput);
250
251    let name = &input.ident;
252    let expanded = quote! {
253        #[derive(Clone, Copy, Debug, PartialEq)]
254        #[repr(u8)]
255        #input
256
257        impl From<#name> for SbFunctionError {
258            fn from(item: #name) -> Self {
259                SbFunctionError::FunctionError(item as u8 + 1)
260            }
261        }
262
263        impl From<#name> for u8 {
264            fn from(item: #name) -> Self {
265                item as u8 + 1
266            }
267        }
268
269        impl std::fmt::Display for #name {
270            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
271                write!(f, "{:?}", self)
272            }
273        }
274
275        impl std::error::Error for #name {}
276    };
277
278    TokenStream::from(expanded)
279}