Skip to main content

tower_mcp/transport/
service.rs

1//! Service types for transport-level middleware support
2//!
3//! This module provides the types needed to apply tower middleware layers
4//! to MCP request processing within HTTP and WebSocket transports.
5//!
6//! The key type is `ServiceFactory`, a function that takes an [`McpRouter`]
7//! and produces a boxed, middleware-wrapped service. Transports store this
8//! factory and use it when creating sessions.
9//!
10//! [`CatchError`] is a wrapper that converts middleware errors (e.g., timeouts)
11//! into [`RouterResponse`] errors, preserving the `Error = Infallible` contract
12//! that [`JsonRpcService`] requires.
13//!
14//! [`McpRouter`]: crate::router::McpRouter
15//! [`RouterResponse`]: crate::router::RouterResponse
16//! [`JsonRpcService`]: crate::jsonrpc::JsonRpcService
17
18use std::convert::Infallible;
19use std::fmt;
20use std::future::Future;
21use std::pin::Pin;
22#[cfg(any(feature = "http", feature = "websocket"))]
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use pin_project_lite::pin_project;
27
28use tower::util::BoxCloneService;
29use tower_service::Service;
30
31use crate::error::JsonRpcError;
32use crate::protocol::{McpRequest, RequestId};
33#[cfg(any(feature = "http", feature = "websocket"))]
34use crate::router::McpRouter;
35use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
36
37/// A boxed, cloneable MCP service with `Error = Infallible`.
38///
39/// This is the service type that transports use internally after applying
40/// middleware layers. It wraps any `Service<RouterRequest>` implementation
41/// so that [`JsonRpcService`](crate::jsonrpc::JsonRpcService) can consume it
42/// without knowing the concrete middleware stack.
43pub type McpBoxService = BoxCloneService<RouterRequest, RouterResponse, Infallible>;
44
45/// A factory function that produces a [`McpBoxService`] from an [`McpRouter`].
46///
47/// Transports store this factory and call it when creating new sessions.
48/// The default factory (from `identity_factory`) returns the router as-is.
49/// When `.layer()` is called on a transport, the factory wraps the router
50/// with the given middleware and a [`CatchError`] adapter.
51#[cfg(any(feature = "http", feature = "websocket"))]
52pub(crate) type ServiceFactory = Arc<dyn Fn(McpRouter) -> McpBoxService + Send + Sync>;
53
54/// Create a `ServiceFactory` that returns the router unchanged.
55///
56/// This is the default factory used by transports when no `.layer()` is applied.
57/// Tool annotations are still injected into request extensions.
58#[cfg(any(feature = "http", feature = "websocket"))]
59pub(crate) fn identity_factory() -> ServiceFactory {
60    Arc::new(|router: McpRouter| {
61        let annotations = router.tool_annotations_map();
62        BoxCloneService::new(InjectAnnotations::new(router, annotations))
63    })
64}
65
66/// A service wrapper that injects [`ToolAnnotationsMap`] into request
67/// extensions for `tools/call` requests.
68///
69/// This allows middleware to inspect tool annotations (e.g., `read_only_hint`,
70/// `destructive_hint`) without needing direct access to the router.
71/// Transports apply this wrapper automatically.
72#[derive(Clone)]
73pub struct InjectAnnotations<S> {
74    inner: S,
75    annotations: ToolAnnotationsMap,
76}
77
78impl<S> InjectAnnotations<S> {
79    /// Create a new `InjectAnnotations` wrapping the given service.
80    pub fn new(inner: S, annotations: ToolAnnotationsMap) -> Self {
81        Self { inner, annotations }
82    }
83}
84
85impl<S: fmt::Debug> fmt::Debug for InjectAnnotations<S> {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.debug_struct("InjectAnnotations")
88            .field("inner", &self.inner)
89            .finish()
90    }
91}
92
93impl<S> Service<RouterRequest> for InjectAnnotations<S>
94where
95    S: Service<RouterRequest, Response = RouterResponse>,
96{
97    type Response = RouterResponse;
98    type Error = S::Error;
99    type Future = S::Future;
100
101    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102        self.inner.poll_ready(cx)
103    }
104
105    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
106        if matches!(&req.inner, McpRequest::CallTool(_)) {
107            req.extensions.insert(self.annotations.clone());
108        }
109        self.inner.call(req)
110    }
111}
112
113/// A service wrapper that catches errors from middleware and converts them
114/// into [`RouterResponse`] error values, maintaining the `Error = Infallible`
115/// contract required by [`JsonRpcService`](crate::jsonrpc::JsonRpcService).
116///
117/// When a middleware layer (e.g., `TimeoutLayer`) produces an error, this
118/// wrapper converts it into a JSON-RPC internal error response using the
119/// request ID from the original request. This allows error information to
120/// flow through the normal response path rather than requiring special
121/// error handling at the transport level.
122pub struct CatchError<S> {
123    inner: S,
124}
125
126impl<S> CatchError<S> {
127    /// Create a new `CatchError` wrapping the given service.
128    pub fn new(inner: S) -> Self {
129        Self { inner }
130    }
131}
132
133impl<S: Clone> Clone for CatchError<S> {
134    fn clone(&self) -> Self {
135        Self {
136            inner: self.inner.clone(),
137        }
138    }
139}
140
141impl<S: fmt::Debug> fmt::Debug for CatchError<S> {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        f.debug_struct("CatchError")
144            .field("inner", &self.inner)
145            .finish()
146    }
147}
148
149pin_project! {
150    /// Future for [`CatchError`].
151    pub struct CatchErrorFuture<F> {
152        #[pin]
153        inner: F,
154        request_id: Option<RequestId>,
155    }
156}
157
158impl<F, E> Future for CatchErrorFuture<F>
159where
160    F: Future<Output = Result<RouterResponse, E>>,
161    E: fmt::Display,
162{
163    type Output = Result<RouterResponse, Infallible>;
164
165    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166        let this = self.project();
167        match this.inner.poll(cx) {
168            Poll::Pending => Poll::Pending,
169            Poll::Ready(Ok(response)) => Poll::Ready(Ok(response)),
170            Poll::Ready(Err(err)) => {
171                let request_id = this.request_id.take().unwrap_or(RequestId::Number(0));
172                Poll::Ready(Ok(RouterResponse {
173                    id: request_id,
174                    inner: Err(JsonRpcError::internal_error(err.to_string())),
175                }))
176            }
177        }
178    }
179}
180
181impl<S> Service<RouterRequest> for CatchError<S>
182where
183    S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
184    S::Error: fmt::Display + Send,
185    S::Future: Send,
186{
187    type Response = RouterResponse;
188    type Error = Infallible;
189    type Future = CatchErrorFuture<S::Future>;
190
191    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192        self.inner.poll_ready(cx).map_err(|_| unreachable!())
193    }
194
195    fn call(&mut self, req: RouterRequest) -> Self::Future {
196        // Capture the request ID before passing the request to the inner service.
197        // We need this to build a proper JSON-RPC error response if the middleware fails.
198        let request_id = req.id.clone();
199        let fut = self.inner.call(req);
200
201        CatchErrorFuture {
202            inner: fut,
203            request_id: Some(request_id),
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::sync::Arc;
211
212    use super::*;
213    use crate::protocol::{CallToolParams, CallToolResult, RequestId, ToolAnnotations};
214    use crate::router::McpRouter;
215
216    #[test]
217    #[cfg(any(feature = "http", feature = "websocket"))]
218    fn test_identity_factory_produces_service() {
219        let router = McpRouter::new().server_info("test", "1.0.0");
220        let factory = identity_factory();
221        let _service = factory(router);
222    }
223
224    #[tokio::test]
225    async fn test_catch_error_passes_through_success() {
226        let router = McpRouter::new().server_info("test", "1.0.0");
227        let mut service = CatchError::new(router);
228
229        let req = RouterRequest {
230            id: RequestId::Number(1),
231            inner: crate::protocol::McpRequest::Ping,
232            extensions: crate::router::Extensions::new(),
233        };
234
235        let result = Service::call(&mut service, req).await;
236        assert!(result.is_ok());
237        let response = result.unwrap();
238        assert!(response.inner.is_ok());
239    }
240
241    #[test]
242    fn test_catch_error_clone() {
243        let router = McpRouter::new().server_info("test", "1.0.0");
244        let service = CatchError::new(router);
245        let _clone = service.clone();
246    }
247
248    #[test]
249    fn test_catch_error_debug() {
250        let router = McpRouter::new().server_info("test", "1.0.0");
251        let service = CatchError::new(router);
252        let debug = format!("{:?}", service);
253        assert!(debug.contains("CatchError"));
254    }
255
256    #[tokio::test]
257    async fn test_inject_annotations_for_call_tool() {
258        use crate::{CallToolResult, ToolBuilder};
259
260        let tool = ToolBuilder::new("read_data")
261            .description("Read some data")
262            .annotations(ToolAnnotations {
263                read_only_hint: true,
264                destructive_hint: false,
265                ..Default::default()
266            })
267            .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
268            .build();
269
270        let router = McpRouter::new().server_info("test", "1.0.0").tool(tool);
271        let annotations = router.tool_annotations_map();
272        let mut service = InjectAnnotations::new(router, annotations);
273
274        let req = RouterRequest {
275            id: RequestId::Number(1),
276            inner: McpRequest::CallTool(CallToolParams {
277                name: "read_data".to_string(),
278                arguments: serde_json::json!({}),
279                meta: None,
280                task: None,
281            }),
282            extensions: crate::router::Extensions::new(),
283        };
284
285        // Verify the service processes the request (we can't inspect extensions
286        // after call, but we test the map is built correctly below)
287        let result = Service::call(&mut service, req).await;
288        assert!(result.is_ok());
289    }
290
291    #[tokio::test]
292    async fn test_inject_annotations_skips_non_call_tool() {
293        let router = McpRouter::new().server_info("test", "1.0.0");
294        let annotations = router.tool_annotations_map();
295        let mut service = InjectAnnotations::new(router, annotations);
296
297        let req = RouterRequest {
298            id: RequestId::Number(1),
299            inner: McpRequest::Ping,
300            extensions: crate::router::Extensions::new(),
301        };
302
303        let result = Service::call(&mut service, req).await;
304        assert!(result.is_ok());
305    }
306
307    #[test]
308    fn test_tool_annotations_map_methods() {
309        use crate::ToolBuilder;
310
311        let read_tool = ToolBuilder::new("reader")
312            .description("Read-only tool")
313            .annotations(ToolAnnotations {
314                read_only_hint: true,
315                destructive_hint: false,
316                idempotent_hint: true,
317                ..Default::default()
318            })
319            .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
320            .build();
321
322        let write_tool = ToolBuilder::new("writer")
323            .description("Destructive tool")
324            .annotations(ToolAnnotations {
325                read_only_hint: false,
326                destructive_hint: true,
327                idempotent_hint: false,
328                ..Default::default()
329            })
330            .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
331            .build();
332
333        let plain_tool = ToolBuilder::new("plain")
334            .description("No annotations")
335            .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
336            .build();
337
338        let router = McpRouter::new()
339            .server_info("test", "1.0.0")
340            .tool(read_tool)
341            .tool(write_tool)
342            .tool(plain_tool);
343
344        let map = router.tool_annotations_map();
345
346        // read-only tool
347        assert!(map.is_read_only("reader"));
348        assert!(!map.is_destructive("reader"));
349        assert!(map.is_idempotent("reader"));
350
351        // destructive tool
352        assert!(!map.is_read_only("writer"));
353        assert!(map.is_destructive("writer"));
354        assert!(!map.is_idempotent("writer"));
355
356        // tool without annotations: not in map, defaults apply
357        assert!(!map.is_read_only("plain"));
358        assert!(map.is_destructive("plain")); // default is true
359        assert!(!map.is_idempotent("plain"));
360
361        // nonexistent tool: same defaults as no annotations
362        assert!(!map.is_read_only("nonexistent"));
363        assert!(map.is_destructive("nonexistent"));
364        assert!(!map.is_idempotent("nonexistent"));
365
366        // get() returns None for plain and nonexistent
367        assert!(map.get("reader").is_some());
368        assert!(map.get("writer").is_some());
369        assert!(map.get("plain").is_none());
370        assert!(map.get("nonexistent").is_none());
371    }
372
373    #[tokio::test]
374    async fn test_annotations_visible_in_middleware() {
375        use crate::ToolBuilder;
376        use crate::router::ToolAnnotationsMap;
377        use std::sync::atomic::{AtomicBool, Ordering};
378
379        // A minimal middleware that checks for annotations in extensions.
380        #[derive(Clone)]
381        struct CheckAnnotations<S> {
382            inner: S,
383            found: Arc<AtomicBool>,
384        }
385
386        impl<S> Service<RouterRequest> for CheckAnnotations<S>
387        where
388            S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>,
389        {
390            type Response = RouterResponse;
391            type Error = Infallible;
392            type Future = S::Future;
393
394            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
395                self.inner.poll_ready(cx)
396            }
397
398            fn call(&mut self, req: RouterRequest) -> Self::Future {
399                if let Some(map) = req.extensions.get::<ToolAnnotationsMap>()
400                    && map.is_read_only("reader")
401                {
402                    self.found.store(true, Ordering::SeqCst);
403                }
404                self.inner.call(req)
405            }
406        }
407
408        let tool = ToolBuilder::new("reader")
409            .description("A read-only tool")
410            .annotations(ToolAnnotations {
411                read_only_hint: true,
412                ..Default::default()
413            })
414            .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
415            .build();
416
417        let router = McpRouter::new().server_info("test", "1.0.0").tool(tool);
418        let annotations = router.tool_annotations_map();
419        let found = Arc::new(AtomicBool::new(false));
420
421        // InjectAnnotations is outer (runs first, injects into extensions),
422        // then CheckAnnotations sees the enriched request.
423        let inner = CheckAnnotations {
424            inner: router,
425            found: found.clone(),
426        };
427        let mut service = InjectAnnotations::new(inner, annotations);
428
429        let req = RouterRequest {
430            id: RequestId::Number(1),
431            inner: McpRequest::CallTool(CallToolParams {
432                name: "reader".to_string(),
433                arguments: serde_json::json!({}),
434                meta: None,
435                task: None,
436            }),
437            extensions: crate::router::Extensions::new(),
438        };
439
440        let result = Service::call(&mut service, req).await;
441        assert!(result.is_ok());
442        assert!(
443            found.load(Ordering::SeqCst),
444            "Middleware should see annotations in extensions"
445        );
446    }
447}