steer_core/tools/mcp/
backend.rs

1//! MCP backend implementation using the official rmcp crate
2//!
3//! This module provides the ToolBackend implementation for MCP servers.
4
5use crate::tools::mcp::error::McpError;
6use async_trait::async_trait;
7use rmcp::transport::ConfigureCommandExt;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Stdio;
12use std::sync::Arc;
13use tokio::process::Command;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info};
16
17use tokio::net::TcpStream;
18#[cfg(unix)]
19use tokio::net::UnixStream;
20
21use crate::api::ToolCall;
22use crate::session::state::ToolFilter;
23use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
24use steer_tools::{
25    InputSchema, ToolError, ToolSchema,
26    result::{ExternalResult, ToolResult},
27};
28
29use rmcp::{
30    model::{CallToolRequestParam, Tool},
31    service::{RoleClient, RunningService, ServiceExt},
32    transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess},
33};
34
35/// MCP transport configuration
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum McpTransport {
39    /// Standard I/O transport (child process)
40    Stdio { command: String, args: Vec<String> },
41    /// TCP transport
42    Tcp { host: String, port: u16 },
43    /// Unix domain socket transport
44    #[cfg(unix)]
45    Unix { path: String },
46    /// Server-Sent Events transport
47    Sse {
48        url: String,
49        #[serde(skip_serializing_if = "Option::is_none")]
50        headers: Option<HashMap<String, String>>,
51    },
52    /// HTTP streamable transport
53    Http {
54        url: String,
55        #[serde(skip_serializing_if = "Option::is_none")]
56        headers: Option<HashMap<String, String>>,
57    },
58}
59
60/// Tool backend for executing tools via MCP servers
61pub struct McpBackend {
62    server_name: String,
63    transport: McpTransport,
64    tool_filter: ToolFilter,
65    client: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
66    tools: Arc<RwLock<HashMap<String, Tool>>>,
67}
68
69impl McpBackend {
70    /// Create a new MCP backend
71    pub async fn new(
72        server_name: String,
73        transport: McpTransport,
74        tool_filter: ToolFilter,
75    ) -> Result<Self, McpError> {
76        info!(
77            "Creating MCP backend '{}' with transport: {:?}",
78            server_name, transport
79        );
80
81        let client = match &transport {
82            McpTransport::Stdio { command, args } => {
83                let (transport, stderr) =
84                    TokioChildProcess::builder(Command::new(command).configure(|cmd| {
85                        cmd.args(args);
86                    }))
87                    .stderr(Stdio::piped())
88                    .spawn()
89                    .map_err(|e| {
90                        error!("Failed to create MCP process: {}", e);
91                        McpError::ConnectionFailed {
92                            server_name: server_name.clone(),
93                            message: format!("Failed to create MCP process: {e}"),
94                        }
95                    })?;
96
97                if let Some(stderr) = stderr {
98                    let server_name_for_logging = server_name.clone();
99                    tokio::spawn(async move {
100                        use tokio::io::{AsyncBufReadExt, BufReader};
101                        let mut reader = BufReader::new(stderr);
102                        let mut line = String::new();
103
104                        while let Ok(len) = reader.read_line(&mut line).await {
105                            if len == 0 {
106                                break;
107                            }
108                            debug!(
109                                target: "mcp_server",
110                                "[{}] {}",
111                                server_name_for_logging,
112                                line.trim()
113                            );
114                            line.clear();
115                        }
116                    });
117                }
118
119                ().serve(transport).await.map_err(|e| {
120                    error!("Failed to serve MCP: {}", e);
121                    McpError::ServeFailed {
122                        transport: "stdio".to_string(),
123                        message: format!("Failed to serve MCP: {e}"),
124                    }
125                })?
126            }
127            McpTransport::Tcp { host, port } => {
128                let stream = TcpStream::connect((host.as_str(), *port))
129                    .await
130                    .map_err(|e| {
131                        error!("Failed to connect to TCP MCP server: {}", e);
132                        McpError::ConnectionFailed {
133                            server_name: server_name.clone(),
134                            message: format!("Failed to connect to {host}:{port} - {e}"),
135                        }
136                    })?;
137
138                ().serve(stream).await.map_err(|e| {
139                    error!("Failed to serve MCP over TCP: {}", e);
140                    McpError::ServeFailed {
141                        transport: "tcp".to_string(),
142                        message: format!("Failed to serve MCP over TCP: {e}"),
143                    }
144                })?
145            }
146            #[cfg(unix)]
147            McpTransport::Unix { path } => {
148                let stream = UnixStream::connect(path).await.map_err(|e| {
149                    error!("Failed to connect to Unix socket MCP server: {}", e);
150                    McpError::ConnectionFailed {
151                        server_name: server_name.clone(),
152                        message: format!("Failed to connect to Unix socket {path} - {e}"),
153                    }
154                })?;
155
156                ().serve(stream).await.map_err(|e| {
157                    error!("Failed to serve MCP over Unix socket: {}", e);
158                    McpError::ServeFailed {
159                        transport: "unix".to_string(),
160                        message: format!("Failed to serve MCP over Unix socket: {e}"),
161                    }
162                })?
163            }
164            McpTransport::Sse { url, headers } => {
165                // Use the dedicated SSE client transport for SSE connections
166                if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
167                    info!(
168                        "SSE transport with custom headers requested; headers may not be applied"
169                    );
170                }
171
172                let transport = SseClientTransport::start(url.clone()).await.map_err(|e| {
173                    error!("Failed to start SSE transport: {}", e);
174                    McpError::ConnectionFailed {
175                        server_name: server_name.clone(),
176                        message: format!("Failed to start SSE transport: {e}"),
177                    }
178                })?;
179
180                ().serve(transport).await.map_err(|e| {
181                    error!("Failed to serve MCP over SSE: {}", e);
182                    McpError::ServeFailed {
183                        transport: "sse".to_string(),
184                        message: format!("Failed to serve MCP over SSE: {e}"),
185                    }
186                })?
187            }
188            McpTransport::Http { url, headers } => {
189                // Use the simpler from_uri method
190                let transport = StreamableHttpClientTransport::from_uri(url.clone());
191
192                if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
193                    info!(
194                        "HTTP transport with custom headers requested; headers may not be applied"
195                    );
196                }
197
198                ().serve(transport).await.map_err(|e| {
199                    error!("Failed to serve MCP over HTTP: {}", e);
200                    McpError::ServeFailed {
201                        transport: "http".to_string(),
202                        message: format!("Failed to serve MCP over HTTP: {e}"),
203                    }
204                })?
205            }
206        };
207
208        let server_info = client.peer_info();
209        info!("Connected to server: {server_info:#?}");
210
211        debug!("Attempting to list tools from MCP server '{}'", server_name);
212
213        let list_tools_timeout = std::time::Duration::from_secs(10);
214        let tool_list =
215            tokio::time::timeout(list_tools_timeout, client.list_tools(Default::default()))
216                .await
217                .map_err(|_| McpError::ListToolsTimeout {
218                    server_name: server_name.clone(),
219                })?
220                .map_err(|e| McpError::ListToolsFailed {
221                    message: format!("Failed to list tools: {e}"),
222                })?;
223
224        // Process the tools
225        let mut tools = HashMap::new();
226        for tool in tool_list.tools {
227            tools.insert(tool.name.to_string(), tool);
228        }
229
230        info!(
231            "Discovered {} tools from MCP server '{}': {}",
232            tools.len(),
233            server_name,
234            tools
235                .keys()
236                .map(|k| k.to_string())
237                .collect::<Vec<_>>()
238                .join(", ")
239        );
240
241        let backend = Self {
242            server_name,
243            transport,
244            tool_filter,
245            client: Arc::new(RwLock::new(Some(client))),
246            tools: Arc::new(RwLock::new(tools)),
247        };
248
249        Ok(backend)
250    }
251
252    /// Apply tool filter to determine if a tool should be included
253    fn should_include_tool(&self, tool_name: &str) -> bool {
254        match &self.tool_filter {
255            ToolFilter::All => true,
256            ToolFilter::Include(included) => included.contains(&tool_name.to_string()),
257            ToolFilter::Exclude(excluded) => !excluded.contains(&tool_name.to_string()),
258        }
259    }
260
261    fn mcp_tool_to_schema(&self, tool: &Tool) -> ToolSchema {
262        let description = match &tool.description {
263            Some(desc) if !desc.is_empty() => desc.to_string(),
264            _ => format!(
265                "Tool '{}' from MCP server '{}'",
266                tool.name, self.server_name
267            ),
268        };
269
270        // Convert Arc<Map> to InputSchema
271        let properties = (*tool.input_schema).clone();
272        let required = properties
273            .get("required")
274            .and_then(|v| v.as_array())
275            .map(|arr| {
276                arr.iter()
277                    .filter_map(|v| v.as_str().map(String::from))
278                    .collect()
279            })
280            .unwrap_or_default();
281
282        let input_schema = InputSchema {
283            properties: properties
284                .get("properties")
285                .and_then(|v| v.as_object())
286                .cloned()
287                .unwrap_or_default(),
288            required,
289            schema_type: "object".to_string(),
290        };
291
292        ToolSchema {
293            name: format!("mcp__{}__{}", self.server_name, tool.name),
294            description,
295            input_schema,
296        }
297    }
298}
299
300#[async_trait]
301impl ToolBackend for McpBackend {
302    async fn execute(
303        &self,
304        tool_call: &ToolCall,
305        _context: &ExecutionContext,
306    ) -> Result<ToolResult, ToolError> {
307        // Get the service
308        let service_guard = self.client.read().await;
309        let service = service_guard
310            .as_ref()
311            .ok_or_else(|| ToolError::execution("mcp", "MCP service not initialized"))?;
312
313        // Extract the actual tool name (remove mcp_servername_ prefix)
314        let prefix = format!("mcp__{}__", self.server_name);
315        let actual_tool_name = if tool_call.name.starts_with(&prefix) {
316            &tool_call.name[prefix.len()..]
317        } else {
318            &tool_call.name
319        };
320
321        debug!(
322            "Executing tool '{}' via MCP server '{}'",
323            actual_tool_name, self.server_name
324        );
325
326        // Convert parameters to a Map if it's an object
327        let arguments = if let Some(obj) = tool_call.parameters.as_object() {
328            Some(obj.clone())
329        } else if tool_call.parameters.is_null() {
330            None
331        } else {
332            return Err(ToolError::invalid_params(
333                &tool_call.name,
334                "Parameters must be an object",
335            ));
336        };
337
338        // Execute the tool
339        let result = service
340            .call_tool(CallToolRequestParam {
341                name: actual_tool_name.to_string().into(),
342                arguments,
343            })
344            .await
345            .map_err(|e| {
346                ToolError::execution(&tool_call.name, format!("Tool execution failed: {e}"))
347            })?;
348
349        // Convert result to string
350        let output = result
351            .content
352            .into_iter()
353            .map(|content| {
354                // Access the raw content
355                match &content.raw {
356                    rmcp::model::RawContent::Text(text_content) => text_content.text.to_string(),
357                    rmcp::model::RawContent::Image { .. } => "[Image content]".to_string(),
358                    rmcp::model::RawContent::Resource { .. } => "[Resource content]".to_string(),
359                    rmcp::model::RawContent::Audio { .. } => "[Audio content]".to_string(),
360                }
361            })
362            .collect::<Vec<_>>()
363            .join("\n");
364
365        // Return as external tool result
366        Ok(ToolResult::External(ExternalResult {
367            tool_name: tool_call.name.clone(),
368            payload: output,
369        }))
370    }
371
372    async fn supported_tools(&self) -> Vec<String> {
373        let tools = self.tools.read().await;
374        tools
375            .keys()
376            .filter(|tool_name| self.should_include_tool(tool_name))
377            .map(|tool_name| format!("mcp__{}__{}", self.server_name, tool_name))
378            .collect()
379    }
380
381    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
382        let tools = self.tools.read().await;
383        tools
384            .values()
385            .filter(|tool| self.should_include_tool(&tool.name))
386            .map(|tool| self.mcp_tool_to_schema(tool))
387            .collect()
388    }
389
390    fn metadata(&self) -> BackendMetadata {
391        let mut metadata = BackendMetadata::new(self.server_name.clone(), "MCP".to_string());
392
393        match &self.transport {
394            McpTransport::Stdio { command, args } => {
395                metadata = metadata
396                    .with_info("transport".to_string(), "stdio".to_string())
397                    .with_info("command".to_string(), command.clone())
398                    .with_info("args".to_string(), args.join(" "));
399            }
400            McpTransport::Tcp { host, port } => {
401                metadata = metadata
402                    .with_info("transport".to_string(), "tcp".to_string())
403                    .with_info("host".to_string(), host.clone())
404                    .with_info("port".to_string(), port.to_string());
405            }
406            #[cfg(unix)]
407            McpTransport::Unix { path } => {
408                metadata = metadata
409                    .with_info("transport".to_string(), "unix".to_string())
410                    .with_info("path".to_string(), path.clone());
411            }
412            McpTransport::Sse { url, .. } => {
413                metadata = metadata
414                    .with_info("transport".to_string(), "sse".to_string())
415                    .with_info("url".to_string(), url.clone());
416            }
417            McpTransport::Http { url, .. } => {
418                metadata = metadata
419                    .with_info("transport".to_string(), "http".to_string())
420                    .with_info("url".to_string(), url.clone());
421            }
422        }
423
424        metadata
425    }
426
427    async fn health_check(&self) -> bool {
428        // Check if service is connected
429        let service_guard = self.client.read().await;
430        service_guard.is_some()
431    }
432
433    async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
434        // MCP tools generally require approval unless we have specific information
435        // In the future, we could query the MCP server for tool metadata
436        Ok(true)
437    }
438}
439
440impl Drop for McpBackend {
441    fn drop(&mut self) {
442        // Schedule cleanup in a detached task
443        let service = self.client.clone();
444
445        tokio::spawn(async move {
446            if let Some(service) = service.write().await.take() {
447                let _ = service.cancel().await;
448            }
449        });
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_tool_name_extraction() {
459        let prefix = "mcp__test__";
460        let full_name = "mcp__test__some_tool";
461        let actual_name = if let Some(stripped) = full_name.strip_prefix(prefix) {
462            stripped
463        } else {
464            full_name
465        };
466
467        assert_eq!(actual_name, "some_tool");
468    }
469
470    #[test]
471    fn test_mcp_transport_serialization() {
472        // Test stdio transport
473        let stdio = McpTransport::Stdio {
474            command: "python".to_string(),
475            args: vec!["-m".to_string(), "test_server".to_string()],
476        };
477        let json = serde_json::to_string(&stdio).unwrap();
478        assert!(json.contains("\"type\":\"stdio\""));
479        assert!(json.contains("\"command\":\"python\""));
480
481        // Test TCP transport
482        let tcp = McpTransport::Tcp {
483            host: "localhost".to_string(),
484            port: 3000,
485        };
486        let json = serde_json::to_string(&tcp).unwrap();
487        assert!(json.contains("\"type\":\"tcp\""));
488        assert!(json.contains("\"host\":\"localhost\""));
489        assert!(json.contains("\"port\":3000"));
490
491        // Test Unix transport
492        #[cfg(unix)]
493        {
494            let unix = McpTransport::Unix {
495                path: "/tmp/test.sock".to_string(),
496            };
497            let json = serde_json::to_string(&unix).unwrap();
498            assert!(json.contains("\"type\":\"unix\""));
499            assert!(json.contains("\"path\":\"/tmp/test.sock\""));
500        }
501    }
502
503    #[test]
504    fn test_mcp_transport_deserialization() {
505        // Test stdio transport
506        let json = r#"{"type":"stdio","command":"node","args":["server.js"]}"#;
507        let transport: McpTransport = serde_json::from_str(json).unwrap();
508        match transport {
509            McpTransport::Stdio { command, args } => {
510                assert_eq!(command, "node");
511                assert_eq!(args, vec!["server.js"]);
512            }
513            _ => unreachable!("Stdio transport"),
514        }
515
516        // Test TCP transport
517        let json = r#"{"type":"tcp","host":"127.0.0.1","port":8080}"#;
518        let transport: McpTransport = serde_json::from_str(json).unwrap();
519        match transport {
520            McpTransport::Tcp { host, port } => {
521                assert_eq!(host, "127.0.0.1");
522                assert_eq!(port, 8080);
523            }
524            _ => unreachable!("TCP transport"),
525        }
526    }
527}