tower_mcp/transport/
service.rs1use std::convert::Infallible;
19use std::fmt;
20use std::future::Future;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use tower::util::BoxCloneService;
26use tower_service::Service;
27
28use crate::error::JsonRpcError;
29use crate::router::{McpRouter, RouterRequest, RouterResponse};
30
31pub type McpBoxService = BoxCloneService<RouterRequest, RouterResponse, Infallible>;
38
39pub type ServiceFactory = Arc<dyn Fn(McpRouter) -> McpBoxService + Send + Sync>;
46
47pub fn identity_factory() -> ServiceFactory {
51 Arc::new(|router: McpRouter| BoxCloneService::new(router))
52}
53
54pub struct CatchError<S> {
64 inner: S,
65}
66
67impl<S> CatchError<S> {
68 pub fn new(inner: S) -> Self {
70 Self { inner }
71 }
72}
73
74impl<S: Clone> Clone for CatchError<S> {
75 fn clone(&self) -> Self {
76 Self {
77 inner: self.inner.clone(),
78 }
79 }
80}
81
82impl<S: fmt::Debug> fmt::Debug for CatchError<S> {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.debug_struct("CatchError")
85 .field("inner", &self.inner)
86 .finish()
87 }
88}
89
90impl<S> Service<RouterRequest> for CatchError<S>
91where
92 S: Service<RouterRequest, Response = RouterResponse> + Clone + Send + 'static,
93 S::Error: fmt::Display + Send,
94 S::Future: Send,
95{
96 type Response = RouterResponse;
97 type Error = Infallible;
98 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
99
100 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101 self.inner.poll_ready(cx).map_err(|_| unreachable!())
102 }
103
104 fn call(&mut self, req: RouterRequest) -> Self::Future {
105 let request_id = req.id.clone();
108 let fut = self.inner.call(req);
109
110 Box::pin(async move {
111 match fut.await {
112 Ok(response) => Ok(response),
113 Err(err) => Ok(RouterResponse {
114 id: request_id,
115 inner: Err(JsonRpcError::internal_error(err.to_string())),
116 }),
117 }
118 })
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::protocol::RequestId;
126
127 #[test]
128 fn test_identity_factory_produces_service() {
129 let router = McpRouter::new().server_info("test", "1.0.0");
130 let factory = identity_factory();
131 let _service = factory(router);
132 }
133
134 #[tokio::test]
135 async fn test_catch_error_passes_through_success() {
136 let router = McpRouter::new().server_info("test", "1.0.0");
137 let mut service = CatchError::new(router);
138
139 let req = RouterRequest {
140 id: RequestId::Number(1),
141 inner: crate::protocol::McpRequest::Ping,
142 extensions: crate::router::Extensions::new(),
143 };
144
145 let result = Service::call(&mut service, req).await;
146 assert!(result.is_ok());
147 let response = result.unwrap();
148 assert!(response.inner.is_ok());
149 }
150
151 #[test]
152 fn test_catch_error_clone() {
153 let router = McpRouter::new().server_info("test", "1.0.0");
154 let service = CatchError::new(router);
155 let _clone = service.clone();
156 }
157
158 #[test]
159 fn test_catch_error_debug() {
160 let router = McpRouter::new().server_info("test", "1.0.0");
161 let service = CatchError::new(router);
162 let debug = format!("{:?}", service);
163 assert!(debug.contains("CatchError"));
164 }
165}