1use super::base_client::BaseMCPClient;
11use super::model::*;
12use super::{ResourceCache, SubscriptionManager};
13use crate::desktop::window_uri::{is_window_uri, WindowURI};
14use async_trait::async_trait;
15use es::Client as EsClient;
16use eventsource_client as es;
17use futures::stream::{Stream, StreamExt};
18use serde_json;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::{mpsc, Mutex};
23use tracing::{debug, error, info, warn};
24
25pub struct SseMCPClient {
27 base: BaseMCPClient<SseServerParameters>,
29 http_client: reqwest::Client,
31 request_tx: Arc<Mutex<Option<mpsc::UnboundedSender<serde_json::Value>>>>,
33 response_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<serde_json::Value>>>>,
35 session_id: Arc<Mutex<Option<String>>>,
37 endpoint_url: Arc<Mutex<Option<String>>>,
39 subscription_manager: SubscriptionManager,
41 resource_cache: ResourceCache,
43 update_tx: Arc<Mutex<Option<mpsc::UnboundedSender<ResourceUpdate>>>>,
45}
46
47#[derive(Debug, Clone)]
49pub struct ResourceUpdate {
50 pub uri: String,
52 pub data: serde_json::Value,
54 pub version: u64,
56}
57
58impl std::fmt::Debug for SseMCPClient {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 f.debug_struct("SseMCPClient")
61 .field("url", &self.base.params.url)
62 .field("headers", &self.base.params.headers)
63 .field("state", &self.base.state())
64 .finish()
65 }
66}
67
68impl SseMCPClient {
69 pub fn new(params: SseServerParameters) -> Self {
71 let http_client = reqwest::Client::builder()
72 .timeout(std::time::Duration::from_secs(30))
73 .build()
74 .expect("Failed to create HTTP client");
75
76 Self {
77 base: BaseMCPClient::new(params),
78 http_client,
79 request_tx: Arc::new(Mutex::new(None)),
80 response_rx: Arc::new(Mutex::new(None)),
81 session_id: Arc::new(Mutex::new(None)),
82 endpoint_url: Arc::new(Mutex::new(None)),
83 subscription_manager: SubscriptionManager::new(),
84 resource_cache: ResourceCache::new(Duration::from_secs(60)), update_tx: Arc::new(Mutex::new(None)),
86 }
87 }
88
89 async fn send_request(
91 &self,
92 method: &str,
93 params: Option<serde_json::Value>,
94 ) -> Result<serde_json::Value, MCPClientError> {
95 let mut request_body = serde_json::json!({
96 "jsonrpc": "2.0",
97 "method": method,
98 });
99
100 if let Some(p) = params {
101 request_body["params"] = p;
102 }
103
104 let is_notification = method.starts_with("notifications/");
106
107 if !is_notification {
109 let request_id = std::time::SystemTime::now()
110 .duration_since(std::time::UNIX_EPOCH)
111 .unwrap()
112 .as_secs() as i64;
113 request_body["id"] = serde_json::Value::Number(serde_json::Number::from(request_id));
114 }
115
116 debug!("Sending SSE request: {}", request_body);
117
118 let tx = self.request_tx.lock().await;
120 if let Some(ref tx) = *tx {
121 tx.send(request_body.clone()).map_err(|e| {
122 MCPClientError::ConnectionError(format!("Failed to send request: {}", e))
123 })?;
124 } else {
125 return Err(MCPClientError::ConnectionError(
126 "SSE connection not established".to_string(),
127 ));
128 }
129 drop(tx);
130
131 if is_notification {
132 return Ok(serde_json::json!({}));
133 }
134
135 let mut rx = self.response_rx.lock().await;
137 if let Some(ref mut receiver) = *rx {
138 match tokio::time::timeout(Duration::from_secs(30), receiver.recv()).await {
139 Ok(Some(response)) => {
140 debug!("Received SSE response: {}", response);
141 Ok(response)
142 }
143 Ok(None) => Err(MCPClientError::ConnectionError(
144 "Response channel closed".to_string(),
145 )),
146 Err(_) => Err(MCPClientError::TimeoutError(
147 "SSE response timed out after 30s".to_string(),
148 )),
149 }
150 } else {
151 Err(MCPClientError::ConnectionError(
152 "Response channel not established".to_string(),
153 ))
154 }
155 }
156
157 async fn start_sse_connection(&self) -> Result<(), MCPClientError> {
159 let url = &self.base.params.url;
160
161 let mut builder = es::ClientBuilder::for_url(url)
163 .map_err(|e| MCPClientError::ConnectionError(format!("Invalid SSE URL: {:?}", e)))?;
164
165 for (key, value) in &self.base.params.headers {
167 builder = builder.header(key, value).map_err(|e| {
168 MCPClientError::ConnectionError(format!("Failed to add header {}: {:?}", key, e))
169 })?;
170 }
171
172 let es_client = builder.build();
173
174 let (request_tx, request_rx) = mpsc::unbounded_channel::<serde_json::Value>();
176 let (response_tx, response_rx) = mpsc::unbounded_channel::<serde_json::Value>();
177
178 *self.request_tx.lock().await = Some(request_tx);
179 *self.response_rx.lock().await = Some(response_rx);
180
181 let resource_cache = self.resource_cache.clone();
183 let update_tx = self.update_tx.clone();
184 let endpoint_url = self.endpoint_url.clone();
185 let base_url = url.clone();
186 let http_client = self.http_client.clone();
187 let headers = self.base.params.headers.clone();
188
189 let stream: Pin<Box<dyn Stream<Item = Result<es::SSE, es::Error>> + Send + Sync>> =
191 es_client.stream();
192
193 tokio::spawn(async move {
194 let mut stream = Box::pin(stream);
195 let mut request_rx = Box::pin(request_rx);
196
197 loop {
198 tokio::select! {
199 Some(event_result) = stream.next() => {
201 match event_result {
202 Ok(event) => {
203 debug!("Received SSE event: {:?}", event);
204
205 match event {
206 es::SSE::Event(event_data) => {
207 if event_data.event_type == "endpoint" {
209 let endpoint = event_data.data.trim();
210 let resolved = resolve_endpoint_url(&base_url, endpoint);
212 info!("SSE endpoint resolved: {}", resolved);
213 *endpoint_url.lock().await = Some(resolved);
214 continue;
215 }
216
217 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&event_data.data) {
218 if let Some(method) = value.get("method").and_then(|m| m.as_str()) {
222 if method == "resources/update" || method.contains("update") {
223 debug!("Received resource update notification");
224
225 if let Some(params) = value.get("params") {
227 if let Some(uri) = params.get("uri").and_then(|u| u.as_str()) {
228 if let Some(data) = params.get("data") {
230 let _ = resource_cache.refresh(uri, data.clone()).await;
231
232 if let Some(tx) = update_tx.lock().await.as_ref() {
234 let _ = tx.send(ResourceUpdate {
235 uri: uri.to_string(),
236 data: data.clone(),
237 version: 1,
238 });
239 }
240 }
241 }
242 }
243 } else {
244 let _ = response_tx.send(value);
246 }
247 } else {
248 let _ = response_tx.send(value);
250 }
251 }
252 }
253 es::SSE::Comment(_) => {
254 debug!("Received SSE comment");
255 }
256 es::SSE::Connected(_) => {
257 debug!("SSE connection established");
258 }
259 }
260 }
261 Err(e) => {
262 error!("SSE event error: {:?}", e);
263 break;
264 }
265 }
266 }
267
268 Some(request) = request_rx.recv() => {
270 debug!("Sending request via HTTP POST: {}", request);
271
272 let post_url = match endpoint_url.lock().await.clone() {
273 Some(url) => url,
274 None => {
275 error!("No endpoint URL available for POST request");
276 continue;
277 }
278 };
279
280 let mut req = http_client.post(&post_url)
281 .header("Content-Type", "application/json");
282
283 for (key, value) in &headers {
285 req = req.header(key, value);
286 }
287
288 match req.json(&request).send().await {
289 Ok(resp) => {
290 if resp.status().is_success() {
291 let ct = resp.headers()
293 .get("content-type")
294 .and_then(|v| v.to_str().ok())
295 .unwrap_or("")
296 .to_string();
297
298 if ct.contains("application/json") {
299 match resp.json::<serde_json::Value>().await {
300 Ok(json_resp) => {
301 let _ = response_tx.send(json_resp);
302 }
303 Err(e) => {
304 error!("Failed to parse POST JSON response: {}", e);
305 let error_json = serde_json::json!({
306 "jsonrpc": "2.0",
307 "error": {
308 "code": -32603,
309 "message": format!("Failed to parse POST JSON response: {}", e)
310 }
311 });
312 let _ = response_tx.send(error_json);
313 }
314 }
315 }
316 } else {
318 let status = resp.status();
319 error!("POST request failed with status: {}", status);
320 let error_json = serde_json::json!({
321 "jsonrpc": "2.0",
322 "error": {
323 "code": -32603,
324 "message": format!("POST request failed with status: {}", status)
325 }
326 });
327 let _ = response_tx.send(error_json);
328 }
329 }
330 Err(e) => {
331 let mut error_msg = format!("Failed to send POST request: {}", e);
333 {
334 use std::error::Error as StdError;
335 let mut source = e.source();
336 while let Some(cause) = source {
337 error_msg.push_str(&format!("\n Caused by: {}", cause));
338 source = cause.source();
339 }
340 }
341 error!("{}", error_msg);
342 let error_json = serde_json::json!({
343 "jsonrpc": "2.0",
344 "error": {
345 "code": -32603,
346 "message": error_msg
347 }
348 });
349 let _ = response_tx.send(error_json);
350 }
351 }
352 }
353 }
354 }
355 });
356
357 Ok(())
358 }
359
360 async fn initialize_session(&self) -> Result<(), MCPClientError> {
362 let params = serde_json::json!({
363 "protocolVersion": "2024-11-05",
364 "capabilities": {
365 "tools": {},
366 "resources": {}
367 },
368 "clientInfo": {
369 "name": "a2c-smcp-rust",
370 "version": "0.1.0"
371 }
372 });
373
374 let response = self.send_request("initialize", Some(params)).await?;
375
376 if let Some(error) = response.get("error") {
378 return Err(MCPClientError::ProtocolError(format!(
379 "Initialize error: {}",
380 error
381 )));
382 }
383
384 if let Some(result) = response.get("result") {
385 if let Some(session_id) = result.get("sessionId").and_then(|v| v.as_str()) {
386 *self.session_id.lock().await = Some(session_id.to_string());
387 }
388 }
389
390 self.send_request("notifications/initialized", Some(serde_json::json!({})))
392 .await?;
393
394 info!("SSE session initialized successfully");
395 Ok(())
396 }
397
398 pub async fn is_subscribed(&self, uri: &str) -> bool {
402 self.subscription_manager.is_subscribed(uri).await
403 }
404
405 pub async fn get_subscriptions(&self) -> Vec<String> {
407 self.subscription_manager.get_subscriptions().await
408 }
409
410 pub async fn subscription_count(&self) -> usize {
412 self.subscription_manager.subscription_count().await
413 }
414
415 pub async fn get_cached_resource(&self, uri: &str) -> Option<serde_json::Value> {
419 self.resource_cache.get(uri).await
420 }
421
422 pub async fn has_cache(&self, uri: &str) -> bool {
424 self.resource_cache.contains(uri).await
425 }
426
427 pub async fn cache_size(&self) -> usize {
429 self.resource_cache.size().await
430 }
431
432 pub async fn cleanup_cache(&self) -> usize {
434 self.resource_cache.cleanup_expired().await
435 }
436
437 pub async fn cache_keys(&self) -> Vec<String> {
439 self.resource_cache.keys().await
440 }
441
442 pub async fn subscribe_to_updates(&self) -> mpsc::UnboundedReceiver<ResourceUpdate> {
448 let (tx, rx) = mpsc::unbounded_channel();
449 *self.update_tx.lock().await = Some(tx);
450 rx
451 }
452}
453
454fn resolve_endpoint_url(base_url: &str, endpoint: &str) -> String {
456 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
458 return endpoint.to_string();
459 }
460 if let Ok(base) = url::Url::parse(base_url) {
462 if let Ok(resolved) = base.join(endpoint) {
463 return resolved.to_string();
464 }
465 }
466 format!("{}{}", base_url.trim_end_matches('/'), endpoint)
468}
469
470#[async_trait]
471impl MCPClientProtocol for SseMCPClient {
472 fn state(&self) -> ClientState {
473 self.base.state()
474 }
475
476 async fn connect(&self) -> Result<(), MCPClientError> {
477 if !self.base.can_connect().await {
479 return Err(MCPClientError::ConnectionError(format!(
480 "Cannot connect in state: {}",
481 self.base.get_state().await
482 )));
483 }
484
485 self.start_sse_connection().await?;
487
488 let deadline = tokio::time::Instant::now() + Duration::from_secs(10);
490 loop {
491 if self.endpoint_url.lock().await.is_some() {
492 break;
493 }
494 if tokio::time::Instant::now() >= deadline {
495 return Err(MCPClientError::TimeoutError(
496 "Timed out waiting for SSE endpoint event".to_string(),
497 ));
498 }
499 tokio::time::sleep(Duration::from_millis(100)).await;
500 }
501
502 self.initialize_session().await?;
504
505 self.base.update_state(ClientState::Connected).await;
507 info!("SSE client connected successfully");
508
509 Ok(())
510 }
511
512 async fn disconnect(&self) -> Result<(), MCPClientError> {
513 if !self.base.can_disconnect().await {
515 return Err(MCPClientError::ConnectionError(format!(
516 "Cannot disconnect in state: {}",
517 self.base.get_state().await
518 )));
519 }
520
521 if let Err(e) = self.send_request("shutdown", None).await {
523 warn!("Failed to send shutdown request: {}", e);
524 }
525
526 if let Err(e) = self.send_request("exit", None).await {
528 warn!("Failed to send exit notification: {}", e);
529 }
530
531 *self.request_tx.lock().await = None;
533
534 *self.session_id.lock().await = None;
536
537 *self.endpoint_url.lock().await = None;
539
540 self.base.update_state(ClientState::Disconnected).await;
542 info!("SSE client disconnected successfully");
543
544 Ok(())
545 }
546
547 async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
548 if self.base.get_state().await != ClientState::Connected {
549 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
550 }
551
552 let response = self
553 .send_request("tools/list", Some(serde_json::json!({})))
554 .await?;
555
556 if let Some(error) = response.get("error") {
557 return Err(MCPClientError::ProtocolError(format!(
558 "List tools error: {}",
559 error
560 )));
561 }
562
563 if let Some(result) = response.get("result") {
564 if let Some(tools) = result.get("tools").and_then(|v| v.as_array()) {
565 let mut tool_list = Vec::new();
566 for tool in tools {
567 if let Ok(parsed_tool) = serde_json::from_value::<Tool>(tool.clone()) {
568 tool_list.push(parsed_tool);
569 }
570 }
571 return Ok(tool_list);
572 }
573 }
574
575 Ok(vec![])
576 }
577
578 async fn call_tool(
579 &self,
580 tool_name: &str,
581 params: serde_json::Value,
582 ) -> Result<CallToolResult, MCPClientError> {
583 if self.base.get_state().await != ClientState::Connected {
584 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
585 }
586
587 let call_params = serde_json::json!({
588 "name": tool_name,
589 "arguments": params
590 });
591
592 let response = self.send_request("tools/call", Some(call_params)).await?;
593
594 if let Some(error) = response.get("error") {
595 return Err(MCPClientError::ProtocolError(format!(
596 "Call tool error: {}",
597 error
598 )));
599 }
600
601 if let Some(result) = response.get("result") {
602 let call_result: CallToolResult = serde_json::from_value(result.clone())?;
603 return Ok(call_result);
604 }
605
606 Err(MCPClientError::ProtocolError(
607 "Invalid response".to_string(),
608 ))
609 }
610
611 async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
612 if self.base.get_state().await != ClientState::Connected {
613 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
614 }
615
616 let response = self
618 .send_request("resources/list", Some(serde_json::json!({})))
619 .await?;
620
621 if let Some(error) = response.get("error") {
622 return Err(MCPClientError::ProtocolError(format!(
623 "List resources error: {}",
624 error
625 )));
626 }
627
628 let mut all_resources = Vec::new();
629 if let Some(result) = response.get("result") {
630 if let Some(resources) = result.get("resources").and_then(|v| v.as_array()) {
631 for resource in resources {
632 if let Ok(parsed_resource) =
633 serde_json::from_value::<Resource>(resource.clone())
634 {
635 all_resources.push(parsed_resource);
636 }
637 }
638 }
639 }
640
641 let mut filtered_resources: Vec<(Resource, i32)> = Vec::new();
643
644 for resource in all_resources {
645 if !is_window_uri(&resource.uri) {
646 continue;
647 }
648
649 let priority = if let Ok(uri) = WindowURI::new(&resource.uri) {
650 uri.priority().unwrap_or(0)
651 } else {
652 0
653 };
654
655 filtered_resources.push((resource, priority));
656 }
657
658 filtered_resources.sort_by(|a, b| b.1.cmp(&a.1));
659
660 Ok(filtered_resources.into_iter().map(|(r, _)| r).collect())
661 }
662
663 async fn get_window_detail(
664 &self,
665 resource: Resource,
666 ) -> Result<ReadResourceResult, MCPClientError> {
667 if self.base.get_state().await != ClientState::Connected {
668 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
669 }
670
671 let params = serde_json::json!({
672 "uri": resource.uri
673 });
674
675 let response = self.send_request("resources/read", Some(params)).await?;
676
677 if let Some(error) = response.get("error") {
678 return Err(MCPClientError::ProtocolError(format!(
679 "Read resource error: {}",
680 error
681 )));
682 }
683
684 if let Some(result) = response.get("result") {
685 let read_result: ReadResourceResult = serde_json::from_value(result.clone())?;
686 return Ok(read_result);
687 }
688
689 Err(MCPClientError::ProtocolError(
690 "Invalid response".to_string(),
691 ))
692 }
693
694 async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
695 if self.base.get_state().await != ClientState::Connected {
696 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
697 }
698
699 let params = serde_json::json!({
700 "uri": resource.uri
701 });
702
703 let response = self
704 .send_request("resources/subscribe", Some(params))
705 .await?;
706
707 if let Some(error) = response.get("error") {
708 return Err(MCPClientError::ProtocolError(format!(
709 "Subscribe resource error: {}",
710 error
711 )));
712 }
713
714 let _ = self
715 .subscription_manager
716 .add_subscription(resource.uri.clone())
717 .await;
718
719 match self.get_window_detail(resource.clone()).await {
720 Ok(result) => {
721 if !result.contents.is_empty() {
722 if let Ok(json_value) = serde_json::to_value(&result.contents[0]) {
723 self.resource_cache
724 .set(resource.uri.clone(), json_value, None)
725 .await;
726 info!("Subscribed and cached: {}", resource.uri);
727 }
728 }
729 }
730 Err(e) => {
731 warn!("Failed to fetch resource data after subscription: {:?}", e);
732 }
733 }
734
735 Ok(())
736 }
737
738 async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
739 if self.base.get_state().await != ClientState::Connected {
740 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
741 }
742
743 let params = serde_json::json!({
744 "uri": resource.uri
745 });
746
747 let response = self
748 .send_request("resources/unsubscribe", Some(params))
749 .await?;
750
751 if let Some(error) = response.get("error") {
752 return Err(MCPClientError::ProtocolError(format!(
753 "Unsubscribe resource error: {}",
754 error
755 )));
756 }
757
758 let _ = self
759 .subscription_manager
760 .remove_subscription(&resource.uri)
761 .await;
762
763 self.resource_cache.remove(&resource.uri).await;
764 info!("Unsubscribed and removed cache: {}", resource.uri);
765
766 Ok(())
767 }
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773 use serde_json::json;
774 use std::collections::HashMap;
775
776 #[test]
777 fn test_resolve_endpoint_url_absolute() {
778 let result = resolve_endpoint_url(
779 "http://localhost:8081/sse",
780 "https://other.example.com/messages",
781 );
782 assert_eq!(result, "https://other.example.com/messages");
783 }
784
785 #[test]
786 fn test_resolve_endpoint_url_relative() {
787 let result = resolve_endpoint_url("http://localhost:8081/sse", "/messages");
788 assert_eq!(result, "http://localhost:8081/messages");
789 }
790
791 #[test]
792 fn test_resolve_endpoint_url_relative_path() {
793 let result = resolve_endpoint_url("http://localhost:8081/api/sse", "messages");
794 assert_eq!(result, "http://localhost:8081/api/messages");
795 }
796
797 #[tokio::test]
798 async fn test_sse_client_creation() {
799 let params = SseServerParameters {
800 url: "http://localhost:8081".to_string(),
801 headers: HashMap::new(),
802 };
803
804 let client = SseMCPClient::new(params);
805 assert_eq!(client.state(), ClientState::Initialized);
806 assert_eq!(client.base.params.url, "http://localhost:8081");
807 }
808
809 #[tokio::test]
810 async fn test_sse_client_with_headers() {
811 let mut headers = HashMap::new();
812 headers.insert("Authorization".to_string(), "Bearer token123".to_string());
813 headers.insert("Accept".to_string(), "text/event-stream".to_string());
814
815 let params = SseServerParameters {
816 url: "http://localhost:8081".to_string(),
817 headers,
818 };
819
820 let client = SseMCPClient::new(params);
821 assert_eq!(
822 client.base.params.headers.get("Authorization"),
823 Some(&"Bearer token123".to_string())
824 );
825 }
826
827 #[tokio::test]
828 async fn test_session_id_management() {
829 let params = SseServerParameters {
830 url: "http://localhost:8081".to_string(),
831 headers: HashMap::new(),
832 };
833
834 let client = SseMCPClient::new(params);
835
836 let session_id = client.session_id.lock().await;
838 assert!(session_id.is_none());
839 drop(session_id);
840
841 *client.session_id.lock().await = Some("session123".to_string());
843 let session_id = client.session_id.lock().await;
844 assert_eq!(session_id.as_ref().unwrap(), "session123");
845 }
846
847 #[tokio::test]
848 async fn test_endpoint_url_management() {
849 let params = SseServerParameters {
850 url: "http://localhost:8081".to_string(),
851 headers: HashMap::new(),
852 };
853
854 let client = SseMCPClient::new(params);
855
856 assert!(client.endpoint_url.lock().await.is_none());
858
859 *client.endpoint_url.lock().await = Some("http://localhost:8081/messages".to_string());
861 assert_eq!(
862 client.endpoint_url.lock().await.as_ref().unwrap(),
863 "http://localhost:8081/messages"
864 );
865 }
866
867 #[tokio::test]
868 async fn test_send_request_without_connection() {
869 let params = SseServerParameters {
870 url: "http://localhost:8081".to_string(),
871 headers: HashMap::new(),
872 };
873
874 let client = SseMCPClient::new(params);
875
876 let method = "test/method";
877 let params = Some(json!({"param1": "value1"}));
878
879 let result = client.send_request(method, params).await;
880 assert!(result.is_err());
881 assert!(matches!(
882 result.unwrap_err(),
883 MCPClientError::ConnectionError(_)
884 ));
885 }
886
887 #[tokio::test]
888 async fn test_connect_state_checks() {
889 let params = SseServerParameters {
890 url: "http://localhost:8081".to_string(),
891 headers: HashMap::new(),
892 };
893
894 let client = SseMCPClient::new(params);
895
896 client.base.update_state(ClientState::Connected).await;
898 let result = client.connect().await;
899 assert!(result.is_err());
900 assert!(matches!(
901 result.unwrap_err(),
902 MCPClientError::ConnectionError(_)
903 ));
904 }
905
906 #[tokio::test]
907 async fn test_disconnect_state_checks() {
908 let params = SseServerParameters {
909 url: "http://localhost:8081".to_string(),
910 headers: HashMap::new(),
911 };
912
913 let client = SseMCPClient::new(params);
914
915 let result = client.disconnect().await;
917 assert!(result.is_err());
918 assert!(matches!(
919 result.unwrap_err(),
920 MCPClientError::ConnectionError(_)
921 ));
922 }
923
924 #[tokio::test]
925 async fn test_list_tools_requires_connection() {
926 let params = SseServerParameters {
927 url: "http://localhost:8081".to_string(),
928 headers: HashMap::new(),
929 };
930
931 let client = SseMCPClient::new(params);
932
933 let result = client.list_tools().await;
934 assert!(result.is_err());
935 assert!(matches!(
936 result.unwrap_err(),
937 MCPClientError::ConnectionError(_)
938 ));
939 }
940
941 #[tokio::test]
942 async fn test_call_tool_requires_connection() {
943 let params = SseServerParameters {
944 url: "http://localhost:8081".to_string(),
945 headers: HashMap::new(),
946 };
947
948 let client = SseMCPClient::new(params);
949
950 let result = client.call_tool("test_tool", json!({})).await;
951 assert!(result.is_err());
952 assert!(matches!(
953 result.unwrap_err(),
954 MCPClientError::ConnectionError(_)
955 ));
956 }
957
958 #[tokio::test]
959 async fn test_list_windows_requires_connection() {
960 let params = SseServerParameters {
961 url: "http://localhost:8081".to_string(),
962 headers: HashMap::new(),
963 };
964
965 let client = SseMCPClient::new(params);
966
967 let result = client.list_windows().await;
968 assert!(result.is_err());
969 assert!(matches!(
970 result.unwrap_err(),
971 MCPClientError::ConnectionError(_)
972 ));
973 }
974
975 #[tokio::test]
976 async fn test_get_window_detail_requires_connection() {
977 let params = SseServerParameters {
978 url: "http://localhost:8081".to_string(),
979 headers: HashMap::new(),
980 };
981
982 let client = SseMCPClient::new(params);
983
984 let resource = make_resource("window://123", "Test Window", None, None);
985
986 let result = client.get_window_detail(resource).await;
987 assert!(result.is_err());
988 assert!(matches!(
989 result.unwrap_err(),
990 MCPClientError::ConnectionError(_)
991 ));
992 }
993
994 #[tokio::test]
995 async fn test_start_sse_connection_url_formatting() {
996 let params = SseServerParameters {
997 url: "http://localhost:8081".to_string(),
998 headers: HashMap::new(),
999 };
1000
1001 let client = SseMCPClient::new(params);
1002
1003 let result = client.start_sse_connection().await;
1004 assert!(result.is_ok());
1005
1006 let request_tx = client.request_tx.lock().await;
1008 assert!(request_tx.is_some());
1009
1010 let response_rx = client.response_rx.lock().await;
1011 assert!(response_rx.is_some());
1012 }
1013
1014 #[tokio::test]
1015 async fn test_start_sse_connection_url_formatting_with_query() {
1016 let params = SseServerParameters {
1017 url: "http://localhost:8081?param=value".to_string(),
1018 headers: HashMap::new(),
1019 };
1020
1021 let client = SseMCPClient::new(params);
1022
1023 let result = client.start_sse_connection().await;
1024 assert!(result.is_ok());
1025
1026 let request_tx = client.request_tx.lock().await;
1027 assert!(request_tx.is_some());
1028
1029 let response_rx = client.response_rx.lock().await;
1030 assert!(response_rx.is_some());
1031 }
1032
1033 #[tokio::test]
1034 async fn test_disconnect_cleanup() {
1035 let params = SseServerParameters {
1036 url: "http://localhost:8081".to_string(),
1037 headers: HashMap::new(),
1038 };
1039
1040 let client = SseMCPClient::new(params);
1041
1042 *client.session_id.lock().await = Some("session123".to_string());
1044 *client.endpoint_url.lock().await = Some("http://localhost:8081/messages".to_string());
1045
1046 client.base.update_state(ClientState::Connected).await;
1048
1049 let _ = client.disconnect().await;
1050
1051 assert!(client.session_id.lock().await.is_none());
1053 assert!(client.endpoint_url.lock().await.is_none());
1054 assert_eq!(client.base.get_state().await, ClientState::Disconnected);
1055 }
1056
1057 #[tokio::test]
1058 async fn test_request_response_channels() {
1059 let params = SseServerParameters {
1060 url: "http://localhost:8081".to_string(),
1061 headers: HashMap::new(),
1062 };
1063
1064 let client = SseMCPClient::new(params);
1065
1066 let request_tx = client.request_tx.lock().await;
1067 assert!(request_tx.is_none());
1068 drop(request_tx);
1069
1070 let response_rx = client.response_rx.lock().await;
1071 assert!(response_rx.is_none());
1072 }
1073
1074 #[tokio::test]
1075 async fn test_initialize_session_request_format() {
1076 let params = SseServerParameters {
1077 url: "http://localhost:8081".to_string(),
1078 headers: HashMap::new(),
1079 };
1080
1081 let client = SseMCPClient::new(params);
1082
1083 let result = client.initialize_session().await;
1084 assert!(result.is_err());
1085 }
1086
1087 #[tokio::test]
1088 async fn test_error_handling_in_list_tools() {
1089 let params = SseServerParameters {
1090 url: "http://localhost:8081".to_string(),
1091 headers: HashMap::new(),
1092 };
1093
1094 let client = SseMCPClient::new(params);
1095
1096 client.base.update_state(ClientState::Connected).await;
1097
1098 let result = client.list_tools().await;
1099 assert!(result.is_err());
1100 }
1101
1102 #[tokio::test]
1103 async fn test_error_handling_in_call_tool() {
1104 let params = SseServerParameters {
1105 url: "http://localhost:8081".to_string(),
1106 headers: HashMap::new(),
1107 };
1108
1109 let client = SseMCPClient::new(params);
1110
1111 client.base.update_state(ClientState::Connected).await;
1112
1113 let result = client
1114 .call_tool("test_tool", json!({"param": "value"}))
1115 .await;
1116 assert!(result.is_err());
1117 }
1118
1119 #[tokio::test]
1120 async fn test_sse_client_debug_format() {
1121 let params = SseServerParameters {
1122 url: "http://localhost:8081".to_string(),
1123 headers: HashMap::new(),
1124 };
1125
1126 let client = SseMCPClient::new(params);
1127
1128 let debug_str = format!("{:?}", client);
1129 assert!(debug_str.contains("SseMCPClient"));
1130 }
1131}