1use std::convert::Infallible;
19use std::fmt;
20use std::future::Future;
21use std::pin::Pin;
22#[cfg(any(feature = "http", feature = "websocket"))]
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use pin_project_lite::pin_project;
27
28use tower::util::BoxCloneService;
29use tower_service::Service;
30
31use crate::error::JsonRpcError;
32use crate::protocol::{McpRequest, RequestId};
33#[cfg(any(feature = "http", feature = "websocket"))]
34use crate::router::McpRouter;
35use crate::router::{RouterRequest, RouterResponse, ToolAnnotationsMap};
36
37pub type McpBoxService = BoxCloneService<RouterRequest, RouterResponse, Infallible>;
44
45#[cfg(any(feature = "http", feature = "websocket"))]
52pub(crate) type ServiceFactory = Arc<dyn Fn(McpRouter) -> McpBoxService + Send + Sync>;
53
54#[cfg(any(feature = "http", feature = "websocket"))]
59pub(crate) fn identity_factory() -> ServiceFactory {
60 Arc::new(|router: McpRouter| {
61 let annotations = router.tool_annotations_map();
62 BoxCloneService::new(InjectAnnotations::new(router, annotations))
63 })
64}
65
66#[derive(Clone)]
73pub struct InjectAnnotations<S> {
74 inner: S,
75 annotations: ToolAnnotationsMap,
76}
77
78impl<S> InjectAnnotations<S> {
79 pub fn new(inner: S, annotations: ToolAnnotationsMap) -> Self {
81 Self { inner, annotations }
82 }
83}
84
85impl<S: fmt::Debug> fmt::Debug for InjectAnnotations<S> {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 f.debug_struct("InjectAnnotations")
88 .field("inner", &self.inner)
89 .finish()
90 }
91}
92
93impl<S> Service<RouterRequest> for InjectAnnotations<S>
94where
95 S: Service<RouterRequest, Response = RouterResponse>,
96{
97 type Response = RouterResponse;
98 type Error = S::Error;
99 type Future = S::Future;
100
101 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102 self.inner.poll_ready(cx)
103 }
104
105 fn call(&mut self, mut req: RouterRequest) -> Self::Future {
106 if matches!(&req.inner, McpRequest::CallTool(_)) {
107 req.extensions.insert(self.annotations.clone());
108 }
109 self.inner.call(req)
110 }
111}
112
113pub struct CatchError<S> {
123 inner: S,
124}
125
126impl<S> CatchError<S> {
127 pub fn new(inner: S) -> Self {
129 Self { inner }
130 }
131}
132
133impl<S: Clone> Clone for CatchError<S> {
134 fn clone(&self) -> Self {
135 Self {
136 inner: self.inner.clone(),
137 }
138 }
139}
140
141impl<S: fmt::Debug> fmt::Debug for CatchError<S> {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 f.debug_struct("CatchError")
144 .field("inner", &self.inner)
145 .finish()
146 }
147}
148
149pin_project! {
150 pub struct CatchErrorFuture<F> {
152 #[pin]
153 inner: F,
154 request_id: Option<RequestId>,
155 }
156}
157
158impl<F, E> Future for CatchErrorFuture<F>
159where
160 F: Future<Output = Result<RouterResponse, E>>,
161 E: fmt::Display,
162{
163 type Output = Result<RouterResponse, Infallible>;
164
165 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 let this = self.project();
167 match this.inner.poll(cx) {
168 Poll::Pending => Poll::Pending,
169 Poll::Ready(Ok(response)) => Poll::Ready(Ok(response)),
170 Poll::Ready(Err(err)) => {
171 let request_id = this.request_id.take().unwrap_or(RequestId::Number(0));
172 Poll::Ready(Ok(RouterResponse {
173 id: request_id,
174 inner: Err(JsonRpcError::internal_error(err.to_string())),
175 }))
176 }
177 }
178 }
179}
180
181impl<S> Service<RouterRequest> for CatchError<S>
182where
183 S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
184 S::Error: fmt::Display + Send,
185 S::Future: Send,
186{
187 type Response = RouterResponse;
188 type Error = Infallible;
189 type Future = CatchErrorFuture<S::Future>;
190
191 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192 self.inner.poll_ready(cx).map_err(|_| unreachable!())
193 }
194
195 fn call(&mut self, req: RouterRequest) -> Self::Future {
196 let request_id = req.id.clone();
199 let fut = self.inner.call(req);
200
201 CatchErrorFuture {
202 inner: fut,
203 request_id: Some(request_id),
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::sync::Arc;
211
212 use super::*;
213 use crate::protocol::{CallToolParams, CallToolResult, RequestId, ToolAnnotations};
214 use crate::router::McpRouter;
215
216 #[test]
217 #[cfg(any(feature = "http", feature = "websocket"))]
218 fn test_identity_factory_produces_service() {
219 let router = McpRouter::new().server_info("test", "1.0.0");
220 let factory = identity_factory();
221 let _service = factory(router);
222 }
223
224 #[tokio::test]
225 async fn test_catch_error_passes_through_success() {
226 let router = McpRouter::new().server_info("test", "1.0.0");
227 let mut service = CatchError::new(router);
228
229 let req = RouterRequest {
230 id: RequestId::Number(1),
231 inner: crate::protocol::McpRequest::Ping,
232 extensions: crate::router::Extensions::new(),
233 };
234
235 let result = Service::call(&mut service, req).await;
236 assert!(result.is_ok());
237 let response = result.unwrap();
238 assert!(response.inner.is_ok());
239 }
240
241 #[test]
242 fn test_catch_error_clone() {
243 let router = McpRouter::new().server_info("test", "1.0.0");
244 let service = CatchError::new(router);
245 let _clone = service.clone();
246 }
247
248 #[test]
249 fn test_catch_error_debug() {
250 let router = McpRouter::new().server_info("test", "1.0.0");
251 let service = CatchError::new(router);
252 let debug = format!("{:?}", service);
253 assert!(debug.contains("CatchError"));
254 }
255
256 #[tokio::test]
257 async fn test_inject_annotations_for_call_tool() {
258 use crate::{CallToolResult, ToolBuilder};
259
260 let tool = ToolBuilder::new("read_data")
261 .description("Read some data")
262 .annotations(ToolAnnotations {
263 read_only_hint: true,
264 destructive_hint: false,
265 ..Default::default()
266 })
267 .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
268 .build();
269
270 let router = McpRouter::new().server_info("test", "1.0.0").tool(tool);
271 let annotations = router.tool_annotations_map();
272 let mut service = InjectAnnotations::new(router, annotations);
273
274 let req = RouterRequest {
275 id: RequestId::Number(1),
276 inner: McpRequest::CallTool(CallToolParams {
277 name: "read_data".to_string(),
278 arguments: serde_json::json!({}),
279 meta: None,
280 task: None,
281 }),
282 extensions: crate::router::Extensions::new(),
283 };
284
285 let result = Service::call(&mut service, req).await;
288 assert!(result.is_ok());
289 }
290
291 #[tokio::test]
292 async fn test_inject_annotations_skips_non_call_tool() {
293 let router = McpRouter::new().server_info("test", "1.0.0");
294 let annotations = router.tool_annotations_map();
295 let mut service = InjectAnnotations::new(router, annotations);
296
297 let req = RouterRequest {
298 id: RequestId::Number(1),
299 inner: McpRequest::Ping,
300 extensions: crate::router::Extensions::new(),
301 };
302
303 let result = Service::call(&mut service, req).await;
304 assert!(result.is_ok());
305 }
306
307 #[test]
308 fn test_tool_annotations_map_methods() {
309 use crate::ToolBuilder;
310
311 let read_tool = ToolBuilder::new("reader")
312 .description("Read-only tool")
313 .annotations(ToolAnnotations {
314 read_only_hint: true,
315 destructive_hint: false,
316 idempotent_hint: true,
317 ..Default::default()
318 })
319 .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
320 .build();
321
322 let write_tool = ToolBuilder::new("writer")
323 .description("Destructive tool")
324 .annotations(ToolAnnotations {
325 read_only_hint: false,
326 destructive_hint: true,
327 idempotent_hint: false,
328 ..Default::default()
329 })
330 .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
331 .build();
332
333 let plain_tool = ToolBuilder::new("plain")
334 .description("No annotations")
335 .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
336 .build();
337
338 let router = McpRouter::new()
339 .server_info("test", "1.0.0")
340 .tool(read_tool)
341 .tool(write_tool)
342 .tool(plain_tool);
343
344 let map = router.tool_annotations_map();
345
346 assert!(map.is_read_only("reader"));
348 assert!(!map.is_destructive("reader"));
349 assert!(map.is_idempotent("reader"));
350
351 assert!(!map.is_read_only("writer"));
353 assert!(map.is_destructive("writer"));
354 assert!(!map.is_idempotent("writer"));
355
356 assert!(!map.is_read_only("plain"));
358 assert!(map.is_destructive("plain")); assert!(!map.is_idempotent("plain"));
360
361 assert!(!map.is_read_only("nonexistent"));
363 assert!(map.is_destructive("nonexistent"));
364 assert!(!map.is_idempotent("nonexistent"));
365
366 assert!(map.get("reader").is_some());
368 assert!(map.get("writer").is_some());
369 assert!(map.get("plain").is_none());
370 assert!(map.get("nonexistent").is_none());
371 }
372
373 #[tokio::test]
374 async fn test_annotations_visible_in_middleware() {
375 use crate::ToolBuilder;
376 use crate::router::ToolAnnotationsMap;
377 use std::sync::atomic::{AtomicBool, Ordering};
378
379 #[derive(Clone)]
381 struct CheckAnnotations<S> {
382 inner: S,
383 found: Arc<AtomicBool>,
384 }
385
386 impl<S> Service<RouterRequest> for CheckAnnotations<S>
387 where
388 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>,
389 {
390 type Response = RouterResponse;
391 type Error = Infallible;
392 type Future = S::Future;
393
394 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
395 self.inner.poll_ready(cx)
396 }
397
398 fn call(&mut self, req: RouterRequest) -> Self::Future {
399 if let Some(map) = req.extensions.get::<ToolAnnotationsMap>()
400 && map.is_read_only("reader")
401 {
402 self.found.store(true, Ordering::SeqCst);
403 }
404 self.inner.call(req)
405 }
406 }
407
408 let tool = ToolBuilder::new("reader")
409 .description("A read-only tool")
410 .annotations(ToolAnnotations {
411 read_only_hint: true,
412 ..Default::default()
413 })
414 .handler(|_: serde_json::Value| async move { Ok(CallToolResult::text("ok")) })
415 .build();
416
417 let router = McpRouter::new().server_info("test", "1.0.0").tool(tool);
418 let annotations = router.tool_annotations_map();
419 let found = Arc::new(AtomicBool::new(false));
420
421 let inner = CheckAnnotations {
424 inner: router,
425 found: found.clone(),
426 };
427 let mut service = InjectAnnotations::new(inner, annotations);
428
429 let req = RouterRequest {
430 id: RequestId::Number(1),
431 inner: McpRequest::CallTool(CallToolParams {
432 name: "reader".to_string(),
433 arguments: serde_json::json!({}),
434 meta: None,
435 task: None,
436 }),
437 extensions: crate::router::Extensions::new(),
438 };
439
440 let result = Service::call(&mut service, req).await;
441 assert!(result.is_ok());
442 assert!(
443 found.load(Ordering::SeqCst),
444 "Middleware should see annotations in extensions"
445 );
446 }
447}