steer_core/tools/
mcp_backend.rs

1//! MCP backend implementation using the official rmcp crate
2//!
3//! This module provides the ToolBackend implementation for MCP servers.
4
5use async_trait::async_trait;
6use rmcp::transport::ConfigureCommandExt;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::process::Stdio;
11use std::sync::Arc;
12use tokio::process::Command;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info};
15
16use tokio::net::TcpStream;
17#[cfg(unix)]
18use tokio::net::UnixStream;
19
20use crate::api::ToolCall;
21use crate::session::state::ToolFilter;
22use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
23use steer_tools::{
24    InputSchema, ToolError, ToolSchema,
25    result::{ExternalResult, ToolResult},
26};
27
28use rmcp::{
29    model::{CallToolRequestParam, Tool},
30    service::{RoleClient, RunningService, ServiceExt},
31    transport::{SseClientTransport, StreamableHttpClientTransport, TokioChildProcess},
32};
33
34/// MCP transport configuration
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
36#[serde(tag = "type", rename_all = "snake_case")]
37pub enum McpTransport {
38    /// Standard I/O transport (child process)
39    Stdio { command: String, args: Vec<String> },
40    /// TCP transport
41    Tcp { host: String, port: u16 },
42    /// Unix domain socket transport
43    #[cfg(unix)]
44    Unix { path: String },
45    /// Server-Sent Events transport
46    Sse {
47        url: String,
48        #[serde(skip_serializing_if = "Option::is_none")]
49        headers: Option<HashMap<String, String>>,
50    },
51    /// HTTP streamable transport
52    Http {
53        url: String,
54        #[serde(skip_serializing_if = "Option::is_none")]
55        headers: Option<HashMap<String, String>>,
56    },
57}
58
59/// Tool backend for executing tools via MCP servers
60pub struct McpBackend {
61    server_name: String,
62    transport: McpTransport,
63    tool_filter: ToolFilter,
64    client: Arc<RwLock<Option<RunningService<RoleClient, ()>>>>,
65    tools: Arc<RwLock<HashMap<String, Tool>>>,
66}
67
68impl McpBackend {
69    /// Create a new MCP backend
70    pub async fn new(
71        server_name: String,
72        transport: McpTransport,
73        tool_filter: ToolFilter,
74    ) -> Result<Self, ToolError> {
75        info!(
76            "Creating MCP backend '{}' with transport: {:?}",
77            server_name, transport
78        );
79
80        let client = match &transport {
81            McpTransport::Stdio { command, args } => {
82                let (transport, stderr) =
83                    TokioChildProcess::builder(Command::new(command).configure(|cmd| {
84                        cmd.args(args);
85                    }))
86                    .stderr(Stdio::piped())
87                    .spawn()
88                    .map_err(|e| {
89                        error!("Failed to create MCP process: {}", e);
90                        ToolError::mcp_connection_failed(
91                            &server_name,
92                            format!("Failed to create MCP process: {e}"),
93                        )
94                    })?;
95
96                if let Some(stderr) = stderr {
97                    let server_name_for_logging = server_name.clone();
98                    tokio::spawn(async move {
99                        use tokio::io::{AsyncBufReadExt, BufReader};
100                        let mut reader = BufReader::new(stderr);
101                        let mut line = String::new();
102
103                        while let Ok(len) = reader.read_line(&mut line).await {
104                            if len == 0 {
105                                break;
106                            }
107                            debug!(
108                                target: "mcp_server",
109                                "[{}] {}",
110                                server_name_for_logging,
111                                line.trim()
112                            );
113                            line.clear();
114                        }
115                    });
116                }
117
118                ().serve(transport).await.map_err(|e| {
119                    error!("Failed to serve MCP: {}", e);
120                    ToolError::mcp_connection_failed(
121                        &server_name,
122                        format!("Failed to serve MCP: {e}"),
123                    )
124                })?
125            }
126            McpTransport::Tcp { host, port } => {
127                let stream = TcpStream::connect((host.as_str(), *port))
128                    .await
129                    .map_err(|e| {
130                        error!("Failed to connect to TCP MCP server: {}", e);
131                        ToolError::mcp_connection_failed(
132                            &server_name,
133                            format!("Failed to connect to {host}:{port} - {e}"),
134                        )
135                    })?;
136
137                ().serve(stream).await.map_err(|e| {
138                    error!("Failed to serve MCP over TCP: {}", e);
139                    ToolError::mcp_connection_failed(
140                        &server_name,
141                        format!("Failed to serve MCP over TCP: {e}"),
142                    )
143                })?
144            }
145            #[cfg(unix)]
146            McpTransport::Unix { path } => {
147                let stream = UnixStream::connect(path).await.map_err(|e| {
148                    error!("Failed to connect to Unix socket MCP server: {}", e);
149                    ToolError::mcp_connection_failed(
150                        &server_name,
151                        format!("Failed to connect to Unix socket {path} - {e}"),
152                    )
153                })?;
154
155                ().serve(stream).await.map_err(|e| {
156                    error!("Failed to serve MCP over Unix socket: {}", e);
157                    ToolError::mcp_connection_failed(
158                        &server_name,
159                        format!("Failed to serve MCP over Unix socket: {e}"),
160                    )
161                })?
162            }
163            McpTransport::Sse { url, headers } => {
164                // Use the dedicated SSE client transport for SSE connections
165                if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
166                    info!(
167                        "SSE transport with custom headers requested; headers may not be applied"
168                    );
169                }
170
171                let transport = SseClientTransport::start(url.clone()).await.map_err(|e| {
172                    error!("Failed to start SSE transport: {}", e);
173                    ToolError::mcp_connection_failed(
174                        &server_name,
175                        format!("Failed to start SSE transport: {e}"),
176                    )
177                })?;
178
179                ().serve(transport).await.map_err(|e| {
180                    error!("Failed to serve MCP over SSE: {}", e);
181                    ToolError::mcp_connection_failed(
182                        &server_name,
183                        format!("Failed to serve MCP over SSE: {e}"),
184                    )
185                })?
186            }
187            McpTransport::Http { url, headers } => {
188                // Use the simpler from_uri method
189                let transport = StreamableHttpClientTransport::from_uri(url.clone());
190
191                if headers.is_some() && !headers.as_ref().unwrap().is_empty() {
192                    info!(
193                        "HTTP transport with custom headers requested; headers may not be applied"
194                    );
195                }
196
197                ().serve(transport).await.map_err(|e| {
198                    error!("Failed to serve MCP over HTTP: {}", e);
199                    ToolError::mcp_connection_failed(
200                        &server_name,
201                        format!("Failed to serve MCP over HTTP: {e}"),
202                    )
203                })?
204            }
205        };
206
207        let server_info = client.peer_info();
208        info!("Connected to server: {server_info:#?}");
209
210        debug!("Attempting to list tools from MCP server '{}'", server_name);
211
212        let list_tools_timeout = std::time::Duration::from_secs(10);
213        let tool_list =
214            tokio::time::timeout(list_tools_timeout, client.list_tools(Default::default()))
215                .await
216                .map_err(|_| {
217                    ToolError::mcp_connection_failed(
218                        &server_name,
219                        "Timeout listing tools".to_string(),
220                    )
221                })?
222                .map_err(|e| {
223                    ToolError::mcp_connection_failed(
224                        &server_name,
225                        format!("Failed to list tools: {e}"),
226                    )
227                })?;
228
229        // Process the tools
230        let mut tools = HashMap::new();
231        for tool in tool_list.tools {
232            tools.insert(tool.name.to_string(), tool);
233        }
234
235        info!(
236            "Discovered {} tools from MCP server '{}': {}",
237            tools.len(),
238            server_name,
239            tools
240                .keys()
241                .map(|k| k.to_string())
242                .collect::<Vec<_>>()
243                .join(", ")
244        );
245
246        let backend = Self {
247            server_name,
248            transport,
249            tool_filter,
250            client: Arc::new(RwLock::new(Some(client))),
251            tools: Arc::new(RwLock::new(tools)),
252        };
253
254        Ok(backend)
255    }
256
257    /// Apply tool filter to determine if a tool should be included
258    fn should_include_tool(&self, tool_name: &str) -> bool {
259        match &self.tool_filter {
260            ToolFilter::All => true,
261            ToolFilter::Include(included) => included.contains(&tool_name.to_string()),
262            ToolFilter::Exclude(excluded) => !excluded.contains(&tool_name.to_string()),
263        }
264    }
265
266    fn mcp_tool_to_schema(&self, tool: &Tool) -> ToolSchema {
267        let description = match &tool.description {
268            Some(desc) if !desc.is_empty() => desc.to_string(),
269            _ => format!(
270                "Tool '{}' from MCP server '{}'",
271                tool.name, self.server_name
272            ),
273        };
274
275        // Convert Arc<Map> to InputSchema
276        let properties = (*tool.input_schema).clone();
277        let required = properties
278            .get("required")
279            .and_then(|v| v.as_array())
280            .map(|arr| {
281                arr.iter()
282                    .filter_map(|v| v.as_str().map(String::from))
283                    .collect()
284            })
285            .unwrap_or_default();
286
287        let input_schema = InputSchema {
288            properties: properties
289                .get("properties")
290                .and_then(|v| v.as_object())
291                .cloned()
292                .unwrap_or_default(),
293            required,
294            schema_type: "object".to_string(),
295        };
296
297        ToolSchema {
298            name: format!("mcp__{}__{}", self.server_name, tool.name),
299            description,
300            input_schema,
301        }
302    }
303}
304
305#[async_trait]
306impl ToolBackend for McpBackend {
307    async fn execute(
308        &self,
309        tool_call: &ToolCall,
310        _context: &ExecutionContext,
311    ) -> Result<ToolResult, ToolError> {
312        // Get the service
313        let service_guard = self.client.read().await;
314        let service = service_guard
315            .as_ref()
316            .ok_or_else(|| ToolError::execution("mcp", "MCP service not initialized"))?;
317
318        // Extract the actual tool name (remove mcp_servername_ prefix)
319        let prefix = format!("mcp__{}__", self.server_name);
320        let actual_tool_name = if tool_call.name.starts_with(&prefix) {
321            &tool_call.name[prefix.len()..]
322        } else {
323            &tool_call.name
324        };
325
326        debug!(
327            "Executing tool '{}' via MCP server '{}'",
328            actual_tool_name, self.server_name
329        );
330
331        // Convert parameters to a Map if it's an object
332        let arguments = if let Some(obj) = tool_call.parameters.as_object() {
333            Some(obj.clone())
334        } else if tool_call.parameters.is_null() {
335            None
336        } else {
337            return Err(ToolError::invalid_params(
338                &tool_call.name,
339                "Parameters must be an object",
340            ));
341        };
342
343        // Execute the tool
344        let result = service
345            .call_tool(CallToolRequestParam {
346                name: actual_tool_name.to_string().into(),
347                arguments,
348            })
349            .await
350            .map_err(|e| {
351                ToolError::execution(&tool_call.name, format!("Tool execution failed: {e}"))
352            })?;
353
354        // Convert result to string
355        let output = result
356            .content
357            .into_iter()
358            .map(|content| {
359                // Access the raw content
360                match &content.raw {
361                    rmcp::model::RawContent::Text(text_content) => text_content.text.to_string(),
362                    rmcp::model::RawContent::Image { .. } => "[Image content]".to_string(),
363                    rmcp::model::RawContent::Resource { .. } => "[Resource content]".to_string(),
364                    rmcp::model::RawContent::Audio { .. } => "[Audio content]".to_string(),
365                }
366            })
367            .collect::<Vec<_>>()
368            .join("\n");
369
370        // Return as external tool result
371        Ok(ToolResult::External(ExternalResult {
372            tool_name: tool_call.name.clone(),
373            payload: output,
374        }))
375    }
376
377    async fn supported_tools(&self) -> Vec<String> {
378        let tools = self.tools.read().await;
379        tools
380            .keys()
381            .filter(|tool_name| self.should_include_tool(tool_name))
382            .map(|tool_name| format!("mcp__{}__{}", self.server_name, tool_name))
383            .collect()
384    }
385
386    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
387        let tools = self.tools.read().await;
388        tools
389            .values()
390            .filter(|tool| self.should_include_tool(&tool.name))
391            .map(|tool| self.mcp_tool_to_schema(tool))
392            .collect()
393    }
394
395    fn metadata(&self) -> BackendMetadata {
396        let mut metadata = BackendMetadata::new(self.server_name.clone(), "MCP".to_string());
397
398        match &self.transport {
399            McpTransport::Stdio { command, args } => {
400                metadata = metadata
401                    .with_info("transport".to_string(), "stdio".to_string())
402                    .with_info("command".to_string(), command.clone())
403                    .with_info("args".to_string(), args.join(" "));
404            }
405            McpTransport::Tcp { host, port } => {
406                metadata = metadata
407                    .with_info("transport".to_string(), "tcp".to_string())
408                    .with_info("host".to_string(), host.clone())
409                    .with_info("port".to_string(), port.to_string());
410            }
411            #[cfg(unix)]
412            McpTransport::Unix { path } => {
413                metadata = metadata
414                    .with_info("transport".to_string(), "unix".to_string())
415                    .with_info("path".to_string(), path.clone());
416            }
417            McpTransport::Sse { url, .. } => {
418                metadata = metadata
419                    .with_info("transport".to_string(), "sse".to_string())
420                    .with_info("url".to_string(), url.clone());
421            }
422            McpTransport::Http { url, .. } => {
423                metadata = metadata
424                    .with_info("transport".to_string(), "http".to_string())
425                    .with_info("url".to_string(), url.clone());
426            }
427        }
428
429        metadata
430    }
431
432    async fn health_check(&self) -> bool {
433        // Check if service is connected
434        let service_guard = self.client.read().await;
435        service_guard.is_some()
436    }
437
438    async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
439        // MCP tools generally require approval unless we have specific information
440        // In the future, we could query the MCP server for tool metadata
441        Ok(true)
442    }
443}
444
445impl Drop for McpBackend {
446    fn drop(&mut self) {
447        // Schedule cleanup in a detached task
448        let service = self.client.clone();
449
450        tokio::spawn(async move {
451            if let Some(service) = service.write().await.take() {
452                let _ = service.cancel().await;
453            }
454        });
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_tool_name_extraction() {
464        let prefix = "mcp__test__";
465        let full_name = "mcp__test__some_tool";
466        let actual_name = if let Some(stripped) = full_name.strip_prefix(prefix) {
467            stripped
468        } else {
469            full_name
470        };
471
472        assert_eq!(actual_name, "some_tool");
473    }
474
475    #[test]
476    fn test_mcp_transport_serialization() {
477        // Test stdio transport
478        let stdio = McpTransport::Stdio {
479            command: "python".to_string(),
480            args: vec!["-m".to_string(), "test_server".to_string()],
481        };
482        let json = serde_json::to_string(&stdio).unwrap();
483        assert!(json.contains("\"type\":\"stdio\""));
484        assert!(json.contains("\"command\":\"python\""));
485
486        // Test TCP transport
487        let tcp = McpTransport::Tcp {
488            host: "localhost".to_string(),
489            port: 3000,
490        };
491        let json = serde_json::to_string(&tcp).unwrap();
492        assert!(json.contains("\"type\":\"tcp\""));
493        assert!(json.contains("\"host\":\"localhost\""));
494        assert!(json.contains("\"port\":3000"));
495
496        // Test Unix transport
497        #[cfg(unix)]
498        {
499            let unix = McpTransport::Unix {
500                path: "/tmp/test.sock".to_string(),
501            };
502            let json = serde_json::to_string(&unix).unwrap();
503            assert!(json.contains("\"type\":\"unix\""));
504            assert!(json.contains("\"path\":\"/tmp/test.sock\""));
505        }
506    }
507
508    #[test]
509    fn test_mcp_transport_deserialization() {
510        // Test stdio transport
511        let json = r#"{"type":"stdio","command":"node","args":["server.js"]}"#;
512        let transport: McpTransport = serde_json::from_str(json).unwrap();
513        match transport {
514            McpTransport::Stdio { command, args } => {
515                assert_eq!(command, "node");
516                assert_eq!(args, vec!["server.js"]);
517            }
518            _ => unreachable!("Stdio transport"),
519        }
520
521        // Test TCP transport
522        let json = r#"{"type":"tcp","host":"127.0.0.1","port":8080}"#;
523        let transport: McpTransport = serde_json::from_str(json).unwrap();
524        match transport {
525            McpTransport::Tcp { host, port } => {
526                assert_eq!(host, "127.0.0.1");
527                assert_eq!(port, 8080);
528            }
529            _ => unreachable!("TCP transport"),
530        }
531    }
532}