sacp_proxy/
mcp_server_builder.rs

1use std::{marker::PhantomData, pin::pin, sync::Arc};
2
3use futures::future::Either;
4use fxhash::FxHashMap;
5use rmcp::{
6    ErrorData, ServerHandler,
7    handler::server::tool::cached_schema_for_type,
8    model::{CallToolResult, ListToolsResult, Tool},
9};
10use sacp::{BoxFuture, ByteStreams, Component};
11
12mod tool;
13use schemars::JsonSchema;
14use serde::{Serialize, de::DeserializeOwned};
15use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
16pub use tool::*;
17
18use crate::McpContext;
19
20/// Our MCP server implementation.
21#[derive(Clone, Default)]
22pub struct McpServer {
23    instructions: Option<String>,
24    tool_models: Vec<rmcp::model::Tool>,
25    tools: FxHashMap<String, Arc<dyn ErasedMcpTool>>,
26}
27
28impl McpServer {
29    /// Create an empty server with no content.
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Set the server instructions that are provided to the client.
35    pub fn instructions(mut self, instructions: impl ToString) -> Self {
36        self.instructions = Some(instructions.to_string());
37        self
38    }
39
40    /// Add a tool to the server.
41    pub fn tool(mut self, tool: impl McpTool + 'static) -> Self {
42        let tool_model = make_tool_model(&tool);
43        self.tool_models.push(tool_model);
44        self.tools.insert(tool.name(), make_erased_mcp_tool(tool));
45        self
46    }
47
48    /// Convenience wrapper for defining a tool without having to create a struct.
49    ///
50    /// # Parameters
51    ///
52    /// * `name`: The name of the tool.
53    /// * `description`: The description of the tool.
54    /// * `func`: The function that implements the tool. Use an async closure like `async |args, cx| { .. }`.
55    /// * `to_future_hack`: A function that converts the tool function into a future.
56    ///   You should always write `|t, args, cx| Box::pin(t(args, cx))`.
57    ///   This is needed to sidestep current Rust language limitations.
58    ///
59    /// # Examples
60    ///
61    /// ```rust
62    /// ```
63    pub fn tool_fn<P, R, F, H>(
64        self,
65        name: impl ToString,
66        description: impl ToString,
67        func: F,
68        to_future_hack: H,
69    ) -> Self
70    where
71        P: JsonSchema + DeserializeOwned + 'static + Send,
72        R: JsonSchema + Serialize + 'static + Send,
73        F: AsyncFn(P, McpContext) -> Result<R, sacp::Error> + Send + Sync + 'static,
74        H: Fn(&F, P, McpContext) -> BoxFuture<'_, Result<R, sacp::Error>> + Send + Sync + 'static,
75    {
76        struct ToolFnTool<P, R, F, H> {
77            name: String,
78            description: String,
79            func: F,
80            to_future_hack: H,
81            phantom: PhantomData<fn(P) -> R>,
82        }
83
84        impl<P, R, F, H> McpTool for ToolFnTool<P, R, F, H>
85        where
86            P: JsonSchema + DeserializeOwned + 'static + Send,
87            R: JsonSchema + Serialize + 'static + Send,
88            F: AsyncFn(P, McpContext) -> Result<R, sacp::Error> + Send + Sync + 'static,
89            H: Fn(&F, P, McpContext) -> BoxFuture<'_, Result<R, sacp::Error>>
90                + Send
91                + Sync
92                + 'static,
93        {
94            type Input = P;
95            type Output = R;
96
97            fn name(&self) -> String {
98                self.name.clone()
99            }
100
101            fn description(&self) -> String {
102                self.description.clone()
103            }
104
105            async fn call_tool(&self, params: P, cx: McpContext) -> Result<R, sacp::Error> {
106                (self.to_future_hack)(&self.func, params, cx).await
107            }
108        }
109
110        self.tool(ToolFnTool {
111            name: name.to_string(),
112            description: description.to_string(),
113            func,
114            to_future_hack,
115            phantom: PhantomData::<fn(P) -> R>,
116        })
117    }
118
119    /// Create a connection to communicate with this server given the MCP context.
120    /// This is pub(crate) because it is only used internally by the MCP server registry.
121    pub(crate) fn new_connection(&self, mcp_cx: McpContext) -> McpServerConnection {
122        McpServerConnection {
123            service: self.clone(),
124            mcp_cx,
125        }
126    }
127}
128
129/// An MCP server instance connected to the ACP framework.
130pub(crate) struct McpServerConnection {
131    service: McpServer,
132    mcp_cx: McpContext,
133}
134
135impl Component for McpServerConnection {
136    async fn serve(self, client: impl Component) -> Result<(), sacp::Error> {
137        // Create tokio byte streams that rmcp expects
138        let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
139        let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
140        let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
141
142        // Create ByteStreams component for the client side
143        let byte_streams =
144            ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
145
146        // Spawn task to connect byte_streams to the provided client
147        tokio::spawn(async move {
148            let _ = byte_streams.serve(client).await;
149        });
150
151        // Run the rmcp server with the server side of the duplex stream
152        let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
153            .await
154            .map_err(sacp::Error::into_internal_error)?;
155
156        // Wait for the server to finish
157        running_server
158            .waiting()
159            .await
160            .map(|_quit_reason| ())
161            .map_err(sacp::Error::into_internal_error)
162    }
163}
164
165impl ServerHandler for McpServerConnection {
166    async fn call_tool(
167        &self,
168        request: rmcp::model::CallToolRequestParam,
169        context: rmcp::service::RequestContext<rmcp::RoleServer>,
170    ) -> Result<CallToolResult, ErrorData> {
171        // Lookup the tool definition, erroring if not found
172        let Some(tool) = self.service.tools.get(&request.name[..]) else {
173            return Err(rmcp::model::ErrorData::invalid_params(
174                format!("tool `{}` not found", request.name),
175                None,
176            ));
177        };
178
179        // Convert input into JSON
180        let serde_value = serde_json::to_value(request.arguments).expect("valid json");
181
182        // Execute the user's tool, unless cancellation occurs
183        match futures::future::select(
184            tool.call_tool(serde_value, self.mcp_cx.clone()),
185            pin!(context.ct.cancelled()),
186        )
187        .await
188        {
189            // If completed successfully
190            Either::Left((m, _)) => match m {
191                Ok(result) => Ok(CallToolResult::structured(result)),
192                Err(error) => Err(to_rmcp_error(error)),
193            },
194
195            // If cancelled
196            Either::Right(((), _)) => {
197                Err(rmcp::ErrorData::internal_error("operation cancelled", None))
198            }
199        }
200    }
201
202    async fn list_tools(
203        &self,
204        _request: Option<rmcp::model::PaginatedRequestParam>,
205        _context: rmcp::service::RequestContext<rmcp::RoleServer>,
206    ) -> Result<rmcp::model::ListToolsResult, ErrorData> {
207        // Just return all tools
208        Ok(ListToolsResult::with_all_items(
209            self.service.tool_models.clone(),
210        ))
211    }
212
213    fn get_info(&self) -> rmcp::model::ServerInfo {
214        // Basic server info
215        rmcp::model::ServerInfo {
216            protocol_version: rmcp::model::ProtocolVersion::default(),
217            capabilities: rmcp::model::ServerCapabilities::builder()
218                .enable_tools()
219                .build(),
220            server_info: rmcp::model::Implementation::default(),
221            instructions: self.service.instructions.clone(),
222        }
223    }
224}
225
226/// Erased version of the MCP tool trait that is dyn-compatible.
227trait ErasedMcpTool: Send + Sync {
228    fn call_tool(
229        &self,
230        input: serde_json::Value,
231        context: McpContext,
232    ) -> BoxFuture<'_, Result<serde_json::Value, sacp::Error>>;
233}
234
235//// Create an `rmcp` tool mode from our [`McpTool`] trait.
236fn make_tool_model<M: McpTool>(tool: &M) -> Tool {
237    rmcp::model::Tool {
238        name: tool.name().into(),
239        title: tool.title(),
240        description: Some(tool.description().into()),
241        input_schema: cached_schema_for_type::<M::Input>(),
242        output_schema: Some(cached_schema_for_type::<M::Output>()),
243        annotations: None,
244        icons: None,
245    }
246}
247
248/// Create a [`ErasedMcpTool`] from a [`McpTool`], erasing the type details.
249fn make_erased_mcp_tool<'s, M: McpTool + 's>(tool: M) -> Arc<dyn ErasedMcpTool + 's> {
250    struct ErasedMcpToolImpl<M: McpTool> {
251        tool: M,
252    }
253
254    impl<M: McpTool> ErasedMcpTool for ErasedMcpToolImpl<M> {
255        fn call_tool(
256            &self,
257            input: serde_json::Value,
258            context: McpContext,
259        ) -> BoxFuture<'_, Result<serde_json::Value, sacp::Error>> {
260            Box::pin(async move {
261                let input = serde_json::from_value(input).map_err(sacp::util::internal_error)?;
262                serde_json::to_value(self.tool.call_tool(input, context).await?)
263                    .map_err(sacp::util::internal_error)
264            })
265        }
266    }
267
268    Arc::new(ErasedMcpToolImpl { tool })
269}
270
271/// Convert a [`sacp::Error`] into an [`rmcp::ErrorData`].
272fn to_rmcp_error(error: sacp::Error) -> rmcp::ErrorData {
273    rmcp::ErrorData {
274        code: rmcp::model::ErrorCode(error.code),
275        message: error.message.into(),
276        data: error.data,
277    }
278}