switchboard_starknet_macros/
lib.rs1extern crate proc_macro;
2
3mod params;
4mod utils;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{ FnArg, ItemFn, Result as SynResult, ReturnType, Type };
9
10#[proc_macro_attribute]
11pub fn switchboard_function(attr: TokenStream, item: TokenStream) -> TokenStream {
12 let macro_params = match syn::parse::<params::SwitchboardStarknetFunctionArgs>(attr.clone()) {
14 Ok(args) => args,
15 Err(err) => {
16 let e = syn::Error::new_spanned(
17 err.to_compile_error(),
18 format!("Failed to parse macro parameters: {:?}", err)
19 );
20
21 return e.to_compile_error().into();
22 }
23 };
24
25 match build_token_stream(macro_params, item) {
27 Ok(token_stream) => token_stream,
28 Err(err) => err.to_compile_error().into(),
29 }
30}
31
32fn validate_function_runner_param(input: &ItemFn) -> SynResult<()> {
34 let first_param_type = input.sig.inputs
36 .iter()
37 .next()
38 .ok_or_else(|| {
39 syn::Error::new_spanned(
40 &input.sig,
41 "The switchboard_function must take at least one parameter"
42 )
43 })?;
44
45 let typed_arg = match first_param_type {
46 FnArg::Typed(typed) => { typed }
47 _ => {
48 return Err(syn::Error::new_spanned(first_param_type, "Expected a typed parameter"));
49 }
50 };
51
52 let is_function_runner = if let Type::Path(type_path) = &*typed_arg.ty {
53 &type_path.path.segments.last().unwrap().ident == "StarknetFunctionRunner"
54 } else {
55 false
56 };
57
58 if !is_function_runner {
59 return Err(syn::Error::new_spanned(&typed_arg.ty, "Parameter must be StarknetFunctionRunner"));
60 }
61
62 Ok(())
63}
64
65fn validate_function_return_type(input: &ItemFn) -> SynResult<()> {
67 let ty = match &input.sig.output {
68 ReturnType::Type(_, ty) => ty,
69 ReturnType::Default => {
70 return Err(
71 syn::Error::new_spanned(&input.sig.output, "Function does not have a return type")
72 );
73 }
74 };
75
76 let (ok_type, err_type) = utils
77 ::extract_result_args(ty)
78 .ok_or_else(|| {
79 syn::Error::new_spanned(&input.sig.output, "Return type must be a Result")
80 })?;
81
82 let inner_vec_type = utils
84 ::extract_inner_type_from_vec(ok_type)
85 .ok_or_else(|| {
86 syn::Error::new_spanned(
87 &input.sig.output,
88 "Ok variant of Result must be a Vec<Call>"
89 )
90 })?;
91
92 if !matches!(inner_vec_type, Type::Path(t) if t.path.is_ident("Call")) {
93 return Err(
94 syn::Error::new_spanned(
95 &input.sig.output,
96 "Ok variant of Result must be a Vec<Call>"
97 )
98 );
99 }
100
101 let error_type_path_segments = match err_type {
103 Type::Path(type_path) => &type_path.path.segments,
104 _ => {
105 return Err(syn::Error::new_spanned(err_type, "Error type must be a path type"));
106 }
107 };
108
109 let is_sb_function_error = match error_type_path_segments.last() {
111 Some(last_segment) if last_segment.ident == "SbFunctionError" => true,
112 Some(last_segment) if last_segment.ident == "Error" => {
113 error_type_path_segments.len() > 1 &&
115 error_type_path_segments[error_type_path_segments.len() - 2].ident ==
116 "switchboard_common"
117 }
118 _ => false,
119 };
120
121 if !is_sb_function_error {
122 return Err(
123 syn::Error::new_spanned(
124 &input.sig.output,
125 "The error variant in the Result return type should be SbFunctionError"
126 )
127 );
128 }
129
130 Ok(())
131}
132
133fn validate_second_parameter(input: &ItemFn) -> SynResult<()> {
134 let second_param = input.sig.inputs
135 .iter()
136 .nth(1)
137 .ok_or_else(|| {
138 syn::Error::new_spanned(&input.sig, "The switchboard_function must take two parameters")
139 })?;
140
141 let typed_arg = match second_param {
142 FnArg::Typed(typed) => typed,
143 _ => {
144 return Err(syn::Error::new_spanned(second_param, "Expected a typed second parameter"));
145 }
146 };
147
148 let inner_type = utils
150 ::extract_inner_type_from_vec(&typed_arg.ty)
151 .ok_or_else(||
152 syn::Error::new_spanned(&typed_arg.ty, "The second parameter must be of type Vec<FieldElement>")
153 )?;
154
155 if let Type::Path(type_path) = inner_type {
157 if !type_path.path.is_ident("FieldElement") {
158 return Err(
159 syn::Error::new_spanned(
160 &typed_arg.ty,
161 "The second parameter must be of type Vec<FieldElement>"
162 )
163 );
164 }
165 } else {
166 return Err(
167 syn::Error::new_spanned(&typed_arg.ty, "The second parameter must be of type Vec<FieldElement>")
168 );
169 }
170
171 Ok(())
172}
173
174fn build_token_stream(
175 _params: params::SwitchboardStarknetFunctionArgs,
176 item: TokenStream
177) -> SynResult<TokenStream> {
178 let input: ItemFn = syn::parse(item.clone())?;
179 let function_name = &input.sig.ident;
180
181 if input.sig.inputs.len() != 2 {
183 return Err(
184 syn::Error::new_spanned(
185 &input.sig,
186 "The switchboard_function must take exactly one parameter of type 'Arc<StarknetFunctionRunner>' and 'Vec<FieldElement>'"
187 )
188 );
189 }
190
191 validate_function_return_type(&input)?;
192
193 validate_function_runner_param(&input)?;
195 validate_second_parameter(&input)?;
196
197 let expanded =
198 quote! {
199
200 #input
202
203 pub type SwitchboardFunctionResult<T> = std::result::Result<T, SbFunctionError>;
204
205 pub async fn run_switchboard_function<F, T>(
207 logic: F,
208 ) -> SwitchboardFunctionResult<()>
209 where
210 F: Fn(StarknetFunctionRunner, Vec<FieldElement>) -> T + Send + 'static,
211 T: futures::Future<Output = SwitchboardFunctionResult<Vec<Call>>>
212 + Send,
213 {
214 let mut runner = StarknetFunctionRunner::new().map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
216 match logic(runner.clone(), runner.params.clone()).await {
217 Ok(calls) => {
218 runner
219 .emit(calls)
220 .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
221 Ok(())
222 }
223 Err(e) => {
224 println!("Error: Switchboard function failed with error code: {:?}", e);
225 let mut err_code = 199;
226 if let SbFunctionError::FunctionError(code) = e {
227 err_code = code;
228 }
229 runner
230 .emit_error(err_code)
231 .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
232 Ok(())
233 }
234 }
235 }
236
237 #[tokio::main(worker_threads = 12)]
238 async fn main() -> SwitchboardFunctionResult<()> {
239 run_switchboard_function(#function_name).await?;
240 Ok(())
241 }
242 };
243
244 Ok(TokenStream::from(expanded))
245}
246
247#[proc_macro_attribute]
248pub fn sb_error(_attr: TokenStream, item: TokenStream) -> TokenStream {
249 let input = syn::parse_macro_input!(item as syn::DeriveInput);
250
251 let name = &input.ident;
252 let expanded = quote! {
253 #[derive(Clone, Copy, Debug, PartialEq)]
254 #[repr(u8)]
255 #input
256
257 impl From<#name> for SbFunctionError {
258 fn from(item: #name) -> Self {
259 SbFunctionError::FunctionError(item as u8 + 1)
260 }
261 }
262
263 impl From<#name> for u8 {
264 fn from(item: #name) -> Self {
265 item as u8 + 1
266 }
267 }
268
269 impl std::fmt::Display for #name {
270 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
271 write!(f, "{:?}", self)
272 }
273 }
274
275 impl std::error::Error for #name {}
276 };
277
278 TokenStream::from(expanded)
279}