1use std::{marker::PhantomData, pin::pin, sync::Arc};
2
3use futures::future::Either;
4use fxhash::FxHashMap;
5use rmcp::{
6 ErrorData, ServerHandler,
7 handler::server::tool::cached_schema_for_type,
8 model::{CallToolResult, ListToolsResult, Tool},
9};
10use sacp::{BoxFuture, ByteStreams, Component};
11
12mod tool;
13use schemars::JsonSchema;
14use serde::{Serialize, de::DeserializeOwned};
15use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
16pub use tool::*;
17
18use crate::McpContext;
19
20#[derive(Clone, Default)]
22pub struct McpServer {
23 instructions: Option<String>,
24 tool_models: Vec<rmcp::model::Tool>,
25 tools: FxHashMap<String, Arc<dyn ErasedMcpTool>>,
26}
27
28impl McpServer {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn instructions(mut self, instructions: impl ToString) -> Self {
36 self.instructions = Some(instructions.to_string());
37 self
38 }
39
40 pub fn tool(mut self, tool: impl McpTool + 'static) -> Self {
42 let tool_model = make_tool_model(&tool);
43 self.tool_models.push(tool_model);
44 self.tools.insert(tool.name(), make_erased_mcp_tool(tool));
45 self
46 }
47
48 pub fn tool_fn<P, R, F, H>(
64 self,
65 name: impl ToString,
66 description: impl ToString,
67 func: F,
68 to_future_hack: H,
69 ) -> Self
70 where
71 P: JsonSchema + DeserializeOwned + 'static + Send,
72 R: JsonSchema + Serialize + 'static + Send,
73 F: AsyncFn(P, McpContext) -> Result<R, sacp::Error> + Send + Sync + 'static,
74 H: Fn(&F, P, McpContext) -> BoxFuture<'_, Result<R, sacp::Error>> + Send + Sync + 'static,
75 {
76 struct ToolFnTool<P, R, F, H> {
77 name: String,
78 description: String,
79 func: F,
80 to_future_hack: H,
81 phantom: PhantomData<fn(P) -> R>,
82 }
83
84 impl<P, R, F, H> McpTool for ToolFnTool<P, R, F, H>
85 where
86 P: JsonSchema + DeserializeOwned + 'static + Send,
87 R: JsonSchema + Serialize + 'static + Send,
88 F: AsyncFn(P, McpContext) -> Result<R, sacp::Error> + Send + Sync + 'static,
89 H: Fn(&F, P, McpContext) -> BoxFuture<'_, Result<R, sacp::Error>>
90 + Send
91 + Sync
92 + 'static,
93 {
94 type Input = P;
95 type Output = R;
96
97 fn name(&self) -> String {
98 self.name.clone()
99 }
100
101 fn description(&self) -> String {
102 self.description.clone()
103 }
104
105 async fn call_tool(&self, params: P, cx: McpContext) -> Result<R, sacp::Error> {
106 (self.to_future_hack)(&self.func, params, cx).await
107 }
108 }
109
110 self.tool(ToolFnTool {
111 name: name.to_string(),
112 description: description.to_string(),
113 func,
114 to_future_hack,
115 phantom: PhantomData::<fn(P) -> R>,
116 })
117 }
118
119 pub(crate) fn new_connection(&self, mcp_cx: McpContext) -> McpServerConnection {
122 McpServerConnection {
123 service: self.clone(),
124 mcp_cx,
125 }
126 }
127}
128
129pub(crate) struct McpServerConnection {
131 service: McpServer,
132 mcp_cx: McpContext,
133}
134
135impl Component for McpServerConnection {
136 async fn serve(self, client: impl Component) -> Result<(), sacp::Error> {
137 let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
139 let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
140 let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
141
142 let byte_streams =
144 ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
145
146 tokio::spawn(async move {
148 let _ = byte_streams.serve(client).await;
149 });
150
151 let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
153 .await
154 .map_err(sacp::Error::into_internal_error)?;
155
156 running_server
158 .waiting()
159 .await
160 .map(|_quit_reason| ())
161 .map_err(sacp::Error::into_internal_error)
162 }
163}
164
165impl ServerHandler for McpServerConnection {
166 async fn call_tool(
167 &self,
168 request: rmcp::model::CallToolRequestParam,
169 context: rmcp::service::RequestContext<rmcp::RoleServer>,
170 ) -> Result<CallToolResult, ErrorData> {
171 let Some(tool) = self.service.tools.get(&request.name[..]) else {
173 return Err(rmcp::model::ErrorData::invalid_params(
174 format!("tool `{}` not found", request.name),
175 None,
176 ));
177 };
178
179 let serde_value = serde_json::to_value(request.arguments).expect("valid json");
181
182 match futures::future::select(
184 tool.call_tool(serde_value, self.mcp_cx.clone()),
185 pin!(context.ct.cancelled()),
186 )
187 .await
188 {
189 Either::Left((m, _)) => match m {
191 Ok(result) => Ok(CallToolResult::structured(result)),
192 Err(error) => Err(to_rmcp_error(error)),
193 },
194
195 Either::Right(((), _)) => {
197 Err(rmcp::ErrorData::internal_error("operation cancelled", None))
198 }
199 }
200 }
201
202 async fn list_tools(
203 &self,
204 _request: Option<rmcp::model::PaginatedRequestParam>,
205 _context: rmcp::service::RequestContext<rmcp::RoleServer>,
206 ) -> Result<rmcp::model::ListToolsResult, ErrorData> {
207 Ok(ListToolsResult::with_all_items(
209 self.service.tool_models.clone(),
210 ))
211 }
212
213 fn get_info(&self) -> rmcp::model::ServerInfo {
214 rmcp::model::ServerInfo {
216 protocol_version: rmcp::model::ProtocolVersion::default(),
217 capabilities: rmcp::model::ServerCapabilities::builder()
218 .enable_tools()
219 .build(),
220 server_info: rmcp::model::Implementation::default(),
221 instructions: self.service.instructions.clone(),
222 }
223 }
224}
225
226trait ErasedMcpTool: Send + Sync {
228 fn call_tool(
229 &self,
230 input: serde_json::Value,
231 context: McpContext,
232 ) -> BoxFuture<'_, Result<serde_json::Value, sacp::Error>>;
233}
234
235fn make_tool_model<M: McpTool>(tool: &M) -> Tool {
237 rmcp::model::Tool {
238 name: tool.name().into(),
239 title: tool.title(),
240 description: Some(tool.description().into()),
241 input_schema: cached_schema_for_type::<M::Input>(),
242 output_schema: Some(cached_schema_for_type::<M::Output>()),
243 annotations: None,
244 icons: None,
245 }
246}
247
248fn make_erased_mcp_tool<'s, M: McpTool + 's>(tool: M) -> Arc<dyn ErasedMcpTool + 's> {
250 struct ErasedMcpToolImpl<M: McpTool> {
251 tool: M,
252 }
253
254 impl<M: McpTool> ErasedMcpTool for ErasedMcpToolImpl<M> {
255 fn call_tool(
256 &self,
257 input: serde_json::Value,
258 context: McpContext,
259 ) -> BoxFuture<'_, Result<serde_json::Value, sacp::Error>> {
260 Box::pin(async move {
261 let input = serde_json::from_value(input).map_err(sacp::util::internal_error)?;
262 serde_json::to_value(self.tool.call_tool(input, context).await?)
263 .map_err(sacp::util::internal_error)
264 })
265 }
266 }
267
268 Arc::new(ErasedMcpToolImpl { tool })
269}
270
271fn to_rmcp_error(error: sacp::Error) -> rmcp::ErrorData {
273 rmcp::ErrorData {
274 code: rmcp::model::ErrorCode(error.code),
275 message: error.message.into(),
276 data: error.data,
277 }
278}