Skip to main content

tower_mcp/
jsonrpc.rs

1//! JSON-RPC 2.0 service layer
2//!
3//! Provides a Tower [`Layer`] and [`Service`] for JSON-RPC framing of MCP requests.
4//!
5//! - [`JsonRpcLayer`] - Tower layer for [`ServiceBuilder`](tower::ServiceBuilder) composition
6//! - [`JsonRpcService`] - Tower service wrapping an MCP router
7//!
8//! The service handles:
9//! - Single request processing
10//! - Batch request processing (concurrent execution)
11//! - JSON-RPC version validation
12//! - Error conversion to JSON-RPC error responses
13
14use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use tower::{Layer, ServiceExt};
19use tower_service::Service;
20
21use crate::error::{Error, JsonRpcError, Result};
22use crate::protocol::{
23    JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
24};
25use crate::router::{Extensions, RouterRequest, RouterResponse};
26
27/// Tower layer that adds JSON-RPC 2.0 framing to an MCP service.
28///
29/// This is the standard way to compose `JsonRpcService` with other tower
30/// middleware via [`ServiceBuilder`](tower::ServiceBuilder).
31///
32/// # Example
33///
34/// ```rust
35/// use tower::ServiceBuilder;
36/// use tower_mcp::{McpRouter, JsonRpcLayer, JsonRpcService};
37///
38/// let router = McpRouter::new().server_info("my-server", "1.0.0");
39///
40/// // Compose with ServiceBuilder
41/// let service = ServiceBuilder::new()
42///     .layer(JsonRpcLayer::new())
43///     .service(router);
44/// ```
45#[derive(Debug, Clone, Copy, Default)]
46pub struct JsonRpcLayer {
47    _priv: (),
48}
49
50impl JsonRpcLayer {
51    /// Create a new `JsonRpcLayer`.
52    pub fn new() -> Self {
53        Self { _priv: () }
54    }
55}
56
57impl<S> Layer<S> for JsonRpcLayer {
58    type Service = JsonRpcService<S>;
59
60    fn layer(&self, inner: S) -> Self::Service {
61        JsonRpcService::new(inner)
62    }
63}
64
65/// Service that handles JSON-RPC framing.
66///
67/// Wraps an MCP service and handles JSON-RPC request/response conversion.
68/// Supports both single requests and batch requests.
69///
70/// Can be created directly via [`JsonRpcService::new`] or through the
71/// [`JsonRpcLayer`] for [`ServiceBuilder`](tower::ServiceBuilder) composition.
72///
73/// # Example
74///
75/// ```rust
76/// use tower_mcp::{McpRouter, JsonRpcService};
77///
78/// let router = McpRouter::new().server_info("my-server", "1.0.0");
79/// let service = JsonRpcService::new(router);
80/// ```
81pub struct JsonRpcService<S> {
82    inner: S,
83    extensions: Extensions,
84}
85
86impl<S> JsonRpcService<S> {
87    /// Create a new JSON-RPC service wrapping the given inner service
88    pub fn new(inner: S) -> Self {
89        Self {
90            inner,
91            extensions: Extensions::new(),
92        }
93    }
94
95    /// Set extensions to inject into every `RouterRequest` created by this service.
96    ///
97    /// This is used by transports to bridge data (e.g., `TokenClaims`) from the
98    /// HTTP/WebSocket layer into the MCP request pipeline.
99    pub fn with_extensions(mut self, ext: Extensions) -> Self {
100        self.extensions = ext;
101        self
102    }
103
104    /// Process a single JSON-RPC request
105    pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
106    where
107        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
108            + Clone
109            + Send
110            + 'static,
111        S::Future: Send,
112    {
113        process_single_request(self.inner.clone(), req, self.extensions.clone()).await
114    }
115
116    /// Process a batch of JSON-RPC requests concurrently
117    pub async fn call_batch(
118        &mut self,
119        requests: Vec<JsonRpcRequest>,
120    ) -> Result<Vec<JsonRpcResponse>>
121    where
122        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
123            + Clone
124            + Send
125            + 'static,
126        S::Future: Send,
127    {
128        if requests.is_empty() {
129            return Err(Error::JsonRpc(JsonRpcError::invalid_request(
130                "Empty batch request",
131            )));
132        }
133
134        // Process all requests concurrently
135        let futures: Vec<_> = requests
136            .into_iter()
137            .map(|req| {
138                let inner = self.inner.clone();
139                let extensions = self.extensions.clone();
140                let req_id = req.id.clone();
141                async move {
142                    match process_single_request(inner, req, extensions).await {
143                        Ok(resp) => resp,
144                        Err(e) => {
145                            // Convert errors to error responses instead of dropping
146                            JsonRpcResponse::error(
147                                Some(req_id),
148                                JsonRpcError::internal_error(e.to_string()),
149                            )
150                        }
151                    }
152                }
153            })
154            .collect();
155
156        let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
157
158        // Results will never be empty since we converted all errors to responses
159        Ok(results)
160    }
161
162    /// Process a JSON-RPC message (single or batch)
163    pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
164    where
165        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
166            + Clone
167            + Send
168            + 'static,
169        S::Future: Send,
170    {
171        match msg {
172            JsonRpcMessage::Single(req) => {
173                let response = self.call_single(req).await?;
174                Ok(JsonRpcResponseMessage::Single(response))
175            }
176            JsonRpcMessage::Batch(requests) => {
177                let responses = self.call_batch(requests).await?;
178                Ok(JsonRpcResponseMessage::Batch(responses))
179            }
180            _ => Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
181                None,
182                JsonRpcError::invalid_request("Unsupported message type"),
183            ))),
184        }
185    }
186}
187
188impl<S> Clone for JsonRpcService<S>
189where
190    S: Clone,
191{
192    fn clone(&self) -> Self {
193        Self {
194            inner: self.inner.clone(),
195            extensions: self.extensions.clone(),
196        }
197    }
198}
199
200impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
201where
202    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
203        + Clone
204        + Send
205        + 'static,
206    S::Future: Send,
207{
208    type Response = JsonRpcResponse;
209    type Error = Error;
210    type Future =
211        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
212
213    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
214        self.inner.poll_ready(cx).map_err(|_| unreachable!())
215    }
216
217    fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
218        let inner = self.inner.clone();
219        let extensions = self.extensions.clone();
220        Box::pin(async move {
221            // Parse the MCP request from JSON-RPC
222            let mcp_request = McpRequest::from_jsonrpc(&req)?;
223
224            // Create router request
225            let router_req = RouterRequest {
226                id: req.id,
227                inner: mcp_request,
228                extensions,
229            };
230
231            // Call the inner service (oneshot handles poll_ready)
232            let response = inner.oneshot(router_req).await.unwrap(); // Infallible
233
234            // Convert to JSON-RPC response
235            Ok(response.into_jsonrpc())
236        })
237    }
238}
239
240/// Service implementation for JSON-RPC batch requests
241impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
242where
243    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
244        + Clone
245        + Send
246        + 'static,
247    S::Future: Send,
248{
249    type Response = JsonRpcResponseMessage;
250    type Error = Error;
251    type Future =
252        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
253
254    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
255        self.inner.poll_ready(cx).map_err(|_| unreachable!())
256    }
257
258    fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
259        let inner = self.inner.clone();
260        let extensions = self.extensions.clone();
261        Box::pin(async move {
262            match msg {
263                JsonRpcMessage::Single(req) => {
264                    let response = process_single_request(inner, req, extensions).await?;
265                    Ok(JsonRpcResponseMessage::Single(response))
266                }
267                JsonRpcMessage::Batch(requests) => {
268                    if requests.is_empty() {
269                        // Empty batch is an invalid request per JSON-RPC spec
270                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
271                            None,
272                            JsonRpcError::invalid_request("Empty batch request"),
273                        )));
274                    }
275
276                    // Process all requests concurrently
277                    let futures: Vec<_> = requests
278                        .into_iter()
279                        .map(|req| {
280                            let inner = inner.clone();
281                            let extensions = extensions.clone();
282                            let req_id = req.id.clone();
283                            async move {
284                                match process_single_request(inner, req, extensions).await {
285                                    Ok(resp) => resp,
286                                    Err(e) => {
287                                        // Convert errors to error responses instead of dropping
288                                        JsonRpcResponse::error(
289                                            Some(req_id),
290                                            JsonRpcError::internal_error(e.to_string()),
291                                        )
292                                    }
293                                }
294                            }
295                        })
296                        .collect();
297
298                    let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
299
300                    // Empty results only possible if input was empty (already handled above)
301                    if results.is_empty() {
302                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
303                            None,
304                            JsonRpcError::internal_error("All batch requests failed"),
305                        )));
306                    }
307
308                    Ok(JsonRpcResponseMessage::Batch(results))
309                }
310                _ => Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
311                    None,
312                    JsonRpcError::invalid_request("Unsupported message type"),
313                ))),
314            }
315        })
316    }
317}
318
319/// Helper function to process a single JSON-RPC request
320async fn process_single_request<S>(
321    inner: S,
322    req: JsonRpcRequest,
323    extensions: Extensions,
324) -> std::result::Result<JsonRpcResponse, Error>
325where
326    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
327        + Send
328        + 'static,
329    S::Future: Send,
330{
331    // Validate JSON-RPC version
332    if let Err(e) = req.validate() {
333        return Ok(JsonRpcResponse::error(Some(req.id), e));
334    }
335
336    // Parse the MCP request from JSON-RPC
337    let mcp_request = match McpRequest::from_jsonrpc(&req) {
338        Ok(r) => r,
339        Err(e) => {
340            return Ok(JsonRpcResponse::error(
341                Some(req.id),
342                JsonRpcError::invalid_params(e.to_string()),
343            ));
344        }
345    };
346
347    // Create router request
348    let router_req = RouterRequest {
349        id: req.id,
350        inner: mcp_request,
351        extensions,
352    };
353
354    // Call the inner service (oneshot handles poll_ready)
355    let response = inner.oneshot(router_req).await.unwrap(); // Infallible
356
357    // Convert to JSON-RPC response
358    Ok(response.into_jsonrpc())
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::McpRouter;
365    use crate::tool::ToolBuilder;
366    use schemars::JsonSchema;
367    use serde::Deserialize;
368
369    #[derive(Debug, Deserialize, JsonSchema)]
370    struct AddInput {
371        a: i32,
372        b: i32,
373    }
374
375    fn create_test_router() -> McpRouter {
376        let add_tool = ToolBuilder::new("add")
377            .description("Add two numbers")
378            .handler(|input: AddInput| async move {
379                Ok(crate::CallToolResult::text(format!(
380                    "{}",
381                    input.a + input.b
382                )))
383            })
384            .build();
385
386        McpRouter::new()
387            .server_info("test-server", "1.0.0")
388            .tool(add_tool)
389    }
390
391    #[tokio::test]
392    async fn test_jsonrpc_service() {
393        let router = create_test_router();
394        let mut service = JsonRpcService::new(router.clone());
395
396        // Initialize first
397        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
398            "protocolVersion": "2025-11-25",
399            "capabilities": {},
400            "clientInfo": { "name": "test", "version": "1.0" }
401        }));
402        let resp = service.call_single(init_req).await.unwrap();
403        assert!(matches!(resp, JsonRpcResponse::Result(_)));
404
405        // Mark as initialized
406        router.handle_notification(crate::protocol::McpNotification::Initialized);
407
408        // Now list tools
409        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
410        let resp = service.call_single(req).await.unwrap();
411
412        match resp {
413            JsonRpcResponse::Result(r) => {
414                let tools = r.result.get("tools").unwrap().as_array().unwrap();
415                assert_eq!(tools.len(), 1);
416            }
417            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
418            _ => panic!("unexpected response variant"),
419        }
420    }
421
422    #[tokio::test]
423    async fn test_batch_request() {
424        let router = create_test_router();
425        let mut service = JsonRpcService::new(router.clone());
426
427        // Initialize first
428        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
429            "protocolVersion": "2025-11-25",
430            "capabilities": {},
431            "clientInfo": { "name": "test", "version": "1.0" }
432        }));
433        service.call_single(init_req).await.unwrap();
434        router.handle_notification(crate::protocol::McpNotification::Initialized);
435
436        // Batch request
437        let requests = vec![
438            JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
439            JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
440                "name": "add",
441                "arguments": { "a": 1, "b": 2 }
442            })),
443        ];
444
445        let responses = service.call_batch(requests).await.unwrap();
446        assert_eq!(responses.len(), 2);
447    }
448
449    #[tokio::test]
450    async fn test_empty_batch_error() {
451        let router = create_test_router();
452        let mut service = JsonRpcService::new(router);
453
454        let result = service.call_batch(vec![]).await;
455        assert!(result.is_err());
456    }
457
458    #[tokio::test]
459    async fn test_jsonrpc_layer() {
460        use tower::ServiceBuilder;
461
462        let router = create_test_router();
463        let router_clone = router.clone();
464
465        // Build service using the layer via ServiceBuilder
466        let mut service = ServiceBuilder::new()
467            .layer(JsonRpcLayer::new())
468            .service(router);
469
470        // Initialize
471        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
472            "protocolVersion": "2025-11-25",
473            "capabilities": {},
474            "clientInfo": { "name": "test", "version": "1.0" }
475        }));
476        let resp = Service::<JsonRpcRequest>::call(&mut service, init_req)
477            .await
478            .unwrap();
479        assert!(matches!(resp, JsonRpcResponse::Result(_)));
480
481        router_clone.handle_notification(crate::protocol::McpNotification::Initialized);
482
483        // List tools through the layer-composed service
484        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
485        let resp = Service::<JsonRpcRequest>::call(&mut service, req)
486            .await
487            .unwrap();
488
489        match resp {
490            JsonRpcResponse::Result(r) => {
491                let tools = r.result.get("tools").unwrap().as_array().unwrap();
492                assert_eq!(tools.len(), 1);
493            }
494            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
495            _ => panic!("unexpected response variant"),
496        }
497    }
498
499    #[test]
500    fn test_jsonrpc_layer_default() {
501        // JsonRpcLayer implements Default
502        let _layer = JsonRpcLayer::default();
503    }
504
505    #[test]
506    fn test_jsonrpc_layer_clone() {
507        // JsonRpcLayer implements Clone and Copy
508        let layer = JsonRpcLayer::new();
509        let _cloned = layer;
510        let _copied = layer;
511    }
512}