turul_mcp_json_rpc_server/
async.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde_json::Value;
6
7#[cfg(feature = "streams")]
8use futures::{Stream, StreamExt};
9#[cfg(feature = "streams")]
10use std::pin::Pin;
11
12use crate::{
13    error::JsonRpcError,
14    notification::JsonRpcNotification,
15    request::{JsonRpcRequest, RequestParams},
16    response::{JsonRpcMessage, ResponseResult},
17};
18
19/// Minimal session context for JSON-RPC handlers
20/// This provides basic session information without circular dependencies
21#[derive(Debug, Clone)]
22pub struct SessionContext {
23    /// Unique session identifier
24    pub session_id: String,
25    /// Session metadata
26    pub metadata: HashMap<String, Value>,
27    /// Optional broadcaster for session notifications
28    pub broadcaster: Option<Arc<dyn std::any::Any + Send + Sync>>,
29    /// Session timestamp (Unix milliseconds)
30    pub timestamp: u64,
31}
32
33/// Trait for handling JSON-RPC method calls
34#[async_trait]
35pub trait JsonRpcHandler: Send + Sync {
36    /// The error type returned by this handler
37    type Error: std::error::Error + Send + Sync + 'static;
38
39    /// Handle a JSON-RPC method call with optional session context
40    /// Returns domain errors only - dispatcher handles conversion to JSON-RPC errors
41    async fn handle(
42        &self,
43        method: &str,
44        params: Option<RequestParams>,
45        session_context: Option<SessionContext>,
46    ) -> Result<Value, Self::Error>;
47
48    /// Handle a JSON-RPC notification with optional session context (optional - default does nothing)
49    async fn handle_notification(
50        &self,
51        method: &str,
52        params: Option<RequestParams>,
53        session_context: Option<SessionContext>,
54    ) -> Result<(), Self::Error> {
55        // Default implementation - ignore notifications
56        let _ = (method, params, session_context);
57        Ok(())
58    }
59
60    /// List supported methods (optional - used for introspection)
61    fn supported_methods(&self) -> Vec<String> {
62        vec![]
63    }
64}
65
66/// A simple function-based handler
67pub struct FunctionHandler<F, N, E>
68where
69    E: std::error::Error + Send + Sync + 'static,
70    F: Fn(
71            &str,
72            Option<RequestParams>,
73            Option<SessionContext>,
74        ) -> futures::future::BoxFuture<'static, Result<Value, E>>
75        + Send
76        + Sync,
77    N: Fn(
78            &str,
79            Option<RequestParams>,
80            Option<SessionContext>,
81        ) -> futures::future::BoxFuture<'static, Result<(), E>>
82        + Send
83        + Sync,
84{
85    handler_fn: F,
86    notification_fn: Option<N>,
87    methods: Vec<String>,
88}
89
90impl<F, N, E> FunctionHandler<F, N, E>
91where
92    E: std::error::Error + Send + Sync + 'static,
93    F: Fn(
94            &str,
95            Option<RequestParams>,
96            Option<SessionContext>,
97        ) -> futures::future::BoxFuture<'static, Result<Value, E>>
98        + Send
99        + Sync,
100    N: Fn(
101            &str,
102            Option<RequestParams>,
103            Option<SessionContext>,
104        ) -> futures::future::BoxFuture<'static, Result<(), E>>
105        + Send
106        + Sync,
107{
108    pub fn new(handler_fn: F) -> Self {
109        Self {
110            handler_fn,
111            notification_fn: None,
112            methods: vec![],
113        }
114    }
115
116    pub fn with_notification_handler(mut self, notification_fn: N) -> Self {
117        self.notification_fn = Some(notification_fn);
118        self
119    }
120
121    pub fn with_methods(mut self, methods: Vec<String>) -> Self {
122        self.methods = methods;
123        self
124    }
125}
126
127#[async_trait]
128impl<F, N, E> JsonRpcHandler for FunctionHandler<F, N, E>
129where
130    E: std::error::Error + Send + Sync + 'static,
131    F: Fn(
132            &str,
133            Option<RequestParams>,
134            Option<SessionContext>,
135        ) -> futures::future::BoxFuture<'static, Result<Value, E>>
136        + Send
137        + Sync,
138    N: Fn(
139            &str,
140            Option<RequestParams>,
141            Option<SessionContext>,
142        ) -> futures::future::BoxFuture<'static, Result<(), E>>
143        + Send
144        + Sync,
145{
146    type Error = E;
147
148    async fn handle(
149        &self,
150        method: &str,
151        params: Option<RequestParams>,
152        session_context: Option<SessionContext>,
153    ) -> Result<Value, Self::Error> {
154        (self.handler_fn)(method, params, session_context).await
155    }
156
157    async fn handle_notification(
158        &self,
159        method: &str,
160        params: Option<RequestParams>,
161        session_context: Option<SessionContext>,
162    ) -> Result<(), Self::Error> {
163        if let Some(ref notification_fn) = self.notification_fn {
164            (notification_fn)(method, params, session_context).await
165        } else {
166            Ok(())
167        }
168    }
169
170    fn supported_methods(&self) -> Vec<String> {
171        self.methods.clone()
172    }
173}
174
175/// Trait for errors that can be converted to JSON-RPC error objects
176pub trait ToJsonRpcError: std::error::Error + Send + Sync + 'static {
177    /// Convert this error to a JSON-RPC error object
178    fn to_error_object(&self) -> crate::error::JsonRpcErrorObject;
179}
180
181/// JSON-RPC method dispatcher with specific error type
182pub struct JsonRpcDispatcher<E>
183where
184    E: ToJsonRpcError,
185{
186    pub handlers: HashMap<String, Arc<dyn JsonRpcHandler<Error = E>>>,
187    pub default_handler: Option<Arc<dyn JsonRpcHandler<Error = E>>>,
188}
189
190impl<E> JsonRpcDispatcher<E>
191where
192    E: ToJsonRpcError,
193{
194    pub fn new() -> Self {
195        Self {
196            handlers: HashMap::new(),
197            default_handler: None,
198        }
199    }
200
201    /// Register a handler for a specific method
202    pub fn register_method<H>(&mut self, method: String, handler: H)
203    where
204        H: JsonRpcHandler<Error = E> + 'static,
205    {
206        self.handlers.insert(method, Arc::new(handler));
207    }
208
209    /// Register a handler for multiple methods
210    pub fn register_methods<H>(&mut self, methods: Vec<String>, handler: H)
211    where
212        H: JsonRpcHandler<Error = E> + 'static,
213    {
214        let handler_arc = Arc::new(handler);
215        for method in methods {
216            self.handlers.insert(method, handler_arc.clone());
217        }
218    }
219
220    /// Set a default handler for unregistered methods
221    pub fn set_default_handler<H>(&mut self, handler: H)
222    where
223        H: JsonRpcHandler<Error = E> + 'static,
224    {
225        self.default_handler = Some(Arc::new(handler));
226    }
227
228    /// Process a JSON-RPC request with session context and return a response
229    pub async fn handle_request_with_context(
230        &self,
231        request: JsonRpcRequest,
232        session_context: SessionContext,
233    ) -> JsonRpcMessage {
234        let handler = self
235            .handlers
236            .get(&request.method)
237            .or(self.default_handler.as_ref());
238
239        match handler {
240            Some(handler) => {
241                match handler
242                    .handle(&request.method, request.params, Some(session_context))
243                    .await
244                {
245                    Ok(result) => {
246                        JsonRpcMessage::success(request.id, ResponseResult::Success(result))
247                    }
248                    Err(domain_error) => {
249                        // Convert domain error to JSON-RPC error using type-safe conversion
250                        let error_object = domain_error.to_error_object();
251                        let rpc_error = JsonRpcError::new(Some(request.id.clone()), error_object);
252                        JsonRpcMessage::error(rpc_error)
253                    }
254                }
255            }
256            None => {
257                let error = JsonRpcError::method_not_found(request.id.clone(), &request.method);
258                JsonRpcMessage::error(error)
259            }
260        }
261    }
262
263    /// Process a JSON-RPC request and return a response (backward compatibility - no session context)
264    pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcMessage {
265        let handler = self
266            .handlers
267            .get(&request.method)
268            .or(self.default_handler.as_ref());
269
270        match handler {
271            Some(handler) => {
272                match handler.handle(&request.method, request.params, None).await {
273                    Ok(result) => {
274                        JsonRpcMessage::success(request.id, ResponseResult::Success(result))
275                    }
276                    Err(domain_error) => {
277                        // Convert domain error to JSON-RPC error using type-safe conversion
278                        let error_object = domain_error.to_error_object();
279                        let rpc_error = JsonRpcError::new(Some(request.id.clone()), error_object);
280                        JsonRpcMessage::error(rpc_error)
281                    }
282                }
283            }
284            None => {
285                let error = JsonRpcError::method_not_found(request.id.clone(), &request.method);
286                JsonRpcMessage::error(error)
287            }
288        }
289    }
290
291    /// Process a JSON-RPC notification
292    pub async fn handle_notification(&self, notification: JsonRpcNotification) -> Result<(), E> {
293        let handler = self
294            .handlers
295            .get(&notification.method)
296            .or(self.default_handler.as_ref());
297
298        match handler {
299            Some(handler) => {
300                handler
301                    .handle_notification(&notification.method, notification.params, None)
302                    .await
303            }
304            None => {
305                // Notifications don't return errors, just ignore unknown methods
306                Ok(())
307            }
308        }
309    }
310
311    /// Process a JSON-RPC notification with session context
312    pub async fn handle_notification_with_context(
313        &self,
314        notification: JsonRpcNotification,
315        session_context: Option<SessionContext>,
316    ) -> Result<(), E> {
317        let handler = self
318            .handlers
319            .get(&notification.method)
320            .or(self.default_handler.as_ref());
321
322        match handler {
323            Some(handler) => {
324                handler
325                    .handle_notification(&notification.method, notification.params, session_context)
326                    .await
327            }
328            None => {
329                // Notifications don't return errors, just ignore unknown methods
330                Ok(())
331            }
332        }
333    }
334
335    /// Get all registered methods
336    pub fn registered_methods(&self) -> Vec<String> {
337        self.handlers.keys().cloned().collect()
338    }
339}
340
341impl<E> Default for JsonRpcDispatcher<E>
342where
343    E: ToJsonRpcError,
344{
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::{JsonRpcRequest, RequestId};
354    use serde_json::json;
355
356    #[derive(thiserror::Error, Debug)]
357    enum TestError {
358        #[error("Test error: {0}")]
359        TestError(String),
360        #[error("Unknown method: {0}")]
361        UnknownMethod(String),
362    }
363
364    impl ToJsonRpcError for TestError {
365        fn to_error_object(&self) -> crate::error::JsonRpcErrorObject {
366            use crate::error::JsonRpcErrorObject;
367            match self {
368                TestError::TestError(msg) => JsonRpcErrorObject::internal_error(Some(msg.clone())),
369                TestError::UnknownMethod(method) => JsonRpcErrorObject::method_not_found(method),
370            }
371        }
372    }
373
374    struct TestHandler;
375
376    #[async_trait]
377    impl JsonRpcHandler for TestHandler {
378        type Error = TestError;
379
380        async fn handle(
381            &self,
382            method: &str,
383            _params: Option<RequestParams>,
384            _session_context: Option<SessionContext>,
385        ) -> Result<Value, Self::Error> {
386            match method {
387                "add" => Ok(json!({"result": "addition"})),
388                "error" => Err(TestError::TestError("test error".to_string())),
389                _ => Err(TestError::UnknownMethod(method.to_string())),
390            }
391        }
392
393        fn supported_methods(&self) -> Vec<String> {
394            vec!["add".to_string(), "error".to_string()]
395        }
396    }
397
398    #[tokio::test]
399    async fn test_dispatcher_success() {
400        let mut dispatcher: JsonRpcDispatcher<TestError> = JsonRpcDispatcher::new();
401        dispatcher.register_method("add".to_string(), TestHandler);
402
403        let request = JsonRpcRequest::new_no_params(RequestId::Number(1), "add".to_string());
404
405        let response = dispatcher.handle_request(request).await;
406        assert_eq!(response.id(), Some(&RequestId::Number(1)));
407        assert!(!response.is_error());
408    }
409
410    #[tokio::test]
411    async fn test_dispatcher_method_not_found() {
412        let dispatcher: JsonRpcDispatcher<TestError> = JsonRpcDispatcher::new();
413
414        let request = JsonRpcRequest::new_no_params(RequestId::Number(1), "unknown".to_string());
415
416        let response = dispatcher.handle_request(request).await;
417        assert_eq!(response.id(), Some(&RequestId::Number(1)));
418        assert!(response.is_error());
419    }
420
421    #[tokio::test]
422    async fn test_function_handler() {
423        // Test JsonRpcHandler directly
424        let handler = TestHandler;
425        let result = handler.handle("add", None, None).await.unwrap();
426        assert_eq!(result["result"], "addition");
427    }
428}
429
430// ============================================================================
431// 🚀 STREAMING DISPATCHER - MCP 2025-06-18 Support
432// ============================================================================
433
434#[cfg(feature = "streams")]
435pub mod streaming {
436    use super::*;
437
438    /// JSON-RPC frame for streaming responses
439    /// Represents individual chunks in a progressive response stream
440    #[derive(Debug, Clone)]
441    pub enum JsonRpcFrame {
442        /// Progress update with optional token for cancellation
443        Progress {
444            request_id: crate::types::RequestId,
445            progress: Value,
446            progress_token: Option<String>,
447        },
448        /// Partial result chunk
449        PartialResult {
450            request_id: crate::types::RequestId,
451            data: Value,
452        },
453        /// Final result (ends the stream)
454        FinalResult {
455            request_id: crate::types::RequestId,
456            result: Value,
457        },
458        /// Error result (ends the stream)
459        Error {
460            request_id: crate::types::RequestId,
461            error: crate::error::JsonRpcErrorObject,
462        },
463        /// Notification frame (doesn't end stream)
464        Notification {
465            method: String,
466            params: Option<Value>,
467        },
468    }
469
470    impl JsonRpcFrame {
471        /// Convert frame to JSON-RPC message format
472        pub fn to_json(&self) -> Value {
473            match self {
474                JsonRpcFrame::Progress {
475                    request_id,
476                    progress,
477                    progress_token,
478                } => {
479                    let mut obj = serde_json::json!({
480                        "jsonrpc": "2.0",
481                        "id": request_id,
482                        "_meta": {
483                            "progress": progress
484                        }
485                    });
486
487                    if let Some(token) = progress_token {
488                        obj["_meta"]["progressToken"] = Value::String(token.clone());
489                    }
490
491                    obj
492                }
493                JsonRpcFrame::PartialResult { request_id, data } => {
494                    serde_json::json!({
495                        "jsonrpc": "2.0",
496                        "id": request_id,
497                        "_meta": {
498                            "partial": true
499                        },
500                        "result": data
501                    })
502                }
503                JsonRpcFrame::FinalResult { request_id, result } => {
504                    serde_json::json!({
505                        "jsonrpc": "2.0",
506                        "id": request_id,
507                        "result": result
508                    })
509                }
510                JsonRpcFrame::Error { request_id, error } => {
511                    serde_json::json!({
512                        "jsonrpc": "2.0",
513                        "id": request_id,
514                        "error": {
515                            "code": error.code,
516                            "message": &error.message,
517                            "data": &error.data
518                        }
519                    })
520                }
521                JsonRpcFrame::Notification { method, params } => {
522                    let mut obj = serde_json::json!({
523                        "jsonrpc": "2.0",
524                        "method": method
525                    });
526
527                    if let Some(params) = params {
528                        obj["params"] = params.clone();
529                    }
530
531                    obj
532                }
533            }
534        }
535
536        /// Check if this frame ends the stream
537        pub fn is_terminal(&self) -> bool {
538            matches!(
539                self,
540                JsonRpcFrame::FinalResult { .. } | JsonRpcFrame::Error { .. }
541            )
542        }
543    }
544
545    /// Trait for handlers that support streaming responses
546    #[async_trait]
547    pub trait StreamingJsonRpcHandler: Send + Sync {
548        /// The error type returned by this handler
549        type Error: std::error::Error + Send + Sync + 'static;
550
551        /// Handle a request with streaming response support
552        /// Returns a stream of frames for progressive responses
553        async fn handle_streaming(
554            &self,
555            method: &str,
556            params: Option<crate::request::RequestParams>,
557            session_context: Option<SessionContext>,
558            request_id: crate::types::RequestId,
559        ) -> Pin<Box<dyn Stream<Item = Result<JsonRpcFrame, Self::Error>> + Send>>;
560
561        /// Handle a notification (non-streaming, same as regular handler)
562        async fn handle_notification(
563            &self,
564            method: &str,
565            params: Option<crate::request::RequestParams>,
566            session_context: Option<SessionContext>,
567        ) -> Result<(), Self::Error> {
568            // Default implementation - ignore notifications
569            let _ = (method, params, session_context);
570            Ok(())
571        }
572
573        /// List supported methods (optional - used for introspection)
574        fn supported_methods(&self) -> Vec<String> {
575            vec![]
576        }
577    }
578
579    /// Streaming JSON-RPC method dispatcher
580    pub struct StreamingJsonRpcDispatcher<E>
581    where
582        E: ToJsonRpcError,
583    {
584        streaming_handlers: HashMap<String, Arc<dyn StreamingJsonRpcHandler<Error = E>>>,
585        fallback_handlers: HashMap<String, Arc<dyn JsonRpcHandler<Error = E>>>,
586        default_handler: Option<Arc<dyn JsonRpcHandler<Error = E>>>,
587    }
588
589    impl<E> StreamingJsonRpcDispatcher<E>
590    where
591        E: ToJsonRpcError,
592    {
593        pub fn new() -> Self {
594            Self {
595                streaming_handlers: HashMap::new(),
596                fallback_handlers: HashMap::new(),
597                default_handler: None,
598            }
599        }
600
601        /// Register a streaming handler for a specific method
602        pub fn register_streaming_method<H>(&mut self, method: String, handler: H)
603        where
604            H: StreamingJsonRpcHandler<Error = E> + 'static,
605        {
606            self.streaming_handlers.insert(method, Arc::new(handler));
607        }
608
609        /// Register a fallback (non-streaming) handler for a specific method
610        pub fn register_fallback_method<H>(&mut self, method: String, handler: H)
611        where
612            H: JsonRpcHandler<Error = E> + 'static,
613        {
614            self.fallback_handlers.insert(method, Arc::new(handler));
615        }
616
617        /// Set a default handler for unregistered methods
618        pub fn set_default_handler<H>(&mut self, handler: H)
619        where
620            H: JsonRpcHandler<Error = E> + 'static,
621        {
622            self.default_handler = Some(Arc::new(handler));
623        }
624
625        /// Process a JSON-RPC request with streaming support
626        pub async fn handle_request_streaming(
627            &self,
628            request: crate::request::JsonRpcRequest,
629            session_context: SessionContext,
630        ) -> Pin<Box<dyn Stream<Item = JsonRpcFrame> + Send>> {
631            // First try streaming handler
632            if let Some(streaming_handler) = self.streaming_handlers.get(&request.method) {
633                let request_id_clone = request.id.clone();
634                let stream = streaming_handler
635                    .handle_streaming(
636                        &request.method,
637                        request.params,
638                        Some(session_context),
639                        request.id.clone(),
640                    )
641                    .await;
642
643                return Box::pin(stream.map(move |result| match result {
644                    Ok(frame) => frame,
645                    Err(domain_error) => JsonRpcFrame::Error {
646                        request_id: request_id_clone.clone(),
647                        error: domain_error.to_error_object(),
648                    },
649                }));
650            }
651
652            // Fall back to regular handler wrapped in streaming
653            if let Some(fallback_handler) = self
654                .fallback_handlers
655                .get(&request.method)
656                .or(self.default_handler.as_ref())
657            {
658                let method = request.method.clone();
659                let params = request.params.clone();
660                let request_id = request.id.clone();
661                let handler = fallback_handler.clone();
662
663                return Box::pin(futures::stream::once(async move {
664                    match handler.handle(&method, params, Some(session_context)).await {
665                        Ok(result) => JsonRpcFrame::FinalResult { request_id, result },
666                        Err(domain_error) => JsonRpcFrame::Error {
667                            request_id,
668                            error: domain_error.to_error_object(),
669                        },
670                    }
671                }));
672            }
673
674            // Method not found
675            let error = crate::error::JsonRpcErrorObject {
676                code: crate::error_codes::METHOD_NOT_FOUND,
677                message: format!("Method '{}' not found", request.method),
678                data: None,
679            };
680
681            Box::pin(futures::stream::once(async move {
682                JsonRpcFrame::Error {
683                    request_id: request.id,
684                    error,
685                }
686            }))
687        }
688
689        /// Process a JSON-RPC notification
690        pub async fn handle_notification(
691            &self,
692            notification: crate::notification::JsonRpcNotification,
693        ) -> Result<(), E> {
694            // Try streaming handler first
695            if let Some(streaming_handler) = self.streaming_handlers.get(&notification.method) {
696                return streaming_handler
697                    .handle_notification(&notification.method, notification.params, None)
698                    .await;
699            }
700
701            // Try fallback handler
702            if let Some(fallback_handler) = self
703                .fallback_handlers
704                .get(&notification.method)
705                .or(self.default_handler.as_ref())
706            {
707                return fallback_handler
708                    .handle_notification(&notification.method, notification.params, None)
709                    .await;
710            }
711
712            Ok(()) // Ignore unknown notifications
713        }
714    }
715
716    impl<E> Default for StreamingJsonRpcDispatcher<E>
717    where
718        E: ToJsonRpcError,
719    {
720        fn default() -> Self {
721            Self::new()
722        }
723    }
724}