rust_mcp_transport/
stdio.rs

1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{MCPMessage, RPCMessage};
4use std::collections::HashMap;
5use std::pin::Pin;
6use tokio::process::{Child, Command};
7use tokio::sync::watch::Sender;
8use tokio::sync::{watch, Mutex};
9
10use crate::error::{GenericWatchSendError, TransportError, TransportResult};
11use crate::mcp_stream::MCPStream;
12use crate::message_dispatcher::MessageDispatcher;
13use crate::transport::Transport;
14use crate::{IOStream, MCPDispatch, TransportOptions};
15
16/// Implements a standard I/O transport for MCP communication.
17///
18/// This module provides the `StdioTransport` struct, which serves as a transport layer for the
19/// Model Context Protocol (MCP) using standard input/output (stdio). It supports both client-side
20/// and server-side communication by optionally launching a subprocess or using the current
21/// process's stdio streams. The transport handles message streaming, dispatching, and shutdown
22/// operations, integrating with the MCP runtime ecosystem.
23pub struct StdioTransport {
24    command: Option<String>,
25    args: Option<Vec<String>>,
26    env: Option<HashMap<String, String>>,
27    process: Mutex<Option<Child>>,
28    options: TransportOptions,
29    shutdown_tx: tokio::sync::RwLock<Option<Sender<bool>>>,
30    is_shut_down: Mutex<bool>,
31}
32
33impl StdioTransport {
34    /// Creates a new `StdioTransport` instance for MCP Server.
35    ///
36    /// This constructor configures the transport to use the current process's stdio streams,
37    ///
38    /// # Arguments
39    /// * `options` - Configuration options for the transport, including timeout settings.
40    ///
41    /// # Returns
42    /// A `TransportResult` containing the initialized `StdioTransport` instance.
43    ///
44    /// # Errors
45    /// Currently, this method does not fail, but it returns a `TransportResult` for API consistency.
46    pub fn new(options: TransportOptions) -> TransportResult<Self> {
47        Ok(Self {
48            // when transport is used for MCP Server, we do not need a command
49            args: None,
50            command: None,
51            env: None,
52            process: Mutex::new(None),
53            options,
54            shutdown_tx: tokio::sync::RwLock::new(None),
55            is_shut_down: Mutex::new(false),
56        })
57    }
58
59    /// Creates a new `StdioTransport` instance with a subprocess for MCP Client use.
60    ///
61    /// This constructor configures the transport to launch a MCP Server with a specified command
62    /// arguments and optional environment variables
63    ///
64    /// # Arguments
65    /// * `command` - The command to execute (e.g., "rust-mcp-filesystem").
66    /// * `args` - Arguments to pass to the command. (e.g., "~/Documents").
67    /// * `env` - Optional environment variables for the subprocess.
68    /// * `options` - Configuration options for the transport, including timeout settings.
69    ///
70    /// # Returns
71    /// A `TransportResult` containing the initialized `StdioTransport` instance, ready to launch
72    /// the MCP server on `start`.
73    pub fn create_with_server_launch<C: Into<String>>(
74        command: C,
75        args: Vec<String>,
76        env: Option<HashMap<String, String>>,
77        options: TransportOptions,
78    ) -> TransportResult<Self> {
79        Ok(Self {
80            // when transport is used for MCP Server, we do not need a command
81            args: Some(args),
82            command: Some(command.into()),
83            env,
84            process: Mutex::new(None),
85            options,
86            shutdown_tx: tokio::sync::RwLock::new(None),
87            is_shut_down: Mutex::new(false),
88        })
89    }
90
91    /// Sets the subprocess handle for the transport.
92    async fn set_process(&self, value: Child) -> TransportResult<()> {
93        let mut process = self.process.lock().await;
94        *process = Some(value);
95        Ok(())
96    }
97
98    /// Retrieves the command and arguments for launching the subprocess.
99    ///
100    /// Adjusts the command based on the platform: on Windows, wraps it with `cmd.exe /c`.
101    ///
102    /// # Returns
103    /// A tuple of the command string and its arguments.
104    fn get_launch_commands(&self) -> (String, Vec<std::string::String>) {
105        #[cfg(windows)]
106        {
107            let command = "cmd.exe".to_string();
108            let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
109            command_args.extend(self.args.clone().unwrap_or_default());
110            (command, command_args)
111        }
112
113        #[cfg(unix)]
114        {
115            let command = self.command.clone().unwrap_or_default();
116            let command_args = self.args.clone().unwrap_or_default();
117            (command, command_args)
118        }
119    }
120}
121
122#[async_trait]
123impl<R, S> Transport<R, S> for StdioTransport
124where
125    R: RPCMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
126    S: MCPMessage + Clone + Send + Sync + serde::Serialize + 'static,
127{
128    /// Starts the transport, initializing streams and the message dispatcher.
129    ///
130    /// If configured with a command (MCP Client), launches the MCP server and connects its stdio streams.
131    /// Otherwise, uses the current process's stdio for server-side communication.
132    ///
133    /// # Returns
134    /// A `TransportResult` containing:
135    /// - A pinned stream of incoming messages.
136    /// - A `MessageDispatcher<R>` for sending messages.
137    /// - An `IOStream` for stderr (readable) or stdout (writable) depending on the mode.
138    ///
139    /// # Errors
140    /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed.
141    async fn start(
142        &self,
143    ) -> TransportResult<(
144        Pin<Box<dyn Stream<Item = R> + Send>>,
145        MessageDispatcher<R>,
146        IOStream,
147    )>
148    where
149        MessageDispatcher<R>: MCPDispatch<R, S>,
150    {
151        let (shutdown_tx, shutdown_rx) = watch::channel(false);
152
153        let mut lock = self.shutdown_tx.write().await;
154        *lock = Some(shutdown_tx);
155
156        if self.command.is_some() {
157            let (command_name, command_args) = self.get_launch_commands();
158
159            let mut command = Command::new(command_name);
160            command
161                .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
162                .args(&command_args)
163                .stdout(std::process::Stdio::piped())
164                .stdin(std::process::Stdio::piped())
165                .stderr(std::process::Stdio::piped())
166                .kill_on_drop(true);
167
168            #[cfg(windows)]
169            command.creation_flags(0x08000000); // https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags
170
171            #[cfg(unix)]
172            command.process_group(0);
173
174            let mut process = command.spawn().map_err(TransportError::StdioError)?;
175
176            let stdin = process
177                .stdin
178                .take()
179                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
180
181            let stdout = process
182                .stdout
183                .take()
184                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
185
186            let stderr = process
187                .stderr
188                .take()
189                .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
190
191            self.set_process(process).await.unwrap();
192
193            let (stream, sender, error_stream) = MCPStream::create(
194                Box::pin(stdout),
195                Mutex::new(Box::pin(stdin)),
196                IOStream::Readable(Box::pin(stderr)),
197                self.options.timeout,
198                shutdown_rx,
199            );
200
201            Ok((stream, sender, error_stream))
202        } else {
203            let (stream, sender, error_stream) = MCPStream::create(
204                Box::pin(tokio::io::stdin()),
205                Mutex::new(Box::pin(tokio::io::stdout())),
206                IOStream::Writable(Box::pin(tokio::io::stderr())),
207                self.options.timeout,
208                shutdown_rx,
209            );
210
211            Ok((stream, sender, error_stream))
212        }
213    }
214
215    /// Checks if the transport has been shut down.
216    async fn is_shut_down(&self) -> bool {
217        let result = self.is_shut_down.lock().await;
218        *result
219    }
220
221    // Shuts down the transport, terminating any subprocess and signaling closure.
222    ///
223    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
224    ///
225    /// # Returns
226    /// A `TransportResult` indicating success or failure.
227    ///
228    /// # Errors
229    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
230    async fn shut_down(&self) -> TransportResult<()> {
231        let lock = self.shutdown_tx.write().await;
232        if let Some(tx) = lock.as_ref() {
233            tx.send(true).map_err(GenericWatchSendError::new)?;
234            let mut lock = self.is_shut_down.lock().await;
235            *lock = true
236        }
237
238        let mut process = self.process.lock().await;
239        if let Some(p) = process.as_mut() {
240            p.kill().await?;
241            p.wait().await?;
242        }
243        Ok(())
244    }
245}