Skip to main content

server_less_rpc/
lib.rs

1//! Shared utilities for RPC-style macros (MCP, WebSocket, JSON-RPC).
2//!
3//! These macros use JSON-RPC-like dispatch:
4//! - Receive `{"method": "name", "params": {...}}`
5//! - Extract params from JSON
6//! - Call the method
7//! - Serialize result back to JSON
8
9use proc_macro2::TokenStream;
10use quote::quote;
11use server_less_parse::{MethodInfo, ParamInfo};
12
13/// Generate code to extract a parameter from a `serde_json::Value` args object.
14pub fn generate_param_extraction(param: &ParamInfo) -> TokenStream {
15    let name = &param.name;
16    let name_str = param.name.to_string();
17    let ty = &param.ty;
18
19    if param.is_optional {
20        // For Option<T>, extract inner value, return None if missing/null
21        quote! {
22            let #name: #ty = args.get(#name_str)
23                .and_then(|v| if v.is_null() { None } else {
24                    ::server_less::serde_json::from_value(v.clone()).ok()
25                });
26        }
27    } else {
28        // Required parameter - error if missing
29        quote! {
30            let __val = args.get(#name_str)
31                .ok_or_else(|| format!("Missing required parameter: {}", #name_str))?
32                .clone();
33            let #name: #ty = ::server_less::serde_json::from_value::<#ty>(__val)
34                .map_err(|e| format!("Invalid parameter {}: {}", #name_str, e))?;
35        }
36    }
37}
38
39/// Generate all param extractions for a method.
40pub fn generate_all_param_extractions(method: &MethodInfo) -> Vec<TokenStream> {
41    method
42        .params
43        .iter()
44        .map(generate_param_extraction)
45        .collect()
46}
47
48/// Generate param extractions for specific parameters only.
49///
50/// This allows filtering out framework-injected params (like Context)
51/// that shouldn't be extracted from JSON.
52pub fn generate_param_extractions_for(params: &[&ParamInfo]) -> Vec<TokenStream> {
53    params
54        .iter()
55        .map(|p| generate_param_extraction(p))
56        .collect()
57}
58
59/// Generate the method call expression.
60///
61/// Returns tokens for calling `self.method_name(arg1, arg2, ...)`.
62/// For async methods, returns an error (caller should handle async context).
63pub fn generate_method_call(method: &MethodInfo, handle_async: AsyncHandling) -> TokenStream {
64    let method_name = &method.name;
65    let arg_names: Vec<_> = method.params.iter().map(|p| &p.name).collect();
66
67    match (method.is_async, handle_async) {
68        (true, AsyncHandling::Error) => {
69            quote! {
70                return Err("Async methods not supported in sync context".to_string());
71            }
72        }
73        (true, AsyncHandling::Await) => {
74            quote! {
75                let result = self.#method_name(#(#arg_names),*).await;
76            }
77        }
78        (true, AsyncHandling::BlockOn) => {
79            quote! {
80                let result = ::tokio::runtime::Runtime::new()
81                    .expect("Failed to create Tokio runtime")
82                    .block_on(self.#method_name(#(#arg_names),*));
83            }
84        }
85        (false, _) => {
86            quote! {
87                let result = self.#method_name(#(#arg_names),*);
88            }
89        }
90    }
91}
92
93/// Generate method call with custom argument expressions.
94///
95/// This allows mixing framework-injected args (like `__ctx`) with
96/// params extracted from JSON.
97pub fn generate_method_call_with_args(
98    method: &MethodInfo,
99    arg_exprs: Vec<TokenStream>,
100    handle_async: AsyncHandling,
101) -> TokenStream {
102    let method_name = &method.name;
103
104    match (method.is_async, handle_async) {
105        (true, AsyncHandling::Error) => {
106            quote! {
107                return Err("Async methods not supported in sync context".to_string());
108            }
109        }
110        (true, AsyncHandling::Await) => {
111            quote! {
112                let result = self.#method_name(#(#arg_exprs),*).await;
113            }
114        }
115        (true, AsyncHandling::BlockOn) => {
116            quote! {
117                let result = ::tokio::runtime::Runtime::new()
118                    .expect("Failed to create Tokio runtime")
119                    .block_on(self.#method_name(#(#arg_exprs),*));
120            }
121        }
122        (false, _) => {
123            quote! {
124                let result = self.#method_name(#(#arg_exprs),*);
125            }
126        }
127    }
128}
129
130/// How to handle async methods.
131#[derive(Debug, Clone, Copy)]
132pub enum AsyncHandling {
133    /// Return an error if method is async
134    Error,
135    /// Await the method (caller must be async)
136    Await,
137    /// Use tokio::runtime::Runtime::block_on
138    BlockOn,
139}
140
141/// Generate response handling that converts the method result to JSON.
142///
143/// Handles:
144/// - `()` → `{"success": true}`
145/// - `Result<T, E>` → `Ok(T)` or `Err(message)`
146/// - `Option<T>` → `T` or `null`
147/// - `T` → serialized T
148pub fn generate_json_response(method: &MethodInfo) -> TokenStream {
149    let ret = &method.return_info;
150
151    if ret.is_unit {
152        quote! {
153            Ok(::server_less::serde_json::json!({"success": true}))
154        }
155    } else if ret.is_stream {
156        // Automatically collect streams into Vec for JSON serialization
157        quote! {
158            {
159                use ::server_less::futures::StreamExt;
160                let collected: Vec<_> = result.collect().await;
161                Ok(::server_less::serde_json::to_value(collected)
162                    .map_err(|e| format!("Serialization error: {}", e))?)
163            }
164        }
165    } else if ret.is_result {
166        quote! {
167            match result {
168                Ok(value) => Ok(::server_less::serde_json::to_value(value)
169                    .map_err(|e| format!("Serialization error: {}", e))?),
170                Err(err) => Err(format!("{:?}", err)),
171            }
172        }
173    } else if ret.is_option {
174        quote! {
175            match result {
176                Some(value) => Ok(::server_less::serde_json::to_value(value)
177                    .map_err(|e| format!("Serialization error: {}", e))?),
178                None => Ok(::server_less::serde_json::Value::Null),
179            }
180        }
181    } else {
182        // Plain T
183        quote! {
184            Ok(::server_less::serde_json::to_value(result)
185                .map_err(|e| format!("Serialization error: {}", e))?)
186        }
187    }
188}
189
190/// Generate a complete dispatch match arm for an RPC method.
191///
192/// Combines param extraction, method call, and response handling.
193pub fn generate_dispatch_arm(
194    method: &MethodInfo,
195    method_name_override: Option<&str>,
196    async_handling: AsyncHandling,
197) -> TokenStream {
198    let method_name_str = method_name_override
199        .map(String::from)
200        .unwrap_or_else(|| method.name.to_string());
201
202    // Methods that are async OR return streams require async context
203    let requires_async = method.is_async || method.return_info.is_stream;
204
205    // For methods requiring async with Error handling, return early
206    if requires_async && matches!(async_handling, AsyncHandling::Error) {
207        return quote! {
208            #method_name_str => {
209                return Err("Async methods and streaming methods not supported in sync context".to_string());
210            }
211        };
212    }
213
214    let param_extractions = generate_all_param_extractions(method);
215    let call = generate_method_call(method, async_handling);
216    let response = generate_json_response(method);
217
218    quote! {
219        #method_name_str => {
220            #(#param_extractions)*
221            #call
222            #response
223        }
224    }
225}
226
227/// Infer JSON schema type from Rust type.
228pub fn infer_json_type(ty: &syn::Type) -> &'static str {
229    let ty_str = quote!(#ty).to_string();
230
231    if ty_str.contains("String") || ty_str.contains("str") {
232        "string"
233    } else if ty_str.contains("i8")
234        || ty_str.contains("i16")
235        || ty_str.contains("i32")
236        || ty_str.contains("i64")
237        || ty_str.contains("u8")
238        || ty_str.contains("u16")
239        || ty_str.contains("u32")
240        || ty_str.contains("u64")
241        || ty_str.contains("isize")
242        || ty_str.contains("usize")
243    {
244        "integer"
245    } else if ty_str.contains("f32") || ty_str.contains("f64") {
246        "number"
247    } else if ty_str.contains("bool") {
248        "boolean"
249    } else if ty_str.contains("Vec") || ty_str.contains("[]") {
250        "array"
251    } else {
252        "object"
253    }
254}
255
256/// Generate JSON schema properties for method parameters.
257pub fn generate_param_schema(params: &[ParamInfo]) -> (Vec<TokenStream>, Vec<String>) {
258    let properties: Vec<_> = params
259        .iter()
260        .map(|p| {
261            let param_name = p.name.to_string();
262            let param_type = infer_json_type(&p.ty);
263            let description = format!("Parameter: {}", param_name);
264
265            quote! {
266                (#param_name, #param_type, #description)
267            }
268        })
269        .collect();
270
271    let required: Vec<_> = params
272        .iter()
273        .filter(|p| !p.is_optional)
274        .map(|p| p.name.to_string())
275        .collect();
276
277    (properties, required)
278}
279
280/// Generate JSON schema properties for specific parameters (e.g., excluding Context).
281pub fn generate_param_schema_for(params: &[&ParamInfo]) -> (Vec<TokenStream>, Vec<String>) {
282    let properties: Vec<_> = params
283        .iter()
284        .map(|p| {
285            let param_name = p.name.to_string();
286            let param_type = infer_json_type(&p.ty);
287            let description = format!("Parameter: {}", param_name);
288
289            quote! {
290                (#param_name, #param_type, #description)
291            }
292        })
293        .collect();
294
295    let required: Vec<_> = params
296        .iter()
297        .filter(|p| !p.is_optional)
298        .map(|p| p.name.to_string())
299        .collect();
300
301    (properties, required)
302}