turbomcp_cli/
transport.rs

1//! Transport factory and auto-detection
2
3use crate::cli::{Connection, TransportKind};
4use crate::error::{CliError, CliResult};
5use std::collections::HashMap;
6use std::time::Duration;
7use turbomcp_client::Client;
8use turbomcp_protocol::types::Tool;
9
10#[cfg(feature = "stdio")]
11use turbomcp_transport::child_process::{ChildProcessConfig, ChildProcessTransport};
12
13#[cfg(feature = "tcp")]
14use turbomcp_transport::tcp::TcpTransportBuilder;
15
16#[cfg(feature = "unix")]
17use turbomcp_transport::unix::UnixTransportBuilder;
18
19#[cfg(feature = "http")]
20use turbomcp_transport::streamable_http_client::{
21    StreamableHttpClientConfig, StreamableHttpClientTransport,
22};
23
24#[cfg(feature = "websocket")]
25use turbomcp_transport::{WebSocketBidirectionalConfig, WebSocketBidirectionalTransport};
26
27/// Wrapper for unified client operations, hiding transport implementation details
28pub struct UnifiedClient {
29    inner: ClientInner,
30}
31
32enum ClientInner {
33    #[cfg(feature = "stdio")]
34    Stdio(Client<ChildProcessTransport>),
35    #[cfg(feature = "tcp")]
36    Tcp(Client<turbomcp_transport::tcp::TcpTransport>),
37    #[cfg(feature = "unix")]
38    Unix(Client<turbomcp_transport::unix::UnixTransport>),
39    #[cfg(feature = "http")]
40    Http(Client<StreamableHttpClientTransport>),
41    #[cfg(feature = "websocket")]
42    WebSocket(Client<WebSocketBidirectionalTransport>),
43}
44
45impl UnifiedClient {
46    pub async fn initialize(&self) -> CliResult<turbomcp_client::InitializeResult> {
47        match &self.inner {
48            #[cfg(feature = "stdio")]
49            ClientInner::Stdio(client) => Ok(client.initialize().await?),
50            #[cfg(feature = "tcp")]
51            ClientInner::Tcp(client) => Ok(client.initialize().await?),
52            #[cfg(feature = "unix")]
53            ClientInner::Unix(client) => Ok(client.initialize().await?),
54            #[cfg(feature = "http")]
55            ClientInner::Http(client) => Ok(client.initialize().await?),
56            #[cfg(feature = "websocket")]
57            ClientInner::WebSocket(client) => Ok(client.initialize().await?),
58        }
59    }
60
61    pub async fn list_tools(&self) -> CliResult<Vec<Tool>> {
62        match &self.inner {
63            #[cfg(feature = "stdio")]
64            ClientInner::Stdio(client) => Ok(client.list_tools().await?),
65            #[cfg(feature = "tcp")]
66            ClientInner::Tcp(client) => Ok(client.list_tools().await?),
67            #[cfg(feature = "unix")]
68            ClientInner::Unix(client) => Ok(client.list_tools().await?),
69            #[cfg(feature = "http")]
70            ClientInner::Http(client) => Ok(client.list_tools().await?),
71            #[cfg(feature = "websocket")]
72            ClientInner::WebSocket(client) => Ok(client.list_tools().await?),
73        }
74    }
75
76    pub async fn call_tool(
77        &self,
78        name: &str,
79        arguments: Option<HashMap<String, serde_json::Value>>,
80    ) -> CliResult<serde_json::Value> {
81        let result = match &self.inner {
82            #[cfg(feature = "stdio")]
83            ClientInner::Stdio(client) => client.call_tool(name, arguments).await?,
84            #[cfg(feature = "tcp")]
85            ClientInner::Tcp(client) => client.call_tool(name, arguments).await?,
86            #[cfg(feature = "unix")]
87            ClientInner::Unix(client) => client.call_tool(name, arguments).await?,
88            #[cfg(feature = "http")]
89            ClientInner::Http(client) => client.call_tool(name, arguments).await?,
90            #[cfg(feature = "websocket")]
91            ClientInner::WebSocket(client) => client.call_tool(name, arguments).await?,
92        };
93
94        // Serialize CallToolResult to JSON for CLI display
95        Ok(serde_json::to_value(result)?)
96    }
97
98    pub async fn list_resources(&self) -> CliResult<Vec<turbomcp_protocol::types::Resource>> {
99        match &self.inner {
100            #[cfg(feature = "stdio")]
101            ClientInner::Stdio(client) => Ok(client.list_resources().await?),
102            #[cfg(feature = "tcp")]
103            ClientInner::Tcp(client) => Ok(client.list_resources().await?),
104            #[cfg(feature = "unix")]
105            ClientInner::Unix(client) => Ok(client.list_resources().await?),
106            #[cfg(feature = "http")]
107            ClientInner::Http(client) => Ok(client.list_resources().await?),
108            #[cfg(feature = "websocket")]
109            ClientInner::WebSocket(client) => Ok(client.list_resources().await?),
110        }
111    }
112
113    pub async fn read_resource(
114        &self,
115        uri: &str,
116    ) -> CliResult<turbomcp_protocol::types::ReadResourceResult> {
117        match &self.inner {
118            #[cfg(feature = "stdio")]
119            ClientInner::Stdio(client) => Ok(client.read_resource(uri).await?),
120            #[cfg(feature = "tcp")]
121            ClientInner::Tcp(client) => Ok(client.read_resource(uri).await?),
122            #[cfg(feature = "unix")]
123            ClientInner::Unix(client) => Ok(client.read_resource(uri).await?),
124            #[cfg(feature = "http")]
125            ClientInner::Http(client) => Ok(client.read_resource(uri).await?),
126            #[cfg(feature = "websocket")]
127            ClientInner::WebSocket(client) => Ok(client.read_resource(uri).await?),
128        }
129    }
130
131    pub async fn list_resource_templates(&self) -> CliResult<Vec<String>> {
132        match &self.inner {
133            #[cfg(feature = "stdio")]
134            ClientInner::Stdio(client) => Ok(client.list_resource_templates().await?),
135            #[cfg(feature = "tcp")]
136            ClientInner::Tcp(client) => Ok(client.list_resource_templates().await?),
137            #[cfg(feature = "unix")]
138            ClientInner::Unix(client) => Ok(client.list_resource_templates().await?),
139            #[cfg(feature = "http")]
140            ClientInner::Http(client) => Ok(client.list_resource_templates().await?),
141            #[cfg(feature = "websocket")]
142            ClientInner::WebSocket(client) => Ok(client.list_resource_templates().await?),
143        }
144    }
145
146    pub async fn subscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
147        match &self.inner {
148            #[cfg(feature = "stdio")]
149            ClientInner::Stdio(client) => Ok(client.subscribe(uri).await?),
150            #[cfg(feature = "tcp")]
151            ClientInner::Tcp(client) => Ok(client.subscribe(uri).await?),
152            #[cfg(feature = "unix")]
153            ClientInner::Unix(client) => Ok(client.subscribe(uri).await?),
154            #[cfg(feature = "http")]
155            ClientInner::Http(client) => Ok(client.subscribe(uri).await?),
156            #[cfg(feature = "websocket")]
157            ClientInner::WebSocket(client) => Ok(client.subscribe(uri).await?),
158        }
159    }
160
161    pub async fn unsubscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
162        match &self.inner {
163            #[cfg(feature = "stdio")]
164            ClientInner::Stdio(client) => Ok(client.unsubscribe(uri).await?),
165            #[cfg(feature = "tcp")]
166            ClientInner::Tcp(client) => Ok(client.unsubscribe(uri).await?),
167            #[cfg(feature = "unix")]
168            ClientInner::Unix(client) => Ok(client.unsubscribe(uri).await?),
169            #[cfg(feature = "http")]
170            ClientInner::Http(client) => Ok(client.unsubscribe(uri).await?),
171            #[cfg(feature = "websocket")]
172            ClientInner::WebSocket(client) => Ok(client.unsubscribe(uri).await?),
173        }
174    }
175
176    pub async fn list_prompts(&self) -> CliResult<Vec<turbomcp_protocol::types::Prompt>> {
177        match &self.inner {
178            #[cfg(feature = "stdio")]
179            ClientInner::Stdio(client) => Ok(client.list_prompts().await?),
180            #[cfg(feature = "tcp")]
181            ClientInner::Tcp(client) => Ok(client.list_prompts().await?),
182            #[cfg(feature = "unix")]
183            ClientInner::Unix(client) => Ok(client.list_prompts().await?),
184            #[cfg(feature = "http")]
185            ClientInner::Http(client) => Ok(client.list_prompts().await?),
186            #[cfg(feature = "websocket")]
187            ClientInner::WebSocket(client) => Ok(client.list_prompts().await?),
188        }
189    }
190
191    pub async fn get_prompt(
192        &self,
193        name: &str,
194        arguments: Option<HashMap<String, serde_json::Value>>,
195    ) -> CliResult<turbomcp_protocol::types::GetPromptResult> {
196        match &self.inner {
197            #[cfg(feature = "stdio")]
198            ClientInner::Stdio(client) => Ok(client.get_prompt(name, arguments).await?),
199            #[cfg(feature = "tcp")]
200            ClientInner::Tcp(client) => Ok(client.get_prompt(name, arguments).await?),
201            #[cfg(feature = "unix")]
202            ClientInner::Unix(client) => Ok(client.get_prompt(name, arguments).await?),
203            #[cfg(feature = "http")]
204            ClientInner::Http(client) => Ok(client.get_prompt(name, arguments).await?),
205            #[cfg(feature = "websocket")]
206            ClientInner::WebSocket(client) => Ok(client.get_prompt(name, arguments).await?),
207        }
208    }
209
210    pub async fn complete_prompt(
211        &self,
212        prompt_name: &str,
213        argument_name: &str,
214        argument_value: &str,
215        context: Option<turbomcp_protocol::types::CompletionContext>,
216    ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
217        match &self.inner {
218            #[cfg(feature = "stdio")]
219            ClientInner::Stdio(client) => Ok(client
220                .complete_prompt(prompt_name, argument_name, argument_value, context)
221                .await?),
222            #[cfg(feature = "tcp")]
223            ClientInner::Tcp(client) => Ok(client
224                .complete_prompt(prompt_name, argument_name, argument_value, context)
225                .await?),
226            #[cfg(feature = "unix")]
227            ClientInner::Unix(client) => Ok(client
228                .complete_prompt(prompt_name, argument_name, argument_value, context)
229                .await?),
230            #[cfg(feature = "http")]
231            ClientInner::Http(client) => Ok(client
232                .complete_prompt(prompt_name, argument_name, argument_value, context)
233                .await?),
234            #[cfg(feature = "websocket")]
235            ClientInner::WebSocket(client) => Ok(client
236                .complete_prompt(prompt_name, argument_name, argument_value, context)
237                .await?),
238        }
239    }
240
241    pub async fn complete_resource(
242        &self,
243        resource_uri: &str,
244        argument_name: &str,
245        argument_value: &str,
246        context: Option<turbomcp_protocol::types::CompletionContext>,
247    ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
248        match &self.inner {
249            #[cfg(feature = "stdio")]
250            ClientInner::Stdio(client) => Ok(client
251                .complete_resource(resource_uri, argument_name, argument_value, context)
252                .await?),
253            #[cfg(feature = "tcp")]
254            ClientInner::Tcp(client) => Ok(client
255                .complete_resource(resource_uri, argument_name, argument_value, context)
256                .await?),
257            #[cfg(feature = "unix")]
258            ClientInner::Unix(client) => Ok(client
259                .complete_resource(resource_uri, argument_name, argument_value, context)
260                .await?),
261            #[cfg(feature = "http")]
262            ClientInner::Http(client) => Ok(client
263                .complete_resource(resource_uri, argument_name, argument_value, context)
264                .await?),
265            #[cfg(feature = "websocket")]
266            ClientInner::WebSocket(client) => Ok(client
267                .complete_resource(resource_uri, argument_name, argument_value, context)
268                .await?),
269        }
270    }
271
272    pub async fn ping(&self) -> CliResult<()> {
273        match &self.inner {
274            #[cfg(feature = "stdio")]
275            ClientInner::Stdio(client) => {
276                client.ping().await?;
277                Ok(())
278            }
279            #[cfg(feature = "tcp")]
280            ClientInner::Tcp(client) => {
281                client.ping().await?;
282                Ok(())
283            }
284            #[cfg(feature = "unix")]
285            ClientInner::Unix(client) => {
286                client.ping().await?;
287                Ok(())
288            }
289            #[cfg(feature = "http")]
290            ClientInner::Http(client) => {
291                client.ping().await?;
292                Ok(())
293            }
294            #[cfg(feature = "websocket")]
295            ClientInner::WebSocket(client) => {
296                client.ping().await?;
297                Ok(())
298            }
299        }
300    }
301
302    pub async fn set_log_level(&self, level: turbomcp_protocol::types::LogLevel) -> CliResult<()> {
303        match &self.inner {
304            #[cfg(feature = "stdio")]
305            ClientInner::Stdio(client) => {
306                client.set_log_level(level).await?;
307                Ok(())
308            }
309            #[cfg(feature = "tcp")]
310            ClientInner::Tcp(client) => {
311                client.set_log_level(level).await?;
312                Ok(())
313            }
314            #[cfg(feature = "unix")]
315            ClientInner::Unix(client) => {
316                client.set_log_level(level).await?;
317                Ok(())
318            }
319            #[cfg(feature = "http")]
320            ClientInner::Http(client) => {
321                client.set_log_level(level).await?;
322                Ok(())
323            }
324            #[cfg(feature = "websocket")]
325            ClientInner::WebSocket(client) => {
326                client.set_log_level(level).await?;
327                Ok(())
328            }
329        }
330    }
331}
332
333/// Create a unified client that hides transport type complexity from the executor
334pub async fn create_client(conn: &Connection) -> CliResult<UnifiedClient> {
335    let transport_kind = determine_transport(conn);
336
337    match transport_kind {
338        #[cfg(feature = "stdio")]
339        TransportKind::Stdio => {
340            let transport = create_stdio_transport(conn)?;
341            Ok(UnifiedClient {
342                inner: ClientInner::Stdio(Client::new(transport)),
343            })
344        }
345        #[cfg(not(feature = "stdio"))]
346        TransportKind::Stdio => {
347            Err(CliError::NotSupported(
348                "STDIO transport is not enabled (missing 'stdio' feature)".to_string(),
349            ))
350        }
351        #[cfg(feature = "http")]
352        TransportKind::Http => {
353            let transport = create_http_transport(conn).await?;
354            Ok(UnifiedClient {
355                inner: ClientInner::Http(Client::new(transport)),
356            })
357        }
358        #[cfg(not(feature = "http"))]
359        TransportKind::Http => {
360            Err(CliError::NotSupported(
361                "HTTP transport is not enabled. Rebuild with --features http or --features all"
362                    .to_string(),
363            ))
364        }
365        #[cfg(feature = "websocket")]
366        TransportKind::Ws => {
367            let transport = create_websocket_transport(conn).await?;
368            Ok(UnifiedClient {
369                inner: ClientInner::WebSocket(Client::new(transport)),
370            })
371        }
372        #[cfg(not(feature = "websocket"))]
373        TransportKind::Ws => {
374            Err(CliError::NotSupported(
375                "WebSocket transport is not enabled. Rebuild with --features websocket or --features all"
376                    .to_string(),
377            ))
378        }
379        #[cfg(feature = "tcp")]
380        TransportKind::Tcp => {
381            let transport = create_tcp_transport(conn).await?;
382            Ok(UnifiedClient {
383                inner: ClientInner::Tcp(Client::new(transport)),
384            })
385        }
386        #[cfg(not(feature = "tcp"))]
387        TransportKind::Tcp => {
388            Err(CliError::NotSupported(
389                "TCP transport is not enabled (missing 'tcp' feature)".to_string(),
390            ))
391        }
392        #[cfg(feature = "unix")]
393        TransportKind::Unix => {
394            let transport = create_unix_transport(conn).await?;
395            Ok(UnifiedClient {
396                inner: ClientInner::Unix(Client::new(transport)),
397            })
398        }
399        #[cfg(not(feature = "unix"))]
400        TransportKind::Unix => {
401            Err(CliError::NotSupported(
402                "Unix socket transport is not enabled (missing 'unix' feature)".to_string(),
403            ))
404        }
405    }
406}
407
408/// Determine transport type from connection config
409pub fn determine_transport(conn: &Connection) -> TransportKind {
410    // Use explicit transport if provided
411    if let Some(transport) = &conn.transport {
412        return transport.clone();
413    }
414
415    // Auto-detect based on URL/command patterns
416    let url = &conn.url;
417
418    if conn.command.is_some() {
419        return TransportKind::Stdio;
420    }
421
422    if url.starts_with("tcp://") {
423        return TransportKind::Tcp;
424    }
425
426    if url.starts_with("unix://") || url.starts_with("/") {
427        return TransportKind::Unix;
428    }
429
430    if url.starts_with("ws://") || url.starts_with("wss://") {
431        return TransportKind::Ws;
432    }
433
434    if url.starts_with("http://") || url.starts_with("https://") {
435        return TransportKind::Http;
436    }
437
438    // Default to STDIO for executable paths
439    TransportKind::Stdio
440}
441
442/// Create STDIO transport from connection
443#[cfg(feature = "stdio")]
444fn create_stdio_transport(conn: &Connection) -> CliResult<ChildProcessTransport> {
445    // Use --command if provided, otherwise use --url
446    let command_str = conn.command.as_deref().unwrap_or(&conn.url);
447
448    // Parse command and arguments
449    let parts: Vec<&str> = command_str.split_whitespace().collect();
450    if parts.is_empty() {
451        return Err(CliError::InvalidArguments(
452            "No command specified for STDIO transport".to_string(),
453        ));
454    }
455
456    let command = parts[0].to_string();
457    let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
458
459    // Create config
460    let config = ChildProcessConfig {
461        command,
462        args,
463        working_directory: None,
464        environment: None,
465        startup_timeout: Duration::from_secs(conn.timeout),
466        shutdown_timeout: Duration::from_secs(5),
467        max_message_size: 10 * 1024 * 1024, // 10MB
468        buffer_size: 8192,                  // 8KB buffer
469        kill_on_drop: true,                 // Kill process when client is dropped
470    };
471
472    // Create transport
473    Ok(ChildProcessTransport::new(config))
474}
475
476/// Create TCP transport from connection
477#[cfg(feature = "tcp")]
478async fn create_tcp_transport(
479    conn: &Connection,
480) -> CliResult<turbomcp_transport::tcp::TcpTransport> {
481    let url = &conn.url;
482
483    // Parse TCP URL
484    let addr_str = url
485        .strip_prefix("tcp://")
486        .ok_or_else(|| CliError::InvalidArguments(format!("Invalid TCP URL: {}", url)))?;
487
488    // Parse into SocketAddr
489    let socket_addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
490        CliError::InvalidArguments(format!("Invalid address '{}': {}", addr_str, e))
491    })?;
492
493    let transport = TcpTransportBuilder::new().remote_addr(socket_addr).build();
494
495    Ok(transport)
496}
497
498/// Create Unix socket transport from connection
499#[cfg(feature = "unix")]
500async fn create_unix_transport(
501    conn: &Connection,
502) -> CliResult<turbomcp_transport::unix::UnixTransport> {
503    let path = conn.url.strip_prefix("unix://").unwrap_or(&conn.url);
504
505    let transport = UnixTransportBuilder::new_client().socket_path(path).build();
506
507    Ok(transport)
508}
509
510/// Create HTTP transport from connection
511#[cfg(feature = "http")]
512async fn create_http_transport(conn: &Connection) -> CliResult<StreamableHttpClientTransport> {
513    let url = &conn.url;
514
515    // Parse HTTP URL (remove http:// or https://)
516    let base_url = if let Some(stripped) = url.strip_prefix("https://") {
517        format!("https://{}", stripped)
518    } else if let Some(stripped) = url.strip_prefix("http://") {
519        format!("http://{}", stripped)
520    } else {
521        url.clone()
522    };
523
524    let config = StreamableHttpClientConfig {
525        base_url,
526        endpoint_path: "/mcp".to_string(),
527        timeout: Duration::from_secs(conn.timeout),
528        ..Default::default()
529    };
530
531    Ok(StreamableHttpClientTransport::new(config))
532}
533
534/// Create WebSocket transport from connection
535#[cfg(feature = "websocket")]
536async fn create_websocket_transport(
537    conn: &Connection,
538) -> CliResult<WebSocketBidirectionalTransport> {
539    let url = &conn.url;
540
541    // Validate URL is a proper WebSocket URL
542    if !url.starts_with("ws://") && !url.starts_with("wss://") {
543        return Err(CliError::InvalidArguments(format!(
544            "Invalid WebSocket URL: {} (must start with ws:// or wss://)",
545            url
546        )));
547    }
548
549    let config = WebSocketBidirectionalConfig::client(url.clone());
550
551    WebSocketBidirectionalTransport::new(config)
552        .await
553        .map_err(|e| CliError::ConnectionFailed(e.to_string()))
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_determine_transport() {
562        // STDIO detection
563        let conn = Connection {
564            transport: None,
565            url: "./my-server".to_string(),
566            command: None,
567            auth: None,
568            timeout: 30,
569        };
570        assert_eq!(determine_transport(&conn), TransportKind::Stdio);
571
572        // Command override
573        let conn = Connection {
574            transport: None,
575            url: "http://localhost".to_string(),
576            command: Some("python server.py".to_string()),
577            auth: None,
578            timeout: 30,
579        };
580        assert_eq!(determine_transport(&conn), TransportKind::Stdio);
581
582        // TCP detection
583        let conn = Connection {
584            transport: None,
585            url: "tcp://localhost:8080".to_string(),
586            command: None,
587            auth: None,
588            timeout: 30,
589        };
590        assert_eq!(determine_transport(&conn), TransportKind::Tcp);
591
592        // Unix detection
593        let conn = Connection {
594            transport: None,
595            url: "/tmp/mcp.sock".to_string(),
596            command: None,
597            auth: None,
598            timeout: 30,
599        };
600        assert_eq!(determine_transport(&conn), TransportKind::Unix);
601
602        // Explicit override
603        let conn = Connection {
604            transport: Some(TransportKind::Tcp),
605            url: "http://localhost".to_string(),
606            command: None,
607            auth: None,
608            timeout: 30,
609        };
610        assert_eq!(determine_transport(&conn), TransportKind::Tcp);
611    }
612}