rust_mcp_transport/
stdio.rs

1use crate::schema::schema_utils::{
2    ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages,
3};
4use crate::schema::RequestId;
5use async_trait::async_trait;
6use serde::de::DeserializeOwned;
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::process::Command;
12use tokio::sync::oneshot::Sender;
13use tokio::sync::{oneshot, Mutex};
14use tokio::task::JoinHandle;
15
16use crate::error::{TransportError, TransportResult};
17use crate::mcp_stream::MCPStream;
18use crate::message_dispatcher::MessageDispatcher;
19use crate::transport::Transport;
20use crate::utils::CancellationTokenSource;
21use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
22
23/// Implements a standard I/O transport for MCP communication.
24///
25/// This module provides the `StdioTransport` struct, which serves as a transport layer for the
26/// Model Context Protocol (MCP) using standard input/output (stdio). It supports both client-side
27/// and server-side communication by optionally launching a subprocess or using the current
28/// process's stdio streams. The transport handles message streaming, dispatching, and shutdown
29/// operations, integrating with the MCP runtime ecosystem.
30pub struct StdioTransport<R>
31where
32    R: Clone + Send + Sync + DeserializeOwned + 'static,
33{
34    command: Option<String>,
35    args: Option<Vec<String>>,
36    env: Option<HashMap<String, String>>,
37    options: TransportOptions,
38    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
39    is_shut_down: Mutex<bool>,
40    message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
41    error_stream: tokio::sync::RwLock<Option<IoStream>>,
42    pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
43}
44
45impl<R> StdioTransport<R>
46where
47    R: Clone + Send + Sync + DeserializeOwned + 'static,
48{
49    /// Creates a new `StdioTransport` instance for MCP Server.
50    ///
51    /// This constructor configures the transport to use the current process's stdio streams,
52    ///
53    /// # Arguments
54    /// * `options` - Configuration options for the transport, including timeout settings.
55    ///
56    /// # Returns
57    /// A `TransportResult` containing the initialized `StdioTransport` instance.
58    ///
59    /// # Errors
60    /// Currently, this method does not fail, but it returns a `TransportResult` for API consistency.
61    pub fn new(options: TransportOptions) -> TransportResult<Self> {
62        Ok(Self {
63            // when transport is used for MCP Server, we do not need a command
64            args: None,
65            command: None,
66            env: None,
67            options,
68            shutdown_source: tokio::sync::RwLock::new(None),
69            is_shut_down: Mutex::new(false),
70            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
71            error_stream: tokio::sync::RwLock::new(None),
72            pending_requests: Arc::new(Mutex::new(HashMap::new())),
73        })
74    }
75
76    /// Creates a new `StdioTransport` instance with a subprocess for MCP Client use.
77    ///
78    /// This constructor configures the transport to launch a MCP Server with a specified command
79    /// arguments and optional environment variables
80    ///
81    /// # Arguments
82    /// * `command` - The command to execute (e.g., "rust-mcp-filesystem").
83    /// * `args` - Arguments to pass to the command. (e.g., "~/Documents").
84    /// * `env` - Optional environment variables for the subprocess.
85    /// * `options` - Configuration options for the transport, including timeout settings.
86    ///
87    /// # Returns
88    /// A `TransportResult` containing the initialized `StdioTransport` instance, ready to launch
89    /// the MCP server on `start`.
90    pub fn create_with_server_launch<C: Into<String>>(
91        command: C,
92        args: Vec<String>,
93        env: Option<HashMap<String, String>>,
94        options: TransportOptions,
95    ) -> TransportResult<Self> {
96        Ok(Self {
97            // when transport is used for MCP Server, we do not need a command
98            args: Some(args),
99            command: Some(command.into()),
100            env,
101            options,
102            shutdown_source: tokio::sync::RwLock::new(None),
103            is_shut_down: Mutex::new(false),
104            message_sender: Arc::new(tokio::sync::RwLock::new(None)),
105            error_stream: tokio::sync::RwLock::new(None),
106            pending_requests: Arc::new(Mutex::new(HashMap::new())),
107        })
108    }
109
110    /// Retrieves the command and arguments for launching the subprocess.
111    ///
112    /// Adjusts the command based on the platform: on Windows, wraps it with `cmd.exe /c`.
113    ///
114    /// # Returns
115    /// A tuple of the command string and its arguments.
116    fn launch_commands(&self) -> (String, Vec<std::string::String>) {
117        #[cfg(windows)]
118        {
119            let command = "cmd.exe".to_string();
120            let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
121            command_args.extend(self.args.clone().unwrap_or_default());
122            (command, command_args)
123        }
124
125        #[cfg(unix)]
126        {
127            let command = self.command.clone().unwrap_or_default();
128            let command_args = self.args.clone().unwrap_or_default();
129            (command, command_args)
130        }
131    }
132
133    pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
134        let mut lock = self.message_sender.write().await;
135        *lock = Some(sender);
136    }
137
138    pub(crate) async fn set_error_stream(
139        &self,
140        error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
141    ) {
142        let mut lock = self.error_stream.write().await;
143        *lock = Some(IoStream::Writable(error_stream));
144    }
145}
146
147#[async_trait]
148impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for StdioTransport<M>
149where
150    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
151    S: Clone + Send + Sync + serde::Serialize + 'static,
152    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
153    OR: Clone + Send + Sync + serde::Serialize + 'static,
154    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
155{
156    /// Starts the transport, initializing streams and the message dispatcher.
157    ///
158    /// If configured with a command (MCP Client), launches the MCP server and connects its stdio streams.
159    /// Otherwise, uses the current process's stdio for server-side communication.
160    ///
161    /// # Returns
162    /// A `TransportResult` containing:
163    /// - A pinned stream of incoming messages.
164    /// - A `MessageDispatcher<R>` for sending messages.
165    /// - An `IoStream` for stderr (readable) or stdout (writable) depending on the mode.
166    ///
167    /// # Errors
168    /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed.
169    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
170    where
171        MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
172    {
173        // Create CancellationTokenSource and token
174        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
175        let mut lock = self.shutdown_source.write().await;
176        *lock = Some(cancellation_source);
177
178        if self.command.is_some() {
179            let (command_name, command_args) = self.launch_commands();
180
181            let mut command = Command::new(command_name);
182            command
183                .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
184                .args(&command_args)
185                .stdout(std::process::Stdio::piped())
186                .stdin(std::process::Stdio::piped())
187                .stderr(std::process::Stdio::piped())
188                .kill_on_drop(true);
189
190            #[cfg(windows)]
191            command.creation_flags(0x08000000); // https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags
192
193            #[cfg(unix)]
194            command.process_group(0);
195
196            let mut process = command.spawn().map_err(TransportError::StdioError)?;
197
198            let stdin = process
199                .stdin
200                .take()
201                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
202
203            let stdout = process
204                .stdout
205                .take()
206                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
207
208            let stderr = process
209                .stderr
210                .take()
211                .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
212
213            let pending_requests_clone1 = self.pending_requests.clone();
214            let pending_requests_clone2 = self.pending_requests.clone();
215
216            tokio::spawn(async move {
217                let _ = process.wait().await;
218                // clean up pending requests to cancel waiting tasks
219                let mut pending_requests = pending_requests_clone1.lock().await;
220                pending_requests.clear();
221            });
222
223            let (stream, sender, error_stream) = MCPStream::create(
224                Box::pin(stdout),
225                Mutex::new(Box::pin(stdin)),
226                IoStream::Readable(Box::pin(stderr)),
227                pending_requests_clone2,
228                self.options.timeout,
229                cancellation_token,
230            );
231
232            self.set_message_sender(sender).await;
233
234            if let IoStream::Writable(error_stream) = error_stream {
235                self.set_error_stream(error_stream).await;
236            }
237
238            Ok(stream)
239        } else {
240            let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<M>>>> =
241                Arc::new(Mutex::new(HashMap::new()));
242            let (stream, sender, error_stream) = MCPStream::create(
243                Box::pin(tokio::io::stdin()),
244                Mutex::new(Box::pin(tokio::io::stdout())),
245                IoStream::Writable(Box::pin(tokio::io::stderr())),
246                pending_requests,
247                self.options.timeout,
248                cancellation_token,
249            );
250
251            self.set_message_sender(sender).await;
252
253            if let IoStream::Writable(error_stream) = error_stream {
254                self.set_error_stream(error_stream).await;
255            }
256            Ok(stream)
257        }
258    }
259
260    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
261        let mut pending_requests = self.pending_requests.lock().await;
262        pending_requests.remove(request_id)
263    }
264
265    /// Checks if the transport has been shut down.
266    async fn is_shut_down(&self) -> bool {
267        let result = self.is_shut_down.lock().await;
268        *result
269    }
270
271    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
272        self.message_sender.clone() as _
273    }
274
275    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
276        &self.error_stream as _
277    }
278
279    async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
280        Err(TransportError::FromString(
281            "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(),
282        ))
283    }
284
285    async fn keep_alive(
286        &self,
287        _interval: Duration,
288        _disconnect_tx: oneshot::Sender<()>,
289    ) -> TransportResult<JoinHandle<()>> {
290        Err(TransportError::FromString(
291            "Invalid invocation of keep_alive() function for StdioTransport".to_string(),
292        ))
293    }
294
295    // Shuts down the transport, terminating any subprocess and signaling closure.
296    ///
297    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
298    ///
299    /// # Returns
300    /// A `TransportResult` indicating success or failure.
301    ///
302    /// # Errors
303    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
304    async fn shut_down(&self) -> TransportResult<()> {
305        // Trigger cancellation
306        let mut cancellation_lock = self.shutdown_source.write().await;
307        if let Some(source) = cancellation_lock.as_ref() {
308            source.cancel()?;
309        }
310        *cancellation_lock = None; // Clear cancellation_source
311
312        // Mark as shut down
313        let mut is_shut_down_lock = self.is_shut_down.lock().await;
314        *is_shut_down_lock = true;
315        Ok(())
316    }
317}
318
319#[async_trait]
320impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
321    for StdioTransport<ClientMessage>
322{
323    async fn send_message(
324        &self,
325        message: ServerMessages,
326        request_timeout: Option<Duration>,
327    ) -> TransportResult<Option<ClientMessages>> {
328        let sender = self.message_sender.read().await;
329        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
330        sender.send_message(message, request_timeout).await
331    }
332
333    async fn send(
334        &self,
335        message: ServerMessage,
336        request_timeout: Option<Duration>,
337    ) -> TransportResult<Option<ClientMessage>> {
338        let sender = self.message_sender.read().await;
339        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
340        sender.send(message, request_timeout).await
341    }
342
343    async fn send_batch(
344        &self,
345        message: Vec<ServerMessage>,
346        request_timeout: Option<Duration>,
347    ) -> TransportResult<Option<Vec<ClientMessage>>> {
348        let sender = self.message_sender.read().await;
349        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
350        sender.send_batch(message, request_timeout).await
351    }
352
353    async fn write_str(&self, payload: &str) -> TransportResult<()> {
354        let sender = self.message_sender.read().await;
355        let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
356        sender.write_str(payload).await
357    }
358}
359
360impl
361    TransportDispatcher<
362        ClientMessages,
363        MessageFromServer,
364        ClientMessage,
365        ServerMessages,
366        ServerMessage,
367    > for StdioTransport<ClientMessage>
368{
369}