1use super::model::*;
11use super::utils::client_factory;
12use super::vrl_runtime::VrlRuntime;
13use crate::errors::ComputerError;
14use serde_json::Value;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use std::sync::Arc as StdArc;
18use tokio::sync::{watch, RwLock};
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, thiserror::Error)]
23#[error("Tool '{tool_name}' exists in multiple servers: {servers:?}")]
24pub struct ToolNameDuplicatedError {
25 pub tool_name: String,
26 pub servers: Vec<String>,
27}
28
29pub struct MCPServerManager {
31 servers_config: Arc<RwLock<HashMap<ServerName, MCPServerConfig>>>,
33 active_clients: Arc<RwLock<HashMap<ServerName, StdArc<dyn MCPClientProtocol>>>>,
35 tool_mapping: Arc<RwLock<HashMap<ToolName, ServerName>>>,
37 alias_mapping: Arc<RwLock<HashMap<String, (ServerName, ToolName)>>>,
39 disabled_tools: Arc<RwLock<HashSet<ToolName>>>,
41 auto_reconnect: Arc<RwLock<bool>>,
43 auto_connect: Arc<RwLock<bool>>,
45 state_notifier: watch::Sender<ManagerState>,
47 health_check_config: Arc<RwLock<HealthCheckConfig>>,
49 reconnect_policy: Arc<RwLock<ReconnectPolicy>>,
51 health_monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
53 retry_counts: Arc<RwLock<HashMap<ServerName, u32>>>,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ManagerState {
60 Uninitialized,
62 Initialized,
64 Running,
66 Error,
68}
69
70impl MCPServerManager {
71 pub fn new() -> Self {
73 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
74
75 Self {
76 servers_config: Arc::new(RwLock::new(HashMap::new())),
77 active_clients: Arc::new(RwLock::new(HashMap::new())),
78 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
79 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
80 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
81 auto_reconnect: Arc::new(RwLock::new(true)),
82 auto_connect: Arc::new(RwLock::new(false)),
83 state_notifier: state_tx,
84 health_check_config: Arc::new(RwLock::new(HealthCheckConfig::default())),
85 reconnect_policy: Arc::new(RwLock::new(ReconnectPolicy::default())),
86 health_monitor_handle: Arc::new(RwLock::new(None)),
87 retry_counts: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90
91 pub fn with_config(
93 health_check_config: HealthCheckConfig,
94 reconnect_policy: ReconnectPolicy,
95 ) -> Self {
96 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
97
98 Self {
99 servers_config: Arc::new(RwLock::new(HashMap::new())),
100 active_clients: Arc::new(RwLock::new(HashMap::new())),
101 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
102 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
103 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
104 auto_reconnect: Arc::new(RwLock::new(reconnect_policy.enabled)),
105 auto_connect: Arc::new(RwLock::new(false)),
106 state_notifier: state_tx,
107 health_check_config: Arc::new(RwLock::new(health_check_config)),
108 reconnect_policy: Arc::new(RwLock::new(reconnect_policy)),
109 health_monitor_handle: Arc::new(RwLock::new(None)),
110 retry_counts: Arc::new(RwLock::new(HashMap::new())),
111 }
112 }
113
114 pub fn get_state_notifier(&self) -> watch::Receiver<ManagerState> {
116 self.state_notifier.subscribe()
117 }
118
119 async fn update_state(&self, state: ManagerState) {
121 let _ = self.state_notifier.send(state);
122 }
123
124 pub async fn initialize(&self, servers: Vec<MCPServerConfig>) -> Result<(), ComputerError> {
126 self.stop_all().await?;
128
129 self.clear_all().await;
131
132 {
134 let mut configs = self.servers_config.write().await;
135 for server in servers {
136 configs.insert(server.name().to_string(), server);
137 }
138 }
139
140 self.refresh_tool_mapping().await?;
142
143 self.update_state(ManagerState::Initialized).await;
145
146 info!("Manager initialized successfully");
147 Ok(())
148 }
149
150 pub async fn add_or_update_server(&self, config: MCPServerConfig) -> Result<(), ComputerError> {
152 let server_name = config.name().to_string();
153
154 let is_active = {
156 let clients = self.active_clients.read().await;
157 clients.contains_key(&server_name)
158 };
159
160 if is_active {
161 let auto_reconnect = *self.auto_reconnect.read().await;
162 if auto_reconnect {
163 self.restart_server(&server_name).await?;
165 } else {
166 return Err(ComputerError::InvalidConfiguration(format!(
167 "Server {} is active. Stop it before updating config",
168 server_name
169 )));
170 }
171 }
172
173 {
175 let mut configs = self.servers_config.write().await;
176 configs.insert(server_name.clone(), config);
177 }
178
179 let auto_connect = *self.auto_connect.read().await;
181 if auto_connect && !is_active {
182 self.start_client(&server_name).await?;
183 }
184
185 self.refresh_tool_mapping().await?;
187
188 Ok(())
189 }
190
191 pub async fn remove_server(&self, server_name: &str) -> Result<(), ComputerError> {
193 self.stop_client(server_name).await?;
195
196 {
198 let mut configs = self.servers_config.write().await;
199 configs.remove(server_name);
200 }
201
202 self.refresh_tool_mapping().await?;
204
205 Ok(())
206 }
207
208 pub async fn start_all(&self) -> Result<(), ComputerError> {
210 let configs = self.servers_config.read().await;
211 let server_names: Vec<String> = configs
212 .iter()
213 .filter(|(_, config)| !config.disabled())
214 .map(|(name, _)| name.clone())
215 .collect();
216
217 drop(configs);
218
219 for server_name in server_names {
220 self.start_client(&server_name).await?;
221 }
222
223 self.update_state(ManagerState::Running).await;
225
226 info!("All servers started successfully");
227 Ok(())
228 }
229
230 pub async fn start_client(&self, server_name: &str) -> Result<(), ComputerError> {
232 let config = {
234 let configs = self.servers_config.read().await;
235 configs.get(server_name).cloned()
236 };
237
238 let config = config.ok_or_else(|| {
239 ComputerError::InvalidConfiguration(format!("Unknown server: {}", server_name))
240 })?;
241
242 if config.disabled() {
243 return Err(ComputerError::InvalidConfiguration(format!(
244 "Cannot start disabled server: {}",
245 server_name
246 )));
247 }
248
249 {
251 let clients = self.active_clients.read().await;
252 if clients.contains_key(server_name) {
253 return Ok(()); }
255 }
256
257 let client = client_factory(config);
259
260 client.connect().await.map_err(|e| {
262 ComputerError::ConnectionError(format!("Failed to connect to {}: {}", server_name, e))
263 })?;
264
265 {
267 let mut clients = self.active_clients.write().await;
268 clients.insert(server_name.to_string(), client);
269 }
270
271 self.refresh_tool_mapping().await?;
273
274 info!("Client {} started successfully", server_name);
275 Ok(())
276 }
277
278 pub async fn stop_client(&self, server_name: &str) -> Result<(), ComputerError> {
280 let mut client = {
282 let mut clients = self.active_clients.write().await;
283 clients.remove(server_name)
284 };
285
286 if let Some(ref mut c) = client {
288 c.disconnect().await.map_err(|e| {
289 ComputerError::ConnectionError(format!(
290 "Failed to disconnect from {}: {}",
291 server_name, e
292 ))
293 })?;
294 }
295
296 self.refresh_tool_mapping().await?;
298
299 info!("Client {} stopped successfully", server_name);
300 Ok(())
301 }
302
303 async fn restart_server(&self, server_name: &str) -> Result<(), ComputerError> {
305 self.stop_client(server_name).await?;
306
307 let enabled = {
309 let configs = self.servers_config.read().await;
310 configs
311 .get(server_name)
312 .map(|c| !c.disabled())
313 .unwrap_or(false)
314 };
315
316 if enabled {
317 self.start_client(server_name).await?;
318 }
319
320 Ok(())
321 }
322
323 pub async fn stop_all(&self) -> Result<(), ComputerError> {
325 let server_names: Vec<String> = {
326 let clients = self.active_clients.read().await;
327 clients.keys().cloned().collect()
328 };
329
330 for server_name in server_names {
331 self.stop_client(&server_name).await?;
332 }
333
334 self.update_state(ManagerState::Initialized).await;
336
337 info!("All servers stopped successfully");
338 Ok(())
339 }
340
341 async fn clear_all(&self) {
343 self.servers_config.write().await.clear();
344 self.active_clients.write().await.clear();
345 self.tool_mapping.write().await.clear();
346 self.alias_mapping.write().await.clear();
347 self.disabled_tools.write().await.clear();
348 }
349
350 pub async fn close(&self) -> Result<(), ComputerError> {
352 self.stop_all().await?;
353 self.clear_all().await;
354 self.update_state(ManagerState::Uninitialized).await;
355 info!("Manager closed successfully");
356 Ok(())
357 }
358
359 async fn refresh_tool_mapping(&self) -> Result<(), ComputerError> {
361 self.tool_mapping.write().await.clear();
363 self.alias_mapping.write().await.clear();
364 self.disabled_tools.write().await.clear();
365
366 let mut tool_sources: HashMap<ToolName, Vec<ServerName>> = HashMap::new();
368
369 let clients = self.active_clients.read().await;
371 let configs = self.servers_config.read().await;
372
373 for (server_name, client) in clients.iter() {
374 let config = match configs.get(server_name) {
375 Some(c) => c,
376 None => continue,
377 };
378
379 match client.list_tools().await {
381 Ok(tools) => {
382 for tool in tools {
383 let original_tool_name = tool.name.clone();
384
385 let tool_meta = self.merged_tool_meta(config, &original_tool_name);
387
388 let display_name = tool_meta
390 .and_then(|meta| meta.alias)
391 .unwrap_or_else(|| original_tool_name.clone());
392
393 if display_name != original_tool_name {
395 let mut alias_map = self.alias_mapping.write().await;
396 alias_map.insert(
397 display_name.clone(),
398 (server_name.clone(), original_tool_name.clone()),
399 );
400 }
401
402 tool_sources
404 .entry(display_name.clone())
405 .or_default()
406 .push(server_name.clone());
407
408 let forbidden_tools = config.forbidden_tools();
410 if forbidden_tools.contains(&display_name)
411 || forbidden_tools.contains(&original_tool_name)
412 {
413 let mut disabled = self.disabled_tools.write().await;
414 disabled.insert(display_name);
415 }
416 }
417 }
418 Err(e) => {
419 error!("Error listing tools for {}: {}", server_name, e);
420 }
421 }
422 }
423
424 for (tool, sources) in tool_sources {
426 if sources.len() > 1 {
427 warn!("Tool '{}' exists in multiple servers: {:?}", tool, sources);
428 let suggestion =
429 "Please use the 'alias' feature in ToolMeta to resolve conflicts. \
430 Each tool should have a unique name or alias across all servers.";
431 return Err(ComputerError::InvalidConfiguration(format!(
432 "Tool '{}' exists in multiple servers: {:?}\n{}",
433 tool, sources, suggestion
434 )));
435 }
436 let mut mapping = self.tool_mapping.write().await;
437 mapping.insert(tool, sources[0].clone());
438 }
439
440 debug!("Tool mapping refreshed successfully");
441 Ok(())
442 }
443
444 pub async fn validate_tool_call(
446 &self,
447 tool_name: &str,
448 _parameters: &serde_json::Value,
449 ) -> Result<(ServerName, ToolName), ComputerError> {
450 let disabled = self.disabled_tools.read().await;
452 if disabled.contains(tool_name) {
453 return Err(ComputerError::PermissionError(format!(
454 "Tool '{}' is disabled by configuration",
455 tool_name
456 )));
457 }
458
459 let server_name = {
461 let mapping = self.tool_mapping.read().await;
462 mapping.get(tool_name).cloned()
463 };
464
465 let server_name = server_name.ok_or_else(|| {
466 ComputerError::InvalidConfiguration(format!(
467 "Tool '{}' not found in any active server",
468 tool_name
469 ))
470 })?;
471
472 let original_tool_name = {
474 let alias_map = self.alias_mapping.read().await;
475 if let Some((_, original)) = alias_map.get(tool_name) {
476 original.clone()
477 } else {
478 tool_name.to_string()
479 }
480 };
481
482 Ok((server_name, original_tool_name))
483 }
484
485 pub async fn call_tool(
487 &self,
488 server_name: &str,
489 tool_name: &str,
490 parameters: serde_json::Value,
491 timeout: Option<std::time::Duration>,
492 ) -> Result<CallToolResult, ComputerError> {
493 let client = {
495 let clients = self.active_clients.read().await;
496 clients
497 .get(server_name)
498 .ok_or_else(|| {
499 ComputerError::InvalidConfiguration(format!(
500 "Server '{}' for tool '{}' is not active",
501 server_name, tool_name
502 ))
503 })?
504 .clone()
505 };
506
507 let result = if let Some(timeout) = timeout {
509 tokio::time::timeout(timeout, client.call_tool(tool_name, parameters))
510 .await
511 .map_err(|_| ComputerError::TimeoutError("Tool execution timed out".to_string()))?
512 } else {
513 client.call_tool(tool_name, parameters).await
514 };
515
516 let mut result = result
517 .map_err(|e| ComputerError::ProtocolError(format!("Tool execution failed: {}", e)))?;
518
519 let config = {
521 let configs = self.servers_config.read().await;
522 configs.get(server_name).cloned()
523 };
524
525 if let Some(config) = config {
526 if let Some(tool_meta) = self.merged_tool_meta(&config, tool_name) {
527 if result.meta.is_none() {
528 result.meta = Some(std::collections::HashMap::new());
529 }
530 if let Some(ref mut meta) = result.meta {
531 meta.insert(
532 A2C_TOOL_META.to_string(),
533 serde_json::to_value(tool_meta).unwrap(),
534 );
535 }
536 }
537
538 if let Some(vrl_script) = config.vrl() {
540 let parameters = serde_json::json!({});
543
544 let mut event = serde_json::to_value(&result).unwrap_or_default();
546 if let Value::Object(ref mut map) = event {
547 map.insert(
548 "tool_name".to_string(),
549 Value::String(tool_name.to_string()),
550 );
551 map.insert("parameters".to_string(), parameters);
552 }
553
554 let mut runtime = VrlRuntime::new();
556 match runtime.run(vrl_script, event, "UTC") {
557 Ok(vrl_result) => {
558 if result.meta.is_none() {
560 result.meta = Some(std::collections::HashMap::new());
561 }
562 if let Some(ref mut meta) = result.meta {
563 if let Ok(transformed_json) =
565 serde_json::to_string(&vrl_result.processed_event)
566 {
567 meta.insert(
568 A2C_VRL_TRANSFORMED.to_string(),
569 Value::String(transformed_json),
570 );
571 }
572 }
573 debug!(
574 "VRL转换成功 / VRL transformation succeeded for tool '{}'",
575 tool_name
576 );
577 }
578 Err(e) => {
579 warn!(
580 "VRL转换失败 / VRL transformation failed for tool '{}': {}. 原始结果将正常返回 / Original result will be returned normally.",
581 tool_name, e
582 );
583 }
584 }
585 }
586 }
587
588 Ok(result)
589 }
590
591 pub async fn execute_tool(
593 &self,
594 tool_name: &str,
595 parameters: serde_json::Value,
596 timeout: Option<std::time::Duration>,
597 ) -> Result<CallToolResult, ComputerError> {
598 let (server_name, original_tool_name) =
599 self.validate_tool_call(tool_name, ¶meters).await?;
600 self.call_tool(&server_name, &original_tool_name, parameters, timeout)
601 .await
602 }
603
604 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
606 let configs = self.servers_config.read().await;
607 let clients = self.active_clients.read().await;
608
609 configs
610 .keys()
611 .map(|name| {
612 let is_active = clients.contains_key(name);
613 let state = if is_active {
614 clients
615 .get(name)
616 .map(|c| c.state().to_string())
617 .unwrap_or_else(|| "unknown".to_string())
618 } else {
619 "pending".to_string()
620 };
621 (name.clone(), is_active, state)
622 })
623 .collect()
624 }
625
626 pub async fn get_server_configs(&self) -> serde_json::Value {
631 let configs = self.servers_config.read().await;
632 let clients = self.active_clients.read().await;
633
634 let mut result = serde_json::Map::new();
635
636 for (name, config) in configs.iter() {
637 let is_active = clients.contains_key(name);
638 let state = if is_active {
639 clients
640 .get(name)
641 .map(|c| c.state().to_string())
642 .unwrap_or_else(|| "unknown".to_string())
643 } else {
644 "pending".to_string()
645 };
646
647 let mut server_info = serde_json::Map::new();
649
650 let server_type = match config {
652 MCPServerConfig::Stdio(_) => "stdio",
653 MCPServerConfig::Sse(_) => "sse",
654 MCPServerConfig::Http(_) => "http",
655 };
656 server_info.insert(
657 "type".to_string(),
658 serde_json::Value::String(server_type.to_string()),
659 );
660
661 server_info.insert("status".to_string(), serde_json::Value::String(state));
663 server_info.insert("is_active".to_string(), serde_json::Value::Bool(is_active));
664 server_info.insert(
665 "disabled".to_string(),
666 serde_json::Value::Bool(config.disabled()),
667 );
668
669 let forbidden_tools: Vec<serde_json::Value> = config
671 .forbidden_tools()
672 .iter()
673 .map(|t| serde_json::Value::String(t.clone()))
674 .collect();
675 server_info.insert(
676 "forbidden_tools".to_string(),
677 serde_json::Value::Array(forbidden_tools),
678 );
679
680 if let Ok(tool_meta_json) = serde_json::to_value(config.tool_meta()) {
682 server_info.insert("tool_meta".to_string(), tool_meta_json);
683 }
684
685 if let Some(default_meta) = config.default_tool_meta() {
687 if let Ok(default_meta_json) = serde_json::to_value(default_meta) {
688 server_info.insert("default_tool_meta".to_string(), default_meta_json);
689 }
690 }
691
692 if let Some(vrl) = config.vrl() {
694 server_info.insert(
695 "vrl".to_string(),
696 serde_json::Value::String(vrl.to_string()),
697 );
698 }
699
700 match config {
702 MCPServerConfig::Stdio(stdio_config) => {
703 if let Ok(params_json) = serde_json::to_value(&stdio_config.server_parameters) {
704 server_info.insert("server_parameters".to_string(), params_json);
705 }
706 }
707 MCPServerConfig::Sse(sse_config) => {
708 if let Ok(params_json) = serde_json::to_value(&sse_config.server_parameters) {
709 server_info.insert("server_parameters".to_string(), params_json);
710 }
711 }
712 MCPServerConfig::Http(http_config) => {
713 if let Ok(params_json) = serde_json::to_value(&http_config.server_parameters) {
714 server_info.insert("server_parameters".to_string(), params_json);
715 }
716 }
717 }
718
719 result.insert(name.clone(), serde_json::Value::Object(server_info));
720 }
721
722 serde_json::Value::Object(result)
723 }
724
725 pub async fn list_available_tools(&self) -> Vec<Tool> {
727 let mut tools = Vec::new();
728 let mapping = self.tool_mapping.read().await;
729 let alias_map = self.alias_mapping.read().await;
730
731 for (display_name, server_name) in mapping.iter() {
732 let client = {
733 let clients = self.active_clients.read().await;
734 clients.get(server_name).cloned()
735 };
736
737 if let Some(client) = client {
738 let original_name = alias_map
740 .get(display_name)
741 .map(|(_, original)| original.clone())
742 .unwrap_or_else(|| display_name.clone());
743
744 if let Ok(tool_list) = client.list_tools().await {
746 if let Some(tool) = tool_list.into_iter().find(|t| t.name == original_name) {
747 let mut display_tool = tool;
749 display_tool.name = display_name.clone();
750 tools.push(display_tool);
751 }
752 }
753 }
754 }
755
756 tools
757 }
758
759 fn merged_tool_meta(&self, config: &MCPServerConfig, tool_name: &str) -> Option<ToolMeta> {
761 let specific = config.tool_meta().get(tool_name);
762 let default = config.default_tool_meta();
763
764 match (specific, default) {
765 (None, None) => None,
766 (Some(s), None) => Some(s.clone()),
767 (None, Some(d)) => Some(d.clone()),
768 (Some(s), Some(d)) => {
769 let mut merged = d.clone();
771 if s.auto_apply.is_some() {
772 merged.auto_apply = s.auto_apply;
773 }
774 if s.alias.is_some() {
775 merged.alias = s.alias.clone();
776 }
777 if s.tags.is_some() {
778 merged.tags = s.tags.clone();
779 }
780 if s.ret_object_mapper.is_some() {
781 merged.ret_object_mapper = s.ret_object_mapper.clone();
782 }
783 Some(merged)
784 }
785 }
786 }
787
788 pub async fn enable_auto_connect(&self) {
790 *self.auto_connect.write().await = true;
791 }
792
793 pub async fn disable_auto_connect(&self) {
795 *self.auto_connect.write().await = false;
796 }
797
798 pub async fn enable_auto_reconnect(&self) {
800 *self.auto_reconnect.write().await = true;
801 }
802
803 pub async fn disable_auto_reconnect(&self) {
805 *self.auto_reconnect.write().await = false;
806 }
807
808 pub async fn set_health_check_config(&self, config: HealthCheckConfig) {
810 *self.health_check_config.write().await = config;
811 }
812
813 pub async fn get_health_check_config(&self) -> HealthCheckConfig {
815 self.health_check_config.read().await.clone()
816 }
817
818 pub async fn set_reconnect_policy(&self, policy: ReconnectPolicy) {
820 *self.reconnect_policy.write().await = policy;
821 }
822
823 pub async fn get_reconnect_policy(&self) -> ReconnectPolicy {
825 self.reconnect_policy.read().await.clone()
826 }
827
828 pub async fn start_health_monitor(&self) {
832 self.stop_health_monitor().await;
834
835 let health_config = self.health_check_config.clone();
836 let reconnect_policy = self.reconnect_policy.clone();
837 let active_clients = self.active_clients.clone();
838 let _servers_config = self.servers_config.clone();
839 let retry_counts = self.retry_counts.clone();
840 let auto_reconnect = self.auto_reconnect.clone();
841
842 let handle = tokio::spawn(async move {
843 loop {
844 let config = health_config.read().await.clone();
845 if !config.enabled {
846 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
849 continue;
850 }
851
852 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
854 let clients_guard = active_clients.read().await;
855 clients_guard
856 .iter()
857 .map(|(k, v)| (k.clone(), v.clone()))
858 .collect()
859 };
860
861 for (server_name, client) in clients {
863 let check_result = tokio::time::timeout(
864 std::time::Duration::from_secs(config.timeout_secs),
865 client.health_check(),
866 )
867 .await;
868
869 let is_healthy = match check_result {
870 Ok(result) => result.is_healthy,
871 Err(_) => {
872 warn!("Health check timed out for server: {}", server_name);
873 false
874 }
875 };
876
877 if !is_healthy {
878 warn!("Server {} is unhealthy", server_name);
879
880 let should_reconnect = *auto_reconnect.read().await;
882 if !should_reconnect {
883 continue;
884 }
885
886 let policy = reconnect_policy.read().await.clone();
887 let mut retries = retry_counts.write().await;
888 let retry_count = retries.entry(server_name.clone()).or_insert(0);
889
890 if policy.should_retry(*retry_count) {
891 let delay = policy.calculate_delay(*retry_count);
892 info!(
893 "Attempting to reconnect {} (retry {}/{}), delay {:?}",
894 server_name,
895 *retry_count + 1,
896 if policy.max_retries == 0 {
897 "∞".to_string()
898 } else {
899 policy.max_retries.to_string()
900 },
901 delay
902 );
903
904 tokio::time::sleep(delay).await;
905
906 if let Err(e) = client.disconnect().await {
908 warn!("Failed to disconnect {}: {}", server_name, e);
909 }
910
911 match client.connect().await {
912 Ok(_) => {
913 info!("Successfully reconnected to {}", server_name);
914 *retry_count = 0;
916 }
917 Err(e) => {
918 error!("Failed to reconnect to {}: {}", server_name, e);
919 *retry_count += 1;
920 }
921 }
922 } else {
923 error!(
924 "Max retries ({}) reached for server {}. Giving up.",
925 policy.max_retries, server_name
926 );
927 }
929 } else {
930 let mut retries = retry_counts.write().await;
932 retries.remove(&server_name);
933 debug!("Server {} is healthy", server_name);
934 }
935 }
936
937 tokio::time::sleep(std::time::Duration::from_secs(config.interval_secs)).await;
939 }
940 });
941
942 *self.health_monitor_handle.write().await = Some(handle);
943 info!("Health monitor started");
944 }
945
946 pub async fn stop_health_monitor(&self) {
948 if let Some(handle) = self.health_monitor_handle.write().await.take() {
949 handle.abort();
950 info!("Health monitor stopped");
951 }
952 }
953
954 pub async fn check_server_health(&self, server_name: &str) -> Option<HealthCheckResult> {
956 let clients = self.active_clients.read().await;
957 if let Some(client) = clients.get(server_name) {
958 let config = self.health_check_config.read().await;
959 let result = tokio::time::timeout(
960 std::time::Duration::from_secs(config.timeout_secs),
961 client.health_check(),
962 )
963 .await;
964
965 match result {
966 Ok(health_result) => Some(health_result),
967 Err(_) => Some(HealthCheckResult {
968 is_healthy: false,
969 checked_at: std::time::Instant::now(),
970 error: Some("Health check timed out".to_string()),
971 response_time_ms: None,
972 }),
973 }
974 } else {
975 None
976 }
977 }
978
979 pub async fn check_all_health(&self) -> HashMap<String, HealthCheckResult> {
981 let mut results = HashMap::new();
982 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
983 let clients_guard = self.active_clients.read().await;
984 clients_guard
985 .iter()
986 .map(|(k, v)| (k.clone(), v.clone()))
987 .collect()
988 };
989
990 let config = self.health_check_config.read().await.clone();
991
992 for (server_name, client) in clients {
993 let result = tokio::time::timeout(
994 std::time::Duration::from_secs(config.timeout_secs),
995 client.health_check(),
996 )
997 .await;
998
999 let health_result = match result {
1000 Ok(hr) => hr,
1001 Err(_) => HealthCheckResult {
1002 is_healthy: false,
1003 checked_at: std::time::Instant::now(),
1004 error: Some("Health check timed out".to_string()),
1005 response_time_ms: None,
1006 },
1007 };
1008
1009 results.insert(server_name, health_result);
1010 }
1011
1012 results
1013 }
1014
1015 pub async fn get_retry_counts(&self) -> HashMap<String, u32> {
1017 self.retry_counts.read().await.clone()
1018 }
1019
1020 pub async fn reset_retry_count(&self, server_name: &str) {
1022 self.retry_counts.write().await.remove(server_name);
1023 }
1024
1025 pub async fn reset_all_retry_counts(&self) {
1027 self.retry_counts.write().await.clear();
1028 }
1029}
1030
1031impl Default for MCPServerManager {
1032 fn default() -> Self {
1033 Self::new()
1034 }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 use super::*;
1040 use std::collections::HashMap;
1041 use tokio::time::{sleep, Duration};
1042
1043 #[tokio::test]
1044 async fn test_manager_creation() {
1045 let manager = MCPServerManager::new();
1046 let status = manager.get_server_status().await;
1047 assert!(status.is_empty());
1048 }
1049
1050 #[tokio::test]
1051 async fn test_manager_initialization() {
1052 let manager = MCPServerManager::new();
1053
1054 let configs = vec![
1056 MCPServerConfig::Stdio(StdioServerConfig {
1058 name: "test_stdio".to_string(),
1059 disabled: false,
1060 forbidden_tools: vec![],
1061 tool_meta: HashMap::new(),
1062 default_tool_meta: None,
1063 vrl: None,
1064 server_parameters: StdioServerParameters {
1065 command: "echo".to_string(),
1066 args: vec!["hello".to_string()],
1067 env: HashMap::new(),
1068 cwd: None,
1069 },
1070 }),
1071 MCPServerConfig::Http(HttpServerConfig {
1073 name: "test_http".to_string(),
1074 disabled: true, forbidden_tools: vec![],
1076 tool_meta: HashMap::new(),
1077 default_tool_meta: None,
1078 vrl: None,
1079 server_parameters: HttpServerParameters {
1080 url: "http://localhost:8080".to_string(),
1081 headers: HashMap::new(),
1082 },
1083 }),
1084 ];
1085
1086 let result = manager.initialize(configs).await;
1088 assert!(result.is_ok());
1089
1090 let status = manager.get_server_status().await;
1092 assert_eq!(status.len(), 2);
1093
1094 let stdio_status = status
1096 .iter()
1097 .find(|(name, _, _)| name == "test_stdio")
1098 .unwrap();
1099 assert!(!stdio_status.1); let http_status = status
1102 .iter()
1103 .find(|(name, _, _)| name == "test_http")
1104 .unwrap();
1105 assert!(!http_status.1); }
1107
1108 #[tokio::test]
1109 async fn test_add_server() {
1110 let manager = MCPServerManager::new();
1111
1112 let config = MCPServerConfig::Stdio(StdioServerConfig {
1114 name: "test_server".to_string(),
1115 disabled: false,
1116 forbidden_tools: vec![],
1117 tool_meta: HashMap::new(),
1118 default_tool_meta: None,
1119 vrl: None,
1120 server_parameters: StdioServerParameters {
1121 command: "echo".to_string(),
1122 args: vec![],
1123 env: HashMap::new(),
1124 cwd: None,
1125 },
1126 });
1127
1128 let result = manager.add_or_update_server(config).await;
1129 assert!(result.is_ok());
1130
1131 let status = manager.get_server_status().await;
1133 assert_eq!(status.len(), 1);
1134 assert_eq!(status[0].0, "test_server");
1135 }
1136
1137 #[tokio::test]
1138 async fn test_remove_server() {
1139 let manager = MCPServerManager::new();
1140
1141 let config = MCPServerConfig::Stdio(StdioServerConfig {
1143 name: "test_server".to_string(),
1144 disabled: false,
1145 forbidden_tools: vec![],
1146 tool_meta: HashMap::new(),
1147 default_tool_meta: None,
1148 vrl: None,
1149 server_parameters: StdioServerParameters {
1150 command: "echo".to_string(),
1151 args: vec![],
1152 env: HashMap::new(),
1153 cwd: None,
1154 },
1155 });
1156
1157 manager.add_or_update_server(config).await.unwrap();
1158
1159 let result = manager.remove_server("test_server").await;
1161 assert!(result.is_ok());
1162
1163 let status = manager.get_server_status().await;
1165 assert!(status.is_empty());
1166 }
1167
1168 #[tokio::test]
1169 async fn test_tool_conflict_detection() {
1170 let manager = MCPServerManager::new();
1171
1172 let configs = vec![
1174 MCPServerConfig::Stdio(StdioServerConfig {
1176 name: "server1".to_string(),
1177 disabled: false,
1178 forbidden_tools: vec![],
1179 tool_meta: HashMap::new(),
1180 default_tool_meta: None,
1181 vrl: None,
1182 server_parameters: StdioServerParameters {
1183 command: "echo".to_string(),
1184 args: vec!["server1".to_string()],
1185 env: HashMap::new(),
1186 cwd: None,
1187 },
1188 }),
1189 MCPServerConfig::Stdio(StdioServerConfig {
1191 name: "server2".to_string(),
1192 disabled: false,
1193 forbidden_tools: vec![],
1194 tool_meta: HashMap::new(),
1195 default_tool_meta: None,
1196 vrl: None,
1197 server_parameters: StdioServerParameters {
1198 command: "echo".to_string(),
1199 args: vec!["server2".to_string()],
1200 env: HashMap::new(),
1201 cwd: None,
1202 },
1203 }),
1204 ];
1205
1206 let result = manager.initialize(configs).await;
1208 assert!(result.is_ok());
1209
1210 let _result = manager.start_all().await;
1212 sleep(Duration::from_millis(200)).await;
1217 }
1218
1219 #[tokio::test]
1220 async fn test_health_check_config() {
1221 let manager = MCPServerManager::new();
1222
1223 let config = manager.get_health_check_config().await;
1225 assert_eq!(config.interval_secs, 30);
1226 assert_eq!(config.timeout_secs, 5);
1227 assert!(config.enabled);
1228
1229 let new_config = HealthCheckConfig {
1231 interval_secs: 60,
1232 timeout_secs: 10,
1233 enabled: false,
1234 };
1235 manager.set_health_check_config(new_config.clone()).await;
1236
1237 let updated = manager.get_health_check_config().await;
1238 assert_eq!(updated.interval_secs, 60);
1239 assert_eq!(updated.timeout_secs, 10);
1240 assert!(!updated.enabled);
1241 }
1242
1243 #[tokio::test]
1244 async fn test_reconnect_policy() {
1245 let manager = MCPServerManager::new();
1246
1247 let policy = manager.get_reconnect_policy().await;
1249 assert!(policy.enabled);
1250 assert_eq!(policy.max_retries, 5);
1251 assert_eq!(policy.initial_delay_ms, 1000);
1252 assert_eq!(policy.max_delay_ms, 30000);
1253 assert_eq!(policy.backoff_factor, 2.0);
1254
1255 assert_eq!(policy.calculate_delay(0).as_millis(), 1000);
1257 assert_eq!(policy.calculate_delay(1).as_millis(), 2000);
1258 assert_eq!(policy.calculate_delay(2).as_millis(), 4000);
1259 assert_eq!(policy.calculate_delay(3).as_millis(), 8000);
1260
1261 assert!(policy.should_retry(0));
1263 assert!(policy.should_retry(4));
1264 assert!(!policy.should_retry(5)); let infinite_policy = ReconnectPolicy {
1268 enabled: true,
1269 max_retries: 0,
1270 ..Default::default()
1271 };
1272 assert!(infinite_policy.should_retry(100));
1273 }
1274
1275 #[tokio::test]
1276 async fn test_retry_counts() {
1277 let manager = MCPServerManager::new();
1278
1279 let counts = manager.get_retry_counts().await;
1281 assert!(counts.is_empty());
1282
1283 {
1285 manager
1286 .retry_counts
1287 .write()
1288 .await
1289 .insert("server1".to_string(), 3);
1290 manager
1291 .retry_counts
1292 .write()
1293 .await
1294 .insert("server2".to_string(), 5);
1295 }
1296
1297 let counts = manager.get_retry_counts().await;
1298 assert_eq!(counts.get("server1"), Some(&3));
1299 assert_eq!(counts.get("server2"), Some(&5));
1300
1301 manager.reset_retry_count("server1").await;
1303 let counts = manager.get_retry_counts().await;
1304 assert!(!counts.contains_key("server1"));
1305 assert_eq!(counts.get("server2"), Some(&5));
1306
1307 manager.reset_all_retry_counts().await;
1309 let counts = manager.get_retry_counts().await;
1310 assert!(counts.is_empty());
1311 }
1312
1313 #[tokio::test]
1314 async fn test_manager_with_custom_config() {
1315 let health_config = HealthCheckConfig {
1316 interval_secs: 15,
1317 timeout_secs: 3,
1318 enabled: true,
1319 };
1320 let reconnect_policy = ReconnectPolicy {
1321 enabled: true,
1322 max_retries: 10,
1323 initial_delay_ms: 500,
1324 max_delay_ms: 60000,
1325 backoff_factor: 1.5,
1326 };
1327
1328 let manager =
1329 MCPServerManager::with_config(health_config.clone(), reconnect_policy.clone());
1330
1331 let got_health = manager.get_health_check_config().await;
1332 assert_eq!(got_health.interval_secs, 15);
1333 assert_eq!(got_health.timeout_secs, 3);
1334
1335 let got_reconnect = manager.get_reconnect_policy().await;
1336 assert_eq!(got_reconnect.max_retries, 10);
1337 assert_eq!(got_reconnect.initial_delay_ms, 500);
1338 }
1339}