Skip to main content

unipotato_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, LitStr};
4
5fn http_method_impl(method: &str, args: TokenStream, input: TokenStream) -> TokenStream {
6    let path = parse_macro_input!(args as LitStr);
7    let path_value = path.value();
8    let func = parse_macro_input!(input as ItemFn);
9    
10    let func_name = &func.sig.ident;
11    let func_vis = &func.vis;
12    let func_block = &func.block;
13    let func_inputs = &func.sig.inputs;
14    let func_output = &func.sig.output;
15    let func_asyncness = &func.sig.asyncness;
16
17    let method_upper = method.to_uppercase();
18    let method_ident = syn::Ident::new(&method_upper, proc_macro2::Span::call_site());
19    
20    let route_info_name = syn::Ident::new(
21        &format!("__ROUTE_INFO_{}", func_name.to_string().to_uppercase()),
22        proc_macro2::Span::call_site()
23    );
24
25    // Convert path with params like "/users/<id>" to regex pattern "/users/[^/]+"
26    let pattern = convert_path_to_pattern(&path_value);
27
28    let expanded = quote! {
29        #func_vis #func_asyncness fn #func_name(#func_inputs) #func_output {
30            #func_block
31        }
32        
33        // Store route metadata for lazy registration
34        #[doc(hidden)]
35        #[allow(non_camel_case_types)]
36        pub struct #route_info_name;
37
38        #[cfg(not(test))]
39        impl #route_info_name {
40            pub fn register() {
41                let handler: ::unipotato::route::Handler = ::std::sync::Arc::new(|req: ::unipotato::Request| {
42                    ::std::boxed::Box::pin(async move {
43                        #func_name(req).await
44                    })
45                });
46                ::unipotato::route::register_route(
47                    ::unipotato::hyper::Method::#method_ident,
48                    #path_value.to_string(),
49                    #pattern.to_string(),
50                    handler
51                );
52            }
53        }
54        
55        #[cfg(test)]
56        impl #route_info_name {
57            pub fn register() {
58                // No-op in test mode
59            }
60        }
61    };
62
63    TokenStream::from(expanded)
64}
65
66/// Converts a path like "/users/<id>/posts/<post_id>" to a regex pattern "/users/[^/]+/posts/[^/]+"
67fn convert_path_to_pattern(path: &str) -> String {
68    let mut pattern = String::new();
69    let mut chars = path.chars().peekable();
70    
71    while let Some(c) = chars.next() {
72        if c == '<' {
73            // Skip until we find '>'
74            while let Some(inner) = chars.next() {
75                if inner == '>' {
76                    break;
77                }
78            }
79            // Replace <param> with a pattern that matches any non-slash characters
80            pattern.push_str("[^/]+");
81        } else {
82            // Escape regex special characters except for common path chars
83            match c {
84                '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$' | '\\' => {
85                    pattern.push('\\');
86                    pattern.push(c);
87                }
88                _ => pattern.push(c),
89            }
90        }
91    }
92    
93    format!("^{}$", pattern)
94}
95
96#[proc_macro_attribute]
97pub fn get(args: TokenStream, input: TokenStream) -> TokenStream {
98    http_method_impl("get", args, input)
99}
100
101#[proc_macro_attribute]
102pub fn post(args: TokenStream, input: TokenStream) -> TokenStream {
103    http_method_impl("post", args, input)
104}
105
106#[proc_macro_attribute]
107pub fn put(args: TokenStream, input: TokenStream) -> TokenStream {
108    http_method_impl("put", args, input)
109}
110
111#[proc_macro_attribute]
112pub fn delete(args: TokenStream, input: TokenStream) -> TokenStream {
113    http_method_impl("delete", args, input)
114}
115
116#[proc_macro_attribute]
117pub fn patch(args: TokenStream, input: TokenStream) -> TokenStream {
118    http_method_impl("patch", args, input)
119}