1use std::borrow::Cow;
7use std::sync::Arc;
8
9use anyhow::{Context as _, Result};
10use async_trait::async_trait;
11use rmcp::RoleClient;
12use rmcp::model::{ClientInfo, Implementation, InitializeRequestParam};
13use rmcp::service::RunningService;
14use rmcp::transport::IntoTransport;
15use rmcp::{ServiceExt, model::CallToolRequestParam};
16use schemars::Schema;
17use serde::{Deserialize, Serialize};
18use swiftide_core::CommandError;
19use swiftide_core::chat_completion::ToolCall;
20use swiftide_core::{
21 Tool, ToolBox,
22 chat_completion::{ToolSpec, errors::ToolError},
23};
24use tokio::sync::RwLock;
25
26#[derive(Clone, Debug, Serialize, Deserialize)]
28pub enum ToolFilter {
29 Blacklist(Vec<String>),
30 Whitelist(Vec<String>),
31}
32
33#[derive(Clone)]
38pub struct McpToolbox {
39 service: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
40
41 name: Option<String>,
43
44 filter: Arc<Option<ToolFilter>>,
45}
46
47impl McpToolbox {
48 pub fn with_blacklist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
50 &mut self,
51 blacklist: I,
52 ) -> &mut Self {
53 let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
54 self.filter = Some(ToolFilter::Blacklist(list)).into();
55 self
56 }
57
58 pub fn with_whitelist<ITEM: Into<String>, I: IntoIterator<Item = ITEM>>(
60 &mut self,
61 blacklist: I,
62 ) -> &mut Self {
63 let list = blacklist.into_iter().map(Into::into).collect::<Vec<_>>();
64 self.filter = Some(ToolFilter::Whitelist(list)).into();
65 self
66 }
67
68 pub fn with_filter(&mut self, filter: ToolFilter) -> &mut Self {
70 self.filter = Some(filter).into();
71 self
72 }
73
74 pub fn with_name(&mut self, name: impl Into<String>) -> &mut Self {
76 self.name = Some(name.into());
77 self
78 }
79
80 pub fn name(&self) -> &str {
81 self.name.as_deref().unwrap_or("MCP Toolbox")
82 }
83
84 pub async fn try_from_transport<
90 E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
91 A,
92 >(
93 transport: impl IntoTransport<RoleClient, E, A>,
94 ) -> Result<Self> {
95 let info = Self::default_client_info();
96 let service = Arc::new(RwLock::new(Some(info.serve(transport).await?)));
97
98 Ok(Self {
99 service,
100 filter: None.into(),
101 name: None,
102 })
103 }
104
105 pub fn from_running_service(
107 service: RunningService<RoleClient, InitializeRequestParam>,
108 ) -> Self {
109 Self {
110 service: Arc::new(RwLock::new(Some(service))),
111 filter: None.into(),
112 name: None,
113 }
114 }
115
116 fn default_client_info() -> ClientInfo {
117 ClientInfo {
118 client_info: Implementation {
119 name: "swiftide".into(),
120 version: env!("CARGO_PKG_VERSION").into(),
121 },
122 ..Default::default()
123 }
124 }
125
126 pub async fn cancel(&mut self) -> Result<()> {
134 let mut lock = self.service.write().await;
135 let Some(service) = std::mem::take(&mut *lock) else {
136 tracing::warn!("mcp server is not running");
137 return Ok(());
138 };
139
140 tracing::debug!(name = self.name(), "Stopping mcp server");
141
142 service
143 .cancel()
144 .await
145 .context("failed to stop mcp server")?;
146
147 Ok(())
148 }
149}
150
151#[async_trait]
152impl ToolBox for McpToolbox {
153 #[tracing::instrument(skip_all)]
154 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
155 let Some(service) = &*self.service.read().await else {
156 anyhow::bail!("No service available");
157 };
158 tracing::debug!(name = self.name(), "Connecting to mcp server");
159 let peer_info = service.peer_info();
160 tracing::debug!(?peer_info, name = self.name(), "Connected to mcp server");
161
162 tracing::debug!(name = self.name(), "Listing tools from mcp server");
163 let tools = service
164 .list_all_tools()
165 .await
166 .context("Failed to list tools")?;
167
168 let filter = self.filter.as_ref().clone();
169 let mut server_name = peer_info
170 .map_or("mcp", |info| info.server_info.name.as_str())
171 .trim()
172 .to_owned();
173 if server_name.is_empty() {
174 server_name = "mcp".into();
175 }
176
177 let tools = tools
178 .into_iter()
179 .filter(|tool| match &filter {
180 Some(ToolFilter::Blacklist(blacklist)) => {
181 !blacklist.iter().any(|blocked| blocked == &tool.name)
182 }
183 Some(ToolFilter::Whitelist(whitelist)) => {
184 whitelist.iter().any(|allowed| allowed == &tool.name)
185 }
186 None => true,
187 })
188 .map(|tool| {
189 let schema_value = tool.schema_as_json_value();
190 tracing::trace!(
191 schema = ?schema_value,
192 "Parsing tool input schema for {}",
193 tool.name
194 );
195
196 let mut tool_spec_builder = ToolSpec::builder();
197 let registered_name = format!("{}:{}", server_name, tool.name);
198 tool_spec_builder.name(registered_name.clone());
199 tool_spec_builder.description(tool.description.unwrap_or_default());
200
201 match schema_value {
202 serde_json::Value::Null => {}
203 value => {
204 let schema: Schema = serde_json::from_value(value)
205 .context("Failed to parse tool input schema")?;
206 tool_spec_builder.parameters_schema(schema);
207 }
208 }
209
210 let tool_spec = tool_spec_builder
211 .build()
212 .context("Failed to build tool spec")?;
213 Ok(Box::new(McpTool {
214 client: Arc::clone(&self.service),
215 registered_name,
216 server_tool_name: tool.name.into(),
217 tool_spec,
218 }) as Box<dyn Tool>)
219 })
220 .collect::<Result<Vec<_>>>()
221 .context("Failed to build mcp tool specs")?;
222 Ok(tools)
223 }
224
225 fn name(&self) -> Cow<'_, str> {
226 self.name().into()
227 }
228}
229
230#[derive(Clone)]
231struct McpTool {
232 client: Arc<RwLock<Option<RunningService<RoleClient, InitializeRequestParam>>>>,
233 registered_name: String,
234 server_tool_name: String,
235 tool_spec: ToolSpec,
236}
237
238#[async_trait]
239impl Tool for McpTool {
240 async fn invoke(
241 &self,
242 _agent_context: &dyn swiftide_core::AgentContext,
243 tool_call: &ToolCall,
244 ) -> Result<
245 swiftide_core::chat_completion::ToolOutput,
246 swiftide_core::chat_completion::errors::ToolError,
247 > {
248 let args = match tool_call.args() {
249 Some(args) => Some(serde_json::from_str(args).map_err(ToolError::WrongArguments)?),
250 None => None,
251 };
252
253 let request = CallToolRequestParam {
254 name: self.server_tool_name.clone().into(),
255 arguments: args,
256 };
257
258 let Some(service) = &*self.client.read().await else {
259 return Err(
260 CommandError::ExecutorError(anyhow::anyhow!("mcp server is not running")).into(),
261 );
262 };
263
264 tracing::debug!(request = ?request, tool = self.name().as_ref(), "Invoking mcp tool");
265 let response = service
266 .call_tool(request)
267 .await
268 .context("Failed to call tool")?;
269
270 tracing::debug!(response = ?response, tool = self.name().as_ref(), "Received response from mcp tool");
271 let Some(content) = response.content else {
272 if response.is_error.unwrap_or(false) {
273 return Err(ToolError::Unknown(anyhow::anyhow!(
274 "Error received from mcp tool without content"
275 )));
276 }
277
278 return Ok("Tool executed successfully".into());
279 };
280 let content = content
281 .into_iter()
282 .filter_map(|c| c.as_text().map(|t| t.text.clone()))
283 .collect::<Vec<_>>()
284 .join("\n");
285
286 if let Some(error) = response.is_error
287 && error
288 {
289 return Err(ToolError::Unknown(anyhow::anyhow!(
290 "Failed to execute mcp tool: {content}"
291 )));
292 }
293
294 Ok(content.into())
295 }
296
297 fn name(&self) -> std::borrow::Cow<'_, str> {
298 self.registered_name.as_str().into()
299 }
300
301 fn tool_spec(&self) -> ToolSpec {
302 self.tool_spec.clone()
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use copied_from_rmcp::Calculator;
310 use rmcp::serve_server;
311 use tokio::net::{UnixListener, UnixStream};
312
313 const SOCKET_PATH: &str = "/tmp/swiftide-mcp.sock";
314 const EXPECTED_PREFIX: &str = "rmcp";
315
316 #[allow(clippy::similar_names)]
317 #[test_log::test(tokio::test(flavor = "multi_thread"))]
318 async fn test_socket() {
319 let _ = std::fs::remove_file(SOCKET_PATH);
320
321 match UnixListener::bind(SOCKET_PATH) {
322 Ok(unix_listener) => {
323 println!("Server successfully listening on {SOCKET_PATH}");
324 tokio::spawn(server(unix_listener));
325 }
326 Err(e) => {
327 println!("Unable to bind to {SOCKET_PATH}: {e}");
328 }
329 }
330
331 let client = client().await.unwrap();
332
333 let t = client.available_tools().await.unwrap();
334 assert_eq!(client.available_tools().await.unwrap().len(), 3);
335
336 let mut names = t.iter().map(|t| t.name().into_owned()).collect::<Vec<_>>();
337 names.sort();
338 assert_eq!(
339 names,
340 [
341 format!("{EXPECTED_PREFIX}:optional"),
342 format!("{EXPECTED_PREFIX}:sub"),
343 format!("{EXPECTED_PREFIX}:sum")
344 ]
345 );
346
347 let sum_name = format!("{EXPECTED_PREFIX}:sum");
348 let sum_tool = t.iter().find(|t| t.name().as_ref() == sum_name).unwrap();
349 let mut builder = ToolCall::builder()
350 .id("some")
351 .args(r#"{"b": "hello"}"#)
352 .name("test")
353 .name("test")
354 .to_owned();
355
356 assert_eq!(sum_tool.tool_spec().name, sum_name);
357
358 let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
359
360 let result = sum_tool
361 .invoke(&(), &tool_call)
362 .await
363 .unwrap()
364 .content()
365 .unwrap()
366 .to_string();
367 assert_eq!(result, "30");
368
369 let sub_name = format!("{EXPECTED_PREFIX}:sub");
370 let sub_tool = t.iter().find(|t| t.name().as_ref() == sub_name).unwrap();
371 assert_eq!(sub_tool.tool_spec().name, sub_name);
372
373 let tool_call = builder.args(r#"{"a": 10, "b": 20}"#).build().unwrap();
374
375 let result = sub_tool
376 .invoke(&(), &tool_call)
377 .await
378 .unwrap()
379 .content()
380 .unwrap()
381 .to_string();
382 assert_eq!(result, "-10");
383
384 let optional_name = format!("{EXPECTED_PREFIX}:optional");
386 let optional_tool = t
387 .iter()
388 .find(|t| t.name().as_ref() == optional_name)
389 .unwrap();
390 assert_eq!(optional_tool.tool_spec().name, optional_name);
391 let spec = optional_tool.tool_spec();
392 let schema = spec
393 .parameters_schema
394 .expect("optional tool should expose a schema");
395 let schema_json = serde_json::to_value(schema).unwrap();
396 assert_eq!(
397 schema_json
398 .get("properties")
399 .and_then(|props| props.get("text"))
400 .and_then(|prop| prop.get("type"))
401 .and_then(serde_json::Value::as_str),
402 Some("string")
403 );
404
405 let tool_call = builder.args(r#"{"text": "hello"}"#).build().unwrap();
406
407 let result = optional_tool
408 .invoke(&(), &tool_call)
409 .await
410 .unwrap()
411 .content()
412 .unwrap()
413 .to_string();
414 assert_eq!(result, "hello");
415
416 let tool_call = builder.args(r#"{"text": null}"#).build().unwrap();
417 let result = optional_tool
418 .invoke(&(), &tool_call)
419 .await
420 .unwrap()
421 .content()
422 .unwrap()
423 .to_string();
424 assert_eq!(result, "");
425
426 let _ = std::fs::remove_file(SOCKET_PATH);
428 }
429
430 async fn server(unix_listener: UnixListener) -> anyhow::Result<()> {
431 while let Ok((stream, addr)) = unix_listener.accept().await {
432 println!("Client connected: {addr:?}");
433 tokio::spawn(async move {
434 match serve_server(Calculator::new(), stream).await {
435 Ok(server) => {
436 println!("Server initialized successfully");
437 if let Err(e) = server.waiting().await {
438 println!("Error while server waiting: {e:?}");
439 }
440 }
441 Err(e) => println!("Server initialization failed: {e:?}"),
442 }
443
444 anyhow::Ok(())
445 });
446 }
447 Ok(())
448 }
449
450 async fn client() -> anyhow::Result<McpToolbox> {
451 println!("Client connecting to {SOCKET_PATH}");
452 let stream = UnixStream::connect(SOCKET_PATH).await?;
453
454 let client = McpToolbox::try_from_transport(stream).await?;
456 println!("Client connected and initialized successfully");
457
458 Ok(client)
459 }
460
461 #[allow(clippy::unused_self)]
462 mod copied_from_rmcp {
463 use rmcp::{
464 ErrorData as McpError, ServerHandler,
465 handler::server::tool::{Parameters, ToolRouter},
466 model::{CallToolResult, Content, ServerCapabilities, ServerInfo},
467 schemars, tool, tool_handler,
468 };
469
470 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
471 pub struct Request {
472 pub a: i32,
473 pub b: i32,
474 }
475
476 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
477 pub struct OptRequest {
478 pub text: Option<String>,
479 }
480
481 #[derive(Debug, Clone)]
482 pub struct Calculator {
483 tool_router: ToolRouter<Self>,
484 }
485
486 #[rmcp::tool_router]
487 impl Calculator {
488 pub fn new() -> Self {
489 Self {
490 tool_router: Self::tool_router(),
491 }
492 }
493
494 #[allow(clippy::unnecessary_wraps)]
495 #[tool(description = "Calculate the sum of two numbers")]
496 fn sum(
497 &self,
498 Parameters(Request { a, b }): Parameters<Request>,
499 ) -> Result<CallToolResult, McpError> {
500 Ok(CallToolResult::success(vec![Content::text(
501 (a + b).to_string(),
502 )]))
503 }
504
505 #[allow(clippy::unnecessary_wraps)]
506 #[tool(description = "Calculate the sum of two numbers")]
507 fn sub(
508 &self,
509 Parameters(Request { a, b }): Parameters<Request>,
510 ) -> Result<CallToolResult, McpError> {
511 Ok(CallToolResult::success(vec![Content::text(
512 (a - b).to_string(),
513 )]))
514 }
515
516 #[allow(clippy::unnecessary_wraps)]
517 #[tool(description = "Optional echo")]
518 fn optional(
519 &self,
520 Parameters(OptRequest { text }): Parameters<OptRequest>,
521 ) -> Result<CallToolResult, McpError> {
522 Ok(CallToolResult::success(vec![Content::text(
523 text.unwrap_or_default(),
524 )]))
525 }
526 }
527
528 #[tool_handler]
529 impl ServerHandler for Calculator {
530 fn get_info(&self) -> ServerInfo {
531 ServerInfo {
532 instructions: Some("A simple calculator".into()),
533 capabilities: ServerCapabilities::builder().enable_tools().build(),
534 ..Default::default()
535 }
536 }
537 }
538 }
539}