rust_mcp_transport/
stdio.rs

1use crate::schema::schema_utils::{
2    ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage,
3    ServerMessages,
4};
5use crate::schema::RequestId;
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::process::Command;
12use tokio::sync::oneshot::Sender;
13use tokio::sync::{oneshot, Mutex};
14use tokio::task::JoinHandle;
15
16use crate::error::{TransportError, TransportResult};
17use crate::mcp_stream::MCPStream;
18use crate::message_dispatcher::MessageDispatcher;
19use crate::transport::Transport;
20use crate::utils::CancellationTokenSource;
21use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
22
23/// Implements a standard I/O transport for MCP communication.
24///
25/// This module provides the `StdioTransport` struct, which serves as a transport layer for the
26/// Model Context Protocol (MCP) using standard input/output (stdio). It supports both client-side
27/// and server-side communication by optionally launching a subprocess or using the current
28/// process's stdio streams. The transport handles message streaming, dispatching, and shutdown
29/// operations, integrating with the MCP runtime ecosystem.
30pub struct StdioTransport<R>
31where
32    R: Clone + Send + Sync + DeserializeOwned + 'static,
33{
34    command: Option<String>,
35    args: Option<Vec<String>>,
36    env: Option<HashMap<String, String>>,
37    options: TransportOptions,
38    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
39    is_shut_down: Mutex<bool>,
40    message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
41    error_stream: tokio::sync::RwLock<Option<IoStream>>,
42    pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
43}
44
45impl<R> StdioTransport<R>
46where
47    R: Clone + Send + Sync + DeserializeOwned + 'static,
48{
49    /// Creates a new `StdioTransport` instance for MCP Server.
50    ///
51    /// This constructor configures the transport to use the current process's stdio streams,
52    ///
53    /// # Arguments
54    /// * `options` - Configuration options for the transport, including timeout settings.
55    ///
56    /// # Returns
57    /// A `TransportResult` containing the initialized `StdioTransport` instance.
58    ///
59    /// # Errors
60    /// Currently, this method does not fail, but it returns a `TransportResult` for API consistency.
61    pub fn new(options: TransportOptions) -> TransportResult<Self> {
62        Ok(Self {
63            // when transport is used for MCP Server, we do not need a command
64            args: None,
65            command: None,
66            env: None,
67            options,
68            shutdown_source: tokio::sync::RwLock::new(None),
69            is_shut_down: Mutex::new(false),
70            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
71            error_stream: tokio::sync::RwLock::new(None),
72            pending_requests: Arc::new(Mutex::new(HashMap::new())),
73        })
74    }
75
76    /// Creates a new `StdioTransport` instance with a subprocess for MCP Client use.
77    ///
78    /// This constructor configures the transport to launch a MCP Server with a specified command
79    /// arguments and optional environment variables
80    ///
81    /// # Arguments
82    /// * `command` - The command to execute (e.g., "rust-mcp-filesystem").
83    /// * `args` - Arguments to pass to the command. (e.g., "~/Documents").
84    /// * `env` - Optional environment variables for the subprocess.
85    /// * `options` - Configuration options for the transport, including timeout settings.
86    ///
87    /// # Returns
88    /// A `TransportResult` containing the initialized `StdioTransport` instance, ready to launch
89    /// the MCP server on `start`.
90    pub fn create_with_server_launch<C: Into<String>>(
91        command: C,
92        args: Vec<String>,
93        env: Option<HashMap<String, String>>,
94        options: TransportOptions,
95    ) -> TransportResult<Self> {
96        Ok(Self {
97            // when transport is used for MCP Server, we do not need a command
98            args: Some(args),
99            command: Some(command.into()),
100            env,
101            options,
102            shutdown_source: tokio::sync::RwLock::new(None),
103            is_shut_down: Mutex::new(false),
104            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
105            error_stream: tokio::sync::RwLock::new(None),
106            pending_requests: Arc::new(Mutex::new(HashMap::new())),
107        })
108    }
109
110    /// Retrieves the command and arguments for launching the subprocess.
111    ///
112    /// Adjusts the command based on the platform: on Windows, wraps it with `cmd.exe /c`.
113    ///
114    /// # Returns
115    /// A tuple of the command string and its arguments.
116    fn launch_commands(&self) -> (String, Vec<std::string::String>) {
117        #[cfg(windows)]
118        {
119            let command = "cmd.exe".to_string();
120            let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
121            command_args.extend(self.args.clone().unwrap_or_default());
122            (command, command_args)
123        }
124
125        #[cfg(unix)]
126        {
127            let command = self.command.clone().unwrap_or_default();
128            let command_args = self.args.clone().unwrap_or_default();
129            (command, command_args)
130        }
131    }
132
133    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
134        let mut lock = self.message_sender.write().await;
135        *lock = Some(sender);
136    }
137
138    pub(crate) async fn set_error_stream(&self, error_stream: IoStream) {
139        let mut lock = self.error_stream.write().await;
140        *lock = Some(error_stream);
141    }
142}
143
144#[async_trait]
145impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for StdioTransport<M>
146where
147    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
148    S: Clone + Send + Sync + serde::Serialize + 'static,
149    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
150    OR: Clone + Send + Sync + serde::Serialize + 'static,
151    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
152{
153    /// Starts the transport, initializing streams and the message dispatcher.
154    ///
155    /// If configured with a command (MCP Client), launches the MCP server and connects its stdio streams.
156    /// Otherwise, uses the current process's stdio for server-side communication.
157    ///
158    /// # Returns
159    /// A `TransportResult` containing:
160    /// - A pinned stream of incoming messages.
161    /// - A `MessageDispatcher<R>` for sending messages.
162    /// - An `IoStream` for stderr (readable) or stdout (writable) depending on the mode.
163    ///
164    /// # Errors
165    /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed.
166    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
167    where
168        MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
169    {
170        // Create CancellationTokenSource and token
171        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
172        let mut lock = self.shutdown_source.write().await;
173        *lock = Some(cancellation_source);
174
175        if self.command.is_some() {
176            let (command_name, command_args) = self.launch_commands();
177
178            let mut command = Command::new(command_name);
179            command
180                .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
181                .args(&command_args)
182                .stdout(std::process::Stdio::piped())
183                .stdin(std::process::Stdio::piped())
184                .stderr(std::process::Stdio::piped())
185                .kill_on_drop(true);
186
187            #[cfg(windows)]
188            command.creation_flags(0x08000000); // https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags
189
190            #[cfg(unix)]
191            command.process_group(0);
192
193            let mut process = command.spawn().map_err(TransportError::Io)?;
194
195            let stdin = process
196                .stdin
197                .take()
198                .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?;
199
200            let stdout = process
201                .stdout
202                .take()
203                .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?;
204
205            let stderr = process
206                .stderr
207                .take()
208                .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?;
209
210            let pending_requests_clone = self.pending_requests.clone();
211
212            tokio::spawn(async move {
213                let _ = process.wait().await;
214                // clean up pending requests to cancel waiting tasks
215                let mut pending_requests = pending_requests_clone.lock().await;
216                pending_requests.clear();
217            });
218
219            let (stream, sender, error_stream) = MCPStream::create(
220                Box::pin(stdout),
221                Mutex::new(Box::pin(stdin)),
222                IoStream::Readable(Box::pin(stderr)),
223                self.pending_requests.clone(),
224                self.options.timeout,
225                cancellation_token,
226            );
227
228            self.set_message_sender(sender).await;
229            self.set_error_stream(error_stream).await;
230
231            Ok(stream)
232        } else {
233            let (stream, sender, error_stream) = MCPStream::create(
234                Box::pin(tokio::io::stdin()),
235                Mutex::new(Box::pin(tokio::io::stdout())),
236                IoStream::Writable(Box::pin(tokio::io::stderr())),
237                self.pending_requests.clone(),
238                self.options.timeout,
239                cancellation_token,
240            );
241
242            self.set_message_sender(sender).await;
243            self.set_error_stream(error_stream).await;
244            Ok(stream)
245        }
246    }
247
248    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
249        let mut pending_requests = self.pending_requests.lock().await;
250        pending_requests.remove(request_id)
251    }
252
253    /// Checks if the transport has been shut down.
254    async fn is_shut_down(&self) -> bool {
255        let result = self.is_shut_down.lock().await;
256        *result
257    }
258
259    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
260        self.message_sender.clone() as _
261    }
262
263    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
264        &self.error_stream as _
265    }
266
267    async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
268        Err(TransportError::Internal(
269            "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(),
270        ))
271    }
272
273    async fn keep_alive(
274        &self,
275        _interval: Duration,
276        _disconnect_tx: oneshot::Sender<()>,
277    ) -> TransportResult<JoinHandle<()>> {
278        Err(TransportError::Internal(
279            "Invalid invocation of keep_alive() function for StdioTransport".to_string(),
280        ))
281    }
282
283    // Shuts down the transport, terminating any subprocess and signaling closure.
284    ///
285    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
286    ///
287    /// # Returns
288    /// A `TransportResult` indicating success or failure.
289    ///
290    /// # Errors
291    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
292    async fn shut_down(&self) -> TransportResult<()> {
293        // Trigger cancellation
294        let mut cancellation_lock = self.shutdown_source.write().await;
295        if let Some(source) = cancellation_lock.as_ref() {
296            source.cancel()?;
297        }
298        *cancellation_lock = None; // Clear cancellation_source
299
300        // Mark as shut down
301        let mut is_shut_down_lock = self.is_shut_down.lock().await;
302        *is_shut_down_lock = true;
303        Ok(())
304    }
305}
306
307#[async_trait]
308impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
309    for StdioTransport<ClientMessage>
310{
311    async fn send_message(
312        &self,
313        message: ServerMessages,
314        request_timeout: Option<Duration>,
315    ) -> TransportResult<Option<ClientMessages>> {
316        let sender = self.message_sender.read().await;
317        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
318        sender.send_message(message, request_timeout).await
319    }
320
321    async fn send(
322        &self,
323        message: ServerMessage,
324        request_timeout: Option<Duration>,
325    ) -> TransportResult<Option<ClientMessage>> {
326        let sender = self.message_sender.read().await;
327        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
328        sender.send(message, request_timeout).await
329    }
330
331    async fn send_batch(
332        &self,
333        message: Vec<ServerMessage>,
334        request_timeout: Option<Duration>,
335    ) -> TransportResult<Option<Vec<ClientMessage>>> {
336        let sender = self.message_sender.read().await;
337        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
338        sender.send_batch(message, request_timeout).await
339    }
340
341    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
342        let sender = self.message_sender.read().await;
343        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
344        sender.write_str(payload, skip_store).await
345    }
346}
347
348impl
349    TransportDispatcher<
350        ClientMessages,
351        MessageFromServer,
352        ClientMessage,
353        ServerMessages,
354        ServerMessage,
355    > for StdioTransport<ClientMessage>
356{
357}
358
359#[async_trait]
360impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
361    for StdioTransport<ServerMessage>
362{
363    async fn send_message(
364        &self,
365        message: ClientMessages,
366        request_timeout: Option<Duration>,
367    ) -> TransportResult<Option<ServerMessages>> {
368        let sender = self.message_sender.read().await;
369        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
370        sender.send_message(message, request_timeout).await
371    }
372
373    async fn send(
374        &self,
375        message: ClientMessage,
376        request_timeout: Option<Duration>,
377    ) -> TransportResult<Option<ServerMessage>> {
378        let sender = self.message_sender.read().await;
379        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
380        sender.send(message, request_timeout).await
381    }
382
383    async fn send_batch(
384        &self,
385        message: Vec<ClientMessage>,
386        request_timeout: Option<Duration>,
387    ) -> TransportResult<Option<Vec<ServerMessage>>> {
388        let sender = self.message_sender.read().await;
389        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
390        sender.send_batch(message, request_timeout).await
391    }
392
393    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
394        let sender = self.message_sender.read().await;
395        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
396        sender.write_str(payload, skip_store).await
397    }
398}
399
400impl
401    TransportDispatcher<
402        ServerMessages,
403        MessageFromClient,
404        ServerMessage,
405        ClientMessages,
406        ClientMessage,
407    > for StdioTransport<ServerMessage>
408{
409}