Skip to main content

smcp_computer/mcp_clients/
model.rs

1/**
2* 文件名: model
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: serde, async-trait
8* 描述: MCP客户端相关的数据模型定义
9*/
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt;
13use thiserror::Error;
14
15// Re-export MCP protocol types from rmcp
16pub use rmcp::model::{
17    Annotated, CallToolResult, Content, RawContent, RawResource, RawTextContent,
18    ReadResourceResult, Resource, ResourceContents, Tool, ToolAnnotations,
19};
20
21// 常量定义 / Constants definition
22pub const A2C_TOOL_META: &str = "a2c_tool_meta";
23pub const A2C_VRL_TRANSFORMED: &str = "a2c_vrl_transformed";
24
25// 类型别名 / Type aliases
26pub type ServerName = String;
27pub type ToolName = String;
28
29/// MCP工具元数据 / MCP tool metadata
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
31pub struct ToolMeta {
32    /// 是否自动使用 / Whether to auto-apply
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub auto_apply: Option<bool>,
35    /// 工具别名 / Tool alias
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub alias: Option<String>,
38    /// 工具标签 / Tool tags
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub tags: Option<Vec<String>>,
41    /// 返回值字段映射 / Return value field mapping
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub ret_object_mapper: Option<HashMap<String, String>>,
44}
45
46impl ToolMeta {
47    /// 创建空的工具元数据 / Create empty tool metadata
48    pub fn new() -> Self {
49        Self {
50            auto_apply: None,
51            alias: None,
52            tags: None,
53            ret_object_mapper: None,
54        }
55    }
56}
57
58impl Default for ToolMeta {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// MCP服务器配置基类 / Base MCP server configuration
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(tag = "type")]
67pub enum MCPServerConfig {
68    /// STDIO类型服务器 / STDIO type server
69    #[serde(alias = "stdio", alias = "STDIO")]
70    Stdio(StdioServerConfig),
71    /// SSE类型服务器 / SSE type server
72    #[serde(alias = "sse", alias = "SSE")]
73    Sse(SseServerConfig),
74    /// HTTP类型服务器 / HTTP type server
75    #[serde(alias = "http", alias = "HTTP")]
76    Http(HttpServerConfig),
77}
78
79impl MCPServerConfig {
80    /// 获取服务器名称 / Get server name
81    pub fn name(&self) -> &str {
82        match self {
83            MCPServerConfig::Stdio(config) => &config.name,
84            MCPServerConfig::Sse(config) => &config.name,
85            MCPServerConfig::Http(config) => &config.name,
86        }
87    }
88
89    /// 获取是否禁用标志 / Get disabled flag
90    pub fn disabled(&self) -> bool {
91        match self {
92            MCPServerConfig::Stdio(config) => config.disabled,
93            MCPServerConfig::Sse(config) => config.disabled,
94            MCPServerConfig::Http(config) => config.disabled,
95        }
96    }
97
98    /// 获取禁用工具列表 / Get forbidden tools list
99    pub fn forbidden_tools(&self) -> &[String] {
100        match self {
101            MCPServerConfig::Stdio(config) => &config.forbidden_tools,
102            MCPServerConfig::Sse(config) => &config.forbidden_tools,
103            MCPServerConfig::Http(config) => &config.forbidden_tools,
104        }
105    }
106
107    /// 获取工具元数据映射 / Get tool metadata mapping
108    pub fn tool_meta(&self) -> &HashMap<ToolName, ToolMeta> {
109        match self {
110            MCPServerConfig::Stdio(config) => &config.tool_meta,
111            MCPServerConfig::Sse(config) => &config.tool_meta,
112            MCPServerConfig::Http(config) => &config.tool_meta,
113        }
114    }
115
116    /// 获取默认工具元数据 / Get default tool metadata
117    pub fn default_tool_meta(&self) -> Option<&ToolMeta> {
118        match self {
119            MCPServerConfig::Stdio(config) => config.default_tool_meta.as_ref(),
120            MCPServerConfig::Sse(config) => config.default_tool_meta.as_ref(),
121            MCPServerConfig::Http(config) => config.default_tool_meta.as_ref(),
122        }
123    }
124
125    /// 获取VRL脚本 / Get VRL script
126    pub fn vrl(&self) -> Option<&str> {
127        match self {
128            MCPServerConfig::Stdio(config) => config.vrl.as_deref(),
129            MCPServerConfig::Sse(config) => config.vrl.as_deref(),
130            MCPServerConfig::Http(config) => config.vrl.as_deref(),
131        }
132    }
133}
134
135/// STDIO服务器配置 / STDIO server configuration
136#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
137pub struct StdioServerConfig {
138    /// 服务器名称 / Server name
139    pub name: ServerName,
140    /// 是否禁用 / Whether disabled
141    #[serde(default)]
142    pub disabled: bool,
143    /// 禁用工具列表 / Forbidden tools list
144    #[serde(default)]
145    pub forbidden_tools: Vec<ToolName>,
146    /// 工具元数据 / Tool metadata
147    #[serde(default)]
148    pub tool_meta: HashMap<ToolName, ToolMeta>,
149    /// 默认工具元数据 / Default tool metadata
150    pub default_tool_meta: Option<ToolMeta>,
151    /// VRL脚本 / VRL script
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub vrl: Option<String>,
154    /// STDIO服务器参数 / STDIO server parameters
155    pub server_parameters: StdioServerParameters,
156}
157
158/// SSE服务器配置 / SSE server configuration
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
160pub struct SseServerConfig {
161    /// 服务器名称 / Server name
162    pub name: ServerName,
163    /// 是否禁用 / Whether disabled
164    #[serde(default)]
165    pub disabled: bool,
166    /// 禁用工具列表 / Forbidden tools list
167    #[serde(default)]
168    pub forbidden_tools: Vec<ToolName>,
169    /// 工具元数据 / Tool metadata
170    #[serde(default)]
171    pub tool_meta: HashMap<ToolName, ToolMeta>,
172    /// 默认工具元数据 / Default tool metadata
173    pub default_tool_meta: Option<ToolMeta>,
174    /// VRL脚本 / VRL script
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub vrl: Option<String>,
177    /// SSE服务器参数 / SSE server parameters
178    pub server_parameters: SseServerParameters,
179}
180
181/// HTTP服务器配置 / HTTP server configuration
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
183pub struct HttpServerConfig {
184    /// 服务器名称 / Server name
185    pub name: ServerName,
186    /// 是否禁用 / Whether disabled
187    #[serde(default)]
188    pub disabled: bool,
189    /// 禁用工具列表 / Forbidden tools list
190    #[serde(default)]
191    pub forbidden_tools: Vec<ToolName>,
192    /// 工具元数据 / Tool metadata
193    #[serde(default)]
194    pub tool_meta: HashMap<ToolName, ToolMeta>,
195    /// 默认工具元数据 / Default tool metadata
196    pub default_tool_meta: Option<ToolMeta>,
197    /// VRL脚本 / VRL script
198    #[serde(skip_serializing_if = "Option::is_none")]
199    pub vrl: Option<String>,
200    /// HTTP服务器参数 / HTTP server parameters
201    pub server_parameters: HttpServerParameters,
202}
203
204fn null_to_empty_map<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
205where
206    D: serde::Deserializer<'de>,
207{
208    let opt = Option::<HashMap<String, String>>::deserialize(deserializer)?;
209    Ok(opt.unwrap_or_default())
210}
211
212/// STDIO服务器参数 / STDIO server parameters
213#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
214pub struct StdioServerParameters {
215    /// 命令 / Command
216    pub command: String,
217    /// 参数 / Arguments
218    #[serde(default)]
219    pub args: Vec<String>,
220    /// 环境变量 / Environment variables
221    #[serde(default, deserialize_with = "null_to_empty_map")]
222    pub env: HashMap<String, String>,
223    /// 工作目录 / Working directory
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub cwd: Option<String>,
226}
227
228/// SSE服务器参数 / SSE server parameters
229#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
230pub struct SseServerParameters {
231    /// URL / URL
232    pub url: String,
233    /// Headers / Headers
234    #[serde(default)]
235    pub headers: HashMap<String, String>,
236}
237
238/// HTTP服务器参数 / HTTP server parameters
239#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
240pub struct HttpServerParameters {
241    /// URL / URL
242    pub url: String,
243    /// Headers / Headers
244    #[serde(default)]
245    pub headers: HashMap<String, String>,
246}
247
248/// MCP服务器输入项基类 / Base MCP server input configuration
249#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
250#[serde(tag = "type")]
251pub enum MCPServerInput {
252    /// 字符串输入 / String input
253    PromptString(PromptStringInput),
254    /// 选择输入 / Pick string input
255    PickString(PickStringInput),
256    /// 命令输入 / Command input
257    Command(CommandInput),
258}
259
260impl MCPServerInput {
261    /// 获取输入ID / Get input ID
262    pub fn id(&self) -> &str {
263        match self {
264            MCPServerInput::PromptString(input) => &input.id,
265            MCPServerInput::PickString(input) => &input.id,
266            MCPServerInput::Command(input) => &input.id,
267        }
268    }
269
270    /// 获取输入描述 / Get input description
271    pub fn description(&self) -> &str {
272        match self {
273            MCPServerInput::PromptString(input) => &input.description,
274            MCPServerInput::PickString(input) => &input.description,
275            MCPServerInput::Command(input) => &input.description,
276        }
277    }
278
279    /// 获取默认值 / Get default value
280    pub fn default(&self) -> Option<serde_json::Value> {
281        match self {
282            MCPServerInput::PromptString(input) => input
283                .default
284                .as_ref()
285                .map(|s| serde_json::Value::String(s.clone())),
286            MCPServerInput::PickString(input) => input
287                .default
288                .as_ref()
289                .map(|s| serde_json::Value::String(s.clone())),
290            MCPServerInput::Command(_input) => {
291                // Command 类型不支持默认值
292                // Command type doesn't support default values
293                None
294            }
295        }
296    }
297}
298
299/// 字符串输入类型 / String input type
300#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
301pub struct PromptStringInput {
302    /// 输入ID / Input ID
303    pub id: String,
304    /// 描述 / Description
305    pub description: String,
306    /// 默认值 / Default value
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub default: Option<String>,
309    /// 是否为密码 / Whether password
310    #[serde(skip_serializing_if = "Option::is_none")]
311    pub password: Option<bool>,
312}
313
314/// 选择输入类型 / Pick string input type
315#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
316pub struct PickStringInput {
317    /// 输入ID / Input ID
318    pub id: String,
319    /// 描述 / Description
320    pub description: String,
321    /// 选项 / Options
322    #[serde(default)]
323    pub options: Vec<String>,
324    /// 默认值 / Default value
325    #[serde(skip_serializing_if = "Option::is_none")]
326    pub default: Option<String>,
327}
328
329/// 命令输入类型 / Command input type
330#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
331pub struct CommandInput {
332    /// 输入ID / Input ID
333    pub id: String,
334    /// 描述 / Description
335    pub description: String,
336    /// 命令 / Command
337    pub command: String,
338    /// 参数 / Arguments
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub args: Option<HashMap<String, String>>,
341}
342
343/// 健康检查配置 / Health check configuration
344#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
345pub struct HealthCheckConfig {
346    /// 健康检查间隔(秒)/ Health check interval in seconds
347    #[serde(default = "default_health_check_interval")]
348    pub interval_secs: u64,
349    /// 超时时间(秒)/ Timeout in seconds
350    #[serde(default = "default_health_check_timeout")]
351    pub timeout_secs: u64,
352    /// 是否启用健康检查 / Whether to enable health check
353    #[serde(default = "default_health_check_enabled")]
354    pub enabled: bool,
355}
356
357fn default_health_check_interval() -> u64 {
358    30
359}
360
361fn default_health_check_timeout() -> u64 {
362    5
363}
364
365fn default_health_check_enabled() -> bool {
366    true
367}
368
369impl Default for HealthCheckConfig {
370    fn default() -> Self {
371        Self {
372            interval_secs: 30,
373            timeout_secs: 5,
374            enabled: true,
375        }
376    }
377}
378
379/// 重连策略 / Reconnect policy
380#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
381pub struct ReconnectPolicy {
382    /// 是否启用自动重连 / Whether to enable auto reconnect
383    #[serde(default = "default_reconnect_enabled")]
384    pub enabled: bool,
385    /// 最大重试次数(0表示无限重试)/ Max retry count (0 means infinite)
386    #[serde(default = "default_max_retries")]
387    pub max_retries: u32,
388    /// 初始延迟时间(毫秒)/ Initial delay in milliseconds
389    #[serde(default = "default_initial_delay_ms")]
390    pub initial_delay_ms: u64,
391    /// 最大延迟时间(毫秒)/ Max delay in milliseconds
392    #[serde(default = "default_max_delay_ms")]
393    pub max_delay_ms: u64,
394    /// 退避因子(延迟时间乘数)/ Backoff factor (delay multiplier)
395    #[serde(default = "default_backoff_factor")]
396    pub backoff_factor: f64,
397}
398
399fn default_reconnect_enabled() -> bool {
400    true
401}
402
403fn default_max_retries() -> u32 {
404    5
405}
406
407fn default_initial_delay_ms() -> u64 {
408    1000
409}
410
411fn default_max_delay_ms() -> u64 {
412    30000
413}
414
415fn default_backoff_factor() -> f64 {
416    2.0
417}
418
419impl Default for ReconnectPolicy {
420    fn default() -> Self {
421        Self {
422            enabled: true,
423            max_retries: 5,
424            initial_delay_ms: 1000,
425            max_delay_ms: 30000,
426            backoff_factor: 2.0,
427        }
428    }
429}
430
431impl ReconnectPolicy {
432    /// 计算下次重试的延迟时间 / Calculate delay for next retry
433    pub fn calculate_delay(&self, retry_count: u32) -> std::time::Duration {
434        let delay_ms = (self.initial_delay_ms as f64 * self.backoff_factor.powi(retry_count as i32))
435            .min(self.max_delay_ms as f64) as u64;
436        std::time::Duration::from_millis(delay_ms)
437    }
438
439    /// 检查是否应该继续重试 / Check if should continue retry
440    pub fn should_retry(&self, retry_count: u32) -> bool {
441        self.enabled && (self.max_retries == 0 || retry_count < self.max_retries)
442    }
443}
444
445/// 健康检查结果 / Health check result
446#[derive(Debug, Clone)]
447pub struct HealthCheckResult {
448    /// 是否健康 / Is healthy
449    pub is_healthy: bool,
450    /// 检查时间 / Check time
451    pub checked_at: std::time::Instant,
452    /// 错误信息(如果有)/ Error message if any
453    pub error: Option<String>,
454    /// 响应时间(毫秒)/ Response time in milliseconds
455    pub response_time_ms: Option<u64>,
456}
457
458/// MCP客户端协议trait / MCP client protocol trait
459#[async_trait::async_trait]
460pub trait MCPClientProtocol: Send + Sync {
461    /// 获取客户端状态 / Get client state
462    fn state(&self) -> ClientState;
463
464    /// 连接MCP服务器 / Connect to MCP server
465    async fn connect(&self) -> Result<(), MCPClientError>;
466
467    /// 断开连接 / Disconnect
468    async fn disconnect(&self) -> Result<(), MCPClientError>;
469
470    /// 获取可用工具列表 / Get available tools list
471    async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError>;
472
473    /// 调用工具 / Call tool
474    async fn call_tool(
475        &self,
476        tool_name: &str,
477        params: serde_json::Value,
478    ) -> Result<CallToolResult, MCPClientError>;
479
480    /// 列出窗口资源 / List window resources
481    async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError>;
482
483    /// 获取窗口详情 / Get window detail
484    async fn get_window_detail(
485        &self,
486        resource: Resource,
487    ) -> Result<ReadResourceResult, MCPClientError>;
488
489    /// 订阅窗口资源更新 / Subscribe to window resource updates
490    async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError>;
491
492    /// 取消订阅窗口资源更新 / Unsubscribe from window resource updates
493    async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError>;
494
495    /// 执行健康检查 / Perform health check
496    /// 默认实现通过检查状态和尝试 list_tools 来验证连接
497    /// Default implementation checks state and tries list_tools to verify connection
498    async fn health_check(&self) -> HealthCheckResult {
499        let start = std::time::Instant::now();
500
501        // 首先检查状态 / First check state
502        if self.state() != ClientState::Connected {
503            return HealthCheckResult {
504                is_healthy: false,
505                checked_at: start,
506                error: Some(format!("Client state is {:?}, not Connected", self.state())),
507                response_time_ms: None,
508            };
509        }
510
511        // 尝试调用 list_tools 来验证连接 / Try calling list_tools to verify connection
512        match tokio::time::timeout(std::time::Duration::from_secs(5), self.list_tools()).await {
513            Ok(Ok(_)) => {
514                let elapsed = start.elapsed();
515                HealthCheckResult {
516                    is_healthy: true,
517                    checked_at: start,
518                    error: None,
519                    response_time_ms: Some(elapsed.as_millis() as u64),
520                }
521            }
522            Ok(Err(e)) => HealthCheckResult {
523                is_healthy: false,
524                checked_at: start,
525                error: Some(format!("Health check failed: {}", e)),
526                response_time_ms: None,
527            },
528            Err(_) => HealthCheckResult {
529                is_healthy: false,
530                checked_at: start,
531                error: Some("Health check timed out".to_string()),
532                response_time_ms: None,
533            },
534        }
535    }
536}
537
538/// 客户端状态 / Client state
539#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
540pub enum ClientState {
541    /// 已初始化 / Initialized
542    Initialized,
543    /// 已连接 / Connected
544    Connected,
545    /// 已断开 / Disconnected
546    Disconnected,
547    /// 错误状态 / Error
548    Error,
549}
550
551impl fmt::Display for ClientState {
552    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
553        match self {
554            ClientState::Initialized => write!(f, "initialized"),
555            ClientState::Connected => write!(f, "connected"),
556            ClientState::Disconnected => write!(f, "disconnected"),
557            ClientState::Error => write!(f, "error"),
558        }
559    }
560}
561
562/// MCP客户端错误 / MCP client error
563#[derive(Debug, Error)]
564pub enum MCPClientError {
565    /// 连接错误 / Connection error
566    #[error("Connection error: {0}")]
567    ConnectionError(String),
568    /// 协议错误 / Protocol error
569    #[error("Protocol error: {0}")]
570    ProtocolError(String),
571    /// IO错误 / IO error
572    #[error("IO error: {0}")]
573    IoError(#[from] std::io::Error),
574    /// JSON错误 / JSON error
575    #[error("JSON error: {0}")]
576    JsonError(#[from] serde_json::Error),
577    /// 超时错误 / Timeout error
578    #[error("Timeout error: {0}")]
579    TimeoutError(String),
580    /// 其他错误 / Other error
581    #[error("Other error: {0}")]
582    Other(String),
583}
584
585/// 便捷函数:创建 Resource / Convenience: create a Resource
586pub fn make_resource(
587    uri: impl Into<String>,
588    name: impl Into<String>,
589    description: Option<String>,
590    mime_type: Option<String>,
591) -> Resource {
592    use rmcp::model::AnnotateAble;
593    let mut raw = RawResource::new(uri, name);
594    raw.description = description;
595    raw.mime_type = mime_type;
596    raw.no_annotation()
597}
598
599/// 便捷函数:检查 CallToolResult 是否为错误 / Convenience: check if CallToolResult is error
600pub fn is_call_tool_error(result: &CallToolResult) -> bool {
601    result.is_error.unwrap_or(false)
602}
603
604/// 便捷函数:从 Content 中提取文本 / Convenience: extract text from Content
605pub fn content_as_text(content: &Content) -> Option<&str> {
606    content.as_text().map(|t| t.text.as_str())
607}
608
609/// 便捷函数:从 ResourceContents 中提取文本 / Convenience: extract text from ResourceContents
610pub fn resource_contents_as_text(rc: &ResourceContents) -> Option<&str> {
611    match rc {
612        ResourceContents::TextResourceContents { text, .. } => Some(text.as_str()),
613        _ => None,
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn test_is_call_tool_error() {
623        let ok_result = CallToolResult::success(vec![Content::text("ok")]);
624        assert!(!is_call_tool_error(&ok_result));
625
626        let err_result = CallToolResult::error(vec![Content::text("fail")]);
627        assert!(is_call_tool_error(&err_result));
628    }
629
630    #[test]
631    fn test_content_as_text() {
632        let content = Content::text("hello");
633        assert_eq!(content_as_text(&content), Some("hello"));
634    }
635
636    #[test]
637    fn test_resource_contents_as_text() {
638        let rc = ResourceContents::TextResourceContents {
639            uri: "test://uri".to_string(),
640            mime_type: None,
641            text: "some text".to_string(),
642            meta: None,
643        };
644        assert_eq!(resource_contents_as_text(&rc), Some("some text"));
645
646        let blob = ResourceContents::BlobResourceContents {
647            uri: "test://uri".to_string(),
648            mime_type: None,
649            blob: "base64data".to_string(),
650            meta: None,
651        };
652        assert_eq!(resource_contents_as_text(&blob), None);
653    }
654
655    #[test]
656    fn test_make_resource() {
657        let resource = make_resource("window://test", "Test", Some("desc".into()), None);
658        assert_eq!(resource.raw.uri, "window://test");
659        assert_eq!(resource.raw.name, "Test");
660        assert_eq!(resource.raw.description, Some("desc".into()));
661        assert!(resource.raw.mime_type.is_none());
662    }
663}