turul_mcp_json_rpc_server/
async.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6
7use crate::{
8 error::{JsonRpcError, JsonRpcProcessingError},
9 notification::JsonRpcNotification,
10 request::{JsonRpcRequest, RequestParams},
11 response::{JsonRpcResponse, ResponseResult},
12};
13
14#[derive(Debug, Clone)]
17pub struct SessionContext {
18 pub session_id: String,
20 pub metadata: HashMap<String, Value>,
22 pub broadcaster: Option<Arc<dyn std::any::Any + Send + Sync>>,
24 pub timestamp: u64,
26}
27
28pub type JsonRpcResult<T> = Result<T, JsonRpcProcessingError>;
30
31#[async_trait]
33pub trait JsonRpcHandler: Send + Sync {
34 async fn handle(&self, method: &str, params: Option<RequestParams>, session_context: Option<SessionContext>) -> JsonRpcResult<Value>;
36
37 async fn handle_notification(&self, method: &str, params: Option<RequestParams>, session_context: Option<SessionContext>) -> JsonRpcResult<()> {
39 let _ = (method, params, session_context);
41 Ok(())
42 }
43
44 fn supported_methods(&self) -> Vec<String> {
46 vec![]
47 }
48}
49
50pub struct FunctionHandler<F, N>
52where
53 F: Fn(&str, Option<RequestParams>, Option<SessionContext>) -> futures::future::BoxFuture<'static, JsonRpcResult<Value>> + Send + Sync,
54 N: Fn(&str, Option<RequestParams>, Option<SessionContext>) -> futures::future::BoxFuture<'static, JsonRpcResult<()>> + Send + Sync,
55{
56 handler_fn: F,
57 notification_fn: Option<N>,
58 methods: Vec<String>,
59}
60
61impl<F, N> FunctionHandler<F, N>
62where
63 F: Fn(&str, Option<RequestParams>, Option<SessionContext>) -> futures::future::BoxFuture<'static, JsonRpcResult<Value>> + Send + Sync,
64 N: Fn(&str, Option<RequestParams>, Option<SessionContext>) -> futures::future::BoxFuture<'static, JsonRpcResult<()>> + Send + Sync,
65{
66 pub fn new(handler_fn: F) -> Self {
67 Self {
68 handler_fn,
69 notification_fn: None,
70 methods: vec![],
71 }
72 }
73
74 pub fn with_notification_handler(mut self, notification_fn: N) -> Self {
75 self.notification_fn = Some(notification_fn);
76 self
77 }
78
79 pub fn with_methods(mut self, methods: Vec<String>) -> Self {
80 self.methods = methods;
81 self
82 }
83}
84
85#[async_trait]
86impl<F, N> JsonRpcHandler for FunctionHandler<F, N>
87where
88 F: Fn(&str, Option<RequestParams>, Option<SessionContext>) -> futures::future::BoxFuture<'static, JsonRpcResult<Value>> + Send + Sync,
89 N: Fn(&str, Option<RequestParams>, Option<SessionContext>) -> futures::future::BoxFuture<'static, JsonRpcResult<()>> + Send + Sync,
90{
91 async fn handle(&self, method: &str, params: Option<RequestParams>, session_context: Option<SessionContext>) -> JsonRpcResult<Value> {
92 (self.handler_fn)(method, params, session_context).await
93 }
94
95 async fn handle_notification(&self, method: &str, params: Option<RequestParams>, session_context: Option<SessionContext>) -> JsonRpcResult<()> {
96 if let Some(ref notification_fn) = self.notification_fn {
97 (notification_fn)(method, params, session_context).await
98 } else {
99 Ok(())
100 }
101 }
102
103 fn supported_methods(&self) -> Vec<String> {
104 self.methods.clone()
105 }
106}
107
108pub struct JsonRpcDispatcher {
110 handlers: HashMap<String, Arc<dyn JsonRpcHandler>>,
111 default_handler: Option<Arc<dyn JsonRpcHandler>>,
112}
113
114impl JsonRpcDispatcher {
115 pub fn new() -> Self {
116 Self {
117 handlers: HashMap::new(),
118 default_handler: None,
119 }
120 }
121
122 pub fn register_method<H>(&mut self, method: String, handler: H)
124 where
125 H: JsonRpcHandler + 'static,
126 {
127 self.handlers.insert(method, Arc::new(handler));
128 }
129
130 pub fn register_methods<H>(&mut self, methods: Vec<String>, handler: H)
132 where
133 H: JsonRpcHandler + 'static,
134 {
135 let handler_arc = Arc::new(handler);
136 for method in methods {
137 self.handlers.insert(method, handler_arc.clone());
138 }
139 }
140
141 pub fn set_default_handler<H>(&mut self, handler: H)
143 where
144 H: JsonRpcHandler + 'static,
145 {
146 self.default_handler = Some(Arc::new(handler));
147 }
148
149 pub async fn handle_request_with_context(&self, request: JsonRpcRequest, session_context: SessionContext) -> JsonRpcResponse {
151 let handler = self.handlers.get(&request.method)
152 .or(self.default_handler.as_ref());
153
154 match handler {
155 Some(handler) => {
156 match handler.handle(&request.method, request.params, Some(session_context)).await {
157 Ok(result) => JsonRpcResponse::new(request.id, ResponseResult::Success(result)),
158 Err(err) => {
159 let rpc_error = err.to_rpc_error(Some(request.id.clone()));
160 JsonRpcResponse::new(request.id, ResponseResult::Success(
162 serde_json::json!({
163 "error": {
164 "code": rpc_error.error.code,
165 "message": rpc_error.error.message,
166 "data": rpc_error.error.data
167 }
168 })
169 ))
170 }
171 }
172 }
173 None => {
174 let error = JsonRpcError::method_not_found(request.id.clone(), &request.method);
175 JsonRpcResponse::new(request.id, ResponseResult::Success(
176 serde_json::json!({
177 "error": {
178 "code": error.error.code,
179 "message": error.error.message
180 }
181 })
182 ))
183 }
184 }
185 }
186
187 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
189 let handler = self.handlers.get(&request.method)
190 .or(self.default_handler.as_ref());
191
192 match handler {
193 Some(handler) => {
194 match handler.handle(&request.method, request.params, None).await {
195 Ok(result) => JsonRpcResponse::new(request.id, ResponseResult::Success(result)),
196 Err(err) => {
197 let rpc_error = err.to_rpc_error(Some(request.id.clone()));
198 JsonRpcResponse::new(request.id, ResponseResult::Success(
200 serde_json::json!({
201 "error": {
202 "code": rpc_error.error.code,
203 "message": rpc_error.error.message,
204 "data": rpc_error.error.data
205 }
206 })
207 ))
208 }
209 }
210 }
211 None => {
212 let error = JsonRpcError::method_not_found(request.id.clone(), &request.method);
213 JsonRpcResponse::new(request.id, ResponseResult::Success(
214 serde_json::json!({
215 "error": {
216 "code": error.error.code,
217 "message": error.error.message
218 }
219 })
220 ))
221 }
222 }
223 }
224
225 pub async fn handle_notification(&self, notification: JsonRpcNotification) -> JsonRpcResult<()> {
227 let handler = self.handlers.get(¬ification.method)
228 .or(self.default_handler.as_ref());
229
230 match handler {
231 Some(handler) => {
232 handler.handle_notification(¬ification.method, notification.params, None).await
233 }
234 None => {
235 Ok(())
237 }
238 }
239 }
240
241 pub async fn handle_notification_with_context(&self, notification: JsonRpcNotification, session_context: Option<SessionContext>) -> JsonRpcResult<()> {
243 let handler = self.handlers.get(¬ification.method)
244 .or(self.default_handler.as_ref());
245
246 match handler {
247 Some(handler) => {
248 handler.handle_notification(¬ification.method, notification.params, session_context).await
249 }
250 None => {
251 Ok(())
253 }
254 }
255 }
256
257 pub fn registered_methods(&self) -> Vec<String> {
259 self.handlers.keys().cloned().collect()
260 }
261}
262
263impl Default for JsonRpcDispatcher {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use serde_json::json;
273 use crate::{RequestId, JsonRpcRequest};
274
275 struct TestHandler;
276
277 #[async_trait]
278 impl JsonRpcHandler for TestHandler {
279 async fn handle(&self, method: &str, _params: Option<RequestParams>, _session_context: Option<SessionContext>) -> JsonRpcResult<Value> {
280 match method {
281 "add" => Ok(json!({"result": "addition"})),
282 "error" => Err(JsonRpcProcessingError::HandlerError("test error".to_string())),
283 _ => Err(JsonRpcProcessingError::HandlerError("unknown method".to_string())),
284 }
285 }
286
287 fn supported_methods(&self) -> Vec<String> {
288 vec!["add".to_string(), "error".to_string()]
289 }
290 }
291
292 #[tokio::test]
293 async fn test_dispatcher_success() {
294 let mut dispatcher = JsonRpcDispatcher::new();
295 dispatcher.register_method("add".to_string(), TestHandler);
296
297 let request = JsonRpcRequest::new_no_params(
298 RequestId::Number(1),
299 "add".to_string(),
300 );
301
302 let response = dispatcher.handle_request(request).await;
303 assert_eq!(response.id, RequestId::Number(1));
304 }
305
306 #[tokio::test]
307 async fn test_dispatcher_method_not_found() {
308 let dispatcher = JsonRpcDispatcher::new();
309
310 let request = JsonRpcRequest::new_no_params(
311 RequestId::Number(1),
312 "unknown".to_string(),
313 );
314
315 let response = dispatcher.handle_request(request).await;
316 assert_eq!(response.id, RequestId::Number(1));
317 }
319
320 #[tokio::test]
321 async fn test_function_handler() {
322 let handler = TestHandler;
324 let result = handler.handle("add", None, None).await.unwrap();
325 assert_eq!(result["result"], "addition");
326 }
327}