1use std::borrow::Cow;
7use std::{collections::HashMap, sync::Arc};
8
9use anyhow::{Context as _, Result};
10use async_trait::async_trait;
11use rmcp::RoleClient;
12use rmcp::model::{ClientInfo, Implementation, InitializeRequestParam};
13use rmcp::service::RunningService;
14use rmcp::transport::IntoTransport;
15use rmcp::{ServiceExt, model::CallToolRequestParam};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use swiftide_core::CommandError;
19use swiftide_core::chat_completion::ToolCall;
20use swiftide_core::{
21 Tool, ToolBox,
22 chat_completion::{ParamSpec, ParamType, ToolSpec, errors::ToolError},
23};
24use tokio::sync::RwLock;
25
26#[derive(Clone, Debug, Serialize, Deserialize)]
28pub enum ToolFilter {
29 Blacklist(Vec<String>),
30 Whitelist(Vec<String>),
31}
32
33#[derive(Clone)]
38pub struct McpToolbox {
39 service: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
40
41 name: Option<String>,
43
44 filter: Arc<Option<ToolFilter>>,
45}
46
47impl McpToolbox {
48 pub fn with_blacklist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
50 &mut self,
51 blacklist: I,
52 ) -> &mut Self {
53 let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
54 self.filter = Some(ToolFilter::Blacklist(list)).into();
55 self
56 }
57
58 pub fn with_whitelist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
60 &mut self,
61 blacklist: I,
62 ) -> &mut Self {
63 let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
64 self.filter = Some(ToolFilter::Whitelist(list)).into();
65 self
66 }
67
68 pub fn with_filter(&mut self, filter: ToolFilter) -> &mut Self {
70 self.filter = Some(filter).into();
71 self
72 }
73
74 pub fn with_name(&mut self, name: impl Into<String>) -> &mut Self {
76 self.name = Some(name.into());
77 self
78 }
79
80 pub fn name(&self) -> &str {
81 self.name.as_deref().unwrap_or("MCP Toolbox")
82 }
83
84 pub async fn try_from_transport<
90 E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
91 A,
92 >(
93 transport: impl IntoTransport<RoleClient, E, A>,
94 ) -> Result<Self> {
95 let info = Self::default_client_info();
96 let service = Arc::new(RwLock::new(Some(info.serve(transport).await?)));
97
98 Ok(Self {
99 service,
100 filter: None.into(),
101 name: None,
102 })
103 }
104
105 pub fn from_running_service(
107 service: RunningService<RoleClient, InitializeRequestParam>,
108 ) -> Self {
109 Self {
110 service: Arc::new(RwLock::new(Some(service))),
111 filter: None.into(),
112 name: None,
113 }
114 }
115
116 fn default_client_info() -> ClientInfo {
117 ClientInfo {
118 client_info: Implementation {
119 name: "swiftide".into(),
120 version: env!("CARGO_PKG_VERSION").into(),
121 },
122 ..Default::default()
123 }
124 }
125
126 pub async fn cancel(&mut self) -> Result<()> {
134 let mut lock = self.service.write().await;
135 let Some(service) = std::mem::take(&mut *lock) else {
136 tracing::warn!("mcp server is not running");
137 return Ok(());
138 };
139
140 tracing::debug!(name = self.name(), "Stopping mcp server");
141
142 service
143 .cancel()
144 .await
145 .context("failed to stop mcp server")?;
146
147 Ok(())
148 }
149}
150
151#[derive(Deserialize, Debug)]
152struct ToolInputSchema {
153 #[serde(rename = "type")]
154 #[allow(dead_code)]
155 pub type_: String, pub properties: Option<HashMap<String, Value>>,
157 pub required: Option<Vec<String>>,
158}
159
160#[async_trait]
161impl ToolBox for McpToolbox {
162 #[tracing::instrument(skip_all)]
163 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
164 let Some(service) = &*self.service.read().await else {
165 anyhow::bail!("No service available");
166 };
167 tracing::debug!(name = self.name(), "Connecting to mcp server");
168 let peer_info = service.peer_info();
169 tracing::debug!(?peer_info, name = self.name(), "Connected to mcp server");
170
171 tracing::debug!(name = self.name(), "Listing tools from mcp server");
172 let tools = service
173 .list_all_tools()
174 .await
175 .context("Failed to list tools")?;
176
177 let tools = tools
178 .into_iter()
179 .map(|t| {
180 let schema: ToolInputSchema = serde_json::from_value(t.schema_as_json_value())
181 .context("Failed to parse tool input schema")?;
182
183 tracing::trace!(?schema, "Parsing tool input schema for {}", t.name);
184
185 let mut tool_spec = ToolSpec::builder()
186 .name(t.name.clone())
187 .description(t.description.unwrap_or_default())
188 .to_owned();
189 let mut parameters = Vec::new();
190
191 if let Some(mut p) = schema.properties {
192 for (name, value) in &mut p {
193 let param = ParamSpec::builder()
194 .name(name)
195 .description(
196 value
197 .get("description")
198 .and_then(Value::as_str)
199 .unwrap_or(""),
200 )
201 .ty(value
202 .get_mut("type")
203 .and_then(|t| serde_json::from_value(t.take()).ok())
204 .unwrap_or(ParamType::String))
205 .required(schema.required.as_ref().is_some_and(|r| r.contains(name)))
206 .build()
207 .context("Failed to build parameters for mcp tool")?;
208
209 parameters.push(param);
210 }
211 }
212
213 tool_spec.parameters(parameters);
214 let tool_spec = tool_spec.build().context("Failed to build tool spec")?;
215
216 Ok(Box::new(McpTool {
217 client: Arc::clone(&self.service),
218 tool_name: t.name.into(),
219 tool_spec,
220 }) as Box<dyn Tool>)
221 })
222 .collect::<Result<Vec<_>>>()
223 .context("Failed to build mcp tool specs")?;
224
225 if let Some(filter) = self.filter.as_ref() {
226 match filter {
227 ToolFilter::Blacklist(blacklist) => {
228 let blacklist = blacklist.iter().map(String::as_str).collect::<Vec<_>>();
229 Ok(tools
230 .into_iter()
231 .filter(|t| !blacklist.contains(&t.name().as_ref()))
232 .collect())
233 }
234 ToolFilter::Whitelist(whitelist) => {
235 let whitelist = whitelist.iter().map(String::as_str).collect::<Vec<_>>();
236 Ok(tools
237 .into_iter()
238 .filter(|t| whitelist.contains(&t.name().as_ref()))
239 .collect())
240 }
241 }
242 } else {
243 Ok(tools)
244 }
245 }
246
247 fn name(&self) -> Cow<'_, str> {
248 self.name().into()
249 }
250}
251
252#[derive(Clone)]
253struct McpTool {
254 client: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
255 tool_name: String,
256 tool_spec: ToolSpec,
257}
258
259#[async_trait]
260impl Tool for McpTool {
261 async fn invoke(
262 &self,
263 _agent_context: &dyn swiftide_core::AgentContext,
264 tool_call: &ToolCall,
265 ) -> Result<
266 swiftide_core::chat_completion::ToolOutput,
267 swiftide_core::chat_completion::errors::ToolError,
268 > {
269 let args = match tool_call.args() {
270 Some(args) => Some(serde_json::from_str(args).map_err(ToolError::WrongArguments)?),
271 None => None,
272 };
273
274 let request = CallToolRequestParam {
275 name: self.tool_name.clone().into(),
276 arguments: args,
277 };
278
279 let Some(service) = &*self.client.read().await else {
280 return Err(
281 CommandError::ExecutorError(anyhow::anyhow!("mcp server is not running")).into(),
282 );
283 };
284
285 tracing::debug!(request = ?request, tool = self.name().as_ref(), "Invoking mcp tool");
286 let response = service
287 .call_tool(request)
288 .await
289 .context("Failed to call tool")?;
290
291 tracing::debug!(response = ?response, tool = self.name().as_ref(), "Received response from mcp tool");
292 let Some(content) = response.content else {
293 if response.is_error.unwrap_or(false) {
294 return Err(ToolError::Unknown(anyhow::anyhow!(
295 "Error received from mcp tool without content"
296 )));
297 }
298
299 return Ok("Tool executed successfully".into());
300 };
301 let content = content
302 .into_iter()
303 .filter_map(|c| c.as_text().map(|t| t.text.to_string()))
304 .collect::<Vec<_>>()
305 .join("\n");
306
307 if let Some(error) = response.is_error
308 && error
309 {
310 return Err(ToolError::Unknown(anyhow::anyhow!(
311 "Failed to execute mcp tool: {content}"
312 )));
313 }
314
315 Ok(content.into())
316 }
317
318 fn name(&self) -> std::borrow::Cow<'_, str> {
319 self.tool_name.as_str().into()
320 }
321
322 fn tool_spec(&self) -> ToolSpec {
323 self.tool_spec.clone()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use copied_from_rmcp::Calculator;
331 use rmcp::serve_server;
332 use serde_json::json;
333 use tokio::net::{UnixListener, UnixStream};
334
335 const SOCKET_PATH: &str = "/tmp/swiftide-mcp.sock";
336
337 #[allow(clippy::similar_names)]
338 #[test_log::test(tokio::test(flavor = "multi_thread"))]
339 async fn test_socket() {
340 let _ = std::fs::remove_file(SOCKET_PATH);
341
342 match UnixListener::bind(SOCKET_PATH) {
343 Ok(unix_listener) => {
344 println!("Server successfully listening on {SOCKET_PATH}");
345 tokio::spawn(server(unix_listener));
346 }
347 Err(e) => {
348 println!("Unable to bind to {SOCKET_PATH}: {e}");
349 }
350 }
351
352 let client = client().await.unwrap();
353
354 let t = client.available_tools().await.unwrap();
355 assert_eq!(client.available_tools().await.unwrap().len(), 3);
356
357 let mut names = t.iter().map(|t| t.name()).collect::<Vec<_>>();
358 names.sort();
359 assert_eq!(names, ["optional", "sub", "sum"]);
360
361 let sum_tool = t.iter().find(|t| t.name() == "sum").unwrap();
362 let mut builder = ToolCall::builder()
363 .id("some")
364 .args(r#"{"b": "hello"}"#)
365 .name("test")
366 .name("test")
367 .to_owned();
368
369 assert_eq!(sum_tool.tool_spec().name, "sum");
370
371 let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
372
373 let result = sum_tool
374 .invoke(&(), &tool_call)
375 .await
376 .unwrap()
377 .content()
378 .unwrap()
379 .to_string();
380 assert_eq!(result, "30");
381
382 let sub_tool = t.iter().find(|t| t.name() == "sub").unwrap();
383 assert_eq!(sub_tool.tool_spec().name, "sub");
384
385 let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
386
387 let result = sub_tool
388 .invoke(&(), &tool_call)
389 .await
390 .unwrap()
391 .content()
392 .unwrap()
393 .to_string();
394 assert_eq!(result, "-10");
395
396 let optional_tool = t.iter().find(|t| t.name() == "optional").unwrap();
398 dbg!(optional_tool.tool_spec());
399 assert_eq!(optional_tool.tool_spec().name, "optional");
400 assert_eq!(optional_tool.tool_spec().parameters.len(), 1);
401 assert_eq!(
402 serde_json::to_string(&optional_tool.tool_spec().parameters[0].ty).unwrap(),
403 json!("string").to_string()
404 );
405
406 let tool_call = builder.args(r#"{"text": "hello"}"#).build().unwrap();
407
408 let result = optional_tool
409 .invoke(&(), &tool_call)
410 .await
411 .unwrap()
412 .content()
413 .unwrap()
414 .to_string();
415 assert_eq!(result, "hello");
416
417 let tool_call = builder.args(r#"{"text": null}"#).build().unwrap();
418 let result = optional_tool
419 .invoke(&(), &tool_call)
420 .await
421 .unwrap()
422 .content()
423 .unwrap()
424 .to_string();
425 assert_eq!(result, "");
426
427 let _ = std::fs::remove_file(SOCKET_PATH);
429 }
430
431 async fn server(unix_listener: UnixListener) -> anyhow::Result<()> {
432 while let Ok((stream, addr)) = unix_listener.accept().await {
433 println!("Client connected: {addr:?}");
434 tokio::spawn(async move {
435 match serve_server(Calculator::new(), stream).await {
436 Ok(server) => {
437 println!("Server initialized successfully");
438 if let Err(e) = server.waiting().await {
439 println!("Error while server waiting: {e:?}");
440 }
441 }
442 Err(e) => println!("Server initialization failed: {e:?}"),
443 }
444
445 anyhow::Ok(())
446 });
447 }
448 Ok(())
449 }
450
451 async fn client() -> anyhow::Result<McpToolbox> {
452 println!("Client connecting to {SOCKET_PATH}");
453 let stream = UnixStream::connect(SOCKET_PATH).await?;
454
455 let client = McpToolbox::try_from_transport(stream).await?;
457 println!("Client connected and initialized successfully");
458
459 Ok(client)
460 }
461
462 #[allow(clippy::unused_self)]
463 mod copied_from_rmcp {
464 use rmcp::{
465 ErrorData as McpError, ServerHandler,
466 handler::server::tool::{Parameters, ToolRouter},
467 model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
468 schemars, tool, tool_handler,
469 };
470
471 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
472 pub struct Request {
473 pub a: i32,
474 pub b: i32,
475 }
476
477 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
478 pub struct OptRequest {
479 pub text: Option<String>,
480 }
481
482 #[derive(Debug, Clone)]
483 pub struct Calculator {
484 tool_router: ToolRouter<Self>,
485 }
486
487 #[rmcp::tool_router]
488 impl Calculator {
489 pub fn new() -> Self {
490 Self {
491 tool_router: Self::tool_router(),
492 }
493 }
494
495 #[allow(clippy::unnecessary_wraps)]
496 #[tool(description = "Calculate the sum of two numbers")]
497 fn sum(
498 &self,
499 Parameters(Request { a, b }): Parameters<Request>,
500 ) -> Result<CallToolResult, McpError> {
501 Ok(CallToolResult::success(vec![Content::text(
502 (a + b).to_string(),
503 )]))
504 }
505
506 #[allow(clippy::unnecessary_wraps)]
507 #[tool(description = "Calculate the sum of two numbers")]
508 fn sub(
509 &self,
510 Parameters(Request { a, b }): Parameters<Request>,
511 ) -> Result<CallToolResult, McpError> {
512 Ok(CallToolResult::success(vec![Content::text(
513 (a - b).to_string(),
514 )]))
515 }
516
517 #[allow(clippy::unnecessary_wraps)]
518 #[tool(description = "Optional echo")]
519 fn optional(
520 &self,
521 Parameters(OptRequest { text }): Parameters<OptRequest>,
522 ) -> Result<CallToolResult, McpError> {
523 Ok(CallToolResult::success(vec![Content::text(
524 text.unwrap_or_default(),
525 )]))
526 }
527 }
528
529 #[tool_handler]
530 impl ServerHandler for Calculator {
531 fn get_info(&self) -> ServerInfo {
532 ServerInfo {
533 instructions: Some("A simple calculator".into()),
534 capabilities: ServerCapabilities::builder().enable_tools().build(),
535 ..Default::default()
536 }
537 }
538 }
539 }
540}