Skip to main content

smcp/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use uuid::Uuid;
4
5/// SMCP协议的命名空间
6pub const SMCP_NAMESPACE: &str = "/smcp";
7
8/// 标准错误码模块 / Standard error codes module
9pub mod error_codes {
10    // 通用错误码 / General error codes
11    pub const BAD_REQUEST: i32 = 400;
12    pub const UNAUTHORIZED: i32 = 401;
13    pub const FORBIDDEN: i32 = 403;
14    pub const NOT_FOUND: i32 = 404;
15    pub const TIMEOUT: i32 = 408;
16    pub const INTERNAL_ERROR: i32 = 500;
17
18    // 工具调用错误码 / Tool call error codes
19    pub const TOOL_NOT_FOUND: i32 = 4001;
20    pub const TOOL_DISABLED: i32 = 4002;
21    pub const TOOL_EXECUTION_FAILED: i32 = 4003;
22    pub const TOOL_TIMEOUT: i32 = 4004;
23    pub const TOOL_REQUIRES_CONFIRMATION: i32 = 4005;
24
25    // 房间管理错误码 / Room management error codes
26    pub const ROOM_FULL: i32 = 4101;
27    pub const ROOM_NOT_FOUND: i32 = 4102;
28    pub const NOT_IN_ROOM: i32 = 4103;
29    pub const CROSS_ROOM_ACCESS: i32 = 4104;
30}
31
32/// 错误详情结构 / Error detail structure
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ErrorDetail {
35    /// 错误码 / Error code
36    pub code: i32,
37    /// 人类可读的错误描述 / Human readable error message
38    pub message: String,
39    /// 结构化调试信息(可选)/ Structured debug info (optional)
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub details: Option<HashMap<String, serde_json::Value>>,
42}
43
44impl ErrorDetail {
45    /// 创建新的错误详情 / Create new error detail
46    pub fn new(code: i32, message: impl Into<String>) -> Self {
47        Self {
48            code,
49            message: message.into(),
50            details: None,
51        }
52    }
53
54    /// 添加详情字段 / Add detail field
55    pub fn with_detail(
56        mut self,
57        key: impl Into<String>,
58        value: impl Into<serde_json::Value>,
59    ) -> Self {
60        let details = self.details.get_or_insert_with(HashMap::new);
61        details.insert(key.into(), value.into());
62        self
63    }
64
65    /// 添加多个详情字段 / Add multiple detail fields
66    pub fn with_details(mut self, details: HashMap<String, serde_json::Value>) -> Self {
67        self.details = Some(details);
68        self
69    }
70}
71
72/// 标准错误响应格式 / Standard error response format
73/// 格式: { "error": { "code": i32, "message": str, "details": object? } }
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ErrorResponse {
76    /// 错误详情 / Error detail
77    pub error: ErrorDetail,
78}
79
80impl ErrorResponse {
81    /// 创建新的错误响应 / Create new error response
82    pub fn new(code: i32, message: impl Into<String>) -> Self {
83        Self {
84            error: ErrorDetail::new(code, message),
85        }
86    }
87
88    /// 添加详情字段 / Add detail field
89    pub fn with_detail(
90        mut self,
91        key: impl Into<String>,
92        value: impl Into<serde_json::Value>,
93    ) -> Self {
94        self.error = self.error.with_detail(key, value);
95        self
96    }
97
98    // 便捷构造方法 / Convenience constructors
99
100    /// Bad Request 错误 / Bad Request error
101    pub fn bad_request(message: impl Into<String>) -> Self {
102        Self::new(error_codes::BAD_REQUEST, message)
103    }
104
105    /// Unauthorized 错误 / Unauthorized error
106    pub fn unauthorized(message: impl Into<String>) -> Self {
107        Self::new(error_codes::UNAUTHORIZED, message)
108    }
109
110    /// Forbidden 错误 / Forbidden error
111    pub fn forbidden(message: impl Into<String>) -> Self {
112        Self::new(error_codes::FORBIDDEN, message)
113    }
114
115    /// Not Found 错误 / Not Found error
116    pub fn not_found(message: impl Into<String>) -> Self {
117        Self::new(error_codes::NOT_FOUND, message)
118    }
119
120    /// Timeout 错误 / Timeout error
121    pub fn timeout(message: impl Into<String>) -> Self {
122        Self::new(error_codes::TIMEOUT, message)
123    }
124
125    /// Internal Error 错误 / Internal Error error
126    pub fn internal_error(message: impl Into<String>) -> Self {
127        Self::new(error_codes::INTERNAL_ERROR, message)
128    }
129
130    /// Tool Not Found 错误 / Tool Not Found error
131    pub fn tool_not_found(tool_name: impl Into<String>) -> Self {
132        let name = tool_name.into();
133        Self::new(
134            error_codes::TOOL_NOT_FOUND,
135            format!("Tool '{}' not found", name),
136        )
137        .with_detail("tool_name", serde_json::Value::String(name))
138    }
139
140    /// Tool Execution Failed 错误 / Tool Execution Failed error
141    pub fn tool_execution_failed(message: impl Into<String>) -> Self {
142        Self::new(error_codes::TOOL_EXECUTION_FAILED, message)
143    }
144
145    /// Tool Timeout 错误 / Tool Timeout error
146    pub fn tool_timeout(timeout_secs: u64) -> Self {
147        Self::new(
148            error_codes::TOOL_TIMEOUT,
149            format!("Tool execution timed out after {} seconds", timeout_secs),
150        )
151        .with_detail(
152            "timeout",
153            serde_json::Value::Number(serde_json::Number::from(timeout_secs)),
154        )
155    }
156
157    /// Room Full 错误 / Room Full error
158    pub fn room_full(office_id: impl Into<String>) -> Self {
159        let id = office_id.into();
160        Self::new(
161            error_codes::ROOM_FULL,
162            format!("Room '{}' already has an agent", id),
163        )
164        .with_detail("office_id", serde_json::Value::String(id))
165    }
166
167    /// Not In Room 错误 / Not In Room error
168    pub fn not_in_room() -> Self {
169        Self::new(error_codes::NOT_IN_ROOM, "Session is not in any room")
170    }
171}
172
173/// SMCP事件常量定义
174pub mod events {
175    /// 客户端请求获取工具列表
176    pub const CLIENT_GET_TOOLS: &str = "client:get_tools";
177    /// 客户端请求获取配置
178    pub const CLIENT_GET_CONFIG: &str = "client:get_config";
179    /// 客户端请求获取桌面信息
180    pub const CLIENT_GET_DESKTOP: &str = "client:get_desktop";
181    /// 客户端工具调用请求
182    pub const CLIENT_TOOL_CALL: &str = "client:tool_call";
183
184    /// 服务器加入办公室请求
185    pub const SERVER_JOIN_OFFICE: &str = "server:join_office";
186    /// 服务器离开办公室请求
187    pub const SERVER_LEAVE_OFFICE: &str = "server:leave_office";
188    /// 服务器更新配置请求
189    pub const SERVER_UPDATE_CONFIG: &str = "server:update_config";
190    /// 服务器更新工具列表请求
191    pub const SERVER_UPDATE_TOOL_LIST: &str = "server:update_tool_list";
192    /// 服务器更新桌面请求
193    pub const SERVER_UPDATE_DESKTOP: &str = "server:update_desktop";
194    /// 服务器取消工具调用请求
195    pub const SERVER_TOOL_CALL_CANCEL: &str = "server:tool_call_cancel";
196    /// 服务器列出房间请求
197    pub const SERVER_LIST_ROOM: &str = "server:list_room";
198
199    /// 通知取消工具调用
200    pub const NOTIFY_TOOL_CALL_CANCEL: &str = "notify:tool_call_cancel";
201    /// 通知进入办公室
202    pub const NOTIFY_ENTER_OFFICE: &str = "notify:enter_office";
203    /// 通知离开办公室
204    pub const NOTIFY_LEAVE_OFFICE: &str = "notify:leave_office";
205    /// 通知更新配置
206    pub const NOTIFY_UPDATE_CONFIG: &str = "notify:update_config";
207    /// 通知更新工具列表
208    pub const NOTIFY_UPDATE_TOOL_LIST: &str = "notify:update_tool_list";
209    /// 通知更新桌面
210    pub const NOTIFY_UPDATE_DESKTOP: &str = "notify:update_desktop";
211
212    /// 通用通知前缀
213    pub const NOTIFY_PREFIX: &str = "notify:";
214}
215
216/// 请求ID,使用UUID确保全局唯一性
217#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
218pub struct ReqId(pub String);
219
220impl ReqId {
221    /// 生成新的请求ID(使用hex格式以匹配Python的uuid.uuid4().hex)
222    pub fn new() -> Self {
223        Self(Uuid::new_v4().simple().to_string())
224    }
225
226    /// 从字符串创建请求ID
227    pub fn from_string(s: String) -> Self {
228        Self(s)
229    }
230
231    /// 获取请求ID的字符串引用
232    pub fn as_str(&self) -> &str {
233        &self.0
234    }
235}
236
237impl Default for ReqId {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243/// 角色类型
244#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
245#[serde(rename_all = "lowercase")]
246pub enum Role {
247    Agent,
248    Computer,
249}
250
251impl std::fmt::Display for Role {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        match self {
254            Role::Agent => write!(f, "agent"),
255            Role::Computer => write!(f, "computer"),
256        }
257    }
258}
259
260/// 用户信息
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct UserInfo {
263    pub name: String,
264    pub role: Role,
265}
266
267/// 工具调用请求
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ToolCallReq {
270    #[serde(flatten)]
271    pub base: AgentCallData,
272    pub computer: String,
273    pub tool_name: String,
274    pub params: serde_json::Value,
275    pub timeout: i32,
276}
277
278/// 获取计算机配置请求
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct GetComputerConfigReq {
281    #[serde(flatten)]
282    pub base: AgentCallData,
283    pub computer: String,
284}
285
286/// 更新计算机配置请求
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct UpdateComputerConfigReq {
289    pub computer: String,
290}
291
292/// 获取计算机配置返回
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct GetComputerConfigRet {
295    #[serde(skip_serializing_if = "Option::is_none")]
296    pub inputs: Option<Vec<serde_json::Value>>,
297    pub servers: serde_json::Value,
298}
299
300/// 工具调用返回(符合 MCP CallToolResult 标准)
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct ToolCallRet {
303    #[serde(skip_serializing_if = "Option::is_none")]
304    pub content: Option<Vec<serde_json::Value>>,
305    #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
306    pub is_error: Option<bool>,
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub req_id: Option<ReqId>,
309}
310
311/// 获取工具请求
312#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct GetToolsReq {
314    #[serde(flatten)]
315    pub base: AgentCallData,
316    pub computer: String,
317}
318
319/// SMCP工具定义
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct SMCPTool {
322    pub name: String,
323    pub description: String,
324    pub params_schema: serde_json::Value,
325    pub return_schema: Option<serde_json::Value>,
326    #[serde(skip_serializing_if = "Option::is_none")]
327    pub meta: Option<serde_json::Value>,
328}
329
330/// 获取工具返回
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct GetToolsRet {
333    pub tools: Vec<SMCPTool>,
334    pub req_id: ReqId,
335}
336
337/// 代理调用数据(基类)
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct AgentCallData {
340    pub agent: String,
341    pub req_id: ReqId,
342}
343
344/// 进入办公室请求
345#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct EnterOfficeReq {
347    pub role: Role,
348    pub name: String,
349    pub office_id: String,
350}
351
352/// 离开办公室请求
353#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct LeaveOfficeReq {
355    pub office_id: String,
356}
357
358/// 获取桌面请求
359#[derive(Debug, Clone, Serialize, Deserialize)]
360pub struct GetDesktopReq {
361    #[serde(flatten)]
362    pub base: AgentCallData,
363    pub computer: String,
364    #[serde(skip_serializing_if = "Option::is_none")]
365    pub desktop_size: Option<i32>,
366    #[serde(skip_serializing_if = "Option::is_none")]
367    pub window: Option<String>,
368}
369
370/// 桌面类型别名
371pub type Desktop = String;
372
373/// 获取桌面返回
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct GetDesktopRet {
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub desktops: Option<Vec<Desktop>>,
378    pub req_id: ReqId,
379}
380
381/// 列出房间请求
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct ListRoomReq {
384    #[serde(flatten)]
385    pub base: AgentCallData,
386    pub office_id: String,
387}
388
389/// 会话信息
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct SessionInfo {
392    pub sid: String,
393    pub name: String,
394    pub role: Role,
395    pub office_id: String,
396}
397
398/// 列出房间返回
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct ListRoomRet {
401    pub sessions: Vec<SessionInfo>,
402    pub req_id: ReqId,
403}
404
405/// 进入办公室通知
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct EnterOfficeNotification {
408    pub office_id: String,
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub computer: Option<String>,
411    #[serde(skip_serializing_if = "Option::is_none")]
412    pub agent: Option<String>,
413}
414
415/// 离开办公室通知
416#[derive(Debug, Clone, Serialize, Deserialize)]
417pub struct LeaveOfficeNotification {
418    pub office_id: String,
419    #[serde(skip_serializing_if = "Option::is_none")]
420    pub computer: Option<String>,
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub agent: Option<String>,
423}
424
425/// 更新MCP配置通知
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct UpdateMCPConfigNotification {
428    pub computer: String,
429}
430
431/// 更新工具列表通知
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub struct UpdateToolListNotification {
434    pub computer: String,
435}
436
437/// 通知类型枚举
438#[derive(Debug, Clone, Serialize, Deserialize)]
439#[serde(tag = "type")]
440pub enum Notification {
441    ToolCallCancel,
442    EnterOffice(EnterOfficeNotification),
443    LeaveOffice(LeaveOfficeNotification),
444    UpdateMCPConfig(UpdateMCPConfigNotification),
445    UpdateToolList(UpdateToolListNotification),
446    UpdateDesktop,
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_req_id_helpers() {
455        let req_id = ReqId::new();
456        assert!(!req_id.as_str().is_empty());
457
458        let req_id2 = ReqId::from_string("abc".to_string());
459        assert_eq!(req_id2.as_str(), "abc");
460
461        let req_id3 = ReqId::default();
462        assert!(!req_id3.as_str().is_empty());
463    }
464
465    #[test]
466    fn test_role_serde_lowercase() {
467        let json = serde_json::to_string(&Role::Agent).unwrap();
468        assert_eq!(json, "\"agent\"");
469
470        let de: Role = serde_json::from_str("\"computer\"").unwrap();
471        assert!(matches!(de, Role::Computer));
472    }
473
474    #[test]
475    fn test_notification_serde() {
476        let n = Notification::EnterOffice(EnterOfficeNotification {
477            office_id: "office1".to_string(),
478            computer: Some("c1".to_string()),
479            agent: None,
480        });
481
482        let json = serde_json::to_string(&n).unwrap();
483        let de: Notification = serde_json::from_str(&json).unwrap();
484        match de {
485            Notification::EnterOffice(p) => {
486                assert_eq!(p.office_id, "office1");
487                assert_eq!(p.computer.as_deref(), Some("c1"));
488                assert!(p.agent.is_none());
489            }
490            _ => panic!("unexpected notification"),
491        }
492    }
493
494    #[test]
495    fn test_tool_call_ret_mcp_format() {
496        // 测试成功的工具调用返回(MCP CallToolResult 格式)
497        let success_ret = ToolCallRet {
498            content: Some(vec![serde_json::json!({
499                "type": "text",
500                "text": "Operation completed successfully"
501            })]),
502            is_error: Some(false),
503            req_id: Some(ReqId::from_string("test123".to_string())),
504        };
505
506        let json = serde_json::to_string(&success_ret).unwrap();
507
508        // 验证 JSON 包含正确的 MCP 字段
509        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
510        assert!(parsed.get("content").is_some());
511        assert!(parsed.get("isError").is_some());
512        assert_eq!(parsed.get("isError").unwrap(), false);
513        assert_eq!(parsed.get("req_id").unwrap().as_str().unwrap(), "test123");
514
515        // 验证字段名是 camelCase(isError 而不是 is_error)
516        assert!(json.contains("isError"));
517        assert!(!json.contains("is_error"));
518        // 验证没有旧的 Rust 风格字段(检查字段名而不是整个字符串)
519        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
520        assert!(parsed.get("success").is_none());
521        assert!(parsed.get("result").is_none());
522        assert!(parsed.get("error").is_none());
523    }
524
525    #[test]
526    fn test_tool_call_ret_error_format() {
527        // 测试错误的工具调用返回
528        let error_ret = ToolCallRet {
529            content: Some(vec![serde_json::json!({
530                "type": "text",
531                "text": "Tool execution failed"
532            })]),
533            is_error: Some(true),
534            req_id: None,
535        };
536
537        let json = serde_json::to_string(&error_ret).unwrap();
538
539        // 验证 JSON 格式
540        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
541        assert!(parsed.get("content").is_some());
542        assert_eq!(parsed.get("isError").unwrap(), true);
543        assert!(parsed.get("req_id").is_none());
544
545        // 验证没有旧的 Rust 风格字段
546        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
547        assert!(parsed.get("success").is_none());
548        assert!(parsed.get("result").is_none());
549        assert!(parsed.get("error").is_none());
550    }
551
552    #[test]
553    fn test_tool_call_ret_minimal() {
554        // 测试最小化的工具调用返回
555        let minimal_ret = ToolCallRet {
556            content: None,
557            is_error: None,
558            req_id: None,
559        };
560
561        let json = serde_json::to_string(&minimal_ret).unwrap();
562        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
563
564        // 空对象应该序列化为 {}
565        assert_eq!(parsed, serde_json::json!({}));
566    }
567
568    #[test]
569    fn test_tool_call_ret_roundtrip() {
570        // 测试序列化和反序列化的往返一致性
571        let original = ToolCallRet {
572            content: Some(vec![serde_json::json!({
573                "type": "text",
574                "text": "Test result"
575            })]),
576            is_error: Some(false),
577            req_id: Some(ReqId::new()),
578        };
579
580        let json = serde_json::to_string(&original).unwrap();
581        let deserialized: ToolCallRet = serde_json::from_str(&json).unwrap();
582
583        assert_eq!(original.content, deserialized.content);
584        assert_eq!(original.is_error, deserialized.is_error);
585        assert_eq!(original.req_id, deserialized.req_id);
586    }
587
588    #[test]
589    fn test_error_response_format() {
590        // 测试标准错误响应格式 / Test standard error response format
591        let error_resp = ErrorResponse::new(404, "Resource not found");
592
593        let json = serde_json::to_string(&error_resp).unwrap();
594        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
595
596        // 验证格式: { "error": { "code": 404, "message": "..." } }
597        assert!(parsed.get("error").is_some());
598        let error = parsed.get("error").unwrap();
599        assert_eq!(error.get("code").unwrap(), 404);
600        assert_eq!(error.get("message").unwrap(), "Resource not found");
601        assert!(error.get("details").is_none()); // 没有 details 时不序列化
602    }
603
604    #[test]
605    fn test_error_response_with_details() {
606        // 测试带详情的错误响应 / Test error response with details
607        let error_resp = ErrorResponse::tool_not_found("my_tool");
608
609        let json = serde_json::to_string(&error_resp).unwrap();
610        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
611
612        let error = parsed.get("error").unwrap();
613        assert_eq!(error.get("code").unwrap(), error_codes::TOOL_NOT_FOUND);
614        assert!(error
615            .get("message")
616            .unwrap()
617            .as_str()
618            .unwrap()
619            .contains("my_tool"));
620        assert!(error.get("details").is_some());
621        assert_eq!(
622            error.get("details").unwrap().get("tool_name").unwrap(),
623            "my_tool"
624        );
625    }
626
627    #[test]
628    fn test_error_response_convenience_constructors() {
629        // 测试便捷构造方法 / Test convenience constructors
630        assert_eq!(ErrorResponse::bad_request("test").error.code, 400);
631        assert_eq!(ErrorResponse::unauthorized("test").error.code, 401);
632        assert_eq!(ErrorResponse::forbidden("test").error.code, 403);
633        assert_eq!(ErrorResponse::not_found("test").error.code, 404);
634        assert_eq!(ErrorResponse::timeout("test").error.code, 408);
635        assert_eq!(ErrorResponse::internal_error("test").error.code, 500);
636        assert_eq!(ErrorResponse::tool_not_found("t").error.code, 4001);
637        assert_eq!(ErrorResponse::tool_execution_failed("t").error.code, 4003);
638        assert_eq!(ErrorResponse::tool_timeout(30).error.code, 4004);
639        assert_eq!(ErrorResponse::room_full("office1").error.code, 4101);
640        assert_eq!(ErrorResponse::not_in_room().error.code, 4103);
641    }
642
643    #[test]
644    fn test_error_response_roundtrip() {
645        // 测试序列化和反序列化往返 / Test serialization roundtrip
646        let original = ErrorResponse::new(500, "Internal error")
647            .with_detail("trace_id", serde_json::Value::String("abc123".to_string()));
648
649        let json = serde_json::to_string(&original).unwrap();
650        let deserialized: ErrorResponse = serde_json::from_str(&json).unwrap();
651
652        assert_eq!(original.error.code, deserialized.error.code);
653        assert_eq!(original.error.message, deserialized.error.message);
654        assert_eq!(
655            original.error.details.as_ref().unwrap().get("trace_id"),
656            deserialized.error.details.as_ref().unwrap().get("trace_id")
657        );
658    }
659}