1use super::model::*;
11use super::utils::client_factory;
12use super::vrl_runtime::VrlRuntime;
13use crate::errors::ComputerError;
14use serde_json::Value;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use std::sync::Arc as StdArc;
18use tokio::sync::{watch, RwLock};
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, thiserror::Error)]
23#[error("Tool '{tool_name}' exists in multiple servers: {servers:?}")]
24pub struct ToolNameDuplicatedError {
25 pub tool_name: String,
26 pub servers: Vec<String>,
27}
28
29pub struct MCPServerManager {
31 servers_config: Arc<RwLock<HashMap<ServerName, MCPServerConfig>>>,
33 active_clients: Arc<RwLock<HashMap<ServerName, StdArc<dyn MCPClientProtocol>>>>,
35 tool_mapping: Arc<RwLock<HashMap<ToolName, ServerName>>>,
37 alias_mapping: Arc<RwLock<HashMap<String, (ServerName, ToolName)>>>,
39 disabled_tools: Arc<RwLock<HashSet<ToolName>>>,
41 auto_reconnect: Arc<RwLock<bool>>,
43 auto_connect: Arc<RwLock<bool>>,
45 state_notifier: watch::Sender<ManagerState>,
47 health_check_config: Arc<RwLock<HealthCheckConfig>>,
49 reconnect_policy: Arc<RwLock<ReconnectPolicy>>,
51 health_monitor_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
53 retry_counts: Arc<RwLock<HashMap<ServerName, u32>>>,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ManagerState {
60 Uninitialized,
62 Initialized,
64 Running,
66 Error,
68}
69
70impl MCPServerManager {
71 pub fn new() -> Self {
73 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
74
75 Self {
76 servers_config: Arc::new(RwLock::new(HashMap::new())),
77 active_clients: Arc::new(RwLock::new(HashMap::new())),
78 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
79 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
80 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
81 auto_reconnect: Arc::new(RwLock::new(true)),
82 auto_connect: Arc::new(RwLock::new(false)),
83 state_notifier: state_tx,
84 health_check_config: Arc::new(RwLock::new(HealthCheckConfig::default())),
85 reconnect_policy: Arc::new(RwLock::new(ReconnectPolicy::default())),
86 health_monitor_handle: Arc::new(RwLock::new(None)),
87 retry_counts: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90
91 pub fn with_config(
93 health_check_config: HealthCheckConfig,
94 reconnect_policy: ReconnectPolicy,
95 ) -> Self {
96 let (state_tx, _) = watch::channel(ManagerState::Uninitialized);
97
98 Self {
99 servers_config: Arc::new(RwLock::new(HashMap::new())),
100 active_clients: Arc::new(RwLock::new(HashMap::new())),
101 tool_mapping: Arc::new(RwLock::new(HashMap::new())),
102 alias_mapping: Arc::new(RwLock::new(HashMap::new())),
103 disabled_tools: Arc::new(RwLock::new(HashSet::new())),
104 auto_reconnect: Arc::new(RwLock::new(reconnect_policy.enabled)),
105 auto_connect: Arc::new(RwLock::new(false)),
106 state_notifier: state_tx,
107 health_check_config: Arc::new(RwLock::new(health_check_config)),
108 reconnect_policy: Arc::new(RwLock::new(reconnect_policy)),
109 health_monitor_handle: Arc::new(RwLock::new(None)),
110 retry_counts: Arc::new(RwLock::new(HashMap::new())),
111 }
112 }
113
114 pub fn get_state_notifier(&self) -> watch::Receiver<ManagerState> {
116 self.state_notifier.subscribe()
117 }
118
119 async fn update_state(&self, state: ManagerState) {
121 let _ = self.state_notifier.send(state);
122 }
123
124 pub async fn initialize(&self, servers: Vec<MCPServerConfig>) -> Result<(), ComputerError> {
126 self.stop_all().await?;
128
129 self.clear_all().await;
131
132 {
134 let mut configs = self.servers_config.write().await;
135 for server in servers {
136 configs.insert(server.name().to_string(), server);
137 }
138 }
139
140 self.refresh_tool_mapping().await?;
142
143 self.update_state(ManagerState::Initialized).await;
145
146 info!("Manager initialized successfully");
147 Ok(())
148 }
149
150 pub async fn add_or_update_server(&self, config: MCPServerConfig) -> Result<(), ComputerError> {
152 let server_name = config.name().to_string();
153
154 let is_active = {
156 let clients = self.active_clients.read().await;
157 clients.contains_key(&server_name)
158 };
159
160 if is_active {
161 let auto_reconnect = *self.auto_reconnect.read().await;
162 if auto_reconnect {
163 self.restart_server(&server_name).await?;
165 } else {
166 return Err(ComputerError::InvalidConfiguration(format!(
167 "Server {} is active. Stop it before updating config",
168 server_name
169 )));
170 }
171 }
172
173 {
175 let mut configs = self.servers_config.write().await;
176 configs.insert(server_name.clone(), config);
177 }
178
179 let auto_connect = *self.auto_connect.read().await;
181 if auto_connect && !is_active {
182 self.start_client(&server_name).await?;
183 }
184
185 self.refresh_tool_mapping().await?;
187
188 Ok(())
189 }
190
191 pub async fn remove_server(&self, server_name: &str) -> Result<(), ComputerError> {
193 self.stop_client(server_name).await?;
195
196 {
198 let mut configs = self.servers_config.write().await;
199 configs.remove(server_name);
200 }
201
202 self.refresh_tool_mapping().await?;
204
205 Ok(())
206 }
207
208 pub async fn start_all(&self) -> Result<(), ComputerError> {
210 let configs = self.servers_config.read().await;
211 let server_names: Vec<String> = configs
212 .iter()
213 .filter(|(_, config)| !config.disabled())
214 .map(|(name, _)| name.clone())
215 .collect();
216
217 drop(configs);
218
219 for server_name in server_names {
220 self.start_client(&server_name).await?;
221 }
222
223 self.update_state(ManagerState::Running).await;
225
226 info!("All servers started successfully");
227 Ok(())
228 }
229
230 pub async fn start_client(&self, server_name: &str) -> Result<(), ComputerError> {
232 let config = {
234 let configs = self.servers_config.read().await;
235 configs.get(server_name).cloned()
236 };
237
238 let config = config.ok_or_else(|| {
239 ComputerError::InvalidConfiguration(format!("Unknown server: {}", server_name))
240 })?;
241
242 if config.disabled() {
243 return Err(ComputerError::InvalidConfiguration(format!(
244 "Cannot start disabled server: {}",
245 server_name
246 )));
247 }
248
249 {
251 let clients = self.active_clients.read().await;
252 if clients.contains_key(server_name) {
253 return Ok(()); }
255 }
256
257 let client = client_factory(config);
259
260 client.connect().await.map_err(|e| {
262 ComputerError::ConnectionError(format!("Failed to connect to {}: {}", server_name, e))
263 })?;
264
265 {
267 let mut clients = self.active_clients.write().await;
268 clients.insert(server_name.to_string(), client);
269 }
270
271 self.refresh_tool_mapping().await?;
273
274 info!("Client {} started successfully", server_name);
275 Ok(())
276 }
277
278 pub async fn stop_client(&self, server_name: &str) -> Result<(), ComputerError> {
280 let mut client = {
282 let mut clients = self.active_clients.write().await;
283 clients.remove(server_name)
284 };
285
286 if let Some(ref mut c) = client {
288 c.disconnect().await.map_err(|e| {
289 ComputerError::ConnectionError(format!(
290 "Failed to disconnect from {}: {}",
291 server_name, e
292 ))
293 })?;
294 }
295
296 self.refresh_tool_mapping().await?;
298
299 info!("Client {} stopped successfully", server_name);
300 Ok(())
301 }
302
303 async fn restart_server(&self, server_name: &str) -> Result<(), ComputerError> {
305 self.stop_client(server_name).await?;
306
307 let enabled = {
309 let configs = self.servers_config.read().await;
310 configs
311 .get(server_name)
312 .map(|c| !c.disabled())
313 .unwrap_or(false)
314 };
315
316 if enabled {
317 self.start_client(server_name).await?;
318 }
319
320 Ok(())
321 }
322
323 pub async fn stop_all(&self) -> Result<(), ComputerError> {
325 let server_names: Vec<String> = {
326 let clients = self.active_clients.read().await;
327 clients.keys().cloned().collect()
328 };
329
330 for server_name in server_names {
331 self.stop_client(&server_name).await?;
332 }
333
334 self.update_state(ManagerState::Initialized).await;
336
337 info!("All servers stopped successfully");
338 Ok(())
339 }
340
341 async fn clear_all(&self) {
343 self.servers_config.write().await.clear();
344 self.active_clients.write().await.clear();
345 self.tool_mapping.write().await.clear();
346 self.alias_mapping.write().await.clear();
347 self.disabled_tools.write().await.clear();
348 }
349
350 pub async fn close(&self) -> Result<(), ComputerError> {
352 self.stop_all().await?;
353 self.clear_all().await;
354 self.update_state(ManagerState::Uninitialized).await;
355 info!("Manager closed successfully");
356 Ok(())
357 }
358
359 async fn refresh_tool_mapping(&self) -> Result<(), ComputerError> {
361 self.tool_mapping.write().await.clear();
363 self.alias_mapping.write().await.clear();
364 self.disabled_tools.write().await.clear();
365
366 let mut tool_sources: HashMap<ToolName, Vec<ServerName>> = HashMap::new();
368
369 let clients = self.active_clients.read().await;
371 let configs = self.servers_config.read().await;
372
373 for (server_name, client) in clients.iter() {
374 let config = match configs.get(server_name) {
375 Some(c) => c,
376 None => continue,
377 };
378
379 match client.list_tools().await {
381 Ok(tools) => {
382 for tool in tools {
383 let original_tool_name = tool.name.clone();
384
385 let tool_meta = self.merged_tool_meta(config, &original_tool_name);
387
388 let display_name = tool_meta
390 .and_then(|meta| meta.alias)
391 .unwrap_or_else(|| original_tool_name.to_string());
392
393 let original_tool_name_str = original_tool_name.to_string();
395 if display_name != original_tool_name_str {
396 let mut alias_map = self.alias_mapping.write().await;
397 alias_map.insert(
398 display_name.clone(),
399 (server_name.clone(), original_tool_name_str.clone()),
400 );
401 }
402
403 tool_sources
405 .entry(display_name.clone())
406 .or_default()
407 .push(server_name.clone());
408
409 let forbidden_tools = config.forbidden_tools();
411 if forbidden_tools.contains(&display_name)
412 || forbidden_tools.contains(&original_tool_name_str)
413 {
414 let mut disabled = self.disabled_tools.write().await;
415 disabled.insert(display_name);
416 }
417 }
418 }
419 Err(e) => {
420 error!("Error listing tools for {}: {}", server_name, e);
421 }
422 }
423 }
424
425 for (tool, sources) in tool_sources {
427 if sources.len() > 1 {
428 warn!("Tool '{}' exists in multiple servers: {:?}", tool, sources);
429 let suggestion =
430 "Please use the 'alias' feature in ToolMeta to resolve conflicts. \
431 Each tool should have a unique name or alias across all servers.";
432 return Err(ComputerError::InvalidConfiguration(format!(
433 "Tool '{}' exists in multiple servers: {:?}\n{}",
434 tool, sources, suggestion
435 )));
436 }
437 let mut mapping = self.tool_mapping.write().await;
438 mapping.insert(tool, sources[0].clone());
439 }
440
441 debug!("Tool mapping refreshed successfully");
442 Ok(())
443 }
444
445 pub async fn validate_tool_call(
447 &self,
448 tool_name: &str,
449 _parameters: &serde_json::Value,
450 ) -> Result<(ServerName, ToolName), ComputerError> {
451 let disabled = self.disabled_tools.read().await;
453 if disabled.contains(tool_name) {
454 return Err(ComputerError::PermissionError(format!(
455 "Tool '{}' is disabled by configuration",
456 tool_name
457 )));
458 }
459
460 let server_name = {
462 let mapping = self.tool_mapping.read().await;
463 mapping.get(tool_name).cloned()
464 };
465
466 let server_name = server_name.ok_or_else(|| {
467 ComputerError::InvalidConfiguration(format!(
468 "Tool '{}' not found in any active server",
469 tool_name
470 ))
471 })?;
472
473 let original_tool_name = {
475 let alias_map = self.alias_mapping.read().await;
476 if let Some((_, original)) = alias_map.get(tool_name) {
477 original.clone()
478 } else {
479 tool_name.to_string()
480 }
481 };
482
483 Ok((server_name, original_tool_name))
484 }
485
486 pub async fn call_tool(
488 &self,
489 server_name: &str,
490 tool_name: &str,
491 parameters: serde_json::Value,
492 timeout: Option<std::time::Duration>,
493 ) -> Result<CallToolResult, ComputerError> {
494 let client = {
496 let clients = self.active_clients.read().await;
497 clients
498 .get(server_name)
499 .ok_or_else(|| {
500 ComputerError::InvalidConfiguration(format!(
501 "Server '{}' for tool '{}' is not active",
502 server_name, tool_name
503 ))
504 })?
505 .clone()
506 };
507
508 let result = if let Some(timeout) = timeout {
510 tokio::time::timeout(timeout, client.call_tool(tool_name, parameters))
511 .await
512 .map_err(|_| ComputerError::TimeoutError("Tool execution timed out".to_string()))?
513 } else {
514 client.call_tool(tool_name, parameters).await
515 };
516
517 let mut result = result
518 .map_err(|e| ComputerError::ProtocolError(format!("Tool execution failed: {}", e)))?;
519
520 let config = {
522 let configs = self.servers_config.read().await;
523 configs.get(server_name).cloned()
524 };
525
526 if let Some(config) = config {
527 if let Some(tool_meta) = self.merged_tool_meta(&config, tool_name) {
528 if result.meta.is_none() {
529 result.meta = Some(rmcp::model::Meta::new());
530 }
531 if let Some(ref mut meta) = result.meta {
532 meta.insert(
533 A2C_TOOL_META.to_string(),
534 serde_json::to_value(tool_meta).unwrap(),
535 );
536 }
537 }
538
539 if let Some(vrl_script) = config.vrl() {
541 let parameters = serde_json::json!({});
544
545 let mut event = serde_json::to_value(&result).unwrap_or_default();
547 if let Value::Object(ref mut map) = event {
548 map.insert(
549 "tool_name".to_string(),
550 Value::String(tool_name.to_string()),
551 );
552 map.insert("parameters".to_string(), parameters);
553 }
554
555 let mut runtime = VrlRuntime::new();
557 match runtime.run(vrl_script, event, "UTC") {
558 Ok(vrl_result) => {
559 if result.meta.is_none() {
561 result.meta = Some(rmcp::model::Meta::new());
562 }
563 if let Some(ref mut meta) = result.meta {
564 if let Ok(transformed_json) =
566 serde_json::to_string(&vrl_result.processed_event)
567 {
568 meta.insert(
569 A2C_VRL_TRANSFORMED.to_string(),
570 Value::String(transformed_json),
571 );
572 }
573 }
574 debug!(
575 "VRL转换成功 / VRL transformation succeeded for tool '{}'",
576 tool_name
577 );
578 }
579 Err(e) => {
580 warn!(
581 "VRL转换失败 / VRL transformation failed for tool '{}': {}. 原始结果将正常返回 / Original result will be returned normally.",
582 tool_name, e
583 );
584 }
585 }
586 }
587 }
588
589 Ok(result)
590 }
591
592 pub async fn execute_tool(
594 &self,
595 tool_name: &str,
596 parameters: serde_json::Value,
597 timeout: Option<std::time::Duration>,
598 ) -> Result<CallToolResult, ComputerError> {
599 let (server_name, original_tool_name) =
600 self.validate_tool_call(tool_name, ¶meters).await?;
601 self.call_tool(&server_name, &original_tool_name, parameters, timeout)
602 .await
603 }
604
605 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
607 let configs = self.servers_config.read().await;
608 let clients = self.active_clients.read().await;
609
610 configs
611 .keys()
612 .map(|name| {
613 let is_active = clients.contains_key(name);
614 let state = if is_active {
615 clients
616 .get(name)
617 .map(|c| c.state().to_string())
618 .unwrap_or_else(|| "unknown".to_string())
619 } else {
620 "pending".to_string()
621 };
622 (name.clone(), is_active, state)
623 })
624 .collect()
625 }
626
627 pub async fn get_server_configs(&self) -> serde_json::Value {
632 let configs = self.servers_config.read().await;
633 let clients = self.active_clients.read().await;
634
635 let mut result = serde_json::Map::new();
636
637 for (name, config) in configs.iter() {
638 let is_active = clients.contains_key(name);
639 let state = if is_active {
640 clients
641 .get(name)
642 .map(|c| c.state().to_string())
643 .unwrap_or_else(|| "unknown".to_string())
644 } else {
645 "pending".to_string()
646 };
647
648 let mut server_info = serde_json::Map::new();
650
651 let server_type = match config {
653 MCPServerConfig::Stdio(_) => "stdio",
654 MCPServerConfig::Sse(_) => "sse",
655 MCPServerConfig::Http(_) => "http",
656 };
657 server_info.insert(
658 "type".to_string(),
659 serde_json::Value::String(server_type.to_string()),
660 );
661
662 server_info.insert("status".to_string(), serde_json::Value::String(state));
664 server_info.insert("is_active".to_string(), serde_json::Value::Bool(is_active));
665 server_info.insert(
666 "disabled".to_string(),
667 serde_json::Value::Bool(config.disabled()),
668 );
669
670 let forbidden_tools: Vec<serde_json::Value> = config
672 .forbidden_tools()
673 .iter()
674 .map(|t| serde_json::Value::String(t.clone()))
675 .collect();
676 server_info.insert(
677 "forbidden_tools".to_string(),
678 serde_json::Value::Array(forbidden_tools),
679 );
680
681 if let Ok(tool_meta_json) = serde_json::to_value(config.tool_meta()) {
683 server_info.insert("tool_meta".to_string(), tool_meta_json);
684 }
685
686 if let Some(default_meta) = config.default_tool_meta() {
688 if let Ok(default_meta_json) = serde_json::to_value(default_meta) {
689 server_info.insert("default_tool_meta".to_string(), default_meta_json);
690 }
691 }
692
693 if let Some(vrl) = config.vrl() {
695 server_info.insert(
696 "vrl".to_string(),
697 serde_json::Value::String(vrl.to_string()),
698 );
699 }
700
701 match config {
703 MCPServerConfig::Stdio(stdio_config) => {
704 if let Ok(params_json) = serde_json::to_value(&stdio_config.server_parameters) {
705 server_info.insert("server_parameters".to_string(), params_json);
706 }
707 }
708 MCPServerConfig::Sse(sse_config) => {
709 if let Ok(params_json) = serde_json::to_value(&sse_config.server_parameters) {
710 server_info.insert("server_parameters".to_string(), params_json);
711 }
712 }
713 MCPServerConfig::Http(http_config) => {
714 if let Ok(params_json) = serde_json::to_value(&http_config.server_parameters) {
715 server_info.insert("server_parameters".to_string(), params_json);
716 }
717 }
718 }
719
720 result.insert(name.clone(), serde_json::Value::Object(server_info));
721 }
722
723 serde_json::Value::Object(result)
724 }
725
726 pub async fn list_available_tools(&self) -> Vec<Tool> {
728 let mut tools = Vec::new();
729 let mapping = self.tool_mapping.read().await;
730 let alias_map = self.alias_mapping.read().await;
731
732 for (display_name, server_name) in mapping.iter() {
733 let client = {
734 let clients = self.active_clients.read().await;
735 clients.get(server_name).cloned()
736 };
737
738 if let Some(client) = client {
739 let original_name = alias_map
741 .get(display_name)
742 .map(|(_, original)| original.clone())
743 .unwrap_or_else(|| display_name.clone());
744
745 if let Ok(tool_list) = client.list_tools().await {
747 if let Some(tool) = tool_list.into_iter().find(|t| t.name == original_name) {
748 let mut display_tool = tool;
750 display_tool.name = display_name.clone().into();
751
752 let config = {
754 let configs = self.servers_config.read().await;
755 configs.get(server_name).cloned()
756 };
757 if let Some(config) = config {
758 if let Some(tool_meta) = self.merged_tool_meta(&config, &original_name)
759 {
760 if display_tool.meta.is_none() {
761 display_tool.meta = Some(rmcp::model::Meta::new());
762 }
763 if let Some(ref mut meta) = display_tool.meta {
764 meta.insert(
765 A2C_TOOL_META.to_string(),
766 serde_json::to_value(tool_meta).unwrap(),
767 );
768 }
769 }
770 }
771
772 tools.push(display_tool);
773 }
774 }
775 }
776 }
777
778 tools
779 }
780
781 pub async fn list_all_windows(&self, window_uri: Option<&str>) -> Vec<(ServerName, Resource)> {
785 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
786 let clients_guard = self.active_clients.read().await;
787 clients_guard
788 .iter()
789 .map(|(k, v)| (k.clone(), v.clone()))
790 .collect()
791 };
792
793 let mut results = Vec::new();
794 for (server_name, client) in clients {
795 match client.list_windows().await {
796 Ok(windows) => {
797 for resource in windows {
798 if let Some(uri_filter) = window_uri {
799 if resource.uri.as_str() != uri_filter {
800 continue;
801 }
802 }
803 results.push((server_name.clone(), resource));
804 }
805 }
806 Err(e) => {
807 warn!(
808 "Failed to list windows from server '{}': {}",
809 server_name, e
810 );
811 }
812 }
813 }
814 results
815 }
816
817 pub async fn get_windows_details(
821 &self,
822 window_uri: Option<&str>,
823 ) -> Vec<(ServerName, Resource, ReadResourceResult)> {
824 let windows = self.list_all_windows(window_uri).await;
825
826 let clients = self.active_clients.read().await;
827 let mut results = Vec::new();
828 for (server_name, resource) in windows {
829 if let Some(client) = clients.get(&server_name) {
830 match client.get_window_detail(resource.clone()).await {
831 Ok(detail) => {
832 results.push((server_name, resource, detail));
833 }
834 Err(e) => {
835 warn!(
836 "Failed to get window detail for '{}' from server '{}': {}",
837 resource.uri, server_name, e
838 );
839 }
840 }
841 }
842 }
843 results
844 }
845
846 pub async fn get_window_detail(
850 &self,
851 server_name: &str,
852 resource: Resource,
853 ) -> Result<ReadResourceResult, ComputerError> {
854 let client = {
855 let clients = self.active_clients.read().await;
856 clients.get(server_name).cloned().ok_or_else(|| {
857 ComputerError::InvalidState(format!("Server '{}' not connected", server_name))
858 })?
859 };
860 client
861 .get_window_detail(resource)
862 .await
863 .map_err(|e| ComputerError::ProtocolError(format!("Get window detail error: {}", e)))
864 }
865
866 fn merged_tool_meta(&self, config: &MCPServerConfig, tool_name: &str) -> Option<ToolMeta> {
868 let specific = config.tool_meta().get(tool_name);
869 let default = config.default_tool_meta();
870
871 match (specific, default) {
872 (None, None) => None,
873 (Some(s), None) => Some(s.clone()),
874 (None, Some(d)) => Some(d.clone()),
875 (Some(s), Some(d)) => {
876 let mut merged = d.clone();
878 if s.auto_apply.is_some() {
879 merged.auto_apply = s.auto_apply;
880 }
881 if s.alias.is_some() {
882 merged.alias = s.alias.clone();
883 }
884 if s.tags.is_some() {
885 merged.tags = s.tags.clone();
886 }
887 if s.ret_object_mapper.is_some() {
888 merged.ret_object_mapper = s.ret_object_mapper.clone();
889 }
890 Some(merged)
891 }
892 }
893 }
894
895 pub async fn enable_auto_connect(&self) {
897 *self.auto_connect.write().await = true;
898 }
899
900 pub async fn disable_auto_connect(&self) {
902 *self.auto_connect.write().await = false;
903 }
904
905 pub async fn enable_auto_reconnect(&self) {
907 *self.auto_reconnect.write().await = true;
908 }
909
910 pub async fn disable_auto_reconnect(&self) {
912 *self.auto_reconnect.write().await = false;
913 }
914
915 pub async fn set_health_check_config(&self, config: HealthCheckConfig) {
917 *self.health_check_config.write().await = config;
918 }
919
920 pub async fn get_health_check_config(&self) -> HealthCheckConfig {
922 self.health_check_config.read().await.clone()
923 }
924
925 pub async fn set_reconnect_policy(&self, policy: ReconnectPolicy) {
927 *self.reconnect_policy.write().await = policy;
928 }
929
930 pub async fn get_reconnect_policy(&self) -> ReconnectPolicy {
932 self.reconnect_policy.read().await.clone()
933 }
934
935 pub async fn start_health_monitor(&self) {
939 self.stop_health_monitor().await;
941
942 let health_config = self.health_check_config.clone();
943 let reconnect_policy = self.reconnect_policy.clone();
944 let active_clients = self.active_clients.clone();
945 let _servers_config = self.servers_config.clone();
946 let retry_counts = self.retry_counts.clone();
947 let auto_reconnect = self.auto_reconnect.clone();
948
949 let handle = tokio::spawn(async move {
950 loop {
951 let config = health_config.read().await.clone();
952 if !config.enabled {
953 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
956 continue;
957 }
958
959 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
961 let clients_guard = active_clients.read().await;
962 clients_guard
963 .iter()
964 .map(|(k, v)| (k.clone(), v.clone()))
965 .collect()
966 };
967
968 for (server_name, client) in clients {
970 let check_result = tokio::time::timeout(
971 std::time::Duration::from_secs(config.timeout_secs),
972 client.health_check(),
973 )
974 .await;
975
976 let is_healthy = match check_result {
977 Ok(result) => result.is_healthy,
978 Err(_) => {
979 warn!("Health check timed out for server: {}", server_name);
980 false
981 }
982 };
983
984 if !is_healthy {
985 warn!("Server {} is unhealthy", server_name);
986
987 let should_reconnect = *auto_reconnect.read().await;
989 if !should_reconnect {
990 continue;
991 }
992
993 let policy = reconnect_policy.read().await.clone();
994 let mut retries = retry_counts.write().await;
995 let retry_count = retries.entry(server_name.clone()).or_insert(0);
996
997 if policy.should_retry(*retry_count) {
998 let delay = policy.calculate_delay(*retry_count);
999 info!(
1000 "Attempting to reconnect {} (retry {}/{}), delay {:?}",
1001 server_name,
1002 *retry_count + 1,
1003 if policy.max_retries == 0 {
1004 "∞".to_string()
1005 } else {
1006 policy.max_retries.to_string()
1007 },
1008 delay
1009 );
1010
1011 tokio::time::sleep(delay).await;
1012
1013 if let Err(e) = client.disconnect().await {
1015 warn!("Failed to disconnect {}: {}", server_name, e);
1016 }
1017
1018 match client.connect().await {
1019 Ok(_) => {
1020 info!("Successfully reconnected to {}", server_name);
1021 *retry_count = 0;
1023 }
1024 Err(e) => {
1025 error!("Failed to reconnect to {}: {}", server_name, e);
1026 *retry_count += 1;
1027 }
1028 }
1029 } else {
1030 error!(
1031 "Max retries ({}) reached for server {}. Giving up.",
1032 policy.max_retries, server_name
1033 );
1034 }
1036 } else {
1037 let mut retries = retry_counts.write().await;
1039 retries.remove(&server_name);
1040 debug!("Server {} is healthy", server_name);
1041 }
1042 }
1043
1044 tokio::time::sleep(std::time::Duration::from_secs(config.interval_secs)).await;
1046 }
1047 });
1048
1049 *self.health_monitor_handle.write().await = Some(handle);
1050 info!("Health monitor started");
1051 }
1052
1053 pub async fn stop_health_monitor(&self) {
1055 if let Some(handle) = self.health_monitor_handle.write().await.take() {
1056 handle.abort();
1057 info!("Health monitor stopped");
1058 }
1059 }
1060
1061 pub async fn check_server_health(&self, server_name: &str) -> Option<HealthCheckResult> {
1063 let clients = self.active_clients.read().await;
1064 if let Some(client) = clients.get(server_name) {
1065 let config = self.health_check_config.read().await;
1066 let result = tokio::time::timeout(
1067 std::time::Duration::from_secs(config.timeout_secs),
1068 client.health_check(),
1069 )
1070 .await;
1071
1072 match result {
1073 Ok(health_result) => Some(health_result),
1074 Err(_) => Some(HealthCheckResult {
1075 is_healthy: false,
1076 checked_at: std::time::Instant::now(),
1077 error: Some("Health check timed out".to_string()),
1078 response_time_ms: None,
1079 }),
1080 }
1081 } else {
1082 None
1083 }
1084 }
1085
1086 pub async fn check_all_health(&self) -> HashMap<String, HealthCheckResult> {
1088 let mut results = HashMap::new();
1089 let clients: Vec<(String, StdArc<dyn MCPClientProtocol>)> = {
1090 let clients_guard = self.active_clients.read().await;
1091 clients_guard
1092 .iter()
1093 .map(|(k, v)| (k.clone(), v.clone()))
1094 .collect()
1095 };
1096
1097 let config = self.health_check_config.read().await.clone();
1098
1099 for (server_name, client) in clients {
1100 let result = tokio::time::timeout(
1101 std::time::Duration::from_secs(config.timeout_secs),
1102 client.health_check(),
1103 )
1104 .await;
1105
1106 let health_result = match result {
1107 Ok(hr) => hr,
1108 Err(_) => HealthCheckResult {
1109 is_healthy: false,
1110 checked_at: std::time::Instant::now(),
1111 error: Some("Health check timed out".to_string()),
1112 response_time_ms: None,
1113 },
1114 };
1115
1116 results.insert(server_name, health_result);
1117 }
1118
1119 results
1120 }
1121
1122 pub async fn get_retry_counts(&self) -> HashMap<String, u32> {
1124 self.retry_counts.read().await.clone()
1125 }
1126
1127 pub async fn reset_retry_count(&self, server_name: &str) {
1129 self.retry_counts.write().await.remove(server_name);
1130 }
1131
1132 pub async fn reset_all_retry_counts(&self) {
1134 self.retry_counts.write().await.clear();
1135 }
1136}
1137
1138impl Default for MCPServerManager {
1139 fn default() -> Self {
1140 Self::new()
1141 }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146 use super::*;
1147 use std::collections::HashMap;
1148 use tokio::time::{sleep, Duration};
1149
1150 #[tokio::test]
1151 async fn test_manager_creation() {
1152 let manager = MCPServerManager::new();
1153 let status = manager.get_server_status().await;
1154 assert!(status.is_empty());
1155 }
1156
1157 #[tokio::test]
1158 async fn test_manager_initialization() {
1159 let manager = MCPServerManager::new();
1160
1161 let configs = vec![
1163 MCPServerConfig::Stdio(StdioServerConfig {
1165 name: "test_stdio".to_string(),
1166 disabled: false,
1167 forbidden_tools: vec![],
1168 tool_meta: HashMap::new(),
1169 default_tool_meta: None,
1170 vrl: None,
1171 server_parameters: StdioServerParameters {
1172 command: "echo".to_string(),
1173 args: vec!["hello".to_string()],
1174 env: HashMap::new(),
1175 cwd: None,
1176 },
1177 }),
1178 MCPServerConfig::Http(HttpServerConfig {
1180 name: "test_http".to_string(),
1181 disabled: true, forbidden_tools: vec![],
1183 tool_meta: HashMap::new(),
1184 default_tool_meta: None,
1185 vrl: None,
1186 server_parameters: HttpServerParameters {
1187 url: "http://localhost:8080".to_string(),
1188 headers: HashMap::new(),
1189 },
1190 }),
1191 ];
1192
1193 let result = manager.initialize(configs).await;
1195 assert!(result.is_ok());
1196
1197 let status = manager.get_server_status().await;
1199 assert_eq!(status.len(), 2);
1200
1201 let stdio_status = status
1203 .iter()
1204 .find(|(name, _, _)| name == "test_stdio")
1205 .unwrap();
1206 assert!(!stdio_status.1); let http_status = status
1209 .iter()
1210 .find(|(name, _, _)| name == "test_http")
1211 .unwrap();
1212 assert!(!http_status.1); }
1214
1215 #[tokio::test]
1216 async fn test_add_server() {
1217 let manager = MCPServerManager::new();
1218
1219 let config = MCPServerConfig::Stdio(StdioServerConfig {
1221 name: "test_server".to_string(),
1222 disabled: false,
1223 forbidden_tools: vec![],
1224 tool_meta: HashMap::new(),
1225 default_tool_meta: None,
1226 vrl: None,
1227 server_parameters: StdioServerParameters {
1228 command: "echo".to_string(),
1229 args: vec![],
1230 env: HashMap::new(),
1231 cwd: None,
1232 },
1233 });
1234
1235 let result = manager.add_or_update_server(config).await;
1236 assert!(result.is_ok());
1237
1238 let status = manager.get_server_status().await;
1240 assert_eq!(status.len(), 1);
1241 assert_eq!(status[0].0, "test_server");
1242 }
1243
1244 #[tokio::test]
1245 async fn test_remove_server() {
1246 let manager = MCPServerManager::new();
1247
1248 let config = MCPServerConfig::Stdio(StdioServerConfig {
1250 name: "test_server".to_string(),
1251 disabled: false,
1252 forbidden_tools: vec![],
1253 tool_meta: HashMap::new(),
1254 default_tool_meta: None,
1255 vrl: None,
1256 server_parameters: StdioServerParameters {
1257 command: "echo".to_string(),
1258 args: vec![],
1259 env: HashMap::new(),
1260 cwd: None,
1261 },
1262 });
1263
1264 manager.add_or_update_server(config).await.unwrap();
1265
1266 let result = manager.remove_server("test_server").await;
1268 assert!(result.is_ok());
1269
1270 let status = manager.get_server_status().await;
1272 assert!(status.is_empty());
1273 }
1274
1275 #[tokio::test]
1276 async fn test_tool_conflict_detection() {
1277 let manager = MCPServerManager::new();
1278
1279 let configs = vec![
1281 MCPServerConfig::Stdio(StdioServerConfig {
1283 name: "server1".to_string(),
1284 disabled: false,
1285 forbidden_tools: vec![],
1286 tool_meta: HashMap::new(),
1287 default_tool_meta: None,
1288 vrl: None,
1289 server_parameters: StdioServerParameters {
1290 command: "echo".to_string(),
1291 args: vec!["server1".to_string()],
1292 env: HashMap::new(),
1293 cwd: None,
1294 },
1295 }),
1296 MCPServerConfig::Stdio(StdioServerConfig {
1298 name: "server2".to_string(),
1299 disabled: false,
1300 forbidden_tools: vec![],
1301 tool_meta: HashMap::new(),
1302 default_tool_meta: None,
1303 vrl: None,
1304 server_parameters: StdioServerParameters {
1305 command: "echo".to_string(),
1306 args: vec!["server2".to_string()],
1307 env: HashMap::new(),
1308 cwd: None,
1309 },
1310 }),
1311 ];
1312
1313 let result = manager.initialize(configs).await;
1315 assert!(result.is_ok());
1316
1317 let _result = manager.start_all().await;
1319 sleep(Duration::from_millis(200)).await;
1324 }
1325
1326 #[tokio::test]
1327 async fn test_health_check_config() {
1328 let manager = MCPServerManager::new();
1329
1330 let config = manager.get_health_check_config().await;
1332 assert_eq!(config.interval_secs, 30);
1333 assert_eq!(config.timeout_secs, 5);
1334 assert!(config.enabled);
1335
1336 let new_config = HealthCheckConfig {
1338 interval_secs: 60,
1339 timeout_secs: 10,
1340 enabled: false,
1341 };
1342 manager.set_health_check_config(new_config.clone()).await;
1343
1344 let updated = manager.get_health_check_config().await;
1345 assert_eq!(updated.interval_secs, 60);
1346 assert_eq!(updated.timeout_secs, 10);
1347 assert!(!updated.enabled);
1348 }
1349
1350 #[tokio::test]
1351 async fn test_reconnect_policy() {
1352 let manager = MCPServerManager::new();
1353
1354 let policy = manager.get_reconnect_policy().await;
1356 assert!(policy.enabled);
1357 assert_eq!(policy.max_retries, 5);
1358 assert_eq!(policy.initial_delay_ms, 1000);
1359 assert_eq!(policy.max_delay_ms, 30000);
1360 assert_eq!(policy.backoff_factor, 2.0);
1361
1362 assert_eq!(policy.calculate_delay(0).as_millis(), 1000);
1364 assert_eq!(policy.calculate_delay(1).as_millis(), 2000);
1365 assert_eq!(policy.calculate_delay(2).as_millis(), 4000);
1366 assert_eq!(policy.calculate_delay(3).as_millis(), 8000);
1367
1368 assert!(policy.should_retry(0));
1370 assert!(policy.should_retry(4));
1371 assert!(!policy.should_retry(5)); let infinite_policy = ReconnectPolicy {
1375 enabled: true,
1376 max_retries: 0,
1377 ..Default::default()
1378 };
1379 assert!(infinite_policy.should_retry(100));
1380 }
1381
1382 #[tokio::test]
1383 async fn test_retry_counts() {
1384 let manager = MCPServerManager::new();
1385
1386 let counts = manager.get_retry_counts().await;
1388 assert!(counts.is_empty());
1389
1390 {
1392 manager
1393 .retry_counts
1394 .write()
1395 .await
1396 .insert("server1".to_string(), 3);
1397 manager
1398 .retry_counts
1399 .write()
1400 .await
1401 .insert("server2".to_string(), 5);
1402 }
1403
1404 let counts = manager.get_retry_counts().await;
1405 assert_eq!(counts.get("server1"), Some(&3));
1406 assert_eq!(counts.get("server2"), Some(&5));
1407
1408 manager.reset_retry_count("server1").await;
1410 let counts = manager.get_retry_counts().await;
1411 assert!(!counts.contains_key("server1"));
1412 assert_eq!(counts.get("server2"), Some(&5));
1413
1414 manager.reset_all_retry_counts().await;
1416 let counts = manager.get_retry_counts().await;
1417 assert!(counts.is_empty());
1418 }
1419
1420 #[tokio::test]
1421 async fn test_manager_with_custom_config() {
1422 let health_config = HealthCheckConfig {
1423 interval_secs: 15,
1424 timeout_secs: 3,
1425 enabled: true,
1426 };
1427 let reconnect_policy = ReconnectPolicy {
1428 enabled: true,
1429 max_retries: 10,
1430 initial_delay_ms: 500,
1431 max_delay_ms: 60000,
1432 backoff_factor: 1.5,
1433 };
1434
1435 let manager =
1436 MCPServerManager::with_config(health_config.clone(), reconnect_policy.clone());
1437
1438 let got_health = manager.get_health_check_config().await;
1439 assert_eq!(got_health.interval_secs, 15);
1440 assert_eq!(got_health.timeout_secs, 3);
1441
1442 let got_reconnect = manager.get_reconnect_policy().await;
1443 assert_eq!(got_reconnect.max_retries, 10);
1444 assert_eq!(got_reconnect.initial_delay_ms, 500);
1445 }
1446
1447 #[tokio::test]
1448 async fn test_merged_tool_meta() {
1449 let manager = MCPServerManager::new();
1450
1451 let config = MCPServerConfig::Stdio(StdioServerConfig {
1453 name: "s".to_string(),
1454 disabled: false,
1455 forbidden_tools: vec![],
1456 tool_meta: HashMap::from([(
1457 "tool_a".to_string(),
1458 ToolMeta {
1459 auto_apply: Some(true),
1460 alias: None,
1461 tags: Some(vec!["tag1".to_string()]),
1462 ret_object_mapper: None,
1463 },
1464 )]),
1465 default_tool_meta: None,
1466 vrl: None,
1467 server_parameters: StdioServerParameters {
1468 command: "echo".to_string(),
1469 args: vec![],
1470 env: HashMap::new(),
1471 cwd: None,
1472 },
1473 });
1474 let meta = manager.merged_tool_meta(&config, "tool_a").unwrap();
1475 assert_eq!(meta.auto_apply, Some(true));
1476 assert_eq!(meta.tags, Some(vec!["tag1".to_string()]));
1477
1478 let config = MCPServerConfig::Stdio(StdioServerConfig {
1480 name: "s".to_string(),
1481 disabled: false,
1482 forbidden_tools: vec![],
1483 tool_meta: HashMap::new(),
1484 default_tool_meta: Some(ToolMeta {
1485 auto_apply: Some(false),
1486 alias: None,
1487 tags: Some(vec!["default_tag".to_string()]),
1488 ret_object_mapper: None,
1489 }),
1490 vrl: None,
1491 server_parameters: StdioServerParameters {
1492 command: "echo".to_string(),
1493 args: vec![],
1494 env: HashMap::new(),
1495 cwd: None,
1496 },
1497 });
1498 let meta = manager.merged_tool_meta(&config, "any_tool").unwrap();
1499 assert_eq!(meta.auto_apply, Some(false));
1500 assert_eq!(meta.tags, Some(vec!["default_tag".to_string()]));
1501
1502 let config = MCPServerConfig::Stdio(StdioServerConfig {
1504 name: "s".to_string(),
1505 disabled: false,
1506 forbidden_tools: vec![],
1507 tool_meta: HashMap::from([(
1508 "tool_a".to_string(),
1509 ToolMeta {
1510 auto_apply: Some(true),
1511 alias: None,
1512 tags: None,
1513 ret_object_mapper: None,
1514 },
1515 )]),
1516 default_tool_meta: Some(ToolMeta {
1517 auto_apply: Some(false),
1518 alias: Some("default_alias".to_string()),
1519 tags: Some(vec!["default_tag".to_string()]),
1520 ret_object_mapper: None,
1521 }),
1522 vrl: None,
1523 server_parameters: StdioServerParameters {
1524 command: "echo".to_string(),
1525 args: vec![],
1526 env: HashMap::new(),
1527 cwd: None,
1528 },
1529 });
1530 let meta = manager.merged_tool_meta(&config, "tool_a").unwrap();
1531 assert_eq!(meta.auto_apply, Some(true)); assert_eq!(meta.alias, Some("default_alias".to_string())); assert_eq!(meta.tags, Some(vec!["default_tag".to_string()])); let config = MCPServerConfig::Stdio(StdioServerConfig {
1537 name: "s".to_string(),
1538 disabled: false,
1539 forbidden_tools: vec![],
1540 tool_meta: HashMap::new(),
1541 default_tool_meta: None,
1542 vrl: None,
1543 server_parameters: StdioServerParameters {
1544 command: "echo".to_string(),
1545 args: vec![],
1546 env: HashMap::new(),
1547 cwd: None,
1548 },
1549 });
1550 assert!(manager.merged_tool_meta(&config, "tool_a").is_none());
1551 }
1552
1553 #[tokio::test]
1554 async fn test_list_all_windows_empty_manager() {
1555 let manager = MCPServerManager::new();
1556 let windows = manager.list_all_windows(None).await;
1557 assert!(windows.is_empty());
1558 }
1559
1560 #[tokio::test]
1561 async fn test_get_window_detail_server_not_connected() {
1562 use super::make_resource;
1563 let manager = MCPServerManager::new();
1564 let resource = make_resource(
1565 "window://test/status",
1566 "Test",
1567 None,
1568 Some("text/plain".into()),
1569 );
1570 let result = manager.get_window_detail("unknown_server", resource).await;
1571 assert!(result.is_err());
1572 match result {
1573 Err(ComputerError::InvalidState(msg)) => {
1574 assert!(msg.contains("not connected"));
1575 }
1576 other => panic!("Expected InvalidState, got {:?}", other),
1577 }
1578 }
1579}