swiftide_agents/tools/
mcp.rs

1//! Add tools provided by an MCP server to an agent
2//!
3//! Uses the `rmcp` crate to connect to an MCP server and list available tools, and invoke them
4//!
5//! Supports any transport that the `rmcp` crate supports
6use std::borrow::Cow;
7use std::{collections::HashMap, sync::Arc};
8
9use anyhow::{Context as _, Result};
10use async_trait::async_trait;
11use rmcp::RoleClient;
12use rmcp::model::{ClientInfo, Implementation, InitializeRequestParam};
13use rmcp::service::RunningService;
14use rmcp::transport::IntoTransport;
15use rmcp::{ServiceExt, model::CallToolRequestParam};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use swiftide_core::CommandError;
19use swiftide_core::chat_completion::ToolCall;
20use swiftide_core::{
21    Tool, ToolBox,
22    chat_completion::{ParamSpec, ParamType, ToolSpec, errors::ToolError},
23};
24use tokio::sync::RwLock;
25
26/// A filter to apply to the available tools
27#[derive(Clone, Debug, Serialize, Deserialize)]
28pub enum ToolFilter {
29    Blacklist(Vec<String>),
30    Whitelist(Vec<String>),
31}
32
33/// Connects to an MCP server and provides tools at runtime to the agent.
34///
35/// WARN: The rmcp has a quirky feature to serve from `()`. This does not work; serve from
36/// `ClientInfo` instead, or from the transport and `Swiftide` will handle the rest.
37#[derive(Clone)]
38pub struct McpToolbox {
39    service: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
40
41    /// Optional human readable name for the toolbox
42    name: Option<String>,
43
44    filter: Arc<Option<ToolFilter>>,
45}
46
47impl McpToolbox {
48    /// Blacklist tools by name, the agent will not be able to use these tools
49    pub fn with_blacklist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
50        &mut self,
51        blacklist: I,
52    ) -> &mut Self {
53        let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
54        self.filter = Some(ToolFilter::Blacklist(list)).into();
55        self
56    }
57
58    /// Whitelist tools by name, the agent will only be able to use these tools
59    pub fn with_whitelist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
60        &mut self,
61        blacklist: I,
62    ) -> &mut Self {
63        let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
64        self.filter = Some(ToolFilter::Whitelist(list)).into();
65        self
66    }
67
68    /// Apply a custom filter to the tools
69    pub fn with_filter(&mut self, filter: ToolFilter) -> &mut Self {
70        self.filter = Some(filter).into();
71        self
72    }
73
74    /// Apply an optional name to the toolbox
75    pub fn with_name(&mut self, name: impl Into<String>) -> &mut Self {
76        self.name = Some(name.into());
77        self
78    }
79
80    pub fn name(&self) -> &str {
81        self.name.as_deref().unwrap_or("MCP Toolbox")
82    }
83
84    /// Create a new toolbox from a transport
85    ///
86    /// # Errors
87    ///
88    /// Errors if the transport fails to connect
89    pub async fn try_from_transport<
90        E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
91        A,
92    >(
93        transport: impl IntoTransport<RoleClient, E, A>,
94    ) -> Result<Self> {
95        let info = Self::default_client_info();
96        let service = Arc::new(RwLock::new(Some(info.serve(transport).await?)));
97
98        Ok(Self {
99            service,
100            filter: None.into(),
101            name: None,
102        })
103    }
104
105    /// Create a new toolbox from a running service
106    pub fn from_running_service(
107        service: RunningService<RoleClient, InitializeRequestParam>,
108    ) -> Self {
109        Self {
110            service: Arc::new(RwLock::new(Some(service))),
111            filter: None.into(),
112            name: None,
113        }
114    }
115
116    fn default_client_info() -> ClientInfo {
117        ClientInfo {
118            client_info: Implementation {
119                name: "swiftide".into(),
120                version: env!("CARGO_PKG_VERSION").into(),
121            },
122            ..Default::default()
123        }
124    }
125
126    /// Disconnects from the MCP server if it is running
127    ///
128    /// If it is not running, an Ok is returned and it logs a tracing message
129    ///
130    /// # Errors
131    ///
132    /// Errors if the service is running but cannot be stopped
133    pub async fn cancel(&mut self) -> Result<()> {
134        let mut lock = self.service.write().await;
135        let Some(service) = std::mem::take(&mut *lock) else {
136            tracing::warn!("mcp server is not running");
137            return Ok(());
138        };
139
140        tracing::debug!(name = self.name(), "Stopping mcp server");
141
142        service
143            .cancel()
144            .await
145            .context("failed to stop mcp server")?;
146
147        Ok(())
148    }
149}
150
151#[derive(Deserialize, Debug)]
152struct ToolInputSchema {
153    #[serde(rename = "type")]
154    #[allow(dead_code)]
155    pub type_: String, // This _must_ be object
156    pub properties: Option<HashMap<String, Value>>,
157    pub required: Option<Vec<String>>,
158}
159
160#[async_trait]
161impl ToolBox for McpToolbox {
162    #[tracing::instrument(skip_all)]
163    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
164        let Some(service) = &*self.service.read().await else {
165            anyhow::bail!("No service available");
166        };
167        tracing::debug!(name = self.name(), "Connecting to mcp server");
168        let peer_info = service.peer_info();
169        tracing::debug!(?peer_info, name = self.name(), "Connected to mcp server");
170
171        tracing::debug!(name = self.name(), "Listing tools from mcp server");
172        let tools = service
173            .list_all_tools()
174            .await
175            .context("Failed to list tools")?;
176
177        let tools = tools
178            .into_iter()
179            .map(|t| {
180                let schema: ToolInputSchema = serde_json::from_value(t.schema_as_json_value())
181                    .context("Failed to parse tool input schema")?;
182
183                tracing::trace!(?schema, "Parsing tool input schema for {}", t.name);
184
185                let mut tool_spec = ToolSpec::builder()
186                    .name(t.name.clone())
187                    .description(t.description.unwrap_or_default())
188                    .to_owned();
189                let mut parameters = Vec::new();
190
191                if let Some(mut p) = schema.properties {
192                    for (name, value) in &mut p {
193                        let param = ParamSpec::builder()
194                            .name(name)
195                            .description(
196                                value
197                                    .get("description")
198                                    .and_then(Value::as_str)
199                                    .unwrap_or(""),
200                            )
201                            .ty(value
202                                .get_mut("type")
203                                .and_then(|t| serde_json::from_value(t.take()).ok())
204                                .unwrap_or(ParamType::String))
205                            .required(schema.required.as_ref().is_some_and(|r| r.contains(name)))
206                            .build()
207                            .context("Failed to build parameters for mcp tool")?;
208
209                        parameters.push(param);
210                    }
211                }
212
213                tool_spec.parameters(parameters);
214                let tool_spec = tool_spec.build().context("Failed to build tool spec")?;
215
216                Ok(Box::new(McpTool {
217                    client: Arc::clone(&self.service),
218                    tool_name: t.name.into(),
219                    tool_spec,
220                }) as Box<dyn Tool>)
221            })
222            .collect::<Result<Vec<_>>>()
223            .context("Failed to build mcp tool specs")?;
224
225        if let Some(filter) = self.filter.as_ref() {
226            match filter {
227                ToolFilter::Blacklist(blacklist) => {
228                    let blacklist = blacklist.iter().map(String::as_str).collect::<Vec<_>>();
229                    Ok(tools
230                        .into_iter()
231                        .filter(|t| !blacklist.contains(&t.name().as_ref()))
232                        .collect())
233                }
234                ToolFilter::Whitelist(whitelist) => {
235                    let whitelist = whitelist.iter().map(String::as_str).collect::<Vec<_>>();
236                    Ok(tools
237                        .into_iter()
238                        .filter(|t| whitelist.contains(&t.name().as_ref()))
239                        .collect())
240                }
241            }
242        } else {
243            Ok(tools)
244        }
245    }
246
247    fn name(&self) -> Cow<'_, str> {
248        self.name().into()
249    }
250}
251
252#[derive(Clone)]
253struct McpTool {
254    client: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
255    tool_name: String,
256    tool_spec: ToolSpec,
257}
258
259#[async_trait]
260impl Tool for McpTool {
261    async fn invoke(
262        &self,
263        _agent_context: &dyn swiftide_core::AgentContext,
264        tool_call: &ToolCall,
265    ) -> Result<
266        swiftide_core::chat_completion::ToolOutput,
267        swiftide_core::chat_completion::errors::ToolError,
268    > {
269        let args = match tool_call.args() {
270            Some(args) => Some(serde_json::from_str(args).map_err(ToolError::WrongArguments)?),
271            None => None,
272        };
273
274        let request = CallToolRequestParam {
275            name: self.tool_name.clone().into(),
276            arguments: args,
277        };
278
279        let Some(service) = &*self.client.read().await else {
280            return Err(
281                CommandError::ExecutorError(anyhow::anyhow!("mcp server is not running")).into(),
282            );
283        };
284
285        tracing::debug!(request = ?request, tool = self.name().as_ref(), "Invoking mcp tool");
286        let response = service
287            .call_tool(request)
288            .await
289            .context("Failed to call tool")?;
290
291        tracing::debug!(response = ?response, tool = self.name().as_ref(), "Received response from mcp tool");
292        let Some(content) = response.content else {
293            if response.is_error.unwrap_or(false) {
294                return Err(ToolError::Unknown(anyhow::anyhow!(
295                    "Error received from mcp tool without content"
296                )));
297            }
298
299            return Ok("Tool executed successfully".into());
300        };
301        let content = content
302            .into_iter()
303            .filter_map(|c| c.as_text().map(|t| t.text.to_string()))
304            .collect::<Vec<_>>()
305            .join("\n");
306
307        if let Some(error) = response.is_error
308            && error
309        {
310            return Err(ToolError::Unknown(anyhow::anyhow!(
311                "Failed to execute mcp tool: {content}"
312            )));
313        }
314
315        Ok(content.into())
316    }
317
318    fn name(&self) -> std::borrow::Cow<'_, str> {
319        self.tool_name.as_str().into()
320    }
321
322    fn tool_spec(&self) -> ToolSpec {
323        self.tool_spec.clone()
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use copied_from_rmcp::Calculator;
331    use rmcp::serve_server;
332    use serde_json::json;
333    use tokio::net::{UnixListener, UnixStream};
334
335    const SOCKET_PATH: &str = "/tmp/swiftide-mcp.sock";
336
337    #[allow(clippy::similar_names)]
338    #[test_log::test(tokio::test(flavor = "multi_thread"))]
339    async fn test_socket() {
340        let _ = std::fs::remove_file(SOCKET_PATH);
341
342        match UnixListener::bind(SOCKET_PATH) {
343            Ok(unix_listener) => {
344                println!("Server successfully listening on {SOCKET_PATH}");
345                tokio::spawn(server(unix_listener));
346            }
347            Err(e) => {
348                println!("Unable to bind to {SOCKET_PATH}: {e}");
349            }
350        }
351
352        let client = client().await.unwrap();
353
354        let t = client.available_tools().await.unwrap();
355        assert_eq!(client.available_tools().await.unwrap().len(), 3);
356
357        let mut names = t.iter().map(|t| t.name()).collect::<Vec<_>>();
358        names.sort();
359        assert_eq!(names, ["optional", "sub", "sum"]);
360
361        let sum_tool = t.iter().find(|t| t.name() == "sum").unwrap();
362        let mut builder = ToolCall::builder()
363            .id("some")
364            .args(r#"{"b": "hello"}"#)
365            .name("test")
366            .name("test")
367            .to_owned();
368
369        assert_eq!(sum_tool.tool_spec().name, "sum");
370
371        let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
372
373        let result = sum_tool
374            .invoke(&(), &tool_call)
375            .await
376            .unwrap()
377            .content()
378            .unwrap()
379            .to_string();
380        assert_eq!(result, "30");
381
382        let sub_tool = t.iter().find(|t| t.name() == "sub").unwrap();
383        assert_eq!(sub_tool.tool_spec().name, "sub");
384
385        let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
386
387        let result = sub_tool
388            .invoke(&(), &tool_call)
389            .await
390            .unwrap()
391            .content()
392            .unwrap()
393            .to_string();
394        assert_eq!(result, "-10");
395
396        // The input schema type for the input param is ["string", "null"]
397        let optional_tool = t.iter().find(|t| t.name() == "optional").unwrap();
398        dbg!(optional_tool.tool_spec());
399        assert_eq!(optional_tool.tool_spec().name, "optional");
400        assert_eq!(optional_tool.tool_spec().parameters.len(), 1);
401        assert_eq!(
402            serde_json::to_string(&optional_tool.tool_spec().parameters[0].ty).unwrap(),
403            json!("string").to_string()
404        );
405
406        let tool_call = builder.args(r#"{"text": "hello"}"#).build().unwrap();
407
408        let result = optional_tool
409            .invoke(&(), &tool_call)
410            .await
411            .unwrap()
412            .content()
413            .unwrap()
414            .to_string();
415        assert_eq!(result, "hello");
416
417        let tool_call = builder.args(r#"{"text": null}"#).build().unwrap();
418        let result = optional_tool
419            .invoke(&(), &tool_call)
420            .await
421            .unwrap()
422            .content()
423            .unwrap()
424            .to_string();
425        assert_eq!(result, "");
426
427        // Clean up socket file
428        let _ = std::fs::remove_file(SOCKET_PATH);
429    }
430
431    async fn server(unix_listener: UnixListener) -> anyhow::Result<()> {
432        while let Ok((stream, addr)) = unix_listener.accept().await {
433            println!("Client connected: {addr:?}");
434            tokio::spawn(async move {
435                match serve_server(Calculator::new(), stream).await {
436                    Ok(server) => {
437                        println!("Server initialized successfully");
438                        if let Err(e) = server.waiting().await {
439                            println!("Error while server waiting: {e:?}");
440                        }
441                    }
442                    Err(e) => println!("Server initialization failed: {e:?}"),
443                }
444
445                anyhow::Ok(())
446            });
447        }
448        Ok(())
449    }
450
451    async fn client() -> anyhow::Result<McpToolbox> {
452        println!("Client connecting to {SOCKET_PATH}");
453        let stream = UnixStream::connect(SOCKET_PATH).await?;
454
455        // let client = serve_client((), stream).await?;
456        let client = McpToolbox::try_from_transport(stream).await?;
457        println!("Client connected and initialized successfully");
458
459        Ok(client)
460    }
461
462    #[allow(clippy::unused_self)]
463    mod copied_from_rmcp {
464        use rmcp::{
465            ErrorData as McpError, ServerHandler,
466            handler::server::tool::{Parameters, ToolRouter},
467            model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
468            schemars, tool, tool_handler,
469        };
470
471        #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
472        pub struct Request {
473            pub a: i32,
474            pub b: i32,
475        }
476
477        #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
478        pub struct OptRequest {
479            pub text: Option<String>,
480        }
481
482        #[derive(Debug, Clone)]
483        pub struct Calculator {
484            tool_router: ToolRouter<Self>,
485        }
486
487        #[rmcp::tool_router]
488        impl Calculator {
489            pub fn new() -> Self {
490                Self {
491                    tool_router: Self::tool_router(),
492                }
493            }
494
495            #[allow(clippy::unnecessary_wraps)]
496            #[tool(description = "Calculate the sum of two numbers")]
497            fn sum(
498                &self,
499                Parameters(Request { a, b }): Parameters<Request>,
500            ) -> Result<CallToolResult, McpError> {
501                Ok(CallToolResult::success(vec![Content::text(
502                    (a + b).to_string(),
503                )]))
504            }
505
506            #[allow(clippy::unnecessary_wraps)]
507            #[tool(description = "Calculate the sum of two numbers")]
508            fn sub(
509                &self,
510                Parameters(Request { a, b }): Parameters<Request>,
511            ) -> Result<CallToolResult, McpError> {
512                Ok(CallToolResult::success(vec![Content::text(
513                    (a - b).to_string(),
514                )]))
515            }
516
517            #[allow(clippy::unnecessary_wraps)]
518            #[tool(description = "Optional echo")]
519            fn optional(
520                &self,
521                Parameters(OptRequest { text }): Parameters<OptRequest>,
522            ) -> Result<CallToolResult, McpError> {
523                Ok(CallToolResult::success(vec![Content::text(
524                    text.unwrap_or_default(),
525                )]))
526            }
527        }
528
529        #[tool_handler]
530        impl ServerHandler for Calculator {
531            fn get_info(&self) -> ServerInfo {
532                ServerInfo {
533                    instructions: Some("A simple calculator".into()),
534                    capabilities: ServerCapabilities::builder().enable_tools().build(),
535                    ..Default::default()
536                }
537            }
538        }
539    }
540}