webhook_flows_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use quote::{quote, ToTokens};
4
5use syn::parse::Parser;
6
7use rand::distributions::Alphanumeric;
8use rand::{thread_rng, Rng};
9
10// syn::AttributeArgs does not implement syn::Parse
11type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
12
13fn parse_methods(args: TokenStream) -> syn::Result<Vec<String>> {
14    let mut idents = vec![];
15    if let Ok(args) = AttributeArgs::parse_terminated.parse2(args.into()) {
16        for arg in args {
17            let ident = match arg {
18                syn::Meta::NameValue(namevalue) => namevalue
19                    .path
20                    .get_ident()
21                    .ok_or_else(|| {
22                        syn::Error::new_spanned(&namevalue, "Must have specified ident")
23                    })?
24                    .to_string()
25                    .to_uppercase(),
26                syn::Meta::Path(path) => path
27                    .get_ident()
28                    .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
29                    .to_string()
30                    .to_uppercase(),
31                other => {
32                    return Err(syn::Error::new_spanned(
33                        other,
34                        "Unknown attribute inside the macro",
35                    ))
36                }
37            };
38            match ident.as_str() {
39                "GET" | "HEAD" | "POST" | "PUT" | "DELETE" | "OPTIONS" | "TRACE" | "PATCH" => {}
40                name => {
41                    let msg = format!(
42                            "Unknown method attribute {} is specified; expected one of: `GET`, `HEAD`, `POST`, `PUT`, `DELETE`, `OPTIONS`, `TRACE`, `PATCH`",
43                            name,
44                        );
45                    return Err(syn::Error::new_spanned(name, msg));
46                }
47            }
48            idents.push(ident);
49        }
50    }
51    Ok(idents)
52}
53
54#[proc_macro_attribute]
55pub fn request_handler(args: TokenStream, item: TokenStream) -> TokenStream {
56    let ast: syn::ItemFn = syn::parse(item).unwrap();
57    let func_ident = ast.sig.ident.clone();
58
59    let gen = match ast.sig.inputs.len() {
60        0 => {
61            quote! {
62                #[no_mangle]
63                #[tokio::main(flavor = "current_thread")]
64                pub async fn __webhook__on_request_received() {
65                    #func_ident().await;
66                }
67            }
68        }
69        4 => {
70            let extern_mod_name = format!("webhook_flows_macros_{}", rand_ident());
71            let extern_mod_ident = Ident::new(&extern_mod_name, Span::call_site());
72            let request_fn_name = format!("__request_{}", rand_ident());
73            let request_fn_ident = Ident::new(&request_fn_name, Span::call_site());
74
75            let gen = quote! {
76                mod #extern_mod_ident {
77                    extern "C" {
78                        pub fn get_event_headers_length() -> i32;
79                        pub fn get_event_headers(p: *mut u8) -> i32;
80                        pub fn get_event_query_length() -> i32;
81                        pub fn get_event_query(p: *mut u8) -> i32;
82                        pub fn get_event_subpath_length() -> i32;
83                        pub fn get_event_subpath(p: *mut u8) -> i32;
84                        pub fn get_event_body_length() -> i32;
85                        pub fn get_event_body(p: *mut u8) -> i32;
86                    }
87                }
88
89                fn #request_fn_ident() -> Option<(Vec<(String, String)>, String, HashMap<String, Value>, Vec<u8>)> {
90                    unsafe {
91                        let l = #extern_mod_ident::get_event_headers_length();
92                        let mut event_headers = Vec::<u8>::with_capacity(l as usize);
93                        let c = #extern_mod_ident::get_event_headers(event_headers.as_mut_ptr());
94                        assert!(c == l);
95                        event_headers.set_len(c as usize);
96                        let event_headers = serde_json::from_slice(&event_headers).unwrap();
97
98                        let l = #extern_mod_ident::get_event_query_length();
99                        let mut event_query = Vec::<u8>::with_capacity(l as usize);
100                        let c = #extern_mod_ident::get_event_query(event_query.as_mut_ptr());
101                        assert!(c == l);
102                        event_query.set_len(c as usize);
103                        let event_query = serde_json::from_slice(&event_query).unwrap();
104
105                        let l = #extern_mod_ident::get_event_subpath_length();
106                        let mut event_subpath = Vec::<u8>::with_capacity(l as usize);
107                        let c = #extern_mod_ident::get_event_subpath(event_subpath.as_mut_ptr());
108                        assert!(c == l);
109                        event_subpath.set_len(c as usize);
110                        let event_subpath = String::from_utf8_lossy(&event_subpath).into_owned();
111
112                        let l = #extern_mod_ident::get_event_body_length();
113                        let mut event_body = Vec::<u8>::with_capacity(l as usize);
114                        let c = #extern_mod_ident::get_event_body(event_body.as_mut_ptr());
115                        assert!(c == l);
116                        event_body.set_len(c as usize);
117
118                        Some((event_headers, event_subpath, event_query, event_body))
119                    }
120                }
121            };
122
123            let methods = parse_methods(args).unwrap();
124
125            match methods.len() > 0 {
126                true => {
127                    let mut q = quote! {};
128                    for m in methods.iter() {
129                        let fn_name = format!("__webhook__on_request_received_{}", m);
130                        let fn_ident = Ident::new(&fn_name, Span::call_site());
131                        q = quote! {
132                            #q
133                            #[no_mangle]
134                            #[tokio::main(flavor = "current_thread")]
135                            pub async fn #fn_ident() {
136                                if let Some((headers, subpath, qry, body)) = #request_fn_ident() {
137                                    #func_ident(headers, subpath, qry, body).await;
138                                }
139                            }
140                        };
141                    }
142                    quote! {
143                        #gen
144                        #q
145                    }
146                }
147                false => {
148                    quote! {
149                        #gen
150
151                        #[no_mangle]
152                        #[tokio::main(flavor = "current_thread")]
153                        pub async fn __webhook__on_request_received() {
154                            if let Some((headers, subpath, qry, body)) = #request_fn_ident() {
155                                #func_ident(headers, subpath, qry, body).await;
156                            }
157                        }
158                    }
159                }
160            }
161        }
162        _ => {
163            panic!("Not compatible fn");
164        }
165    };
166
167    let ori_run_str = ast.to_token_stream().to_string();
168    let x = gen.to_string() + &ori_run_str;
169    x.parse().unwrap()
170}
171
172fn rand_ident() -> String {
173    thread_rng()
174        .sample_iter(&Alphanumeric)
175        .take(3)
176        .map(char::from)
177        .collect::<String>()
178        .to_lowercase()
179}