1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt;
13use thiserror::Error;
14
15pub use rmcp::model::{
17 Annotated, CallToolResult, Content, RawContent, RawResource, RawTextContent,
18 ReadResourceResult, Resource, ResourceContents, Tool, ToolAnnotations,
19};
20
21pub const A2C_TOOL_META: &str = "a2c_tool_meta";
23pub const A2C_VRL_TRANSFORMED: &str = "a2c_vrl_transformed";
24
25pub type ServerName = String;
27pub type ToolName = String;
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
31pub struct ToolMeta {
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub auto_apply: Option<bool>,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub alias: Option<String>,
38 #[serde(skip_serializing_if = "Option::is_none")]
40 pub tags: Option<Vec<String>>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub ret_object_mapper: Option<HashMap<String, String>>,
44}
45
46impl ToolMeta {
47 pub fn new() -> Self {
49 Self {
50 auto_apply: None,
51 alias: None,
52 tags: None,
53 ret_object_mapper: None,
54 }
55 }
56}
57
58impl Default for ToolMeta {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(tag = "type")]
67pub enum MCPServerConfig {
68 #[serde(alias = "stdio", alias = "STDIO")]
70 Stdio(StdioServerConfig),
71 #[serde(alias = "sse", alias = "SSE")]
73 Sse(SseServerConfig),
74 #[serde(alias = "http", alias = "HTTP")]
76 Http(HttpServerConfig),
77}
78
79impl MCPServerConfig {
80 pub fn name(&self) -> &str {
82 match self {
83 MCPServerConfig::Stdio(config) => &config.name,
84 MCPServerConfig::Sse(config) => &config.name,
85 MCPServerConfig::Http(config) => &config.name,
86 }
87 }
88
89 pub fn disabled(&self) -> bool {
91 match self {
92 MCPServerConfig::Stdio(config) => config.disabled,
93 MCPServerConfig::Sse(config) => config.disabled,
94 MCPServerConfig::Http(config) => config.disabled,
95 }
96 }
97
98 pub fn forbidden_tools(&self) -> &[String] {
100 match self {
101 MCPServerConfig::Stdio(config) => &config.forbidden_tools,
102 MCPServerConfig::Sse(config) => &config.forbidden_tools,
103 MCPServerConfig::Http(config) => &config.forbidden_tools,
104 }
105 }
106
107 pub fn tool_meta(&self) -> &HashMap<ToolName, ToolMeta> {
109 match self {
110 MCPServerConfig::Stdio(config) => &config.tool_meta,
111 MCPServerConfig::Sse(config) => &config.tool_meta,
112 MCPServerConfig::Http(config) => &config.tool_meta,
113 }
114 }
115
116 pub fn default_tool_meta(&self) -> Option<&ToolMeta> {
118 match self {
119 MCPServerConfig::Stdio(config) => config.default_tool_meta.as_ref(),
120 MCPServerConfig::Sse(config) => config.default_tool_meta.as_ref(),
121 MCPServerConfig::Http(config) => config.default_tool_meta.as_ref(),
122 }
123 }
124
125 pub fn vrl(&self) -> Option<&str> {
127 match self {
128 MCPServerConfig::Stdio(config) => config.vrl.as_deref(),
129 MCPServerConfig::Sse(config) => config.vrl.as_deref(),
130 MCPServerConfig::Http(config) => config.vrl.as_deref(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
137pub struct StdioServerConfig {
138 pub name: ServerName,
140 #[serde(default)]
142 pub disabled: bool,
143 #[serde(default)]
145 pub forbidden_tools: Vec<ToolName>,
146 #[serde(default)]
148 pub tool_meta: HashMap<ToolName, ToolMeta>,
149 pub default_tool_meta: Option<ToolMeta>,
151 #[serde(skip_serializing_if = "Option::is_none")]
153 pub vrl: Option<String>,
154 pub server_parameters: StdioServerParameters,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
160pub struct SseServerConfig {
161 pub name: ServerName,
163 #[serde(default)]
165 pub disabled: bool,
166 #[serde(default)]
168 pub forbidden_tools: Vec<ToolName>,
169 #[serde(default)]
171 pub tool_meta: HashMap<ToolName, ToolMeta>,
172 pub default_tool_meta: Option<ToolMeta>,
174 #[serde(skip_serializing_if = "Option::is_none")]
176 pub vrl: Option<String>,
177 pub server_parameters: SseServerParameters,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
183pub struct HttpServerConfig {
184 pub name: ServerName,
186 #[serde(default)]
188 pub disabled: bool,
189 #[serde(default)]
191 pub forbidden_tools: Vec<ToolName>,
192 #[serde(default)]
194 pub tool_meta: HashMap<ToolName, ToolMeta>,
195 pub default_tool_meta: Option<ToolMeta>,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub vrl: Option<String>,
200 pub server_parameters: HttpServerParameters,
202}
203
204fn null_to_empty_map<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
205where
206 D: serde::Deserializer<'de>,
207{
208 let opt = Option::<HashMap<String, String>>::deserialize(deserializer)?;
209 Ok(opt.unwrap_or_default())
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
214pub struct StdioServerParameters {
215 pub command: String,
217 #[serde(default)]
219 pub args: Vec<String>,
220 #[serde(default, deserialize_with = "null_to_empty_map")]
222 pub env: HashMap<String, String>,
223 #[serde(skip_serializing_if = "Option::is_none")]
225 pub cwd: Option<String>,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
230pub struct SseServerParameters {
231 pub url: String,
233 #[serde(default)]
235 pub headers: HashMap<String, String>,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
240pub struct HttpServerParameters {
241 pub url: String,
243 #[serde(default)]
245 pub headers: HashMap<String, String>,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
250#[serde(tag = "type")]
251pub enum MCPServerInput {
252 PromptString(PromptStringInput),
254 PickString(PickStringInput),
256 Command(CommandInput),
258}
259
260impl MCPServerInput {
261 pub fn id(&self) -> &str {
263 match self {
264 MCPServerInput::PromptString(input) => &input.id,
265 MCPServerInput::PickString(input) => &input.id,
266 MCPServerInput::Command(input) => &input.id,
267 }
268 }
269
270 pub fn description(&self) -> &str {
272 match self {
273 MCPServerInput::PromptString(input) => &input.description,
274 MCPServerInput::PickString(input) => &input.description,
275 MCPServerInput::Command(input) => &input.description,
276 }
277 }
278
279 pub fn default(&self) -> Option<serde_json::Value> {
281 match self {
282 MCPServerInput::PromptString(input) => input
283 .default
284 .as_ref()
285 .map(|s| serde_json::Value::String(s.clone())),
286 MCPServerInput::PickString(input) => input
287 .default
288 .as_ref()
289 .map(|s| serde_json::Value::String(s.clone())),
290 MCPServerInput::Command(_input) => {
291 None
294 }
295 }
296 }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
301pub struct PromptStringInput {
302 pub id: String,
304 pub description: String,
306 #[serde(skip_serializing_if = "Option::is_none")]
308 pub default: Option<String>,
309 #[serde(skip_serializing_if = "Option::is_none")]
311 pub password: Option<bool>,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
316pub struct PickStringInput {
317 pub id: String,
319 pub description: String,
321 #[serde(default)]
323 pub options: Vec<String>,
324 #[serde(skip_serializing_if = "Option::is_none")]
326 pub default: Option<String>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
331pub struct CommandInput {
332 pub id: String,
334 pub description: String,
336 pub command: String,
338 #[serde(skip_serializing_if = "Option::is_none")]
340 pub args: Option<HashMap<String, String>>,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
345pub struct HealthCheckConfig {
346 #[serde(default = "default_health_check_interval")]
348 pub interval_secs: u64,
349 #[serde(default = "default_health_check_timeout")]
351 pub timeout_secs: u64,
352 #[serde(default = "default_health_check_enabled")]
354 pub enabled: bool,
355}
356
357fn default_health_check_interval() -> u64 {
358 30
359}
360
361fn default_health_check_timeout() -> u64 {
362 5
363}
364
365fn default_health_check_enabled() -> bool {
366 true
367}
368
369impl Default for HealthCheckConfig {
370 fn default() -> Self {
371 Self {
372 interval_secs: 30,
373 timeout_secs: 5,
374 enabled: true,
375 }
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
381pub struct ReconnectPolicy {
382 #[serde(default = "default_reconnect_enabled")]
384 pub enabled: bool,
385 #[serde(default = "default_max_retries")]
387 pub max_retries: u32,
388 #[serde(default = "default_initial_delay_ms")]
390 pub initial_delay_ms: u64,
391 #[serde(default = "default_max_delay_ms")]
393 pub max_delay_ms: u64,
394 #[serde(default = "default_backoff_factor")]
396 pub backoff_factor: f64,
397}
398
399fn default_reconnect_enabled() -> bool {
400 true
401}
402
403fn default_max_retries() -> u32 {
404 5
405}
406
407fn default_initial_delay_ms() -> u64 {
408 1000
409}
410
411fn default_max_delay_ms() -> u64 {
412 30000
413}
414
415fn default_backoff_factor() -> f64 {
416 2.0
417}
418
419impl Default for ReconnectPolicy {
420 fn default() -> Self {
421 Self {
422 enabled: true,
423 max_retries: 5,
424 initial_delay_ms: 1000,
425 max_delay_ms: 30000,
426 backoff_factor: 2.0,
427 }
428 }
429}
430
431impl ReconnectPolicy {
432 pub fn calculate_delay(&self, retry_count: u32) -> std::time::Duration {
434 let delay_ms = (self.initial_delay_ms as f64 * self.backoff_factor.powi(retry_count as i32))
435 .min(self.max_delay_ms as f64) as u64;
436 std::time::Duration::from_millis(delay_ms)
437 }
438
439 pub fn should_retry(&self, retry_count: u32) -> bool {
441 self.enabled && (self.max_retries == 0 || retry_count < self.max_retries)
442 }
443}
444
445#[derive(Debug, Clone)]
447pub struct HealthCheckResult {
448 pub is_healthy: bool,
450 pub checked_at: std::time::Instant,
452 pub error: Option<String>,
454 pub response_time_ms: Option<u64>,
456}
457
458#[async_trait::async_trait]
460pub trait MCPClientProtocol: Send + Sync {
461 fn state(&self) -> ClientState;
463
464 async fn connect(&self) -> Result<(), MCPClientError>;
466
467 async fn disconnect(&self) -> Result<(), MCPClientError>;
469
470 async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError>;
472
473 async fn call_tool(
475 &self,
476 tool_name: &str,
477 params: serde_json::Value,
478 ) -> Result<CallToolResult, MCPClientError>;
479
480 async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError>;
482
483 async fn get_window_detail(
485 &self,
486 resource: Resource,
487 ) -> Result<ReadResourceResult, MCPClientError>;
488
489 async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError>;
491
492 async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError>;
494
495 async fn health_check(&self) -> HealthCheckResult {
499 let start = std::time::Instant::now();
500
501 if self.state() != ClientState::Connected {
503 return HealthCheckResult {
504 is_healthy: false,
505 checked_at: start,
506 error: Some(format!("Client state is {:?}, not Connected", self.state())),
507 response_time_ms: None,
508 };
509 }
510
511 match tokio::time::timeout(std::time::Duration::from_secs(5), self.list_tools()).await {
513 Ok(Ok(_)) => {
514 let elapsed = start.elapsed();
515 HealthCheckResult {
516 is_healthy: true,
517 checked_at: start,
518 error: None,
519 response_time_ms: Some(elapsed.as_millis() as u64),
520 }
521 }
522 Ok(Err(e)) => HealthCheckResult {
523 is_healthy: false,
524 checked_at: start,
525 error: Some(format!("Health check failed: {}", e)),
526 response_time_ms: None,
527 },
528 Err(_) => HealthCheckResult {
529 is_healthy: false,
530 checked_at: start,
531 error: Some("Health check timed out".to_string()),
532 response_time_ms: None,
533 },
534 }
535 }
536}
537
538#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
540pub enum ClientState {
541 Initialized,
543 Connected,
545 Disconnected,
547 Error,
549}
550
551impl fmt::Display for ClientState {
552 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
553 match self {
554 ClientState::Initialized => write!(f, "initialized"),
555 ClientState::Connected => write!(f, "connected"),
556 ClientState::Disconnected => write!(f, "disconnected"),
557 ClientState::Error => write!(f, "error"),
558 }
559 }
560}
561
562#[derive(Debug, Error)]
564pub enum MCPClientError {
565 #[error("Connection error: {0}")]
567 ConnectionError(String),
568 #[error("Protocol error: {0}")]
570 ProtocolError(String),
571 #[error("IO error: {0}")]
573 IoError(#[from] std::io::Error),
574 #[error("JSON error: {0}")]
576 JsonError(#[from] serde_json::Error),
577 #[error("Timeout error: {0}")]
579 TimeoutError(String),
580 #[error("Other error: {0}")]
582 Other(String),
583}
584
585pub fn make_resource(
587 uri: impl Into<String>,
588 name: impl Into<String>,
589 description: Option<String>,
590 mime_type: Option<String>,
591) -> Resource {
592 use rmcp::model::AnnotateAble;
593 let mut raw = RawResource::new(uri, name);
594 raw.description = description;
595 raw.mime_type = mime_type;
596 raw.no_annotation()
597}
598
599pub fn is_call_tool_error(result: &CallToolResult) -> bool {
601 result.is_error.unwrap_or(false)
602}
603
604pub fn content_as_text(content: &Content) -> Option<&str> {
606 content.as_text().map(|t| t.text.as_str())
607}
608
609pub fn resource_contents_as_text(rc: &ResourceContents) -> Option<&str> {
611 match rc {
612 ResourceContents::TextResourceContents { text, .. } => Some(text.as_str()),
613 _ => None,
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn test_is_call_tool_error() {
623 let ok_result = CallToolResult::success(vec![Content::text("ok")]);
624 assert!(!is_call_tool_error(&ok_result));
625
626 let err_result = CallToolResult::error(vec![Content::text("fail")]);
627 assert!(is_call_tool_error(&err_result));
628 }
629
630 #[test]
631 fn test_content_as_text() {
632 let content = Content::text("hello");
633 assert_eq!(content_as_text(&content), Some("hello"));
634 }
635
636 #[test]
637 fn test_resource_contents_as_text() {
638 let rc = ResourceContents::TextResourceContents {
639 uri: "test://uri".to_string(),
640 mime_type: None,
641 text: "some text".to_string(),
642 meta: None,
643 };
644 assert_eq!(resource_contents_as_text(&rc), Some("some text"));
645
646 let blob = ResourceContents::BlobResourceContents {
647 uri: "test://uri".to_string(),
648 mime_type: None,
649 blob: "base64data".to_string(),
650 meta: None,
651 };
652 assert_eq!(resource_contents_as_text(&blob), None);
653 }
654
655 #[test]
656 fn test_make_resource() {
657 let resource = make_resource("window://test", "Test", Some("desc".into()), None);
658 assert_eq!(resource.raw.uri, "window://test");
659 assert_eq!(resource.raw.name, "Test");
660 assert_eq!(resource.raw.description, Some("desc".into()));
661 assert!(resource.raw.mime_type.is_none());
662 }
663}