Skip to main content

smcp_computer/mcp_clients/
sse_client.rs

1/**
2* 文件名: sse_client
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, reqwest, eventsource-client, serde_json
8* 描述: SSE类型的MCP客户端实现
9*/
10use super::base_client::BaseMCPClient;
11use super::model::*;
12use super::{ResourceCache, SubscriptionManager};
13use crate::desktop::window_uri::{is_window_uri, WindowURI};
14use async_trait::async_trait;
15use es::Client as EsClient;
16use eventsource_client as es;
17use futures::stream::{Stream, StreamExt};
18use serde_json;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::{mpsc, Mutex};
23use tracing::{debug, error, info, warn};
24
25/// SSE MCP客户端 / SSE MCP client
26pub struct SseMCPClient {
27    /// 基础客户端 / Base client
28    base: BaseMCPClient<SseServerParameters>,
29    /// HTTP客户端 / HTTP client
30    http_client: reqwest::Client,
31    /// 请求发送器 / Request sender
32    request_tx: Arc<Mutex<Option<mpsc::UnboundedSender<serde_json::Value>>>>,
33    /// 响应接收器 / Response receiver
34    response_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<serde_json::Value>>>>,
35    /// 会话ID / Session ID
36    session_id: Arc<Mutex<Option<String>>>,
37    /// SSE 服务器告知的 POST 端点 URL
38    endpoint_url: Arc<Mutex<Option<String>>>,
39    /// 订阅管理器 / Subscription manager
40    subscription_manager: SubscriptionManager,
41    /// 资源缓存 / Resource cache
42    resource_cache: ResourceCache,
43    /// 资源更新通知发送器 / Resource update notification sender
44    update_tx: Arc<Mutex<Option<mpsc::UnboundedSender<ResourceUpdate>>>>,
45}
46
47/// 资源更新通知
48#[derive(Debug, Clone)]
49pub struct ResourceUpdate {
50    /// 资源 URI
51    pub uri: String,
52    /// 新数据
53    pub data: serde_json::Value,
54    /// 版本号
55    pub version: u64,
56}
57
58impl std::fmt::Debug for SseMCPClient {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("SseMCPClient")
61            .field("url", &self.base.params.url)
62            .field("headers", &self.base.params.headers)
63            .field("state", &self.base.state())
64            .finish()
65    }
66}
67
68impl SseMCPClient {
69    /// 创建新的SSE客户端 / Create new SSE client
70    pub fn new(params: SseServerParameters) -> Self {
71        let http_client = reqwest::Client::builder()
72            .timeout(std::time::Duration::from_secs(30))
73            .build()
74            .expect("Failed to create HTTP client");
75
76        Self {
77            base: BaseMCPClient::new(params),
78            http_client,
79            request_tx: Arc::new(Mutex::new(None)),
80            response_rx: Arc::new(Mutex::new(None)),
81            session_id: Arc::new(Mutex::new(None)),
82            endpoint_url: Arc::new(Mutex::new(None)),
83            subscription_manager: SubscriptionManager::new(),
84            resource_cache: ResourceCache::new(Duration::from_secs(60)), // 默认 60 秒 TTL
85            update_tx: Arc::new(Mutex::new(None)),
86        }
87    }
88
89    /// 发送JSON-RPC请求 / Send JSON-RPC request
90    async fn send_request(
91        &self,
92        method: &str,
93        params: Option<serde_json::Value>,
94    ) -> Result<serde_json::Value, MCPClientError> {
95        let mut request_body = serde_json::json!({
96            "jsonrpc": "2.0",
97            "method": method,
98        });
99
100        if let Some(p) = params {
101            request_body["params"] = p;
102        }
103
104        // 通知类消息不需要等待响应 / Notifications don't need a response
105        let is_notification = method.starts_with("notifications/");
106
107        // 仅非 notification 时添加请求ID / Only add request ID for non-notifications
108        if !is_notification {
109            let request_id = std::time::SystemTime::now()
110                .duration_since(std::time::UNIX_EPOCH)
111                .unwrap()
112                .as_secs() as i64;
113            request_body["id"] = serde_json::Value::Number(serde_json::Number::from(request_id));
114        }
115
116        debug!("Sending SSE request: {}", request_body);
117
118        // 通过SSE发送请求 / Send request via SSE
119        let tx = self.request_tx.lock().await;
120        if let Some(ref tx) = *tx {
121            tx.send(request_body.clone()).map_err(|e| {
122                MCPClientError::ConnectionError(format!("Failed to send request: {}", e))
123            })?;
124        } else {
125            return Err(MCPClientError::ConnectionError(
126                "SSE connection not established".to_string(),
127            ));
128        }
129        drop(tx);
130
131        if is_notification {
132            return Ok(serde_json::json!({}));
133        }
134
135        // 等待响应(带超时)/ Wait for response (with timeout)
136        let mut rx = self.response_rx.lock().await;
137        if let Some(ref mut receiver) = *rx {
138            match tokio::time::timeout(Duration::from_secs(30), receiver.recv()).await {
139                Ok(Some(response)) => {
140                    debug!("Received SSE response: {}", response);
141                    Ok(response)
142                }
143                Ok(None) => Err(MCPClientError::ConnectionError(
144                    "Response channel closed".to_string(),
145                )),
146                Err(_) => Err(MCPClientError::TimeoutError(
147                    "SSE response timed out after 30s".to_string(),
148                )),
149            }
150        } else {
151            Err(MCPClientError::ConnectionError(
152                "Response channel not established".to_string(),
153            ))
154        }
155    }
156
157    /// 启动SSE连接 / Start SSE connection
158    async fn start_sse_connection(&self) -> Result<(), MCPClientError> {
159        let url = &self.base.params.url;
160
161        // 直接使用原始 URL,不拼接 ?events=true
162        let mut builder = es::ClientBuilder::for_url(url)
163            .map_err(|e| MCPClientError::ConnectionError(format!("Invalid SSE URL: {:?}", e)))?;
164
165        // 添加headers / Add headers
166        for (key, value) in &self.base.params.headers {
167            builder = builder.header(key, value).map_err(|e| {
168                MCPClientError::ConnectionError(format!("Failed to add header {}: {:?}", key, e))
169            })?;
170        }
171
172        let es_client = builder.build();
173
174        // 创建通信通道 / Create communication channels
175        let (request_tx, request_rx) = mpsc::unbounded_channel::<serde_json::Value>();
176        let (response_tx, response_rx) = mpsc::unbounded_channel::<serde_json::Value>();
177
178        *self.request_tx.lock().await = Some(request_tx);
179        *self.response_rx.lock().await = Some(response_rx);
180
181        // 克隆资源缓存和更新通知发送器,用于 SSE 事件处理
182        let resource_cache = self.resource_cache.clone();
183        let update_tx = self.update_tx.clone();
184        let endpoint_url = self.endpoint_url.clone();
185        let base_url = url.clone();
186        let http_client = self.http_client.clone();
187        let headers = self.base.params.headers.clone();
188
189        // 启动SSE事件处理任务 / Start SSE event handling task
190        let stream: Pin<Box<dyn Stream<Item = Result<es::SSE, es::Error>> + Send + Sync>> =
191            es_client.stream();
192
193        tokio::spawn(async move {
194            let mut stream = Box::pin(stream);
195            let mut request_rx = Box::pin(request_rx);
196
197            loop {
198                tokio::select! {
199                    // 处理SSE事件 / Handle SSE events
200                    Some(event_result) = stream.next() => {
201                        match event_result {
202                            Ok(event) => {
203                                debug!("Received SSE event: {:?}", event);
204
205                                match event {
206                                    es::SSE::Event(event_data) => {
207                                        // 处理 endpoint 事件 / Handle endpoint event
208                                        if event_data.event_type == "endpoint" {
209                                            let endpoint = event_data.data.trim();
210                                            // 支持相对路径 / Support relative paths
211                                            let resolved = resolve_endpoint_url(&base_url, endpoint);
212                                            info!("SSE endpoint resolved: {}", resolved);
213                                            *endpoint_url.lock().await = Some(resolved);
214                                            continue;
215                                        }
216
217                                        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&event_data.data) {
218                                            // 区分消息类型 / Distinguish message types
219
220                                            // 检查是否是资源更新通知
221                                            if let Some(method) = value.get("method").and_then(|m| m.as_str()) {
222                                                if method == "resources/update" || method.contains("update") {
223                                                    debug!("Received resource update notification");
224
225                                                    // 提取 URI 和数据
226                                                    if let Some(params) = value.get("params") {
227                                                        if let Some(uri) = params.get("uri").and_then(|u| u.as_str()) {
228                                                            // 刷新缓存
229                                                            if let Some(data) = params.get("data") {
230                                                                let _ = resource_cache.refresh(uri, data.clone()).await;
231
232                                                                // 发送更新通知
233                                                                if let Some(tx) = update_tx.lock().await.as_ref() {
234                                                                    let _ = tx.send(ResourceUpdate {
235                                                                        uri: uri.to_string(),
236                                                                        data: data.clone(),
237                                                                        version: 1,
238                                                                    });
239                                                                }
240                                                            }
241                                                        }
242                                                    }
243                                                } else {
244                                                    // 其他通知,也发送到 response channel
245                                                    let _ = response_tx.send(value);
246                                                }
247                                            } else {
248                                                // JSON-RPC 响应
249                                                let _ = response_tx.send(value);
250                                            }
251                                        }
252                                    }
253                                    es::SSE::Comment(_) => {
254                                        debug!("Received SSE comment");
255                                    }
256                                    es::SSE::Connected(_) => {
257                                        debug!("SSE connection established");
258                                    }
259                                }
260                            }
261                            Err(e) => {
262                                error!("SSE event error: {:?}", e);
263                                break;
264                            }
265                        }
266                    }
267
268                    // 处理请求发送 / Handle request sending via HTTP POST
269                    Some(request) = request_rx.recv() => {
270                        debug!("Sending request via HTTP POST: {}", request);
271
272                        let post_url = match endpoint_url.lock().await.clone() {
273                            Some(url) => url,
274                            None => {
275                                error!("No endpoint URL available for POST request");
276                                continue;
277                            }
278                        };
279
280                        let mut req = http_client.post(&post_url)
281                            .header("Content-Type", "application/json");
282
283                        // 添加用户配置的 headers
284                        for (key, value) in &headers {
285                            req = req.header(key, value);
286                        }
287
288                        match req.json(&request).send().await {
289                            Ok(resp) => {
290                                if resp.status().is_success() {
291                                    // 检查 Content-Type,如果是 JSON 直接解析为响应
292                                    let ct = resp.headers()
293                                        .get("content-type")
294                                        .and_then(|v| v.to_str().ok())
295                                        .unwrap_or("")
296                                        .to_string();
297
298                                    if ct.contains("application/json") {
299                                        match resp.json::<serde_json::Value>().await {
300                                            Ok(json_resp) => {
301                                                let _ = response_tx.send(json_resp);
302                                            }
303                                            Err(e) => {
304                                                error!("Failed to parse POST JSON response: {}", e);
305                                                let error_json = serde_json::json!({
306                                                    "jsonrpc": "2.0",
307                                                    "error": {
308                                                        "code": -32603,
309                                                        "message": format!("Failed to parse POST JSON response: {}", e)
310                                                    }
311                                                });
312                                                let _ = response_tx.send(error_json);
313                                            }
314                                        }
315                                    }
316                                    // 如果是 SSE 响应,数据会通过 SSE stream 返回,无需在此处理
317                                } else {
318                                    let status = resp.status();
319                                    error!("POST request failed with status: {}", status);
320                                    let error_json = serde_json::json!({
321                                        "jsonrpc": "2.0",
322                                        "error": {
323                                            "code": -32603,
324                                            "message": format!("POST request failed with status: {}", status)
325                                        }
326                                    });
327                                    let _ = response_tx.send(error_json);
328                                }
329                            }
330                            Err(e) => {
331                                // Log full error chain
332                                let mut error_msg = format!("Failed to send POST request: {}", e);
333                                {
334                                    use std::error::Error as StdError;
335                                    let mut source = e.source();
336                                    while let Some(cause) = source {
337                                        error_msg.push_str(&format!("\n  Caused by: {}", cause));
338                                        source = cause.source();
339                                    }
340                                }
341                                error!("{}", error_msg);
342                                let error_json = serde_json::json!({
343                                    "jsonrpc": "2.0",
344                                    "error": {
345                                        "code": -32603,
346                                        "message": error_msg
347                                    }
348                                });
349                                let _ = response_tx.send(error_json);
350                            }
351                        }
352                    }
353                }
354            }
355        });
356
357        Ok(())
358    }
359
360    /// 初始化会话 / Initialize session
361    async fn initialize_session(&self) -> Result<(), MCPClientError> {
362        let params = serde_json::json!({
363            "protocolVersion": "2024-11-05",
364            "capabilities": {
365                "tools": {},
366                "resources": {}
367            },
368            "clientInfo": {
369                "name": "a2c-smcp-rust",
370                "version": "0.1.0"
371            }
372        });
373
374        let response = self.send_request("initialize", Some(params)).await?;
375
376        // 检查响应 / Check response
377        if let Some(error) = response.get("error") {
378            return Err(MCPClientError::ProtocolError(format!(
379                "Initialize error: {}",
380                error
381            )));
382        }
383
384        if let Some(result) = response.get("result") {
385            if let Some(session_id) = result.get("sessionId").and_then(|v| v.as_str()) {
386                *self.session_id.lock().await = Some(session_id.to_string());
387            }
388        }
389
390        // 发送initialized通知 / Send initialized notification
391        self.send_request("notifications/initialized", Some(serde_json::json!({})))
392            .await?;
393
394        info!("SSE session initialized successfully");
395        Ok(())
396    }
397
398    // ========== 订阅管理 API / Subscription Management API ==========
399
400    /// 检查是否已订阅指定资源
401    pub async fn is_subscribed(&self, uri: &str) -> bool {
402        self.subscription_manager.is_subscribed(uri).await
403    }
404
405    /// 获取所有订阅的 URI 列表
406    pub async fn get_subscriptions(&self) -> Vec<String> {
407        self.subscription_manager.get_subscriptions().await
408    }
409
410    /// 获取订阅数量
411    pub async fn subscription_count(&self) -> usize {
412        self.subscription_manager.subscription_count().await
413    }
414
415    // ========== 资源缓存 API / Resource Cache API ==========
416
417    /// 获取缓存的资源数据
418    pub async fn get_cached_resource(&self, uri: &str) -> Option<serde_json::Value> {
419        self.resource_cache.get(uri).await
420    }
421
422    /// 检查资源是否已缓存
423    pub async fn has_cache(&self, uri: &str) -> bool {
424        self.resource_cache.contains(uri).await
425    }
426
427    /// 获取缓存大小
428    pub async fn cache_size(&self) -> usize {
429        self.resource_cache.size().await
430    }
431
432    /// 清理过期的缓存
433    pub async fn cleanup_cache(&self) -> usize {
434        self.resource_cache.cleanup_expired().await
435    }
436
437    /// 获取所有缓存的 URI 列表
438    pub async fn cache_keys(&self) -> Vec<String> {
439        self.resource_cache.keys().await
440    }
441
442    // ========== 资源更新订阅 API / Resource Update Subscription API ==========
443
444    /// 订阅资源更新通知
445    ///
446    /// 返回一个 receiver,可以用于接收资源更新通知
447    pub async fn subscribe_to_updates(&self) -> mpsc::UnboundedReceiver<ResourceUpdate> {
448        let (tx, rx) = mpsc::unbounded_channel();
449        *self.update_tx.lock().await = Some(tx);
450        rx
451    }
452}
453
454/// 解析 endpoint URL,支持相对路径
455fn resolve_endpoint_url(base_url: &str, endpoint: &str) -> String {
456    // 如果是绝对 URL 直接返回
457    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
458        return endpoint.to_string();
459    }
460    // 用 url crate 解析相对路径
461    if let Ok(base) = url::Url::parse(base_url) {
462        if let Ok(resolved) = base.join(endpoint) {
463            return resolved.to_string();
464        }
465    }
466    // fallback: 简单拼接
467    format!("{}{}", base_url.trim_end_matches('/'), endpoint)
468}
469
470#[async_trait]
471impl MCPClientProtocol for SseMCPClient {
472    fn state(&self) -> ClientState {
473        self.base.state()
474    }
475
476    async fn connect(&self) -> Result<(), MCPClientError> {
477        // 检查是否可以连接 / Check if can connect
478        if !self.base.can_connect().await {
479            return Err(MCPClientError::ConnectionError(format!(
480                "Cannot connect in state: {}",
481                self.base.get_state().await
482            )));
483        }
484
485        // 启动SSE连接 / Start SSE connection
486        self.start_sse_connection().await?;
487
488        // 等待 endpoint URL 就绪 / Wait for endpoint URL to be ready
489        let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
490        loop {
491            if self.endpoint_url.lock().await.is_some() {
492                break;
493            }
494            if tokio::time::Instant::now() >= deadline {
495                return Err(MCPClientError::TimeoutError(
496                    "Timed out waiting for SSE endpoint event".to_string(),
497                ));
498            }
499            tokio::time::sleep(Duration::from_millis(100)).await;
500        }
501
502        // 初始化会话 / Initialize session
503        self.initialize_session().await?;
504
505        // 更新状态 / Update state
506        self.base.update_state(ClientState::Connected).await;
507        info!("SSE client connected successfully");
508
509        Ok(())
510    }
511
512    async fn disconnect(&self) -> Result<(), MCPClientError> {
513        // 检查是否可以断开 / Check if can disconnect
514        if !self.base.can_disconnect().await {
515            return Err(MCPClientError::ConnectionError(format!(
516                "Cannot disconnect in state: {}",
517                self.base.get_state().await
518            )));
519        }
520
521        // 尝试优雅关闭 / Try graceful shutdown
522        if let Err(e) = self.send_request("shutdown", None).await {
523            warn!("Failed to send shutdown request: {}", e);
524        }
525
526        // 发送exit通知 / Send exit notification
527        if let Err(e) = self.send_request("exit", None).await {
528            warn!("Failed to send exit notification: {}", e);
529        }
530
531        // 关闭SSE连接 / Close SSE connection
532        *self.request_tx.lock().await = None;
533
534        // 清理会话ID / Clear session ID
535        *self.session_id.lock().await = None;
536
537        // 清理 endpoint URL
538        *self.endpoint_url.lock().await = None;
539
540        // 更新状态 / Update state
541        self.base.update_state(ClientState::Disconnected).await;
542        info!("SSE client disconnected successfully");
543
544        Ok(())
545    }
546
547    async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
548        if self.base.get_state().await != ClientState::Connected {
549            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
550        }
551
552        let response = self
553            .send_request("tools/list", Some(serde_json::json!({})))
554            .await?;
555
556        if let Some(error) = response.get("error") {
557            return Err(MCPClientError::ProtocolError(format!(
558                "List tools error: {}",
559                error
560            )));
561        }
562
563        if let Some(result) = response.get("result") {
564            if let Some(tools) = result.get("tools").and_then(|v| v.as_array()) {
565                let mut tool_list = Vec::new();
566                for tool in tools {
567                    if let Ok(parsed_tool) = serde_json::from_value::<Tool>(tool.clone()) {
568                        tool_list.push(parsed_tool);
569                    }
570                }
571                return Ok(tool_list);
572            }
573        }
574
575        Ok(vec![])
576    }
577
578    async fn call_tool(
579        &self,
580        tool_name: &str,
581        params: serde_json::Value,
582    ) -> Result<CallToolResult, MCPClientError> {
583        if self.base.get_state().await != ClientState::Connected {
584            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
585        }
586
587        let call_params = serde_json::json!({
588            "name": tool_name,
589            "arguments": params
590        });
591
592        let response = self.send_request("tools/call", Some(call_params)).await?;
593
594        if let Some(error) = response.get("error") {
595            return Err(MCPClientError::ProtocolError(format!(
596                "Call tool error: {}",
597                error
598            )));
599        }
600
601        if let Some(result) = response.get("result") {
602            let call_result: CallToolResult = serde_json::from_value(result.clone())?;
603            return Ok(call_result);
604        }
605
606        Err(MCPClientError::ProtocolError(
607            "Invalid response".to_string(),
608        ))
609    }
610
611    async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
612        if self.base.get_state().await != ClientState::Connected {
613            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
614        }
615
616        // SSE 客户端目前不支持分页,直接获取所有资源
617        let response = self
618            .send_request("resources/list", Some(serde_json::json!({})))
619            .await?;
620
621        if let Some(error) = response.get("error") {
622            return Err(MCPClientError::ProtocolError(format!(
623                "List resources error: {}",
624                error
625            )));
626        }
627
628        let mut all_resources = Vec::new();
629        if let Some(result) = response.get("result") {
630            if let Some(resources) = result.get("resources").and_then(|v| v.as_array()) {
631                for resource in resources {
632                    if let Ok(parsed_resource) =
633                        serde_json::from_value::<Resource>(resource.clone())
634                    {
635                        all_resources.push(parsed_resource);
636                    }
637                }
638            }
639        }
640
641        // 过滤 window:// 资源并按 priority 排序
642        let mut filtered_resources: Vec<(Resource, i32)> = Vec::new();
643
644        for resource in all_resources {
645            if !is_window_uri(&resource.uri) {
646                continue;
647            }
648
649            let priority = if let Ok(uri) = WindowURI::new(&resource.uri) {
650                uri.priority().unwrap_or(0)
651            } else {
652                0
653            };
654
655            filtered_resources.push((resource, priority));
656        }
657
658        filtered_resources.sort_by(|a, b| b.1.cmp(&a.1));
659
660        Ok(filtered_resources.into_iter().map(|(r, _)| r).collect())
661    }
662
663    async fn get_window_detail(
664        &self,
665        resource: Resource,
666    ) -> Result<ReadResourceResult, MCPClientError> {
667        if self.base.get_state().await != ClientState::Connected {
668            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
669        }
670
671        let params = serde_json::json!({
672            "uri": resource.uri
673        });
674
675        let response = self.send_request("resources/read", Some(params)).await?;
676
677        if let Some(error) = response.get("error") {
678            return Err(MCPClientError::ProtocolError(format!(
679                "Read resource error: {}",
680                error
681            )));
682        }
683
684        if let Some(result) = response.get("result") {
685            let read_result: ReadResourceResult = serde_json::from_value(result.clone())?;
686            return Ok(read_result);
687        }
688
689        Err(MCPClientError::ProtocolError(
690            "Invalid response".to_string(),
691        ))
692    }
693
694    async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
695        if self.base.get_state().await != ClientState::Connected {
696            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
697        }
698
699        let params = serde_json::json!({
700            "uri": resource.uri
701        });
702
703        let response = self
704            .send_request("resources/subscribe", Some(params))
705            .await?;
706
707        if let Some(error) = response.get("error") {
708            return Err(MCPClientError::ProtocolError(format!(
709                "Subscribe resource error: {}",
710                error
711            )));
712        }
713
714        let _ = self
715            .subscription_manager
716            .add_subscription(resource.uri.clone())
717            .await;
718
719        match self.get_window_detail(resource.clone()).await {
720            Ok(result) => {
721                if !result.contents.is_empty() {
722                    if let Ok(json_value) = serde_json::to_value(&result.contents[0]) {
723                        self.resource_cache
724                            .set(resource.uri.clone(), json_value, None)
725                            .await;
726                        info!("Subscribed and cached: {}", resource.uri);
727                    }
728                }
729            }
730            Err(e) => {
731                warn!("Failed to fetch resource data after subscription: {:?}", e);
732            }
733        }
734
735        Ok(())
736    }
737
738    async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
739        if self.base.get_state().await != ClientState::Connected {
740            return Err(MCPClientError::ConnectionError("Not connected".to_string()));
741        }
742
743        let params = serde_json::json!({
744            "uri": resource.uri
745        });
746
747        let response = self
748            .send_request("resources/unsubscribe", Some(params))
749            .await?;
750
751        if let Some(error) = response.get("error") {
752            return Err(MCPClientError::ProtocolError(format!(
753                "Unsubscribe resource error: {}",
754                error
755            )));
756        }
757
758        let _ = self
759            .subscription_manager
760            .remove_subscription(&resource.uri)
761            .await;
762
763        self.resource_cache.remove(&resource.uri).await;
764        info!("Unsubscribed and removed cache: {}", resource.uri);
765
766        Ok(())
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use serde_json::json;
774    use std::collections::HashMap;
775
776    #[test]
777    fn test_resolve_endpoint_url_absolute() {
778        let result = resolve_endpoint_url(
779            "http://localhost:8081/sse",
780            "https://other.example.com/messages",
781        );
782        assert_eq!(result, "https://other.example.com/messages");
783    }
784
785    #[test]
786    fn test_resolve_endpoint_url_relative() {
787        let result = resolve_endpoint_url("http://localhost:8081/sse", "/messages");
788        assert_eq!(result, "http://localhost:8081/messages");
789    }
790
791    #[test]
792    fn test_resolve_endpoint_url_relative_path() {
793        let result = resolve_endpoint_url("http://localhost:8081/api/sse", "messages");
794        assert_eq!(result, "http://localhost:8081/api/messages");
795    }
796
797    #[tokio::test]
798    async fn test_sse_client_creation() {
799        let params = SseServerParameters {
800            url: "http://localhost:8081".to_string(),
801            headers: HashMap::new(),
802        };
803
804        let client = SseMCPClient::new(params);
805        assert_eq!(client.state(), ClientState::Initialized);
806        assert_eq!(client.base.params.url, "http://localhost:8081");
807    }
808
809    #[tokio::test]
810    async fn test_sse_client_with_headers() {
811        let mut headers = HashMap::new();
812        headers.insert("Authorization".to_string(), "Bearer token123".to_string());
813        headers.insert("Accept".to_string(), "text/event-stream".to_string());
814
815        let params = SseServerParameters {
816            url: "http://localhost:8081".to_string(),
817            headers,
818        };
819
820        let client = SseMCPClient::new(params);
821        assert_eq!(
822            client.base.params.headers.get("Authorization"),
823            Some(&"Bearer token123".to_string())
824        );
825    }
826
827    #[tokio::test]
828    async fn test_session_id_management() {
829        let params = SseServerParameters {
830            url: "http://localhost:8081".to_string(),
831            headers: HashMap::new(),
832        };
833
834        let client = SseMCPClient::new(params);
835
836        // 初始会话ID应该为空
837        let session_id = client.session_id.lock().await;
838        assert!(session_id.is_none());
839        drop(session_id);
840
841        // 设置会话ID
842        *client.session_id.lock().await = Some("session123".to_string());
843        let session_id = client.session_id.lock().await;
844        assert_eq!(session_id.as_ref().unwrap(), "session123");
845    }
846
847    #[tokio::test]
848    async fn test_endpoint_url_management() {
849        let params = SseServerParameters {
850            url: "http://localhost:8081".to_string(),
851            headers: HashMap::new(),
852        };
853
854        let client = SseMCPClient::new(params);
855
856        // 初始 endpoint URL 应该为空
857        assert!(client.endpoint_url.lock().await.is_none());
858
859        // 设置 endpoint URL
860        *client.endpoint_url.lock().await = Some("http://localhost:8081/messages".to_string());
861        assert_eq!(
862            client.endpoint_url.lock().await.as_ref().unwrap(),
863            "http://localhost:8081/messages"
864        );
865    }
866
867    #[tokio::test]
868    async fn test_send_request_without_connection() {
869        let params = SseServerParameters {
870            url: "http://localhost:8081".to_string(),
871            headers: HashMap::new(),
872        };
873
874        let client = SseMCPClient::new(params);
875
876        let method = "test/method";
877        let params = Some(json!({"param1": "value1"}));
878
879        let result = client.send_request(method, params).await;
880        assert!(result.is_err());
881        assert!(matches!(
882            result.unwrap_err(),
883            MCPClientError::ConnectionError(_)
884        ));
885    }
886
887    #[tokio::test]
888    async fn test_connect_state_checks() {
889        let params = SseServerParameters {
890            url: "http://localhost:8081".to_string(),
891            headers: HashMap::new(),
892        };
893
894        let client = SseMCPClient::new(params);
895
896        // 在已连接状态下尝试连接应该失败
897        client.base.update_state(ClientState::Connected).await;
898        let result = client.connect().await;
899        assert!(result.is_err());
900        assert!(matches!(
901            result.unwrap_err(),
902            MCPClientError::ConnectionError(_)
903        ));
904    }
905
906    #[tokio::test]
907    async fn test_disconnect_state_checks() {
908        let params = SseServerParameters {
909            url: "http://localhost:8081".to_string(),
910            headers: HashMap::new(),
911        };
912
913        let client = SseMCPClient::new(params);
914
915        // 在未连接状态下尝试断开应该失败
916        let result = client.disconnect().await;
917        assert!(result.is_err());
918        assert!(matches!(
919            result.unwrap_err(),
920            MCPClientError::ConnectionError(_)
921        ));
922    }
923
924    #[tokio::test]
925    async fn test_list_tools_requires_connection() {
926        let params = SseServerParameters {
927            url: "http://localhost:8081".to_string(),
928            headers: HashMap::new(),
929        };
930
931        let client = SseMCPClient::new(params);
932
933        let result = client.list_tools().await;
934        assert!(result.is_err());
935        assert!(matches!(
936            result.unwrap_err(),
937            MCPClientError::ConnectionError(_)
938        ));
939    }
940
941    #[tokio::test]
942    async fn test_call_tool_requires_connection() {
943        let params = SseServerParameters {
944            url: "http://localhost:8081".to_string(),
945            headers: HashMap::new(),
946        };
947
948        let client = SseMCPClient::new(params);
949
950        let result = client.call_tool("test_tool", json!({})).await;
951        assert!(result.is_err());
952        assert!(matches!(
953            result.unwrap_err(),
954            MCPClientError::ConnectionError(_)
955        ));
956    }
957
958    #[tokio::test]
959    async fn test_list_windows_requires_connection() {
960        let params = SseServerParameters {
961            url: "http://localhost:8081".to_string(),
962            headers: HashMap::new(),
963        };
964
965        let client = SseMCPClient::new(params);
966
967        let result = client.list_windows().await;
968        assert!(result.is_err());
969        assert!(matches!(
970            result.unwrap_err(),
971            MCPClientError::ConnectionError(_)
972        ));
973    }
974
975    #[tokio::test]
976    async fn test_get_window_detail_requires_connection() {
977        let params = SseServerParameters {
978            url: "http://localhost:8081".to_string(),
979            headers: HashMap::new(),
980        };
981
982        let client = SseMCPClient::new(params);
983
984        let resource = make_resource("window://123", "Test Window", None, None);
985
986        let result = client.get_window_detail(resource).await;
987        assert!(result.is_err());
988        assert!(matches!(
989            result.unwrap_err(),
990            MCPClientError::ConnectionError(_)
991        ));
992    }
993
994    #[tokio::test]
995    async fn test_start_sse_connection_url_formatting() {
996        let params = SseServerParameters {
997            url: "http://localhost:8081".to_string(),
998            headers: HashMap::new(),
999        };
1000
1001        let client = SseMCPClient::new(params);
1002
1003        let result = client.start_sse_connection().await;
1004        assert!(result.is_ok());
1005
1006        // 验证通道已创建
1007        let request_tx = client.request_tx.lock().await;
1008        assert!(request_tx.is_some());
1009
1010        let response_rx = client.response_rx.lock().await;
1011        assert!(response_rx.is_some());
1012    }
1013
1014    #[tokio::test]
1015    async fn test_start_sse_connection_url_formatting_with_query() {
1016        let params = SseServerParameters {
1017            url: "http://localhost:8081?param=value".to_string(),
1018            headers: HashMap::new(),
1019        };
1020
1021        let client = SseMCPClient::new(params);
1022
1023        let result = client.start_sse_connection().await;
1024        assert!(result.is_ok());
1025
1026        let request_tx = client.request_tx.lock().await;
1027        assert!(request_tx.is_some());
1028
1029        let response_rx = client.response_rx.lock().await;
1030        assert!(response_rx.is_some());
1031    }
1032
1033    #[tokio::test]
1034    async fn test_disconnect_cleanup() {
1035        let params = SseServerParameters {
1036            url: "http://localhost:8081".to_string(),
1037            headers: HashMap::new(),
1038        };
1039
1040        let client = SseMCPClient::new(params);
1041
1042        // 设置会话ID 和 endpoint URL
1043        *client.session_id.lock().await = Some("session123".to_string());
1044        *client.endpoint_url.lock().await = Some("http://localhost:8081/messages".to_string());
1045
1046        // 设置为已连接状态
1047        client.base.update_state(ClientState::Connected).await;
1048
1049        let _ = client.disconnect().await;
1050
1051        // 验证清理
1052        assert!(client.session_id.lock().await.is_none());
1053        assert!(client.endpoint_url.lock().await.is_none());
1054        assert_eq!(client.base.get_state().await, ClientState::Disconnected);
1055    }
1056
1057    #[tokio::test]
1058    async fn test_request_response_channels() {
1059        let params = SseServerParameters {
1060            url: "http://localhost:8081".to_string(),
1061            headers: HashMap::new(),
1062        };
1063
1064        let client = SseMCPClient::new(params);
1065
1066        let request_tx = client.request_tx.lock().await;
1067        assert!(request_tx.is_none());
1068        drop(request_tx);
1069
1070        let response_rx = client.response_rx.lock().await;
1071        assert!(response_rx.is_none());
1072    }
1073
1074    #[tokio::test]
1075    async fn test_initialize_session_request_format() {
1076        let params = SseServerParameters {
1077            url: "http://localhost:8081".to_string(),
1078            headers: HashMap::new(),
1079        };
1080
1081        let client = SseMCPClient::new(params);
1082
1083        let result = client.initialize_session().await;
1084        assert!(result.is_err());
1085    }
1086
1087    #[tokio::test]
1088    async fn test_error_handling_in_list_tools() {
1089        let params = SseServerParameters {
1090            url: "http://localhost:8081".to_string(),
1091            headers: HashMap::new(),
1092        };
1093
1094        let client = SseMCPClient::new(params);
1095
1096        client.base.update_state(ClientState::Connected).await;
1097
1098        let result = client.list_tools().await;
1099        assert!(result.is_err());
1100    }
1101
1102    #[tokio::test]
1103    async fn test_error_handling_in_call_tool() {
1104        let params = SseServerParameters {
1105            url: "http://localhost:8081".to_string(),
1106            headers: HashMap::new(),
1107        };
1108
1109        let client = SseMCPClient::new(params);
1110
1111        client.base.update_state(ClientState::Connected).await;
1112
1113        let result = client
1114            .call_tool("test_tool", json!({"param": "value"}))
1115            .await;
1116        assert!(result.is_err());
1117    }
1118
1119    #[tokio::test]
1120    async fn test_sse_client_debug_format() {
1121        let params = SseServerParameters {
1122            url: "http://localhost:8081".to_string(),
1123            headers: HashMap::new(),
1124        };
1125
1126        let client = SseMCPClient::new(params);
1127
1128        let debug_str = format!("{:?}", client);
1129        assert!(debug_str.contains("SseMCPClient"));
1130    }
1131}