1use crate::cli::{Connection, TransportKind};
4use crate::error::{CliError, CliResult};
5use std::collections::HashMap;
6use std::time::Duration;
7use turbomcp_client::Client;
8use turbomcp_protocol::types::Tool;
9
10#[cfg(feature = "stdio")]
11use turbomcp_transport::child_process::{ChildProcessConfig, ChildProcessTransport};
12
13#[cfg(feature = "tcp")]
14use turbomcp_transport::tcp::TcpTransportBuilder;
15
16#[cfg(feature = "unix")]
17use turbomcp_transport::unix::UnixTransportBuilder;
18
19#[cfg(feature = "http")]
20use turbomcp_transport::streamable_http_client::{
21 StreamableHttpClientConfig, StreamableHttpClientTransport,
22};
23
24#[cfg(feature = "websocket")]
25use turbomcp_transport::{WebSocketBidirectionalConfig, WebSocketBidirectionalTransport};
26
27pub struct UnifiedClient {
29 inner: ClientInner,
30}
31
32enum ClientInner {
33 #[cfg(feature = "stdio")]
34 Stdio(Client<ChildProcessTransport>),
35 #[cfg(feature = "tcp")]
36 Tcp(Client<turbomcp_transport::tcp::TcpTransport>),
37 #[cfg(feature = "unix")]
38 Unix(Client<turbomcp_transport::unix::UnixTransport>),
39 #[cfg(feature = "http")]
40 Http(Client<StreamableHttpClientTransport>),
41 #[cfg(feature = "websocket")]
42 WebSocket(Client<WebSocketBidirectionalTransport>),
43}
44
45impl UnifiedClient {
46 pub async fn initialize(&self) -> CliResult<turbomcp_client::InitializeResult> {
47 match &self.inner {
48 #[cfg(feature = "stdio")]
49 ClientInner::Stdio(client) => Ok(client.initialize().await?),
50 #[cfg(feature = "tcp")]
51 ClientInner::Tcp(client) => Ok(client.initialize().await?),
52 #[cfg(feature = "unix")]
53 ClientInner::Unix(client) => Ok(client.initialize().await?),
54 #[cfg(feature = "http")]
55 ClientInner::Http(client) => Ok(client.initialize().await?),
56 #[cfg(feature = "websocket")]
57 ClientInner::WebSocket(client) => Ok(client.initialize().await?),
58 }
59 }
60
61 pub async fn list_tools(&self) -> CliResult<Vec<Tool>> {
62 match &self.inner {
63 #[cfg(feature = "stdio")]
64 ClientInner::Stdio(client) => Ok(client.list_tools().await?),
65 #[cfg(feature = "tcp")]
66 ClientInner::Tcp(client) => Ok(client.list_tools().await?),
67 #[cfg(feature = "unix")]
68 ClientInner::Unix(client) => Ok(client.list_tools().await?),
69 #[cfg(feature = "http")]
70 ClientInner::Http(client) => Ok(client.list_tools().await?),
71 #[cfg(feature = "websocket")]
72 ClientInner::WebSocket(client) => Ok(client.list_tools().await?),
73 }
74 }
75
76 pub async fn call_tool(
77 &self,
78 name: &str,
79 arguments: Option<HashMap<String, serde_json::Value>>,
80 ) -> CliResult<serde_json::Value> {
81 let result = match &self.inner {
82 #[cfg(feature = "stdio")]
83 ClientInner::Stdio(client) => client.call_tool(name, arguments).await?,
84 #[cfg(feature = "tcp")]
85 ClientInner::Tcp(client) => client.call_tool(name, arguments).await?,
86 #[cfg(feature = "unix")]
87 ClientInner::Unix(client) => client.call_tool(name, arguments).await?,
88 #[cfg(feature = "http")]
89 ClientInner::Http(client) => client.call_tool(name, arguments).await?,
90 #[cfg(feature = "websocket")]
91 ClientInner::WebSocket(client) => client.call_tool(name, arguments).await?,
92 };
93
94 Ok(serde_json::to_value(result)?)
96 }
97
98 pub async fn list_resources(&self) -> CliResult<Vec<turbomcp_protocol::types::Resource>> {
99 match &self.inner {
100 #[cfg(feature = "stdio")]
101 ClientInner::Stdio(client) => Ok(client.list_resources().await?),
102 #[cfg(feature = "tcp")]
103 ClientInner::Tcp(client) => Ok(client.list_resources().await?),
104 #[cfg(feature = "unix")]
105 ClientInner::Unix(client) => Ok(client.list_resources().await?),
106 #[cfg(feature = "http")]
107 ClientInner::Http(client) => Ok(client.list_resources().await?),
108 #[cfg(feature = "websocket")]
109 ClientInner::WebSocket(client) => Ok(client.list_resources().await?),
110 }
111 }
112
113 pub async fn read_resource(
114 &self,
115 uri: &str,
116 ) -> CliResult<turbomcp_protocol::types::ReadResourceResult> {
117 match &self.inner {
118 #[cfg(feature = "stdio")]
119 ClientInner::Stdio(client) => Ok(client.read_resource(uri).await?),
120 #[cfg(feature = "tcp")]
121 ClientInner::Tcp(client) => Ok(client.read_resource(uri).await?),
122 #[cfg(feature = "unix")]
123 ClientInner::Unix(client) => Ok(client.read_resource(uri).await?),
124 #[cfg(feature = "http")]
125 ClientInner::Http(client) => Ok(client.read_resource(uri).await?),
126 #[cfg(feature = "websocket")]
127 ClientInner::WebSocket(client) => Ok(client.read_resource(uri).await?),
128 }
129 }
130
131 pub async fn list_resource_templates(&self) -> CliResult<Vec<String>> {
132 match &self.inner {
133 #[cfg(feature = "stdio")]
134 ClientInner::Stdio(client) => Ok(client.list_resource_templates().await?),
135 #[cfg(feature = "tcp")]
136 ClientInner::Tcp(client) => Ok(client.list_resource_templates().await?),
137 #[cfg(feature = "unix")]
138 ClientInner::Unix(client) => Ok(client.list_resource_templates().await?),
139 #[cfg(feature = "http")]
140 ClientInner::Http(client) => Ok(client.list_resource_templates().await?),
141 #[cfg(feature = "websocket")]
142 ClientInner::WebSocket(client) => Ok(client.list_resource_templates().await?),
143 }
144 }
145
146 pub async fn subscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
147 match &self.inner {
148 #[cfg(feature = "stdio")]
149 ClientInner::Stdio(client) => Ok(client.subscribe(uri).await?),
150 #[cfg(feature = "tcp")]
151 ClientInner::Tcp(client) => Ok(client.subscribe(uri).await?),
152 #[cfg(feature = "unix")]
153 ClientInner::Unix(client) => Ok(client.subscribe(uri).await?),
154 #[cfg(feature = "http")]
155 ClientInner::Http(client) => Ok(client.subscribe(uri).await?),
156 #[cfg(feature = "websocket")]
157 ClientInner::WebSocket(client) => Ok(client.subscribe(uri).await?),
158 }
159 }
160
161 pub async fn unsubscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
162 match &self.inner {
163 #[cfg(feature = "stdio")]
164 ClientInner::Stdio(client) => Ok(client.unsubscribe(uri).await?),
165 #[cfg(feature = "tcp")]
166 ClientInner::Tcp(client) => Ok(client.unsubscribe(uri).await?),
167 #[cfg(feature = "unix")]
168 ClientInner::Unix(client) => Ok(client.unsubscribe(uri).await?),
169 #[cfg(feature = "http")]
170 ClientInner::Http(client) => Ok(client.unsubscribe(uri).await?),
171 #[cfg(feature = "websocket")]
172 ClientInner::WebSocket(client) => Ok(client.unsubscribe(uri).await?),
173 }
174 }
175
176 pub async fn list_prompts(&self) -> CliResult<Vec<turbomcp_protocol::types::Prompt>> {
177 match &self.inner {
178 #[cfg(feature = "stdio")]
179 ClientInner::Stdio(client) => Ok(client.list_prompts().await?),
180 #[cfg(feature = "tcp")]
181 ClientInner::Tcp(client) => Ok(client.list_prompts().await?),
182 #[cfg(feature = "unix")]
183 ClientInner::Unix(client) => Ok(client.list_prompts().await?),
184 #[cfg(feature = "http")]
185 ClientInner::Http(client) => Ok(client.list_prompts().await?),
186 #[cfg(feature = "websocket")]
187 ClientInner::WebSocket(client) => Ok(client.list_prompts().await?),
188 }
189 }
190
191 pub async fn get_prompt(
192 &self,
193 name: &str,
194 arguments: Option<HashMap<String, serde_json::Value>>,
195 ) -> CliResult<turbomcp_protocol::types::GetPromptResult> {
196 match &self.inner {
197 #[cfg(feature = "stdio")]
198 ClientInner::Stdio(client) => Ok(client.get_prompt(name, arguments).await?),
199 #[cfg(feature = "tcp")]
200 ClientInner::Tcp(client) => Ok(client.get_prompt(name, arguments).await?),
201 #[cfg(feature = "unix")]
202 ClientInner::Unix(client) => Ok(client.get_prompt(name, arguments).await?),
203 #[cfg(feature = "http")]
204 ClientInner::Http(client) => Ok(client.get_prompt(name, arguments).await?),
205 #[cfg(feature = "websocket")]
206 ClientInner::WebSocket(client) => Ok(client.get_prompt(name, arguments).await?),
207 }
208 }
209
210 pub async fn complete_prompt(
211 &self,
212 prompt_name: &str,
213 argument_name: &str,
214 argument_value: &str,
215 context: Option<turbomcp_protocol::types::CompletionContext>,
216 ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
217 match &self.inner {
218 #[cfg(feature = "stdio")]
219 ClientInner::Stdio(client) => Ok(client
220 .complete_prompt(prompt_name, argument_name, argument_value, context)
221 .await?),
222 #[cfg(feature = "tcp")]
223 ClientInner::Tcp(client) => Ok(client
224 .complete_prompt(prompt_name, argument_name, argument_value, context)
225 .await?),
226 #[cfg(feature = "unix")]
227 ClientInner::Unix(client) => Ok(client
228 .complete_prompt(prompt_name, argument_name, argument_value, context)
229 .await?),
230 #[cfg(feature = "http")]
231 ClientInner::Http(client) => Ok(client
232 .complete_prompt(prompt_name, argument_name, argument_value, context)
233 .await?),
234 #[cfg(feature = "websocket")]
235 ClientInner::WebSocket(client) => Ok(client
236 .complete_prompt(prompt_name, argument_name, argument_value, context)
237 .await?),
238 }
239 }
240
241 pub async fn complete_resource(
242 &self,
243 resource_uri: &str,
244 argument_name: &str,
245 argument_value: &str,
246 context: Option<turbomcp_protocol::types::CompletionContext>,
247 ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
248 match &self.inner {
249 #[cfg(feature = "stdio")]
250 ClientInner::Stdio(client) => Ok(client
251 .complete_resource(resource_uri, argument_name, argument_value, context)
252 .await?),
253 #[cfg(feature = "tcp")]
254 ClientInner::Tcp(client) => Ok(client
255 .complete_resource(resource_uri, argument_name, argument_value, context)
256 .await?),
257 #[cfg(feature = "unix")]
258 ClientInner::Unix(client) => Ok(client
259 .complete_resource(resource_uri, argument_name, argument_value, context)
260 .await?),
261 #[cfg(feature = "http")]
262 ClientInner::Http(client) => Ok(client
263 .complete_resource(resource_uri, argument_name, argument_value, context)
264 .await?),
265 #[cfg(feature = "websocket")]
266 ClientInner::WebSocket(client) => Ok(client
267 .complete_resource(resource_uri, argument_name, argument_value, context)
268 .await?),
269 }
270 }
271
272 pub async fn ping(&self) -> CliResult<()> {
273 match &self.inner {
274 #[cfg(feature = "stdio")]
275 ClientInner::Stdio(client) => {
276 client.ping().await?;
277 Ok(())
278 }
279 #[cfg(feature = "tcp")]
280 ClientInner::Tcp(client) => {
281 client.ping().await?;
282 Ok(())
283 }
284 #[cfg(feature = "unix")]
285 ClientInner::Unix(client) => {
286 client.ping().await?;
287 Ok(())
288 }
289 #[cfg(feature = "http")]
290 ClientInner::Http(client) => {
291 client.ping().await?;
292 Ok(())
293 }
294 #[cfg(feature = "websocket")]
295 ClientInner::WebSocket(client) => {
296 client.ping().await?;
297 Ok(())
298 }
299 }
300 }
301
302 pub async fn set_log_level(&self, level: turbomcp_protocol::types::LogLevel) -> CliResult<()> {
303 match &self.inner {
304 #[cfg(feature = "stdio")]
305 ClientInner::Stdio(client) => {
306 client.set_log_level(level).await?;
307 Ok(())
308 }
309 #[cfg(feature = "tcp")]
310 ClientInner::Tcp(client) => {
311 client.set_log_level(level).await?;
312 Ok(())
313 }
314 #[cfg(feature = "unix")]
315 ClientInner::Unix(client) => {
316 client.set_log_level(level).await?;
317 Ok(())
318 }
319 #[cfg(feature = "http")]
320 ClientInner::Http(client) => {
321 client.set_log_level(level).await?;
322 Ok(())
323 }
324 #[cfg(feature = "websocket")]
325 ClientInner::WebSocket(client) => {
326 client.set_log_level(level).await?;
327 Ok(())
328 }
329 }
330 }
331}
332
333pub async fn create_client(conn: &Connection) -> CliResult<UnifiedClient> {
335 let transport_kind = determine_transport(conn);
336
337 match transport_kind {
338 #[cfg(feature = "stdio")]
339 TransportKind::Stdio => {
340 let transport = create_stdio_transport(conn)?;
341 Ok(UnifiedClient {
342 inner: ClientInner::Stdio(Client::new(transport)),
343 })
344 }
345 #[cfg(not(feature = "stdio"))]
346 TransportKind::Stdio => {
347 Err(CliError::NotSupported(
348 "STDIO transport is not enabled (missing 'stdio' feature)".to_string(),
349 ))
350 }
351 #[cfg(feature = "http")]
352 TransportKind::Http => {
353 let transport = create_http_transport(conn).await?;
354 Ok(UnifiedClient {
355 inner: ClientInner::Http(Client::new(transport)),
356 })
357 }
358 #[cfg(not(feature = "http"))]
359 TransportKind::Http => {
360 Err(CliError::NotSupported(
361 "HTTP transport is not enabled. Rebuild with --features http or --features all"
362 .to_string(),
363 ))
364 }
365 #[cfg(feature = "websocket")]
366 TransportKind::Ws => {
367 let transport = create_websocket_transport(conn).await?;
368 Ok(UnifiedClient {
369 inner: ClientInner::WebSocket(Client::new(transport)),
370 })
371 }
372 #[cfg(not(feature = "websocket"))]
373 TransportKind::Ws => {
374 Err(CliError::NotSupported(
375 "WebSocket transport is not enabled. Rebuild with --features websocket or --features all"
376 .to_string(),
377 ))
378 }
379 #[cfg(feature = "tcp")]
380 TransportKind::Tcp => {
381 let transport = create_tcp_transport(conn).await?;
382 Ok(UnifiedClient {
383 inner: ClientInner::Tcp(Client::new(transport)),
384 })
385 }
386 #[cfg(not(feature = "tcp"))]
387 TransportKind::Tcp => {
388 Err(CliError::NotSupported(
389 "TCP transport is not enabled (missing 'tcp' feature)".to_string(),
390 ))
391 }
392 #[cfg(feature = "unix")]
393 TransportKind::Unix => {
394 let transport = create_unix_transport(conn).await?;
395 Ok(UnifiedClient {
396 inner: ClientInner::Unix(Client::new(transport)),
397 })
398 }
399 #[cfg(not(feature = "unix"))]
400 TransportKind::Unix => {
401 Err(CliError::NotSupported(
402 "Unix socket transport is not enabled (missing 'unix' feature)".to_string(),
403 ))
404 }
405 }
406}
407
408pub fn determine_transport(conn: &Connection) -> TransportKind {
410 if let Some(transport) = &conn.transport {
412 return transport.clone();
413 }
414
415 let url = &conn.url;
417
418 if conn.command.is_some() {
419 return TransportKind::Stdio;
420 }
421
422 if url.starts_with("tcp://") {
423 return TransportKind::Tcp;
424 }
425
426 if url.starts_with("unix://") || url.starts_with("/") {
427 return TransportKind::Unix;
428 }
429
430 if url.starts_with("ws://") || url.starts_with("wss://") {
431 return TransportKind::Ws;
432 }
433
434 if url.starts_with("http://") || url.starts_with("https://") {
435 return TransportKind::Http;
436 }
437
438 TransportKind::Stdio
440}
441
442#[cfg(feature = "stdio")]
444fn create_stdio_transport(conn: &Connection) -> CliResult<ChildProcessTransport> {
445 let command_str = conn.command.as_deref().unwrap_or(&conn.url);
447
448 let parts: Vec<&str> = command_str.split_whitespace().collect();
450 if parts.is_empty() {
451 return Err(CliError::InvalidArguments(
452 "No command specified for STDIO transport".to_string(),
453 ));
454 }
455
456 let command = parts[0].to_string();
457 let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
458
459 let config = ChildProcessConfig {
461 command,
462 args,
463 working_directory: None,
464 environment: None,
465 startup_timeout: Duration::from_secs(conn.timeout),
466 shutdown_timeout: Duration::from_secs(5),
467 max_message_size: 10 * 1024 * 1024, buffer_size: 8192, kill_on_drop: true, };
471
472 Ok(ChildProcessTransport::new(config))
474}
475
476#[cfg(feature = "tcp")]
478async fn create_tcp_transport(
479 conn: &Connection,
480) -> CliResult<turbomcp_transport::tcp::TcpTransport> {
481 let url = &conn.url;
482
483 let addr_str = url
485 .strip_prefix("tcp://")
486 .ok_or_else(|| CliError::InvalidArguments(format!("Invalid TCP URL: {}", url)))?;
487
488 let socket_addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
490 CliError::InvalidArguments(format!("Invalid address '{}': {}", addr_str, e))
491 })?;
492
493 let transport = TcpTransportBuilder::new().remote_addr(socket_addr).build();
494
495 Ok(transport)
496}
497
498#[cfg(feature = "unix")]
500async fn create_unix_transport(
501 conn: &Connection,
502) -> CliResult<turbomcp_transport::unix::UnixTransport> {
503 let path = conn.url.strip_prefix("unix://").unwrap_or(&conn.url);
504
505 let transport = UnixTransportBuilder::new_client().socket_path(path).build();
506
507 Ok(transport)
508}
509
510#[cfg(feature = "http")]
512async fn create_http_transport(conn: &Connection) -> CliResult<StreamableHttpClientTransport> {
513 let url = &conn.url;
514
515 let base_url = if let Some(stripped) = url.strip_prefix("https://") {
517 format!("https://{}", stripped)
518 } else if let Some(stripped) = url.strip_prefix("http://") {
519 format!("http://{}", stripped)
520 } else {
521 url.clone()
522 };
523
524 let config = StreamableHttpClientConfig {
525 base_url,
526 endpoint_path: "/mcp".to_string(),
527 timeout: Duration::from_secs(conn.timeout),
528 ..Default::default()
529 };
530
531 Ok(StreamableHttpClientTransport::new(config))
532}
533
534#[cfg(feature = "websocket")]
536async fn create_websocket_transport(
537 conn: &Connection,
538) -> CliResult<WebSocketBidirectionalTransport> {
539 let url = &conn.url;
540
541 if !url.starts_with("ws://") && !url.starts_with("wss://") {
543 return Err(CliError::InvalidArguments(format!(
544 "Invalid WebSocket URL: {} (must start with ws:// or wss://)",
545 url
546 )));
547 }
548
549 let config = WebSocketBidirectionalConfig::client(url.clone());
550
551 WebSocketBidirectionalTransport::new(config)
552 .await
553 .map_err(|e| CliError::ConnectionFailed(e.to_string()))
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_determine_transport() {
562 let conn = Connection {
564 transport: None,
565 url: "./my-server".to_string(),
566 command: None,
567 auth: None,
568 timeout: 30,
569 };
570 assert_eq!(determine_transport(&conn), TransportKind::Stdio);
571
572 let conn = Connection {
574 transport: None,
575 url: "http://localhost".to_string(),
576 command: Some("python server.py".to_string()),
577 auth: None,
578 timeout: 30,
579 };
580 assert_eq!(determine_transport(&conn), TransportKind::Stdio);
581
582 let conn = Connection {
584 transport: None,
585 url: "tcp://localhost:8080".to_string(),
586 command: None,
587 auth: None,
588 timeout: 30,
589 };
590 assert_eq!(determine_transport(&conn), TransportKind::Tcp);
591
592 let conn = Connection {
594 transport: None,
595 url: "/tmp/mcp.sock".to_string(),
596 command: None,
597 auth: None,
598 timeout: 30,
599 };
600 assert_eq!(determine_transport(&conn), TransportKind::Unix);
601
602 let conn = Connection {
604 transport: Some(TransportKind::Tcp),
605 url: "http://localhost".to_string(),
606 command: None,
607 auth: None,
608 timeout: 30,
609 };
610 assert_eq!(determine_transport(&conn), TransportKind::Tcp);
611 }
612}