switchboard_node_macros/
lib.rs

1extern crate proc_macro;
2
3mod params;
4mod utils;
5use utils::*;
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::{Error, ItemFn, Result as SynResult, ReturnType, Type};
10
11#[proc_macro_attribute]
12pub fn routine(attr: TokenStream, item: TokenStream) -> TokenStream {
13    // Parse the macro parameters to set a timeout
14    let macro_params = match syn::parse::<params::SwitchboardRoutineArgs>(attr.clone()) {
15        Ok(args) => args,
16        Err(err) => {
17            let e = syn::Error::new_spanned(
18                err.to_compile_error(),
19                format!("Failed to parse macro parameters: {:?}", err),
20            );
21
22            return e.to_compile_error().into();
23        }
24    };
25
26    // Try to build the token stream, return errors if failed
27    match build_token_stream(macro_params, item) {
28        Ok(token_stream) => token_stream,
29        Err(err) => err.to_compile_error().into(),
30    }
31}
32
33fn verify_sb_result_return_type(ty: &Type) -> Result<(), Error> {
34    let (ok_type, err_type) = extract_result_args(ty).unwrap();
35
36    // verify ok_type
37    if !is_empty_tuple_type(ok_type) {
38        return Err(Error::new_spanned(
39            ty,
40            format!(
41                "Function must return `Result<(), SbError>`, found: {:?}",
42                quote! { # ok_type }
43            ),
44        ));
45    }
46
47    // verify err_type
48    if !is_sb_error_type(err_type) {
49        return Err(Error::new_spanned(
50            ty,
51            format!(
52                "Function must return `Result<(), SbError>`, found error type: {:?}",
53                quote! { # err_type }
54            ),
55        ));
56    }
57
58    Ok(())
59}
60
61fn build_token_stream(
62    params: params::SwitchboardRoutineArgs,
63    item: TokenStream,
64) -> SynResult<TokenStream> {
65    let input: ItemFn = syn::parse(item.clone())?;
66
67    // Decompose the input function
68    let fn_name = input.sig.ident;
69    let fn_block = input.block;
70    let fn_inputs = input.sig.inputs;
71    let fn_return_type = input.sig.output;
72
73    let interval = params.interval;
74    let skip_first_tick = params.skip_first_tick;
75
76    let expanded_token_stream = match &fn_return_type {
77        // fn_output is Result<(), SbError>
78        ReturnType::Type(_, ty) => {
79            verify_sb_result_return_type(ty)?;
80
81            // let closure_type = quote! { FnMut() -> std::future::Future<Output = #ty> };
82
83            // Everything is good, build the token stream
84            quote! {
85                async fn #fn_name(#fn_inputs)  {
86                    let async_fn: std::boxed::Box<
87                        dyn Fn() -> std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = #ty> + Send>> + Send + Sync
88                    > = std::boxed::Box::new(move || {  // use 'move' here
89                        std::boxed::Box::pin(async  { #fn_block })  // and here
90                    });
91
92                    // Create the interval
93                    let mut interval: Interval = interval(Duration::from_secs(std::cmp::max(1, #interval)));
94                    interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
95
96                    if #skip_first_tick {
97                        interval.tick().await;
98                    }
99
100                    loop {
101                        interval.tick().await;
102                        trace!("Running routine {}", stringify!(#fn_name));
103
104                        // Run custom async fn here
105                        if let Err(err) = async_fn().await {
106                            error!("Error in routine {}: {:?}", stringify!(#fn_name), err);
107                        }
108                    }
109                }
110            }
111        }
112        // fn_output is ()
113        _ => {
114            // Everything is good, build the token stream
115            quote! {
116                async fn #fn_name(#fn_inputs) {
117                    // Store async_fn logic in a closure
118                    let async_fn: std::boxed::Box<
119                        dyn Fn() -> std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync
120                    > = std::boxed::Box::new(move || {
121                        std::boxed::Box::pin(async { #fn_block })
122                    });
123
124                    // Create the interval
125                    let mut interval: Interval = interval(Duration::from_secs(std::cmp::max(1, #interval)));
126                    interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
127
128                    if #skip_first_tick {
129                        interval.tick().await;
130                    }
131
132                    loop {
133                        interval.tick().await;
134                        trace!("Running routine {}", stringify!(#fn_name));
135
136                        // Run custom async fn here
137                        async_fn().await;
138                    }
139                }
140            }
141        }
142    };
143
144    Ok(TokenStream::from(expanded_token_stream))
145}