1use crate::config::mcp::{
7 McpAllowListConfig, McpClientConfig, McpProviderConfig, McpTransportConfig,
8};
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use parking_lot::RwLock;
12use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
13use rmcp::{
14 ServiceExt,
15 handler::client::ClientHandler,
16 model::{
17 CallToolRequestParam, CallToolResult, CancelledNotificationParam, ClientCapabilities,
18 ClientInfo, CreateElicitationRequestParam, CreateElicitationResult, ElicitationAction,
19 Implementation, ListToolsResult, LoggingLevel, LoggingMessageNotificationParam,
20 ProgressNotificationParam, ResourceUpdatedNotificationParam, RootsCapabilities,
21 },
22 service::{NotificationContext, RequestContext, RoleClient},
23 transport::TokioChildProcess,
24};
25use serde_json::{Map, Value};
26use std::collections::HashMap;
27use std::future;
28use std::process::Stdio;
29use std::sync::Arc;
30use tokio::io::{AsyncBufReadExt, BufReader};
31use tokio::process::Command;
32use tokio::sync::Mutex;
33use tracing::{Level, debug, error, info, warn};
34
35#[derive(Clone)]
36struct LoggingClientHandler {
37 provider_name: String,
38 info: ClientInfo,
39}
40
41impl LoggingClientHandler {
42 fn new(provider_name: &str) -> Self {
43 let mut info = ClientInfo::default();
44 info.capabilities = ClientCapabilities {
45 roots: Some(RootsCapabilities {
46 list_changed: Some(true),
47 }),
48 ..ClientCapabilities::default()
49 };
50 info.client_info = Implementation {
51 name: "vtcode".to_string(),
52 title: Some("VT Code MCP client".to_string()),
53 version: env!("CARGO_PKG_VERSION").to_string(),
54 icons: None,
55 website_url: Some("https://github.com/modelcontextprotocol".to_string()),
56 };
57
58 Self {
59 provider_name: provider_name.to_string(),
60 info,
61 }
62 }
63
64 fn handle_logging(&self, params: LoggingMessageNotificationParam) {
65 let LoggingMessageNotificationParam {
66 level,
67 logger,
68 data,
69 } = params;
70 let payload = data;
71 let logger_name = logger.unwrap_or_else(|| "".to_string());
72 let summary = payload
73 .get("message")
74 .and_then(Value::as_str)
75 .map(str::to_owned)
76 .unwrap_or_else(|| payload.to_string());
77
78 match level {
79 LoggingLevel::Debug => debug!(
80 provider = self.provider_name.as_str(),
81 logger = logger_name.as_str(),
82 summary = %summary,
83 payload = ?payload,
84 "MCP provider log"
85 ),
86 LoggingLevel::Info | LoggingLevel::Notice => info!(
87 provider = self.provider_name.as_str(),
88 logger = logger_name.as_str(),
89 summary = %summary,
90 payload = ?payload,
91 "MCP provider log"
92 ),
93 LoggingLevel::Warning => warn!(
94 provider = self.provider_name.as_str(),
95 logger = logger_name.as_str(),
96 summary = %summary,
97 payload = ?payload,
98 "MCP provider warning"
99 ),
100 LoggingLevel::Error
101 | LoggingLevel::Critical
102 | LoggingLevel::Alert
103 | LoggingLevel::Emergency => error!(
104 provider = self.provider_name.as_str(),
105 logger = logger_name.as_str(),
106 summary = %summary,
107 payload = ?payload,
108 "MCP provider error"
109 ),
110 }
111 }
112}
113
114impl ClientHandler for LoggingClientHandler {
115 fn create_elicitation(
116 &self,
117 request: CreateElicitationRequestParam,
118 _context: RequestContext<RoleClient>,
119 ) -> impl std::future::Future<Output = Result<CreateElicitationResult, rmcp::ErrorData>> + Send + '_
120 {
121 let CreateElicitationRequestParam { message, .. } = request;
122 info!(
123 provider = self.provider_name.as_str(),
124 message = message.as_str(),
125 "MCP provider requested elicitation; declining"
126 );
127 future::ready(Ok(CreateElicitationResult {
128 action: ElicitationAction::Decline,
129 content: None,
130 }))
131 }
132
133 fn on_cancelled(
134 &self,
135 params: CancelledNotificationParam,
136 _context: NotificationContext<RoleClient>,
137 ) -> impl std::future::Future<Output = ()> + Send + '_ {
138 info!(
139 provider = self.provider_name.as_str(),
140 request_id = %params.request_id,
141 reason = ?params.reason,
142 "MCP provider cancelled request"
143 );
144 future::ready(())
145 }
146
147 fn on_progress(
148 &self,
149 params: ProgressNotificationParam,
150 _context: NotificationContext<RoleClient>,
151 ) -> impl std::future::Future<Output = ()> + Send + '_ {
152 info!(
153 provider = self.provider_name.as_str(),
154 progress_token = ?params.progress_token,
155 progress = params.progress,
156 total = ?params.total,
157 message = ?params.message,
158 "MCP provider progress update"
159 );
160 future::ready(())
161 }
162
163 fn on_resource_updated(
164 &self,
165 params: ResourceUpdatedNotificationParam,
166 _context: NotificationContext<RoleClient>,
167 ) -> impl std::future::Future<Output = ()> + Send + '_ {
168 info!(
169 provider = self.provider_name.as_str(),
170 uri = params.uri.as_str(),
171 "MCP provider resource updated"
172 );
173 future::ready(())
174 }
175
176 fn on_resource_list_changed(
177 &self,
178 _context: NotificationContext<RoleClient>,
179 ) -> impl std::future::Future<Output = ()> + Send + '_ {
180 info!(
181 provider = self.provider_name.as_str(),
182 "MCP provider resource list changed"
183 );
184 future::ready(())
185 }
186
187 fn on_tool_list_changed(
188 &self,
189 _context: NotificationContext<RoleClient>,
190 ) -> impl std::future::Future<Output = ()> + Send + '_ {
191 info!(
192 provider = self.provider_name.as_str(),
193 "MCP provider tool list changed"
194 );
195 future::ready(())
196 }
197
198 fn on_prompt_list_changed(
199 &self,
200 _context: NotificationContext<RoleClient>,
201 ) -> impl std::future::Future<Output = ()> + Send + '_ {
202 info!(
203 provider = self.provider_name.as_str(),
204 "MCP provider prompt list changed"
205 );
206 future::ready(())
207 }
208
209 fn on_logging_message(
210 &self,
211 params: LoggingMessageNotificationParam,
212 _context: NotificationContext<RoleClient>,
213 ) -> impl std::future::Future<Output = ()> + Send + '_ {
214 self.handle_logging(params);
215 future::ready(())
216 }
217
218 fn get_info(&self) -> ClientInfo {
219 self.info.clone()
220 }
221}
222
223pub struct McpClient {
225 config: McpClientConfig,
226 pub providers: HashMap<String, Arc<McpProvider>>,
227 active_connections: Arc<Mutex<HashMap<String, Arc<RunningMcpService>>>>,
228 allowlist: Arc<RwLock<McpAllowListConfig>>,
229 tool_provider_index: Arc<RwLock<HashMap<String, String>>>,
230}
231
232impl McpClient {
233 pub fn new(config: McpClientConfig) -> Self {
235 let allowlist = Arc::new(RwLock::new(config.allowlist.clone()));
236 Self {
237 config,
238 providers: HashMap::new(),
239 active_connections: Arc::new(Mutex::new(HashMap::new())),
240 allowlist,
241 tool_provider_index: Arc::new(RwLock::new(HashMap::new())),
242 }
243 }
244
245 fn record_tool_provider(&self, provider: &str, tool: &str) {
246 debug!("Recording tool '{}' -> provider '{}'", tool, provider);
247 self.tool_provider_index
248 .write()
249 .insert(tool.to_string(), provider.to_string());
250 }
251
252 pub fn provider_for_tool(&self, tool_name: &str) -> Option<String> {
254 let index = self.tool_provider_index.read();
255 if let Some(provider) = index.get(tool_name) {
256 if self.providers.contains_key(provider) {
258 debug!("Found tool '{}' in provider '{}'", tool_name, provider);
259 Some(provider.clone())
260 } else {
261 debug!(
262 "Tool '{}' references non-existent provider '{}'",
263 tool_name, provider
264 );
265 None
266 }
267 } else {
268 debug!("Tool '{}' not found in provider index", tool_name);
269 None
270 }
271 }
272
273 pub fn update_allowlist(&self, allowlist: McpAllowListConfig) {
275 *self.allowlist.write() = allowlist;
276 }
277
278 pub fn current_allowlist(&self) -> McpAllowListConfig {
280 self.allowlist.read().clone()
281 }
282
283 fn format_tool_result(
284 provider_name: &str,
285 tool_name: &str,
286 result: CallToolResult,
287 ) -> Result<Value> {
288 let is_error = result.is_error.unwrap_or(false);
289 let text_summary = result
290 .content
291 .iter()
292 .find_map(|content| content.as_text().map(|text| text.text.clone()));
293
294 if is_error {
295 let detail = result
296 .structured_content
297 .as_ref()
298 .and_then(|value| value.get("message").and_then(Value::as_str))
299 .map(str::to_owned)
300 .or_else(|| {
301 result
302 .structured_content
303 .as_ref()
304 .map(|value| value.to_string())
305 })
306 .or(text_summary)
307 .unwrap_or_else(|| "Unknown MCP tool error".to_string());
308
309 return Err(anyhow::anyhow!(
310 "MCP tool '{}' on provider '{}' reported an error: {}",
311 tool_name,
312 provider_name,
313 detail
314 ));
315 }
316
317 let mut payload = Map::new();
318 payload.insert("provider".into(), Value::String(provider_name.to_string()));
319 payload.insert("tool".into(), Value::String(tool_name.to_string()));
320
321 if let Some(meta) = result.meta {
322 if let Ok(meta_value) = serde_json::to_value(&meta) {
323 if !meta_value.is_null() {
324 payload.insert("meta".into(), meta_value);
325 }
326 }
327 }
328
329 if let Some(structured) = result.structured_content {
330 match structured {
331 Value::Object(mut object) => {
332 object
333 .entry("provider")
334 .or_insert_with(|| Value::String(provider_name.to_string()));
335 object
336 .entry("tool")
337 .or_insert_with(|| Value::String(tool_name.to_string()));
338
339 if let Some(meta_value) = payload.remove("meta") {
340 object.entry("meta").or_insert(meta_value);
341 }
342
343 return Ok(Value::Object(object));
344 }
345 other => {
346 payload.insert("structured_content".into(), other);
347 }
348 }
349 }
350
351 if let Some(summary) = text_summary {
352 payload.insert("message".into(), Value::String(summary));
353 }
354
355 if !result.content.is_empty() {
356 if let Ok(content_value) = serde_json::to_value(&result.content) {
357 payload.insert("content".into(), content_value);
358 }
359 }
360
361 Ok(Value::Object(payload))
362 }
363
364 pub async fn initialize(&mut self) -> Result<()> {
366 if !self.config.enabled {
367 info!("MCP client is disabled in configuration");
368 return Ok(());
369 }
370
371 info!(
372 "Initializing MCP client with {} configured providers",
373 self.config.providers.len()
374 );
375
376 for provider_config in &self.config.providers {
377 if provider_config.enabled {
378 info!("Initializing MCP provider '{}'", provider_config.name);
379
380 match McpProvider::new(provider_config.clone()).await {
381 Ok(provider) => {
382 let provider = Arc::new(provider);
383 self.providers
384 .insert(provider_config.name.clone(), provider);
385 info!(
386 "Successfully initialized MCP provider '{}'",
387 provider_config.name
388 );
389 self.audit_log(
390 Some(provider_config.name.as_str()),
391 "mcp.provider_initialized",
392 Level::INFO,
393 format!("Provider '{}' initialized", provider_config.name),
394 );
395 }
396 Err(e) => {
397 error!(
398 "Failed to initialize MCP provider '{}': {}",
399 provider_config.name, e
400 );
401 self.audit_log(
402 Some(provider_config.name.as_str()),
403 "mcp.provider_initialization_failed",
404 Level::WARN,
405 format!(
406 "Failed to initialize provider '{}' due to error: {}",
407 provider_config.name, e
408 ),
409 );
410 continue;
412 }
413 }
414 } else {
415 debug!(
416 "MCP provider '{}' is disabled, skipping",
417 provider_config.name
418 );
419 }
420 }
421
422 info!(
423 "MCP client initialization complete. Active providers: {}",
424 self.providers.len()
425 );
426
427 Ok(())
432 }
433
434 async fn kill_remaining_mcp_processes(&self) {
436 debug!("Checking for remaining MCP provider processes to clean up");
437
438 let process_cleanup_attempts = tokio::time::timeout(
441 tokio::time::Duration::from_secs(5),
442 self.attempt_process_cleanup(),
443 )
444 .await;
445
446 match process_cleanup_attempts {
447 Ok(Ok(cleaned_count)) => {
448 if cleaned_count > 0 {
449 info!(
450 "Cleaned up {} remaining MCP provider processes",
451 cleaned_count
452 );
453 self.audit_log(
454 None,
455 "mcp.process_cleanup",
456 Level::INFO,
457 format!(
458 "Cleaned up {} remaining MCP provider processes",
459 cleaned_count
460 ),
461 );
462 } else {
463 debug!("No remaining MCP provider processes to clean up");
464 }
465 }
466 Ok(Err(e)) => {
467 warn!("Error during MCP process cleanup (non-critical): {}", e);
468 self.audit_log(
469 None,
470 "mcp.process_cleanup_error",
471 Level::WARN,
472 format!("Error during MCP process cleanup: {}", e),
473 );
474 }
475 Err(_) => {
476 warn!("MCP process cleanup timed out (non-critical)");
477 self.audit_log(
478 None,
479 "mcp.process_cleanup_timeout",
480 Level::WARN,
481 "MCP process cleanup timed out".to_string(),
482 );
483 }
484 }
485 }
486
487 async fn attempt_process_cleanup(&self) -> Result<usize> {
489 use tokio::process::Command as TokioCommand;
490
491 let mut cleaned_count = 0;
492
493 let current_pid = std::process::id();
495
496 for provider_config in &self.config.providers {
499 if !provider_config.enabled {
500 continue;
501 }
502
503 let provider_name = &provider_config.name;
504 debug!("Attempting cleanup for MCP provider '{}'", provider_name);
505
506 let mut provider_cleaned = 0;
508
509 if let Ok(output) = TokioCommand::new("pgrep")
511 .args(["-f", &format!("mcp-server-{}", provider_name)])
512 .output()
513 .await
514 {
515 if output.status.success() {
516 let pids = String::from_utf8_lossy(&output.stdout);
517 for pid_str in pids.lines() {
518 if let Ok(pid) = pid_str.trim().parse::<u32>() {
519 if pid != current_pid && pid > 0 {
520 if self.kill_process_gracefully(pid).await {
521 provider_cleaned += 1;
522 }
523 }
524 }
525 }
526 }
527 }
528
529 if provider_cleaned == 0 {
531 if let Ok(output) = TokioCommand::new("ps").args(["aux"]).output().await {
532 if output.status.success() {
533 let processes = String::from_utf8_lossy(&output.stdout);
534 for line in processes.lines() {
535 if line.contains(provider_name)
537 && (line.contains("mcp")
538 || line.contains("node")
539 || line.contains("python"))
540 {
541 let parts: Vec<&str> = line.split_whitespace().collect();
543 if let Some(pid_str) = parts.first() {
544 if let Ok(pid) = pid_str.parse::<u32>() {
545 if pid != current_pid && pid > 0 {
546 if self.kill_process_gracefully(pid).await {
547 provider_cleaned += 1;
548 }
549 }
550 }
551 }
552 }
553 }
554 }
555 }
556 }
557
558 if provider_cleaned > 0 {
559 debug!(
560 "Cleaned up {} processes for MCP provider '{}'",
561 provider_cleaned, provider_name
562 );
563 cleaned_count += provider_cleaned;
564 self.tool_provider_index.write().clear();
566 }
567 }
568
569 Ok(cleaned_count)
570 }
571
572 async fn kill_process_gracefully(&self, pid: u32) -> bool {
574 debug!("Killing process {} gracefully", pid);
575
576 let _ = tokio::process::Command::new("kill")
578 .args(["-TERM", &pid.to_string()])
579 .output()
580 .await;
581
582 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
584
585 if let Ok(output) = tokio::process::Command::new("kill")
587 .args(["-0", &pid.to_string()]) .output()
589 .await
590 {
591 if output.status.success() {
592 debug!("Process {} still running, force killing", pid);
594 let _ = tokio::process::Command::new("kill")
595 .args(["-KILL", &pid.to_string()])
596 .output()
597 .await;
598 true
599 } else {
600 debug!("Process {} already terminated", pid);
602 true
603 }
604 } else {
605 debug!("Process {} check failed, assuming terminated", pid);
607 true
608 }
609 }
610
611 pub async fn cleanup_dead_providers(&self) -> Result<()> {
613 let mut dead_providers = Vec::new();
614
615 for (provider_name, provider) in &self.providers {
616 let provider_health_check = tokio::time::timeout(
618 tokio::time::Duration::from_secs(2),
619 provider.has_tool("ping"),
620 )
621 .await;
622
623 match provider_health_check {
624 Ok(Ok(_)) => {
625 debug!("MCP provider '{}' is healthy", provider_name);
627 }
628 Ok(Err(e)) => {
629 let error_msg = e.to_string();
630 if error_msg.contains("No such process") || error_msg.contains("ESRCH") {
631 warn!(
632 "MCP provider '{}' has terminated process, marking for cleanup",
633 provider_name
634 );
635 dead_providers.push(provider_name.clone());
636 } else {
637 debug!(
638 "MCP provider '{}' returned error but process may be alive: {}",
639 provider_name, e
640 );
641 }
642 }
643 Err(_timeout) => {
644 warn!(
645 "MCP provider '{}' health check timed out, may be unresponsive",
646 provider_name
647 );
648 }
650 }
651 }
652
653 if !dead_providers.is_empty() {
656 warn!(
657 "Found {} dead MCP providers: {:?}",
658 dead_providers.len(),
659 dead_providers
660 );
661 }
662
663 Ok(())
664 }
665
666 pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
668 if !self.config.enabled {
669 debug!("MCP client is disabled, returning empty tool list");
670 return Ok(Vec::new());
671 }
672
673 if self.providers.is_empty() {
674 debug!("No MCP providers configured, returning empty tool list");
675 return Ok(Vec::new());
676 }
677
678 let mut all_tools = Vec::new();
679 let mut errors = Vec::new();
680
681 let allowlist_snapshot = self.allowlist.read().clone();
682
683 for (provider_name, provider) in &self.providers {
684 let provider_id = provider_name.as_str();
685 match tokio::time::timeout(tokio::time::Duration::from_secs(15), provider.list_tools())
686 .await
687 {
688 Ok(Ok(tools)) => {
689 debug!(
690 "Provider '{}' has {} tools",
691 provider_name,
692 tools.tools.len()
693 );
694
695 for tool in tools.tools {
696 let tool_name = tool.name.as_ref();
697
698 if allowlist_snapshot.is_tool_allowed(provider_id, tool_name) {
699 self.record_tool_provider(provider_id, tool_name);
700 all_tools.push(McpToolInfo {
701 name: tool_name.to_string(),
702 description: tool.description.unwrap_or_default().to_string(),
703 provider: provider_name.clone(),
704 input_schema: serde_json::to_value(&*tool.input_schema)
705 .unwrap_or(Value::Null),
706 });
707 } else {
708 self.audit_log(
709 Some(provider_id),
710 "mcp.tool_filtered",
711 Level::DEBUG,
712 format!(
713 "Filtered tool '{}' from provider '{}' due to allow list",
714 tool_name, provider_id
715 ),
716 );
717 }
718 }
719 }
720 Ok(Err(e)) => {
721 let error_msg = e.to_string();
722 if error_msg.contains("No such process")
723 || error_msg.contains("ESRCH")
724 || error_msg.contains("EPIPE")
725 || error_msg.contains("Broken pipe")
726 || error_msg.contains("write EPIPE")
727 {
728 debug!(
729 "MCP provider '{}' process/pipe terminated during tool listing (normal during shutdown): {}",
730 provider_name, e
731 );
732 } else {
733 warn!(
734 "Failed to list tools for provider '{}': {}",
735 provider_name, e
736 );
737 }
738 let error_msg = format!(
739 "Failed to list tools for provider '{}': {}",
740 provider_name, e
741 );
742 errors.push(error_msg);
743 }
744 Err(_timeout) => {
745 warn!("MCP provider '{}' tool listing timed out", provider_name);
746 let error_msg =
747 format!("Tool listing timeout for provider '{}'", provider_name);
748 errors.push(error_msg);
749 }
750 }
751 }
752
753 if !errors.is_empty() {
754 warn!(
755 "Encountered {} errors while listing MCP tools: {:?}",
756 errors.len(),
757 errors
758 );
759 }
760
761 info!(
762 "Found {} total MCP tools across all providers",
763 all_tools.len()
764 );
765 Ok(all_tools)
766 }
767
768 pub async fn execute_tool(&self, tool_name: &str, args: Value) -> Result<Value> {
770 if !self.config.enabled {
771 return Err(anyhow::anyhow!("MCP client is disabled"));
772 }
773
774 if self.providers.is_empty() {
775 return Err(anyhow::anyhow!("No MCP providers configured"));
776 }
777
778 let tool_name_owned = tool_name.to_string();
779 debug!("Executing MCP tool '{}' with args: {}", tool_name, args);
780
781 let provider_name = {
783 let mut found_provider = None;
784 let mut provider_errors = Vec::new();
785
786 for (name, provider) in &self.providers {
787 match provider.has_tool(&tool_name_owned).await {
788 Ok(true) => {
789 found_provider = Some(name.clone());
790 break;
791 }
792 Ok(false) => continue,
793 Err(e) => {
794 let error_msg = format!(
795 "Error checking tool availability for provider '{}': {}",
796 name, e
797 );
798 warn!("{}", error_msg);
799 provider_errors.push(error_msg);
800 }
801 }
802 }
803
804 found_provider.ok_or_else(|| {
805 let error_msg = format!(
806 "Tool '{}' not found in any MCP provider. Provider errors: {:?}",
807 tool_name, provider_errors
808 );
809 anyhow::anyhow!(error_msg)
810 })?
811 };
812
813 debug!("Found tool '{}' in provider '{}'", tool_name, provider_name);
814
815 if !self
816 .allowlist
817 .read()
818 .is_tool_allowed(provider_name.as_str(), tool_name)
819 {
820 let message = format!(
821 "Tool '{}' from provider '{}' is not permitted by the MCP allow list",
822 tool_name, provider_name
823 );
824 self.audit_log(
825 Some(provider_name.as_str()),
826 "mcp.tool_denied",
827 Level::WARN,
828 message.as_str(),
829 );
830 return Err(anyhow::anyhow!(message));
831 }
832
833 self.record_tool_provider(provider_name.as_str(), tool_name);
834
835 let provider = self.providers.get(&provider_name).ok_or_else(|| {
836 anyhow::anyhow!("Provider '{}' not found after discovery", provider_name)
837 })?;
838
839 let connection = match self.get_or_create_connection(provider).await {
841 Ok(conn) => conn,
842 Err(e) => {
843 error!(
844 "Failed to establish connection to provider '{}': {}",
845 provider_name, e
846 );
847 return Err(e);
848 }
849 };
850
851 match connection
853 .call_tool(CallToolRequestParam {
854 name: tool_name_owned.into(),
855 arguments: args.as_object().cloned(),
856 })
857 .await
858 {
859 Ok(result) => match Self::format_tool_result(provider_name.as_str(), tool_name, result)
860 {
861 Ok(serialized) => {
862 info!(
863 "Successfully executed MCP tool '{}' via provider '{}'",
864 tool_name, provider_name
865 );
866 self.audit_log(
867 Some(provider_name.as_str()),
868 "mcp.tool_execution",
869 Level::INFO,
870 format!(
871 "Successfully executed MCP tool '{}' via provider '{}'",
872 tool_name, provider_name
873 ),
874 );
875 Ok(serialized)
876 }
877 Err(err) => {
878 let err_message = err.to_string();
879 warn!(
880 "MCP tool '{}' via provider '{}' returned an error payload: {}",
881 tool_name, provider_name, err_message
882 );
883 self.audit_log(
884 Some(provider_name.as_str()),
885 "mcp.tool_failed",
886 Level::WARN,
887 format!(
888 "MCP tool '{}' via provider '{}' returned an error payload: {}",
889 tool_name, provider_name, err_message
890 ),
891 );
892 Err(err)
893 }
894 },
895 Err(e) => {
896 let error_message = e.to_string();
897
898 error!(
899 "MCP tool '{}' failed on provider '{}': {}",
900 tool_name, provider_name, error_message
901 );
902 self.audit_log(
903 Some(provider_name.as_str()),
904 "mcp.tool_failed",
905 Level::WARN,
906 format!(
907 "MCP tool '{}' failed on provider '{}': {}",
908 tool_name, provider_name, error_message
909 ),
910 );
911
912 if error_message.contains("EPIPE")
914 || error_message.contains("Broken pipe")
915 || error_message.contains("write EPIPE")
916 || error_message.contains("No such process")
917 || error_message.contains("ESRCH")
918 {
919 let mut connections = self.active_connections.lock().await;
921 connections.remove(&provider_name);
922 self.tool_provider_index
924 .write()
925 .retain(|_, provider| provider != &provider_name);
926
927 return Err(anyhow::anyhow!(
928 "MCP provider '{}' disconnected unexpectedly while executing '{}'. The provider process may have terminated. Try re-running the command to restart the provider.",
929 provider_name,
930 tool_name
931 ));
932 } else if error_message.contains("timeout") || error_message.contains("Timeout") {
933 let mut connections = self.active_connections.lock().await;
935 connections.remove(&provider_name);
936
937 return Err(anyhow::anyhow!(
938 "MCP tool '{}' execution timed out on provider '{}'. The provider may be unresponsive. Try re-running the command.",
939 tool_name,
940 provider_name
941 ));
942 } else if error_message.contains("permission")
943 || error_message.contains("Permission denied")
944 {
945 return Err(anyhow::anyhow!(
946 "Permission denied executing MCP tool '{}' on provider '{}': {}",
947 tool_name,
948 provider_name,
949 error_message
950 ));
951 } else if error_message.contains("network")
952 || error_message.contains("Connection refused")
953 {
954 return Err(anyhow::anyhow!(
955 "Network error executing MCP tool '{}' on provider '{}': {}",
956 tool_name,
957 provider_name,
958 error_message
959 ));
960 }
961
962 Err(anyhow::anyhow!(
963 "MCP tool execution failed: {}",
964 error_message
965 ))
966 }
967 }
968 }
969
970 async fn get_or_create_connection(
972 &self,
973 provider: &McpProvider,
974 ) -> Result<Arc<RunningMcpService>> {
975 let provider_name = &provider.config.name;
976 debug!("Getting connection for MCP provider '{}'", provider_name);
977
978 let mut connections = self.active_connections.lock().await;
979
980 if !connections.contains_key(provider_name) {
981 debug!("Creating new connection for provider '{}'", provider_name);
982
983 match tokio::time::timeout(tokio::time::Duration::from_secs(30), provider.connect())
985 .await
986 {
987 Ok(Ok(connection)) => {
988 let connection = Arc::new(connection);
989 connections.insert(provider_name.clone(), Arc::clone(&connection));
990 debug!(
991 "Successfully created connection for provider '{}'",
992 provider_name
993 );
994 Ok(connection)
995 }
996 Ok(Err(e)) => {
997 let error_msg = e.to_string();
998 if error_msg.contains("HTTP MCP server support") {
999 warn!(
1000 "Provider '{}' uses HTTP transport which is not fully implemented: {}",
1001 provider_name, e
1002 );
1003 Err(anyhow::anyhow!(
1004 "HTTP MCP transport not fully implemented for provider '{}'. Consider using stdio transport instead.",
1005 provider_name
1006 ))
1007 } else if error_msg.contains("command not found")
1008 || error_msg.contains("No such file")
1009 {
1010 error!("Command not found for provider '{}': {}", provider_name, e);
1011 Err(anyhow::anyhow!(
1012 "Command '{}' not found for MCP provider '{}'. Please ensure the MCP server is installed and accessible.",
1013 self.config
1014 .providers
1015 .iter()
1016 .find(|p| p.name == *provider_name)
1017 .map(|p| match &p.transport {
1018 McpTransportConfig::Stdio(stdio) => stdio.command.as_str(),
1019 _ => "unknown",
1020 })
1021 .unwrap_or("unknown"),
1022 provider_name
1023 ))
1024 } else if error_msg.contains("permission")
1025 || error_msg.contains("Permission denied")
1026 {
1027 error!(
1028 "Permission denied creating connection for provider '{}': {}",
1029 provider_name, e
1030 );
1031 Err(anyhow::anyhow!(
1032 "Permission denied executing MCP server for provider '{}': {}",
1033 provider_name,
1034 error_msg
1035 ))
1036 } else {
1037 error!(
1038 "Failed to create connection for provider '{}': {}",
1039 provider_name, e
1040 );
1041 Err(anyhow::anyhow!(
1042 "Failed to create connection for MCP provider '{}': {}",
1043 provider_name,
1044 error_msg
1045 ))
1046 }
1047 }
1048 Err(_timeout) => {
1049 error!(
1050 "Connection creation timed out for provider '{}' after {} seconds",
1051 provider_name, 30
1052 );
1053 Err(anyhow::anyhow!(
1054 "Connection creation timed out for MCP provider '{}' after {} seconds. The provider may be slow to start or unresponsive.",
1055 provider_name,
1056 30
1057 ))
1058 }
1059 }
1060 } else {
1061 let existing_connection = connections.get(provider_name).unwrap().clone();
1063
1064 if let Err(e) = self
1066 .validate_connection(provider_name, &existing_connection)
1067 .await
1068 {
1069 debug!(
1070 "Existing connection for provider '{}' is unhealthy, creating new one: {}",
1071 provider_name, e
1072 );
1073
1074 connections.remove(provider_name);
1076
1077 match tokio::time::timeout(tokio::time::Duration::from_secs(30), provider.connect())
1079 .await
1080 {
1081 Ok(Ok(connection)) => {
1082 let connection = Arc::new(connection);
1083 connections.insert(provider_name.clone(), Arc::clone(&connection));
1084 debug!(
1085 "Successfully created new connection for provider '{}'",
1086 provider_name
1087 );
1088 Ok(connection)
1089 }
1090 Ok(Err(e)) => {
1091 error!(
1092 "Failed to create replacement connection for provider '{}': {}",
1093 provider_name, e
1094 );
1095 Err(e)
1096 }
1097 Err(_timeout) => {
1098 error!(
1099 "Replacement connection creation timed out for provider '{}'",
1100 provider_name
1101 );
1102 Err(anyhow::anyhow!(
1103 "Replacement connection timeout for provider '{}'",
1104 provider_name
1105 ))
1106 }
1107 }
1108 } else {
1109 debug!(
1110 "Reusing existing healthy connection for provider '{}'",
1111 provider_name
1112 );
1113 Ok(existing_connection)
1114 }
1115 }
1116 }
1117
1118 async fn validate_connection(
1120 &self,
1121 provider_name: &str,
1122 connection: &RunningMcpService,
1123 ) -> Result<()> {
1124 debug!(
1125 "Validating connection health for provider '{}'",
1126 provider_name
1127 );
1128
1129 match tokio::time::timeout(
1132 tokio::time::Duration::from_secs(2),
1133 connection.list_tools(Default::default()),
1134 )
1135 .await
1136 {
1137 Ok(Ok(_)) => {
1138 debug!(
1139 "Connection health check passed for provider '{}'",
1140 provider_name
1141 );
1142 Ok(())
1143 }
1144 Ok(Err(e)) => {
1145 let error_msg = e.to_string();
1146 debug!(
1147 "Connection health check failed for provider '{}': {}",
1148 provider_name, error_msg
1149 );
1150 Err(anyhow::anyhow!(
1151 "Connection health check failed for provider '{}': {}",
1152 provider_name,
1153 error_msg
1154 ))
1155 }
1156 Err(_) => {
1157 debug!(
1158 "Connection health check timed out for provider '{}'",
1159 provider_name
1160 );
1161 Err(anyhow::anyhow!(
1162 "Connection health check timed out for provider '{}'",
1163 provider_name
1164 ))
1165 }
1166 }
1167 }
1168
1169 fn audit_log(
1170 &self,
1171 provider: Option<&str>,
1172 channel: &str,
1173 level: Level,
1174 message: impl AsRef<str>,
1175 ) {
1176 let logging_allowed = {
1177 let allowlist = self.allowlist.read();
1178 allowlist.is_logging_channel_allowed(provider, channel)
1179 };
1180
1181 if !logging_allowed {
1182 return;
1183 }
1184
1185 let msg = message.as_ref();
1186 match level {
1187 Level::ERROR => error!(target: "mcp", "[{}] {}", channel, msg),
1188 Level::WARN => warn!(target: "mcp", "[{}] {}", channel, msg),
1189 Level::INFO => info!(target: "mcp", "[{}] {}", channel, msg),
1190 Level::DEBUG => debug!(target: "mcp", "[{}] {}", channel, msg),
1191 _ => debug!(target: "mcp", "[{}] {}", channel, msg),
1192 }
1193 }
1194
1195 pub async fn shutdown(&self) -> Result<()> {
1197 info!("Shutting down MCP client and all provider connections");
1198
1199 let mut connections = self.active_connections.lock().await;
1200
1201 if connections.is_empty() {
1202 info!("No active MCP connections to shutdown");
1203 return Ok(());
1204 }
1205
1206 info!(
1207 "Shutting down {} MCP provider connections",
1208 connections.len()
1209 );
1210
1211 let cancellation_tokens: Vec<(String, rmcp::service::RunningServiceCancellationToken)> =
1212 connections
1213 .iter()
1214 .map(|(provider_name, connection)| {
1215 debug!(
1216 "Initiating graceful shutdown for MCP provider '{}'",
1217 provider_name
1218 );
1219 (provider_name.clone(), connection.cancellation_token())
1220 })
1221 .collect();
1222
1223 for (provider_name, token) in cancellation_tokens {
1224 debug!(
1225 "Cancelling MCP provider '{}' via cancellation token",
1226 provider_name
1227 );
1228 token.cancel();
1229 }
1230
1231 let shutdown_timeout = tokio::time::Duration::from_secs(5);
1233 let shutdown_start = std::time::Instant::now();
1234
1235 while shutdown_start.elapsed() < shutdown_timeout && !connections.is_empty() {
1237 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1238
1239 connections.retain(|_, connection| {
1241 Arc::strong_count(connection) > 1 });
1244 }
1245
1246 let remaining_count = connections.len();
1248 if remaining_count > 0 {
1249 warn!(
1250 "{} MCP provider connections did not shutdown gracefully within timeout, forcing shutdown",
1251 remaining_count
1252 );
1253 }
1254
1255 let drained_connections: Vec<_> = connections.drain().collect();
1257 drop(connections);
1258
1259 for (provider_name, connection) in drained_connections {
1260 debug!("Force shutting down MCP provider '{}'", provider_name);
1261
1262 if let Ok(connection) = Arc::try_unwrap(connection) {
1263 debug!(
1264 "Awaiting MCP provider '{}' task cancellation after graceful request",
1265 provider_name
1266 );
1267
1268 match connection.cancel().await {
1269 Ok(quit_reason) => {
1270 debug!(
1271 "MCP provider '{}' cancellation completed with reason: {:?}",
1272 provider_name, quit_reason
1273 );
1274 }
1275 Err(err) => {
1276 debug!(
1277 "MCP provider '{}' cancellation join error (non-critical): {}",
1278 provider_name, err
1279 );
1280 }
1281 }
1282 } else {
1283 debug!(
1284 "Additional references exist for MCP provider '{}'; dropping without awaiting",
1285 provider_name
1286 );
1287 }
1288 }
1289
1290 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
1292
1293 self.kill_remaining_mcp_processes().await;
1296
1297 info!("MCP client shutdown complete");
1298 Ok(())
1299 }
1300}
1301
1302#[derive(Debug, Clone)]
1304pub struct McpToolInfo {
1305 pub name: String,
1306 pub description: String,
1307 pub provider: String,
1308 pub input_schema: Value,
1309}
1310
1311pub struct McpProvider {
1313 config: McpProviderConfig,
1314 tools_cache: Arc<Mutex<Option<ListToolsResult>>>,
1315}
1316
1317impl McpProvider {
1318 pub async fn new(config: McpProviderConfig) -> Result<Self> {
1320 Ok(Self {
1321 config,
1322 tools_cache: Arc::new(Mutex::new(None)),
1323 })
1324 }
1325
1326 pub async fn list_tools(&self) -> Result<ListToolsResult> {
1328 let provider_name = &self.config.name;
1329 debug!("Listing tools for MCP provider '{}'", provider_name);
1330
1331 {
1333 let cache = self.tools_cache.lock().await;
1334 if let Some(cached) = cache.as_ref() {
1335 debug!("Using cached tools for provider '{}'", provider_name);
1336 return Ok(cached.clone());
1337 }
1338 }
1339
1340 debug!("Connecting to provider '{}' to fetch tools", provider_name);
1341
1342 match self.connect().await {
1344 Ok(connection) => {
1345 match connection.list_tools(Default::default()).await {
1346 Ok(tools) => {
1347 debug!(
1348 "Found {} tools for provider '{}'",
1349 tools.tools.len(),
1350 provider_name
1351 );
1352
1353 {
1355 let mut cache = self.tools_cache.lock().await;
1356 *cache = Some(tools.clone());
1357 }
1358
1359 Ok(tools)
1360 }
1361 Err(e) => {
1362 error!(
1363 "Failed to list tools for provider '{}': {}",
1364 provider_name, e
1365 );
1366 Err(anyhow::anyhow!("Failed to list tools: {}", e))
1367 }
1368 }
1369 }
1370 Err(e) => {
1371 error!("Failed to connect to provider '{}': {}", provider_name, e);
1372 Err(e)
1373 }
1374 }
1375 }
1376
1377 pub async fn has_tool(&self, tool_name: &str) -> Result<bool> {
1379 let provider_name = &self.config.name;
1380 debug!(
1381 "Checking if provider '{}' has tool '{}'",
1382 provider_name, tool_name
1383 );
1384
1385 match tokio::time::timeout(tokio::time::Duration::from_secs(10), self.list_tools()).await {
1386 Ok(Ok(tools)) => {
1387 let has_tool = tools.tools.iter().any(|tool| tool.name == tool_name);
1388 debug!(
1389 "Provider '{}' {} tool '{}'",
1390 provider_name,
1391 if has_tool { "has" } else { "does not have" },
1392 tool_name
1393 );
1394 Ok(has_tool)
1395 }
1396 Ok(Err(e)) => {
1397 let error_msg = e.to_string();
1398 if error_msg.contains("No such process")
1399 || error_msg.contains("ESRCH")
1400 || error_msg.contains("EPIPE")
1401 || error_msg.contains("Broken pipe")
1402 || error_msg.contains("write EPIPE")
1403 {
1404 debug!(
1405 "MCP provider '{}' process/pipe terminated during tool check (normal during shutdown): {}",
1406 provider_name, e
1407 );
1408 } else {
1409 warn!(
1410 "Failed to check tool availability for provider '{}': {}",
1411 provider_name, e
1412 );
1413 }
1414 Err(e)
1415 }
1416 Err(_timeout) => {
1417 warn!("MCP provider '{}' tool check timed out", provider_name);
1418 Err(anyhow::anyhow!("Tool availability check timeout"))
1419 }
1420 }
1421 }
1422
1423 async fn connect(&self) -> Result<RunningMcpService> {
1425 let provider_name = &self.config.name;
1426 info!("Connecting to MCP provider '{}'", provider_name);
1427
1428 match &self.config.transport {
1429 McpTransportConfig::Stdio(stdio_config) => {
1430 debug!("Using stdio transport for provider '{}'", provider_name);
1431 self.connect_stdio(stdio_config).await
1432 }
1433 McpTransportConfig::Http(http_config) => {
1434 debug!("Using HTTP transport for provider '{}'", provider_name);
1435 self.connect_http(http_config).await
1436 }
1437 }
1438 }
1439
1440 async fn connect_http(
1442 &self,
1443 config: &crate::config::mcp::McpHttpServerConfig,
1444 ) -> Result<RunningMcpService> {
1445 let provider_name = &self.config.name;
1446 debug!(
1447 "Setting up HTTP connection for provider '{}'",
1448 provider_name
1449 );
1450
1451 let mut headers = HeaderMap::new();
1453 headers.insert("Content-Type", "application/json".parse().unwrap());
1454
1455 if let Some(api_key_env) = &config.api_key_env {
1457 if let Ok(api_key) = std::env::var(api_key_env) {
1458 headers.insert(
1459 "Authorization",
1460 format!("Bearer {}", api_key).parse().unwrap(),
1461 );
1462 } else {
1463 warn!(
1464 "API key environment variable '{}' not found for provider '{}'",
1465 api_key_env, provider_name
1466 );
1467 }
1468 }
1469
1470 for (key, value) in &config.headers {
1472 if let (Ok(header_name), Ok(header_value)) =
1473 (key.parse::<HeaderName>(), value.parse::<HeaderValue>())
1474 {
1475 headers.insert(header_name, header_value);
1476 }
1477 }
1478
1479 let client = reqwest::Client::builder()
1480 .default_headers(headers)
1481 .timeout(std::time::Duration::from_secs(30))
1482 .build()
1483 .context("Failed to build HTTP client")?;
1484
1485 debug!(
1487 "Testing HTTP MCP server connectivity at '{}'",
1488 config.endpoint
1489 );
1490
1491 match client.get(&config.endpoint).send().await {
1492 Ok(response) => {
1493 let status = response.status();
1494 if status.is_success() {
1495 debug!(
1496 "HTTP MCP server at '{}' is reachable (status: {})",
1497 config.endpoint, status
1498 );
1499
1500 let mcp_endpoint = if config.endpoint.ends_with('/') {
1503 format!("{}mcp", config.endpoint)
1504 } else {
1505 format!("{}/mcp", config.endpoint)
1506 };
1507
1508 debug!("Attempting to connect to MCP endpoint: {}", mcp_endpoint);
1509
1510 match client.get(&mcp_endpoint).send().await {
1512 Ok(mcp_response) => {
1513 if mcp_response.status().is_success() {
1514 debug!(
1515 "MCP endpoint '{}' is available (status: {})",
1516 mcp_endpoint,
1517 mcp_response.status()
1518 );
1519
1520 Err(anyhow::anyhow!(
1523 "HTTP MCP server detected at '{}' but full streamable HTTP implementation is required. \
1524 MCP endpoint is available at '{}'. \
1525 Consider using stdio transport or implement HTTP streaming support with Server-Sent Events.",
1526 config.endpoint,
1527 mcp_endpoint
1528 ))
1529 } else {
1530 debug!(
1531 "MCP endpoint '{}' returned status: {}",
1532 mcp_endpoint,
1533 mcp_response.status()
1534 );
1535 Err(anyhow::anyhow!(
1536 "HTTP MCP server at '{}' does not support MCP protocol. \
1537 Expected MCP endpoint at '{}' but got status: {}. \
1538 Consider using stdio transport instead.",
1539 config.endpoint,
1540 mcp_endpoint,
1541 mcp_response.status()
1542 ))
1543 }
1544 }
1545 Err(e) => {
1546 let error_msg = e.to_string();
1547 debug!(
1548 "Failed to connect to MCP endpoint '{}': {}",
1549 mcp_endpoint, error_msg
1550 );
1551
1552 Err(anyhow::anyhow!(
1553 "HTTP MCP server at '{}' does not properly support MCP protocol. \
1554 Could not connect to MCP endpoint at '{}': {}. \
1555 Consider using stdio transport instead.",
1556 config.endpoint,
1557 mcp_endpoint,
1558 error_msg
1559 ))
1560 }
1561 }
1562 } else {
1563 Err(anyhow::anyhow!(
1564 "HTTP MCP server returned error status: {} at endpoint: {}",
1565 status,
1566 config.endpoint
1567 ))
1568 }
1569 }
1570 Err(e) => {
1571 let error_msg = e.to_string();
1572 if error_msg.contains("dns") || error_msg.contains("Name resolution") {
1573 Err(anyhow::anyhow!(
1574 "HTTP MCP server DNS resolution failed for '{}': {}",
1575 config.endpoint,
1576 e
1577 ))
1578 } else if error_msg.contains("Connection refused") || error_msg.contains("connect")
1579 {
1580 Err(anyhow::anyhow!(
1581 "HTTP MCP server connection failed for '{}': {}",
1582 config.endpoint,
1583 e
1584 ))
1585 } else {
1586 Err(anyhow::anyhow!(
1587 "HTTP MCP server error for '{}': {}",
1588 config.endpoint,
1589 e
1590 ))
1591 }
1592 }
1593 }
1594 }
1595
1596 async fn connect_stdio(
1598 &self,
1599 config: &crate::config::mcp::McpStdioServerConfig,
1600 ) -> Result<RunningMcpService> {
1601 let provider_name = &self.config.name;
1602 debug!(
1603 "Setting up stdio connection for provider '{}'",
1604 provider_name
1605 );
1606
1607 debug!("Command: {} with args: {:?}", config.command, config.args);
1608
1609 let command_label = config.command.clone();
1610 let mut command = Command::new(&config.command);
1611 command.args(&config.args);
1612
1613 if let Some(working_dir) = &config.working_directory {
1615 debug!("Using working directory: {}", working_dir);
1616 command.current_dir(working_dir);
1617 }
1618
1619 if !self.config.env.is_empty() {
1621 debug!(
1622 "Setting environment variables for provider '{}'",
1623 provider_name
1624 );
1625 command.envs(&self.config.env);
1626 }
1627
1628 #[cfg(unix)]
1630 {
1631 #[allow(unused_imports)]
1632 use std::os::unix::process::CommandExt;
1633 command.process_group(0);
1634 }
1635
1636 debug!(
1637 "Creating TokioChildProcess for provider '{}'",
1638 provider_name
1639 );
1640
1641 match TokioChildProcess::builder(command)
1642 .stderr(Stdio::piped())
1643 .spawn()
1644 {
1645 Ok((child_process, stderr)) => {
1646 debug!(
1647 "Successfully created child process for provider '{}'",
1648 provider_name
1649 );
1650
1651 if let Some(stderr) = stderr {
1652 let provider = provider_name.to_string();
1653 let command_name = command_label.clone();
1654 tokio::spawn(async move {
1655 let mut reader = BufReader::new(stderr).lines();
1656 loop {
1657 match reader.next_line().await {
1658 Ok(Some(line)) => {
1659 if line.trim().is_empty() {
1660 continue;
1661 }
1662 info!(
1663 provider = provider.as_str(),
1664 command = command_name.as_str(),
1665 line = line.as_str(),
1666 "MCP provider stderr output"
1667 );
1668 }
1669 Ok(None) => break,
1670 Err(error) => {
1671 warn!(
1672 provider = provider.as_str(),
1673 command = command_name.as_str(),
1674 error = %error,
1675 "Failed to read MCP provider stderr"
1676 );
1677 break;
1678 }
1679 }
1680 }
1681 });
1682 }
1683
1684 let handler = LoggingClientHandler::new(provider_name);
1686
1687 match tokio::time::timeout(
1688 tokio::time::Duration::from_secs(30),
1689 handler.serve(child_process),
1690 )
1691 .await
1692 {
1693 Ok(Ok(connection)) => {
1694 info!(
1695 "Successfully established connection to MCP provider '{}'",
1696 provider_name
1697 );
1698 Ok(connection)
1699 }
1700 Ok(Err(e)) => {
1701 let error_msg = e.to_string();
1703 if error_msg.contains("No such process")
1704 || error_msg.contains("ESRCH")
1705 || error_msg.contains("EPIPE")
1706 || error_msg.contains("Broken pipe")
1707 || error_msg.contains("write EPIPE")
1708 {
1709 debug!(
1710 "MCP provider '{}' pipe/process error during connection (normal during shutdown): {}",
1711 provider_name, e
1712 );
1713 Err(anyhow::anyhow!("MCP provider connection terminated: {}", e))
1714 } else {
1715 error!(
1716 "Failed to establish MCP connection for provider '{}': {}",
1717 provider_name, e
1718 );
1719 Err(anyhow::anyhow!("Failed to serve MCP connection: {}", e))
1720 }
1721 }
1722 Err(_timeout) => {
1723 warn!(
1724 "MCP provider '{}' connection timed out after 30 seconds",
1725 provider_name
1726 );
1727 Err(anyhow::anyhow!("MCP provider connection timeout"))
1728 }
1729 }
1730 }
1731 Err(e) => {
1732 let error_msg = e.to_string();
1734 if error_msg.contains("No such process") || error_msg.contains("ESRCH") {
1735 error!(
1736 "Failed to create child process for provider '{}' - process may have terminated: {}",
1737 provider_name, e
1738 );
1739 } else {
1740 error!(
1741 "Failed to create child process for provider '{}': {}",
1742 provider_name, e
1743 );
1744 }
1745 Err(anyhow::anyhow!("Failed to create child process: {}", e))
1746 }
1747 }
1748 }
1749}
1750
1751type RunningMcpService =
1753 rmcp::service::RunningService<rmcp::service::RoleClient, LoggingClientHandler>;
1754
1755#[derive(Debug, Clone)]
1757pub struct McpClientStatus {
1758 pub enabled: bool,
1759 pub provider_count: usize,
1760 pub active_connections: usize,
1761 pub configured_providers: Vec<String>,
1762}
1763
1764impl McpClient {
1765 pub fn get_status(&self) -> McpClientStatus {
1767 McpClientStatus {
1768 enabled: self.config.enabled,
1769 provider_count: self.providers.len(),
1770 active_connections: self
1771 .active_connections
1772 .try_lock()
1773 .map(|connections| connections.len())
1774 .unwrap_or(0),
1775 configured_providers: self.providers.keys().cloned().collect(),
1776 }
1777 }
1778}
1779
1780#[async_trait]
1782pub trait McpToolExecutor: Send + Sync {
1783 async fn execute_mcp_tool(&self, tool_name: &str, args: Value) -> Result<Value>;
1785
1786 async fn list_mcp_tools(&self) -> Result<Vec<McpToolInfo>>;
1788
1789 async fn has_mcp_tool(&self, tool_name: &str) -> Result<bool>;
1791
1792 fn get_status(&self) -> McpClientStatus;
1794}
1795
1796#[async_trait]
1797impl McpToolExecutor for McpClient {
1798 async fn execute_mcp_tool(&self, tool_name: &str, args: Value) -> Result<Value> {
1799 self.execute_tool(tool_name, args).await
1800 }
1801
1802 async fn list_mcp_tools(&self) -> Result<Vec<McpToolInfo>> {
1803 self.list_tools().await
1804 }
1805
1806 async fn has_mcp_tool(&self, tool_name: &str) -> Result<bool> {
1807 if self.providers.is_empty() {
1808 return Ok(false);
1809 }
1810
1811 let mut provider_errors = Vec::new();
1812
1813 for (provider_name, provider) in &self.providers {
1814 let provider_id = provider_name.as_str();
1815 match provider.has_tool(tool_name).await {
1816 Ok(true) => {
1817 if self
1818 .allowlist
1819 .read()
1820 .is_tool_allowed(provider_id, tool_name)
1821 {
1822 self.record_tool_provider(provider_id, tool_name);
1823 return Ok(true);
1824 }
1825
1826 self.audit_log(
1827 Some(provider_id),
1828 "mcp.tool_denied",
1829 Level::DEBUG,
1830 format!(
1831 "Tool '{}' exists on provider '{}' but is blocked by allow list",
1832 tool_name, provider_id
1833 ),
1834 );
1835 }
1836 Ok(false) => continue,
1837 Err(e) => {
1838 let error_msg = format!("Error checking provider '{}': {}", provider_name, e);
1839 warn!("{}", error_msg);
1840 provider_errors.push(error_msg);
1841 }
1842 }
1843 }
1844
1845 if !provider_errors.is_empty() {
1846 debug!(
1847 "Encountered {} errors while checking tool availability: {:?}",
1848 provider_errors.len(),
1849 provider_errors
1850 );
1851
1852 let summary = provider_errors.join("; ");
1853 return Err(anyhow::anyhow!(
1854 "Failed to confirm MCP tool '{}' availability. {}",
1855 tool_name,
1856 summary
1857 ));
1858 }
1859
1860 Ok(false)
1861 }
1862
1863 fn get_status(&self) -> McpClientStatus {
1864 self.get_status()
1865 }
1866}
1867
1868#[cfg(test)]
1869mod tests {
1870 use super::*;
1871 use crate::config::mcp::{McpStdioServerConfig, McpTransportConfig};
1872 use rmcp::model::{Content, Meta};
1873 use serde_json::json;
1874
1875 #[test]
1876 fn test_mcp_client_creation() {
1877 let config = McpClientConfig::default();
1878 let client = McpClient::new(config);
1879 assert!(!client.config.enabled);
1880 assert!(client.providers.is_empty());
1881 }
1882
1883 #[test]
1884 fn test_mcp_tool_info() {
1885 let tool_info = McpToolInfo {
1886 name: "test_tool".to_string(),
1887 description: "A test tool".to_string(),
1888 provider: "test_provider".to_string(),
1889 input_schema: json!({
1890 "type": "object",
1891 "properties": {
1892 "input": {"type": "string"}
1893 }
1894 }),
1895 };
1896
1897 assert_eq!(tool_info.name, "test_tool");
1898 assert_eq!(tool_info.provider, "test_provider");
1899 }
1900
1901 #[test]
1902 fn test_provider_config() {
1903 let config = McpProviderConfig {
1904 name: "test".to_string(),
1905 transport: McpTransportConfig::Stdio(McpStdioServerConfig {
1906 command: "echo".to_string(),
1907 args: vec!["hello".to_string()],
1908 working_directory: None,
1909 }),
1910 env: HashMap::new(),
1911 enabled: true,
1912 max_concurrent_requests: 3,
1913 };
1914
1915 assert_eq!(config.name, "test");
1916 assert!(config.enabled);
1917 assert_eq!(config.max_concurrent_requests, 3);
1918 }
1919
1920 #[test]
1921 fn test_tool_info_creation() {
1922 let tool_info = McpToolInfo {
1923 name: "test_tool".to_string(),
1924 description: "A test tool".to_string(),
1925 provider: "test_provider".to_string(),
1926 input_schema: serde_json::json!({
1927 "type": "object",
1928 "properties": {
1929 "input": {"type": "string"}
1930 }
1931 }),
1932 };
1933
1934 assert_eq!(tool_info.name, "test_tool");
1935 assert_eq!(tool_info.provider, "test_provider");
1936 }
1937
1938 #[test]
1939 fn test_format_tool_result_success() {
1940 let mut result = CallToolResult::structured(json!({
1941 "value": 42,
1942 "status": "ok"
1943 }));
1944 let mut meta = Meta::new();
1945 meta.insert("query".to_string(), Value::String("tokio".to_string()));
1946 result.meta = Some(meta);
1947
1948 let serialized = McpClient::format_tool_result("test", "demo", result).unwrap();
1949 assert_eq!(
1950 serialized.get("provider").and_then(Value::as_str),
1951 Some("test")
1952 );
1953 assert_eq!(serialized.get("tool").and_then(Value::as_str), Some("demo"));
1954 assert_eq!(serialized.get("status").and_then(Value::as_str), Some("ok"));
1955 assert_eq!(serialized.get("value").and_then(Value::as_i64), Some(42));
1956 assert_eq!(
1957 serialized
1958 .get("meta")
1959 .and_then(Value::as_object)
1960 .and_then(|map| map.get("query"))
1961 .and_then(Value::as_str),
1962 Some("tokio")
1963 );
1964 }
1965
1966 #[test]
1967 fn test_format_tool_result_error_detection() {
1968 let result = CallToolResult::structured_error(json!({
1969 "message": "something went wrong"
1970 }));
1971
1972 let error = McpClient::format_tool_result("test", "demo", result).unwrap_err();
1973 assert!(error.to_string().contains("something went wrong"));
1974 }
1975
1976 #[test]
1977 fn test_format_tool_result_error_from_text_content() {
1978 let result = CallToolResult::error(vec![Content::text("plain failure")]);
1979
1980 let error = McpClient::format_tool_result("test", "demo", result).unwrap_err();
1981 assert!(error.to_string().contains("plain failure"));
1982 }
1983}