Skip to main content

steer_core/tools/mcp/
backend.rs

1//! MCP backend implementation using the official rmcp crate
2//!
3//! This module provides the ToolBackend implementation for MCP servers.
4
5use crate::tools::mcp::error::McpError;
6use async_trait::async_trait;
7use rmcp::transport::ConfigureCommandExt;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::process::Stdio;
12use std::sync::Arc;
13use steer_tools::ToolCall;
14use tokio::net::TcpStream;
15#[cfg(unix)]
16use tokio::net::UnixStream;
17use tokio::process::Command;
18use tokio::sync::RwLock;
19use tracing::{debug, error, info};
20
21use crate::session::state::ToolFilter;
22use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
23use steer_tools::{
24    InputSchema, ToolError, ToolSchema,
25    result::{ExternalResult, ToolResult},
26};
27
28use rmcp::{
29    model::{CallToolRequestParam, Tool},
30    service::{RoleClient, RunningService, ServiceExt},
31    transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess},
32};
33
34/// MCP transport configuration
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum McpTransport {
38    /// Standard I/O transport (child process)
39    Stdio { command: String, args: Vec<String> },
40    /// TCP transport
41    Tcp { host: String, port: u16 },
42    /// Unix domain socket transport
43    #[cfg(unix)]
44    Unix { path: String },
45    /// Server-Sent Events transport
46    Sse {
47        url: String,
48        #[serde(skip_serializing_if = "Option::is_none")]
49        headers: Option<HashMap<String, String>>,
50    },
51    /// HTTP streamable transport
52    Http {
53        url: String,
54        #[serde(skip_serializing_if = "Option::is_none")]
55        headers: Option<HashMap<String, String>>,
56    },
57}
58
59/// Tool backend for executing tools via MCP servers
60pub struct McpBackend {
61    server_name: String,
62    transport: McpTransport,
63    tool_filter: ToolFilter,
64    client: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65    tools: Arc<RwLock<HashMap<String, Tool>>>,
66}
67
68impl McpBackend {
69    /// Create a new MCP backend
70    pub async fn new(
71        server_name: String,
72        transport: McpTransport,
73        tool_filter: ToolFilter,
74    ) -> Result<Self, McpError> {
75        info!(
76            "Creating MCP backend '{}' with transport: {:?}",
77            server_name, transport
78        );
79
80        let client = match &transport {
81            McpTransport::Stdio { command, args } => {
82                let (transport, stderr) =
83                    TokioChildProcess::builder(Command::new(command).configure(|cmd| {
84                        cmd.args(args);
85                    }))
86                    .stderr(Stdio::piped())
87                    .spawn()
88                    .map_err(|e| {
89                        error!("Failed to create MCP process: {}", e);
90                        McpError::ConnectionFailed {
91                            server_name: server_name.clone(),
92                            message: format!("Failed to create MCP process: {e}"),
93                        }
94                    })?;
95
96                if let Some(stderr) = stderr {
97                    let server_name_for_logging = server_name.clone();
98                    tokio::spawn(async move {
99                        use tokio::io::{AsyncBufReadExt, BufReader};
100                        let mut reader = BufReader::new(stderr);
101                        let mut line = String::new();
102
103                        while let Ok(len) = reader.read_line(&mut line).await {
104                            if len == 0 {
105                                break;
106                            }
107                            debug!(
108                                target: "mcp_server",
109                                "[{}] {}",
110                                server_name_for_logging,
111                                line.trim()
112                            );
113                            line.clear();
114                        }
115                    });
116                }
117
118                ().serve(transport).await.map_err(|e| {
119                    error!("Failed to serve MCP: {}", e);
120                    McpError::ServeFailed {
121                        transport: "stdio".to_string(),
122                        message: format!("Failed to serve MCP: {e}"),
123                    }
124                })?
125            }
126            McpTransport::Tcp { host, port } => {
127                let stream = TcpStream::connect((host.as_str(), *port))
128                    .await
129                    .map_err(|e| {
130                        error!("Failed to connect to TCP MCP server: {}", e);
131                        McpError::ConnectionFailed {
132                            server_name: server_name.clone(),
133                            message: format!("Failed to connect to {host}:{port} - {e}"),
134                        }
135                    })?;
136
137                ().serve(stream).await.map_err(|e| {
138                    error!("Failed to serve MCP over TCP: {}", e);
139                    McpError::ServeFailed {
140                        transport: "tcp".to_string(),
141                        message: format!("Failed to serve MCP over TCP: {e}"),
142                    }
143                })?
144            }
145            #[cfg(unix)]
146            McpTransport::Unix { path } => {
147                let stream = UnixStream::connect(path).await.map_err(|e| {
148                    error!("Failed to connect to Unix socket MCP server: {}", e);
149                    McpError::ConnectionFailed {
150                        server_name: server_name.clone(),
151                        message: format!("Failed to connect to Unix socket {path} - {e}"),
152                    }
153                })?;
154
155                ().serve(stream).await.map_err(|e| {
156                    error!("Failed to serve MCP over Unix socket: {}", e);
157                    McpError::ServeFailed {
158                        transport: "unix".to_string(),
159                        message: format!("Failed to serve MCP over Unix socket: {e}"),
160                    }
161                })?
162            }
163            McpTransport::Sse { url, headers } => {
164                // Use the dedicated SSE client transport for SSE connections
165                if headers.as_ref().is_some_and(|h| !h.is_empty()) {
166                    info!(
167                        "SSE transport with custom headers requested; headers may not be applied"
168                    );
169                }
170
171                let transport = SseClientTransport::start(url.clone()).await.map_err(|e| {
172                    error!("Failed to start SSE transport: {}", e);
173                    McpError::ConnectionFailed {
174                        server_name: server_name.clone(),
175                        message: format!("Failed to start SSE transport: {e}"),
176                    }
177                })?;
178
179                ().serve(transport).await.map_err(|e| {
180                    error!("Failed to serve MCP over SSE: {}", e);
181                    McpError::ServeFailed {
182                        transport: "sse".to_string(),
183                        message: format!("Failed to serve MCP over SSE: {e}"),
184                    }
185                })?
186            }
187            McpTransport::Http { url, headers } => {
188                // Use the simpler from_uri method
189                let transport = StreamableHttpClientTransport::from_uri(url.clone());
190
191                if headers.as_ref().is_some_and(|h| !h.is_empty()) {
192                    info!(
193                        "HTTP transport with custom headers requested; headers may not be applied"
194                    );
195                }
196
197                ().serve(transport).await.map_err(|e| {
198                    error!("Failed to serve MCP over HTTP: {}", e);
199                    McpError::ServeFailed {
200                        transport: "http".to_string(),
201                        message: format!("Failed to serve MCP over HTTP: {e}"),
202                    }
203                })?
204            }
205        };
206
207        let server_info = client.peer_info();
208        info!("Connected to server: {server_info:#?}");
209
210        debug!("Attempting to list tools from MCP server '{}'", server_name);
211
212        let list_tools_timeout = std::time::Duration::from_secs(10);
213        let tool_list =
214            tokio::time::timeout(list_tools_timeout, client.list_tools(Default::default()))
215                .await
216                .map_err(|_| McpError::ListToolsTimeout {
217                    server_name: server_name.clone(),
218                })?
219                .map_err(|e| McpError::ListToolsFailed {
220                    message: format!("Failed to list tools: {e}"),
221                })?;
222
223        // Process the tools
224        let mut tools = HashMap::new();
225        for tool in tool_list.tools {
226            tools.insert(tool.name.to_string(), tool);
227        }
228
229        info!(
230            "Discovered {} tools from MCP server '{}': {}",
231            tools.len(),
232            server_name,
233            tools.keys().cloned().collect::<Vec<_>>().join(", ")
234        );
235
236        let backend = Self {
237            server_name,
238            transport,
239            tool_filter,
240            client: Arc::new(RwLock::new(Some(client))),
241            tools: Arc::new(RwLock::new(tools)),
242        };
243
244        Ok(backend)
245    }
246
247    /// Apply tool filter to determine if a tool should be included
248    fn should_include_tool(&self, tool_name: &str) -> bool {
249        match &self.tool_filter {
250            ToolFilter::All => true,
251            ToolFilter::Include(included) => included.contains(&tool_name.to_string()),
252            ToolFilter::Exclude(excluded) => !excluded.contains(&tool_name.to_string()),
253        }
254    }
255
256    pub fn has_tool(&self, tool_name: &str) -> bool {
257        let prefixed_name = format!("mcp__{}__", self.server_name);
258        if tool_name.starts_with(&prefixed_name) {
259            let actual_name = &tool_name[prefixed_name.len()..];
260            if let Ok(tools) = self.tools.try_read() {
261                return tools.contains_key(actual_name);
262            }
263        }
264        false
265    }
266
267    fn mcp_tool_to_schema(&self, tool: &Tool) -> ToolSchema {
268        let display_name = tool
269            .annotations
270            .as_ref()
271            .and_then(|annotations| annotations.title.as_deref())
272            .filter(|title| !title.trim().is_empty())
273            .map_or_else(
274                || format!("{}: {}", self.server_name, tool.name),
275                |title| format!("{}: {}", self.server_name, title),
276            );
277
278        let description = match &tool.description {
279            Some(desc) if !desc.is_empty() => desc.to_string(),
280            _ => format!(
281                "Tool '{}' from MCP server '{}'",
282                tool.name, self.server_name
283            ),
284        };
285
286        // Convert Arc<Map> to InputSchema
287        let properties = (*tool.input_schema).clone();
288        let input_schema = InputSchema::from(serde_json::Value::Object(properties.clone()));
289
290        ToolSchema {
291            name: format!("mcp__{}__{}", self.server_name, tool.name),
292            display_name,
293            description,
294            input_schema,
295        }
296    }
297}
298
299#[async_trait]
300impl ToolBackend for McpBackend {
301    async fn execute(
302        &self,
303        tool_call: &ToolCall,
304        _context: &ExecutionContext,
305    ) -> Result<ToolResult, ToolError> {
306        // Get the service
307        let service_guard = self.client.read().await;
308        let service = service_guard
309            .as_ref()
310            .ok_or_else(|| ToolError::execution("mcp", "MCP service not initialized"))?;
311
312        // Extract the actual tool name (remove mcp_servername_ prefix)
313        let prefix = format!("mcp__{}__", self.server_name);
314        let actual_tool_name = if tool_call.name.starts_with(&prefix) {
315            &tool_call.name[prefix.len()..]
316        } else {
317            &tool_call.name
318        };
319
320        debug!(
321            "Executing tool '{}' via MCP server '{}'",
322            actual_tool_name, self.server_name
323        );
324
325        // Convert parameters to a Map if it's an object
326        let arguments = if let Some(obj) = tool_call.parameters.as_object() {
327            Some(obj.clone())
328        } else if tool_call.parameters.is_null() {
329            None
330        } else {
331            return Err(ToolError::invalid_params(
332                &tool_call.name,
333                "Parameters must be an object",
334            ));
335        };
336
337        // Execute the tool
338        let result = service
339            .call_tool(CallToolRequestParam {
340                name: actual_tool_name.to_string().into(),
341                arguments,
342            })
343            .await
344            .map_err(|e| {
345                ToolError::execution(&tool_call.name, format!("Tool execution failed: {e}"))
346            })?;
347
348        // Convert result to string
349        let output = result
350            .content
351            .into_iter()
352            .flat_map(|annotated_contents| annotated_contents.into_iter())
353            .map(|annotated| {
354                // Access the raw content from the Annotated wrapper
355                match annotated.raw {
356                    rmcp::model::RawContent::Text(text_content) => text_content.text.clone(),
357                    rmcp::model::RawContent::Image { .. } => "[Image content]".to_string(),
358                    rmcp::model::RawContent::Resource { .. } => "[Resource content]".to_string(),
359                    rmcp::model::RawContent::Audio { .. } => "[Audio content]".to_string(),
360                }
361            })
362            .collect::<Vec<_>>()
363            .join("\n");
364
365        // Return as external tool result
366        Ok(ToolResult::External(ExternalResult {
367            tool_name: tool_call.name.clone(),
368            payload: output,
369        }))
370    }
371
372    async fn supported_tools(&self) -> Vec<String> {
373        let tools = self.tools.read().await;
374        tools
375            .keys()
376            .filter(|tool_name| self.should_include_tool(tool_name))
377            .map(|tool_name| format!("mcp__{}__{}", self.server_name, tool_name))
378            .collect()
379    }
380
381    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
382        let tools = self.tools.read().await;
383        tools
384            .values()
385            .filter(|tool| self.should_include_tool(&tool.name))
386            .map(|tool| self.mcp_tool_to_schema(tool))
387            .collect()
388    }
389
390    fn metadata(&self) -> BackendMetadata {
391        let mut metadata = BackendMetadata::new(self.server_name.clone(), "MCP".to_string());
392
393        match &self.transport {
394            McpTransport::Stdio { command, args } => {
395                metadata = metadata
396                    .with_info("transport".to_string(), "stdio".to_string())
397                    .with_info("command".to_string(), command.clone())
398                    .with_info("args".to_string(), args.join(" "));
399            }
400            McpTransport::Tcp { host, port } => {
401                metadata = metadata
402                    .with_info("transport".to_string(), "tcp".to_string())
403                    .with_info("host".to_string(), host.clone())
404                    .with_info("port".to_string(), port.to_string());
405            }
406            #[cfg(unix)]
407            McpTransport::Unix { path } => {
408                metadata = metadata
409                    .with_info("transport".to_string(), "unix".to_string())
410                    .with_info("path".to_string(), path.clone());
411            }
412            McpTransport::Sse { url, .. } => {
413                metadata = metadata
414                    .with_info("transport".to_string(), "sse".to_string())
415                    .with_info("url".to_string(), url.clone());
416            }
417            McpTransport::Http { url, .. } => {
418                metadata = metadata
419                    .with_info("transport".to_string(), "http".to_string())
420                    .with_info("url".to_string(), url.clone());
421            }
422        }
423
424        metadata
425    }
426
427    async fn health_check(&self) -> bool {
428        let service_guard = self.client.read().await;
429        service_guard.is_some()
430    }
431
432    async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
433        // MCP tools generally require approval unless we have specific information
434        // In the future, we could query the MCP server for tool metadata
435        Ok(true)
436    }
437}
438
439impl Drop for McpBackend {
440    fn drop(&mut self) {
441        // Schedule cleanup in a detached task
442        let service = self.client.clone();
443
444        tokio::spawn(async move {
445            if let Some(service) = service.write().await.take() {
446                let _ = service.cancel().await;
447            }
448        });
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_tool_name_extraction() {
458        let prefix = "mcp__test__";
459        let full_name = "mcp__test__some_tool";
460        let actual_name = if let Some(stripped) = full_name.strip_prefix(prefix) {
461            stripped
462        } else {
463            full_name
464        };
465
466        assert_eq!(actual_name, "some_tool");
467    }
468
469    #[test]
470    fn test_mcp_transport_serialization() {
471        // Test stdio transport
472        let stdio = McpTransport::Stdio {
473            command: "python".to_string(),
474            args: vec!["-m".to_string(), "test_server".to_string()],
475        };
476        let json = serde_json::to_string(&stdio).unwrap();
477        assert!(json.contains("\"type\":\"stdio\""));
478        assert!(json.contains("\"command\":\"python\""));
479
480        // Test TCP transport
481        let tcp = McpTransport::Tcp {
482            host: "localhost".to_string(),
483            port: 3000,
484        };
485        let json = serde_json::to_string(&tcp).unwrap();
486        assert!(json.contains("\"type\":\"tcp\""));
487        assert!(json.contains("\"host\":\"localhost\""));
488        assert!(json.contains("\"port\":3000"));
489
490        // Test Unix transport
491        #[cfg(unix)]
492        {
493            let unix = McpTransport::Unix {
494                path: "/tmp/test.sock".to_string(),
495            };
496            let json = serde_json::to_string(&unix).unwrap();
497            assert!(json.contains("\"type\":\"unix\""));
498            assert!(json.contains("\"path\":\"/tmp/test.sock\""));
499        }
500    }
501
502    #[test]
503    fn test_mcp_transport_deserialization() {
504        // Test stdio transport
505        let json = r#"{"type":"stdio","command":"node","args":["server.js"]}"#;
506        let transport: McpTransport = serde_json::from_str(json).unwrap();
507        match transport {
508            McpTransport::Stdio { command, args } => {
509                assert_eq!(command, "node");
510                assert_eq!(args, vec!["server.js"]);
511            }
512            _ => unreachable!("Stdio transport"),
513        }
514
515        // Test TCP transport
516        let json = r#"{"type":"tcp","host":"127.0.0.1","port":8080}"#;
517        let transport: McpTransport = serde_json::from_str(json).unwrap();
518        match transport {
519            McpTransport::Tcp { host, port } => {
520                assert_eq!(host, "127.0.0.1");
521                assert_eq!(port, 8080);
522            }
523            _ => unreachable!("TCP transport"),
524        }
525    }
526}