1use crate::tools::mcp::error::McpError;
6use async_trait::async_trait;
7use rmcp::transport::ConfigureCommandExt;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Stdio;
12use std::sync::Arc;
13use tokio::process::Command;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info};
16
17use tokio::net::TcpStream;
18#[cfg(unix)]
19use tokio::net::UnixStream;
20
21use crate::api::ToolCall;
22use crate::session::state::ToolFilter;
23use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
24use steer_tools::{
25 InputSchema, ToolError, ToolSchema,
26 result::{ExternalResult, ToolResult},
27};
28
29use rmcp::{
30 model::{CallToolRequestParam, Tool},
31 service::{RoleClient, RunningService, ServiceExt},
32 transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess},
33};
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum McpTransport {
39 Stdio { command: String, args: Vec<String> },
41 Tcp { host: String, port: u16 },
43 #[cfg(unix)]
45 Unix { path: String },
46 Sse {
48 url: String,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 headers: Option<HashMap<String, String>>,
51 },
52 Http {
54 url: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 headers: Option<HashMap<String, String>>,
57 },
58}
59
60pub struct McpBackend {
62 server_name: String,
63 transport: McpTransport,
64 tool_filter: ToolFilter,
65 client: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
66 tools: Arc<RwLock<HashMap<String, Tool>>>,
67}
68
69impl McpBackend {
70 pub async fn new(
72 server_name: String,
73 transport: McpTransport,
74 tool_filter: ToolFilter,
75 ) -> Result<Self, McpError> {
76 info!(
77 "Creating MCP backend '{}' with transport: {:?}",
78 server_name, transport
79 );
80
81 let client = match &transport {
82 McpTransport::Stdio { command, args } => {
83 let (transport, stderr) =
84 TokioChildProcess::builder(Command::new(command).configure(|cmd| {
85 cmd.args(args);
86 }))
87 .stderr(Stdio::piped())
88 .spawn()
89 .map_err(|e| {
90 error!("Failed to create MCP process: {}", e);
91 McpError::ConnectionFailed {
92 server_name: server_name.clone(),
93 message: format!("Failed to create MCP process: {e}"),
94 }
95 })?;
96
97 if let Some(stderr) = stderr {
98 let server_name_for_logging = server_name.clone();
99 tokio::spawn(async move {
100 use tokio::io::{AsyncBufReadExt, BufReader};
101 let mut reader = BufReader::new(stderr);
102 let mut line = String::new();
103
104 while let Ok(len) = reader.read_line(&mut line).await {
105 if len == 0 {
106 break;
107 }
108 debug!(
109 target: "mcp_server",
110 "[{}] {}",
111 server_name_for_logging,
112 line.trim()
113 );
114 line.clear();
115 }
116 });
117 }
118
119 ().serve(transport).await.map_err(|e| {
120 error!("Failed to serve MCP: {}", e);
121 McpError::ServeFailed {
122 transport: "stdio".to_string(),
123 message: format!("Failed to serve MCP: {e}"),
124 }
125 })?
126 }
127 McpTransport::Tcp { host, port } => {
128 let stream = TcpStream::connect((host.as_str(), *port))
129 .await
130 .map_err(|e| {
131 error!("Failed to connect to TCP MCP server: {}", e);
132 McpError::ConnectionFailed {
133 server_name: server_name.clone(),
134 message: format!("Failed to connect to {host}:{port} - {e}"),
135 }
136 })?;
137
138 ().serve(stream).await.map_err(|e| {
139 error!("Failed to serve MCP over TCP: {}", e);
140 McpError::ServeFailed {
141 transport: "tcp".to_string(),
142 message: format!("Failed to serve MCP over TCP: {e}"),
143 }
144 })?
145 }
146 #[cfg(unix)]
147 McpTransport::Unix { path } => {
148 let stream = UnixStream::connect(path).await.map_err(|e| {
149 error!("Failed to connect to Unix socket MCP server: {}", e);
150 McpError::ConnectionFailed {
151 server_name: server_name.clone(),
152 message: format!("Failed to connect to Unix socket {path} - {e}"),
153 }
154 })?;
155
156 ().serve(stream).await.map_err(|e| {
157 error!("Failed to serve MCP over Unix socket: {}", e);
158 McpError::ServeFailed {
159 transport: "unix".to_string(),
160 message: format!("Failed to serve MCP over Unix socket: {e}"),
161 }
162 })?
163 }
164 McpTransport::Sse { url, headers } => {
165 if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
167 info!(
168 "SSE transport with custom headers requested; headers may not be applied"
169 );
170 }
171
172 let transport = SseClientTransport::start(url.clone()).await.map_err(|e| {
173 error!("Failed to start SSE transport: {}", e);
174 McpError::ConnectionFailed {
175 server_name: server_name.clone(),
176 message: format!("Failed to start SSE transport: {e}"),
177 }
178 })?;
179
180 ().serve(transport).await.map_err(|e| {
181 error!("Failed to serve MCP over SSE: {}", e);
182 McpError::ServeFailed {
183 transport: "sse".to_string(),
184 message: format!("Failed to serve MCP over SSE: {e}"),
185 }
186 })?
187 }
188 McpTransport::Http { url, headers } => {
189 let transport = StreamableHttpClientTransport::from_uri(url.clone());
191
192 if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
193 info!(
194 "HTTP transport with custom headers requested; headers may not be applied"
195 );
196 }
197
198 ().serve(transport).await.map_err(|e| {
199 error!("Failed to serve MCP over HTTP: {}", e);
200 McpError::ServeFailed {
201 transport: "http".to_string(),
202 message: format!("Failed to serve MCP over HTTP: {e}"),
203 }
204 })?
205 }
206 };
207
208 let server_info = client.peer_info();
209 info!("Connected to server: {server_info:#?}");
210
211 debug!("Attempting to list tools from MCP server '{}'", server_name);
212
213 let list_tools_timeout = std::time::Duration::from_secs(10);
214 let tool_list =
215 tokio::time::timeout(list_tools_timeout, client.list_tools(Default::default()))
216 .await
217 .map_err(|_| McpError::ListToolsTimeout {
218 server_name: server_name.clone(),
219 })?
220 .map_err(|e| McpError::ListToolsFailed {
221 message: format!("Failed to list tools: {e}"),
222 })?;
223
224 let mut tools = HashMap::new();
226 for tool in tool_list.tools {
227 tools.insert(tool.name.to_string(), tool);
228 }
229
230 info!(
231 "Discovered {} tools from MCP server '{}': {}",
232 tools.len(),
233 server_name,
234 tools
235 .keys()
236 .map(|k| k.to_string())
237 .collect::<Vec<_>>()
238 .join(", ")
239 );
240
241 let backend = Self {
242 server_name,
243 transport,
244 tool_filter,
245 client: Arc::new(RwLock::new(Some(client))),
246 tools: Arc::new(RwLock::new(tools)),
247 };
248
249 Ok(backend)
250 }
251
252 fn should_include_tool(&self, tool_name: &str) -> bool {
254 match &self.tool_filter {
255 ToolFilter::All => true,
256 ToolFilter::Include(included) => included.contains(&tool_name.to_string()),
257 ToolFilter::Exclude(excluded) => !excluded.contains(&tool_name.to_string()),
258 }
259 }
260
261 fn mcp_tool_to_schema(&self, tool: &Tool) -> ToolSchema {
262 let description = match &tool.description {
263 Some(desc) if !desc.is_empty() => desc.to_string(),
264 _ => format!(
265 "Tool '{}' from MCP server '{}'",
266 tool.name, self.server_name
267 ),
268 };
269
270 let properties = (*tool.input_schema).clone();
272 let required = properties
273 .get("required")
274 .and_then(|v| v.as_array())
275 .map(|arr| {
276 arr.iter()
277 .filter_map(|v| v.as_str().map(String::from))
278 .collect()
279 })
280 .unwrap_or_default();
281
282 let input_schema = InputSchema {
283 properties: properties
284 .get("properties")
285 .and_then(|v| v.as_object())
286 .cloned()
287 .unwrap_or_default(),
288 required,
289 schema_type: "object".to_string(),
290 };
291
292 ToolSchema {
293 name: format!("mcp__{}__{}", self.server_name, tool.name),
294 description,
295 input_schema,
296 }
297 }
298}
299
300#[async_trait]
301impl ToolBackend for McpBackend {
302 async fn execute(
303 &self,
304 tool_call: &ToolCall,
305 _context: &ExecutionContext,
306 ) -> Result<ToolResult, ToolError> {
307 let service_guard = self.client.read().await;
309 let service = service_guard
310 .as_ref()
311 .ok_or_else(|| ToolError::execution("mcp", "MCP service not initialized"))?;
312
313 let prefix = format!("mcp__{}__", self.server_name);
315 let actual_tool_name = if tool_call.name.starts_with(&prefix) {
316 &tool_call.name[prefix.len()..]
317 } else {
318 &tool_call.name
319 };
320
321 debug!(
322 "Executing tool '{}' via MCP server '{}'",
323 actual_tool_name, self.server_name
324 );
325
326 let arguments = if let Some(obj) = tool_call.parameters.as_object() {
328 Some(obj.clone())
329 } else if tool_call.parameters.is_null() {
330 None
331 } else {
332 return Err(ToolError::invalid_params(
333 &tool_call.name,
334 "Parameters must be an object",
335 ));
336 };
337
338 let result = service
340 .call_tool(CallToolRequestParam {
341 name: actual_tool_name.to_string().into(),
342 arguments,
343 })
344 .await
345 .map_err(|e| {
346 ToolError::execution(&tool_call.name, format!("Tool execution failed: {e}"))
347 })?;
348
349 let output = result
351 .content
352 .into_iter()
353 .map(|content| {
354 match &content.raw {
356 rmcp::model::RawContent::Text(text_content) => text_content.text.to_string(),
357 rmcp::model::RawContent::Image { .. } => "[Image content]".to_string(),
358 rmcp::model::RawContent::Resource { .. } => "[Resource content]".to_string(),
359 rmcp::model::RawContent::Audio { .. } => "[Audio content]".to_string(),
360 }
361 })
362 .collect::<Vec<_>>()
363 .join("\n");
364
365 Ok(ToolResult::External(ExternalResult {
367 tool_name: tool_call.name.clone(),
368 payload: output,
369 }))
370 }
371
372 async fn supported_tools(&self) -> Vec<String> {
373 let tools = self.tools.read().await;
374 tools
375 .keys()
376 .filter(|tool_name| self.should_include_tool(tool_name))
377 .map(|tool_name| format!("mcp__{}__{}", self.server_name, tool_name))
378 .collect()
379 }
380
381 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
382 let tools = self.tools.read().await;
383 tools
384 .values()
385 .filter(|tool| self.should_include_tool(&tool.name))
386 .map(|tool| self.mcp_tool_to_schema(tool))
387 .collect()
388 }
389
390 fn metadata(&self) -> BackendMetadata {
391 let mut metadata = BackendMetadata::new(self.server_name.clone(), "MCP".to_string());
392
393 match &self.transport {
394 McpTransport::Stdio { command, args } => {
395 metadata = metadata
396 .with_info("transport".to_string(), "stdio".to_string())
397 .with_info("command".to_string(), command.clone())
398 .with_info("args".to_string(), args.join(" "));
399 }
400 McpTransport::Tcp { host, port } => {
401 metadata = metadata
402 .with_info("transport".to_string(), "tcp".to_string())
403 .with_info("host".to_string(), host.clone())
404 .with_info("port".to_string(), port.to_string());
405 }
406 #[cfg(unix)]
407 McpTransport::Unix { path } => {
408 metadata = metadata
409 .with_info("transport".to_string(), "unix".to_string())
410 .with_info("path".to_string(), path.clone());
411 }
412 McpTransport::Sse { url, .. } => {
413 metadata = metadata
414 .with_info("transport".to_string(), "sse".to_string())
415 .with_info("url".to_string(), url.clone());
416 }
417 McpTransport::Http { url, .. } => {
418 metadata = metadata
419 .with_info("transport".to_string(), "http".to_string())
420 .with_info("url".to_string(), url.clone());
421 }
422 }
423
424 metadata
425 }
426
427 async fn health_check(&self) -> bool {
428 let service_guard = self.client.read().await;
430 service_guard.is_some()
431 }
432
433 async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
434 Ok(true)
437 }
438}
439
440impl Drop for McpBackend {
441 fn drop(&mut self) {
442 let service = self.client.clone();
444
445 tokio::spawn(async move {
446 if let Some(service) = service.write().await.take() {
447 let _ = service.cancel().await;
448 }
449 });
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_tool_name_extraction() {
459 let prefix = "mcp__test__";
460 let full_name = "mcp__test__some_tool";
461 let actual_name = if let Some(stripped) = full_name.strip_prefix(prefix) {
462 stripped
463 } else {
464 full_name
465 };
466
467 assert_eq!(actual_name, "some_tool");
468 }
469
470 #[test]
471 fn test_mcp_transport_serialization() {
472 let stdio = McpTransport::Stdio {
474 command: "python".to_string(),
475 args: vec!["-m".to_string(), "test_server".to_string()],
476 };
477 let json = serde_json::to_string(&stdio).unwrap();
478 assert!(json.contains("\"type\":\"stdio\""));
479 assert!(json.contains("\"command\":\"python\""));
480
481 let tcp = McpTransport::Tcp {
483 host: "localhost".to_string(),
484 port: 3000,
485 };
486 let json = serde_json::to_string(&tcp).unwrap();
487 assert!(json.contains("\"type\":\"tcp\""));
488 assert!(json.contains("\"host\":\"localhost\""));
489 assert!(json.contains("\"port\":3000"));
490
491 #[cfg(unix)]
493 {
494 let unix = McpTransport::Unix {
495 path: "/tmp/test.sock".to_string(),
496 };
497 let json = serde_json::to_string(&unix).unwrap();
498 assert!(json.contains("\"type\":\"unix\""));
499 assert!(json.contains("\"path\":\"/tmp/test.sock\""));
500 }
501 }
502
503 #[test]
504 fn test_mcp_transport_deserialization() {
505 let json = r#"{"type":"stdio","command":"node","args":["server.js"]}"#;
507 let transport: McpTransport = serde_json::from_str(json).unwrap();
508 match transport {
509 McpTransport::Stdio { command, args } => {
510 assert_eq!(command, "node");
511 assert_eq!(args, vec!["server.js"]);
512 }
513 _ => unreachable!("Stdio transport"),
514 }
515
516 let json = r#"{"type":"tcp","host":"127.0.0.1","port":8080}"#;
518 let transport: McpTransport = serde_json::from_str(json).unwrap();
519 match transport {
520 McpTransport::Tcp { host, port } => {
521 assert_eq!(host, "127.0.0.1");
522 assert_eq!(port, 8080);
523 }
524 _ => unreachable!("TCP transport"),
525 }
526 }
527}