Skip to main content

tower_mcp/
jsonrpc.rs

1//! JSON-RPC 2.0 service layer
2//!
3//! Provides a Tower [`Layer`] and [`Service`] for JSON-RPC framing of MCP requests.
4//!
5//! - [`JsonRpcLayer`] - Tower layer for [`ServiceBuilder`](tower::ServiceBuilder) composition
6//! - [`JsonRpcService`] - Tower service wrapping an MCP router
7//!
8//! The service handles:
9//! - Single request processing
10//! - Batch request processing (concurrent execution)
11//! - JSON-RPC version validation
12//! - Error conversion to JSON-RPC error responses
13
14use std::future::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use tower::Layer;
19use tower_service::Service;
20
21use crate::error::{Error, JsonRpcError, Result};
22use crate::protocol::{
23    JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseMessage, McpRequest,
24};
25use crate::router::{Extensions, RouterRequest, RouterResponse};
26
27/// Tower layer that adds JSON-RPC 2.0 framing to an MCP service.
28///
29/// This is the standard way to compose `JsonRpcService` with other tower
30/// middleware via [`ServiceBuilder`](tower::ServiceBuilder).
31///
32/// # Example
33///
34/// ```rust
35/// use tower::ServiceBuilder;
36/// use tower_mcp::{McpRouter, JsonRpcLayer, JsonRpcService};
37///
38/// let router = McpRouter::new().server_info("my-server", "1.0.0");
39///
40/// // Compose with ServiceBuilder
41/// let service = ServiceBuilder::new()
42///     .layer(JsonRpcLayer::new())
43///     .service(router);
44/// ```
45#[derive(Debug, Clone, Copy, Default)]
46pub struct JsonRpcLayer {
47    _priv: (),
48}
49
50impl JsonRpcLayer {
51    /// Create a new `JsonRpcLayer`.
52    pub fn new() -> Self {
53        Self { _priv: () }
54    }
55}
56
57impl<S> Layer<S> for JsonRpcLayer {
58    type Service = JsonRpcService<S>;
59
60    fn layer(&self, inner: S) -> Self::Service {
61        JsonRpcService::new(inner)
62    }
63}
64
65/// Service that handles JSON-RPC framing.
66///
67/// Wraps an MCP service and handles JSON-RPC request/response conversion.
68/// Supports both single requests and batch requests.
69///
70/// Can be created directly via [`JsonRpcService::new`] or through the
71/// [`JsonRpcLayer`] for [`ServiceBuilder`](tower::ServiceBuilder) composition.
72///
73/// # Example
74///
75/// ```rust
76/// use tower_mcp::{McpRouter, JsonRpcService};
77///
78/// let router = McpRouter::new().server_info("my-server", "1.0.0");
79/// let service = JsonRpcService::new(router);
80/// ```
81pub struct JsonRpcService<S> {
82    inner: S,
83    extensions: Extensions,
84}
85
86impl<S> JsonRpcService<S> {
87    /// Create a new JSON-RPC service wrapping the given inner service
88    pub fn new(inner: S) -> Self {
89        Self {
90            inner,
91            extensions: Extensions::new(),
92        }
93    }
94
95    /// Set extensions to inject into every `RouterRequest` created by this service.
96    ///
97    /// This is used by transports to bridge data (e.g., `TokenClaims`) from the
98    /// HTTP/WebSocket layer into the MCP request pipeline.
99    pub fn with_extensions(mut self, ext: Extensions) -> Self {
100        self.extensions = ext;
101        self
102    }
103
104    /// Process a single JSON-RPC request
105    pub async fn call_single(&mut self, req: JsonRpcRequest) -> Result<JsonRpcResponse>
106    where
107        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
108            + Clone
109            + Send
110            + 'static,
111        S::Future: Send,
112    {
113        process_single_request(self.inner.clone(), req, self.extensions.clone()).await
114    }
115
116    /// Process a batch of JSON-RPC requests concurrently
117    pub async fn call_batch(
118        &mut self,
119        requests: Vec<JsonRpcRequest>,
120    ) -> Result<Vec<JsonRpcResponse>>
121    where
122        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
123            + Clone
124            + Send
125            + 'static,
126        S::Future: Send,
127    {
128        if requests.is_empty() {
129            return Err(Error::JsonRpc(JsonRpcError::invalid_request(
130                "Empty batch request",
131            )));
132        }
133
134        // Process all requests concurrently
135        let futures: Vec<_> = requests
136            .into_iter()
137            .map(|req| {
138                let inner = self.inner.clone();
139                let extensions = self.extensions.clone();
140                let req_id = req.id.clone();
141                async move {
142                    match process_single_request(inner, req, extensions).await {
143                        Ok(resp) => resp,
144                        Err(e) => {
145                            // Convert errors to error responses instead of dropping
146                            JsonRpcResponse::error(
147                                Some(req_id),
148                                JsonRpcError::internal_error(e.to_string()),
149                            )
150                        }
151                    }
152                }
153            })
154            .collect();
155
156        let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
157
158        // Results will never be empty since we converted all errors to responses
159        Ok(results)
160    }
161
162    /// Process a JSON-RPC message (single or batch)
163    pub async fn call_message(&mut self, msg: JsonRpcMessage) -> Result<JsonRpcResponseMessage>
164    where
165        S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
166            + Clone
167            + Send
168            + 'static,
169        S::Future: Send,
170    {
171        match msg {
172            JsonRpcMessage::Single(req) => {
173                let response = self.call_single(req).await?;
174                Ok(JsonRpcResponseMessage::Single(response))
175            }
176            JsonRpcMessage::Batch(requests) => {
177                let responses = self.call_batch(requests).await?;
178                Ok(JsonRpcResponseMessage::Batch(responses))
179            }
180        }
181    }
182}
183
184impl<S> Clone for JsonRpcService<S>
185where
186    S: Clone,
187{
188    fn clone(&self) -> Self {
189        Self {
190            inner: self.inner.clone(),
191            extensions: self.extensions.clone(),
192        }
193    }
194}
195
196impl<S> Service<JsonRpcRequest> for JsonRpcService<S>
197where
198    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
199        + Clone
200        + Send
201        + 'static,
202    S::Future: Send,
203{
204    type Response = JsonRpcResponse;
205    type Error = Error;
206    type Future =
207        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
208
209    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
210        self.inner.poll_ready(cx).map_err(|_| unreachable!())
211    }
212
213    fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
214        let mut inner = self.inner.clone();
215        let extensions = self.extensions.clone();
216        Box::pin(async move {
217            // Parse the MCP request from JSON-RPC
218            let mcp_request = McpRequest::from_jsonrpc(&req)?;
219
220            // Create router request
221            let router_req = RouterRequest {
222                id: req.id,
223                inner: mcp_request,
224                extensions,
225            };
226
227            // Call the inner service
228            let response = inner.call(router_req).await.unwrap(); // Infallible
229
230            // Convert to JSON-RPC response
231            Ok(response.into_jsonrpc())
232        })
233    }
234}
235
236/// Service implementation for JSON-RPC batch requests
237impl<S> Service<JsonRpcMessage> for JsonRpcService<S>
238where
239    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
240        + Clone
241        + Send
242        + 'static,
243    S::Future: Send,
244{
245    type Response = JsonRpcResponseMessage;
246    type Error = Error;
247    type Future =
248        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
249
250    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
251        self.inner.poll_ready(cx).map_err(|_| unreachable!())
252    }
253
254    fn call(&mut self, msg: JsonRpcMessage) -> Self::Future {
255        let inner = self.inner.clone();
256        let extensions = self.extensions.clone();
257        Box::pin(async move {
258            match msg {
259                JsonRpcMessage::Single(req) => {
260                    let response = process_single_request(inner, req, extensions).await?;
261                    Ok(JsonRpcResponseMessage::Single(response))
262                }
263                JsonRpcMessage::Batch(requests) => {
264                    if requests.is_empty() {
265                        // Empty batch is an invalid request per JSON-RPC spec
266                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
267                            None,
268                            JsonRpcError::invalid_request("Empty batch request"),
269                        )));
270                    }
271
272                    // Process all requests concurrently
273                    let futures: Vec<_> = requests
274                        .into_iter()
275                        .map(|req| {
276                            let inner = inner.clone();
277                            let extensions = extensions.clone();
278                            let req_id = req.id.clone();
279                            async move {
280                                match process_single_request(inner, req, extensions).await {
281                                    Ok(resp) => resp,
282                                    Err(e) => {
283                                        // Convert errors to error responses instead of dropping
284                                        JsonRpcResponse::error(
285                                            Some(req_id),
286                                            JsonRpcError::internal_error(e.to_string()),
287                                        )
288                                    }
289                                }
290                            }
291                        })
292                        .collect();
293
294                    let results: Vec<JsonRpcResponse> = futures::future::join_all(futures).await;
295
296                    // Empty results only possible if input was empty (already handled above)
297                    if results.is_empty() {
298                        return Ok(JsonRpcResponseMessage::Single(JsonRpcResponse::error(
299                            None,
300                            JsonRpcError::internal_error("All batch requests failed"),
301                        )));
302                    }
303
304                    Ok(JsonRpcResponseMessage::Batch(results))
305                }
306            }
307        })
308    }
309}
310
311/// Helper function to process a single JSON-RPC request
312async fn process_single_request<S>(
313    mut inner: S,
314    req: JsonRpcRequest,
315    extensions: Extensions,
316) -> std::result::Result<JsonRpcResponse, Error>
317where
318    S: Service<RouterRequest, Response = RouterResponse, Error = std::convert::Infallible>
319        + Send
320        + 'static,
321    S::Future: Send,
322{
323    // Validate JSON-RPC version
324    if let Err(e) = req.validate() {
325        return Ok(JsonRpcResponse::error(Some(req.id), e));
326    }
327
328    // Parse the MCP request from JSON-RPC
329    let mcp_request = match McpRequest::from_jsonrpc(&req) {
330        Ok(r) => r,
331        Err(e) => {
332            return Ok(JsonRpcResponse::error(
333                Some(req.id),
334                JsonRpcError::invalid_params(e.to_string()),
335            ));
336        }
337    };
338
339    // Create router request
340    let router_req = RouterRequest {
341        id: req.id,
342        inner: mcp_request,
343        extensions,
344    };
345
346    // Call the inner service
347    let response = inner.call(router_req).await.unwrap(); // Infallible
348
349    // Convert to JSON-RPC response
350    Ok(response.into_jsonrpc())
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::McpRouter;
357    use crate::tool::ToolBuilder;
358    use schemars::JsonSchema;
359    use serde::Deserialize;
360
361    #[derive(Debug, Deserialize, JsonSchema)]
362    struct AddInput {
363        a: i32,
364        b: i32,
365    }
366
367    fn create_test_router() -> McpRouter {
368        let add_tool = ToolBuilder::new("add")
369            .description("Add two numbers")
370            .handler(|input: AddInput| async move {
371                Ok(crate::CallToolResult::text(format!(
372                    "{}",
373                    input.a + input.b
374                )))
375            })
376            .build()
377            .unwrap();
378
379        McpRouter::new()
380            .server_info("test-server", "1.0.0")
381            .tool(add_tool)
382    }
383
384    #[tokio::test]
385    async fn test_jsonrpc_service() {
386        let router = create_test_router();
387        let mut service = JsonRpcService::new(router.clone());
388
389        // Initialize first
390        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
391            "protocolVersion": "2025-11-25",
392            "capabilities": {},
393            "clientInfo": { "name": "test", "version": "1.0" }
394        }));
395        let resp = service.call_single(init_req).await.unwrap();
396        assert!(matches!(resp, JsonRpcResponse::Result(_)));
397
398        // Mark as initialized
399        router.handle_notification(crate::protocol::McpNotification::Initialized);
400
401        // Now list tools
402        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
403        let resp = service.call_single(req).await.unwrap();
404
405        match resp {
406            JsonRpcResponse::Result(r) => {
407                let tools = r.result.get("tools").unwrap().as_array().unwrap();
408                assert_eq!(tools.len(), 1);
409            }
410            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
411        }
412    }
413
414    #[tokio::test]
415    async fn test_batch_request() {
416        let router = create_test_router();
417        let mut service = JsonRpcService::new(router.clone());
418
419        // Initialize first
420        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
421            "protocolVersion": "2025-11-25",
422            "capabilities": {},
423            "clientInfo": { "name": "test", "version": "1.0" }
424        }));
425        service.call_single(init_req).await.unwrap();
426        router.handle_notification(crate::protocol::McpNotification::Initialized);
427
428        // Batch request
429        let requests = vec![
430            JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({})),
431            JsonRpcRequest::new(3, "tools/call").with_params(serde_json::json!({
432                "name": "add",
433                "arguments": { "a": 1, "b": 2 }
434            })),
435        ];
436
437        let responses = service.call_batch(requests).await.unwrap();
438        assert_eq!(responses.len(), 2);
439    }
440
441    #[tokio::test]
442    async fn test_empty_batch_error() {
443        let router = create_test_router();
444        let mut service = JsonRpcService::new(router);
445
446        let result = service.call_batch(vec![]).await;
447        assert!(result.is_err());
448    }
449
450    #[tokio::test]
451    async fn test_jsonrpc_layer() {
452        use tower::ServiceBuilder;
453
454        let router = create_test_router();
455        let router_clone = router.clone();
456
457        // Build service using the layer via ServiceBuilder
458        let mut service = ServiceBuilder::new()
459            .layer(JsonRpcLayer::new())
460            .service(router);
461
462        // Initialize
463        let init_req = JsonRpcRequest::new(1, "initialize").with_params(serde_json::json!({
464            "protocolVersion": "2025-11-25",
465            "capabilities": {},
466            "clientInfo": { "name": "test", "version": "1.0" }
467        }));
468        let resp = Service::<JsonRpcRequest>::call(&mut service, init_req)
469            .await
470            .unwrap();
471        assert!(matches!(resp, JsonRpcResponse::Result(_)));
472
473        router_clone.handle_notification(crate::protocol::McpNotification::Initialized);
474
475        // List tools through the layer-composed service
476        let req = JsonRpcRequest::new(2, "tools/list").with_params(serde_json::json!({}));
477        let resp = Service::<JsonRpcRequest>::call(&mut service, req)
478            .await
479            .unwrap();
480
481        match resp {
482            JsonRpcResponse::Result(r) => {
483                let tools = r.result.get("tools").unwrap().as_array().unwrap();
484                assert_eq!(tools.len(), 1);
485            }
486            JsonRpcResponse::Error(e) => panic!("Expected result, got error: {:?}", e),
487        }
488    }
489
490    #[test]
491    fn test_jsonrpc_layer_default() {
492        // JsonRpcLayer implements Default
493        let _layer = JsonRpcLayer::default();
494    }
495
496    #[test]
497    fn test_jsonrpc_layer_clone() {
498        // JsonRpcLayer implements Clone and Copy
499        let layer = JsonRpcLayer::new();
500        let _cloned = layer;
501        let _copied = layer;
502    }
503}