1use super::model::*;
11use super::utils::client_factory;
12use super::vrl_runtime::VrlRuntime;
13use crate::errors::ComputerError;
14use serde_json::Value;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use std::sync::Arc as StdArc;
18use tokio::sync::{watch, RwLock};
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, thiserror::Error)]
23#[error("Tool '{tool_name}' exists in multiple servers: {servers:?}")]
24pub struct ToolNameDuplicatedError {
25 pub tool_name: String,
26 pub servers: Vec<String>,
27}
28
29pub struct MCPServerManager {
31 servers_config: Arc<RwLock<HashMap<ServerName, MCPServerConfig>>>,
33 active_clients: Arc<RwLock<HashMap<ServerName, StdArc<dyn MCPClientProtocol>>>>,
35 tool_mapping: Arc<RwLock<HashMap<ToolName, ServerName>>>,
37 alias_mapping: Arc<RwLock<HashMap<String, (ServerName, ToolName)>>>,
39 disabled_tools: Arc<RwLock<HashSet<ToolName>>>,
41 auto_reconnect: Arc<RwLock<bool>>,
43 auto_connect: Arc<RwLock<bool>>,
45 state_notifier: watch::Sender<ManagerState>,
47 health_check_config: Arc<RwLock<HealthCheckConfig>>,
49 reconnect_policy: Arc<RwLock<ReconnectPolicy>>,
51 health_monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
53 retry_counts: Arc<RwLock<HashMap<ServerName, u32>>>,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ManagerState {
60 Uninitialized,
62 Initialized,
64 Running,
66 Error,
68}
69
70impl MCPServerManager {
71 pub fn new() -> Self {
73 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
74
75 Self {
76 servers_config: Arc::new(RwLock::new(HashMap::new())),
77 active_clients: Arc::new(RwLock::new(HashMap::new())),
78 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
79 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
80 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
81 auto_reconnect: Arc::new(RwLock::new(true)),
82 auto_connect: Arc::new(RwLock::new(false)),
83 state_notifier: state_tx,
84 health_check_config: Arc::new(RwLock::new(HealthCheckConfig::default())),
85 reconnect_policy: Arc::new(RwLock::new(ReconnectPolicy::default())),
86 health_monitor_handle: Arc::new(RwLock::new(None)),
87 retry_counts: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90
91 pub fn with_config(
93 health_check_config: HealthCheckConfig,
94 reconnect_policy: ReconnectPolicy,
95 ) -> Self {
96 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
97
98 Self {
99 servers_config: Arc::new(RwLock::new(HashMap::new())),
100 active_clients: Arc::new(RwLock::new(HashMap::new())),
101 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
102 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
103 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
104 auto_reconnect: Arc::new(RwLock::new(reconnect_policy.enabled)),
105 auto_connect: Arc::new(RwLock::new(false)),
106 state_notifier: state_tx,
107 health_check_config: Arc::new(RwLock::new(health_check_config)),
108 reconnect_policy: Arc::new(RwLock::new(reconnect_policy)),
109 health_monitor_handle: Arc::new(RwLock::new(None)),
110 retry_counts: Arc::new(RwLock::new(HashMap::new())),
111 }
112 }
113
114 pub fn get_state_notifier(&self) -> watch::Receiver<ManagerState> {
116 self.state_notifier.subscribe()
117 }
118
119 async fn update_state(&self, state: ManagerState) {
121 let _ = self.state_notifier.send(state);
122 }
123
124 pub async fn initialize(&self, servers: Vec<MCPServerConfig>) -> Result<(), ComputerError> {
126 self.stop_all().await?;
128
129 self.clear_all().await;
131
132 {
134 let mut configs = self.servers_config.write().await;
135 for server in servers {
136 configs.insert(server.name().to_string(), server);
137 }
138 }
139
140 self.refresh_tool_mapping().await?;
142
143 self.update_state(ManagerState::Initialized).await;
145
146 info!("Manager initialized successfully");
147 Ok(())
148 }
149
150 pub async fn add_or_update_server(&self, config: MCPServerConfig) -> Result<(), ComputerError> {
152 let server_name = config.name().to_string();
153
154 let is_active = {
156 let clients = self.active_clients.read().await;
157 clients.contains_key(&server_name)
158 };
159
160 if is_active {
161 let auto_reconnect = *self.auto_reconnect.read().await;
162 if auto_reconnect {
163 self.restart_server(&server_name).await?;
165 } else {
166 return Err(ComputerError::InvalidConfiguration(format!(
167 "Server {} is active. Stop it before updating config",
168 server_name
169 )));
170 }
171 }
172
173 {
175 let mut configs = self.servers_config.write().await;
176 configs.insert(server_name.clone(), config);
177 }
178
179 let auto_connect = *self.auto_connect.read().await;
181 if auto_connect && !is_active {
182 self.start_client(&server_name).await?;
183 }
184
185 self.refresh_tool_mapping().await?;
187
188 Ok(())
189 }
190
191 pub async fn remove_server(&self, server_name: &str) -> Result<(), ComputerError> {
193 self.stop_client(server_name).await?;
195
196 {
198 let mut configs = self.servers_config.write().await;
199 configs.remove(server_name);
200 }
201
202 self.refresh_tool_mapping().await?;
204
205 Ok(())
206 }
207
208 pub async fn start_all(&self) -> Result<(), ComputerError> {
210 let configs = self.servers_config.read().await;
211 let server_names: Vec<String> = configs
212 .iter()
213 .filter(|(_, config)| !config.disabled())
214 .map(|(name, _)| name.clone())
215 .collect();
216
217 drop(configs);
218
219 for server_name in server_names {
220 self.start_client(&server_name).await?;
221 }
222
223 self.update_state(ManagerState::Running).await;
225
226 info!("All servers started successfully");
227 Ok(())
228 }
229
230 pub async fn start_client(&self, server_name: &str) -> Result<(), ComputerError> {
232 let config = {
234 let configs = self.servers_config.read().await;
235 configs.get(server_name).cloned()
236 };
237
238 let config = config.ok_or_else(|| {
239 ComputerError::InvalidConfiguration(format!("Unknown server: {}", server_name))
240 })?;
241
242 if config.disabled() {
243 return Err(ComputerError::InvalidConfiguration(format!(
244 "Cannot start disabled server: {}",
245 server_name
246 )));
247 }
248
249 {
251 let clients = self.active_clients.read().await;
252 if clients.contains_key(server_name) {
253 return Ok(()); }
255 }
256
257 let client = client_factory(config);
259
260 client.connect().await.map_err(|e| {
262 ComputerError::ConnectionError(format!("Failed to connect to {}: {}", server_name, e))
263 })?;
264
265 {
267 let mut clients = self.active_clients.write().await;
268 clients.insert(server_name.to_string(), client);
269 }
270
271 self.refresh_tool_mapping().await?;
273
274 info!("Client {} started successfully", server_name);
275 Ok(())
276 }
277
278 pub async fn stop_client(&self, server_name: &str) -> Result<(), ComputerError> {
280 let mut client = {
282 let mut clients = self.active_clients.write().await;
283 clients.remove(server_name)
284 };
285
286 if let Some(ref mut c) = client {
288 c.disconnect().await.map_err(|e| {
289 ComputerError::ConnectionError(format!(
290 "Failed to disconnect from {}: {}",
291 server_name, e
292 ))
293 })?;
294 }
295
296 self.refresh_tool_mapping().await?;
298
299 info!("Client {} stopped successfully", server_name);
300 Ok(())
301 }
302
303 async fn restart_server(&self, server_name: &str) -> Result<(), ComputerError> {
305 self.stop_client(server_name).await?;
306
307 let enabled = {
309 let configs = self.servers_config.read().await;
310 configs
311 .get(server_name)
312 .map(|c| !c.disabled())
313 .unwrap_or(false)
314 };
315
316 if enabled {
317 self.start_client(server_name).await?;
318 }
319
320 Ok(())
321 }
322
323 pub async fn stop_all(&self) -> Result<(), ComputerError> {
325 let server_names: Vec<String> = {
326 let clients = self.active_clients.read().await;
327 clients.keys().cloned().collect()
328 };
329
330 for server_name in server_names {
331 self.stop_client(&server_name).await?;
332 }
333
334 self.update_state(ManagerState::Initialized).await;
336
337 info!("All servers stopped successfully");
338 Ok(())
339 }
340
341 async fn clear_all(&self) {
343 self.servers_config.write().await.clear();
344 self.active_clients.write().await.clear();
345 self.tool_mapping.write().await.clear();
346 self.alias_mapping.write().await.clear();
347 self.disabled_tools.write().await.clear();
348 }
349
350 pub async fn close(&self) -> Result<(), ComputerError> {
352 self.stop_all().await?;
353 self.clear_all().await;
354 self.update_state(ManagerState::Uninitialized).await;
355 info!("Manager closed successfully");
356 Ok(())
357 }
358
359 async fn refresh_tool_mapping(&self) -> Result<(), ComputerError> {
361 self.tool_mapping.write().await.clear();
363 self.alias_mapping.write().await.clear();
364 self.disabled_tools.write().await.clear();
365
366 let mut tool_sources: HashMap<ToolName, Vec<ServerName>> = HashMap::new();
368
369 let clients = self.active_clients.read().await;
371 let configs = self.servers_config.read().await;
372
373 for (server_name, client) in clients.iter() {
374 let config = match configs.get(server_name) {
375 Some(c) => c,
376 None => continue,
377 };
378
379 match client.list_tools().await {
381 Ok(tools) => {
382 for tool in tools {
383 let original_tool_name = tool.name.clone();
384
385 let tool_meta = self.merged_tool_meta(config, &original_tool_name);
387
388 let display_name = tool_meta
390 .and_then(|meta| meta.alias)
391 .unwrap_or_else(|| original_tool_name.to_string());
392
393 let original_tool_name_str = original_tool_name.to_string();
395 if display_name != original_tool_name_str {
396 let mut alias_map = self.alias_mapping.write().await;
397 alias_map.insert(
398 display_name.clone(),
399 (server_name.clone(), original_tool_name_str.clone()),
400 );
401 }
402
403 tool_sources
405 .entry(display_name.clone())
406 .or_default()
407 .push(server_name.clone());
408
409 let forbidden_tools = config.forbidden_tools();
411 if forbidden_tools.contains(&display_name)
412 || forbidden_tools.contains(&original_tool_name_str)
413 {
414 let mut disabled = self.disabled_tools.write().await;
415 disabled.insert(display_name);
416 }
417 }
418 }
419 Err(e) => {
420 error!("Error listing tools for {}: {}", server_name, e);
421 }
422 }
423 }
424
425 for (tool, sources) in tool_sources {
427 if sources.len() > 1 {
428 warn!("Tool '{}' exists in multiple servers: {:?}", tool, sources);
429 let suggestion =
430 "Please use the 'alias' feature in ToolMeta to resolve conflicts. \
431 Each tool should have a unique name or alias across all servers.";
432 return Err(ComputerError::InvalidConfiguration(format!(
433 "Tool '{}' exists in multiple servers: {:?}\n{}",
434 tool, sources, suggestion
435 )));
436 }
437 let mut mapping = self.tool_mapping.write().await;
438 mapping.insert(tool, sources[0].clone());
439 }
440
441 debug!("Tool mapping refreshed successfully");
442 Ok(())
443 }
444
445 pub async fn validate_tool_call(
447 &self,
448 tool_name: &str,
449 _parameters: &serde_json::Value,
450 ) -> Result<(ServerName, ToolName), ComputerError> {
451 let disabled = self.disabled_tools.read().await;
453 if disabled.contains(tool_name) {
454 return Err(ComputerError::PermissionError(format!(
455 "Tool '{}' is disabled by configuration",
456 tool_name
457 )));
458 }
459
460 let server_name = {
462 let mapping = self.tool_mapping.read().await;
463 mapping.get(tool_name).cloned()
464 };
465
466 let server_name = server_name.ok_or_else(|| {
467 ComputerError::InvalidConfiguration(format!(
468 "Tool '{}' not found in any active server",
469 tool_name
470 ))
471 })?;
472
473 let original_tool_name = {
475 let alias_map = self.alias_mapping.read().await;
476 if let Some((_, original)) = alias_map.get(tool_name) {
477 original.clone()
478 } else {
479 tool_name.to_string()
480 }
481 };
482
483 Ok((server_name, original_tool_name))
484 }
485
486 pub async fn call_tool(
488 &self,
489 server_name: &str,
490 tool_name: &str,
491 parameters: serde_json::Value,
492 timeout: Option<std::time::Duration>,
493 ) -> Result<CallToolResult, ComputerError> {
494 let client = {
496 let clients = self.active_clients.read().await;
497 clients
498 .get(server_name)
499 .ok_or_else(|| {
500 ComputerError::InvalidConfiguration(format!(
501 "Server '{}' for tool '{}' is not active",
502 server_name, tool_name
503 ))
504 })?
505 .clone()
506 };
507
508 let result = if let Some(timeout) = timeout {
510 tokio::time::timeout(timeout, client.call_tool(tool_name, parameters))
511 .await
512 .map_err(|_| ComputerError::TimeoutError("Tool execution timed out".to_string()))?
513 } else {
514 client.call_tool(tool_name, parameters).await
515 };
516
517 let mut result = result
518 .map_err(|e| ComputerError::ProtocolError(format!("Tool execution failed: {}", e)))?;
519
520 let config = {
522 let configs = self.servers_config.read().await;
523 configs.get(server_name).cloned()
524 };
525
526 if let Some(config) = config {
527 if let Some(tool_meta) = self.merged_tool_meta(&config, tool_name) {
528 if result.meta.is_none() {
529 result.meta = Some(rmcp::model::Meta::new());
530 }
531 if let Some(ref mut meta) = result.meta {
532 meta.insert(
533 A2C_TOOL_META.to_string(),
534 serde_json::to_value(tool_meta).unwrap(),
535 );
536 }
537 }
538
539 if let Some(vrl_script) = config.vrl() {
541 let parameters = serde_json::json!({});
544
545 let mut event = serde_json::to_value(&result).unwrap_or_default();
547 if let Value::Object(ref mut map) = event {
548 map.insert(
549 "tool_name".to_string(),
550 Value::String(tool_name.to_string()),
551 );
552 map.insert("parameters".to_string(), parameters);
553 }
554
555 let mut runtime = VrlRuntime::new();
557 match runtime.run(vrl_script, event, "UTC") {
558 Ok(vrl_result) => {
559 if result.meta.is_none() {
561 result.meta = Some(rmcp::model::Meta::new());
562 }
563 if let Some(ref mut meta) = result.meta {
564 if let Ok(transformed_json) =
566 serde_json::to_string(&vrl_result.processed_event)
567 {
568 meta.insert(
569 A2C_VRL_TRANSFORMED.to_string(),
570 Value::String(transformed_json),
571 );
572 }
573 }
574 debug!(
575 "VRL转换成功 / VRL transformation succeeded for tool '{}'",
576 tool_name
577 );
578 }
579 Err(e) => {
580 warn!(
581 "VRL转换失败 / VRL transformation failed for tool '{}': {}. 原始结果将正常返回 / Original result will be returned normally.",
582 tool_name, e
583 );
584 }
585 }
586 }
587 }
588
589 Ok(result)
590 }
591
592 pub async fn execute_tool(
594 &self,
595 tool_name: &str,
596 parameters: serde_json::Value,
597 timeout: Option<std::time::Duration>,
598 ) -> Result<CallToolResult, ComputerError> {
599 let (server_name, original_tool_name) =
600 self.validate_tool_call(tool_name, ¶meters).await?;
601 self.call_tool(&server_name, &original_tool_name, parameters, timeout)
602 .await
603 }
604
605 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
607 let configs = self.servers_config.read().await;
608 let clients = self.active_clients.read().await;
609
610 configs
611 .keys()
612 .map(|name| {
613 let is_active = clients.contains_key(name);
614 let state = if is_active {
615 clients
616 .get(name)
617 .map(|c| c.state().to_string())
618 .unwrap_or_else(|| "unknown".to_string())
619 } else {
620 "pending".to_string()
621 };
622 (name.clone(), is_active, state)
623 })
624 .collect()
625 }
626
627 pub async fn get_server_configs(&self) -> serde_json::Value {
632 let configs = self.servers_config.read().await;
633 let clients = self.active_clients.read().await;
634
635 let mut result = serde_json::Map::new();
636
637 for (name, config) in configs.iter() {
638 let is_active = clients.contains_key(name);
639 let state = if is_active {
640 clients
641 .get(name)
642 .map(|c| c.state().to_string())
643 .unwrap_or_else(|| "unknown".to_string())
644 } else {
645 "pending".to_string()
646 };
647
648 let mut server_info = serde_json::Map::new();
650
651 let server_type = match config {
653 MCPServerConfig::Stdio(_) => "stdio",
654 MCPServerConfig::Sse(_) => "sse",
655 MCPServerConfig::Http(_) => "http",
656 };
657 server_info.insert(
658 "type".to_string(),
659 serde_json::Value::String(server_type.to_string()),
660 );
661
662 server_info.insert("status".to_string(), serde_json::Value::String(state));
664 server_info.insert("is_active".to_string(), serde_json::Value::Bool(is_active));
665 server_info.insert(
666 "disabled".to_string(),
667 serde_json::Value::Bool(config.disabled()),
668 );
669
670 let forbidden_tools: Vec<serde_json::Value> = config
672 .forbidden_tools()
673 .iter()
674 .map(|t| serde_json::Value::String(t.clone()))
675 .collect();
676 server_info.insert(
677 "forbidden_tools".to_string(),
678 serde_json::Value::Array(forbidden_tools),
679 );
680
681 if let Ok(tool_meta_json) = serde_json::to_value(config.tool_meta()) {
683 server_info.insert("tool_meta".to_string(), tool_meta_json);
684 }
685
686 if let Some(default_meta) = config.default_tool_meta() {
688 if let Ok(default_meta_json) = serde_json::to_value(default_meta) {
689 server_info.insert("default_tool_meta".to_string(), default_meta_json);
690 }
691 }
692
693 if let Some(vrl) = config.vrl() {
695 server_info.insert(
696 "vrl".to_string(),
697 serde_json::Value::String(vrl.to_string()),
698 );
699 }
700
701 match config {
703 MCPServerConfig::Stdio(stdio_config) => {
704 if let Ok(params_json) = serde_json::to_value(&stdio_config.server_parameters) {
705 server_info.insert("server_parameters".to_string(), params_json);
706 }
707 }
708 MCPServerConfig::Sse(sse_config) => {
709 if let Ok(params_json) = serde_json::to_value(&sse_config.server_parameters) {
710 server_info.insert("server_parameters".to_string(), params_json);
711 }
712 }
713 MCPServerConfig::Http(http_config) => {
714 if let Ok(params_json) = serde_json::to_value(&http_config.server_parameters) {
715 server_info.insert("server_parameters".to_string(), params_json);
716 }
717 }
718 }
719
720 result.insert(name.clone(), serde_json::Value::Object(server_info));
721 }
722
723 serde_json::Value::Object(result)
724 }
725
726 pub async fn list_available_tools(&self) -> Vec<Tool> {
728 let mut tools = Vec::new();
729 let mapping = self.tool_mapping.read().await;
730 let alias_map = self.alias_mapping.read().await;
731
732 for (display_name, server_name) in mapping.iter() {
733 let client = {
734 let clients = self.active_clients.read().await;
735 clients.get(server_name).cloned()
736 };
737
738 if let Some(client) = client {
739 let original_name = alias_map
741 .get(display_name)
742 .map(|(_, original)| original.clone())
743 .unwrap_or_else(|| display_name.clone());
744
745 if let Ok(tool_list) = client.list_tools().await {
747 if let Some(tool) = tool_list.into_iter().find(|t| t.name == original_name) {
748 let mut display_tool = tool;
750 display_tool.name = display_name.clone().into();
751
752 let config = {
754 let configs = self.servers_config.read().await;
755 configs.get(server_name).cloned()
756 };
757 if let Some(config) = config {
758 if let Some(tool_meta) = self.merged_tool_meta(&config, &original_name)
759 {
760 if display_tool.meta.is_none() {
761 display_tool.meta = Some(rmcp::model::Meta::new());
762 }
763 if let Some(ref mut meta) = display_tool.meta {
764 meta.insert(
765 A2C_TOOL_META.to_string(),
766 serde_json::to_value(tool_meta).unwrap(),
767 );
768 }
769 }
770 }
771
772 tools.push(display_tool);
773 }
774 }
775 }
776 }
777
778 tools
779 }
780
781 fn merged_tool_meta(&self, config: &MCPServerConfig, tool_name: &str) -> Option<ToolMeta> {
783 let specific = config.tool_meta().get(tool_name);
784 let default = config.default_tool_meta();
785
786 match (specific, default) {
787 (None, None) => None,
788 (Some(s), None) => Some(s.clone()),
789 (None, Some(d)) => Some(d.clone()),
790 (Some(s), Some(d)) => {
791 let mut merged = d.clone();
793 if s.auto_apply.is_some() {
794 merged.auto_apply = s.auto_apply;
795 }
796 if s.alias.is_some() {
797 merged.alias = s.alias.clone();
798 }
799 if s.tags.is_some() {
800 merged.tags = s.tags.clone();
801 }
802 if s.ret_object_mapper.is_some() {
803 merged.ret_object_mapper = s.ret_object_mapper.clone();
804 }
805 Some(merged)
806 }
807 }
808 }
809
810 pub async fn enable_auto_connect(&self) {
812 *self.auto_connect.write().await = true;
813 }
814
815 pub async fn disable_auto_connect(&self) {
817 *self.auto_connect.write().await = false;
818 }
819
820 pub async fn enable_auto_reconnect(&self) {
822 *self.auto_reconnect.write().await = true;
823 }
824
825 pub async fn disable_auto_reconnect(&self) {
827 *self.auto_reconnect.write().await = false;
828 }
829
830 pub async fn set_health_check_config(&self, config: HealthCheckConfig) {
832 *self.health_check_config.write().await = config;
833 }
834
835 pub async fn get_health_check_config(&self) -> HealthCheckConfig {
837 self.health_check_config.read().await.clone()
838 }
839
840 pub async fn set_reconnect_policy(&self, policy: ReconnectPolicy) {
842 *self.reconnect_policy.write().await = policy;
843 }
844
845 pub async fn get_reconnect_policy(&self) -> ReconnectPolicy {
847 self.reconnect_policy.read().await.clone()
848 }
849
850 pub async fn start_health_monitor(&self) {
854 self.stop_health_monitor().await;
856
857 let health_config = self.health_check_config.clone();
858 let reconnect_policy = self.reconnect_policy.clone();
859 let active_clients = self.active_clients.clone();
860 let _servers_config = self.servers_config.clone();
861 let retry_counts = self.retry_counts.clone();
862 let auto_reconnect = self.auto_reconnect.clone();
863
864 let handle = tokio::spawn(async move {
865 loop {
866 let config = health_config.read().await.clone();
867 if !config.enabled {
868 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
871 continue;
872 }
873
874 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
876 let clients_guard = active_clients.read().await;
877 clients_guard
878 .iter()
879 .map(|(k, v)| (k.clone(), v.clone()))
880 .collect()
881 };
882
883 for (server_name, client) in clients {
885 let check_result = tokio::time::timeout(
886 std::time::Duration::from_secs(config.timeout_secs),
887 client.health_check(),
888 )
889 .await;
890
891 let is_healthy = match check_result {
892 Ok(result) => result.is_healthy,
893 Err(_) => {
894 warn!("Health check timed out for server: {}", server_name);
895 false
896 }
897 };
898
899 if !is_healthy {
900 warn!("Server {} is unhealthy", server_name);
901
902 let should_reconnect = *auto_reconnect.read().await;
904 if !should_reconnect {
905 continue;
906 }
907
908 let policy = reconnect_policy.read().await.clone();
909 let mut retries = retry_counts.write().await;
910 let retry_count = retries.entry(server_name.clone()).or_insert(0);
911
912 if policy.should_retry(*retry_count) {
913 let delay = policy.calculate_delay(*retry_count);
914 info!(
915 "Attempting to reconnect {} (retry {}/{}), delay {:?}",
916 server_name,
917 *retry_count + 1,
918 if policy.max_retries == 0 {
919 "∞".to_string()
920 } else {
921 policy.max_retries.to_string()
922 },
923 delay
924 );
925
926 tokio::time::sleep(delay).await;
927
928 if let Err(e) = client.disconnect().await {
930 warn!("Failed to disconnect {}: {}", server_name, e);
931 }
932
933 match client.connect().await {
934 Ok(_) => {
935 info!("Successfully reconnected to {}", server_name);
936 *retry_count = 0;
938 }
939 Err(e) => {
940 error!("Failed to reconnect to {}: {}", server_name, e);
941 *retry_count += 1;
942 }
943 }
944 } else {
945 error!(
946 "Max retries ({}) reached for server {}. Giving up.",
947 policy.max_retries, server_name
948 );
949 }
951 } else {
952 let mut retries = retry_counts.write().await;
954 retries.remove(&server_name);
955 debug!("Server {} is healthy", server_name);
956 }
957 }
958
959 tokio::time::sleep(std::time::Duration::from_secs(config.interval_secs)).await;
961 }
962 });
963
964 *self.health_monitor_handle.write().await = Some(handle);
965 info!("Health monitor started");
966 }
967
968 pub async fn stop_health_monitor(&self) {
970 if let Some(handle) = self.health_monitor_handle.write().await.take() {
971 handle.abort();
972 info!("Health monitor stopped");
973 }
974 }
975
976 pub async fn check_server_health(&self, server_name: &str) -> Option<HealthCheckResult> {
978 let clients = self.active_clients.read().await;
979 if let Some(client) = clients.get(server_name) {
980 let config = self.health_check_config.read().await;
981 let result = tokio::time::timeout(
982 std::time::Duration::from_secs(config.timeout_secs),
983 client.health_check(),
984 )
985 .await;
986
987 match result {
988 Ok(health_result) => Some(health_result),
989 Err(_) => Some(HealthCheckResult {
990 is_healthy: false,
991 checked_at: std::time::Instant::now(),
992 error: Some("Health check timed out".to_string()),
993 response_time_ms: None,
994 }),
995 }
996 } else {
997 None
998 }
999 }
1000
1001 pub async fn check_all_health(&self) -> HashMap<String, HealthCheckResult> {
1003 let mut results = HashMap::new();
1004 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
1005 let clients_guard = self.active_clients.read().await;
1006 clients_guard
1007 .iter()
1008 .map(|(k, v)| (k.clone(), v.clone()))
1009 .collect()
1010 };
1011
1012 let config = self.health_check_config.read().await.clone();
1013
1014 for (server_name, client) in clients {
1015 let result = tokio::time::timeout(
1016 std::time::Duration::from_secs(config.timeout_secs),
1017 client.health_check(),
1018 )
1019 .await;
1020
1021 let health_result = match result {
1022 Ok(hr) => hr,
1023 Err(_) => HealthCheckResult {
1024 is_healthy: false,
1025 checked_at: std::time::Instant::now(),
1026 error: Some("Health check timed out".to_string()),
1027 response_time_ms: None,
1028 },
1029 };
1030
1031 results.insert(server_name, health_result);
1032 }
1033
1034 results
1035 }
1036
1037 pub async fn get_retry_counts(&self) -> HashMap<String, u32> {
1039 self.retry_counts.read().await.clone()
1040 }
1041
1042 pub async fn reset_retry_count(&self, server_name: &str) {
1044 self.retry_counts.write().await.remove(server_name);
1045 }
1046
1047 pub async fn reset_all_retry_counts(&self) {
1049 self.retry_counts.write().await.clear();
1050 }
1051}
1052
1053impl Default for MCPServerManager {
1054 fn default() -> Self {
1055 Self::new()
1056 }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use std::collections::HashMap;
1063 use tokio::time::{sleep, Duration};
1064
1065 #[tokio::test]
1066 async fn test_manager_creation() {
1067 let manager = MCPServerManager::new();
1068 let status = manager.get_server_status().await;
1069 assert!(status.is_empty());
1070 }
1071
1072 #[tokio::test]
1073 async fn test_manager_initialization() {
1074 let manager = MCPServerManager::new();
1075
1076 let configs = vec![
1078 MCPServerConfig::Stdio(StdioServerConfig {
1080 name: "test_stdio".to_string(),
1081 disabled: false,
1082 forbidden_tools: vec![],
1083 tool_meta: HashMap::new(),
1084 default_tool_meta: None,
1085 vrl: None,
1086 server_parameters: StdioServerParameters {
1087 command: "echo".to_string(),
1088 args: vec!["hello".to_string()],
1089 env: HashMap::new(),
1090 cwd: None,
1091 },
1092 }),
1093 MCPServerConfig::Http(HttpServerConfig {
1095 name: "test_http".to_string(),
1096 disabled: true, forbidden_tools: vec![],
1098 tool_meta: HashMap::new(),
1099 default_tool_meta: None,
1100 vrl: None,
1101 server_parameters: HttpServerParameters {
1102 url: "http://localhost:8080".to_string(),
1103 headers: HashMap::new(),
1104 },
1105 }),
1106 ];
1107
1108 let result = manager.initialize(configs).await;
1110 assert!(result.is_ok());
1111
1112 let status = manager.get_server_status().await;
1114 assert_eq!(status.len(), 2);
1115
1116 let stdio_status = status
1118 .iter()
1119 .find(|(name, _, _)| name == "test_stdio")
1120 .unwrap();
1121 assert!(!stdio_status.1); let http_status = status
1124 .iter()
1125 .find(|(name, _, _)| name == "test_http")
1126 .unwrap();
1127 assert!(!http_status.1); }
1129
1130 #[tokio::test]
1131 async fn test_add_server() {
1132 let manager = MCPServerManager::new();
1133
1134 let config = MCPServerConfig::Stdio(StdioServerConfig {
1136 name: "test_server".to_string(),
1137 disabled: false,
1138 forbidden_tools: vec![],
1139 tool_meta: HashMap::new(),
1140 default_tool_meta: None,
1141 vrl: None,
1142 server_parameters: StdioServerParameters {
1143 command: "echo".to_string(),
1144 args: vec![],
1145 env: HashMap::new(),
1146 cwd: None,
1147 },
1148 });
1149
1150 let result = manager.add_or_update_server(config).await;
1151 assert!(result.is_ok());
1152
1153 let status = manager.get_server_status().await;
1155 assert_eq!(status.len(), 1);
1156 assert_eq!(status[0].0, "test_server");
1157 }
1158
1159 #[tokio::test]
1160 async fn test_remove_server() {
1161 let manager = MCPServerManager::new();
1162
1163 let config = MCPServerConfig::Stdio(StdioServerConfig {
1165 name: "test_server".to_string(),
1166 disabled: false,
1167 forbidden_tools: vec![],
1168 tool_meta: HashMap::new(),
1169 default_tool_meta: None,
1170 vrl: None,
1171 server_parameters: StdioServerParameters {
1172 command: "echo".to_string(),
1173 args: vec![],
1174 env: HashMap::new(),
1175 cwd: None,
1176 },
1177 });
1178
1179 manager.add_or_update_server(config).await.unwrap();
1180
1181 let result = manager.remove_server("test_server").await;
1183 assert!(result.is_ok());
1184
1185 let status = manager.get_server_status().await;
1187 assert!(status.is_empty());
1188 }
1189
1190 #[tokio::test]
1191 async fn test_tool_conflict_detection() {
1192 let manager = MCPServerManager::new();
1193
1194 let configs = vec![
1196 MCPServerConfig::Stdio(StdioServerConfig {
1198 name: "server1".to_string(),
1199 disabled: false,
1200 forbidden_tools: vec![],
1201 tool_meta: HashMap::new(),
1202 default_tool_meta: None,
1203 vrl: None,
1204 server_parameters: StdioServerParameters {
1205 command: "echo".to_string(),
1206 args: vec!["server1".to_string()],
1207 env: HashMap::new(),
1208 cwd: None,
1209 },
1210 }),
1211 MCPServerConfig::Stdio(StdioServerConfig {
1213 name: "server2".to_string(),
1214 disabled: false,
1215 forbidden_tools: vec![],
1216 tool_meta: HashMap::new(),
1217 default_tool_meta: None,
1218 vrl: None,
1219 server_parameters: StdioServerParameters {
1220 command: "echo".to_string(),
1221 args: vec!["server2".to_string()],
1222 env: HashMap::new(),
1223 cwd: None,
1224 },
1225 }),
1226 ];
1227
1228 let result = manager.initialize(configs).await;
1230 assert!(result.is_ok());
1231
1232 let _result = manager.start_all().await;
1234 sleep(Duration::from_millis(200)).await;
1239 }
1240
1241 #[tokio::test]
1242 async fn test_health_check_config() {
1243 let manager = MCPServerManager::new();
1244
1245 let config = manager.get_health_check_config().await;
1247 assert_eq!(config.interval_secs, 30);
1248 assert_eq!(config.timeout_secs, 5);
1249 assert!(config.enabled);
1250
1251 let new_config = HealthCheckConfig {
1253 interval_secs: 60,
1254 timeout_secs: 10,
1255 enabled: false,
1256 };
1257 manager.set_health_check_config(new_config.clone()).await;
1258
1259 let updated = manager.get_health_check_config().await;
1260 assert_eq!(updated.interval_secs, 60);
1261 assert_eq!(updated.timeout_secs, 10);
1262 assert!(!updated.enabled);
1263 }
1264
1265 #[tokio::test]
1266 async fn test_reconnect_policy() {
1267 let manager = MCPServerManager::new();
1268
1269 let policy = manager.get_reconnect_policy().await;
1271 assert!(policy.enabled);
1272 assert_eq!(policy.max_retries, 5);
1273 assert_eq!(policy.initial_delay_ms, 1000);
1274 assert_eq!(policy.max_delay_ms, 30000);
1275 assert_eq!(policy.backoff_factor, 2.0);
1276
1277 assert_eq!(policy.calculate_delay(0).as_millis(), 1000);
1279 assert_eq!(policy.calculate_delay(1).as_millis(), 2000);
1280 assert_eq!(policy.calculate_delay(2).as_millis(), 4000);
1281 assert_eq!(policy.calculate_delay(3).as_millis(), 8000);
1282
1283 assert!(policy.should_retry(0));
1285 assert!(policy.should_retry(4));
1286 assert!(!policy.should_retry(5)); let infinite_policy = ReconnectPolicy {
1290 enabled: true,
1291 max_retries: 0,
1292 ..Default::default()
1293 };
1294 assert!(infinite_policy.should_retry(100));
1295 }
1296
1297 #[tokio::test]
1298 async fn test_retry_counts() {
1299 let manager = MCPServerManager::new();
1300
1301 let counts = manager.get_retry_counts().await;
1303 assert!(counts.is_empty());
1304
1305 {
1307 manager
1308 .retry_counts
1309 .write()
1310 .await
1311 .insert("server1".to_string(), 3);
1312 manager
1313 .retry_counts
1314 .write()
1315 .await
1316 .insert("server2".to_string(), 5);
1317 }
1318
1319 let counts = manager.get_retry_counts().await;
1320 assert_eq!(counts.get("server1"), Some(&3));
1321 assert_eq!(counts.get("server2"), Some(&5));
1322
1323 manager.reset_retry_count("server1").await;
1325 let counts = manager.get_retry_counts().await;
1326 assert!(!counts.contains_key("server1"));
1327 assert_eq!(counts.get("server2"), Some(&5));
1328
1329 manager.reset_all_retry_counts().await;
1331 let counts = manager.get_retry_counts().await;
1332 assert!(counts.is_empty());
1333 }
1334
1335 #[tokio::test]
1336 async fn test_manager_with_custom_config() {
1337 let health_config = HealthCheckConfig {
1338 interval_secs: 15,
1339 timeout_secs: 3,
1340 enabled: true,
1341 };
1342 let reconnect_policy = ReconnectPolicy {
1343 enabled: true,
1344 max_retries: 10,
1345 initial_delay_ms: 500,
1346 max_delay_ms: 60000,
1347 backoff_factor: 1.5,
1348 };
1349
1350 let manager =
1351 MCPServerManager::with_config(health_config.clone(), reconnect_policy.clone());
1352
1353 let got_health = manager.get_health_check_config().await;
1354 assert_eq!(got_health.interval_secs, 15);
1355 assert_eq!(got_health.timeout_secs, 3);
1356
1357 let got_reconnect = manager.get_reconnect_policy().await;
1358 assert_eq!(got_reconnect.max_retries, 10);
1359 assert_eq!(got_reconnect.initial_delay_ms, 500);
1360 }
1361
1362 #[tokio::test]
1363 async fn test_merged_tool_meta() {
1364 let manager = MCPServerManager::new();
1365
1366 let config = MCPServerConfig::Stdio(StdioServerConfig {
1368 name: "s".to_string(),
1369 disabled: false,
1370 forbidden_tools: vec![],
1371 tool_meta: HashMap::from([(
1372 "tool_a".to_string(),
1373 ToolMeta {
1374 auto_apply: Some(true),
1375 alias: None,
1376 tags: Some(vec!["tag1".to_string()]),
1377 ret_object_mapper: None,
1378 },
1379 )]),
1380 default_tool_meta: None,
1381 vrl: None,
1382 server_parameters: StdioServerParameters {
1383 command: "echo".to_string(),
1384 args: vec![],
1385 env: HashMap::new(),
1386 cwd: None,
1387 },
1388 });
1389 let meta = manager.merged_tool_meta(&config, "tool_a").unwrap();
1390 assert_eq!(meta.auto_apply, Some(true));
1391 assert_eq!(meta.tags, Some(vec!["tag1".to_string()]));
1392
1393 let config = MCPServerConfig::Stdio(StdioServerConfig {
1395 name: "s".to_string(),
1396 disabled: false,
1397 forbidden_tools: vec![],
1398 tool_meta: HashMap::new(),
1399 default_tool_meta: Some(ToolMeta {
1400 auto_apply: Some(false),
1401 alias: None,
1402 tags: Some(vec!["default_tag".to_string()]),
1403 ret_object_mapper: None,
1404 }),
1405 vrl: None,
1406 server_parameters: StdioServerParameters {
1407 command: "echo".to_string(),
1408 args: vec![],
1409 env: HashMap::new(),
1410 cwd: None,
1411 },
1412 });
1413 let meta = manager.merged_tool_meta(&config, "any_tool").unwrap();
1414 assert_eq!(meta.auto_apply, Some(false));
1415 assert_eq!(meta.tags, Some(vec!["default_tag".to_string()]));
1416
1417 let config = MCPServerConfig::Stdio(StdioServerConfig {
1419 name: "s".to_string(),
1420 disabled: false,
1421 forbidden_tools: vec![],
1422 tool_meta: HashMap::from([(
1423 "tool_a".to_string(),
1424 ToolMeta {
1425 auto_apply: Some(true),
1426 alias: None,
1427 tags: None,
1428 ret_object_mapper: None,
1429 },
1430 )]),
1431 default_tool_meta: Some(ToolMeta {
1432 auto_apply: Some(false),
1433 alias: Some("default_alias".to_string()),
1434 tags: Some(vec!["default_tag".to_string()]),
1435 ret_object_mapper: None,
1436 }),
1437 vrl: None,
1438 server_parameters: StdioServerParameters {
1439 command: "echo".to_string(),
1440 args: vec![],
1441 env: HashMap::new(),
1442 cwd: None,
1443 },
1444 });
1445 let meta = manager.merged_tool_meta(&config, "tool_a").unwrap();
1446 assert_eq!(meta.auto_apply, Some(true)); assert_eq!(meta.alias, Some("default_alias".to_string())); assert_eq!(meta.tags, Some(vec!["default_tag".to_string()])); let config = MCPServerConfig::Stdio(StdioServerConfig {
1452 name: "s".to_string(),
1453 disabled: false,
1454 forbidden_tools: vec![],
1455 tool_meta: HashMap::new(),
1456 default_tool_meta: None,
1457 vrl: None,
1458 server_parameters: StdioServerParameters {
1459 command: "echo".to_string(),
1460 args: vec![],
1461 env: HashMap::new(),
1462 cwd: None,
1463 },
1464 });
1465 assert!(manager.merged_tool_meta(&config, "tool_a").is_none());
1466 }
1467}