1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};

use syn::parse::Parser;

use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};

// syn::AttributeArgs does not implement syn::Parse
type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;

fn parse_methods(args: TokenStream) -> syn::Result<Vec<String>> {
    let mut idents = vec![];
    if let Ok(args) = AttributeArgs::parse_terminated.parse2(args.into()) {
        for arg in args {
            let ident = match arg {
                syn::Meta::NameValue(namevalue) => namevalue
                    .path
                    .get_ident()
                    .ok_or_else(|| {
                        syn::Error::new_spanned(&namevalue, "Must have specified ident")
                    })?
                    .to_string()
                    .to_uppercase(),
                syn::Meta::Path(path) => path
                    .get_ident()
                    .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
                    .to_string()
                    .to_uppercase(),
                other => {
                    return Err(syn::Error::new_spanned(
                        other,
                        "Unknown attribute inside the macro",
                    ))
                }
            };
            match ident.as_str() {
                "GET" | "HEAD" | "POST" | "PUT" | "DELETE" | "OPTIONS" | "TRACE" | "PATCH" => {}
                name => {
                    let msg = format!(
                            "Unknown method attribute {} is specified; expected one of: `GET`, `HEAD`, `POST`, `PUT`, `DELETE`, `OPTIONS`, `TRACE`, `PATCH`",
                            name,
                        );
                    return Err(syn::Error::new_spanned(name, msg));
                }
            }
            idents.push(ident);
        }
    }
    Ok(idents)
}

#[proc_macro_attribute]
pub fn request_handler(args: TokenStream, item: TokenStream) -> TokenStream {
    let ast: syn::ItemFn = syn::parse(item).unwrap();
    let func_ident = ast.sig.ident.clone();

    let extern_mod_name = format!("webhook_flows_macros_{}", rand_ident());
    let extern_mod_ident = Ident::new(&extern_mod_name, Span::call_site());
    let request_fn_name = format!("__request_{}", rand_ident());
    let request_fn_ident = Ident::new(&request_fn_name, Span::call_site());

    let gen = quote! {
        mod #extern_mod_ident {
            extern "C" {
                pub fn get_event_headers_length() -> i32;
                pub fn get_event_headers(p: *mut u8) -> i32;
                pub fn get_event_query_length() -> i32;
                pub fn get_event_query(p: *mut u8) -> i32;
                pub fn get_event_subpath_length() -> i32;
                pub fn get_event_subpath(p: *mut u8) -> i32;
                pub fn get_event_body_length() -> i32;
                pub fn get_event_body(p: *mut u8) -> i32;
            }
        }

        fn #request_fn_ident() -> Option<(Vec<(String, String)>, String, HashMap<String, Value>, Vec<u8>)> {
            unsafe {
                let l = #extern_mod_ident::get_event_headers_length();
                let mut event_headers = Vec::<u8>::with_capacity(l as usize);
                let c = #extern_mod_ident::get_event_headers(event_headers.as_mut_ptr());
                assert!(c == l);
                event_headers.set_len(c as usize);
                let event_headers = serde_json::from_slice(&event_headers).unwrap();

                let l = #extern_mod_ident::get_event_query_length();
                let mut event_query = Vec::<u8>::with_capacity(l as usize);
                let c = #extern_mod_ident::get_event_query(event_query.as_mut_ptr());
                assert!(c == l);
                event_query.set_len(c as usize);
                let event_query = serde_json::from_slice(&event_query).unwrap();

                let l = #extern_mod_ident::get_event_subpath_length();
                let mut event_subpath = Vec::<u8>::with_capacity(l as usize);
                let c = #extern_mod_ident::get_event_subpath(event_subpath.as_mut_ptr());
                assert!(c == l);
                event_subpath.set_len(c as usize);
                let event_subpath = String::from_utf8_lossy(&event_subpath).into_owned();

                let l = #extern_mod_ident::get_event_body_length();
                let mut event_body = Vec::<u8>::with_capacity(l as usize);
                let c = #extern_mod_ident::get_event_body(event_body.as_mut_ptr());
                assert!(c == l);
                event_body.set_len(c as usize);

                Some((event_headers, event_subpath, event_query, event_body))
            }
        }
    };

    let methods = parse_methods(args).unwrap();

    let gen = match methods.len() > 0 {
        true => {
            let mut q = quote! {};
            for m in methods.iter() {
                let fn_name = format!("__webhook__on_request_received_{}", m);
                let fn_ident = Ident::new(&fn_name, Span::call_site());
                q = quote! {
                    #q
                    #[no_mangle]
                    #[tokio::main(flavor = "current_thread")]
                    pub async fn #fn_ident() {
                        if let Some((headers, subpath, qry, body)) = #request_fn_ident() {
                            #func_ident(headers, subpath, qry, body).await;
                        }
                    }
                };
            }
            quote! {
                #gen
                #q
            }
        }
        false => {
            quote! {
                #gen

                #[no_mangle]
                #[tokio::main(flavor = "current_thread")]
                pub async fn __webhook__on_request_received() {
                    if let Some((headers, subpath, qry, body)) = #request_fn_ident() {
                        #func_ident(headers, subpath, qry, body).await;
                    }
                }
            }
        }
    };

    let ori_run_str = ast.to_token_stream().to_string();
    let x = gen.to_string() + &ori_run_str;
    x.parse().unwrap()
}

fn rand_ident() -> String {
    thread_rng()
        .sample_iter(&Alphanumeric)
        .take(3)
        .map(char::from)
        .collect::<String>()
        .to_lowercase()
}