strands_agents/tools/
mcp.rs

1//! Model Context Protocol (MCP) tool provider.
2//!
3//! This module provides integration with MCP servers for accessing
4//! remote tools through the Model Context Protocol.
5
6use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
15use tokio::process::{Command, Child};
16use tokio::sync::RwLock;
17use tokio::time::timeout;
18
19use crate::types::errors::{Result, StrandsError};
20use crate::types::tools::ToolSpec;
21use crate::types::{ToolResultContent, ToolResultStatus};
22
23use super::{AgentTool, ToolContext, ToolResult2};
24
25/// Configuration for an MCP server connection.
26#[derive(Debug, Clone)]
27pub struct MCPServerConfig {
28    /// Name identifier for this server.
29    pub name: String,
30    /// Transport configuration.
31    pub transport: MCPTransport,
32    /// Connection timeout in seconds.
33    pub timeout_secs: u64,
34}
35
36/// Transport configuration for MCP connections.
37#[derive(Debug, Clone)]
38pub enum MCPTransport {
39    /// Standard I/O transport (subprocess).
40    Stdio {
41        command: String,
42        args: Vec<String>,
43        env: HashMap<String, String>,
44    },
45    /// SSE transport (HTTP).
46    Sse {
47        url: String,
48        headers: HashMap<String, String>,
49    },
50}
51
52/// Filters for controlling which MCP tools are loaded.
53#[derive(Debug, Clone, Default)]
54pub struct ToolFilters {
55    /// Tools matching these patterns are included.
56    pub allowed: Vec<String>,
57    /// Tools matching these patterns are excluded.
58    pub rejected: Vec<String>,
59}
60
61impl ToolFilters {
62    /// Checks if a tool should be included based on filters.
63    pub fn should_include(&self, tool_name: &str) -> bool {
64        if !self.allowed.is_empty() && !self.allowed.iter().any(|p| p == tool_name) {
65            return false;
66        }
67        if self.rejected.iter().any(|p| p == tool_name) {
68            return false;
69        }
70        true
71    }
72}
73
74/// Tool specification from MCP server.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct MCPToolSpec {
77    /// Tool name.
78    pub name: String,
79    /// Tool description.
80    pub description: Option<String>,
81    /// Input schema.
82    #[serde(rename = "inputSchema")]
83    pub input_schema: serde_json::Value,
84    /// Output schema (optional).
85    #[serde(rename = "outputSchema")]
86    pub output_schema: Option<serde_json::Value>,
87}
88
89/// Result of an MCP tool call.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MCPToolResult {
92    /// Status of the call.
93    pub status: String,
94    /// Tool use ID.
95    #[serde(rename = "toolUseId")]
96    pub tool_use_id: String,
97    /// Result content.
98    pub content: Vec<MCPResultContent>,
99    /// Optional structured content.
100    #[serde(rename = "structuredContent")]
101    pub structured_content: Option<serde_json::Value>,
102    /// Optional metadata.
103    pub metadata: Option<serde_json::Value>,
104}
105
106/// Content from MCP tool result.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(untagged)]
109pub enum MCPResultContent {
110    /// Text content.
111    Text { text: String },
112    /// Image content.
113    Image { image: MCPImageContent },
114}
115
116/// Image content from MCP.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct MCPImageContent {
119    /// Image format.
120    pub format: String,
121    /// Image source.
122    pub source: MCPImageSource,
123}
124
125/// Image source from MCP.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(untagged)]
128pub enum MCPImageSource {
129    /// Base64 encoded bytes.
130    Bytes { bytes: Vec<u8> },
131    /// URL reference.
132    Url { url: String },
133}
134
135/// Trait for tool providers with lifecycle management.
136#[async_trait]
137pub trait ToolProvider: Send + Sync {
138    /// Loads and returns the tools in this provider.
139    async fn load_tools(&self) -> Result<Vec<Arc<dyn AgentTool>>>;
140
141    /// Adds a consumer to this tool provider.
142    fn add_consumer(&self, consumer_id: &str);
143
144    /// Removes a consumer from this tool provider.
145    fn remove_consumer(&self, consumer_id: &str);
146}
147
148/// State of the MCP connection.
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub enum ConnectionState {
151    /// Not connected.
152    Disconnected,
153    /// Currently connecting.
154    Connecting,
155    /// Connected and ready.
156    Connected,
157    /// Connection failed.
158    Failed,
159}
160
161/// Shared stdio handles for MCP stdio transport.
162#[derive(Clone)]
163pub(crate) struct StdioHandles {
164    stdin: Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>,
165    stdout: Arc<tokio::sync::Mutex<BufReader<tokio::process::ChildStdout>>>,
166    timeout_secs: u64,
167}
168
169/// An MCP client that provides tools from an MCP server.
170pub struct MCPClient {
171    config: MCPServerConfig,
172    tools: RwLock<HashMap<String, Arc<MCPAgentTool>>>,
173    state: RwLock<ConnectionState>,
174    consumers: RwLock<std::collections::HashSet<String>>,
175    filters: Option<ToolFilters>,
176    prefix: Option<String>,
177    stdio_process: RwLock<Option<Child>>,
178    stdio_handles: RwLock<Option<StdioHandles>>,
179}
180
181impl MCPClient {
182    /// Creates a new MCP client with the given configuration.
183    pub fn new(config: MCPServerConfig) -> Self {
184        Self {
185            config,
186            tools: RwLock::new(HashMap::new()),
187            state: RwLock::new(ConnectionState::Disconnected),
188            consumers: RwLock::new(std::collections::HashSet::new()),
189            filters: None,
190            prefix: None,
191            stdio_process: RwLock::new(None),
192            stdio_handles: RwLock::new(None),
193        }
194    }
195
196    /// Creates an MCP client with stdio transport.
197    pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
198        Self::new(MCPServerConfig {
199            name: name.into(),
200            transport: MCPTransport::Stdio {
201                command: command.into(),
202                args,
203                env: HashMap::new(),
204            },
205            timeout_secs: 30,
206        })
207    }
208
209    /// Creates an MCP client with SSE transport.
210    pub fn sse(name: impl Into<String>, url: impl Into<String>) -> Self {
211        Self::new(MCPServerConfig {
212            name: name.into(),
213            transport: MCPTransport::Sse {
214                url: url.into(),
215                headers: HashMap::new(),
216            },
217            timeout_secs: 30,
218        })
219    }
220
221    /// Sets tool filters for this client.
222    pub fn with_filters(mut self, filters: ToolFilters) -> Self {
223        self.filters = Some(filters);
224        self
225    }
226
227    /// Sets a prefix for tool names.
228    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
229        self.prefix = Some(prefix.into());
230        self
231    }
232
233    /// Sets the connection timeout.
234    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
235        self.config.timeout_secs = timeout_secs;
236        self
237    }
238
239    /// Returns the server name.
240    pub fn name(&self) -> &str {
241        &self.config.name
242    }
243
244    /// Checks if the client is connected.
245    pub async fn is_connected(&self) -> bool {
246        *self.state.read().await == ConnectionState::Connected
247    }
248
249    /// Returns the current connection state.
250    pub async fn connection_state(&self) -> ConnectionState {
251        *self.state.read().await
252    }
253
254    /// Connects to the MCP server and discovers available tools.
255    pub async fn connect(&self) -> Result<()> {
256        {
257            let mut state = self.state.write().await;
258            if *state == ConnectionState::Connected {
259                return Ok(());
260            }
261            *state = ConnectionState::Connecting;
262        }
263
264        let result = self.do_connect().await;
265
266        {
267            let mut state = self.state.write().await;
268            *state = if result.is_ok() {
269                ConnectionState::Connected
270            } else {
271                ConnectionState::Failed
272            };
273        }
274
275        result
276    }
277
278    async fn do_connect(&self) -> Result<()> {
279        match &self.config.transport {
280            MCPTransport::Stdio { command, args, env } => {
281                self.connect_stdio(command, args, env).await
282            }
283            MCPTransport::Sse { url, headers } => {
284                self.connect_sse(url, headers).await
285            }
286        }
287    }
288
289    async fn connect_stdio(
290        &self,
291        command: &str,
292        args: &[String],
293        env: &HashMap<String, String>,
294    ) -> Result<()> {
295        use tracing::debug;
296
297        debug!(
298            command = %command,
299            args = ?args,
300            "Starting MCP stdio transport"
301        );
302
303        let mut cmd = Command::new(command);
304        cmd.args(args);
305        cmd.envs(env);
306        cmd.stdin(Stdio::piped());
307        cmd.stdout(Stdio::piped());
308        cmd.stderr(Stdio::piped());
309
310        let mut child = cmd.spawn().map_err(|e| {
311            StrandsError::MCPClientInitializationError {
312                message: format!("Failed to spawn MCP server process: {}", e),
313            }
314        })?;
315
316        let stdin = child.stdin.take().ok_or_else(|| {
317            StrandsError::MCPClientInitializationError {
318                message: "Failed to acquire stdin handle".to_string(),
319            }
320        })?;
321
322        let stdout = child.stdout.take().ok_or_else(|| {
323            StrandsError::MCPClientInitializationError {
324                message: "Failed to acquire stdout handle".to_string(),
325            }
326        })?;
327
328        let stdin_handle = Arc::new(tokio::sync::Mutex::new(stdin));
329        let stdout_handle = Arc::new(tokio::sync::Mutex::new(BufReader::new(stdout)));
330
331        {
332            let mut process = self.stdio_process.write().await;
333            *process = Some(child);
334        }
335
336        {
337            let mut handles = self.stdio_handles.write().await;
338            *handles = Some(StdioHandles {
339                stdin: stdin_handle.clone(),
340                stdout: stdout_handle.clone(),
341                timeout_secs: self.config.timeout_secs,
342            });
343        }
344
345        let mut line_buf = String::new();
346
347        let init_request = json!({
348            "jsonrpc": "2.0",
349            "id": 1,
350            "method": "initialize",
351            "params": {
352                "protocolVersion": "2024-11-05",
353                "capabilities": {},
354                "clientInfo": {
355                    "name": "strands-rs",
356                    "version": "0.1.0"
357                }
358            }
359        });
360
361        let init_json = serde_json::to_string(&init_request).map_err(|e| {
362            StrandsError::MCPClientInitializationError {
363                message: format!("Failed to serialize init request: {}", e),
364            }
365        })?;
366
367        {
368            let mut stdin_guard = stdin_handle.lock().await;
369            stdin_guard
370                .write_all(format!("{}\n", init_json).as_bytes())
371                .await
372                .map_err(|e| StrandsError::MCPClientInitializationError {
373                    message: format!("Failed to write init request: {}", e),
374                })?;
375
376            stdin_guard.flush().await.map_err(|e| {
377                StrandsError::MCPClientInitializationError {
378                    message: format!("Failed to flush stdin: {}", e),
379                }
380            })?;
381        }
382
383        let read_result = {
384            let mut stdout_guard = stdout_handle.lock().await;
385            timeout(
386                Duration::from_secs(self.config.timeout_secs),
387                stdout_guard.read_line(&mut line_buf),
388            )
389            .await
390        };
391
392        match read_result {
393            Ok(Ok(0)) | Err(_) => {
394                return Err(StrandsError::MCPClientInitializationError {
395                    message: "Timeout or EOF while waiting for initialize response".to_string(),
396                });
397            }
398            Ok(Ok(_)) => {
399                let init_response: serde_json::Value = serde_json::from_str(&line_buf)
400                    .map_err(|e| StrandsError::MCPClientInitializationError {
401                        message: format!("Failed to parse init response: {}", e),
402                    })?;
403
404                debug!(response = ?init_response, "Received initialize response");
405            }
406            Ok(Err(e)) => {
407                return Err(StrandsError::MCPClientInitializationError {
408                    message: format!("Failed to read init response: {}", e),
409                });
410            }
411        }
412
413        let initialized_notification = json!({
414            "jsonrpc": "2.0",
415            "method": "notifications/initialized"
416        });
417
418        let initialized_json = serde_json::to_string(&initialized_notification).map_err(|e| {
419            StrandsError::MCPClientInitializationError {
420                message: format!("Failed to serialize initialized notification: {}", e),
421            }
422        })?;
423
424        {
425            let mut stdin_guard = stdin_handle.lock().await;
426            stdin_guard
427                .write_all(format!("{}\n", initialized_json).as_bytes())
428                .await
429                .map_err(|e| StrandsError::MCPClientInitializationError {
430                    message: format!("Failed to write initialized notification: {}", e),
431                })?;
432
433            stdin_guard.flush().await.map_err(|e| {
434                StrandsError::MCPClientInitializationError {
435                    message: format!("Failed to flush stdin: {}", e),
436                }
437            })?;
438        }
439
440        let tools_list_request = json!({
441            "jsonrpc": "2.0",
442            "id": 2,
443            "method": "tools/list"
444        });
445
446        let tools_list_json = serde_json::to_string(&tools_list_request).map_err(|e| {
447            StrandsError::MCPClientInitializationError {
448                message: format!("Failed to serialize tools/list request: {}", e),
449            }
450        })?;
451
452        line_buf.clear();
453        {
454            let mut stdin_guard = stdin_handle.lock().await;
455            stdin_guard
456                .write_all(format!("{}\n", tools_list_json).as_bytes())
457                .await
458                .map_err(|e| StrandsError::MCPClientInitializationError {
459                    message: format!("Failed to write tools/list request: {}", e),
460                })?;
461
462            stdin_guard.flush().await.map_err(|e| {
463                StrandsError::MCPClientInitializationError {
464                    message: format!("Failed to flush stdin: {}", e),
465                }
466            })?;
467        }
468
469        let read_result = {
470            let mut stdout_guard = stdout_handle.lock().await;
471            timeout(
472                Duration::from_secs(self.config.timeout_secs),
473                stdout_guard.read_line(&mut line_buf),
474            )
475            .await
476        };
477
478        match read_result {
479            Ok(Ok(0)) | Err(_) => {
480                return Err(StrandsError::MCPClientInitializationError {
481                    message: "Timeout or EOF while waiting for tools/list response".to_string(),
482                });
483            }
484            Ok(Ok(_)) => {
485                let tools_response: serde_json::Value = serde_json::from_str(&line_buf)
486                    .map_err(|e| StrandsError::MCPClientInitializationError {
487                        message: format!("Failed to parse tools/list response: {}", e),
488                    })?;
489
490                debug!(response = ?tools_response, "Received tools/list response");
491
492                if let Some(result) = tools_response.get("result") {
493                    if let Some(tools) = result.get("tools").and_then(|t| t.as_array()) {
494                        let mut tools_map = self.tools.write().await;
495                        for tool_value in tools {
496                            if let Ok(tool_spec) = serde_json::from_value::<MCPToolSpec>(tool_value.clone()) {
497                                let tool_name = if let Some(prefix) = &self.prefix {
498                                    format!("{}_{}", prefix, tool_spec.name)
499                                } else {
500                                    tool_spec.name.clone()
501                                };
502
503                                if let Some(ref filters) = self.filters {
504                                    if !filters.should_include(&tool_spec.name) {
505                                        continue;
506                                    }
507                                }
508
509                                let handles = self.stdio_handles.read().await.clone();
510                                let mcp_tool = Arc::new(MCPAgentTool::new_stdio(
511                                    tool_spec.clone(),
512                                    handles,
513                                    self.prefix.clone(),
514                                ));
515
516                                tools_map.insert(tool_name, mcp_tool);
517                            }
518                        }
519                    }
520                }
521            }
522            Ok(Err(e)) => {
523                return Err(StrandsError::MCPClientInitializationError {
524                    message: format!("Failed to read tools/list response: {}", e),
525                });
526            }
527        }
528
529        let tool_count = self.tools.read().await.len();
530        debug!(
531            tool_count = tool_count,
532            "MCP stdio transport connected and tools loaded"
533        );
534
535        Ok(())
536    }
537
538    async fn connect_sse(&self, url: &str, headers: &HashMap<String, String>) -> Result<()> {
539        use reqwest::Client;
540
541        let client = Client::builder()
542            .timeout(Duration::from_secs(self.config.timeout_secs))
543            .build()
544            .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
545
546        let mut request = client.get(format!("{}/tools/list", url.trim_end_matches('/')));
547
548        for (key, value) in headers {
549            request = request.header(key, value);
550        }
551
552        let response = request
553            .send()
554            .await
555            .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
556
557        if !response.status().is_success() {
558            return Err(StrandsError::NetworkError(format!(
559                "MCP server returned status: {}",
560                response.status()
561            )));
562        }
563
564        #[derive(Deserialize)]
565        struct ListToolsResponse {
566            tools: Vec<MCPToolSpec>,
567        }
568
569        let list_response: ListToolsResponse = response
570            .json()
571            .await
572            .map_err(|e| StrandsError::NetworkError(format!("Failed to parse response: {e}")))?;
573
574        let mut tools = self.tools.write().await;
575        tools.clear();
576
577        for mcp_spec in list_response.tools {
578            let tool_name = if let Some(ref prefix) = self.prefix {
579                format!("{}_{}", prefix, mcp_spec.name)
580            } else {
581                mcp_spec.name.clone()
582            };
583
584            if let Some(ref filters) = self.filters {
585                if !filters.should_include(&tool_name) {
586                    continue;
587                }
588            }
589
590            let agent_tool = MCPAgentTool::new(
591                mcp_spec,
592                url.to_string(),
593                headers.clone(),
594                self.config.timeout_secs,
595                self.prefix.clone(),
596            );
597
598            tools.insert(tool_name, Arc::new(agent_tool));
599        }
600
601        Ok(())
602    }
603
604    /// Disconnects from the MCP server.
605    pub async fn disconnect(&self) -> Result<()> {
606        let mut state = self.state.write().await;
607        *state = ConnectionState::Disconnected;
608
609        let mut tools = self.tools.write().await;
610        tools.clear();
611
612        Ok(())
613    }
614
615    /// Returns the list of available tools from the MCP server.
616    pub async fn tools(&self) -> Vec<Arc<dyn AgentTool>> {
617        let tools = self.tools.read().await;
618        tools.values().map(|t| t.clone() as Arc<dyn AgentTool>).collect()
619    }
620
621    /// Calls a tool on the MCP server.
622    pub async fn call_tool(
623        &self,
624        tool_use_id: &str,
625        name: &str,
626        arguments: &serde_json::Value,
627    ) -> Result<MCPToolResult> {
628        if !self.is_connected().await {
629            return Err(StrandsError::ConfigurationError {
630                message: "MCP client is not connected".to_string(),
631            });
632        }
633
634        let tools = self.tools.read().await;
635        let tool = tools.get(name).ok_or_else(|| StrandsError::ToolNotFound {
636            tool_name: name.to_string(),
637        })?;
638
639        tool.call_mcp(tool_use_id, arguments).await
640    }
641}
642
643#[async_trait]
644impl ToolProvider for MCPClient {
645    async fn load_tools(&self) -> Result<Vec<Arc<dyn AgentTool>>> {
646        if !self.is_connected().await {
647            self.connect().await?;
648        }
649        Ok(self.tools().await)
650    }
651
652    fn add_consumer(&self, consumer_id: &str) {
653        if let Ok(mut consumers) = self.consumers.try_write() {
654            consumers.insert(consumer_id.to_string());
655        }
656    }
657
658    fn remove_consumer(&self, consumer_id: &str) {
659        if let Ok(mut consumers) = self.consumers.try_write() {
660            consumers.remove(consumer_id);
661        }
662    }
663}
664
665/// An MCP tool wrapped as an AgentTool.
666pub struct MCPAgentTool {
667    mcp_spec: MCPToolSpec,
668    server_url: String,
669    headers: HashMap<String, String>,
670    timeout_secs: u64,
671    name_override: Option<String>,
672    stdio_handles: Option<StdioHandles>,
673}
674
675impl MCPAgentTool {
676    /// Creates a new MCP agent tool for SSE transport.
677    pub fn new(
678        mcp_spec: MCPToolSpec,
679        server_url: String,
680        headers: HashMap<String, String>,
681        timeout_secs: u64,
682        prefix: Option<String>,
683    ) -> Self {
684        let name_override = prefix.map(|p| format!("{}_{}", p, mcp_spec.name));
685        Self {
686            mcp_spec,
687            server_url,
688            headers,
689            timeout_secs,
690            name_override,
691            stdio_handles: None,
692        }
693    }
694
695    /// Creates a new MCP agent tool for stdio transport.
696    pub(crate) fn new_stdio(
697        mcp_spec: MCPToolSpec,
698        stdio_handles: Option<StdioHandles>,
699        prefix: Option<String>,
700    ) -> Self {
701        let name_override = prefix.map(|p| format!("{}_{}", p, mcp_spec.name));
702        let timeout_secs = stdio_handles.as_ref().map(|h| h.timeout_secs).unwrap_or(30);
703        Self {
704            mcp_spec,
705            server_url: String::new(),
706            headers: HashMap::new(),
707            timeout_secs,
708            name_override,
709            stdio_handles,
710        }
711    }
712
713    /// Calls the MCP server to execute this tool.
714    pub async fn call_mcp(
715        &self,
716        tool_use_id: &str,
717        arguments: &serde_json::Value,
718    ) -> Result<MCPToolResult> {
719        if let Some(ref handles) = self.stdio_handles {
720            return self.call_mcp_stdio(tool_use_id, arguments, handles).await;
721        }
722
723        self.call_mcp_sse(tool_use_id, arguments).await
724    }
725
726    /// Calls the MCP server via stdio transport.
727    async fn call_mcp_stdio(
728        &self,
729        tool_use_id: &str,
730        arguments: &serde_json::Value,
731        handles: &StdioHandles,
732    ) -> Result<MCPToolResult> {
733        use std::sync::atomic::{AtomicU64, Ordering};
734        static REQUEST_ID: AtomicU64 = AtomicU64::new(1000);
735
736        let request_id = REQUEST_ID.fetch_add(1, Ordering::SeqCst);
737
738        let call_request = json!({
739            "jsonrpc": "2.0",
740            "id": request_id,
741            "method": "tools/call",
742            "params": {
743                "name": self.mcp_spec.name,
744                "arguments": arguments
745            }
746        });
747
748        let request_json = serde_json::to_string(&call_request).map_err(|e| {
749            StrandsError::ToolProviderError {
750                message: format!("Failed to serialize tool call request: {}", e),
751            }
752        })?;
753
754        {
755            let mut stdin_guard = handles.stdin.lock().await;
756            stdin_guard
757                .write_all(format!("{}\n", request_json).as_bytes())
758                .await
759                .map_err(|e| StrandsError::ToolProviderError {
760                    message: format!("Failed to write tool call request: {}", e),
761                })?;
762
763            stdin_guard.flush().await.map_err(|e| {
764                StrandsError::ToolProviderError {
765                    message: format!("Failed to flush stdin: {}", e),
766                }
767            })?;
768        }
769
770        let mut line_buf = String::new();
771        let read_result = {
772            let mut stdout_guard = handles.stdout.lock().await;
773            timeout(
774                Duration::from_secs(handles.timeout_secs),
775                stdout_guard.read_line(&mut line_buf),
776            )
777            .await
778        };
779
780        match read_result {
781            Ok(Ok(0)) | Err(_) => {
782                return Ok(MCPToolResult {
783                    status: "error".to_string(),
784                    tool_use_id: tool_use_id.to_string(),
785                    content: vec![MCPResultContent::Text {
786                        text: "Timeout or EOF while waiting for tool call response".to_string(),
787                    }],
788                    structured_content: None,
789                    metadata: None,
790                });
791            }
792            Ok(Ok(_)) => {
793                let response: serde_json::Value = serde_json::from_str(&line_buf).map_err(|e| {
794                    StrandsError::ToolProviderError {
795                        message: format!("Failed to parse tool call response: {}", e),
796                    }
797                })?;
798
799                if let Some(error) = response.get("error") {
800                    return Ok(MCPToolResult {
801                        status: "error".to_string(),
802                        tool_use_id: tool_use_id.to_string(),
803                        content: vec![MCPResultContent::Text {
804                            text: format!("MCP error: {}", error),
805                        }],
806                        structured_content: None,
807                        metadata: None,
808                    });
809                }
810
811                if let Some(result) = response.get("result") {
812                    #[derive(Deserialize)]
813                    struct CallToolResult {
814                        content: Vec<MCPResultContent>,
815                        #[serde(rename = "isError")]
816                        is_error: Option<bool>,
817                        #[serde(rename = "structuredContent")]
818                        structured_content: Option<serde_json::Value>,
819                        #[serde(rename = "meta")]
820                        metadata: Option<serde_json::Value>,
821                    }
822
823                    if let Ok(call_result) = serde_json::from_value::<CallToolResult>(result.clone()) {
824                        return Ok(MCPToolResult {
825                            status: if call_result.is_error.unwrap_or(false) {
826                                "error"
827                            } else {
828                                "success"
829                            }
830                            .to_string(),
831                            tool_use_id: tool_use_id.to_string(),
832                            content: call_result.content,
833                            structured_content: call_result.structured_content,
834                            metadata: call_result.metadata,
835                        });
836                    }
837                }
838
839                Ok(MCPToolResult {
840                    status: "error".to_string(),
841                    tool_use_id: tool_use_id.to_string(),
842                    content: vec![MCPResultContent::Text {
843                        text: "Invalid response format from MCP server".to_string(),
844                    }],
845                    structured_content: None,
846                    metadata: None,
847                })
848            }
849            Ok(Err(e)) => {
850                return Err(StrandsError::ToolProviderError {
851                    message: format!("Failed to read tool call response: {}", e),
852                });
853            }
854        }
855    }
856
857    /// Calls the MCP server via SSE transport.
858    async fn call_mcp_sse(
859        &self,
860        tool_use_id: &str,
861        arguments: &serde_json::Value,
862    ) -> Result<MCPToolResult> {
863        use reqwest::Client;
864
865        let client = Client::builder()
866            .timeout(Duration::from_secs(self.timeout_secs))
867            .build()
868            .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
869
870        #[derive(Serialize)]
871        struct CallToolRequest<'a> {
872            name: &'a str,
873            arguments: &'a serde_json::Value,
874        }
875
876        let request_body = CallToolRequest {
877            name: &self.mcp_spec.name,
878            arguments,
879        };
880
881        let mut request = client
882            .post(format!("{}/tools/call", self.server_url.trim_end_matches('/')))
883            .json(&request_body);
884
885        for (key, value) in &self.headers {
886            request = request.header(key, value);
887        }
888
889        let response = request
890            .send()
891            .await
892            .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
893
894        if !response.status().is_success() {
895            return Ok(MCPToolResult {
896                status: "error".to_string(),
897                tool_use_id: tool_use_id.to_string(),
898                content: vec![MCPResultContent::Text {
899                    text: format!("MCP server returned status: {}", response.status()),
900                }],
901                structured_content: None,
902                metadata: None,
903            });
904        }
905
906        #[derive(Deserialize)]
907        struct CallToolResponse {
908            content: Vec<MCPResultContent>,
909            #[serde(rename = "isError")]
910            is_error: Option<bool>,
911            #[serde(rename = "structuredContent")]
912            structured_content: Option<serde_json::Value>,
913            #[serde(rename = "meta")]
914            metadata: Option<serde_json::Value>,
915        }
916
917        let call_response: CallToolResponse = response
918            .json()
919            .await
920            .map_err(|e| StrandsError::NetworkError(format!("Failed to parse response: {e}")))?;
921
922        Ok(MCPToolResult {
923            status: if call_response.is_error.unwrap_or(false) {
924                "error"
925            } else {
926                "success"
927            }
928            .to_string(),
929            tool_use_id: tool_use_id.to_string(),
930            content: call_response.content,
931            structured_content: call_response.structured_content,
932            metadata: call_response.metadata,
933        })
934    }
935
936}
937
938#[async_trait]
939impl AgentTool for MCPAgentTool {
940    fn name(&self) -> &str {
941        self.name_override.as_deref().unwrap_or(&self.mcp_spec.name)
942    }
943
944    fn description(&self) -> &str {
945        self.mcp_spec.description.as_deref().unwrap_or("MCP tool")
946    }
947
948    fn tool_spec(&self) -> ToolSpec {
949        let description = self
950            .mcp_spec
951            .description
952            .clone()
953            .unwrap_or_else(|| format!("Tool which performs {}", self.mcp_spec.name));
954
955        let mut spec = ToolSpec::new(self.name(), &description)
956            .with_input_schema(self.mcp_spec.input_schema.clone());
957
958        if let Some(ref output_schema) = self.mcp_spec.output_schema {
959            spec = spec.with_output_schema(output_schema.clone());
960        }
961
962        spec
963    }
964
965    fn tool_type(&self) -> &str {
966        "mcp"
967    }
968
969    async fn invoke(
970        &self,
971        input: serde_json::Value,
972        _context: &ToolContext,
973    ) -> std::result::Result<ToolResult2, String> {
974        use reqwest::Client;
975
976        let client = Client::builder()
977            .timeout(Duration::from_secs(self.timeout_secs))
978            .build()
979            .map_err(|e| e.to_string())?;
980
981        #[derive(Serialize)]
982        struct CallToolRequest<'a> {
983            name: &'a str,
984            arguments: &'a serde_json::Value,
985        }
986
987        let request_body = CallToolRequest {
988            name: &self.mcp_spec.name,
989            arguments: &input,
990        };
991
992        let mut request = client
993            .post(format!("{}/tools/call", self.server_url.trim_end_matches('/')))
994            .json(&request_body);
995
996        for (key, value) in &self.headers {
997            request = request.header(key, value);
998        }
999
1000        let response = request.send().await.map_err(|e| e.to_string())?;
1001
1002        if !response.status().is_success() {
1003            return Err(format!("MCP server returned status: {}", response.status()));
1004        }
1005
1006        #[derive(Deserialize)]
1007        struct CallToolResponse {
1008            content: Vec<MCPResultContent>,
1009            #[serde(rename = "isError")]
1010            is_error: Option<bool>,
1011        }
1012
1013        let call_response: CallToolResponse = response.json().await.map_err(|e| e.to_string())?;
1014
1015        let content: Vec<ToolResultContent> = call_response
1016            .content
1017            .into_iter()
1018            .map(|c| match c {
1019                MCPResultContent::Text { text } => ToolResultContent::text(text),
1020                MCPResultContent::Image { image } => ToolResultContent::json(serde_json::json!({
1021                    "type": "image",
1022                    "format": image.format,
1023                })),
1024            })
1025            .collect();
1026
1027        let status = if call_response.is_error.unwrap_or(false) {
1028            ToolResultStatus::Error
1029        } else {
1030            ToolResultStatus::Success
1031        };
1032
1033        Ok(ToolResult2 { status, content })
1034    }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039    use super::*;
1040
1041    #[test]
1042    fn test_mcp_client_creation() {
1043        let client = MCPClient::stdio("test", "echo", vec!["hello".to_string()]);
1044        assert_eq!(client.name(), "test");
1045    }
1046
1047    #[test]
1048    fn test_mcp_sse_client() {
1049        let client = MCPClient::sse("test", "http://localhost:8080");
1050        match client.config.transport {
1051            MCPTransport::Sse { url, .. } => assert_eq!(url, "http://localhost:8080"),
1052            _ => panic!("expected SSE transport"),
1053        }
1054    }
1055
1056    #[test]
1057    fn test_tool_filters() {
1058        let filters = ToolFilters {
1059            allowed: vec!["tool_a".to_string(), "tool_b".to_string()],
1060            rejected: vec!["tool_b".to_string()],
1061        };
1062
1063        assert!(filters.should_include("tool_a"));
1064        assert!(!filters.should_include("tool_b"));
1065        assert!(!filters.should_include("tool_c"));
1066    }
1067
1068    #[test]
1069    fn test_mcp_client_with_options() {
1070        let client = MCPClient::sse("test", "http://localhost:8080")
1071            .with_prefix("my_prefix")
1072            .with_timeout(60)
1073            .with_filters(ToolFilters::default());
1074
1075        assert_eq!(client.config.timeout_secs, 60);
1076        assert_eq!(client.prefix, Some("my_prefix".to_string()));
1077    }
1078}