1use std::collections::HashMap;
7use std::process::Stdio;
8use std::sync::Arc;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
15use tokio::process::{Command, Child};
16use tokio::sync::RwLock;
17use tokio::time::timeout;
18
19use crate::types::errors::{Result, StrandsError};
20use crate::types::tools::ToolSpec;
21use crate::types::{ToolResultContent, ToolResultStatus};
22
23use super::{AgentTool, ToolContext, ToolResult2};
24
25#[derive(Debug, Clone)]
27pub struct MCPServerConfig {
28 pub name: String,
30 pub transport: MCPTransport,
32 pub timeout_secs: u64,
34}
35
36#[derive(Debug, Clone)]
38pub enum MCPTransport {
39 Stdio {
41 command: String,
42 args: Vec<String>,
43 env: HashMap<String, String>,
44 },
45 Sse {
47 url: String,
48 headers: HashMap<String, String>,
49 },
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct ToolFilters {
55 pub allowed: Vec<String>,
57 pub rejected: Vec<String>,
59}
60
61impl ToolFilters {
62 pub fn should_include(&self, tool_name: &str) -> bool {
64 if !self.allowed.is_empty() && !self.allowed.iter().any(|p| p == tool_name) {
65 return false;
66 }
67 if self.rejected.iter().any(|p| p == tool_name) {
68 return false;
69 }
70 true
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct MCPToolSpec {
77 pub name: String,
79 pub description: Option<String>,
81 #[serde(rename = "inputSchema")]
83 pub input_schema: serde_json::Value,
84 #[serde(rename = "outputSchema")]
86 pub output_schema: Option<serde_json::Value>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MCPToolResult {
92 pub status: String,
94 #[serde(rename = "toolUseId")]
96 pub tool_use_id: String,
97 pub content: Vec<MCPResultContent>,
99 #[serde(rename = "structuredContent")]
101 pub structured_content: Option<serde_json::Value>,
102 pub metadata: Option<serde_json::Value>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(untagged)]
109pub enum MCPResultContent {
110 Text { text: String },
112 Image { image: MCPImageContent },
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct MCPImageContent {
119 pub format: String,
121 pub source: MCPImageSource,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(untagged)]
128pub enum MCPImageSource {
129 Bytes { bytes: Vec<u8> },
131 Url { url: String },
133}
134
135#[async_trait]
137pub trait ToolProvider: Send + Sync {
138 async fn load_tools(&self) -> Result<Vec<Arc<dyn AgentTool>>>;
140
141 fn add_consumer(&self, consumer_id: &str);
143
144 fn remove_consumer(&self, consumer_id: &str);
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub enum ConnectionState {
151 Disconnected,
153 Connecting,
155 Connected,
157 Failed,
159}
160
161#[derive(Clone)]
163pub(crate) struct StdioHandles {
164 stdin: Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>,
165 stdout: Arc<tokio::sync::Mutex<BufReader<tokio::process::ChildStdout>>>,
166 timeout_secs: u64,
167}
168
169pub struct MCPClient {
171 config: MCPServerConfig,
172 tools: RwLock<HashMap<String, Arc<MCPAgentTool>>>,
173 state: RwLock<ConnectionState>,
174 consumers: RwLock<std::collections::HashSet<String>>,
175 filters: Option<ToolFilters>,
176 prefix: Option<String>,
177 stdio_process: RwLock<Option<Child>>,
178 stdio_handles: RwLock<Option<StdioHandles>>,
179}
180
181impl MCPClient {
182 pub fn new(config: MCPServerConfig) -> Self {
184 Self {
185 config,
186 tools: RwLock::new(HashMap::new()),
187 state: RwLock::new(ConnectionState::Disconnected),
188 consumers: RwLock::new(std::collections::HashSet::new()),
189 filters: None,
190 prefix: None,
191 stdio_process: RwLock::new(None),
192 stdio_handles: RwLock::new(None),
193 }
194 }
195
196 pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
198 Self::new(MCPServerConfig {
199 name: name.into(),
200 transport: MCPTransport::Stdio {
201 command: command.into(),
202 args,
203 env: HashMap::new(),
204 },
205 timeout_secs: 30,
206 })
207 }
208
209 pub fn sse(name: impl Into<String>, url: impl Into<String>) -> Self {
211 Self::new(MCPServerConfig {
212 name: name.into(),
213 transport: MCPTransport::Sse {
214 url: url.into(),
215 headers: HashMap::new(),
216 },
217 timeout_secs: 30,
218 })
219 }
220
221 pub fn with_filters(mut self, filters: ToolFilters) -> Self {
223 self.filters = Some(filters);
224 self
225 }
226
227 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
229 self.prefix = Some(prefix.into());
230 self
231 }
232
233 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
235 self.config.timeout_secs = timeout_secs;
236 self
237 }
238
239 pub fn name(&self) -> &str {
241 &self.config.name
242 }
243
244 pub async fn is_connected(&self) -> bool {
246 *self.state.read().await == ConnectionState::Connected
247 }
248
249 pub async fn connection_state(&self) -> ConnectionState {
251 *self.state.read().await
252 }
253
254 pub async fn connect(&self) -> Result<()> {
256 {
257 let mut state = self.state.write().await;
258 if *state == ConnectionState::Connected {
259 return Ok(());
260 }
261 *state = ConnectionState::Connecting;
262 }
263
264 let result = self.do_connect().await;
265
266 {
267 let mut state = self.state.write().await;
268 *state = if result.is_ok() {
269 ConnectionState::Connected
270 } else {
271 ConnectionState::Failed
272 };
273 }
274
275 result
276 }
277
278 async fn do_connect(&self) -> Result<()> {
279 match &self.config.transport {
280 MCPTransport::Stdio { command, args, env } => {
281 self.connect_stdio(command, args, env).await
282 }
283 MCPTransport::Sse { url, headers } => {
284 self.connect_sse(url, headers).await
285 }
286 }
287 }
288
289 async fn connect_stdio(
290 &self,
291 command: &str,
292 args: &[String],
293 env: &HashMap<String, String>,
294 ) -> Result<()> {
295 use tracing::debug;
296
297 debug!(
298 command = %command,
299 args = ?args,
300 "Starting MCP stdio transport"
301 );
302
303 let mut cmd = Command::new(command);
304 cmd.args(args);
305 cmd.envs(env);
306 cmd.stdin(Stdio::piped());
307 cmd.stdout(Stdio::piped());
308 cmd.stderr(Stdio::piped());
309
310 let mut child = cmd.spawn().map_err(|e| {
311 StrandsError::MCPClientInitializationError {
312 message: format!("Failed to spawn MCP server process: {}", e),
313 }
314 })?;
315
316 let stdin = child.stdin.take().ok_or_else(|| {
317 StrandsError::MCPClientInitializationError {
318 message: "Failed to acquire stdin handle".to_string(),
319 }
320 })?;
321
322 let stdout = child.stdout.take().ok_or_else(|| {
323 StrandsError::MCPClientInitializationError {
324 message: "Failed to acquire stdout handle".to_string(),
325 }
326 })?;
327
328 let stdin_handle = Arc::new(tokio::sync::Mutex::new(stdin));
329 let stdout_handle = Arc::new(tokio::sync::Mutex::new(BufReader::new(stdout)));
330
331 {
332 let mut process = self.stdio_process.write().await;
333 *process = Some(child);
334 }
335
336 {
337 let mut handles = self.stdio_handles.write().await;
338 *handles = Some(StdioHandles {
339 stdin: stdin_handle.clone(),
340 stdout: stdout_handle.clone(),
341 timeout_secs: self.config.timeout_secs,
342 });
343 }
344
345 let mut line_buf = String::new();
346
347 let init_request = json!({
348 "jsonrpc": "2.0",
349 "id": 1,
350 "method": "initialize",
351 "params": {
352 "protocolVersion": "2024-11-05",
353 "capabilities": {},
354 "clientInfo": {
355 "name": "strands-rs",
356 "version": "0.1.0"
357 }
358 }
359 });
360
361 let init_json = serde_json::to_string(&init_request).map_err(|e| {
362 StrandsError::MCPClientInitializationError {
363 message: format!("Failed to serialize init request: {}", e),
364 }
365 })?;
366
367 {
368 let mut stdin_guard = stdin_handle.lock().await;
369 stdin_guard
370 .write_all(format!("{}\n", init_json).as_bytes())
371 .await
372 .map_err(|e| StrandsError::MCPClientInitializationError {
373 message: format!("Failed to write init request: {}", e),
374 })?;
375
376 stdin_guard.flush().await.map_err(|e| {
377 StrandsError::MCPClientInitializationError {
378 message: format!("Failed to flush stdin: {}", e),
379 }
380 })?;
381 }
382
383 let read_result = {
384 let mut stdout_guard = stdout_handle.lock().await;
385 timeout(
386 Duration::from_secs(self.config.timeout_secs),
387 stdout_guard.read_line(&mut line_buf),
388 )
389 .await
390 };
391
392 match read_result {
393 Ok(Ok(0)) | Err(_) => {
394 return Err(StrandsError::MCPClientInitializationError {
395 message: "Timeout or EOF while waiting for initialize response".to_string(),
396 });
397 }
398 Ok(Ok(_)) => {
399 let init_response: serde_json::Value = serde_json::from_str(&line_buf)
400 .map_err(|e| StrandsError::MCPClientInitializationError {
401 message: format!("Failed to parse init response: {}", e),
402 })?;
403
404 debug!(response = ?init_response, "Received initialize response");
405 }
406 Ok(Err(e)) => {
407 return Err(StrandsError::MCPClientInitializationError {
408 message: format!("Failed to read init response: {}", e),
409 });
410 }
411 }
412
413 let initialized_notification = json!({
414 "jsonrpc": "2.0",
415 "method": "notifications/initialized"
416 });
417
418 let initialized_json = serde_json::to_string(&initialized_notification).map_err(|e| {
419 StrandsError::MCPClientInitializationError {
420 message: format!("Failed to serialize initialized notification: {}", e),
421 }
422 })?;
423
424 {
425 let mut stdin_guard = stdin_handle.lock().await;
426 stdin_guard
427 .write_all(format!("{}\n", initialized_json).as_bytes())
428 .await
429 .map_err(|e| StrandsError::MCPClientInitializationError {
430 message: format!("Failed to write initialized notification: {}", e),
431 })?;
432
433 stdin_guard.flush().await.map_err(|e| {
434 StrandsError::MCPClientInitializationError {
435 message: format!("Failed to flush stdin: {}", e),
436 }
437 })?;
438 }
439
440 let tools_list_request = json!({
441 "jsonrpc": "2.0",
442 "id": 2,
443 "method": "tools/list"
444 });
445
446 let tools_list_json = serde_json::to_string(&tools_list_request).map_err(|e| {
447 StrandsError::MCPClientInitializationError {
448 message: format!("Failed to serialize tools/list request: {}", e),
449 }
450 })?;
451
452 line_buf.clear();
453 {
454 let mut stdin_guard = stdin_handle.lock().await;
455 stdin_guard
456 .write_all(format!("{}\n", tools_list_json).as_bytes())
457 .await
458 .map_err(|e| StrandsError::MCPClientInitializationError {
459 message: format!("Failed to write tools/list request: {}", e),
460 })?;
461
462 stdin_guard.flush().await.map_err(|e| {
463 StrandsError::MCPClientInitializationError {
464 message: format!("Failed to flush stdin: {}", e),
465 }
466 })?;
467 }
468
469 let read_result = {
470 let mut stdout_guard = stdout_handle.lock().await;
471 timeout(
472 Duration::from_secs(self.config.timeout_secs),
473 stdout_guard.read_line(&mut line_buf),
474 )
475 .await
476 };
477
478 match read_result {
479 Ok(Ok(0)) | Err(_) => {
480 return Err(StrandsError::MCPClientInitializationError {
481 message: "Timeout or EOF while waiting for tools/list response".to_string(),
482 });
483 }
484 Ok(Ok(_)) => {
485 let tools_response: serde_json::Value = serde_json::from_str(&line_buf)
486 .map_err(|e| StrandsError::MCPClientInitializationError {
487 message: format!("Failed to parse tools/list response: {}", e),
488 })?;
489
490 debug!(response = ?tools_response, "Received tools/list response");
491
492 if let Some(result) = tools_response.get("result") {
493 if let Some(tools) = result.get("tools").and_then(|t| t.as_array()) {
494 let mut tools_map = self.tools.write().await;
495 for tool_value in tools {
496 if let Ok(tool_spec) = serde_json::from_value::<MCPToolSpec>(tool_value.clone()) {
497 let tool_name = if let Some(prefix) = &self.prefix {
498 format!("{}_{}", prefix, tool_spec.name)
499 } else {
500 tool_spec.name.clone()
501 };
502
503 if let Some(ref filters) = self.filters {
504 if !filters.should_include(&tool_spec.name) {
505 continue;
506 }
507 }
508
509 let handles = self.stdio_handles.read().await.clone();
510 let mcp_tool = Arc::new(MCPAgentTool::new_stdio(
511 tool_spec.clone(),
512 handles,
513 self.prefix.clone(),
514 ));
515
516 tools_map.insert(tool_name, mcp_tool);
517 }
518 }
519 }
520 }
521 }
522 Ok(Err(e)) => {
523 return Err(StrandsError::MCPClientInitializationError {
524 message: format!("Failed to read tools/list response: {}", e),
525 });
526 }
527 }
528
529 let tool_count = self.tools.read().await.len();
530 debug!(
531 tool_count = tool_count,
532 "MCP stdio transport connected and tools loaded"
533 );
534
535 Ok(())
536 }
537
538 async fn connect_sse(&self, url: &str, headers: &HashMap<String, String>) -> Result<()> {
539 use reqwest::Client;
540
541 let client = Client::builder()
542 .timeout(Duration::from_secs(self.config.timeout_secs))
543 .build()
544 .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
545
546 let mut request = client.get(format!("{}/tools/list", url.trim_end_matches('/')));
547
548 for (key, value) in headers {
549 request = request.header(key, value);
550 }
551
552 let response = request
553 .send()
554 .await
555 .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
556
557 if !response.status().is_success() {
558 return Err(StrandsError::NetworkError(format!(
559 "MCP server returned status: {}",
560 response.status()
561 )));
562 }
563
564 #[derive(Deserialize)]
565 struct ListToolsResponse {
566 tools: Vec<MCPToolSpec>,
567 }
568
569 let list_response: ListToolsResponse = response
570 .json()
571 .await
572 .map_err(|e| StrandsError::NetworkError(format!("Failed to parse response: {e}")))?;
573
574 let mut tools = self.tools.write().await;
575 tools.clear();
576
577 for mcp_spec in list_response.tools {
578 let tool_name = if let Some(ref prefix) = self.prefix {
579 format!("{}_{}", prefix, mcp_spec.name)
580 } else {
581 mcp_spec.name.clone()
582 };
583
584 if let Some(ref filters) = self.filters {
585 if !filters.should_include(&tool_name) {
586 continue;
587 }
588 }
589
590 let agent_tool = MCPAgentTool::new(
591 mcp_spec,
592 url.to_string(),
593 headers.clone(),
594 self.config.timeout_secs,
595 self.prefix.clone(),
596 );
597
598 tools.insert(tool_name, Arc::new(agent_tool));
599 }
600
601 Ok(())
602 }
603
604 pub async fn disconnect(&self) -> Result<()> {
606 let mut state = self.state.write().await;
607 *state = ConnectionState::Disconnected;
608
609 let mut tools = self.tools.write().await;
610 tools.clear();
611
612 Ok(())
613 }
614
615 pub async fn tools(&self) -> Vec<Arc<dyn AgentTool>> {
617 let tools = self.tools.read().await;
618 tools.values().map(|t| t.clone() as Arc<dyn AgentTool>).collect()
619 }
620
621 pub async fn call_tool(
623 &self,
624 tool_use_id: &str,
625 name: &str,
626 arguments: &serde_json::Value,
627 ) -> Result<MCPToolResult> {
628 if !self.is_connected().await {
629 return Err(StrandsError::ConfigurationError {
630 message: "MCP client is not connected".to_string(),
631 });
632 }
633
634 let tools = self.tools.read().await;
635 let tool = tools.get(name).ok_or_else(|| StrandsError::ToolNotFound {
636 tool_name: name.to_string(),
637 })?;
638
639 tool.call_mcp(tool_use_id, arguments).await
640 }
641}
642
643#[async_trait]
644impl ToolProvider for MCPClient {
645 async fn load_tools(&self) -> Result<Vec<Arc<dyn AgentTool>>> {
646 if !self.is_connected().await {
647 self.connect().await?;
648 }
649 Ok(self.tools().await)
650 }
651
652 fn add_consumer(&self, consumer_id: &str) {
653 if let Ok(mut consumers) = self.consumers.try_write() {
654 consumers.insert(consumer_id.to_string());
655 }
656 }
657
658 fn remove_consumer(&self, consumer_id: &str) {
659 if let Ok(mut consumers) = self.consumers.try_write() {
660 consumers.remove(consumer_id);
661 }
662 }
663}
664
665pub struct MCPAgentTool {
667 mcp_spec: MCPToolSpec,
668 server_url: String,
669 headers: HashMap<String, String>,
670 timeout_secs: u64,
671 name_override: Option<String>,
672 stdio_handles: Option<StdioHandles>,
673}
674
675impl MCPAgentTool {
676 pub fn new(
678 mcp_spec: MCPToolSpec,
679 server_url: String,
680 headers: HashMap<String, String>,
681 timeout_secs: u64,
682 prefix: Option<String>,
683 ) -> Self {
684 let name_override = prefix.map(|p| format!("{}_{}", p, mcp_spec.name));
685 Self {
686 mcp_spec,
687 server_url,
688 headers,
689 timeout_secs,
690 name_override,
691 stdio_handles: None,
692 }
693 }
694
695 pub(crate) fn new_stdio(
697 mcp_spec: MCPToolSpec,
698 stdio_handles: Option<StdioHandles>,
699 prefix: Option<String>,
700 ) -> Self {
701 let name_override = prefix.map(|p| format!("{}_{}", p, mcp_spec.name));
702 let timeout_secs = stdio_handles.as_ref().map(|h| h.timeout_secs).unwrap_or(30);
703 Self {
704 mcp_spec,
705 server_url: String::new(),
706 headers: HashMap::new(),
707 timeout_secs,
708 name_override,
709 stdio_handles,
710 }
711 }
712
713 pub async fn call_mcp(
715 &self,
716 tool_use_id: &str,
717 arguments: &serde_json::Value,
718 ) -> Result<MCPToolResult> {
719 if let Some(ref handles) = self.stdio_handles {
720 return self.call_mcp_stdio(tool_use_id, arguments, handles).await;
721 }
722
723 self.call_mcp_sse(tool_use_id, arguments).await
724 }
725
726 async fn call_mcp_stdio(
728 &self,
729 tool_use_id: &str,
730 arguments: &serde_json::Value,
731 handles: &StdioHandles,
732 ) -> Result<MCPToolResult> {
733 use std::sync::atomic::{AtomicU64, Ordering};
734 static REQUEST_ID: AtomicU64 = AtomicU64::new(1000);
735
736 let request_id = REQUEST_ID.fetch_add(1, Ordering::SeqCst);
737
738 let call_request = json!({
739 "jsonrpc": "2.0",
740 "id": request_id,
741 "method": "tools/call",
742 "params": {
743 "name": self.mcp_spec.name,
744 "arguments": arguments
745 }
746 });
747
748 let request_json = serde_json::to_string(&call_request).map_err(|e| {
749 StrandsError::ToolProviderError {
750 message: format!("Failed to serialize tool call request: {}", e),
751 }
752 })?;
753
754 {
755 let mut stdin_guard = handles.stdin.lock().await;
756 stdin_guard
757 .write_all(format!("{}\n", request_json).as_bytes())
758 .await
759 .map_err(|e| StrandsError::ToolProviderError {
760 message: format!("Failed to write tool call request: {}", e),
761 })?;
762
763 stdin_guard.flush().await.map_err(|e| {
764 StrandsError::ToolProviderError {
765 message: format!("Failed to flush stdin: {}", e),
766 }
767 })?;
768 }
769
770 let mut line_buf = String::new();
771 let read_result = {
772 let mut stdout_guard = handles.stdout.lock().await;
773 timeout(
774 Duration::from_secs(handles.timeout_secs),
775 stdout_guard.read_line(&mut line_buf),
776 )
777 .await
778 };
779
780 match read_result {
781 Ok(Ok(0)) | Err(_) => {
782 return Ok(MCPToolResult {
783 status: "error".to_string(),
784 tool_use_id: tool_use_id.to_string(),
785 content: vec![MCPResultContent::Text {
786 text: "Timeout or EOF while waiting for tool call response".to_string(),
787 }],
788 structured_content: None,
789 metadata: None,
790 });
791 }
792 Ok(Ok(_)) => {
793 let response: serde_json::Value = serde_json::from_str(&line_buf).map_err(|e| {
794 StrandsError::ToolProviderError {
795 message: format!("Failed to parse tool call response: {}", e),
796 }
797 })?;
798
799 if let Some(error) = response.get("error") {
800 return Ok(MCPToolResult {
801 status: "error".to_string(),
802 tool_use_id: tool_use_id.to_string(),
803 content: vec![MCPResultContent::Text {
804 text: format!("MCP error: {}", error),
805 }],
806 structured_content: None,
807 metadata: None,
808 });
809 }
810
811 if let Some(result) = response.get("result") {
812 #[derive(Deserialize)]
813 struct CallToolResult {
814 content: Vec<MCPResultContent>,
815 #[serde(rename = "isError")]
816 is_error: Option<bool>,
817 #[serde(rename = "structuredContent")]
818 structured_content: Option<serde_json::Value>,
819 #[serde(rename = "meta")]
820 metadata: Option<serde_json::Value>,
821 }
822
823 if let Ok(call_result) = serde_json::from_value::<CallToolResult>(result.clone()) {
824 return Ok(MCPToolResult {
825 status: if call_result.is_error.unwrap_or(false) {
826 "error"
827 } else {
828 "success"
829 }
830 .to_string(),
831 tool_use_id: tool_use_id.to_string(),
832 content: call_result.content,
833 structured_content: call_result.structured_content,
834 metadata: call_result.metadata,
835 });
836 }
837 }
838
839 Ok(MCPToolResult {
840 status: "error".to_string(),
841 tool_use_id: tool_use_id.to_string(),
842 content: vec![MCPResultContent::Text {
843 text: "Invalid response format from MCP server".to_string(),
844 }],
845 structured_content: None,
846 metadata: None,
847 })
848 }
849 Ok(Err(e)) => {
850 return Err(StrandsError::ToolProviderError {
851 message: format!("Failed to read tool call response: {}", e),
852 });
853 }
854 }
855 }
856
857 async fn call_mcp_sse(
859 &self,
860 tool_use_id: &str,
861 arguments: &serde_json::Value,
862 ) -> Result<MCPToolResult> {
863 use reqwest::Client;
864
865 let client = Client::builder()
866 .timeout(Duration::from_secs(self.timeout_secs))
867 .build()
868 .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
869
870 #[derive(Serialize)]
871 struct CallToolRequest<'a> {
872 name: &'a str,
873 arguments: &'a serde_json::Value,
874 }
875
876 let request_body = CallToolRequest {
877 name: &self.mcp_spec.name,
878 arguments,
879 };
880
881 let mut request = client
882 .post(format!("{}/tools/call", self.server_url.trim_end_matches('/')))
883 .json(&request_body);
884
885 for (key, value) in &self.headers {
886 request = request.header(key, value);
887 }
888
889 let response = request
890 .send()
891 .await
892 .map_err(|e| StrandsError::NetworkError(e.to_string()))?;
893
894 if !response.status().is_success() {
895 return Ok(MCPToolResult {
896 status: "error".to_string(),
897 tool_use_id: tool_use_id.to_string(),
898 content: vec![MCPResultContent::Text {
899 text: format!("MCP server returned status: {}", response.status()),
900 }],
901 structured_content: None,
902 metadata: None,
903 });
904 }
905
906 #[derive(Deserialize)]
907 struct CallToolResponse {
908 content: Vec<MCPResultContent>,
909 #[serde(rename = "isError")]
910 is_error: Option<bool>,
911 #[serde(rename = "structuredContent")]
912 structured_content: Option<serde_json::Value>,
913 #[serde(rename = "meta")]
914 metadata: Option<serde_json::Value>,
915 }
916
917 let call_response: CallToolResponse = response
918 .json()
919 .await
920 .map_err(|e| StrandsError::NetworkError(format!("Failed to parse response: {e}")))?;
921
922 Ok(MCPToolResult {
923 status: if call_response.is_error.unwrap_or(false) {
924 "error"
925 } else {
926 "success"
927 }
928 .to_string(),
929 tool_use_id: tool_use_id.to_string(),
930 content: call_response.content,
931 structured_content: call_response.structured_content,
932 metadata: call_response.metadata,
933 })
934 }
935
936}
937
938#[async_trait]
939impl AgentTool for MCPAgentTool {
940 fn name(&self) -> &str {
941 self.name_override.as_deref().unwrap_or(&self.mcp_spec.name)
942 }
943
944 fn description(&self) -> &str {
945 self.mcp_spec.description.as_deref().unwrap_or("MCP tool")
946 }
947
948 fn tool_spec(&self) -> ToolSpec {
949 let description = self
950 .mcp_spec
951 .description
952 .clone()
953 .unwrap_or_else(|| format!("Tool which performs {}", self.mcp_spec.name));
954
955 let mut spec = ToolSpec::new(self.name(), &description)
956 .with_input_schema(self.mcp_spec.input_schema.clone());
957
958 if let Some(ref output_schema) = self.mcp_spec.output_schema {
959 spec = spec.with_output_schema(output_schema.clone());
960 }
961
962 spec
963 }
964
965 fn tool_type(&self) -> &str {
966 "mcp"
967 }
968
969 async fn invoke(
970 &self,
971 input: serde_json::Value,
972 _context: &ToolContext,
973 ) -> std::result::Result<ToolResult2, String> {
974 use reqwest::Client;
975
976 let client = Client::builder()
977 .timeout(Duration::from_secs(self.timeout_secs))
978 .build()
979 .map_err(|e| e.to_string())?;
980
981 #[derive(Serialize)]
982 struct CallToolRequest<'a> {
983 name: &'a str,
984 arguments: &'a serde_json::Value,
985 }
986
987 let request_body = CallToolRequest {
988 name: &self.mcp_spec.name,
989 arguments: &input,
990 };
991
992 let mut request = client
993 .post(format!("{}/tools/call", self.server_url.trim_end_matches('/')))
994 .json(&request_body);
995
996 for (key, value) in &self.headers {
997 request = request.header(key, value);
998 }
999
1000 let response = request.send().await.map_err(|e| e.to_string())?;
1001
1002 if !response.status().is_success() {
1003 return Err(format!("MCP server returned status: {}", response.status()));
1004 }
1005
1006 #[derive(Deserialize)]
1007 struct CallToolResponse {
1008 content: Vec<MCPResultContent>,
1009 #[serde(rename = "isError")]
1010 is_error: Option<bool>,
1011 }
1012
1013 let call_response: CallToolResponse = response.json().await.map_err(|e| e.to_string())?;
1014
1015 let content: Vec<ToolResultContent> = call_response
1016 .content
1017 .into_iter()
1018 .map(|c| match c {
1019 MCPResultContent::Text { text } => ToolResultContent::text(text),
1020 MCPResultContent::Image { image } => ToolResultContent::json(serde_json::json!({
1021 "type": "image",
1022 "format": image.format,
1023 })),
1024 })
1025 .collect();
1026
1027 let status = if call_response.is_error.unwrap_or(false) {
1028 ToolResultStatus::Error
1029 } else {
1030 ToolResultStatus::Success
1031 };
1032
1033 Ok(ToolResult2 { status, content })
1034 }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 use super::*;
1040
1041 #[test]
1042 fn test_mcp_client_creation() {
1043 let client = MCPClient::stdio("test", "echo", vec!["hello".to_string()]);
1044 assert_eq!(client.name(), "test");
1045 }
1046
1047 #[test]
1048 fn test_mcp_sse_client() {
1049 let client = MCPClient::sse("test", "http://localhost:8080");
1050 match client.config.transport {
1051 MCPTransport::Sse { url, .. } => assert_eq!(url, "http://localhost:8080"),
1052 _ => panic!("expected SSE transport"),
1053 }
1054 }
1055
1056 #[test]
1057 fn test_tool_filters() {
1058 let filters = ToolFilters {
1059 allowed: vec!["tool_a".to_string(), "tool_b".to_string()],
1060 rejected: vec!["tool_b".to_string()],
1061 };
1062
1063 assert!(filters.should_include("tool_a"));
1064 assert!(!filters.should_include("tool_b"));
1065 assert!(!filters.should_include("tool_c"));
1066 }
1067
1068 #[test]
1069 fn test_mcp_client_with_options() {
1070 let client = MCPClient::sse("test", "http://localhost:8080")
1071 .with_prefix("my_prefix")
1072 .with_timeout(60)
1073 .with_filters(ToolFilters::default());
1074
1075 assert_eq!(client.config.timeout_secs, 60);
1076 assert_eq!(client.prefix, Some("my_prefix".to_string()));
1077 }
1078}