1use super::model::*;
11use super::utils::client_factory;
12use super::vrl_runtime::VrlRuntime;
13use crate::errors::ComputerError;
14use serde_json::Value;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use std::sync::Arc as StdArc;
18use tokio::sync::{watch, RwLock};
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, thiserror::Error)]
23#[error("Tool '{tool_name}' exists in multiple servers: {servers:?}")]
24pub struct ToolNameDuplicatedError {
25 pub tool_name: String,
26 pub servers: Vec<String>,
27}
28
29pub struct MCPServerManager {
31 servers_config: Arc<RwLock<HashMap<ServerName, MCPServerConfig>>>,
33 active_clients: Arc<RwLock<HashMap<ServerName, StdArc<dyn MCPClientProtocol>>>>,
35 tool_mapping: Arc<RwLock<HashMap<ToolName, ServerName>>>,
37 alias_mapping: Arc<RwLock<HashMap<String, (ServerName, ToolName)>>>,
39 disabled_tools: Arc<RwLock<HashSet<ToolName>>>,
41 auto_reconnect: Arc<RwLock<bool>>,
43 auto_connect: Arc<RwLock<bool>>,
45 state_notifier: watch::Sender<ManagerState>,
47 health_check_config: Arc<RwLock<HealthCheckConfig>>,
49 reconnect_policy: Arc<RwLock<ReconnectPolicy>>,
51 health_monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
53 retry_counts: Arc<RwLock<HashMap<ServerName, u32>>>,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ManagerState {
60 Uninitialized,
62 Initialized,
64 Running,
66 Error,
68}
69
70impl MCPServerManager {
71 pub fn new() -> Self {
73 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
74
75 Self {
76 servers_config: Arc::new(RwLock::new(HashMap::new())),
77 active_clients: Arc::new(RwLock::new(HashMap::new())),
78 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
79 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
80 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
81 auto_reconnect: Arc::new(RwLock::new(true)),
82 auto_connect: Arc::new(RwLock::new(false)),
83 state_notifier: state_tx,
84 health_check_config: Arc::new(RwLock::new(HealthCheckConfig::default())),
85 reconnect_policy: Arc::new(RwLock::new(ReconnectPolicy::default())),
86 health_monitor_handle: Arc::new(RwLock::new(None)),
87 retry_counts: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90
91 pub fn with_config(
93 health_check_config: HealthCheckConfig,
94 reconnect_policy: ReconnectPolicy,
95 ) -> Self {
96 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
97
98 Self {
99 servers_config: Arc::new(RwLock::new(HashMap::new())),
100 active_clients: Arc::new(RwLock::new(HashMap::new())),
101 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
102 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
103 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
104 auto_reconnect: Arc::new(RwLock::new(reconnect_policy.enabled)),
105 auto_connect: Arc::new(RwLock::new(false)),
106 state_notifier: state_tx,
107 health_check_config: Arc::new(RwLock::new(health_check_config)),
108 reconnect_policy: Arc::new(RwLock::new(reconnect_policy)),
109 health_monitor_handle: Arc::new(RwLock::new(None)),
110 retry_counts: Arc::new(RwLock::new(HashMap::new())),
111 }
112 }
113
114 pub fn get_state_notifier(&self) -> watch::Receiver<ManagerState> {
116 self.state_notifier.subscribe()
117 }
118
119 async fn update_state(&self, state: ManagerState) {
121 let _ = self.state_notifier.send(state);
122 }
123
124 pub async fn initialize(&self, servers: Vec<MCPServerConfig>) -> Result<(), ComputerError> {
126 self.stop_all().await?;
128
129 self.clear_all().await;
131
132 {
134 let mut configs = self.servers_config.write().await;
135 for server in servers {
136 configs.insert(server.name().to_string(), server);
137 }
138 }
139
140 self.refresh_tool_mapping().await?;
142
143 self.update_state(ManagerState::Initialized).await;
145
146 info!("Manager initialized successfully");
147 Ok(())
148 }
149
150 pub async fn add_or_update_server(&self, config: MCPServerConfig) -> Result<(), ComputerError> {
152 let server_name = config.name().to_string();
153
154 let is_active = {
156 let clients = self.active_clients.read().await;
157 clients.contains_key(&server_name)
158 };
159
160 if is_active {
161 let auto_reconnect = *self.auto_reconnect.read().await;
162 if auto_reconnect {
163 self.restart_server(&server_name).await?;
165 } else {
166 return Err(ComputerError::InvalidConfiguration(format!(
167 "Server {} is active. Stop it before updating config",
168 server_name
169 )));
170 }
171 }
172
173 {
175 let mut configs = self.servers_config.write().await;
176 configs.insert(server_name.clone(), config);
177 }
178
179 let auto_connect = *self.auto_connect.read().await;
181 if auto_connect && !is_active {
182 self.start_client(&server_name).await?;
183 }
184
185 self.refresh_tool_mapping().await?;
187
188 Ok(())
189 }
190
191 pub async fn remove_server(&self, server_name: &str) -> Result<(), ComputerError> {
193 self.stop_client(server_name).await?;
195
196 {
198 let mut configs = self.servers_config.write().await;
199 configs.remove(server_name);
200 }
201
202 self.refresh_tool_mapping().await?;
204
205 Ok(())
206 }
207
208 pub async fn start_all(&self) -> Result<(), ComputerError> {
210 let configs = self.servers_config.read().await;
211 let server_names: Vec<String> = configs
212 .iter()
213 .filter(|(_, config)| !config.disabled())
214 .map(|(name, _)| name.clone())
215 .collect();
216
217 drop(configs);
218
219 for server_name in server_names {
220 self.start_client(&server_name).await?;
221 }
222
223 self.update_state(ManagerState::Running).await;
225
226 info!("All servers started successfully");
227 Ok(())
228 }
229
230 pub async fn start_client(&self, server_name: &str) -> Result<(), ComputerError> {
232 let config = {
234 let configs = self.servers_config.read().await;
235 configs.get(server_name).cloned()
236 };
237
238 let config = config.ok_or_else(|| {
239 ComputerError::InvalidConfiguration(format!("Unknown server: {}", server_name))
240 })?;
241
242 if config.disabled() {
243 return Err(ComputerError::InvalidConfiguration(format!(
244 "Cannot start disabled server: {}",
245 server_name
246 )));
247 }
248
249 {
251 let clients = self.active_clients.read().await;
252 if clients.contains_key(server_name) {
253 return Ok(()); }
255 }
256
257 let client = client_factory(config);
259
260 client.connect().await.map_err(|e| {
262 ComputerError::ConnectionError(format!("Failed to connect to {}: {}", server_name, e))
263 })?;
264
265 {
267 let mut clients = self.active_clients.write().await;
268 clients.insert(server_name.to_string(), client);
269 }
270
271 self.refresh_tool_mapping().await?;
273
274 info!("Client {} started successfully", server_name);
275 Ok(())
276 }
277
278 pub async fn stop_client(&self, server_name: &str) -> Result<(), ComputerError> {
280 let mut client = {
282 let mut clients = self.active_clients.write().await;
283 clients.remove(server_name)
284 };
285
286 if let Some(ref mut c) = client {
288 c.disconnect().await.map_err(|e| {
289 ComputerError::ConnectionError(format!(
290 "Failed to disconnect from {}: {}",
291 server_name, e
292 ))
293 })?;
294 }
295
296 self.refresh_tool_mapping().await?;
298
299 info!("Client {} stopped successfully", server_name);
300 Ok(())
301 }
302
303 async fn restart_server(&self, server_name: &str) -> Result<(), ComputerError> {
305 self.stop_client(server_name).await?;
306
307 let enabled = {
309 let configs = self.servers_config.read().await;
310 configs
311 .get(server_name)
312 .map(|c| !c.disabled())
313 .unwrap_or(false)
314 };
315
316 if enabled {
317 self.start_client(server_name).await?;
318 }
319
320 Ok(())
321 }
322
323 pub async fn stop_all(&self) -> Result<(), ComputerError> {
325 let server_names: Vec<String> = {
326 let clients = self.active_clients.read().await;
327 clients.keys().cloned().collect()
328 };
329
330 for server_name in server_names {
331 self.stop_client(&server_name).await?;
332 }
333
334 self.update_state(ManagerState::Initialized).await;
336
337 info!("All servers stopped successfully");
338 Ok(())
339 }
340
341 async fn clear_all(&self) {
343 self.servers_config.write().await.clear();
344 self.active_clients.write().await.clear();
345 self.tool_mapping.write().await.clear();
346 self.alias_mapping.write().await.clear();
347 self.disabled_tools.write().await.clear();
348 }
349
350 pub async fn close(&self) -> Result<(), ComputerError> {
352 self.stop_all().await?;
353 self.clear_all().await;
354 self.update_state(ManagerState::Uninitialized).await;
355 info!("Manager closed successfully");
356 Ok(())
357 }
358
359 async fn refresh_tool_mapping(&self) -> Result<(), ComputerError> {
361 self.tool_mapping.write().await.clear();
363 self.alias_mapping.write().await.clear();
364 self.disabled_tools.write().await.clear();
365
366 let mut tool_sources: HashMap<ToolName, Vec<ServerName>> = HashMap::new();
368
369 let clients = self.active_clients.read().await;
371 let configs = self.servers_config.read().await;
372
373 for (server_name, client) in clients.iter() {
374 let config = match configs.get(server_name) {
375 Some(c) => c,
376 None => continue,
377 };
378
379 match client.list_tools().await {
381 Ok(tools) => {
382 for tool in tools {
383 let original_tool_name = tool.name.clone();
384
385 let tool_meta = self.merged_tool_meta(config, &original_tool_name);
387
388 let display_name = tool_meta
390 .and_then(|meta| meta.alias)
391 .unwrap_or_else(|| original_tool_name.clone());
392
393 if display_name != original_tool_name {
395 let mut alias_map = self.alias_mapping.write().await;
396 alias_map.insert(
397 display_name.clone(),
398 (server_name.clone(), original_tool_name.clone()),
399 );
400 }
401
402 tool_sources
404 .entry(display_name.clone())
405 .or_default()
406 .push(server_name.clone());
407
408 let forbidden_tools = config.forbidden_tools();
410 if forbidden_tools.contains(&display_name)
411 || forbidden_tools.contains(&original_tool_name)
412 {
413 let mut disabled = self.disabled_tools.write().await;
414 disabled.insert(display_name);
415 }
416 }
417 }
418 Err(e) => {
419 error!("Error listing tools for {}: {}", server_name, e);
420 }
421 }
422 }
423
424 for (tool, sources) in tool_sources {
426 if sources.len() > 1 {
427 warn!("Tool '{}' exists in multiple servers: {:?}", tool, sources);
428 let suggestion =
429 "Please use the 'alias' feature in ToolMeta to resolve conflicts. \
430 Each tool should have a unique name or alias across all servers.";
431 return Err(ComputerError::InvalidConfiguration(format!(
432 "Tool '{}' exists in multiple servers: {:?}\n{}",
433 tool, sources, suggestion
434 )));
435 }
436 let mut mapping = self.tool_mapping.write().await;
437 mapping.insert(tool, sources[0].clone());
438 }
439
440 debug!("Tool mapping refreshed successfully");
441 Ok(())
442 }
443
444 pub async fn validate_tool_call(
446 &self,
447 tool_name: &str,
448 _parameters: &serde_json::Value,
449 ) -> Result<(ServerName, ToolName), ComputerError> {
450 let disabled = self.disabled_tools.read().await;
452 if disabled.contains(tool_name) {
453 return Err(ComputerError::PermissionError(format!(
454 "Tool '{}' is disabled by configuration",
455 tool_name
456 )));
457 }
458
459 let server_name = {
461 let mapping = self.tool_mapping.read().await;
462 mapping.get(tool_name).cloned()
463 };
464
465 let server_name = server_name.ok_or_else(|| {
466 ComputerError::InvalidConfiguration(format!(
467 "Tool '{}' not found in any active server",
468 tool_name
469 ))
470 })?;
471
472 let original_tool_name = {
474 let alias_map = self.alias_mapping.read().await;
475 if let Some((_, original)) = alias_map.get(tool_name) {
476 original.clone()
477 } else {
478 tool_name.to_string()
479 }
480 };
481
482 Ok((server_name, original_tool_name))
483 }
484
485 pub async fn call_tool(
487 &self,
488 server_name: &str,
489 tool_name: &str,
490 parameters: serde_json::Value,
491 timeout: Option<std::time::Duration>,
492 ) -> Result<CallToolResult, ComputerError> {
493 let client = {
495 let clients = self.active_clients.read().await;
496 clients
497 .get(server_name)
498 .ok_or_else(|| {
499 ComputerError::InvalidConfiguration(format!(
500 "Server '{}' for tool '{}' is not active",
501 server_name, tool_name
502 ))
503 })?
504 .clone()
505 };
506
507 let result = if let Some(timeout) = timeout {
509 tokio::time::timeout(timeout, client.call_tool(tool_name, parameters))
510 .await
511 .map_err(|_| ComputerError::TimeoutError("Tool execution timed out".to_string()))?
512 } else {
513 client.call_tool(tool_name, parameters).await
514 };
515
516 let mut result = result
517 .map_err(|e| ComputerError::ProtocolError(format!("Tool execution failed: {}", e)))?;
518
519 let config = {
521 let configs = self.servers_config.read().await;
522 configs.get(server_name).cloned()
523 };
524
525 if let Some(config) = config {
526 if let Some(tool_meta) = self.merged_tool_meta(&config, tool_name) {
527 if result.meta.is_none() {
528 result.meta = Some(std::collections::HashMap::new());
529 }
530 if let Some(ref mut meta) = result.meta {
531 meta.insert(
532 A2C_TOOL_META.to_string(),
533 serde_json::to_value(tool_meta).unwrap(),
534 );
535 }
536 }
537
538 if let Some(vrl_script) = config.vrl() {
540 let parameters = serde_json::json!({});
543
544 let mut event = serde_json::to_value(&result).unwrap_or_default();
546 if let Value::Object(ref mut map) = event {
547 map.insert(
548 "tool_name".to_string(),
549 Value::String(tool_name.to_string()),
550 );
551 map.insert("parameters".to_string(), parameters);
552 }
553
554 let mut runtime = VrlRuntime::new();
556 match runtime.run(vrl_script, event, "UTC") {
557 Ok(vrl_result) => {
558 if result.meta.is_none() {
560 result.meta = Some(std::collections::HashMap::new());
561 }
562 if let Some(ref mut meta) = result.meta {
563 if let Ok(transformed_json) =
565 serde_json::to_string(&vrl_result.processed_event)
566 {
567 meta.insert(
568 A2C_VRL_TRANSFORMED.to_string(),
569 Value::String(transformed_json),
570 );
571 }
572 }
573 debug!(
574 "VRL转换成功 / VRL transformation succeeded for tool '{}'",
575 tool_name
576 );
577 }
578 Err(e) => {
579 warn!(
580 "VRL转换失败 / VRL transformation failed for tool '{}': {}. 原始结果将正常返回 / Original result will be returned normally.",
581 tool_name, e
582 );
583 }
584 }
585 }
586 }
587
588 Ok(result)
589 }
590
591 pub async fn execute_tool(
593 &self,
594 tool_name: &str,
595 parameters: serde_json::Value,
596 timeout: Option<std::time::Duration>,
597 ) -> Result<CallToolResult, ComputerError> {
598 let (server_name, original_tool_name) =
599 self.validate_tool_call(tool_name, ¶meters).await?;
600 self.call_tool(&server_name, &original_tool_name, parameters, timeout)
601 .await
602 }
603
604 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
606 let configs = self.servers_config.read().await;
607 let clients = self.active_clients.read().await;
608
609 configs
610 .keys()
611 .map(|name| {
612 let is_active = clients.contains_key(name);
613 let state = if is_active {
614 clients
615 .get(name)
616 .map(|c| c.state().to_string())
617 .unwrap_or_else(|| "unknown".to_string())
618 } else {
619 "pending".to_string()
620 };
621 (name.clone(), is_active, state)
622 })
623 .collect()
624 }
625
626 pub async fn get_server_configs(&self) -> serde_json::Value {
631 let configs = self.servers_config.read().await;
632 let clients = self.active_clients.read().await;
633
634 let mut result = serde_json::Map::new();
635
636 for (name, config) in configs.iter() {
637 let is_active = clients.contains_key(name);
638 let state = if is_active {
639 clients
640 .get(name)
641 .map(|c| c.state().to_string())
642 .unwrap_or_else(|| "unknown".to_string())
643 } else {
644 "pending".to_string()
645 };
646
647 let mut server_info = serde_json::Map::new();
649
650 let server_type = match config {
652 MCPServerConfig::Stdio(_) => "stdio",
653 MCPServerConfig::Sse(_) => "sse",
654 MCPServerConfig::Http(_) => "http",
655 };
656 server_info.insert(
657 "type".to_string(),
658 serde_json::Value::String(server_type.to_string()),
659 );
660
661 server_info.insert("status".to_string(), serde_json::Value::String(state));
663 server_info.insert("is_active".to_string(), serde_json::Value::Bool(is_active));
664 server_info.insert(
665 "disabled".to_string(),
666 serde_json::Value::Bool(config.disabled()),
667 );
668
669 let forbidden_tools: Vec<serde_json::Value> = config
671 .forbidden_tools()
672 .iter()
673 .map(|t| serde_json::Value::String(t.clone()))
674 .collect();
675 server_info.insert(
676 "forbidden_tools".to_string(),
677 serde_json::Value::Array(forbidden_tools),
678 );
679
680 if let Ok(tool_meta_json) = serde_json::to_value(config.tool_meta()) {
682 server_info.insert("tool_meta".to_string(), tool_meta_json);
683 }
684
685 if let Some(default_meta) = config.default_tool_meta() {
687 if let Ok(default_meta_json) = serde_json::to_value(default_meta) {
688 server_info.insert("default_tool_meta".to_string(), default_meta_json);
689 }
690 }
691
692 if let Some(vrl) = config.vrl() {
694 server_info.insert(
695 "vrl".to_string(),
696 serde_json::Value::String(vrl.to_string()),
697 );
698 }
699
700 match config {
702 MCPServerConfig::Stdio(stdio_config) => {
703 if let Ok(params_json) = serde_json::to_value(&stdio_config.server_parameters) {
704 server_info.insert("server_parameters".to_string(), params_json);
705 }
706 }
707 MCPServerConfig::Sse(sse_config) => {
708 if let Ok(params_json) = serde_json::to_value(&sse_config.server_parameters) {
709 server_info.insert("server_parameters".to_string(), params_json);
710 }
711 }
712 MCPServerConfig::Http(http_config) => {
713 if let Ok(params_json) = serde_json::to_value(&http_config.server_parameters) {
714 server_info.insert("server_parameters".to_string(), params_json);
715 }
716 }
717 }
718
719 result.insert(name.clone(), serde_json::Value::Object(server_info));
720 }
721
722 serde_json::Value::Object(result)
723 }
724
725 pub async fn list_available_tools(&self) -> Vec<Tool> {
727 let mut tools = Vec::new();
728 let mapping = self.tool_mapping.read().await;
729 let alias_map = self.alias_mapping.read().await;
730
731 for (display_name, server_name) in mapping.iter() {
732 let client = {
733 let clients = self.active_clients.read().await;
734 clients.get(server_name).cloned()
735 };
736
737 if let Some(client) = client {
738 let original_name = alias_map
740 .get(display_name)
741 .map(|(_, original)| original.clone())
742 .unwrap_or_else(|| display_name.clone());
743
744 if let Ok(tool_list) = client.list_tools().await {
746 if let Some(tool) = tool_list.into_iter().find(|t| t.name == original_name) {
747 let mut display_tool = tool;
749 display_tool.name = display_name.clone();
750
751 let config = {
753 let configs = self.servers_config.read().await;
754 configs.get(server_name).cloned()
755 };
756 if let Some(config) = config {
757 if let Some(tool_meta) = self.merged_tool_meta(&config, &original_name)
758 {
759 if display_tool.meta.is_none() {
760 display_tool.meta = Some(HashMap::new());
761 }
762 if let Some(ref mut meta) = display_tool.meta {
763 meta.insert(
764 A2C_TOOL_META.to_string(),
765 serde_json::to_value(tool_meta).unwrap(),
766 );
767 }
768 }
769 }
770
771 tools.push(display_tool);
772 }
773 }
774 }
775 }
776
777 tools
778 }
779
780 fn merged_tool_meta(&self, config: &MCPServerConfig, tool_name: &str) -> Option<ToolMeta> {
782 let specific = config.tool_meta().get(tool_name);
783 let default = config.default_tool_meta();
784
785 match (specific, default) {
786 (None, None) => None,
787 (Some(s), None) => Some(s.clone()),
788 (None, Some(d)) => Some(d.clone()),
789 (Some(s), Some(d)) => {
790 let mut merged = d.clone();
792 if s.auto_apply.is_some() {
793 merged.auto_apply = s.auto_apply;
794 }
795 if s.alias.is_some() {
796 merged.alias = s.alias.clone();
797 }
798 if s.tags.is_some() {
799 merged.tags = s.tags.clone();
800 }
801 if s.ret_object_mapper.is_some() {
802 merged.ret_object_mapper = s.ret_object_mapper.clone();
803 }
804 Some(merged)
805 }
806 }
807 }
808
809 pub async fn enable_auto_connect(&self) {
811 *self.auto_connect.write().await = true;
812 }
813
814 pub async fn disable_auto_connect(&self) {
816 *self.auto_connect.write().await = false;
817 }
818
819 pub async fn enable_auto_reconnect(&self) {
821 *self.auto_reconnect.write().await = true;
822 }
823
824 pub async fn disable_auto_reconnect(&self) {
826 *self.auto_reconnect.write().await = false;
827 }
828
829 pub async fn set_health_check_config(&self, config: HealthCheckConfig) {
831 *self.health_check_config.write().await = config;
832 }
833
834 pub async fn get_health_check_config(&self) -> HealthCheckConfig {
836 self.health_check_config.read().await.clone()
837 }
838
839 pub async fn set_reconnect_policy(&self, policy: ReconnectPolicy) {
841 *self.reconnect_policy.write().await = policy;
842 }
843
844 pub async fn get_reconnect_policy(&self) -> ReconnectPolicy {
846 self.reconnect_policy.read().await.clone()
847 }
848
849 pub async fn start_health_monitor(&self) {
853 self.stop_health_monitor().await;
855
856 let health_config = self.health_check_config.clone();
857 let reconnect_policy = self.reconnect_policy.clone();
858 let active_clients = self.active_clients.clone();
859 let _servers_config = self.servers_config.clone();
860 let retry_counts = self.retry_counts.clone();
861 let auto_reconnect = self.auto_reconnect.clone();
862
863 let handle = tokio::spawn(async move {
864 loop {
865 let config = health_config.read().await.clone();
866 if !config.enabled {
867 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
870 continue;
871 }
872
873 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
875 let clients_guard = active_clients.read().await;
876 clients_guard
877 .iter()
878 .map(|(k, v)| (k.clone(), v.clone()))
879 .collect()
880 };
881
882 for (server_name, client) in clients {
884 let check_result = tokio::time::timeout(
885 std::time::Duration::from_secs(config.timeout_secs),
886 client.health_check(),
887 )
888 .await;
889
890 let is_healthy = match check_result {
891 Ok(result) => result.is_healthy,
892 Err(_) => {
893 warn!("Health check timed out for server: {}", server_name);
894 false
895 }
896 };
897
898 if !is_healthy {
899 warn!("Server {} is unhealthy", server_name);
900
901 let should_reconnect = *auto_reconnect.read().await;
903 if !should_reconnect {
904 continue;
905 }
906
907 let policy = reconnect_policy.read().await.clone();
908 let mut retries = retry_counts.write().await;
909 let retry_count = retries.entry(server_name.clone()).or_insert(0);
910
911 if policy.should_retry(*retry_count) {
912 let delay = policy.calculate_delay(*retry_count);
913 info!(
914 "Attempting to reconnect {} (retry {}/{}), delay {:?}",
915 server_name,
916 *retry_count + 1,
917 if policy.max_retries == 0 {
918 "∞".to_string()
919 } else {
920 policy.max_retries.to_string()
921 },
922 delay
923 );
924
925 tokio::time::sleep(delay).await;
926
927 if let Err(e) = client.disconnect().await {
929 warn!("Failed to disconnect {}: {}", server_name, e);
930 }
931
932 match client.connect().await {
933 Ok(_) => {
934 info!("Successfully reconnected to {}", server_name);
935 *retry_count = 0;
937 }
938 Err(e) => {
939 error!("Failed to reconnect to {}: {}", server_name, e);
940 *retry_count += 1;
941 }
942 }
943 } else {
944 error!(
945 "Max retries ({}) reached for server {}. Giving up.",
946 policy.max_retries, server_name
947 );
948 }
950 } else {
951 let mut retries = retry_counts.write().await;
953 retries.remove(&server_name);
954 debug!("Server {} is healthy", server_name);
955 }
956 }
957
958 tokio::time::sleep(std::time::Duration::from_secs(config.interval_secs)).await;
960 }
961 });
962
963 *self.health_monitor_handle.write().await = Some(handle);
964 info!("Health monitor started");
965 }
966
967 pub async fn stop_health_monitor(&self) {
969 if let Some(handle) = self.health_monitor_handle.write().await.take() {
970 handle.abort();
971 info!("Health monitor stopped");
972 }
973 }
974
975 pub async fn check_server_health(&self, server_name: &str) -> Option<HealthCheckResult> {
977 let clients = self.active_clients.read().await;
978 if let Some(client) = clients.get(server_name) {
979 let config = self.health_check_config.read().await;
980 let result = tokio::time::timeout(
981 std::time::Duration::from_secs(config.timeout_secs),
982 client.health_check(),
983 )
984 .await;
985
986 match result {
987 Ok(health_result) => Some(health_result),
988 Err(_) => Some(HealthCheckResult {
989 is_healthy: false,
990 checked_at: std::time::Instant::now(),
991 error: Some("Health check timed out".to_string()),
992 response_time_ms: None,
993 }),
994 }
995 } else {
996 None
997 }
998 }
999
1000 pub async fn check_all_health(&self) -> HashMap<String, HealthCheckResult> {
1002 let mut results = HashMap::new();
1003 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
1004 let clients_guard = self.active_clients.read().await;
1005 clients_guard
1006 .iter()
1007 .map(|(k, v)| (k.clone(), v.clone()))
1008 .collect()
1009 };
1010
1011 let config = self.health_check_config.read().await.clone();
1012
1013 for (server_name, client) in clients {
1014 let result = tokio::time::timeout(
1015 std::time::Duration::from_secs(config.timeout_secs),
1016 client.health_check(),
1017 )
1018 .await;
1019
1020 let health_result = match result {
1021 Ok(hr) => hr,
1022 Err(_) => HealthCheckResult {
1023 is_healthy: false,
1024 checked_at: std::time::Instant::now(),
1025 error: Some("Health check timed out".to_string()),
1026 response_time_ms: None,
1027 },
1028 };
1029
1030 results.insert(server_name, health_result);
1031 }
1032
1033 results
1034 }
1035
1036 pub async fn get_retry_counts(&self) -> HashMap<String, u32> {
1038 self.retry_counts.read().await.clone()
1039 }
1040
1041 pub async fn reset_retry_count(&self, server_name: &str) {
1043 self.retry_counts.write().await.remove(server_name);
1044 }
1045
1046 pub async fn reset_all_retry_counts(&self) {
1048 self.retry_counts.write().await.clear();
1049 }
1050}
1051
1052impl Default for MCPServerManager {
1053 fn default() -> Self {
1054 Self::new()
1055 }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 use super::*;
1061 use std::collections::HashMap;
1062 use tokio::time::{sleep, Duration};
1063
1064 #[tokio::test]
1065 async fn test_manager_creation() {
1066 let manager = MCPServerManager::new();
1067 let status = manager.get_server_status().await;
1068 assert!(status.is_empty());
1069 }
1070
1071 #[tokio::test]
1072 async fn test_manager_initialization() {
1073 let manager = MCPServerManager::new();
1074
1075 let configs = vec![
1077 MCPServerConfig::Stdio(StdioServerConfig {
1079 name: "test_stdio".to_string(),
1080 disabled: false,
1081 forbidden_tools: vec![],
1082 tool_meta: HashMap::new(),
1083 default_tool_meta: None,
1084 vrl: None,
1085 server_parameters: StdioServerParameters {
1086 command: "echo".to_string(),
1087 args: vec!["hello".to_string()],
1088 env: HashMap::new(),
1089 cwd: None,
1090 },
1091 }),
1092 MCPServerConfig::Http(HttpServerConfig {
1094 name: "test_http".to_string(),
1095 disabled: true, forbidden_tools: vec![],
1097 tool_meta: HashMap::new(),
1098 default_tool_meta: None,
1099 vrl: None,
1100 server_parameters: HttpServerParameters {
1101 url: "http://localhost:8080".to_string(),
1102 headers: HashMap::new(),
1103 },
1104 }),
1105 ];
1106
1107 let result = manager.initialize(configs).await;
1109 assert!(result.is_ok());
1110
1111 let status = manager.get_server_status().await;
1113 assert_eq!(status.len(), 2);
1114
1115 let stdio_status = status
1117 .iter()
1118 .find(|(name, _, _)| name == "test_stdio")
1119 .unwrap();
1120 assert!(!stdio_status.1); let http_status = status
1123 .iter()
1124 .find(|(name, _, _)| name == "test_http")
1125 .unwrap();
1126 assert!(!http_status.1); }
1128
1129 #[tokio::test]
1130 async fn test_add_server() {
1131 let manager = MCPServerManager::new();
1132
1133 let config = MCPServerConfig::Stdio(StdioServerConfig {
1135 name: "test_server".to_string(),
1136 disabled: false,
1137 forbidden_tools: vec![],
1138 tool_meta: HashMap::new(),
1139 default_tool_meta: None,
1140 vrl: None,
1141 server_parameters: StdioServerParameters {
1142 command: "echo".to_string(),
1143 args: vec![],
1144 env: HashMap::new(),
1145 cwd: None,
1146 },
1147 });
1148
1149 let result = manager.add_or_update_server(config).await;
1150 assert!(result.is_ok());
1151
1152 let status = manager.get_server_status().await;
1154 assert_eq!(status.len(), 1);
1155 assert_eq!(status[0].0, "test_server");
1156 }
1157
1158 #[tokio::test]
1159 async fn test_remove_server() {
1160 let manager = MCPServerManager::new();
1161
1162 let config = MCPServerConfig::Stdio(StdioServerConfig {
1164 name: "test_server".to_string(),
1165 disabled: false,
1166 forbidden_tools: vec![],
1167 tool_meta: HashMap::new(),
1168 default_tool_meta: None,
1169 vrl: None,
1170 server_parameters: StdioServerParameters {
1171 command: "echo".to_string(),
1172 args: vec![],
1173 env: HashMap::new(),
1174 cwd: None,
1175 },
1176 });
1177
1178 manager.add_or_update_server(config).await.unwrap();
1179
1180 let result = manager.remove_server("test_server").await;
1182 assert!(result.is_ok());
1183
1184 let status = manager.get_server_status().await;
1186 assert!(status.is_empty());
1187 }
1188
1189 #[tokio::test]
1190 async fn test_tool_conflict_detection() {
1191 let manager = MCPServerManager::new();
1192
1193 let configs = vec![
1195 MCPServerConfig::Stdio(StdioServerConfig {
1197 name: "server1".to_string(),
1198 disabled: false,
1199 forbidden_tools: vec![],
1200 tool_meta: HashMap::new(),
1201 default_tool_meta: None,
1202 vrl: None,
1203 server_parameters: StdioServerParameters {
1204 command: "echo".to_string(),
1205 args: vec!["server1".to_string()],
1206 env: HashMap::new(),
1207 cwd: None,
1208 },
1209 }),
1210 MCPServerConfig::Stdio(StdioServerConfig {
1212 name: "server2".to_string(),
1213 disabled: false,
1214 forbidden_tools: vec![],
1215 tool_meta: HashMap::new(),
1216 default_tool_meta: None,
1217 vrl: None,
1218 server_parameters: StdioServerParameters {
1219 command: "echo".to_string(),
1220 args: vec!["server2".to_string()],
1221 env: HashMap::new(),
1222 cwd: None,
1223 },
1224 }),
1225 ];
1226
1227 let result = manager.initialize(configs).await;
1229 assert!(result.is_ok());
1230
1231 let _result = manager.start_all().await;
1233 sleep(Duration::from_millis(200)).await;
1238 }
1239
1240 #[tokio::test]
1241 async fn test_health_check_config() {
1242 let manager = MCPServerManager::new();
1243
1244 let config = manager.get_health_check_config().await;
1246 assert_eq!(config.interval_secs, 30);
1247 assert_eq!(config.timeout_secs, 5);
1248 assert!(config.enabled);
1249
1250 let new_config = HealthCheckConfig {
1252 interval_secs: 60,
1253 timeout_secs: 10,
1254 enabled: false,
1255 };
1256 manager.set_health_check_config(new_config.clone()).await;
1257
1258 let updated = manager.get_health_check_config().await;
1259 assert_eq!(updated.interval_secs, 60);
1260 assert_eq!(updated.timeout_secs, 10);
1261 assert!(!updated.enabled);
1262 }
1263
1264 #[tokio::test]
1265 async fn test_reconnect_policy() {
1266 let manager = MCPServerManager::new();
1267
1268 let policy = manager.get_reconnect_policy().await;
1270 assert!(policy.enabled);
1271 assert_eq!(policy.max_retries, 5);
1272 assert_eq!(policy.initial_delay_ms, 1000);
1273 assert_eq!(policy.max_delay_ms, 30000);
1274 assert_eq!(policy.backoff_factor, 2.0);
1275
1276 assert_eq!(policy.calculate_delay(0).as_millis(), 1000);
1278 assert_eq!(policy.calculate_delay(1).as_millis(), 2000);
1279 assert_eq!(policy.calculate_delay(2).as_millis(), 4000);
1280 assert_eq!(policy.calculate_delay(3).as_millis(), 8000);
1281
1282 assert!(policy.should_retry(0));
1284 assert!(policy.should_retry(4));
1285 assert!(!policy.should_retry(5)); let infinite_policy = ReconnectPolicy {
1289 enabled: true,
1290 max_retries: 0,
1291 ..Default::default()
1292 };
1293 assert!(infinite_policy.should_retry(100));
1294 }
1295
1296 #[tokio::test]
1297 async fn test_retry_counts() {
1298 let manager = MCPServerManager::new();
1299
1300 let counts = manager.get_retry_counts().await;
1302 assert!(counts.is_empty());
1303
1304 {
1306 manager
1307 .retry_counts
1308 .write()
1309 .await
1310 .insert("server1".to_string(), 3);
1311 manager
1312 .retry_counts
1313 .write()
1314 .await
1315 .insert("server2".to_string(), 5);
1316 }
1317
1318 let counts = manager.get_retry_counts().await;
1319 assert_eq!(counts.get("server1"), Some(&3));
1320 assert_eq!(counts.get("server2"), Some(&5));
1321
1322 manager.reset_retry_count("server1").await;
1324 let counts = manager.get_retry_counts().await;
1325 assert!(!counts.contains_key("server1"));
1326 assert_eq!(counts.get("server2"), Some(&5));
1327
1328 manager.reset_all_retry_counts().await;
1330 let counts = manager.get_retry_counts().await;
1331 assert!(counts.is_empty());
1332 }
1333
1334 #[tokio::test]
1335 async fn test_manager_with_custom_config() {
1336 let health_config = HealthCheckConfig {
1337 interval_secs: 15,
1338 timeout_secs: 3,
1339 enabled: true,
1340 };
1341 let reconnect_policy = ReconnectPolicy {
1342 enabled: true,
1343 max_retries: 10,
1344 initial_delay_ms: 500,
1345 max_delay_ms: 60000,
1346 backoff_factor: 1.5,
1347 };
1348
1349 let manager =
1350 MCPServerManager::with_config(health_config.clone(), reconnect_policy.clone());
1351
1352 let got_health = manager.get_health_check_config().await;
1353 assert_eq!(got_health.interval_secs, 15);
1354 assert_eq!(got_health.timeout_secs, 3);
1355
1356 let got_reconnect = manager.get_reconnect_policy().await;
1357 assert_eq!(got_reconnect.max_retries, 10);
1358 assert_eq!(got_reconnect.initial_delay_ms, 500);
1359 }
1360
1361 #[tokio::test]
1362 async fn test_merged_tool_meta() {
1363 let manager = MCPServerManager::new();
1364
1365 let config = MCPServerConfig::Stdio(StdioServerConfig {
1367 name: "s".to_string(),
1368 disabled: false,
1369 forbidden_tools: vec![],
1370 tool_meta: HashMap::from([(
1371 "tool_a".to_string(),
1372 ToolMeta {
1373 auto_apply: Some(true),
1374 alias: None,
1375 tags: Some(vec!["tag1".to_string()]),
1376 ret_object_mapper: None,
1377 },
1378 )]),
1379 default_tool_meta: None,
1380 vrl: None,
1381 server_parameters: StdioServerParameters {
1382 command: "echo".to_string(),
1383 args: vec![],
1384 env: HashMap::new(),
1385 cwd: None,
1386 },
1387 });
1388 let meta = manager.merged_tool_meta(&config, "tool_a").unwrap();
1389 assert_eq!(meta.auto_apply, Some(true));
1390 assert_eq!(meta.tags, Some(vec!["tag1".to_string()]));
1391
1392 let config = MCPServerConfig::Stdio(StdioServerConfig {
1394 name: "s".to_string(),
1395 disabled: false,
1396 forbidden_tools: vec![],
1397 tool_meta: HashMap::new(),
1398 default_tool_meta: Some(ToolMeta {
1399 auto_apply: Some(false),
1400 alias: None,
1401 tags: Some(vec!["default_tag".to_string()]),
1402 ret_object_mapper: None,
1403 }),
1404 vrl: None,
1405 server_parameters: StdioServerParameters {
1406 command: "echo".to_string(),
1407 args: vec![],
1408 env: HashMap::new(),
1409 cwd: None,
1410 },
1411 });
1412 let meta = manager.merged_tool_meta(&config, "any_tool").unwrap();
1413 assert_eq!(meta.auto_apply, Some(false));
1414 assert_eq!(meta.tags, Some(vec!["default_tag".to_string()]));
1415
1416 let config = MCPServerConfig::Stdio(StdioServerConfig {
1418 name: "s".to_string(),
1419 disabled: false,
1420 forbidden_tools: vec![],
1421 tool_meta: HashMap::from([(
1422 "tool_a".to_string(),
1423 ToolMeta {
1424 auto_apply: Some(true),
1425 alias: None,
1426 tags: None,
1427 ret_object_mapper: None,
1428 },
1429 )]),
1430 default_tool_meta: Some(ToolMeta {
1431 auto_apply: Some(false),
1432 alias: Some("default_alias".to_string()),
1433 tags: Some(vec!["default_tag".to_string()]),
1434 ret_object_mapper: None,
1435 }),
1436 vrl: None,
1437 server_parameters: StdioServerParameters {
1438 command: "echo".to_string(),
1439 args: vec![],
1440 env: HashMap::new(),
1441 cwd: None,
1442 },
1443 });
1444 let meta = manager.merged_tool_meta(&config, "tool_a").unwrap();
1445 assert_eq!(meta.auto_apply, Some(true)); assert_eq!(meta.alias, Some("default_alias".to_string())); assert_eq!(meta.tags, Some(vec!["default_tag".to_string()])); let config = MCPServerConfig::Stdio(StdioServerConfig {
1451 name: "s".to_string(),
1452 disabled: false,
1453 forbidden_tools: vec![],
1454 tool_meta: HashMap::new(),
1455 default_tool_meta: None,
1456 vrl: None,
1457 server_parameters: StdioServerParameters {
1458 command: "echo".to_string(),
1459 args: vec![],
1460 env: HashMap::new(),
1461 cwd: None,
1462 },
1463 });
1464 assert!(manager.merged_tool_meta(&config, "tool_a").is_none());
1465 }
1466}