smcp_computer/mcp_clients/
http_client.rs1use super::base_client::BaseMCPClient;
11use super::model::*;
12use super::{ResourceCache, SubscriptionManager};
13use crate::desktop::window_uri::{is_window_uri, WindowURI};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde_json;
17use std::time::Duration;
18use tracing::{debug, info, warn};
19
20pub struct HttpMCPClient {
22 base: BaseMCPClient<HttpServerParameters>,
24 http_client: Client,
26 session_id: std::sync::Arc<tokio::sync::Mutex<Option<String>>>,
28 subscription_manager: SubscriptionManager,
30 resource_cache: ResourceCache,
32}
33
34impl std::fmt::Debug for HttpMCPClient {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("HttpMCPClient")
37 .field("url", &self.base.params.url)
38 .field("headers", &self.base.params.headers)
39 .field("state", &self.base.state())
40 .finish()
41 }
42}
43
44impl HttpMCPClient {
45 pub fn new(params: HttpServerParameters) -> Self {
47 let http_client = Client::builder()
48 .timeout(std::time::Duration::from_secs(30))
49 .build()
50 .expect("Failed to create HTTP client");
51
52 Self {
53 base: BaseMCPClient::new(params),
54 http_client,
55 session_id: std::sync::Arc::new(tokio::sync::Mutex::new(None)),
56 subscription_manager: SubscriptionManager::new(),
57 resource_cache: ResourceCache::new(Duration::from_secs(60)), }
59 }
60
61 async fn send_request(
63 &self,
64 method: &str,
65 params: Option<serde_json::Value>,
66 ) -> Result<serde_json::Value, MCPClientError> {
67 let url = &self.base.params.url;
68
69 let mut request_body = serde_json::json!({
70 "jsonrpc": "2.0",
71 "method": method,
72 });
73
74 if let Some(p) = params {
75 request_body["params"] = p;
76 }
77
78 request_body["id"] = serde_json::Value::Number(serde_json::Number::from(
80 std::time::SystemTime::now()
81 .duration_since(std::time::UNIX_EPOCH)
82 .unwrap()
83 .as_secs() as i64,
84 ));
85
86 debug!("Sending HTTP request to {}: {}", url, request_body);
87
88 let mut request = self.http_client.post(url);
89
90 for (key, value) in &self.base.params.headers {
92 request = request.header(key, value);
93 }
94
95 request = request.header("Content-Type", "application/json");
97 request = request.header("Accept", "application/json, text/event-stream");
98
99 if let Some(ref sid) = *self.session_id.lock().await {
101 request = request.header("Mcp-Session-Id", sid.as_str());
102 }
103
104 let response =
105 request.json(&request_body).send().await.map_err(|e| {
106 MCPClientError::ConnectionError(format!("HTTP request failed: {}", e))
107 })?;
108
109 let status = response.status();
110
111 if status == reqwest::StatusCode::ACCEPTED || status == reqwest::StatusCode::NO_CONTENT {
113 return Ok(serde_json::json!({}));
114 }
115
116 if !status.is_success() {
117 return Err(MCPClientError::ConnectionError(format!(
118 "HTTP error: {}",
119 status
120 )));
121 }
122
123 if let Some(sid) = response
125 .headers()
126 .get("mcp-session-id")
127 .and_then(|v| v.to_str().ok())
128 {
129 *self.session_id.lock().await = Some(sid.to_string());
130 }
131
132 let content_type = response
134 .headers()
135 .get("content-type")
136 .and_then(|v| v.to_str().ok())
137 .unwrap_or("")
138 .to_string();
139
140 let response_body: serde_json::Value = if content_type.contains("text/event-stream") {
141 let text = response.text().await.map_err(|e| {
143 MCPClientError::ProtocolError(format!("Failed to read SSE response: {}", e))
144 })?;
145 parse_sse_response(&text)?
146 } else {
147 response.json().await.map_err(|e| {
148 MCPClientError::ProtocolError(format!("Failed to parse response: {}", e))
149 })?
150 };
151
152 debug!("Received HTTP response: {}", response_body);
153
154 Ok(response_body)
155 }
156
157 async fn initialize_session(&self) -> Result<(), MCPClientError> {
159 let params = serde_json::json!({
160 "protocolVersion": "2024-11-05",
161 "capabilities": {
162 "tools": {},
163 "resources": {}
164 },
165 "clientInfo": {
166 "name": "a2c-smcp-rust",
167 "version": "0.1.0"
168 }
169 });
170
171 let response = self.send_request("initialize", Some(params)).await?;
172
173 if let Some(error) = response.get("error") {
175 return Err(MCPClientError::ProtocolError(format!(
176 "Initialize error: {}",
177 error
178 )));
179 }
180
181 if let Some(result) = response.get("result") {
182 if self.session_id.lock().await.is_none() {
184 if let Some(session_id) = result.get("sessionId").and_then(|v| v.as_str()) {
185 *self.session_id.lock().await = Some(session_id.to_string());
186 }
187 }
188 }
189
190 self.send_request("notifications/initialized", None).await?;
192
193 info!("HTTP session initialized successfully");
194 Ok(())
195 }
196
197 pub async fn is_subscribed(&self, uri: &str) -> bool {
201 self.subscription_manager.is_subscribed(uri).await
202 }
203
204 pub async fn get_subscriptions(&self) -> Vec<String> {
206 self.subscription_manager.get_subscriptions().await
207 }
208
209 pub async fn subscription_count(&self) -> usize {
211 self.subscription_manager.subscription_count().await
212 }
213
214 pub async fn get_cached_resource(&self, uri: &str) -> Option<serde_json::Value> {
218 self.resource_cache.get(uri).await
219 }
220
221 pub async fn has_cache(&self, uri: &str) -> bool {
223 self.resource_cache.contains(uri).await
224 }
225
226 pub async fn cache_size(&self) -> usize {
228 self.resource_cache.size().await
229 }
230
231 pub async fn cleanup_cache(&self) -> usize {
233 self.resource_cache.cleanup_expired().await
234 }
235
236 pub async fn cache_keys(&self) -> Vec<String> {
238 self.resource_cache.keys().await
239 }
240
241 pub async fn clear_cache(&self) {
243 self.resource_cache.clear().await
244 }
245}
246
247#[async_trait]
248impl MCPClientProtocol for HttpMCPClient {
249 fn state(&self) -> ClientState {
250 self.base.state()
251 }
252
253 async fn connect(&self) -> Result<(), MCPClientError> {
254 if !self.base.can_connect().await {
256 return Err(MCPClientError::ConnectionError(format!(
257 "Cannot connect in state: {}",
258 self.base.get_state().await
259 )));
260 }
261
262 self.initialize_session().await?;
264
265 self.base.update_state(ClientState::Connected).await;
267 info!("HTTP client connected successfully");
268
269 Ok(())
270 }
271
272 async fn disconnect(&self) -> Result<(), MCPClientError> {
273 if !self.base.can_disconnect().await {
275 return Err(MCPClientError::ConnectionError(format!(
276 "Cannot disconnect in state: {}",
277 self.base.get_state().await
278 )));
279 }
280
281 if let Err(e) = self.send_request("shutdown", None).await {
283 warn!("Failed to send shutdown request: {}", e);
284 }
285
286 if let Err(e) = self.send_request("exit", None).await {
288 warn!("Failed to send exit notification: {}", e);
289 }
290
291 *self.session_id.lock().await = None;
293
294 self.base.update_state(ClientState::Disconnected).await;
296 info!("HTTP client disconnected successfully");
297
298 Ok(())
299 }
300
301 async fn list_tools(&self) -> Result<Vec<Tool>, MCPClientError> {
302 if self.base.get_state().await != ClientState::Connected {
303 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
304 }
305
306 let response = self.send_request("tools/list", None).await?;
307
308 if let Some(error) = response.get("error") {
309 return Err(MCPClientError::ProtocolError(format!(
310 "List tools error: {}",
311 error
312 )));
313 }
314
315 if let Some(result) = response.get("result") {
316 if let Some(tools) = result.get("tools").and_then(|v| v.as_array()) {
317 let mut tool_list = Vec::new();
318 for tool in tools {
319 if let Ok(parsed_tool) = serde_json::from_value::<Tool>(tool.clone()) {
320 tool_list.push(parsed_tool);
321 }
322 }
323 return Ok(tool_list);
324 }
325 }
326
327 Ok(vec![])
328 }
329
330 async fn call_tool(
331 &self,
332 tool_name: &str,
333 params: serde_json::Value,
334 ) -> Result<CallToolResult, MCPClientError> {
335 if self.base.get_state().await != ClientState::Connected {
336 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
337 }
338
339 let call_params = serde_json::json!({
340 "name": tool_name,
341 "arguments": params
342 });
343
344 let response = self.send_request("tools/call", Some(call_params)).await?;
345
346 if let Some(error) = response.get("error") {
347 return Err(MCPClientError::ProtocolError(format!(
348 "Call tool error: {}",
349 error
350 )));
351 }
352
353 if let Some(result) = response.get("result") {
354 let call_result: CallToolResult = serde_json::from_value(result.clone())?;
355 return Ok(call_result);
356 }
357
358 Err(MCPClientError::ProtocolError(
359 "Invalid response".to_string(),
360 ))
361 }
362
363 async fn list_windows(&self) -> Result<Vec<Resource>, MCPClientError> {
364 if self.base.get_state().await != ClientState::Connected {
365 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
366 }
367
368 let mut all_resources = Vec::new();
370 let mut cursor: Option<String> = None;
371
372 loop {
373 let params = cursor.as_ref().map(|c| serde_json::json!({ "cursor": c }));
374
375 let response = self.send_request("resources/list", params).await?;
376
377 if let Some(error) = response.get("error") {
378 return Err(MCPClientError::ProtocolError(format!(
379 "List resources error: {}",
380 error
381 )));
382 }
383
384 if let Some(result) = response.get("result") {
385 if let Some(resources) = result.get("resources").and_then(|v| v.as_array()) {
387 for resource in resources {
388 if let Ok(parsed_resource) =
389 serde_json::from_value::<Resource>(resource.clone())
390 {
391 all_resources.push(parsed_resource);
392 }
393 }
394 }
395
396 cursor = result
398 .get("nextCursor")
399 .and_then(|v| v.as_str())
400 .map(|s| s.to_string());
401
402 if cursor.is_none() {
403 break;
404 }
405 } else {
406 break;
407 }
408 }
409
410 let mut filtered_resources: Vec<(Resource, i32)> = Vec::new();
412
413 for resource in all_resources {
414 if !is_window_uri(&resource.uri) {
415 continue;
416 }
417
418 let priority = if let Ok(uri) = WindowURI::new(&resource.uri) {
420 uri.priority().unwrap_or(0)
421 } else {
422 0
423 };
424
425 filtered_resources.push((resource, priority));
426 }
427
428 filtered_resources.sort_by(|a, b| b.1.cmp(&a.1));
430
431 Ok(filtered_resources.into_iter().map(|(r, _)| r).collect())
433 }
434
435 async fn get_window_detail(
436 &self,
437 resource: Resource,
438 ) -> Result<ReadResourceResult, MCPClientError> {
439 if self.base.get_state().await != ClientState::Connected {
440 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
441 }
442
443 let params = serde_json::json!({
444 "uri": resource.uri
445 });
446
447 let response = self.send_request("resources/read", Some(params)).await?;
448
449 if let Some(error) = response.get("error") {
450 return Err(MCPClientError::ProtocolError(format!(
451 "Read resource error: {}",
452 error
453 )));
454 }
455
456 if let Some(result) = response.get("result") {
457 let read_result: ReadResourceResult = serde_json::from_value(result.clone())?;
458 return Ok(read_result);
459 }
460
461 Err(MCPClientError::ProtocolError(
462 "Invalid response".to_string(),
463 ))
464 }
465
466 async fn subscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
467 if self.base.get_state().await != ClientState::Connected {
468 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
469 }
470
471 let params = serde_json::json!({
472 "uri": resource.uri
473 });
474
475 let response = self
476 .send_request("resources/subscribe", Some(params))
477 .await?;
478
479 if let Some(error) = response.get("error") {
480 return Err(MCPClientError::ProtocolError(format!(
481 "Subscribe resource error: {}",
482 error
483 )));
484 }
485
486 let _ = self
488 .subscription_manager
489 .add_subscription(resource.uri.clone())
490 .await;
491
492 match self.get_window_detail(resource.clone()).await {
494 Ok(result) => {
495 if !result.contents.is_empty() {
496 if let Ok(json_value) = serde_json::to_value(&result.contents[0]) {
497 self.resource_cache
498 .set(resource.uri.clone(), json_value, None)
499 .await;
500 info!("Subscribed and cached: {}", resource.uri);
501 }
502 }
503 }
504 Err(e) => {
505 warn!("Failed to fetch resource data after subscription: {:?}", e);
506 }
507 }
508
509 Ok(())
510 }
511
512 async fn unsubscribe_window(&self, resource: Resource) -> Result<(), MCPClientError> {
513 if self.base.get_state().await != ClientState::Connected {
514 return Err(MCPClientError::ConnectionError("Not connected".to_string()));
515 }
516
517 let params = serde_json::json!({
518 "uri": resource.uri
519 });
520
521 let response = self
522 .send_request("resources/unsubscribe", Some(params))
523 .await?;
524
525 if let Some(error) = response.get("error") {
526 return Err(MCPClientError::ProtocolError(format!(
527 "Unsubscribe resource error: {}",
528 error
529 )));
530 }
531
532 let _ = self
534 .subscription_manager
535 .remove_subscription(&resource.uri)
536 .await;
537
538 self.resource_cache.remove(&resource.uri).await;
540 info!("Unsubscribed and removed cache: {}", resource.uri);
541
542 Ok(())
543 }
544}
545
546fn parse_sse_response(text: &str) -> Result<serde_json::Value, MCPClientError> {
549 let mut last_json = None;
550 for line in text.lines() {
551 if let Some(data) = line.strip_prefix("data:") {
552 let data = data.trim();
553 if !data.is_empty() {
554 if let Ok(value) = serde_json::from_str::<serde_json::Value>(data) {
555 if value.get("result").is_some() || value.get("error").is_some() {
557 return Ok(value);
558 }
559 last_json = Some(value);
560 }
561 }
562 }
563 }
564 last_json.ok_or_else(|| {
565 MCPClientError::ProtocolError(format!(
566 "No JSON-RPC message found in SSE response: {}",
567 text.chars().take(200).collect::<String>()
568 ))
569 })
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use serde_json::json;
576 use std::collections::HashMap;
577
578 #[test]
579 fn test_parse_sse_response_basic() {
580 let sse =
581 "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n";
582 let result = parse_sse_response(sse).unwrap();
583 assert!(result.get("result").is_some());
584 }
585
586 #[test]
587 fn test_parse_sse_response_multiple_data_lines() {
588 let sse = "data: {\"jsonrpc\":\"2.0\",\"method\":\"ping\"}\n\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n";
589 let result = parse_sse_response(sse).unwrap();
590 assert_eq!(result["result"]["ok"], json!(true));
591 }
592
593 #[test]
594 fn test_parse_sse_response_no_data() {
595 let sse = "event: endpoint\n: comment\n\n";
596 assert!(parse_sse_response(sse).is_err());
597 }
598
599 #[tokio::test]
600 async fn test_http_client_creation() {
601 let params = HttpServerParameters {
602 url: "http://localhost:8080".to_string(),
603 headers: HashMap::new(),
604 };
605
606 let client = HttpMCPClient::new(params);
607 assert_eq!(client.state(), ClientState::Initialized);
608 assert_eq!(client.base.params.url, "http://localhost:8080");
609 }
610
611 #[tokio::test]
612 async fn test_http_client_with_headers() {
613 let mut headers = HashMap::new();
614 headers.insert("Authorization".to_string(), "Bearer token123".to_string());
615 headers.insert("Content-Type".to_string(), "application/json".to_string());
616
617 let params = HttpServerParameters {
618 url: "http://localhost:8080".to_string(),
619 headers,
620 };
621
622 let client = HttpMCPClient::new(params);
623 assert_eq!(
624 client.base.params.headers.get("Authorization"),
625 Some(&"Bearer token123".to_string())
626 );
627 }
628
629 #[tokio::test]
630 async fn test_session_id_management() {
631 let params = HttpServerParameters {
632 url: "http://localhost:8080".to_string(),
633 headers: HashMap::new(),
634 };
635
636 let client = HttpMCPClient::new(params);
637
638 let session_id = client.session_id.lock().await;
640 assert!(session_id.is_none());
641 drop(session_id);
642
643 *client.session_id.lock().await = Some("session123".to_string());
645 let session_id = client.session_id.lock().await;
646 assert_eq!(session_id.as_ref().unwrap(), "session123");
647 }
648
649 #[tokio::test]
650 async fn test_send_request_format() {
651 let params = HttpServerParameters {
652 url: "http://localhost:8080".to_string(),
653 headers: HashMap::new(),
654 };
655
656 let client = HttpMCPClient::new(params);
657
658 let method = "test/method";
664 let params = Some(json!({"param1": "value1"}));
665
666 let result = client.send_request(method, params).await;
669 assert!(result.is_err());
670 assert!(matches!(
671 result.unwrap_err(),
672 MCPClientError::ConnectionError(_)
673 ));
674 }
675
676 #[tokio::test]
677 async fn test_connect_state_checks() {
678 let params = HttpServerParameters {
679 url: "http://localhost:8080".to_string(),
680 headers: HashMap::new(),
681 };
682
683 let client = HttpMCPClient::new(params);
684
685 client.base.update_state(ClientState::Connected).await;
687 let result = client.connect().await;
688 assert!(result.is_err());
689 assert!(matches!(
690 result.unwrap_err(),
691 MCPClientError::ConnectionError(_)
692 ));
693 }
694
695 #[tokio::test]
696 async fn test_disconnect_state_checks() {
697 let params = HttpServerParameters {
698 url: "http://localhost:8080".to_string(),
699 headers: HashMap::new(),
700 };
701
702 let client = HttpMCPClient::new(params);
703
704 let result = client.disconnect().await;
706 assert!(result.is_err());
707 assert!(matches!(
708 result.unwrap_err(),
709 MCPClientError::ConnectionError(_)
710 ));
711 }
712
713 #[tokio::test]
714 async fn test_list_tools_requires_connection() {
715 let params = HttpServerParameters {
716 url: "http://localhost:8080".to_string(),
717 headers: HashMap::new(),
718 };
719
720 let client = HttpMCPClient::new(params);
721
722 let result = client.list_tools().await;
724 assert!(result.is_err());
725 assert!(matches!(
726 result.unwrap_err(),
727 MCPClientError::ConnectionError(_)
728 ));
729 }
730
731 #[tokio::test]
732 async fn test_call_tool_requires_connection() {
733 let params = HttpServerParameters {
734 url: "http://localhost:8080".to_string(),
735 headers: HashMap::new(),
736 };
737
738 let client = HttpMCPClient::new(params);
739
740 let result = client.call_tool("test_tool", json!({})).await;
742 assert!(result.is_err());
743 assert!(matches!(
744 result.unwrap_err(),
745 MCPClientError::ConnectionError(_)
746 ));
747 }
748
749 #[tokio::test]
750 async fn test_list_windows_requires_connection() {
751 let params = HttpServerParameters {
752 url: "http://localhost:8080".to_string(),
753 headers: HashMap::new(),
754 };
755
756 let client = HttpMCPClient::new(params);
757
758 let result = client.list_windows().await;
760 assert!(result.is_err());
761 assert!(matches!(
762 result.unwrap_err(),
763 MCPClientError::ConnectionError(_)
764 ));
765 }
766
767 #[tokio::test]
768 async fn test_get_window_detail_requires_connection() {
769 let params = HttpServerParameters {
770 url: "http://localhost:8080".to_string(),
771 headers: HashMap::new(),
772 };
773
774 let client = HttpMCPClient::new(params);
775
776 let resource = Resource {
777 uri: "window://123".to_string(),
778 name: "Test Window".to_string(),
779 description: None,
780 mime_type: None,
781 };
782
783 let result = client.get_window_detail(resource).await;
785 assert!(result.is_err());
786 assert!(matches!(
787 result.unwrap_err(),
788 MCPClientError::ConnectionError(_)
789 ));
790 }
791
792 #[tokio::test]
793 async fn test_initialize_session_request_format() {
794 let params = HttpServerParameters {
795 url: "http://localhost:8080".to_string(),
796 headers: HashMap::new(),
797 };
798
799 let client = HttpMCPClient::new(params);
800
801 let result = client.initialize_session().await;
803 assert!(result.is_err());
804 }
805
806 #[tokio::test]
807 async fn test_disconnect_cleanup() {
808 let params = HttpServerParameters {
809 url: "http://localhost:8080".to_string(),
810 headers: HashMap::new(),
811 };
812
813 let client = HttpMCPClient::new(params);
814
815 *client.session_id.lock().await = Some("session123".to_string());
817
818 client.base.update_state(ClientState::Connected).await;
820
821 let _ = client.disconnect().await;
823
824 let session_id = client.session_id.lock().await;
826 assert!(session_id.is_none());
827
828 assert_eq!(client.base.get_state().await, ClientState::Disconnected);
830 }
831
832 #[tokio::test]
833 async fn test_error_handling_in_list_tools() {
834 let params = HttpServerParameters {
835 url: "http://localhost:8080".to_string(),
836 headers: HashMap::new(),
837 };
838
839 let client = HttpMCPClient::new(params);
840
841 client.base.update_state(ClientState::Connected).await;
843
844 let result = client.list_tools().await;
846 assert!(result.is_err());
847 }
848
849 #[tokio::test]
850 async fn test_error_handling_in_call_tool() {
851 let params = HttpServerParameters {
852 url: "http://localhost:8080".to_string(),
853 headers: HashMap::new(),
854 };
855
856 let client = HttpMCPClient::new(params);
857
858 client.base.update_state(ClientState::Connected).await;
860
861 let result = client
863 .call_tool("test_tool", json!({"param": "value"}))
864 .await;
865 assert!(result.is_err());
866 }
867
868 #[tokio::test]
869 async fn test_http_client_debug_format() {
870 let params = HttpServerParameters {
871 url: "http://localhost:8080".to_string(),
872 headers: HashMap::new(),
873 };
874
875 let client = HttpMCPClient::new(params);
876
877 let debug_str = format!("{:?}", client);
879 assert!(debug_str.contains("HttpMCPClient"));
880 }
881}