Skip to main content

tower_mcp/transport/
service.rs

1//! Service types for transport-level middleware support
2//!
3//! This module provides the types needed to apply tower middleware layers
4//! to MCP request processing within HTTP and WebSocket transports.
5//!
6//! The key type is [`ServiceFactory`], a function that takes an [`McpRouter`]
7//! and produces a boxed, middleware-wrapped service. Transports store this
8//! factory and use it when creating sessions.
9//!
10//! [`CatchError`] is a wrapper that converts middleware errors (e.g., timeouts)
11//! into [`RouterResponse`] errors, preserving the `Error = Infallible` contract
12//! that [`JsonRpcService`] requires.
13//!
14//! [`McpRouter`]: crate::router::McpRouter
15//! [`RouterResponse`]: crate::router::RouterResponse
16//! [`JsonRpcService`]: crate::jsonrpc::JsonRpcService
17
18use std::convert::Infallible;
19use std::fmt;
20use std::future::Future;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use tower::util::BoxCloneService;
26use tower_service::Service;
27
28use crate::error::JsonRpcError;
29use crate::router::{McpRouter, RouterRequest, RouterResponse};
30
31/// A boxed, cloneable MCP service with `Error = Infallible`.
32///
33/// This is the service type that transports use internally after applying
34/// middleware layers. It wraps any `Service<RouterRequest>` implementation
35/// so that [`JsonRpcService`](crate::jsonrpc::JsonRpcService) can consume it
36/// without knowing the concrete middleware stack.
37pub type McpBoxService = BoxCloneService<RouterRequest, RouterResponse, Infallible>;
38
39/// A factory function that produces a [`McpBoxService`] from an [`McpRouter`].
40///
41/// Transports store this factory and call it when creating new sessions.
42/// The default factory (from [`identity_factory`]) returns the router as-is.
43/// When `.layer()` is called on a transport, the factory wraps the router
44/// with the given middleware and a [`CatchError`] adapter.
45pub type ServiceFactory = Arc<dyn Fn(McpRouter) -> McpBoxService + Send + Sync>;
46
47/// Create a [`ServiceFactory`] that returns the router unchanged.
48///
49/// This is the default factory used by transports when no `.layer()` is applied.
50pub fn identity_factory() -> ServiceFactory {
51    Arc::new(|router: McpRouter| BoxCloneService::new(router))
52}
53
54/// A service wrapper that catches errors from middleware and converts them
55/// into [`RouterResponse`] error values, maintaining the `Error = Infallible`
56/// contract required by [`JsonRpcService`](crate::jsonrpc::JsonRpcService).
57///
58/// When a middleware layer (e.g., `TimeoutLayer`) produces an error, this
59/// wrapper converts it into a JSON-RPC internal error response using the
60/// request ID from the original request. This allows error information to
61/// flow through the normal response path rather than requiring special
62/// error handling at the transport level.
63pub struct CatchError<S> {
64    inner: S,
65}
66
67impl<S> CatchError<S> {
68    /// Create a new `CatchError` wrapping the given service.
69    pub fn new(inner: S) -> Self {
70        Self { inner }
71    }
72}
73
74impl<S: Clone> Clone for CatchError<S> {
75    fn clone(&self) -> Self {
76        Self {
77            inner: self.inner.clone(),
78        }
79    }
80}
81
82impl<S: fmt::Debug> fmt::Debug for CatchError<S> {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.debug_struct("CatchError")
85            .field("inner", &self.inner)
86            .finish()
87    }
88}
89
90impl<S> Service<RouterRequest> for CatchError<S>
91where
92    S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
93    S::Error: fmt::Display + Send,
94    S::Future: Send,
95{
96    type Response = RouterResponse;
97    type Error = Infallible;
98    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
99
100    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101        self.inner.poll_ready(cx).map_err(|_| unreachable!())
102    }
103
104    fn call(&mut self, req: RouterRequest) -> Self::Future {
105        // Capture the request ID before passing the request to the inner service.
106        // We need this to build a proper JSON-RPC error response if the middleware fails.
107        let request_id = req.id.clone();
108        let fut = self.inner.call(req);
109
110        Box::pin(async move {
111            match fut.await {
112                Ok(response) => Ok(response),
113                Err(err) => Ok(RouterResponse {
114                    id: request_id,
115                    inner: Err(JsonRpcError::internal_error(err.to_string())),
116                }),
117            }
118        })
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::protocol::RequestId;
126
127    #[test]
128    fn test_identity_factory_produces_service() {
129        let router = McpRouter::new().server_info("test", "1.0.0");
130        let factory = identity_factory();
131        let _service = factory(router);
132    }
133
134    #[tokio::test]
135    async fn test_catch_error_passes_through_success() {
136        let router = McpRouter::new().server_info("test", "1.0.0");
137        let mut service = CatchError::new(router);
138
139        let req = RouterRequest {
140            id: RequestId::Number(1),
141            inner: crate::protocol::McpRequest::Ping,
142            extensions: crate::router::Extensions::new(),
143        };
144
145        let result = Service::call(&mut service, req).await;
146        assert!(result.is_ok());
147        let response = result.unwrap();
148        assert!(response.inner.is_ok());
149    }
150
151    #[test]
152    fn test_catch_error_clone() {
153        let router = McpRouter::new().server_info("test", "1.0.0");
154        let service = CatchError::new(router);
155        let _clone = service.clone();
156    }
157
158    #[test]
159    fn test_catch_error_debug() {
160        let router = McpRouter::new().server_info("test", "1.0.0");
161        let service = CatchError::new(router);
162        let debug = format!("{:?}", service);
163        assert!(debug.contains("CatchError"));
164    }
165}