Skip to main content

smcp_computer/
socketio_client.rs

1/*!
2* 文件名: socketio_client
3* 作者: JQQ
4* 创建日期: 2025/12/16
5* 最后修改日期: 2025/12/16
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tf_rust_socketio, tokio, serde
8* 描述: SMCP Computer的Socket.IO客户端实现 / Socket.IO client implementation for SMCP Computer
9*/
10
11use crate::desktop::{organize_desktop, WindowInfo};
12use crate::errors::{ComputerError, ComputerResult};
13use crate::mcp_clients::manager::MCPServerManager;
14use crate::mcp_clients::model::MCPServerInput;
15use futures_util::FutureExt;
16use serde_json::Value;
17use smcp::{
18    events::{
19        CLIENT_GET_CONFIG, CLIENT_GET_DESKTOP, CLIENT_GET_TOOLS, CLIENT_TOOL_CALL,
20        SERVER_JOIN_OFFICE, SERVER_LEAVE_OFFICE, SERVER_UPDATE_CONFIG, SERVER_UPDATE_DESKTOP,
21        SERVER_UPDATE_TOOL_LIST,
22    },
23    GetComputerConfigReq, GetComputerConfigRet, GetDesktopReq, GetDesktopRet, GetToolsReq,
24    GetToolsRet, ToolCallReq, SMCP_NAMESPACE,
25};
26use std::collections::HashMap;
27use std::sync::Arc;
28use tf_rust_socketio::{
29    asynchronous::{Client, ClientBuilder},
30    Event, Payload, TransportType,
31};
32use tokio::sync::RwLock;
33use tracing::{debug, error, info};
34
35/// SMCP Computer Socket.IO客户端
36/// SMCP Computer Socket.IO client
37pub struct SmcpComputerClient {
38    /// Socket.IO客户端实例 / Socket.IO client instance
39    client: Client,
40    /// Computer名称 / Computer name
41    computer_name: String,
42    /// 当前所在的office ID / Current office ID
43    office_id: Arc<RwLock<Option<String>>>,
44    /// 输入定义映射 / Input definitions map
45    #[allow(dead_code)]
46    inputs: Arc<RwLock<HashMap<String, MCPServerInput>>>,
47}
48
49impl SmcpComputerClient {
50    /// 创建新的Socket.IO客户端
51    /// Create a new Socket.IO client
52    pub async fn new(
53        url: &str,
54        manager: Arc<RwLock<Option<MCPServerManager>>>,
55        computer_name: String,
56        auth_secret: Option<String>,
57        inputs: Arc<RwLock<HashMap<String, MCPServerInput>>>,
58        headers: Option<HashMap<String, String>>,
59    ) -> ComputerResult<Self> {
60        let office_id = Arc::new(RwLock::new(None));
61        let manager_clone = manager.clone();
62        let computer_name_clone = computer_name.clone();
63        let office_id_clone = office_id.clone();
64        let inputs_clone = inputs.clone();
65
66        // 使用ClientBuilder注册事件处理器
67        // Use ClientBuilder to register event handlers
68        let mut builder = ClientBuilder::new(url)
69            .namespace(SMCP_NAMESPACE)
70            .transport_type(TransportType::Websocket);
71
72        // 如果提供了认证密钥,添加到请求头
73        // If auth secret is provided, add to request headers
74        if let Some(secret) = auth_secret {
75            builder = builder.opening_header("x-api-key", secret.as_str());
76        }
77
78        // 添加自定义 HTTP headers / Add custom HTTP headers
79        // Safety: opening_header 底层使用 http::HeaderValue::from_bytes() 做 RFC 7230 校验,
80        // 会拒绝包含 \r\n 等控制字符的恶意输入,无需额外防御 header injection。
81        // Safety: opening_header internally uses http::HeaderValue::from_bytes() for RFC 7230
82        // validation, rejecting \r\n and other control characters. No extra injection defense needed.
83        if let Some(custom_headers) = headers {
84            for (key, value) in custom_headers {
85                builder = builder.opening_header(key.as_str(), value.as_str());
86            }
87        }
88
89        let client = builder
90            .on_any(move |event, payload, client| {
91                // 只处理自定义事件
92                // Only handle custom events
93                let event_str = match event {
94                    Event::Custom(s) => s,
95                    _ => return async {}.boxed(),
96                };
97
98                match event_str.as_str() {
99                    CLIENT_TOOL_CALL => {
100                        let manager = manager_clone.clone();
101                        let computer_name = computer_name_clone.clone();
102                        let office_id = office_id_clone.clone();
103                        let client_clone = client.clone();
104                        let payload_clone = payload.clone();
105
106                        async move {
107                            match Self::handle_tool_call_with_ack(
108                                payload,
109                                manager,
110                                computer_name,
111                                office_id,
112                                client_clone,
113                            )
114                            .await
115                            {
116                                Ok((ack_id, response)) => {
117                                    if let Some(id) = ack_id {
118                                        if let Err(e) = client.ack_with_id(id, response).await {
119                                            error!("Failed to send ack: {}", e);
120                                        }
121                                    }
122                                }
123                                Err(e) => {
124                                    error!("Error handling tool call: {}", e);
125                                    // 尝试返回错误响应 / Try to return error response
126                                    if let Ok((Some(id), _)) = Self::extract_ack_id(payload_clone) {
127                                        let error_response = serde_json::json!({
128                                            "isError": true,
129                                            "content": [],
130                                            "structuredContent": {
131                                                "error": e.to_string(),
132                                                "error_type": "ComputerError"
133                                            }
134                                        });
135                                        let _ = client.ack_with_id(id, error_response).await;
136                                    }
137                                }
138                            }
139                        }
140                        .boxed()
141                    }
142                    CLIENT_GET_TOOLS => {
143                        let manager = manager_clone.clone();
144                        let computer_name = computer_name_clone.clone();
145                        let office_id = office_id_clone.clone();
146                        let client_clone = client.clone();
147
148                        async move {
149                            match Self::handle_get_tools_with_ack(
150                                payload,
151                                manager,
152                                computer_name,
153                                office_id,
154                                client_clone,
155                            )
156                            .await
157                            {
158                                Ok((ack_id, response)) => {
159                                    if let Some(id) = ack_id {
160                                        if let Err(e) = client.ack_with_id(id, response).await {
161                                            error!("Failed to send ack: {}", e);
162                                        }
163                                    }
164                                }
165                                Err(e) => {
166                                    error!("Error handling get tools: {}", e);
167                                }
168                            }
169                        }
170                        .boxed()
171                    }
172                    CLIENT_GET_CONFIG => {
173                        let manager = manager_clone.clone();
174                        let computer_name = computer_name_clone.clone();
175                        let office_id = office_id_clone.clone();
176                        let client_clone = client.clone();
177                        let inputs = inputs_clone.clone();
178
179                        async move {
180                            match Self::handle_get_config_with_ack(
181                                payload,
182                                manager,
183                                computer_name,
184                                office_id,
185                                client_clone,
186                                inputs,
187                            )
188                            .await
189                            {
190                                Ok((ack_id, response)) => {
191                                    if let Some(id) = ack_id {
192                                        if let Err(e) = client.ack_with_id(id, response).await {
193                                            error!("Failed to send ack: {}", e);
194                                        }
195                                    }
196                                }
197                                Err(e) => {
198                                    error!("Error handling get config: {}", e);
199                                }
200                            }
201                        }
202                        .boxed()
203                    }
204                    CLIENT_GET_DESKTOP => {
205                        let manager = manager_clone.clone();
206                        let computer_name = computer_name_clone.clone();
207                        let office_id = office_id_clone.clone();
208                        let client_clone = client.clone();
209
210                        async move {
211                            match Self::handle_get_desktop_with_ack(
212                                payload,
213                                manager,
214                                computer_name,
215                                office_id,
216                                client_clone,
217                            )
218                            .await
219                            {
220                                Ok((ack_id, response)) => {
221                                    if let Some(id) = ack_id {
222                                        if let Err(e) = client.ack_with_id(id, response).await {
223                                            error!("Failed to send ack: {}", e);
224                                        }
225                                    }
226                                }
227                                Err(e) => {
228                                    error!("Error handling get desktop: {}", e);
229                                }
230                            }
231                        }
232                        .boxed()
233                    }
234                    _ => {
235                        debug!("Unhandled event: {}", event_str);
236                        async {}.boxed()
237                    }
238                }
239            })
240            .connect()
241            .await
242            .map_err(|e| ComputerError::SocketIoError(format!("Failed to connect: {}", e)))?;
243
244        // 等待一小段时间确保 Socket.IO namespace 连接完全建立
245        // Wait a short time to ensure Socket.IO namespace connection is fully established
246        // Socket.IO 有两个连接阶段:Transport 层和 Namespace 层
247        // Socket.IO has two connection phases: Transport layer and Namespace layer
248        // connect() 只保证 Transport 层连接,namespace 连接是异步的
249        // connect() only guarantees Transport layer connection, namespace connection is async
250        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
251
252        info!(
253            "Connected to SMCP server at {} with computer name: {}",
254            url, computer_name
255        );
256
257        Ok(Self {
258            client,
259            computer_name,
260            office_id,
261            inputs,
262        })
263    }
264
265    /// 加入Office(Socket.IO Room)
266    /// Join an Office (Socket.IO Room)
267    pub async fn join_office(&self, office_id: &str) -> ComputerResult<()> {
268        debug!("Joining office: {}", office_id);
269
270        // 先设置office_id
271        // Set office_id first
272        *self.office_id.write().await = Some(office_id.to_string());
273
274        let req_data = serde_json::json!({
275            "office_id": office_id,
276            "role": "computer",
277            "name": self.computer_name
278        });
279
280        // 使用call方法等待服务器响应
281        // Use call method to wait for server response
282        match self.call(SERVER_JOIN_OFFICE, req_data, Some(10)).await {
283            Ok(response) => {
284                // 服务器返回的是 (bool, Option<String>) 元组序列化后的数组
285                // Server returns serialized array of (bool, Option<String>) tuple
286                debug!("Join office response: {:?}", response);
287
288                // 检查响应是否包含嵌套数组
289                // Check if response contains nested array
290                let actual_response = if response.len() == 1 {
291                    if let Some(arr) = response.first().and_then(|v| v.as_array()) {
292                        arr.to_vec()
293                    } else {
294                        response
295                    }
296                } else {
297                    response
298                };
299
300                if !actual_response.is_empty() {
301                    if let Some(success) = actual_response.first().and_then(|v| v.as_bool()) {
302                        if success {
303                            info!("Successfully joined office: {}", office_id);
304                            Ok(())
305                        } else {
306                            // 加入失败,重置office_id / Reset office_id on failure
307                            *self.office_id.write().await = None;
308                            let error_msg = actual_response
309                                .get(1)
310                                .and_then(|v| v.as_str())
311                                .unwrap_or("Unknown error");
312                            Err(ComputerError::SocketIoError(format!(
313                                "Failed to join office: {}",
314                                error_msg
315                            )))
316                        }
317                    } else {
318                        *self.office_id.write().await = None;
319                        Err(ComputerError::SocketIoError(format!(
320                            "Invalid response format from server: {:?}",
321                            actual_response
322                        )))
323                    }
324                } else {
325                    *self.office_id.write().await = None;
326                    Err(ComputerError::SocketIoError(
327                        "Empty response from server".to_string(),
328                    ))
329                }
330            }
331            Err(e) => {
332                *self.office_id.write().await = None;
333                Err(e)
334            }
335        }
336    }
337
338    /// 获取当前Office ID / Get current Office ID
339    pub async fn get_current_office_id(&self) -> ComputerResult<String> {
340        let office_id = self.office_id.read().await;
341        match office_id.as_ref() {
342            Some(id) => Ok(id.clone()),
343            None => Err(ComputerError::InvalidState(
344                "Not currently in any office".to_string(),
345            )),
346        }
347    }
348
349    /// 离开Office
350    /// Leave an Office
351    pub async fn leave_office(&self, office_id: &str) -> ComputerResult<()> {
352        debug!("Leaving office: {}", office_id);
353
354        let req_data = serde_json::json!({
355            "office_id": office_id
356        });
357
358        self.emit(SERVER_LEAVE_OFFICE, req_data).await?;
359        *self.office_id.write().await = None;
360
361        info!("Left office: {}", office_id);
362        Ok(())
363    }
364
365    /// 发送配置更新通知
366    /// Emit config update notification
367    pub async fn emit_update_config(&self) -> ComputerResult<()> {
368        let office_id = self.office_id.read().await;
369        if office_id.is_some() {
370            let req_data = serde_json::json!({
371                "computer": self.computer_name
372            });
373            self.emit(SERVER_UPDATE_CONFIG, req_data).await?;
374            info!("Emitted config update notification");
375        }
376        Ok(())
377    }
378
379    /// 发送工具列表更新通知
380    /// Emit tool list update notification
381    pub async fn emit_update_tool_list(&self) -> ComputerResult<()> {
382        let office_id = self.office_id.read().await;
383        if office_id.is_some() {
384            let req_data = serde_json::json!({
385                "computer": self.computer_name
386            });
387            self.emit(SERVER_UPDATE_TOOL_LIST, req_data).await?;
388            info!("Emitted tool list update notification");
389        }
390        Ok(())
391    }
392
393    /// 发送桌面更新通知
394    /// Emit desktop update notification
395    pub async fn emit_update_desktop(&self) -> ComputerResult<()> {
396        let office_id = self.office_id.read().await;
397        if office_id.is_some() {
398            let req_data = serde_json::json!({
399                "computer": self.computer_name
400            });
401            self.emit(SERVER_UPDATE_DESKTOP, req_data).await?;
402            info!("Emitted desktop update notification");
403        }
404        Ok(())
405    }
406
407    /// 发送事件(不等待响应)
408    /// Emit event without waiting for response
409    async fn emit(&self, event: &str, data: Value) -> ComputerResult<()> {
410        // 检查事件名 policy / Check event name policy
411        if event.starts_with("notify:") || event.starts_with("client:") {
412            return Err(ComputerError::InvalidState(
413                format!(
414                    "Computer 不允许发送 notify:* 或 client:* 事件 / Computer cannot send notify:* or client:* events: {}",
415                    event
416                )
417            ));
418        }
419
420        debug!("Emitting event: {}", event);
421
422        self.client
423            .emit(event, Payload::Text(vec![data], None))
424            .await
425            .map_err(|e| ComputerError::SocketIoError(format!("Failed to emit {}: {}", event, e)))
426    }
427
428    /// 发送事件并等待响应
429    /// Emit event and wait for response
430    async fn call(
431        &self,
432        event: &str,
433        data: Value,
434        timeout_secs: Option<u64>,
435    ) -> ComputerResult<Vec<Value>> {
436        // 检查事件名 policy / Check event name policy
437        if event.starts_with("notify:") || event.starts_with("client:") {
438            return Err(ComputerError::InvalidState(
439                format!(
440                    "Computer 不允许发送 notify:* 或 client:* 事件 / Computer cannot send notify:* or client:* events: {}",
441                    event
442                )
443            ));
444        }
445
446        let timeout = std::time::Duration::from_secs(timeout_secs.unwrap_or(30));
447        debug!("Calling event: {} with timeout {:?}", event, timeout);
448
449        let (tx, rx) = tokio::sync::oneshot::channel();
450        let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
451
452        let callback = move |payload: Payload, _client: Client| {
453            if let Some(tx_opt) = tx.try_lock().ok().and_then(|mut m| m.take()) {
454                let _ = tx_opt.send(payload);
455            }
456            async {}.boxed()
457        };
458
459        self.client
460            .emit_with_ack(event, Payload::Text(vec![data], None), timeout, callback)
461            .await
462            .map_err(|e| {
463                ComputerError::SocketIoError(format!("Failed to call {}: {}", event, e))
464            })?;
465
466        // 使用 tokio::time::timeout 来确保 rx.await 不会无限期等待
467        // Use tokio::time::timeout to ensure rx.await doesn't wait forever
468        match tokio::time::timeout(timeout, rx).await {
469            Ok(Ok(response)) => {
470                // 从响应中提取JSON数据 / Extract JSON data from response
471                match response {
472                    Payload::Text(values, _) => {
473                        debug!("Received response: {:?}", values);
474                        Ok(values)
475                    }
476                    #[allow(deprecated)]
477                    Payload::String(s, _) => {
478                        // 尝试解析字符串为JSON数组
479                        // Try to parse string as JSON array
480                        let parsed: Vec<Value> = serde_json::from_str(&s).map_err(|e| {
481                            ComputerError::SocketIoError(format!("Failed to parse response: {}", e))
482                        })?;
483                        debug!("Received parsed response: {:?}", parsed);
484                        Ok(parsed)
485                    }
486                    Payload::Binary(_, _) => Err(ComputerError::SocketIoError(
487                        "Binary response not supported".to_string(),
488                    )),
489                }
490            }
491            Ok(Err(_)) => {
492                error!("Channel closed while calling event: {}", event);
493                Err(ComputerError::SocketIoError(
494                    "Channel closed while waiting for response".to_string(),
495                ))
496            }
497            Err(_) => {
498                error!("Timeout while calling event: {}", event);
499                Err(ComputerError::SocketIoError(
500                    "Timeout while waiting for response".to_string(),
501                ))
502            }
503        }
504    }
505
506    /// 处理工具调用事件(带ACK响应)
507    /// Handle tool call event (with ACK response)
508    async fn handle_tool_call_with_ack(
509        payload: Payload,
510        manager: Arc<RwLock<Option<MCPServerManager>>>,
511        computer_name: String,
512        _office_id: Arc<RwLock<Option<String>>>,
513        _client: Client,
514    ) -> ComputerResult<(Option<i32>, Value)> {
515        let (ack_id, req) = Self::extract_ack_and_parse::<ToolCallReq>(payload)?;
516
517        // 验证 computer_name(Server 路由已保证请求来自同一 office,无需验证 agent 字段)
518        // Validate computer_name (Server routing ensures request is from same office, no need to validate agent field)
519        if computer_name != req.computer {
520            return Err(ComputerError::ValidationError(format!(
521                "Computer name mismatch: expected {}, got {}",
522                computer_name, req.computer
523            )));
524        }
525
526        // 执行工具调用 / Execute tool call
527        let result = {
528            let manager_guard = manager.read().await;
529            match manager_guard.as_ref() {
530                Some(mgr) => {
531                    mgr.execute_tool(
532                        &req.tool_name,
533                        req.params,
534                        Some(std::time::Duration::from_secs(req.timeout as u64)),
535                    )
536                    .await?
537                }
538                None => {
539                    return Err(ComputerError::InvalidState(
540                        "MCP Manager not initialized".to_string(),
541                    ));
542                }
543            }
544        };
545
546        let result_value =
547            serde_json::to_value(result).map_err(ComputerError::SerializationError)?;
548
549        info!("Tool call executed successfully: {}", req.tool_name);
550        Ok((ack_id, result_value))
551    }
552
553    /// 处理获取工具列表事件(带ACK响应)
554    /// Handle get tools event (with ACK response)
555    async fn handle_get_tools_with_ack(
556        payload: Payload,
557        manager: Arc<RwLock<Option<MCPServerManager>>>,
558        computer_name: String,
559        _office_id: Arc<RwLock<Option<String>>>,
560        _client: Client,
561    ) -> ComputerResult<(Option<i32>, Value)> {
562        let (ack_id, req) = Self::extract_ack_and_parse::<GetToolsReq>(payload)?;
563
564        // 验证 computer_name(Server 路由已保证请求来自同一 office,无需验证 agent 字段)
565        // Validate computer_name (Server routing ensures request is from same office, no need to validate agent field)
566        if computer_name != req.computer {
567            return Err(ComputerError::ValidationError(format!(
568                "Computer name mismatch: expected {}, got {}",
569                computer_name, req.computer
570            )));
571        }
572
573        // 获取工具列表 / Get tools list
574        let tools: Vec<smcp::SMCPTool> = {
575            let manager_guard = manager.read().await;
576            match manager_guard.as_ref() {
577                Some(mgr) => {
578                    // 转换Tool为SMCPTool
579                    // Convert Tool to SMCPTool
580                    let tool_list = mgr.list_available_tools().await;
581                    tool_list
582                        .into_iter()
583                        .map(convert_tool_to_smcp_tool)
584                        .collect()
585                }
586                None => {
587                    return Err(ComputerError::InvalidState(
588                        "MCP Manager not initialized".to_string(),
589                    ));
590                }
591            }
592        };
593
594        let response = GetToolsRet {
595            tools: tools.clone(),
596            req_id: req.base.req_id,
597        };
598
599        info!(
600            "Returned {} tools for agent {}",
601            tools.len(),
602            req.base.agent
603        );
604        Ok((ack_id, serde_json::to_value(response)?))
605    }
606
607    /// 处理获取配置事件(带ACK响应)
608    /// Handle get config event (with ACK response)
609    async fn handle_get_config_with_ack(
610        payload: Payload,
611        manager: Arc<RwLock<Option<MCPServerManager>>>,
612        computer_name: String,
613        _office_id: Arc<RwLock<Option<String>>>,
614        _client: Client,
615        inputs: Arc<RwLock<HashMap<String, MCPServerInput>>>,
616    ) -> ComputerResult<(Option<i32>, Value)> {
617        let (ack_id, req) = Self::extract_ack_and_parse::<GetComputerConfigReq>(payload)?;
618
619        // 验证 computer_name(Server 路由已保证请求来自同一 office,无需验证 agent 字段)
620        // Validate computer_name (Server routing ensures request is from same office, no need to validate agent field)
621        if computer_name != req.computer {
622            return Err(ComputerError::ValidationError(format!(
623                "Computer name mismatch: expected {}, got {}",
624                computer_name, req.computer
625            )));
626        }
627
628        // 获取配置 / Get config
629        let servers = {
630            let manager_guard = manager.read().await;
631            match manager_guard.as_ref() {
632                Some(mgr) => {
633                    // 获取完整服务器配置(不只是状态)
634                    // Get complete server configurations (not just status)
635                    mgr.get_server_configs().await
636                }
637                None => {
638                    return Err(ComputerError::InvalidState(
639                        "MCP Manager not initialized".to_string(),
640                    ));
641                }
642            }
643        };
644
645        // 获取输入定义 / Get input definitions
646        // 将 HashMap<String, MCPServerInput> 转换为 Vec<serde_json::Value>
647        // Convert HashMap<String, MCPServerInput> to Vec<serde_json::Value>
648        let inputs_data = {
649            let inputs_guard = inputs.read().await;
650            if inputs_guard.is_empty() {
651                None
652            } else {
653                let inputs_vec: Vec<serde_json::Value> = inputs_guard
654                    .values()
655                    .filter_map(|input| serde_json::to_value(input).ok())
656                    .collect();
657                if inputs_vec.is_empty() {
658                    None
659                } else {
660                    Some(inputs_vec)
661                }
662            }
663        };
664
665        let response = GetComputerConfigRet {
666            servers,
667            inputs: inputs_data,
668        };
669
670        info!("Returned config for agent {}", req.base.agent);
671        Ok((ack_id, serde_json::to_value(response)?))
672    }
673
674    /// 处理获取桌面事件(带ACK响应)
675    /// Handle get desktop event (with ACK response)
676    async fn handle_get_desktop_with_ack(
677        payload: Payload,
678        manager: Arc<RwLock<Option<MCPServerManager>>>,
679        computer_name: String,
680        _office_id: Arc<RwLock<Option<String>>>,
681        _client: Client,
682    ) -> ComputerResult<(Option<i32>, Value)> {
683        let (ack_id, req) = Self::extract_ack_and_parse::<GetDesktopReq>(payload)?;
684
685        // 验证 computer_name(Server 路由已保证请求来自同一 office,无需验证 agent 字段)
686        // Validate computer_name (Server routing ensures request is from same office, no need to validate agent field)
687        if computer_name != req.computer {
688            return Err(ComputerError::ValidationError(format!(
689                "Computer name mismatch: expected {}, got {}",
690                computer_name, req.computer
691            )));
692        }
693
694        // 获取桌面窗口信息 / Get desktop window info
695        let desktops = {
696            let mgr_guard = manager.read().await;
697            if let Some(mgr) = mgr_guard.as_ref() {
698                let raw_windows = mgr.get_windows_details(req.window.as_deref()).await;
699                let windows: Vec<WindowInfo> = raw_windows
700                    .into_iter()
701                    .map(|(server_name, resource, read_result)| {
702                        WindowInfo::new(server_name, resource, read_result)
703                    })
704                    .collect();
705                organize_desktop(windows, req.desktop_size.map(|s| s as usize), &[])
706            } else {
707                Vec::new()
708            }
709        };
710
711        let response = GetDesktopRet {
712            desktops: Some(desktops),
713            req_id: req.base.req_id,
714        };
715
716        info!("Returned desktop for agent {}", req.base.agent);
717        Ok((ack_id, serde_json::to_value(response)?))
718    }
719
720    /// 从payload中提取ack_id并解析数据
721    /// Extract ack_id from payload and parse data
722    fn extract_ack_and_parse<T: serde::de::DeserializeOwned>(
723        payload: Payload,
724    ) -> ComputerResult<(Option<i32>, T)> {
725        match payload {
726            Payload::Text(mut values, ack_id) => {
727                if let Some(value) = values.pop() {
728                    let req =
729                        serde_json::from_value(value).map_err(ComputerError::SerializationError)?;
730                    Ok((ack_id, req))
731                } else {
732                    Err(ComputerError::ProtocolError("Empty payload".to_string()))
733                }
734            }
735            #[allow(deprecated)]
736            Payload::String(s, ack_id) => {
737                let req = serde_json::from_str(&s).map_err(ComputerError::SerializationError)?;
738                Ok((ack_id, req))
739            }
740            Payload::Binary(_, _) => Err(ComputerError::SocketIoError(
741                "Binary payload not supported".to_string(),
742            )),
743        }
744    }
745
746    /// 仅提取ack_id(用于错误处理)
747    /// Extract ack_id only (for error handling)
748    fn extract_ack_id(payload: Payload) -> ComputerResult<(Option<i32>, ())> {
749        match payload {
750            Payload::Text(_, ack_id) => Ok((ack_id, ())),
751            #[allow(deprecated)]
752            Payload::String(_, ack_id) => Ok((ack_id, ())),
753            Payload::Binary(_, _) => Ok((None, ())),
754        }
755    }
756
757    /// 断开连接
758    /// Disconnect from server
759    pub async fn disconnect(self) -> ComputerResult<()> {
760        debug!("Disconnecting from server");
761        self.client
762            .disconnect()
763            .await
764            .map_err(|e| ComputerError::SocketIoError(format!("Failed to disconnect: {}", e)))?;
765        info!("Disconnected from server");
766        Ok(())
767    }
768
769    /// 获取当前office ID
770    /// Get current office ID
771    pub async fn get_office_id(&self) -> Option<String> {
772        self.office_id.read().await.clone()
773    }
774
775    /// 获取连接的 URL
776    /// Get connected URL
777    pub fn get_url(&self) -> String {
778        // 由于 tf_rust_socketio 的 Client 没有 uri() 方法,返回默认值
779        // Since tf_rust_socketio Client doesn't have uri() method, return default
780        "unknown".to_string()
781    }
782
783    /// 获取连接的 namespace
784    /// Get connected namespace
785    pub fn get_namespace(&self) -> String {
786        // 从 client 中获取 namespace,如果无法获取则返回默认值
787        // Get namespace from client, return default if unable to get
788        "/smcp".to_string()
789    }
790}
791
792/// 将内部 Tool 转换为协议类型 SMCPTool
793/// Convert internal Tool to protocol type SMCPTool
794pub(crate) fn convert_tool_to_smcp_tool(tool: crate::mcp_clients::model::Tool) -> smcp::SMCPTool {
795    let mut meta_map = serde_json::Map::new();
796
797    // 传递 tool.meta 中的所有键值(如 a2c_tool_meta)
798    // 值需要序列化为 JSON 字符串,与 Python SDK 对齐
799    if let Some(existing_meta) = &tool.meta {
800        for (k, v) in existing_meta.iter() {
801            let str_val = if v.is_string() {
802                v.as_str().unwrap().to_string()
803            } else {
804                serde_json::to_string(v).unwrap_or_default()
805            };
806            meta_map.insert(k.clone(), serde_json::Value::String(str_val));
807        }
808    }
809
810    // 添加 MCP_TOOL_ANNOTATION
811    if let Some(annotations) = &tool.annotations {
812        if let Ok(json_str) = serde_json::to_string(annotations) {
813            meta_map.insert(
814                "MCP_TOOL_ANNOTATION".to_string(),
815                serde_json::Value::String(json_str),
816            );
817        }
818    }
819
820    let meta = if meta_map.is_empty() {
821        None
822    } else {
823        Some(serde_json::Value::Object(meta_map))
824    };
825
826    let description = tool.description.as_deref().unwrap_or("").to_string();
827    let params_schema = tool.schema_as_json_value();
828    smcp::SMCPTool {
829        name: tool.name.to_string(),
830        description,
831        params_schema,
832        return_schema: None,
833        meta,
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use crate::mcp_clients::model::{Tool, ToolAnnotations};
841    use serde_json::json;
842
843    fn make_tool(
844        meta: Option<serde_json::Map<String, serde_json::Value>>,
845        annotations: Option<ToolAnnotations>,
846    ) -> Tool {
847        use std::sync::Arc;
848        let input_schema: serde_json::Map<String, serde_json::Value> =
849            serde_json::from_value(json!({"type": "object"})).unwrap();
850        Tool {
851            name: "test_tool".into(),
852            title: None,
853            description: Some("A test tool".into()),
854            input_schema: Arc::new(input_schema),
855            output_schema: None,
856            annotations,
857            icons: None,
858            meta: meta.map(rmcp::model::Meta),
859        }
860    }
861
862    #[test]
863    fn test_tool_to_smcp_tool_with_meta_and_annotations() {
864        let mut meta = serde_json::Map::new();
865        meta.insert(
866            "a2c_tool_meta".to_string(),
867            json!({"tags": ["browser"], "priority": 1}),
868        );
869        let annotations = ToolAnnotations {
870            title: Some("Test".to_string()),
871            read_only_hint: Some(false),
872            destructive_hint: Some(false),
873            idempotent_hint: None,
874            open_world_hint: Some(false),
875        };
876        let smcp_tool = convert_tool_to_smcp_tool(make_tool(Some(meta), Some(annotations)));
877
878        let meta_obj = smcp_tool.meta.unwrap();
879        let meta_map = meta_obj.as_object().unwrap();
880        assert!(meta_map.contains_key("a2c_tool_meta"));
881        assert!(meta_map.contains_key("MCP_TOOL_ANNOTATION"));
882        // Values should be JSON strings
883        assert!(meta_map["a2c_tool_meta"].is_string());
884        assert!(meta_map["MCP_TOOL_ANNOTATION"].is_string());
885    }
886
887    #[test]
888    fn test_tool_to_smcp_tool_only_meta() {
889        let mut meta = serde_json::Map::new();
890        meta.insert("a2c_tool_meta".to_string(), json!({"tags": ["fs"]}));
891        let smcp_tool = convert_tool_to_smcp_tool(make_tool(Some(meta), None));
892
893        let meta_obj = smcp_tool.meta.unwrap();
894        let meta_map = meta_obj.as_object().unwrap();
895        assert_eq!(meta_map.len(), 1);
896        assert!(meta_map.contains_key("a2c_tool_meta"));
897    }
898
899    #[test]
900    fn test_tool_to_smcp_tool_only_annotations() {
901        let annotations = ToolAnnotations {
902            title: Some("My Tool".to_string()),
903            read_only_hint: Some(true),
904            destructive_hint: Some(false),
905            idempotent_hint: None,
906            open_world_hint: Some(false),
907        };
908        let smcp_tool = convert_tool_to_smcp_tool(make_tool(None, Some(annotations)));
909
910        let meta_obj = smcp_tool.meta.unwrap();
911        let meta_map = meta_obj.as_object().unwrap();
912        assert_eq!(meta_map.len(), 1);
913        assert!(meta_map.contains_key("MCP_TOOL_ANNOTATION"));
914    }
915
916    #[test]
917    fn test_tool_to_smcp_tool_no_meta_no_annotations() {
918        let smcp_tool = convert_tool_to_smcp_tool(make_tool(None, None));
919        assert!(smcp_tool.meta.is_none());
920    }
921
922    #[test]
923    fn test_tool_to_smcp_tool_string_value_not_double_serialized() {
924        let mut meta = serde_json::Map::new();
925        meta.insert(
926            "simple_key".to_string(),
927            serde_json::Value::String("already_a_string".to_string()),
928        );
929        let smcp_tool = convert_tool_to_smcp_tool(make_tool(Some(meta), None));
930
931        let meta_obj = smcp_tool.meta.unwrap();
932        let meta_map = meta_obj.as_object().unwrap();
933        // Should be the raw string, not "\"already_a_string\""
934        assert_eq!(meta_map["simple_key"].as_str().unwrap(), "already_a_string");
935    }
936}