1use crate::tools::mcp::error::McpError;
6use async_trait::async_trait;
7use rmcp::transport::ConfigureCommandExt;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Stdio;
12use std::sync::Arc;
13use steer_tools::ToolCall;
14use tokio::net::TcpStream;
15#[cfg(unix)]
16use tokio::net::UnixStream;
17use tokio::process::Command;
18use tokio::sync::RwLock;
19use tracing::{debug, error, info};
20
21use crate::session::state::ToolFilter;
22use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
23use steer_tools::{
24 InputSchema, ToolError, ToolSchema,
25 result::{ExternalResult, ToolResult},
26};
27
28use rmcp::{
29 model::{CallToolRequestParam, Tool},
30 service::{RoleClient, RunningService, ServiceExt},
31 transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess},
32};
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum McpTransport {
38 Stdio { command: String, args: Vec<String> },
40 Tcp { host: String, port: u16 },
42 #[cfg(unix)]
44 Unix { path: String },
45 Sse {
47 url: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 headers: Option<HashMap<String, String>>,
50 },
51 Http {
53 url: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 headers: Option<HashMap<String, String>>,
56 },
57}
58
59pub struct McpBackend {
61 server_name: String,
62 transport: McpTransport,
63 tool_filter: ToolFilter,
64 client: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65 tools: Arc<RwLock<HashMap<String, Tool>>>,
66}
67
68impl McpBackend {
69 pub async fn new(
71 server_name: String,
72 transport: McpTransport,
73 tool_filter: ToolFilter,
74 ) -> Result<Self, McpError> {
75 info!(
76 "Creating MCP backend '{}' with transport: {:?}",
77 server_name, transport
78 );
79
80 let client = match &transport {
81 McpTransport::Stdio { command, args } => {
82 let (transport, stderr) =
83 TokioChildProcess::builder(Command::new(command).configure(|cmd| {
84 cmd.args(args);
85 }))
86 .stderr(Stdio::piped())
87 .spawn()
88 .map_err(|e| {
89 error!("Failed to create MCP process: {}", e);
90 McpError::ConnectionFailed {
91 server_name: server_name.clone(),
92 message: format!("Failed to create MCP process: {e}"),
93 }
94 })?;
95
96 if let Some(stderr) = stderr {
97 let server_name_for_logging = server_name.clone();
98 tokio::spawn(async move {
99 use tokio::io::{AsyncBufReadExt, BufReader};
100 let mut reader = BufReader::new(stderr);
101 let mut line = String::new();
102
103 while let Ok(len) = reader.read_line(&mut line).await {
104 if len == 0 {
105 break;
106 }
107 debug!(
108 target: "mcp_server",
109 "[{}] {}",
110 server_name_for_logging,
111 line.trim()
112 );
113 line.clear();
114 }
115 });
116 }
117
118 ().serve(transport).await.map_err(|e| {
119 error!("Failed to serve MCP: {}", e);
120 McpError::ServeFailed {
121 transport: "stdio".to_string(),
122 message: format!("Failed to serve MCP: {e}"),
123 }
124 })?
125 }
126 McpTransport::Tcp { host, port } => {
127 let stream = TcpStream::connect((host.as_str(), *port))
128 .await
129 .map_err(|e| {
130 error!("Failed to connect to TCP MCP server: {}", e);
131 McpError::ConnectionFailed {
132 server_name: server_name.clone(),
133 message: format!("Failed to connect to {host}:{port} - {e}"),
134 }
135 })?;
136
137 ().serve(stream).await.map_err(|e| {
138 error!("Failed to serve MCP over TCP: {}", e);
139 McpError::ServeFailed {
140 transport: "tcp".to_string(),
141 message: format!("Failed to serve MCP over TCP: {e}"),
142 }
143 })?
144 }
145 #[cfg(unix)]
146 McpTransport::Unix { path } => {
147 let stream = UnixStream::connect(path).await.map_err(|e| {
148 error!("Failed to connect to Unix socket MCP server: {}", e);
149 McpError::ConnectionFailed {
150 server_name: server_name.clone(),
151 message: format!("Failed to connect to Unix socket {path} - {e}"),
152 }
153 })?;
154
155 ().serve(stream).await.map_err(|e| {
156 error!("Failed to serve MCP over Unix socket: {}", e);
157 McpError::ServeFailed {
158 transport: "unix".to_string(),
159 message: format!("Failed to serve MCP over Unix socket: {e}"),
160 }
161 })?
162 }
163 McpTransport::Sse { url, headers } => {
164 if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
166 info!(
167 "SSE transport with custom headers requested; headers may not be applied"
168 );
169 }
170
171 let transport = SseClientTransport::start(url.clone()).await.map_err(|e| {
172 error!("Failed to start SSE transport: {}", e);
173 McpError::ConnectionFailed {
174 server_name: server_name.clone(),
175 message: format!("Failed to start SSE transport: {e}"),
176 }
177 })?;
178
179 ().serve(transport).await.map_err(|e| {
180 error!("Failed to serve MCP over SSE: {}", e);
181 McpError::ServeFailed {
182 transport: "sse".to_string(),
183 message: format!("Failed to serve MCP over SSE: {e}"),
184 }
185 })?
186 }
187 McpTransport::Http { url, headers } => {
188 let transport = StreamableHttpClientTransport::from_uri(url.clone());
190
191 if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
192 info!(
193 "HTTP transport with custom headers requested; headers may not be applied"
194 );
195 }
196
197 ().serve(transport).await.map_err(|e| {
198 error!("Failed to serve MCP over HTTP: {}", e);
199 McpError::ServeFailed {
200 transport: "http".to_string(),
201 message: format!("Failed to serve MCP over HTTP: {e}"),
202 }
203 })?
204 }
205 };
206
207 let server_info = client.peer_info();
208 info!("Connected to server: {server_info:#?}");
209
210 debug!("Attempting to list tools from MCP server '{}'", server_name);
211
212 let list_tools_timeout = std::time::Duration::from_secs(10);
213 let tool_list =
214 tokio::time::timeout(list_tools_timeout, client.list_tools(Default::default()))
215 .await
216 .map_err(|_| McpError::ListToolsTimeout {
217 server_name: server_name.clone(),
218 })?
219 .map_err(|e| McpError::ListToolsFailed {
220 message: format!("Failed to list tools: {e}"),
221 })?;
222
223 let mut tools = HashMap::new();
225 for tool in tool_list.tools {
226 tools.insert(tool.name.to_string(), tool);
227 }
228
229 info!(
230 "Discovered {} tools from MCP server '{}': {}",
231 tools.len(),
232 server_name,
233 tools
234 .keys()
235 .map(|k| k.to_string())
236 .collect::<Vec<_>>()
237 .join(", ")
238 );
239
240 let backend = Self {
241 server_name,
242 transport,
243 tool_filter,
244 client: Arc::new(RwLock::new(Some(client))),
245 tools: Arc::new(RwLock::new(tools)),
246 };
247
248 Ok(backend)
249 }
250
251 fn should_include_tool(&self, tool_name: &str) -> bool {
253 match &self.tool_filter {
254 ToolFilter::All => true,
255 ToolFilter::Include(included) => included.contains(&tool_name.to_string()),
256 ToolFilter::Exclude(excluded) => !excluded.contains(&tool_name.to_string()),
257 }
258 }
259
260 fn mcp_tool_to_schema(&self, tool: &Tool) -> ToolSchema {
261 let description = match &tool.description {
262 Some(desc) if !desc.is_empty() => desc.to_string(),
263 _ => format!(
264 "Tool '{}' from MCP server '{}'",
265 tool.name, self.server_name
266 ),
267 };
268
269 let properties = (*tool.input_schema).clone();
271 let required = properties
272 .get("required")
273 .and_then(|v| v.as_array())
274 .map(|arr| {
275 arr.iter()
276 .filter_map(|v| v.as_str().map(String::from))
277 .collect()
278 })
279 .unwrap_or_default();
280
281 let input_schema = InputSchema {
282 properties: properties
283 .get("properties")
284 .and_then(|v| v.as_object())
285 .cloned()
286 .unwrap_or_default(),
287 required,
288 schema_type: "object".to_string(),
289 };
290
291 ToolSchema {
292 name: format!("mcp__{}__{}", self.server_name, tool.name),
293 description,
294 input_schema,
295 }
296 }
297}
298
299#[async_trait]
300impl ToolBackend for McpBackend {
301 async fn execute(
302 &self,
303 tool_call: &ToolCall,
304 _context: &ExecutionContext,
305 ) -> Result<ToolResult, ToolError> {
306 let service_guard = self.client.read().await;
308 let service = service_guard
309 .as_ref()
310 .ok_or_else(|| ToolError::execution("mcp", "MCP service not initialized"))?;
311
312 let prefix = format!("mcp__{}__", self.server_name);
314 let actual_tool_name = if tool_call.name.starts_with(&prefix) {
315 &tool_call.name[prefix.len()..]
316 } else {
317 &tool_call.name
318 };
319
320 debug!(
321 "Executing tool '{}' via MCP server '{}'",
322 actual_tool_name, self.server_name
323 );
324
325 let arguments = if let Some(obj) = tool_call.parameters.as_object() {
327 Some(obj.clone())
328 } else if tool_call.parameters.is_null() {
329 None
330 } else {
331 return Err(ToolError::invalid_params(
332 &tool_call.name,
333 "Parameters must be an object",
334 ));
335 };
336
337 let result = service
339 .call_tool(CallToolRequestParam {
340 name: actual_tool_name.to_string().into(),
341 arguments,
342 })
343 .await
344 .map_err(|e| {
345 ToolError::execution(&tool_call.name, format!("Tool execution failed: {e}"))
346 })?;
347
348 let output = result
350 .content
351 .into_iter()
352 .flat_map(|annotated_contents| annotated_contents.into_iter())
353 .map(|annotated| {
354 match annotated.raw {
356 rmcp::model::RawContent::Text(text_content) => text_content.text.to_string(),
357 rmcp::model::RawContent::Image { .. } => "[Image content]".to_string(),
358 rmcp::model::RawContent::Resource { .. } => "[Resource content]".to_string(),
359 rmcp::model::RawContent::Audio { .. } => "[Audio content]".to_string(),
360 }
361 })
362 .collect::<Vec<_>>()
363 .join("\n");
364
365 Ok(ToolResult::External(ExternalResult {
367 tool_name: tool_call.name.clone(),
368 payload: output,
369 }))
370 }
371
372 async fn supported_tools(&self) -> Vec<String> {
373 let tools = self.tools.read().await;
374 tools
375 .keys()
376 .filter(|tool_name| self.should_include_tool(tool_name))
377 .map(|tool_name| format!("mcp__{}__{}", self.server_name, tool_name))
378 .collect()
379 }
380
381 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
382 let tools = self.tools.read().await;
383 tools
384 .values()
385 .filter(|tool| self.should_include_tool(&tool.name))
386 .map(|tool| self.mcp_tool_to_schema(tool))
387 .collect()
388 }
389
390 fn metadata(&self) -> BackendMetadata {
391 let mut metadata = BackendMetadata::new(self.server_name.clone(), "MCP".to_string());
392
393 match &self.transport {
394 McpTransport::Stdio { command, args } => {
395 metadata = metadata
396 .with_info("transport".to_string(), "stdio".to_string())
397 .with_info("command".to_string(), command.clone())
398 .with_info("args".to_string(), args.join(" "));
399 }
400 McpTransport::Tcp { host, port } => {
401 metadata = metadata
402 .with_info("transport".to_string(), "tcp".to_string())
403 .with_info("host".to_string(), host.clone())
404 .with_info("port".to_string(), port.to_string());
405 }
406 #[cfg(unix)]
407 McpTransport::Unix { path } => {
408 metadata = metadata
409 .with_info("transport".to_string(), "unix".to_string())
410 .with_info("path".to_string(), path.clone());
411 }
412 McpTransport::Sse { url, .. } => {
413 metadata = metadata
414 .with_info("transport".to_string(), "sse".to_string())
415 .with_info("url".to_string(), url.clone());
416 }
417 McpTransport::Http { url, .. } => {
418 metadata = metadata
419 .with_info("transport".to_string(), "http".to_string())
420 .with_info("url".to_string(), url.clone());
421 }
422 }
423
424 metadata
425 }
426
427 async fn health_check(&self) -> bool {
428 let service_guard = self.client.read().await;
430 service_guard.is_some()
431 }
432
433 async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
434 Ok(true)
437 }
438}
439
440impl Drop for McpBackend {
441 fn drop(&mut self) {
442 let service = self.client.clone();
444
445 tokio::spawn(async move {
446 if let Some(service) = service.write().await.take() {
447 let _ = service.cancel().await;
448 }
449 });
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_tool_name_extraction() {
459 let prefix = "mcp__test__";
460 let full_name = "mcp__test__some_tool";
461 let actual_name = if let Some(stripped) = full_name.strip_prefix(prefix) {
462 stripped
463 } else {
464 full_name
465 };
466
467 assert_eq!(actual_name, "some_tool");
468 }
469
470 #[test]
471 fn test_mcp_transport_serialization() {
472 let stdio = McpTransport::Stdio {
474 command: "python".to_string(),
475 args: vec!["-m".to_string(), "test_server".to_string()],
476 };
477 let json = serde_json::to_string(&stdio).unwrap();
478 assert!(json.contains("\"type\":\"stdio\""));
479 assert!(json.contains("\"command\":\"python\""));
480
481 let tcp = McpTransport::Tcp {
483 host: "localhost".to_string(),
484 port: 3000,
485 };
486 let json = serde_json::to_string(&tcp).unwrap();
487 assert!(json.contains("\"type\":\"tcp\""));
488 assert!(json.contains("\"host\":\"localhost\""));
489 assert!(json.contains("\"port\":3000"));
490
491 #[cfg(unix)]
493 {
494 let unix = McpTransport::Unix {
495 path: "/tmp/test.sock".to_string(),
496 };
497 let json = serde_json::to_string(&unix).unwrap();
498 assert!(json.contains("\"type\":\"unix\""));
499 assert!(json.contains("\"path\":\"/tmp/test.sock\""));
500 }
501 }
502
503 #[test]
504 fn test_mcp_transport_deserialization() {
505 let json = r#"{"type":"stdio","command":"node","args":["server.js"]}"#;
507 let transport: McpTransport = serde_json::from_str(json).unwrap();
508 match transport {
509 McpTransport::Stdio { command, args } => {
510 assert_eq!(command, "node");
511 assert_eq!(args, vec!["server.js"]);
512 }
513 _ => unreachable!("Stdio transport"),
514 }
515
516 let json = r#"{"type":"tcp","host":"127.0.0.1","port":8080}"#;
518 let transport: McpTransport = serde_json::from_str(json).unwrap();
519 match transport {
520 McpTransport::Tcp { host, port } => {
521 assert_eq!(host, "127.0.0.1");
522 assert_eq!(port, 8080);
523 }
524 _ => unreachable!("TCP transport"),
525 }
526 }
527}