1use crate::Result;
4use async_trait::async_trait;
5use rs_utcp::providers::mcp::McpProvider;
6use rs_utcp::transports::mcp::McpTransport as RsUtcpMcpTransport;
7use rs_utcp::transports::ClientTransport;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::Arc;
12use thulp_core::{
13 Error, GetPromptResult, Prompt, PromptListResult, PromptMessage, Resource, ResourceContents,
14 ResourceListResult, ResourceTemplate, ResourceTemplateListResult, ToolCall, ToolDefinition,
15 ToolResult, Transport as CoreTransport,
16};
17
18pub struct McpTransport {
20 inner: RsUtcpMcpTransport,
22 provider: Arc<McpProvider>,
24 connected: bool,
26 http_client: reqwest::Client,
28 request_id: AtomicU64,
30}
31
32impl McpTransport {
33 pub fn new_http(name: String, url: String) -> Self {
35 let provider = Arc::new(McpProvider::new(name, url, None));
36 let inner = RsUtcpMcpTransport::new();
37
38 Self {
39 inner,
40 provider,
41 connected: false,
42 http_client: reqwest::Client::new(),
43 request_id: AtomicU64::new(1000),
44 }
45 }
46
47 pub fn new_stdio(name: String, command: String, args: Option<Vec<String>>) -> Self {
49 let provider = Arc::new(McpProvider::new_stdio(name, command, args, None));
50 let inner = RsUtcpMcpTransport::new();
51
52 Self {
53 inner,
54 provider,
55 connected: false,
56 http_client: reqwest::Client::new(),
57 request_id: AtomicU64::new(1000),
58 }
59 }
60
61 pub fn new_stdio_with_env(
63 name: String,
64 command: String,
65 args: Option<Vec<String>>,
66 env: HashMap<String, String>,
67 ) -> Self {
68 let provider = Arc::new(McpProvider::new_stdio(name, command, args, Some(env)));
69 let inner = RsUtcpMcpTransport::new();
70
71 Self {
72 inner,
73 provider,
74 connected: false,
75 http_client: reqwest::Client::new(),
76 request_id: AtomicU64::new(1000),
77 }
78 }
79
80 pub fn new() -> Self {
82 Self::new_http("default".to_string(), "http://localhost:8080".to_string())
83 }
84
85 fn next_request_id(&self) -> u64 {
86 self.request_id.fetch_add(1, Ordering::Relaxed)
87 }
88
89 async fn raw_jsonrpc(&self, method: &str, params: Value) -> Result<Value> {
90 let url = self
91 .provider
92 .url
93 .as_deref()
94 .ok_or_else(|| Error::ExecutionFailed("resources require HTTP transport".into()))?;
95
96 let request = serde_json::json!({
97 "jsonrpc": "2.0",
98 "method": method,
99 "params": params,
100 "id": self.next_request_id(),
101 });
102
103 let resp = self
104 .http_client
105 .post(url)
106 .json(&request)
107 .send()
108 .await
109 .map_err(|e| Error::ExecutionFailed(format!("HTTP request failed: {e}")))?;
110
111 if !resp.status().is_success() {
112 return Err(Error::ExecutionFailed(format!(
113 "MCP server returned {}",
114 resp.status()
115 )));
116 }
117
118 let body: Value = resp
119 .json()
120 .await
121 .map_err(|e| Error::ExecutionFailed(format!("invalid JSON response: {e}")))?;
122
123 if let Some(err) = body.get("error") {
124 return Err(Error::ExecutionFailed(format!("JSON-RPC error: {err}")));
125 }
126
127 Ok(body.get("result").cloned().unwrap_or(Value::Null))
128 }
129
130 pub async fn list_resources(&self) -> Result<ResourceListResult> {
132 if !self.connected {
133 return Err(Error::ExecutionFailed("not connected".into()));
134 }
135 let result = self.raw_jsonrpc("resources/list", serde_json::json!({})).await?;
136 let resources: Vec<Resource> = serde_json::from_value(
137 result.get("resources").cloned().unwrap_or(Value::Array(vec![])),
138 )
139 .map_err(|e| Error::ExecutionFailed(format!("failed to parse resources: {e}")))?;
140 let next_cursor = result
141 .get("nextCursor")
142 .and_then(|v| v.as_str())
143 .map(String::from);
144 Ok(ResourceListResult {
145 resources,
146 next_cursor,
147 })
148 }
149
150 pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents> {
152 if !self.connected {
153 return Err(Error::ExecutionFailed("not connected".into()));
154 }
155 let result = self
156 .raw_jsonrpc("resources/read", serde_json::json!({ "uri": uri }))
157 .await?;
158 let contents_arr = result
159 .get("contents")
160 .and_then(|v| v.as_array())
161 .cloned()
162 .unwrap_or_default();
163 if let Some(first) = contents_arr.first() {
164 let text = first.get("text").and_then(|v| v.as_str()).map(String::from);
165 let blob = first.get("blob").and_then(|v| v.as_str()).map(String::from);
166 let mime_type = first
167 .get("mimeType")
168 .and_then(|v| v.as_str())
169 .map(String::from);
170 Ok(ResourceContents {
171 uri: first
172 .get("uri")
173 .and_then(|v| v.as_str())
174 .unwrap_or(uri)
175 .to_string(),
176 text,
177 blob,
178 mime_type,
179 })
180 } else {
181 Ok(ResourceContents::text(uri, ""))
182 }
183 }
184
185 pub async fn list_resource_templates(&self) -> Result<ResourceTemplateListResult> {
187 if !self.connected {
188 return Err(Error::ExecutionFailed("not connected".into()));
189 }
190 let result = self
191 .raw_jsonrpc("resources/templates/list", serde_json::json!({}))
192 .await?;
193 let templates: Vec<ResourceTemplate> = serde_json::from_value(
194 result
195 .get("resourceTemplates")
196 .cloned()
197 .unwrap_or(Value::Array(vec![])),
198 )
199 .map_err(|e| Error::ExecutionFailed(format!("failed to parse templates: {e}")))?;
200 let next_cursor = result
201 .get("nextCursor")
202 .and_then(|v| v.as_str())
203 .map(String::from);
204 Ok(ResourceTemplateListResult {
205 resource_templates: templates,
206 next_cursor,
207 })
208 }
209
210 pub async fn subscribe_resource(&self, uri: &str) -> Result<()> {
212 if !self.connected {
213 return Err(Error::ExecutionFailed("not connected".into()));
214 }
215 self.raw_jsonrpc("resources/subscribe", serde_json::json!({ "uri": uri }))
216 .await?;
217 Ok(())
218 }
219
220 pub async fn unsubscribe_resource(&self, uri: &str) -> Result<()> {
222 if !self.connected {
223 return Err(Error::ExecutionFailed("not connected".into()));
224 }
225 self.raw_jsonrpc("resources/unsubscribe", serde_json::json!({ "uri": uri }))
226 .await?;
227 Ok(())
228 }
229
230 pub async fn list_prompts(&self) -> Result<PromptListResult> {
232 if !self.connected {
233 return Err(Error::ExecutionFailed("not connected".into()));
234 }
235 let result = self.raw_jsonrpc("prompts/list", serde_json::json!({})).await?;
236 let prompts: Vec<Prompt> = serde_json::from_value(
237 result.get("prompts").cloned().unwrap_or(Value::Array(vec![])),
238 )
239 .map_err(|e| Error::ExecutionFailed(format!("failed to parse prompts: {e}")))?;
240 let next_cursor = result
241 .get("nextCursor")
242 .and_then(|v| v.as_str())
243 .map(String::from);
244 Ok(PromptListResult {
245 prompts,
246 next_cursor,
247 })
248 }
249
250 pub async fn get_prompt(
252 &self,
253 name: &str,
254 arguments: HashMap<String, String>,
255 ) -> Result<GetPromptResult> {
256 if !self.connected {
257 return Err(Error::ExecutionFailed("not connected".into()));
258 }
259 let params = serde_json::json!({
260 "name": name,
261 "arguments": arguments,
262 });
263 let result = self.raw_jsonrpc("prompts/get", params).await?;
264 let description = result
265 .get("description")
266 .and_then(|v| v.as_str())
267 .map(String::from);
268 let messages: Vec<PromptMessage> = serde_json::from_value(
269 result.get("messages").cloned().unwrap_or(Value::Array(vec![])),
270 )
271 .map_err(|e| Error::ExecutionFailed(format!("failed to parse prompt messages: {e}")))?;
272 Ok(GetPromptResult {
273 description,
274 messages,
275 })
276 }
277}
278
279impl Default for McpTransport {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285#[async_trait]
286impl CoreTransport for McpTransport {
287 async fn connect(&mut self) -> Result<()> {
288 let _tools = self
290 .inner
291 .register_tool_provider(&*self.provider)
292 .await
293 .map_err(|e| Error::ExecutionFailed(format!("Failed to register provider: {}", e)))?;
294
295 self.connected = true;
296 Ok(())
297 }
298
299 async fn disconnect(&mut self) -> Result<()> {
300 self.inner
301 .deregister_tool_provider(&*self.provider)
302 .await
303 .map_err(|e| Error::ExecutionFailed(format!("Failed to deregister provider: {}", e)))?;
304 self.connected = false;
305 Ok(())
306 }
307
308 fn is_connected(&self) -> bool {
309 self.connected
310 }
311
312 async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
313 if !self.connected {
314 return Err(Error::ExecutionFailed("not connected".to_string()));
315 }
316
317 let tools = self
319 .inner
320 .register_tool_provider(&*self.provider)
321 .await
322 .map_err(|e| Error::ExecutionFailed(format!("Failed to list tools: {}", e)))?;
323
324 let mut definitions = Vec::new();
326 for tool in tools {
327 let inputs_json = serde_json::to_value(&tool.inputs).map_err(|e| {
329 Error::ExecutionFailed(format!("Failed to serialize inputs: {}", e))
330 })?;
331
332 let parameters =
334 ToolDefinition::parse_mcp_input_schema(&inputs_json).unwrap_or_default();
335
336 definitions.push(ToolDefinition {
337 name: tool.name,
338 description: tool.description,
339 parameters,
340 });
341 }
342
343 Ok(definitions)
344 }
345
346 async fn call(&self, call: &ToolCall) -> Result<ToolResult> {
347 if !self.connected {
348 return Err(Error::ExecutionFailed("not connected".to_string()));
349 }
350
351 let args: HashMap<String, Value> = match &call.arguments {
353 Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
354 _ => HashMap::new(),
355 };
356
357 let result = self
359 .inner
360 .call_tool(&call.tool, args, &*self.provider)
361 .await
362 .map_err(|e| Error::ExecutionFailed(format!("Tool call failed: {}", e)))?;
363
364 Ok(ToolResult::success(result))
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use serde_json::json;
372
373 #[test]
374 fn transport_new_http() {
375 let transport =
376 McpTransport::new_http("test".to_string(), "http://localhost:8080".to_string());
377 assert!(!transport.is_connected());
378 }
379
380 #[test]
381 fn transport_new_stdio() {
382 let transport = McpTransport::new_stdio(
383 "test".to_string(),
384 "test-cmd".to_string(),
385 Some(vec!["--arg1".to_string()]),
386 );
387 assert!(!transport.is_connected());
388 }
389
390 #[test]
391 fn transport_new_default() {
392 let transport = McpTransport::new();
393 assert!(!transport.is_connected());
394 }
395
396 #[test]
397 fn transport_default_impl() {
398 let transport1 = McpTransport::default();
399 let transport2 = McpTransport::new();
400 assert!(!transport1.is_connected());
401 assert!(!transport2.is_connected());
402 }
403
404 #[tokio::test]
405 async fn transport_connect_disconnect() {
406 let transport =
407 McpTransport::new_http("test".to_string(), "http://localhost:9999".to_string());
408
409 assert!(!transport.is_connected());
411
412 assert_eq!(transport.is_connected(), false);
415 }
416
417 #[test]
418 fn test_argument_conversion() {
419 let call = ToolCall::builder("test_tool")
421 .arg_str("string_param", "value")
422 .arg_int("int_param", 42)
423 .arg_bool("bool_param", true)
424 .build();
425
426 assert!(call.arguments.is_object());
428 assert_eq!(call.arguments["string_param"], "value");
429 assert_eq!(call.arguments["int_param"], 42);
430 assert_eq!(call.arguments["bool_param"], true);
431 }
432
433 #[test]
434 fn test_argument_conversion_nested() {
435 let call = ToolCall::builder("test_tool")
437 .arg("nested", json!({"key": "value", "number": 123}))
438 .build();
439
440 assert!(call.arguments.is_object());
441 assert_eq!(call.arguments["nested"]["key"], "value");
442 assert_eq!(call.arguments["nested"]["number"], 123);
443 }
444
445 #[test]
446 fn test_argument_conversion_array() {
447 let call = ToolCall::builder("test_tool")
449 .arg("items", json!([1, 2, 3, 4, 5]))
450 .build();
451
452 assert!(call.arguments.is_object());
453 assert_eq!(call.arguments["items"], json!([1, 2, 3, 4, 5]));
454 }
455
456 #[tokio::test]
458 async fn test_list_tools_when_disconnected() {
459 let transport = McpTransport::new();
460
461 let result = transport.list_tools().await;
463 assert!(result.is_err());
464 assert!(result.unwrap_err().to_string().contains("not connected"));
465 }
466
467 #[tokio::test]
468 async fn test_call_when_disconnected() {
469 let transport = McpTransport::new();
470 let call = ToolCall::new("test_tool");
471
472 let result = transport.call(&call).await;
474 assert!(result.is_err());
475 assert!(result.unwrap_err().to_string().contains("not connected"));
476 }
477
478 #[test]
479 fn test_empty_arguments() {
480 let call = ToolCall::new("test_tool");
481 assert!(call.arguments.is_object());
482 assert_eq!(call.arguments.as_object().unwrap().len(), 0);
483 }
484
485 #[test]
486 fn test_special_characters_in_tool_name() {
487 let call = ToolCall::new("test-tool_v2.0");
488 assert_eq!(call.tool, "test-tool_v2.0");
489 }
490
491 #[test]
492 fn test_unicode_in_arguments() {
493 let call = ToolCall::builder("test_tool")
494 .arg_str("message", "Hello δΈη π")
495 .build();
496
497 assert_eq!(call.arguments["message"], "Hello δΈη π");
498 }
499
500 #[test]
501 fn test_large_argument_values() {
502 let large_string = "x".repeat(10000);
504 let call = ToolCall::builder("test_tool")
505 .arg_str("data", &large_string)
506 .build();
507
508 assert_eq!(call.arguments["data"].as_str().unwrap().len(), 10000);
509 }
510
511 #[test]
512 fn test_null_arguments() {
513 let call = ToolCall::builder("test_tool")
514 .arg("null_param", json!(null))
515 .build();
516
517 assert!(call.arguments["null_param"].is_null());
518 }
519
520 #[test]
521 fn test_mixed_type_arguments() {
522 let call = ToolCall::builder("test_tool")
523 .arg_str("string", "value")
524 .arg_int("int", 42)
525 .arg_bool("bool", true)
526 .arg("null", json!(null))
527 .arg("array", json!([1, 2, 3]))
528 .arg("object", json!({"key": "value"}))
529 .build();
530
531 assert_eq!(call.arguments["string"], "value");
532 assert_eq!(call.arguments["int"], 42);
533 assert_eq!(call.arguments["bool"], true);
534 assert!(call.arguments["null"].is_null());
535 assert!(call.arguments["array"].is_array());
536 assert!(call.arguments["object"].is_object());
537 }
538
539 #[test]
540 fn test_stdio_transport_creation() {
541 let transport = McpTransport::new_stdio(
542 "echo-server".to_string(),
543 "npx".to_string(),
544 Some(vec![
545 "-y".to_string(),
546 "@modelcontextprotocol/server-echo".to_string(),
547 ]),
548 );
549
550 assert!(!transport.is_connected());
551 }
552
553 #[test]
554 fn test_http_transport_with_https() {
555 let transport = McpTransport::new_http(
556 "secure".to_string(),
557 "https://api.example.com/mcp".to_string(),
558 );
559
560 assert!(!transport.is_connected());
561 }
562
563 #[test]
564 fn test_transport_creation_with_empty_name() {
565 let transport = McpTransport::new_http("".to_string(), "http://localhost:8080".to_string());
566 assert!(!transport.is_connected());
567 }
568
569 #[test]
570 fn test_deeply_nested_arguments() {
571 let nested = json!({
572 "level1": {
573 "level2": {
574 "level3": {
575 "level4": {
576 "level5": "deep"
577 }
578 }
579 }
580 }
581 });
582
583 let call = ToolCall::builder("test_tool")
584 .arg("nested", nested.clone())
585 .build();
586
587 assert_eq!(call.arguments["nested"], nested);
588 assert_eq!(
589 call.arguments["nested"]["level1"]["level2"]["level3"]["level4"]["level5"],
590 "deep"
591 );
592 }
593
594 #[test]
595 fn test_argument_with_numbers() {
596 let call = ToolCall::builder("test_tool")
597 .arg_int("positive", 42)
598 .arg_int("negative", -42)
599 .arg_int("zero", 0)
600 .arg("float", json!(3.14159))
601 .arg("scientific", json!(1.5e10))
602 .build();
603
604 assert_eq!(call.arguments["positive"], 42);
605 assert_eq!(call.arguments["negative"], -42);
606 assert_eq!(call.arguments["zero"], 0);
607 assert_eq!(call.arguments["float"], 3.14159);
608 assert_eq!(call.arguments["scientific"], 1.5e10);
609 }
610
611 #[test]
612 fn test_argument_with_special_json_values() {
613 let call = ToolCall::builder("test_tool")
614 .arg("empty_string", json!(""))
615 .arg("empty_array", json!([]))
616 .arg("empty_object", json!({}))
617 .arg("boolean_true", json!(true))
618 .arg("boolean_false", json!(false))
619 .build();
620
621 assert_eq!(call.arguments["empty_string"], "");
622 assert_eq!(call.arguments["empty_array"], json!([]));
623 assert_eq!(call.arguments["empty_object"], json!({}));
624 assert_eq!(call.arguments["boolean_true"], true);
625 assert_eq!(call.arguments["boolean_false"], false);
626 }
627
628 #[tokio::test]
629 async fn test_list_resources_when_disconnected() {
630 let transport = McpTransport::new();
631 let result = transport.list_resources().await;
632 assert!(result.is_err());
633 assert!(result.unwrap_err().to_string().contains("not connected"));
634 }
635
636 #[tokio::test]
637 async fn test_read_resource_when_disconnected() {
638 let transport = McpTransport::new();
639 let result = transport.read_resource("file:///test.txt").await;
640 assert!(result.is_err());
641 assert!(result.unwrap_err().to_string().contains("not connected"));
642 }
643
644 #[tokio::test]
645 async fn test_list_resource_templates_when_disconnected() {
646 let transport = McpTransport::new();
647 let result = transport.list_resource_templates().await;
648 assert!(result.is_err());
649 assert!(result.unwrap_err().to_string().contains("not connected"));
650 }
651
652 #[tokio::test]
653 async fn test_subscribe_resource_when_disconnected() {
654 let transport = McpTransport::new();
655 let result = transport.subscribe_resource("file:///test.txt").await;
656 assert!(result.is_err());
657 assert!(result.unwrap_err().to_string().contains("not connected"));
658 }
659
660 #[tokio::test]
661 async fn test_unsubscribe_resource_when_disconnected() {
662 let transport = McpTransport::new();
663 let result = transport.unsubscribe_resource("file:///test.txt").await;
664 assert!(result.is_err());
665 assert!(result.unwrap_err().to_string().contains("not connected"));
666 }
667
668 #[tokio::test]
669 async fn test_stdio_transport_rejects_resources() {
670 let mut transport = McpTransport::new_stdio(
671 "test".to_string(),
672 "echo".to_string(),
673 None,
674 );
675 transport.connected = true;
676 let result = transport.list_resources().await;
677 assert!(result.is_err());
678 assert!(result.unwrap_err().to_string().contains("HTTP transport"));
679 }
680
681 #[tokio::test]
682 async fn test_list_prompts_when_disconnected() {
683 let transport = McpTransport::new();
684 let result = transport.list_prompts().await;
685 assert!(result.is_err());
686 assert!(result.unwrap_err().to_string().contains("not connected"));
687 }
688
689 #[tokio::test]
690 async fn test_get_prompt_when_disconnected() {
691 let transport = McpTransport::new();
692 let result = transport.get_prompt("test", HashMap::new()).await;
693 assert!(result.is_err());
694 assert!(result.unwrap_err().to_string().contains("not connected"));
695 }
696
697 #[tokio::test]
698 async fn test_stdio_transport_rejects_prompts() {
699 let mut transport = McpTransport::new_stdio(
700 "test".to_string(),
701 "echo".to_string(),
702 None,
703 );
704 transport.connected = true;
705 let result = transport.list_prompts().await;
706 assert!(result.is_err());
707 assert!(result.unwrap_err().to_string().contains("HTTP transport"));
708 }
709
710 #[test]
711 fn transport_new_stdio_with_env() {
712 let mut env = HashMap::new();
713 env.insert("API_KEY".to_string(), "secret123".to_string());
714 env.insert("DEBUG".to_string(), "1".to_string());
715
716 let transport = McpTransport::new_stdio_with_env(
717 "test-server".to_string(),
718 "node".to_string(),
719 Some(vec!["server.js".to_string()]),
720 env,
721 );
722
723 assert!(!transport.is_connected());
724 let env_vars = transport.provider.env_vars.as_ref().unwrap();
726 assert_eq!(env_vars.get("API_KEY").unwrap(), "secret123");
727 assert_eq!(env_vars.get("DEBUG").unwrap(), "1");
728 assert_eq!(env_vars.len(), 2);
729 }
730
731 #[test]
732 fn transport_new_stdio_with_empty_env() {
733 let env = HashMap::new();
734 let transport = McpTransport::new_stdio_with_env(
735 "test".to_string(),
736 "echo".to_string(),
737 None,
738 env,
739 );
740
741 assert!(!transport.is_connected());
742 let env_vars = transport.provider.env_vars.as_ref().unwrap();
743 assert!(env_vars.is_empty());
744 }
745
746 #[test]
747 fn transport_new_stdio_has_no_env() {
748 let transport = McpTransport::new_stdio(
749 "test".to_string(),
750 "echo".to_string(),
751 None,
752 );
753
754 assert!(transport.provider.env_vars.is_none());
756 }
757
758 #[test]
759 fn test_request_id_increments() {
760 let transport = McpTransport::new();
761 let id1 = transport.next_request_id();
762 let id2 = transport.next_request_id();
763 assert_eq!(id2, id1 + 1);
764 }
765}