1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6
7#[cfg(feature = "streams")]
8use futures::{Stream, StreamExt};
9#[cfg(feature = "streams")]
10use std::pin::Pin;
11
12use crate::{
13 error::JsonRpcError,
14 notification::JsonRpcNotification,
15 request::{JsonRpcRequest, RequestParams},
16 response::{JsonRpcMessage, ResponseResult},
17};
18
19#[derive(Debug, Clone)]
22pub struct SessionContext {
23 pub session_id: String,
25 pub metadata: HashMap<String, Value>,
27 pub broadcaster: Option<Arc<dyn std::any::Any + Send + Sync>>,
29 pub timestamp: u64,
31}
32
33#[async_trait]
35pub trait JsonRpcHandler: Send + Sync {
36 type Error: std::error::Error + Send + Sync + 'static;
38
39 async fn handle(
42 &self,
43 method: &str,
44 params: Option<RequestParams>,
45 session_context: Option<SessionContext>,
46 ) -> Result<Value, Self::Error>;
47
48 async fn handle_notification(
50 &self,
51 method: &str,
52 params: Option<RequestParams>,
53 session_context: Option<SessionContext>,
54 ) -> Result<(), Self::Error> {
55 let _ = (method, params, session_context);
57 Ok(())
58 }
59
60 fn supported_methods(&self) -> Vec<String> {
62 vec![]
63 }
64}
65
66pub struct FunctionHandler<F, N, E>
68where
69 E: std::error::Error + Send + Sync + 'static,
70 F: Fn(
71 &str,
72 Option<RequestParams>,
73 Option<SessionContext>,
74 ) -> futures::future::BoxFuture<'static, Result<Value, E>>
75 + Send
76 + Sync,
77 N: Fn(
78 &str,
79 Option<RequestParams>,
80 Option<SessionContext>,
81 ) -> futures::future::BoxFuture<'static, Result<(), E>>
82 + Send
83 + Sync,
84{
85 handler_fn: F,
86 notification_fn: Option<N>,
87 methods: Vec<String>,
88}
89
90impl<F, N, E> FunctionHandler<F, N, E>
91where
92 E: std::error::Error + Send + Sync + 'static,
93 F: Fn(
94 &str,
95 Option<RequestParams>,
96 Option<SessionContext>,
97 ) -> futures::future::BoxFuture<'static, Result<Value, E>>
98 + Send
99 + Sync,
100 N: Fn(
101 &str,
102 Option<RequestParams>,
103 Option<SessionContext>,
104 ) -> futures::future::BoxFuture<'static, Result<(), E>>
105 + Send
106 + Sync,
107{
108 pub fn new(handler_fn: F) -> Self {
109 Self {
110 handler_fn,
111 notification_fn: None,
112 methods: vec![],
113 }
114 }
115
116 pub fn with_notification_handler(mut self, notification_fn: N) -> Self {
117 self.notification_fn = Some(notification_fn);
118 self
119 }
120
121 pub fn with_methods(mut self, methods: Vec<String>) -> Self {
122 self.methods = methods;
123 self
124 }
125}
126
127#[async_trait]
128impl<F, N, E> JsonRpcHandler for FunctionHandler<F, N, E>
129where
130 E: std::error::Error + Send + Sync + 'static,
131 F: Fn(
132 &str,
133 Option<RequestParams>,
134 Option<SessionContext>,
135 ) -> futures::future::BoxFuture<'static, Result<Value, E>>
136 + Send
137 + Sync,
138 N: Fn(
139 &str,
140 Option<RequestParams>,
141 Option<SessionContext>,
142 ) -> futures::future::BoxFuture<'static, Result<(), E>>
143 + Send
144 + Sync,
145{
146 type Error = E;
147
148 async fn handle(
149 &self,
150 method: &str,
151 params: Option<RequestParams>,
152 session_context: Option<SessionContext>,
153 ) -> Result<Value, Self::Error> {
154 (self.handler_fn)(method, params, session_context).await
155 }
156
157 async fn handle_notification(
158 &self,
159 method: &str,
160 params: Option<RequestParams>,
161 session_context: Option<SessionContext>,
162 ) -> Result<(), Self::Error> {
163 if let Some(ref notification_fn) = self.notification_fn {
164 (notification_fn)(method, params, session_context).await
165 } else {
166 Ok(())
167 }
168 }
169
170 fn supported_methods(&self) -> Vec<String> {
171 self.methods.clone()
172 }
173}
174
175pub trait ToJsonRpcError: std::error::Error + Send + Sync + 'static {
177 fn to_error_object(&self) -> crate::error::JsonRpcErrorObject;
179}
180
181pub struct JsonRpcDispatcher<E>
183where
184 E: ToJsonRpcError,
185{
186 pub handlers: HashMap<String, Arc<dyn JsonRpcHandler<Error = E>>>,
187 pub default_handler: Option<Arc<dyn JsonRpcHandler<Error = E>>>,
188}
189
190impl<E> JsonRpcDispatcher<E>
191where
192 E: ToJsonRpcError,
193{
194 pub fn new() -> Self {
195 Self {
196 handlers: HashMap::new(),
197 default_handler: None,
198 }
199 }
200
201 pub fn register_method<H>(&mut self, method: String, handler: H)
203 where
204 H: JsonRpcHandler<Error = E> + 'static,
205 {
206 self.handlers.insert(method, Arc::new(handler));
207 }
208
209 pub fn register_methods<H>(&mut self, methods: Vec<String>, handler: H)
211 where
212 H: JsonRpcHandler<Error = E> + 'static,
213 {
214 let handler_arc = Arc::new(handler);
215 for method in methods {
216 self.handlers.insert(method, handler_arc.clone());
217 }
218 }
219
220 pub fn set_default_handler<H>(&mut self, handler: H)
222 where
223 H: JsonRpcHandler<Error = E> + 'static,
224 {
225 self.default_handler = Some(Arc::new(handler));
226 }
227
228 pub async fn handle_request_with_context(
230 &self,
231 request: JsonRpcRequest,
232 session_context: SessionContext,
233 ) -> JsonRpcMessage {
234 let handler = self
235 .handlers
236 .get(&request.method)
237 .or(self.default_handler.as_ref());
238
239 match handler {
240 Some(handler) => {
241 match handler
242 .handle(&request.method, request.params, Some(session_context))
243 .await
244 {
245 Ok(result) => {
246 JsonRpcMessage::success(request.id, ResponseResult::Success(result))
247 }
248 Err(domain_error) => {
249 let error_object = domain_error.to_error_object();
251 let rpc_error = JsonRpcError::new(Some(request.id.clone()), error_object);
252 JsonRpcMessage::error(rpc_error)
253 }
254 }
255 }
256 None => {
257 let error = JsonRpcError::method_not_found(request.id.clone(), &request.method);
258 JsonRpcMessage::error(error)
259 }
260 }
261 }
262
263 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcMessage {
265 let handler = self
266 .handlers
267 .get(&request.method)
268 .or(self.default_handler.as_ref());
269
270 match handler {
271 Some(handler) => {
272 match handler.handle(&request.method, request.params, None).await {
273 Ok(result) => {
274 JsonRpcMessage::success(request.id, ResponseResult::Success(result))
275 }
276 Err(domain_error) => {
277 let error_object = domain_error.to_error_object();
279 let rpc_error = JsonRpcError::new(Some(request.id.clone()), error_object);
280 JsonRpcMessage::error(rpc_error)
281 }
282 }
283 }
284 None => {
285 let error = JsonRpcError::method_not_found(request.id.clone(), &request.method);
286 JsonRpcMessage::error(error)
287 }
288 }
289 }
290
291 pub async fn handle_notification(&self, notification: JsonRpcNotification) -> Result<(), E> {
293 let handler = self
294 .handlers
295 .get(¬ification.method)
296 .or(self.default_handler.as_ref());
297
298 match handler {
299 Some(handler) => {
300 handler
301 .handle_notification(¬ification.method, notification.params, None)
302 .await
303 }
304 None => {
305 Ok(())
307 }
308 }
309 }
310
311 pub async fn handle_notification_with_context(
313 &self,
314 notification: JsonRpcNotification,
315 session_context: Option<SessionContext>,
316 ) -> Result<(), E> {
317 let handler = self
318 .handlers
319 .get(¬ification.method)
320 .or(self.default_handler.as_ref());
321
322 match handler {
323 Some(handler) => {
324 handler
325 .handle_notification(¬ification.method, notification.params, session_context)
326 .await
327 }
328 None => {
329 Ok(())
331 }
332 }
333 }
334
335 pub fn registered_methods(&self) -> Vec<String> {
337 self.handlers.keys().cloned().collect()
338 }
339}
340
341impl<E> Default for JsonRpcDispatcher<E>
342where
343 E: ToJsonRpcError,
344{
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::{JsonRpcRequest, RequestId};
354 use serde_json::json;
355
356 #[derive(thiserror::Error, Debug)]
357 enum TestError {
358 #[error("Test error: {0}")]
359 TestError(String),
360 #[error("Unknown method: {0}")]
361 UnknownMethod(String),
362 }
363
364 impl ToJsonRpcError for TestError {
365 fn to_error_object(&self) -> crate::error::JsonRpcErrorObject {
366 use crate::error::JsonRpcErrorObject;
367 match self {
368 TestError::TestError(msg) => JsonRpcErrorObject::internal_error(Some(msg.clone())),
369 TestError::UnknownMethod(method) => JsonRpcErrorObject::method_not_found(method),
370 }
371 }
372 }
373
374 struct TestHandler;
375
376 #[async_trait]
377 impl JsonRpcHandler for TestHandler {
378 type Error = TestError;
379
380 async fn handle(
381 &self,
382 method: &str,
383 _params: Option<RequestParams>,
384 _session_context: Option<SessionContext>,
385 ) -> Result<Value, Self::Error> {
386 match method {
387 "add" => Ok(json!({"result": "addition"})),
388 "error" => Err(TestError::TestError("test error".to_string())),
389 _ => Err(TestError::UnknownMethod(method.to_string())),
390 }
391 }
392
393 fn supported_methods(&self) -> Vec<String> {
394 vec!["add".to_string(), "error".to_string()]
395 }
396 }
397
398 #[tokio::test]
399 async fn test_dispatcher_success() {
400 let mut dispatcher: JsonRpcDispatcher<TestError> = JsonRpcDispatcher::new();
401 dispatcher.register_method("add".to_string(), TestHandler);
402
403 let request = JsonRpcRequest::new_no_params(RequestId::Number(1), "add".to_string());
404
405 let response = dispatcher.handle_request(request).await;
406 assert_eq!(response.id(), Some(&RequestId::Number(1)));
407 assert!(!response.is_error());
408 }
409
410 #[tokio::test]
411 async fn test_dispatcher_method_not_found() {
412 let dispatcher: JsonRpcDispatcher<TestError> = JsonRpcDispatcher::new();
413
414 let request = JsonRpcRequest::new_no_params(RequestId::Number(1), "unknown".to_string());
415
416 let response = dispatcher.handle_request(request).await;
417 assert_eq!(response.id(), Some(&RequestId::Number(1)));
418 assert!(response.is_error());
419 }
420
421 #[tokio::test]
422 async fn test_function_handler() {
423 let handler = TestHandler;
425 let result = handler.handle("add", None, None).await.unwrap();
426 assert_eq!(result["result"], "addition");
427 }
428}
429
430#[cfg(feature = "streams")]
435pub mod streaming {
436 use super::*;
437
438 #[derive(Debug, Clone)]
441 pub enum JsonRpcFrame {
442 Progress {
444 request_id: crate::types::RequestId,
445 progress: Value,
446 progress_token: Option<String>,
447 },
448 PartialResult {
450 request_id: crate::types::RequestId,
451 data: Value,
452 },
453 FinalResult {
455 request_id: crate::types::RequestId,
456 result: Value,
457 },
458 Error {
460 request_id: crate::types::RequestId,
461 error: crate::error::JsonRpcErrorObject,
462 },
463 Notification {
465 method: String,
466 params: Option<Value>,
467 },
468 }
469
470 impl JsonRpcFrame {
471 pub fn to_json(&self) -> Value {
473 match self {
474 JsonRpcFrame::Progress {
475 request_id,
476 progress,
477 progress_token,
478 } => {
479 let mut obj = serde_json::json!({
480 "jsonrpc": "2.0",
481 "id": request_id,
482 "_meta": {
483 "progress": progress
484 }
485 });
486
487 if let Some(token) = progress_token {
488 obj["_meta"]["progressToken"] = Value::String(token.clone());
489 }
490
491 obj
492 }
493 JsonRpcFrame::PartialResult { request_id, data } => {
494 serde_json::json!({
495 "jsonrpc": "2.0",
496 "id": request_id,
497 "_meta": {
498 "partial": true
499 },
500 "result": data
501 })
502 }
503 JsonRpcFrame::FinalResult { request_id, result } => {
504 serde_json::json!({
505 "jsonrpc": "2.0",
506 "id": request_id,
507 "result": result
508 })
509 }
510 JsonRpcFrame::Error { request_id, error } => {
511 serde_json::json!({
512 "jsonrpc": "2.0",
513 "id": request_id,
514 "error": {
515 "code": error.code,
516 "message": &error.message,
517 "data": &error.data
518 }
519 })
520 }
521 JsonRpcFrame::Notification { method, params } => {
522 let mut obj = serde_json::json!({
523 "jsonrpc": "2.0",
524 "method": method
525 });
526
527 if let Some(params) = params {
528 obj["params"] = params.clone();
529 }
530
531 obj
532 }
533 }
534 }
535
536 pub fn is_terminal(&self) -> bool {
538 matches!(
539 self,
540 JsonRpcFrame::FinalResult { .. } | JsonRpcFrame::Error { .. }
541 )
542 }
543 }
544
545 #[async_trait]
547 pub trait StreamingJsonRpcHandler: Send + Sync {
548 type Error: std::error::Error + Send + Sync + 'static;
550
551 async fn handle_streaming(
554 &self,
555 method: &str,
556 params: Option<crate::request::RequestParams>,
557 session_context: Option<SessionContext>,
558 request_id: crate::types::RequestId,
559 ) -> Pin<Box<dyn Stream<Item = Result<JsonRpcFrame, Self::Error>> + Send>>;
560
561 async fn handle_notification(
563 &self,
564 method: &str,
565 params: Option<crate::request::RequestParams>,
566 session_context: Option<SessionContext>,
567 ) -> Result<(), Self::Error> {
568 let _ = (method, params, session_context);
570 Ok(())
571 }
572
573 fn supported_methods(&self) -> Vec<String> {
575 vec![]
576 }
577 }
578
579 pub struct StreamingJsonRpcDispatcher<E>
581 where
582 E: ToJsonRpcError,
583 {
584 streaming_handlers: HashMap<String, Arc<dyn StreamingJsonRpcHandler<Error = E>>>,
585 fallback_handlers: HashMap<String, Arc<dyn JsonRpcHandler<Error = E>>>,
586 default_handler: Option<Arc<dyn JsonRpcHandler<Error = E>>>,
587 }
588
589 impl<E> StreamingJsonRpcDispatcher<E>
590 where
591 E: ToJsonRpcError,
592 {
593 pub fn new() -> Self {
594 Self {
595 streaming_handlers: HashMap::new(),
596 fallback_handlers: HashMap::new(),
597 default_handler: None,
598 }
599 }
600
601 pub fn register_streaming_method<H>(&mut self, method: String, handler: H)
603 where
604 H: StreamingJsonRpcHandler<Error = E> + 'static,
605 {
606 self.streaming_handlers.insert(method, Arc::new(handler));
607 }
608
609 pub fn register_fallback_method<H>(&mut self, method: String, handler: H)
611 where
612 H: JsonRpcHandler<Error = E> + 'static,
613 {
614 self.fallback_handlers.insert(method, Arc::new(handler));
615 }
616
617 pub fn set_default_handler<H>(&mut self, handler: H)
619 where
620 H: JsonRpcHandler<Error = E> + 'static,
621 {
622 self.default_handler = Some(Arc::new(handler));
623 }
624
625 pub async fn handle_request_streaming(
627 &self,
628 request: crate::request::JsonRpcRequest,
629 session_context: SessionContext,
630 ) -> Pin<Box<dyn Stream<Item = JsonRpcFrame> + Send>> {
631 if let Some(streaming_handler) = self.streaming_handlers.get(&request.method) {
633 let request_id_clone = request.id.clone();
634 let stream = streaming_handler
635 .handle_streaming(
636 &request.method,
637 request.params,
638 Some(session_context),
639 request.id.clone(),
640 )
641 .await;
642
643 return Box::pin(stream.map(move |result| match result {
644 Ok(frame) => frame,
645 Err(domain_error) => JsonRpcFrame::Error {
646 request_id: request_id_clone.clone(),
647 error: domain_error.to_error_object(),
648 },
649 }));
650 }
651
652 if let Some(fallback_handler) = self
654 .fallback_handlers
655 .get(&request.method)
656 .or(self.default_handler.as_ref())
657 {
658 let method = request.method.clone();
659 let params = request.params.clone();
660 let request_id = request.id.clone();
661 let handler = fallback_handler.clone();
662
663 return Box::pin(futures::stream::once(async move {
664 match handler.handle(&method, params, Some(session_context)).await {
665 Ok(result) => JsonRpcFrame::FinalResult { request_id, result },
666 Err(domain_error) => JsonRpcFrame::Error {
667 request_id,
668 error: domain_error.to_error_object(),
669 },
670 }
671 }));
672 }
673
674 let error = crate::error::JsonRpcErrorObject {
676 code: crate::error_codes::METHOD_NOT_FOUND,
677 message: format!("Method '{}' not found", request.method),
678 data: None,
679 };
680
681 Box::pin(futures::stream::once(async move {
682 JsonRpcFrame::Error {
683 request_id: request.id,
684 error,
685 }
686 }))
687 }
688
689 pub async fn handle_notification(
691 &self,
692 notification: crate::notification::JsonRpcNotification,
693 ) -> Result<(), E> {
694 if let Some(streaming_handler) = self.streaming_handlers.get(¬ification.method) {
696 return streaming_handler
697 .handle_notification(¬ification.method, notification.params, None)
698 .await;
699 }
700
701 if let Some(fallback_handler) = self
703 .fallback_handlers
704 .get(¬ification.method)
705 .or(self.default_handler.as_ref())
706 {
707 return fallback_handler
708 .handle_notification(¬ification.method, notification.params, None)
709 .await;
710 }
711
712 Ok(()) }
714 }
715
716 impl<E> Default for StreamingJsonRpcDispatcher<E>
717 where
718 E: ToJsonRpcError,
719 {
720 fn default() -> Self {
721 Self::new()
722 }
723 }
724}