rust_mcp_transport/
stdio.rs

1use crate::schema::schema_utils::{
2    ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage,
3    ServerMessages,
4};
5use crate::schema::RequestId;
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::process::Command;
13use tokio::sync::oneshot::Sender;
14use tokio::sync::{oneshot, Mutex};
15use tokio::task::JoinHandle;
16
17use crate::error::{TransportError, TransportResult};
18use crate::mcp_stream::MCPStream;
19use crate::message_dispatcher::MessageDispatcher;
20use crate::transport::Transport;
21use crate::utils::CancellationTokenSource;
22use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
23
24/// Implements a standard I/O transport for MCP communication.
25///
26/// This module provides the `StdioTransport` struct, which serves as a transport layer for the
27/// Model Context Protocol (MCP) using standard input/output (stdio). It supports both client-side
28/// and server-side communication by optionally launching a subprocess or using the current
29/// process's stdio streams. The transport handles message streaming, dispatching, and shutdown
30/// operations, integrating with the MCP runtime ecosystem.
31pub struct StdioTransport<R>
32where
33    R: Clone + Send + Sync + DeserializeOwned + 'static,
34{
35    command: Option<String>,
36    args: Option<Vec<String>>,
37    env: Option<HashMap<String, String>>,
38    options: TransportOptions,
39    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
40    is_shut_down: Mutex<bool>,
41    message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
42    error_stream: tokio::sync::RwLock<Option<IoStream>>,
43    pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
44}
45
46impl<R> StdioTransport<R>
47where
48    R: Clone + Send + Sync + DeserializeOwned + 'static,
49{
50    /// Creates a new `StdioTransport` instance for MCP Server.
51    ///
52    /// This constructor configures the transport to use the current process's stdio streams,
53    ///
54    /// # Arguments
55    /// * `options` - Configuration options for the transport, including timeout settings.
56    ///
57    /// # Returns
58    /// A `TransportResult` containing the initialized `StdioTransport` instance.
59    ///
60    /// # Errors
61    /// Currently, this method does not fail, but it returns a `TransportResult` for API consistency.
62    pub fn new(options: TransportOptions) -> TransportResult<Self> {
63        Ok(Self {
64            // when transport is used for MCP Server, we do not need a command
65            args: None,
66            command: None,
67            env: None,
68            options,
69            shutdown_source: tokio::sync::RwLock::new(None),
70            is_shut_down: Mutex::new(false),
71            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
72            error_stream: tokio::sync::RwLock::new(None),
73            pending_requests: Arc::new(Mutex::new(HashMap::new())),
74        })
75    }
76
77    /// Creates a new `StdioTransport` instance with a subprocess for MCP Client use.
78    ///
79    /// This constructor configures the transport to launch a MCP Server with a specified command
80    /// arguments and optional environment variables
81    ///
82    /// # Arguments
83    /// * `command` - The command to execute (e.g., "rust-mcp-filesystem").
84    /// * `args` - Arguments to pass to the command. (e.g., "~/Documents").
85    /// * `env` - Optional environment variables for the subprocess.
86    /// * `options` - Configuration options for the transport, including timeout settings.
87    ///
88    /// # Returns
89    /// A `TransportResult` containing the initialized `StdioTransport` instance, ready to launch
90    /// the MCP server on `start`.
91    pub fn create_with_server_launch<C: Into<String>>(
92        command: C,
93        args: Vec<String>,
94        env: Option<HashMap<String, String>>,
95        options: TransportOptions,
96    ) -> TransportResult<Self> {
97        Ok(Self {
98            // when transport is used for MCP Server, we do not need a command
99            args: Some(args),
100            command: Some(command.into()),
101            env,
102            options,
103            shutdown_source: tokio::sync::RwLock::new(None),
104            is_shut_down: Mutex::new(false),
105            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
106            error_stream: tokio::sync::RwLock::new(None),
107            pending_requests: Arc::new(Mutex::new(HashMap::new())),
108        })
109    }
110
111    /// Retrieves the command and arguments for launching the subprocess.
112    ///
113    /// Adjusts the command based on the platform: on Windows, wraps it with `cmd.exe /c`.
114    ///
115    /// # Returns
116    /// A tuple of the command string and its arguments.
117    fn launch_commands(&self) -> (String, Vec<std::string::String>) {
118        #[cfg(windows)]
119        {
120            let command = "cmd.exe".to_string();
121            let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
122            command_args.extend(self.args.clone().unwrap_or_default());
123            (command, command_args)
124        }
125
126        #[cfg(unix)]
127        {
128            let command = self.command.clone().unwrap_or_default();
129            let command_args = self.args.clone().unwrap_or_default();
130            (command, command_args)
131        }
132    }
133
134    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
135        let mut lock = self.message_sender.write().await;
136        *lock = Some(sender);
137    }
138
139    pub(crate) async fn set_error_stream(
140        &self,
141        error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
142    ) {
143        let mut lock = self.error_stream.write().await;
144        *lock = Some(IoStream::Writable(error_stream));
145    }
146}
147
148#[async_trait]
149impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for StdioTransport<M>
150where
151    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
152    S: Clone + Send + Sync + serde::Serialize + 'static,
153    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
154    OR: Clone + Send + Sync + serde::Serialize + 'static,
155    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
156{
157    /// Starts the transport, initializing streams and the message dispatcher.
158    ///
159    /// If configured with a command (MCP Client), launches the MCP server and connects its stdio streams.
160    /// Otherwise, uses the current process's stdio for server-side communication.
161    ///
162    /// # Returns
163    /// A `TransportResult` containing:
164    /// - A pinned stream of incoming messages.
165    /// - A `MessageDispatcher<R>` for sending messages.
166    /// - An `IoStream` for stderr (readable) or stdout (writable) depending on the mode.
167    ///
168    /// # Errors
169    /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed.
170    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
171    where
172        MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
173    {
174        // Create CancellationTokenSource and token
175        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
176        let mut lock = self.shutdown_source.write().await;
177        *lock = Some(cancellation_source);
178
179        if self.command.is_some() {
180            let (command_name, command_args) = self.launch_commands();
181
182            let mut command = Command::new(command_name);
183            command
184                .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
185                .args(&command_args)
186                .stdout(std::process::Stdio::piped())
187                .stdin(std::process::Stdio::piped())
188                .stderr(std::process::Stdio::piped())
189                .kill_on_drop(true);
190
191            #[cfg(windows)]
192            command.creation_flags(0x08000000); // https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags
193
194            #[cfg(unix)]
195            command.process_group(0);
196
197            let mut process = command.spawn().map_err(TransportError::Io)?;
198
199            let stdin = process
200                .stdin
201                .take()
202                .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?;
203
204            let stdout = process
205                .stdout
206                .take()
207                .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?;
208
209            let stderr = process
210                .stderr
211                .take()
212                .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?;
213
214            let pending_requests_clone = self.pending_requests.clone();
215
216            tokio::spawn(async move {
217                let _ = process.wait().await;
218                // clean up pending requests to cancel waiting tasks
219                let mut pending_requests = pending_requests_clone.lock().await;
220                pending_requests.clear();
221            });
222
223            let (stream, sender, error_stream) = MCPStream::create(
224                Box::pin(stdout),
225                Mutex::new(Box::pin(stdin)),
226                IoStream::Readable(Box::pin(stderr)),
227                self.pending_requests.clone(),
228                self.options.timeout,
229                cancellation_token,
230            );
231
232            self.set_message_sender(sender).await;
233
234            if let IoStream::Writable(error_stream) = error_stream {
235                self.set_error_stream(error_stream).await;
236            }
237
238            Ok(stream)
239        } else {
240            let (stream, sender, error_stream) = MCPStream::create(
241                Box::pin(tokio::io::stdin()),
242                Mutex::new(Box::pin(tokio::io::stdout())),
243                IoStream::Writable(Box::pin(tokio::io::stderr())),
244                self.pending_requests.clone(),
245                self.options.timeout,
246                cancellation_token,
247            );
248
249            self.set_message_sender(sender).await;
250
251            if let IoStream::Writable(error_stream) = error_stream {
252                self.set_error_stream(error_stream).await;
253            }
254            Ok(stream)
255        }
256    }
257
258    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
259        let mut pending_requests = self.pending_requests.lock().await;
260        pending_requests.remove(request_id)
261    }
262
263    /// Checks if the transport has been shut down.
264    async fn is_shut_down(&self) -> bool {
265        let result = self.is_shut_down.lock().await;
266        *result
267    }
268
269    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
270        self.message_sender.clone() as _
271    }
272
273    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
274        &self.error_stream as _
275    }
276
277    async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
278        Err(TransportError::Internal(
279            "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(),
280        ))
281    }
282
283    async fn keep_alive(
284        &self,
285        _interval: Duration,
286        _disconnect_tx: oneshot::Sender<()>,
287    ) -> TransportResult<JoinHandle<()>> {
288        Err(TransportError::Internal(
289            "Invalid invocation of keep_alive() function for StdioTransport".to_string(),
290        ))
291    }
292
293    // Shuts down the transport, terminating any subprocess and signaling closure.
294    ///
295    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
296    ///
297    /// # Returns
298    /// A `TransportResult` indicating success or failure.
299    ///
300    /// # Errors
301    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
302    async fn shut_down(&self) -> TransportResult<()> {
303        // Trigger cancellation
304        let mut cancellation_lock = self.shutdown_source.write().await;
305        if let Some(source) = cancellation_lock.as_ref() {
306            source.cancel()?;
307        }
308        *cancellation_lock = None; // Clear cancellation_source
309
310        // Mark as shut down
311        let mut is_shut_down_lock = self.is_shut_down.lock().await;
312        *is_shut_down_lock = true;
313        Ok(())
314    }
315}
316
317#[async_trait]
318impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
319    for StdioTransport<ClientMessage>
320{
321    async fn send_message(
322        &self,
323        message: ServerMessages,
324        request_timeout: Option<Duration>,
325    ) -> TransportResult<Option<ClientMessages>> {
326        let sender = self.message_sender.read().await;
327        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
328        sender.send_message(message, request_timeout).await
329    }
330
331    async fn send(
332        &self,
333        message: ServerMessage,
334        request_timeout: Option<Duration>,
335    ) -> TransportResult<Option<ClientMessage>> {
336        let sender = self.message_sender.read().await;
337        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
338        sender.send(message, request_timeout).await
339    }
340
341    async fn send_batch(
342        &self,
343        message: Vec<ServerMessage>,
344        request_timeout: Option<Duration>,
345    ) -> TransportResult<Option<Vec<ClientMessage>>> {
346        let sender = self.message_sender.read().await;
347        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
348        sender.send_batch(message, request_timeout).await
349    }
350
351    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
352        let sender = self.message_sender.read().await;
353        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
354        sender.write_str(payload, skip_store).await
355    }
356}
357
358impl
359    TransportDispatcher<
360        ClientMessages,
361        MessageFromServer,
362        ClientMessage,
363        ServerMessages,
364        ServerMessage,
365    > for StdioTransport<ClientMessage>
366{
367}
368
369#[async_trait]
370impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
371    for StdioTransport<ServerMessage>
372{
373    async fn send_message(
374        &self,
375        message: ClientMessages,
376        request_timeout: Option<Duration>,
377    ) -> TransportResult<Option<ServerMessages>> {
378        let sender = self.message_sender.read().await;
379        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
380        sender.send_message(message, request_timeout).await
381    }
382
383    async fn send(
384        &self,
385        message: ClientMessage,
386        request_timeout: Option<Duration>,
387    ) -> TransportResult<Option<ServerMessage>> {
388        let sender = self.message_sender.read().await;
389        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
390        sender.send(message, request_timeout).await
391    }
392
393    async fn send_batch(
394        &self,
395        message: Vec<ClientMessage>,
396        request_timeout: Option<Duration>,
397    ) -> TransportResult<Option<Vec<ServerMessage>>> {
398        let sender = self.message_sender.read().await;
399        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
400        sender.send_batch(message, request_timeout).await
401    }
402
403    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
404        let sender = self.message_sender.read().await;
405        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
406        sender.write_str(payload, skip_store).await
407    }
408}
409
410impl
411    TransportDispatcher<
412        ServerMessages,
413        MessageFromClient,
414        ServerMessage,
415        ClientMessages,
416        ClientMessage,
417    > for StdioTransport<ServerMessage>
418{
419}