swiftide_agents/tools/
mcp.rs

1//! Add tools provided by an MCP server to an agent
2//!
3//! Uses the `rmcp` crate to connect to an MCP server and list available tools, and invoke them
4//!
5//! Supports any transport that the `rmcp` crate supports
6use std::borrow::Cow;
7use std::sync::Arc;
8
9use anyhow::{Context as _, Result};
10use async_trait::async_trait;
11use rmcp::RoleClient;
12use rmcp::model::{ClientInfo, Implementation, InitializeRequestParam};
13use rmcp::service::RunningService;
14use rmcp::transport::IntoTransport;
15use rmcp::{ServiceExt, model::CallToolRequestParam};
16use schemars::Schema;
17use serde::{Deserialize, Serialize};
18use swiftide_core::CommandError;
19use swiftide_core::chat_completion::ToolCall;
20use swiftide_core::{
21    Tool, ToolBox,
22    chat_completion::{ToolSpec, errors::ToolError},
23};
24use tokio::sync::RwLock;
25
26/// A filter to apply to the available tools
27#[derive(Clone, Debug, Serialize, Deserialize)]
28pub enum ToolFilter {
29    Blacklist(Vec<String>),
30    Whitelist(Vec<String>),
31}
32
33/// Connects to an MCP server and provides tools at runtime to the agent.
34///
35/// WARN: The rmcp has a quirky feature to serve from `()`. This does not work; serve from
36/// `ClientInfo` instead, or from the transport and `Swiftide` will handle the rest.
37#[derive(Clone)]
38pub struct McpToolbox {
39    service: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
40
41    /// Optional human readable name for the toolbox
42    name: Option<String>,
43
44    filter: Arc<Option<ToolFilter>>,
45}
46
47impl McpToolbox {
48    /// Blacklist tools by name, the agent will not be able to use these tools
49    pub fn with_blacklist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
50        &mut self,
51        blacklist: I,
52    ) -> &mut Self {
53        let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
54        self.filter = Some(ToolFilter::Blacklist(list)).into();
55        self
56    }
57
58    /// Whitelist tools by name, the agent will only be able to use these tools
59    pub fn with_whitelist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
60        &mut self,
61        blacklist: I,
62    ) -> &mut Self {
63        let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
64        self.filter = Some(ToolFilter::Whitelist(list)).into();
65        self
66    }
67
68    /// Apply a custom filter to the tools
69    pub fn with_filter(&mut self, filter: ToolFilter) -> &mut Self {
70        self.filter = Some(filter).into();
71        self
72    }
73
74    /// Apply an optional name to the toolbox
75    pub fn with_name(&mut self, name: impl Into<String>) -> &mut Self {
76        self.name = Some(name.into());
77        self
78    }
79
80    pub fn name(&self) -> &str {
81        self.name.as_deref().unwrap_or("MCP Toolbox")
82    }
83
84    /// Create a new toolbox from a transport
85    ///
86    /// # Errors
87    ///
88    /// Errors if the transport fails to connect
89    pub async fn try_from_transport<
90        E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
91        A,
92    >(
93        transport: impl IntoTransport<RoleClient, E, A>,
94    ) -> Result<Self> {
95        let info = Self::default_client_info();
96        let service = Arc::new(RwLock::new(Some(info.serve(transport).await?)));
97
98        Ok(Self {
99            service,
100            filter: None.into(),
101            name: None,
102        })
103    }
104
105    /// Create a new toolbox from a running service
106    pub fn from_running_service(
107        service: RunningService<RoleClient, InitializeRequestParam>,
108    ) -> Self {
109        Self {
110            service: Arc::new(RwLock::new(Some(service))),
111            filter: None.into(),
112            name: None,
113        }
114    }
115
116    fn default_client_info() -> ClientInfo {
117        ClientInfo {
118            client_info: Implementation {
119                name: "swiftide".into(),
120                version: env!("CARGO_PKG_VERSION").into(),
121            },
122            ..Default::default()
123        }
124    }
125
126    /// Disconnects from the MCP server if it is running
127    ///
128    /// If it is not running, an Ok is returned and it logs a tracing message
129    ///
130    /// # Errors
131    ///
132    /// Errors if the service is running but cannot be stopped
133    pub async fn cancel(&mut self) -> Result<()> {
134        let mut lock = self.service.write().await;
135        let Some(service) = std::mem::take(&mut *lock) else {
136            tracing::warn!("mcp server is not running");
137            return Ok(());
138        };
139
140        tracing::debug!(name = self.name(), "Stopping mcp server");
141
142        service
143            .cancel()
144            .await
145            .context("failed to stop mcp server")?;
146
147        Ok(())
148    }
149}
150
151#[async_trait]
152impl ToolBox for McpToolbox {
153    #[tracing::instrument(skip_all)]
154    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
155        let Some(service) = &*self.service.read().await else {
156            anyhow::bail!("No service available");
157        };
158        tracing::debug!(name = self.name(), "Connecting to mcp server");
159        let peer_info = service.peer_info();
160        tracing::debug!(?peer_info, name = self.name(), "Connected to mcp server");
161
162        tracing::debug!(name = self.name(), "Listing tools from mcp server");
163        let tools = service
164            .list_all_tools()
165            .await
166            .context("Failed to list tools")?;
167
168        let filter = self.filter.as_ref().clone();
169        let mut server_name = peer_info
170            .map_or("mcp", |info| info.server_info.name.as_str())
171            .trim()
172            .to_owned();
173        if server_name.is_empty() {
174            server_name = "mcp".into();
175        }
176
177        let tools = tools
178            .into_iter()
179            .filter(|tool| match &filter {
180                Some(ToolFilter::Blacklist(blacklist)) => {
181                    !blacklist.iter().any(|blocked| blocked == &tool.name)
182                }
183                Some(ToolFilter::Whitelist(whitelist)) => {
184                    whitelist.iter().any(|allowed| allowed == &tool.name)
185                }
186                None => true,
187            })
188            .map(|tool| {
189                let schema_value = tool.schema_as_json_value();
190                tracing::trace!(
191                    schema = ?schema_value,
192                    "Parsing tool input schema for {}",
193                    tool.name
194                );
195
196                let mut tool_spec_builder = ToolSpec::builder();
197                let registered_name = format!("{}:{}", server_name, tool.name);
198                tool_spec_builder.name(registered_name.clone());
199                tool_spec_builder.description(tool.description.unwrap_or_default());
200
201                match schema_value {
202                    serde_json::Value::Null => {}
203                    value => {
204                        let schema: Schema = serde_json::from_value(value)
205                            .context("Failed to parse tool input schema")?;
206                        tool_spec_builder.parameters_schema(schema);
207                    }
208                }
209
210                let tool_spec = tool_spec_builder
211                    .build()
212                    .context("Failed to build tool spec")?;
213                Ok(Box::new(McpTool {
214                    client: Arc::clone(&self.service),
215                    registered_name,
216                    server_tool_name: tool.name.into(),
217                    tool_spec,
218                }) as Box<dyn Tool>)
219            })
220            .collect::<Result<Vec<_>>>()
221            .context("Failed to build mcp tool specs")?;
222        Ok(tools)
223    }
224
225    fn name(&self) -> Cow<'_, str> {
226        self.name().into()
227    }
228}
229
230#[derive(Clone)]
231struct McpTool {
232    client: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
233    registered_name: String,
234    server_tool_name: String,
235    tool_spec: ToolSpec,
236}
237
238#[async_trait]
239impl Tool for McpTool {
240    async fn invoke(
241        &self,
242        _agent_context: &dyn swiftide_core::AgentContext,
243        tool_call: &ToolCall,
244    ) -> Result<
245        swiftide_core::chat_completion::ToolOutput,
246        swiftide_core::chat_completion::errors::ToolError,
247    > {
248        let args = match tool_call.args() {
249            Some(args) => Some(serde_json::from_str(args).map_err(ToolError::WrongArguments)?),
250            None => None,
251        };
252
253        let request = CallToolRequestParam {
254            name: self.server_tool_name.clone().into(),
255            arguments: args,
256        };
257
258        let Some(service) = &*self.client.read().await else {
259            return Err(
260                CommandError::ExecutorError(anyhow::anyhow!("mcp server is not running")).into(),
261            );
262        };
263
264        tracing::debug!(request = ?request, tool = self.name().as_ref(), "Invoking mcp tool");
265        let response = service
266            .call_tool(request)
267            .await
268            .context("Failed to call tool")?;
269
270        tracing::debug!(response = ?response, tool = self.name().as_ref(), "Received response from mcp tool");
271        let Some(content) = response.content else {
272            if response.is_error.unwrap_or(false) {
273                return Err(ToolError::Unknown(anyhow::anyhow!(
274                    "Error received from mcp tool without content"
275                )));
276            }
277
278            return Ok("Tool executed successfully".into());
279        };
280        let content = content
281            .into_iter()
282            .filter_map(|c| c.as_text().map(|t| t.text.clone()))
283            .collect::<Vec<_>>()
284            .join("\n");
285
286        if let Some(error) = response.is_error
287            && error
288        {
289            return Err(ToolError::Unknown(anyhow::anyhow!(
290                "Failed to execute mcp tool: {content}"
291            )));
292        }
293
294        Ok(content.into())
295    }
296
297    fn name(&self) -> std::borrow::Cow<'_, str> {
298        self.registered_name.as_str().into()
299    }
300
301    fn tool_spec(&self) -> ToolSpec {
302        self.tool_spec.clone()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use copied_from_rmcp::Calculator;
310    use rmcp::serve_server;
311    use tokio::net::{UnixListener, UnixStream};
312
313    const SOCKET_PATH: &str = "/tmp/swiftide-mcp.sock";
314    const EXPECTED_PREFIX: &str = "rmcp";
315
316    #[allow(clippy::similar_names)]
317    #[test_log::test(tokio::test(flavor = "multi_thread"))]
318    async fn test_socket() {
319        let _ = std::fs::remove_file(SOCKET_PATH);
320
321        match UnixListener::bind(SOCKET_PATH) {
322            Ok(unix_listener) => {
323                println!("Server successfully listening on {SOCKET_PATH}");
324                tokio::spawn(server(unix_listener));
325            }
326            Err(e) => {
327                println!("Unable to bind to {SOCKET_PATH}: {e}");
328            }
329        }
330
331        let client = client().await.unwrap();
332
333        let t = client.available_tools().await.unwrap();
334        assert_eq!(client.available_tools().await.unwrap().len(), 3);
335
336        let mut names = t.iter().map(|t| t.name().into_owned()).collect::<Vec<_>>();
337        names.sort();
338        assert_eq!(
339            names,
340            [
341                format!("{EXPECTED_PREFIX}:optional"),
342                format!("{EXPECTED_PREFIX}:sub"),
343                format!("{EXPECTED_PREFIX}:sum")
344            ]
345        );
346
347        let sum_name = format!("{EXPECTED_PREFIX}:sum");
348        let sum_tool = t.iter().find(|t| t.name().as_ref() == sum_name).unwrap();
349        let mut builder = ToolCall::builder()
350            .id("some")
351            .args(r#"{"b": "hello"}"#)
352            .name("test")
353            .name("test")
354            .to_owned();
355
356        assert_eq!(sum_tool.tool_spec().name, sum_name);
357
358        let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
359
360        let result = sum_tool
361            .invoke(&(), &tool_call)
362            .await
363            .unwrap()
364            .content()
365            .unwrap()
366            .to_string();
367        assert_eq!(result, "30");
368
369        let sub_name = format!("{EXPECTED_PREFIX}:sub");
370        let sub_tool = t.iter().find(|t| t.name().as_ref() == sub_name).unwrap();
371        assert_eq!(sub_tool.tool_spec().name, sub_name);
372
373        let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
374
375        let result = sub_tool
376            .invoke(&(), &tool_call)
377            .await
378            .unwrap()
379            .content()
380            .unwrap()
381            .to_string();
382        assert_eq!(result, "-10");
383
384        // The input schema type for the input param is string with null allowed
385        let optional_name = format!("{EXPECTED_PREFIX}:optional");
386        let optional_tool = t
387            .iter()
388            .find(|t| t.name().as_ref() == optional_name)
389            .unwrap();
390        assert_eq!(optional_tool.tool_spec().name, optional_name);
391        let spec = optional_tool.tool_spec();
392        let schema = spec
393            .parameters_schema
394            .expect("optional tool should expose a schema");
395        let schema_json = serde_json::to_value(schema).unwrap();
396        assert_eq!(
397            schema_json
398                .get("properties")
399                .and_then(|props| props.get("text"))
400                .and_then(|prop| prop.get("type"))
401                .and_then(serde_json::Value::as_str),
402            Some("string")
403        );
404
405        let tool_call = builder.args(r#"{"text": "hello"}"#).build().unwrap();
406
407        let result = optional_tool
408            .invoke(&(), &tool_call)
409            .await
410            .unwrap()
411            .content()
412            .unwrap()
413            .to_string();
414        assert_eq!(result, "hello");
415
416        let tool_call = builder.args(r#"{"text": null}"#).build().unwrap();
417        let result = optional_tool
418            .invoke(&(), &tool_call)
419            .await
420            .unwrap()
421            .content()
422            .unwrap()
423            .to_string();
424        assert_eq!(result, "");
425
426        // Clean up socket file
427        let _ = std::fs::remove_file(SOCKET_PATH);
428    }
429
430    async fn server(unix_listener: UnixListener) -> anyhow::Result<()> {
431        while let Ok((stream, addr)) = unix_listener.accept().await {
432            println!("Client connected: {addr:?}");
433            tokio::spawn(async move {
434                match serve_server(Calculator::new(), stream).await {
435                    Ok(server) => {
436                        println!("Server initialized successfully");
437                        if let Err(e) = server.waiting().await {
438                            println!("Error while server waiting: {e:?}");
439                        }
440                    }
441                    Err(e) => println!("Server initialization failed: {e:?}"),
442                }
443
444                anyhow::Ok(())
445            });
446        }
447        Ok(())
448    }
449
450    async fn client() -> anyhow::Result<McpToolbox> {
451        println!("Client connecting to {SOCKET_PATH}");
452        let stream = UnixStream::connect(SOCKET_PATH).await?;
453
454        // let client = serve_client((), stream).await?;
455        let client = McpToolbox::try_from_transport(stream).await?;
456        println!("Client connected and initialized successfully");
457
458        Ok(client)
459    }
460
461    #[allow(clippy::unused_self)]
462    mod copied_from_rmcp {
463        use rmcp::{
464            ErrorData as McpError, ServerHandler,
465            handler::server::tool::{Parameters, ToolRouter},
466            model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
467            schemars, tool, tool_handler,
468        };
469
470        #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
471        pub struct Request {
472            pub a: i32,
473            pub b: i32,
474        }
475
476        #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
477        pub struct OptRequest {
478            pub text: Option<String>,
479        }
480
481        #[derive(Debug, Clone)]
482        pub struct Calculator {
483            tool_router: ToolRouter<Self>,
484        }
485
486        #[rmcp::tool_router]
487        impl Calculator {
488            pub fn new() -> Self {
489                Self {
490                    tool_router: Self::tool_router(),
491                }
492            }
493
494            #[allow(clippy::unnecessary_wraps)]
495            #[tool(description = "Calculate the sum of two numbers")]
496            fn sum(
497                &self,
498                Parameters(Request { a, b }): Parameters<Request>,
499            ) -> Result<CallToolResult, McpError> {
500                Ok(CallToolResult::success(vec![Content::text(
501                    (a + b).to_string(),
502                )]))
503            }
504
505            #[allow(clippy::unnecessary_wraps)]
506            #[tool(description = "Calculate the sum of two numbers")]
507            fn sub(
508                &self,
509                Parameters(Request { a, b }): Parameters<Request>,
510            ) -> Result<CallToolResult, McpError> {
511                Ok(CallToolResult::success(vec![Content::text(
512                    (a - b).to_string(),
513                )]))
514            }
515
516            #[allow(clippy::unnecessary_wraps)]
517            #[tool(description = "Optional echo")]
518            fn optional(
519                &self,
520                Parameters(OptRequest { text }): Parameters<OptRequest>,
521            ) -> Result<CallToolResult, McpError> {
522                Ok(CallToolResult::success(vec![Content::text(
523                    text.unwrap_or_default(),
524                )]))
525            }
526        }
527
528        #[tool_handler]
529        impl ServerHandler for Calculator {
530            fn get_info(&self) -> ServerInfo {
531                ServerInfo {
532                    instructions: Some("A simple calculator".into()),
533                    capabilities: ServerCapabilities::builder().enable_tools().build(),
534                    ..Default::default()
535                }
536            }
537        }
538    }
539}