serverless_fn_macro/
lib.rs1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{FnArg, Pat, PatType, TypePath};
9
10#[proc_macro_attribute]
29pub fn serverless(_args: TokenStream, input: TokenStream) -> TokenStream {
30 let input_fn = syn::parse_macro_input!(input as syn::ItemFn);
31 let sig = &input_fn.sig;
32 let vis = &input_fn.vis;
33 let block = &input_fn.block;
34 let attrs = &input_fn.attrs;
35
36 let fn_name = &sig.ident;
37 let fn_name_str = fn_name.to_string();
38
39 let args: Vec<(syn::Ident, syn::Type)> = sig
41 .inputs
42 .iter()
43 .filter_map(|arg| {
44 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
45 if let Pat::Ident(pat_ident) = pat.as_ref() {
46 Some((pat_ident.ident.clone(), ty.as_ref().clone()))
47 } else {
48 None
49 }
50 } else {
51 None
52 }
53 })
54 .collect();
55
56 let arg_names: Vec<&syn::Ident> = args.iter().map(|(name, _)| name).collect();
57 let arg_types: Vec<&syn::Type> = args.iter().map(|(_, ty)| ty).collect();
58
59 let input_struct_name = syn::Ident::new(&format!("__{}_input", fn_name), fn_name.span());
61
62 let (return_type, success_type) = match &sig.output {
64 syn::ReturnType::Type(_, ty) => {
65 let return_ty = ty.as_ref().clone();
66 let success_ty = extract_ok_type(&return_ty).unwrap_or_else(|| return_ty.clone());
67 (return_ty, success_ty)
68 }
69 syn::ReturnType::Default => {
70 let unit_ty: syn::Type = syn::parse_str("()").expect("Failed to parse unit type");
71 (unit_ty.clone(), unit_ty)
72 }
73 };
74
75 let server_fn_name = syn::Ident::new(&format!("__{}_impl", fn_name), fn_name.span());
77
78 let registrar_struct_name =
80 syn::Ident::new(&format!("__{}_registrar", fn_name), fn_name.span());
81 let path_static = syn::Ident::new(
82 &format!("__{}_PATH", fn_name.to_string().to_uppercase()),
83 fn_name.span(),
84 );
85
86 let config = RegistrarConfig {
87 registrar_struct_name: ®istrar_struct_name,
88 path_static: &path_static,
89 fn_name_str: &fn_name_str,
90 server_fn_name: &server_fn_name,
91 input_struct_name: &input_struct_name,
92 arg_names: &arg_names,
93 arg_types: &arg_types,
94 _success_type: &success_type,
95 };
96
97 let registrar_impl = generate_registrar_impl(&config);
98
99 let output = quote! {
100 #(#attrs)*
102 #vis async fn #server_fn_name(#(#arg_names: #arg_types),*) -> #return_type #block
103
104 #[cfg(all(feature = "remote_call", not(feature = "local_call")))]
106 #(#attrs)*
107 #vis async fn #fn_name(#(#arg_names: #arg_types),*) -> #return_type {
108 use serverless_fn::transport::{get_default_transport, Transport};
109 use serverless_fn::serializer::{get_default_serializer, Serializer};
110 use serverless_fn::error::ServerlessError;
111 use serverless_fn::config::Config;
112 use serde::{Serialize, Deserialize};
113
114 #[allow(non_camel_case_types, missing_docs)]
115 #[derive(Serialize, Deserialize)]
116 struct #input_struct_name {
117 #(pub #arg_names: #arg_types,)*
118 }
119
120 let input = #input_struct_name {
121 #(#arg_names,)*
122 };
123
124 let config = Config::from_env();
125 let serializer = get_default_serializer();
126 let serialized_input = serializer.serialize(&input).map_err(ServerlessError::from)?;
127
128 let transport = get_default_transport(config.timeout(), config.retries());
129 let response_bytes = transport.call(#fn_name_str, serialized_input, None).await
130 .map_err(|e| ServerlessError::RemoteExecution(e.to_string()))?;
131
132 let output: #success_type = get_default_serializer()
133 .deserialize(&response_bytes)
134 .map_err(ServerlessError::from)?;
135
136 Ok(output)
137 }
138
139 #[cfg(any(feature = "local_call", not(feature = "remote_call")))]
141 #(#attrs)*
142 #vis async fn #fn_name(#(#arg_names: #arg_types),*) -> #return_type {
143 #server_fn_name(#(#arg_names),*).await
144 }
145
146 #registrar_impl
147 };
148
149 output.into()
150}
151
152fn extract_ok_type(ty: &syn::Type) -> Option<syn::Type> {
154 if let syn::Type::Path(TypePath { path, .. }) = ty
155 && let Some(segment) = path.segments.first()
156 && segment.ident == "Result"
157 && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
158 && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
159 {
160 return Some(ty.clone());
161 }
162 None
163}
164
165struct RegistrarConfig<'a> {
167 registrar_struct_name: &'a syn::Ident,
168 path_static: &'a syn::Ident,
169 fn_name_str: &'a str,
170 server_fn_name: &'a syn::Ident,
171 input_struct_name: &'a syn::Ident,
172 arg_names: &'a [&'a syn::Ident],
173 arg_types: &'a [&'a syn::Type],
174 _success_type: &'a syn::Type,
175}
176
177fn generate_registrar_impl(config: &RegistrarConfig<'_>) -> proc_macro2::TokenStream {
179 let RegistrarConfig {
180 registrar_struct_name,
181 path_static,
182 fn_name_str,
183 server_fn_name,
184 input_struct_name,
185 arg_names,
186 arg_types,
187 _success_type,
188 } = config;
189
190 quote! {
191 static #path_static: &str = ::std::concat!("/", #fn_name_str);
193
194 #[allow(non_camel_case_types, missing_docs)]
196 struct #registrar_struct_name;
197
198 impl ::serverless_fn::server::FunctionRegistry for #registrar_struct_name {
199 fn function_name(&self) -> &'static str {
200 #fn_name_str
201 }
202
203 fn function_path(&self) -> &'static str {
204 #path_static
205 }
206
207 fn register(&self, server: &mut ::serverless_fn::server::FunctionServer) {
208 use serde::{Deserialize, Serialize};
209 use serverless_fn::serializer::get_default_serializer;
210 use serverless_fn::error::ServerlessError;
211 use axum::extract::Json;
212 use axum::http::StatusCode;
213 use axum::body::Bytes;
214
215 #[derive(Serialize, Deserialize)]
216 #[allow(non_camel_case_types, missing_docs)]
217 struct #input_struct_name {
218 #(pub #arg_names: #arg_types,)*
219 }
220
221 server.register_http_route(#path_static, move |body: Bytes| async move {
222 let serializer = get_default_serializer();
223 let input: #input_struct_name = serializer.deserialize(
224 &body.to_vec()
225 ).map_err(|e| {
226 (StatusCode::BAD_REQUEST, e.to_string())
227 })?;
228
229 let result = #server_fn_name(#(input.#arg_names),*).await;
230
231 match result {
232 Ok(value) => {
233 let response_bytes = serializer.serialize(&value)
234 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
235 Ok::<_, (StatusCode, String)>(axum::response::Response::builder()
236 .status(StatusCode::OK)
237 .header("content-type", "application/octet-stream")
238 .body(axum::body::Body::from(response_bytes))
239 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?)
240 }
241 Err(e) => Err((
242 StatusCode::INTERNAL_SERVER_ERROR,
243 e.to_string()
244 )),
245 }
246 });
247 }
248 }
249
250 ::inventory::submit! {
252 &#registrar_struct_name as &'static dyn ::serverless_fn::server::FunctionRegistry
253 }
254 }
255}