1use async_trait::async_trait;
6use rmcp::transport::ConfigureCommandExt;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::process::Stdio;
11use std::sync::Arc;
12use tokio::process::Command;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info};
15
16use tokio::net::TcpStream;
17#[cfg(unix)]
18use tokio::net::UnixStream;
19
20use crate::api::ToolCall;
21use crate::session::state::ToolFilter;
22use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
23use steer_tools::{
24 InputSchema, ToolError, ToolSchema,
25 result::{ExternalResult, ToolResult},
26};
27
28use rmcp::{
29 model::{CallToolRequestParam, Tool},
30 service::{RoleClient, RunningService, ServiceExt},
31 transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess},
32};
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum McpTransport {
38 Stdio { command: String, args: Vec<String> },
40 Tcp { host: String, port: u16 },
42 #[cfg(unix)]
44 Unix { path: String },
45 Sse {
47 url: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 headers: Option<HashMap<String, String>>,
50 },
51 Http {
53 url: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 headers: Option<HashMap<String, String>>,
56 },
57}
58
59pub struct McpBackend {
61 server_name: String,
62 transport: McpTransport,
63 tool_filter: ToolFilter,
64 client: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65 tools: Arc<RwLock<HashMap<String, Tool>>>,
66}
67
68impl McpBackend {
69 pub async fn new(
71 server_name: String,
72 transport: McpTransport,
73 tool_filter: ToolFilter,
74 ) -> Result<Self, ToolError> {
75 info!(
76 "Creating MCP backend '{}' with transport: {:?}",
77 server_name, transport
78 );
79
80 let client = match &transport {
81 McpTransport::Stdio { command, args } => {
82 let (transport, stderr) =
83 TokioChildProcess::builder(Command::new(command).configure(|cmd| {
84 cmd.args(args);
85 }))
86 .stderr(Stdio::piped())
87 .spawn()
88 .map_err(|e| {
89 error!("Failed to create MCP process: {}", e);
90 ToolError::mcp_connection_failed(
91 &server_name,
92 format!("Failed to create MCP process: {e}"),
93 )
94 })?;
95
96 if let Some(stderr) = stderr {
97 let server_name_for_logging = server_name.clone();
98 tokio::spawn(async move {
99 use tokio::io::{AsyncBufReadExt, BufReader};
100 let mut reader = BufReader::new(stderr);
101 let mut line = String::new();
102
103 while let Ok(len) = reader.read_line(&mut line).await {
104 if len == 0 {
105 break;
106 }
107 debug!(
108 target: "mcp_server",
109 "[{}] {}",
110 server_name_for_logging,
111 line.trim()
112 );
113 line.clear();
114 }
115 });
116 }
117
118 ().serve(transport).await.map_err(|e| {
119 error!("Failed to serve MCP: {}", e);
120 ToolError::mcp_connection_failed(
121 &server_name,
122 format!("Failed to serve MCP: {e}"),
123 )
124 })?
125 }
126 McpTransport::Tcp { host, port } => {
127 let stream = TcpStream::connect((host.as_str(), *port))
128 .await
129 .map_err(|e| {
130 error!("Failed to connect to TCP MCP server: {}", e);
131 ToolError::mcp_connection_failed(
132 &server_name,
133 format!("Failed to connect to {host}:{port} - {e}"),
134 )
135 })?;
136
137 ().serve(stream).await.map_err(|e| {
138 error!("Failed to serve MCP over TCP: {}", e);
139 ToolError::mcp_connection_failed(
140 &server_name,
141 format!("Failed to serve MCP over TCP: {e}"),
142 )
143 })?
144 }
145 #[cfg(unix)]
146 McpTransport::Unix { path } => {
147 let stream = UnixStream::connect(path).await.map_err(|e| {
148 error!("Failed to connect to Unix socket MCP server: {}", e);
149 ToolError::mcp_connection_failed(
150 &server_name,
151 format!("Failed to connect to Unix socket {path} - {e}"),
152 )
153 })?;
154
155 ().serve(stream).await.map_err(|e| {
156 error!("Failed to serve MCP over Unix socket: {}", e);
157 ToolError::mcp_connection_failed(
158 &server_name,
159 format!("Failed to serve MCP over Unix socket: {e}"),
160 )
161 })?
162 }
163 McpTransport::Sse { url, headers } => {
164 if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
166 info!(
167 "SSE transport with custom headers requested; headers may not be applied"
168 );
169 }
170
171 let transport = SseClientTransport::start(url.clone()).await.map_err(|e| {
172 error!("Failed to start SSE transport: {}", e);
173 ToolError::mcp_connection_failed(
174 &server_name,
175 format!("Failed to start SSE transport: {e}"),
176 )
177 })?;
178
179 ().serve(transport).await.map_err(|e| {
180 error!("Failed to serve MCP over SSE: {}", e);
181 ToolError::mcp_connection_failed(
182 &server_name,
183 format!("Failed to serve MCP over SSE: {e}"),
184 )
185 })?
186 }
187 McpTransport::Http { url, headers } => {
188 let transport = StreamableHttpClientTransport::from_uri(url.clone());
190
191 if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
192 info!(
193 "HTTP transport with custom headers requested; headers may not be applied"
194 );
195 }
196
197 ().serve(transport).await.map_err(|e| {
198 error!("Failed to serve MCP over HTTP: {}", e);
199 ToolError::mcp_connection_failed(
200 &server_name,
201 format!("Failed to serve MCP over HTTP: {e}"),
202 )
203 })?
204 }
205 };
206
207 let server_info = client.peer_info();
208 info!("Connected to server: {server_info:#?}");
209
210 debug!("Attempting to list tools from MCP server '{}'", server_name);
211
212 let list_tools_timeout = std::time::Duration::from_secs(10);
213 let tool_list =
214 tokio::time::timeout(list_tools_timeout, client.list_tools(Default::default()))
215 .await
216 .map_err(|_| {
217 ToolError::mcp_connection_failed(
218 &server_name,
219 "Timeout listing tools".to_string(),
220 )
221 })?
222 .map_err(|e| {
223 ToolError::mcp_connection_failed(
224 &server_name,
225 format!("Failed to list tools: {e}"),
226 )
227 })?;
228
229 let mut tools = HashMap::new();
231 for tool in tool_list.tools {
232 tools.insert(tool.name.to_string(), tool);
233 }
234
235 info!(
236 "Discovered {} tools from MCP server '{}': {}",
237 tools.len(),
238 server_name,
239 tools
240 .keys()
241 .map(|k| k.to_string())
242 .collect::<Vec<_>>()
243 .join(", ")
244 );
245
246 let backend = Self {
247 server_name,
248 transport,
249 tool_filter,
250 client: Arc::new(RwLock::new(Some(client))),
251 tools: Arc::new(RwLock::new(tools)),
252 };
253
254 Ok(backend)
255 }
256
257 fn should_include_tool(&self, tool_name: &str) -> bool {
259 match &self.tool_filter {
260 ToolFilter::All => true,
261 ToolFilter::Include(included) => included.contains(&tool_name.to_string()),
262 ToolFilter::Exclude(excluded) => !excluded.contains(&tool_name.to_string()),
263 }
264 }
265
266 fn mcp_tool_to_schema(&self, tool: &Tool) -> ToolSchema {
267 let description = match &tool.description {
268 Some(desc) if !desc.is_empty() => desc.to_string(),
269 _ => format!(
270 "Tool '{}' from MCP server '{}'",
271 tool.name, self.server_name
272 ),
273 };
274
275 let properties = (*tool.input_schema).clone();
277 let required = properties
278 .get("required")
279 .and_then(|v| v.as_array())
280 .map(|arr| {
281 arr.iter()
282 .filter_map(|v| v.as_str().map(String::from))
283 .collect()
284 })
285 .unwrap_or_default();
286
287 let input_schema = InputSchema {
288 properties: properties
289 .get("properties")
290 .and_then(|v| v.as_object())
291 .cloned()
292 .unwrap_or_default(),
293 required,
294 schema_type: "object".to_string(),
295 };
296
297 ToolSchema {
298 name: format!("mcp__{}__{}", self.server_name, tool.name),
299 description,
300 input_schema,
301 }
302 }
303}
304
305#[async_trait]
306impl ToolBackend for McpBackend {
307 async fn execute(
308 &self,
309 tool_call: &ToolCall,
310 _context: &ExecutionContext,
311 ) -> Result<ToolResult, ToolError> {
312 let service_guard = self.client.read().await;
314 let service = service_guard
315 .as_ref()
316 .ok_or_else(|| ToolError::execution("mcp", "MCP service not initialized"))?;
317
318 let prefix = format!("mcp__{}__", self.server_name);
320 let actual_tool_name = if tool_call.name.starts_with(&prefix) {
321 &tool_call.name[prefix.len()..]
322 } else {
323 &tool_call.name
324 };
325
326 debug!(
327 "Executing tool '{}' via MCP server '{}'",
328 actual_tool_name, self.server_name
329 );
330
331 let arguments = if let Some(obj) = tool_call.parameters.as_object() {
333 Some(obj.clone())
334 } else if tool_call.parameters.is_null() {
335 None
336 } else {
337 return Err(ToolError::invalid_params(
338 &tool_call.name,
339 "Parameters must be an object",
340 ));
341 };
342
343 let result = service
345 .call_tool(CallToolRequestParam {
346 name: actual_tool_name.to_string().into(),
347 arguments,
348 })
349 .await
350 .map_err(|e| {
351 ToolError::execution(&tool_call.name, format!("Tool execution failed: {e}"))
352 })?;
353
354 let output = result
356 .content
357 .into_iter()
358 .map(|content| {
359 match &content.raw {
361 rmcp::model::RawContent::Text(text_content) => text_content.text.to_string(),
362 rmcp::model::RawContent::Image { .. } => "[Image content]".to_string(),
363 rmcp::model::RawContent::Resource { .. } => "[Resource content]".to_string(),
364 rmcp::model::RawContent::Audio { .. } => "[Audio content]".to_string(),
365 }
366 })
367 .collect::<Vec<_>>()
368 .join("\n");
369
370 Ok(ToolResult::External(ExternalResult {
372 tool_name: tool_call.name.clone(),
373 payload: output,
374 }))
375 }
376
377 async fn supported_tools(&self) -> Vec<String> {
378 let tools = self.tools.read().await;
379 tools
380 .keys()
381 .filter(|tool_name| self.should_include_tool(tool_name))
382 .map(|tool_name| format!("mcp__{}__{}", self.server_name, tool_name))
383 .collect()
384 }
385
386 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
387 let tools = self.tools.read().await;
388 tools
389 .values()
390 .filter(|tool| self.should_include_tool(&tool.name))
391 .map(|tool| self.mcp_tool_to_schema(tool))
392 .collect()
393 }
394
395 fn metadata(&self) -> BackendMetadata {
396 let mut metadata = BackendMetadata::new(self.server_name.clone(), "MCP".to_string());
397
398 match &self.transport {
399 McpTransport::Stdio { command, args } => {
400 metadata = metadata
401 .with_info("transport".to_string(), "stdio".to_string())
402 .with_info("command".to_string(), command.clone())
403 .with_info("args".to_string(), args.join(" "));
404 }
405 McpTransport::Tcp { host, port } => {
406 metadata = metadata
407 .with_info("transport".to_string(), "tcp".to_string())
408 .with_info("host".to_string(), host.clone())
409 .with_info("port".to_string(), port.to_string());
410 }
411 #[cfg(unix)]
412 McpTransport::Unix { path } => {
413 metadata = metadata
414 .with_info("transport".to_string(), "unix".to_string())
415 .with_info("path".to_string(), path.clone());
416 }
417 McpTransport::Sse { url, .. } => {
418 metadata = metadata
419 .with_info("transport".to_string(), "sse".to_string())
420 .with_info("url".to_string(), url.clone());
421 }
422 McpTransport::Http { url, .. } => {
423 metadata = metadata
424 .with_info("transport".to_string(), "http".to_string())
425 .with_info("url".to_string(), url.clone());
426 }
427 }
428
429 metadata
430 }
431
432 async fn health_check(&self) -> bool {
433 let service_guard = self.client.read().await;
435 service_guard.is_some()
436 }
437
438 async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
439 Ok(true)
442 }
443}
444
445impl Drop for McpBackend {
446 fn drop(&mut self) {
447 let service = self.client.clone();
449
450 tokio::spawn(async move {
451 if let Some(service) = service.write().await.take() {
452 let _ = service.cancel().await;
453 }
454 });
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_tool_name_extraction() {
464 let prefix = "mcp__test__";
465 let full_name = "mcp__test__some_tool";
466 let actual_name = if let Some(stripped) = full_name.strip_prefix(prefix) {
467 stripped
468 } else {
469 full_name
470 };
471
472 assert_eq!(actual_name, "some_tool");
473 }
474
475 #[test]
476 fn test_mcp_transport_serialization() {
477 let stdio = McpTransport::Stdio {
479 command: "python".to_string(),
480 args: vec!["-m".to_string(), "test_server".to_string()],
481 };
482 let json = serde_json::to_string(&stdio).unwrap();
483 assert!(json.contains("\"type\":\"stdio\""));
484 assert!(json.contains("\"command\":\"python\""));
485
486 let tcp = McpTransport::Tcp {
488 host: "localhost".to_string(),
489 port: 3000,
490 };
491 let json = serde_json::to_string(&tcp).unwrap();
492 assert!(json.contains("\"type\":\"tcp\""));
493 assert!(json.contains("\"host\":\"localhost\""));
494 assert!(json.contains("\"port\":3000"));
495
496 #[cfg(unix)]
498 {
499 let unix = McpTransport::Unix {
500 path: "/tmp/test.sock".to_string(),
501 };
502 let json = serde_json::to_string(&unix).unwrap();
503 assert!(json.contains("\"type\":\"unix\""));
504 assert!(json.contains("\"path\":\"/tmp/test.sock\""));
505 }
506 }
507
508 #[test]
509 fn test_mcp_transport_deserialization() {
510 let json = r#"{"type":"stdio","command":"node","args":["server.js"]}"#;
512 let transport: McpTransport = serde_json::from_str(json).unwrap();
513 match transport {
514 McpTransport::Stdio { command, args } => {
515 assert_eq!(command, "node");
516 assert_eq!(args, vec!["server.js"]);
517 }
518 _ => unreachable!("Stdio transport"),
519 }
520
521 let json = r#"{"type":"tcp","host":"127.0.0.1","port":8080}"#;
523 let transport: McpTransport = serde_json::from_str(json).unwrap();
524 match transport {
525 McpTransport::Tcp { host, port } => {
526 assert_eq!(host, "127.0.0.1");
527 assert_eq!(port, 8080);
528 }
529 _ => unreachable!("TCP transport"),
530 }
531 }
532}