rust_mcp_transport/
stdio.rs

1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
4use rust_mcp_schema::RequestId;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::process::Command;
9use tokio::sync::watch::Sender;
10use tokio::sync::{watch, Mutex};
11
12use crate::error::{GenericWatchSendError, TransportError, TransportResult};
13use crate::mcp_stream::MCPStream;
14use crate::message_dispatcher::MessageDispatcher;
15use crate::transport::Transport;
16use crate::{IoStream, McpDispatch, TransportOptions};
17
18/// Implements a standard I/O transport for MCP communication.
19///
20/// This module provides the `StdioTransport` struct, which serves as a transport layer for the
21/// Model Context Protocol (MCP) using standard input/output (stdio). It supports both client-side
22/// and server-side communication by optionally launching a subprocess or using the current
23/// process's stdio streams. The transport handles message streaming, dispatching, and shutdown
24/// operations, integrating with the MCP runtime ecosystem.
25pub struct StdioTransport {
26    command: Option<String>,
27    args: Option<Vec<String>>,
28    env: Option<HashMap<String, String>>,
29    options: TransportOptions,
30    shutdown_tx: tokio::sync::RwLock<Option<Sender<bool>>>,
31    is_shut_down: Mutex<bool>,
32}
33
34impl StdioTransport {
35    /// Creates a new `StdioTransport` instance for MCP Server.
36    ///
37    /// This constructor configures the transport to use the current process's stdio streams,
38    ///
39    /// # Arguments
40    /// * `options` - Configuration options for the transport, including timeout settings.
41    ///
42    /// # Returns
43    /// A `TransportResult` containing the initialized `StdioTransport` instance.
44    ///
45    /// # Errors
46    /// Currently, this method does not fail, but it returns a `TransportResult` for API consistency.
47    pub fn new(options: TransportOptions) -> TransportResult<Self> {
48        Ok(Self {
49            // when transport is used for MCP Server, we do not need a command
50            args: None,
51            command: None,
52            env: None,
53            options,
54            shutdown_tx: tokio::sync::RwLock::new(None),
55            is_shut_down: Mutex::new(false),
56        })
57    }
58
59    /// Creates a new `StdioTransport` instance with a subprocess for MCP Client use.
60    ///
61    /// This constructor configures the transport to launch a MCP Server with a specified command
62    /// arguments and optional environment variables
63    ///
64    /// # Arguments
65    /// * `command` - The command to execute (e.g., "rust-mcp-filesystem").
66    /// * `args` - Arguments to pass to the command. (e.g., "~/Documents").
67    /// * `env` - Optional environment variables for the subprocess.
68    /// * `options` - Configuration options for the transport, including timeout settings.
69    ///
70    /// # Returns
71    /// A `TransportResult` containing the initialized `StdioTransport` instance, ready to launch
72    /// the MCP server on `start`.
73    pub fn create_with_server_launch<C: Into<String>>(
74        command: C,
75        args: Vec<String>,
76        env: Option<HashMap<String, String>>,
77        options: TransportOptions,
78    ) -> TransportResult<Self> {
79        Ok(Self {
80            // when transport is used for MCP Server, we do not need a command
81            args: Some(args),
82            command: Some(command.into()),
83            env,
84            options,
85            shutdown_tx: tokio::sync::RwLock::new(None),
86            is_shut_down: Mutex::new(false),
87        })
88    }
89
90    /// Retrieves the command and arguments for launching the subprocess.
91    ///
92    /// Adjusts the command based on the platform: on Windows, wraps it with `cmd.exe /c`.
93    ///
94    /// # Returns
95    /// A tuple of the command string and its arguments.
96    fn launch_commands(&self) -> (String, Vec<std::string::String>) {
97        #[cfg(windows)]
98        {
99            let command = "cmd.exe".to_string();
100            let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
101            command_args.extend(self.args.clone().unwrap_or_default());
102            (command, command_args)
103        }
104
105        #[cfg(unix)]
106        {
107            let command = self.command.clone().unwrap_or_default();
108            let command_args = self.args.clone().unwrap_or_default();
109            (command, command_args)
110        }
111    }
112}
113
114#[async_trait]
115impl<R, S> Transport<R, S> for StdioTransport
116where
117    R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
118    S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
119{
120    /// Starts the transport, initializing streams and the message dispatcher.
121    ///
122    /// If configured with a command (MCP Client), launches the MCP server and connects its stdio streams.
123    /// Otherwise, uses the current process's stdio for server-side communication.
124    ///
125    /// # Returns
126    /// A `TransportResult` containing:
127    /// - A pinned stream of incoming messages.
128    /// - A `MessageDispatcher<R>` for sending messages.
129    /// - An `IoStream` for stderr (readable) or stdout (writable) depending on the mode.
130    ///
131    /// # Errors
132    /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed.
133    async fn start(
134        &self,
135    ) -> TransportResult<(
136        Pin<Box<dyn Stream<Item = R> + Send>>,
137        MessageDispatcher<R>,
138        IoStream,
139    )>
140    where
141        MessageDispatcher<R>: McpDispatch<R, S>,
142    {
143        let (shutdown_tx, shutdown_rx) = watch::channel(false);
144
145        let mut lock = self.shutdown_tx.write().await;
146        *lock = Some(shutdown_tx);
147
148        if self.command.is_some() {
149            let (command_name, command_args) = self.launch_commands();
150
151            let mut command = Command::new(command_name);
152            command
153                .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
154                .args(&command_args)
155                .stdout(std::process::Stdio::piped())
156                .stdin(std::process::Stdio::piped())
157                .stderr(std::process::Stdio::piped())
158                .kill_on_drop(true);
159
160            #[cfg(windows)]
161            command.creation_flags(0x08000000); // https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags
162
163            #[cfg(unix)]
164            command.process_group(0);
165
166            let mut process = command.spawn().map_err(TransportError::StdioError)?;
167
168            let stdin = process
169                .stdin
170                .take()
171                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
172
173            let stdout = process
174                .stdout
175                .take()
176                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
177
178            let stderr = process
179                .stderr
180                .take()
181                .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
182
183            let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
184                Arc::new(Mutex::new(HashMap::new()));
185            let pending_requests_clone = Arc::clone(&pending_requests);
186
187            tokio::spawn(async move {
188                let _ = process.wait().await;
189                // clean up pending requests to cancel waiting tasks
190                let mut pending_requests = pending_requests.lock().await;
191                pending_requests.clear();
192            });
193
194            let (stream, sender, error_stream) = MCPStream::create(
195                Box::pin(stdout),
196                Mutex::new(Box::pin(stdin)),
197                IoStream::Readable(Box::pin(stderr)),
198                pending_requests_clone,
199                self.options.timeout,
200                shutdown_rx,
201            );
202
203            Ok((stream, sender, error_stream))
204        } else {
205            let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
206                Arc::new(Mutex::new(HashMap::new()));
207            let (stream, sender, error_stream) = MCPStream::create(
208                Box::pin(tokio::io::stdin()),
209                Mutex::new(Box::pin(tokio::io::stdout())),
210                IoStream::Writable(Box::pin(tokio::io::stderr())),
211                pending_requests,
212                self.options.timeout,
213                shutdown_rx,
214            );
215
216            Ok((stream, sender, error_stream))
217        }
218    }
219
220    /// Checks if the transport has been shut down.
221    async fn is_shut_down(&self) -> bool {
222        let result = self.is_shut_down.lock().await;
223        *result
224    }
225
226    // Shuts down the transport, terminating any subprocess and signaling closure.
227    ///
228    /// Sends a shutdown signal via the watch channel and kills the subprocess if present.
229    ///
230    /// # Returns
231    /// A `TransportResult` indicating success or failure.
232    ///
233    /// # Errors
234    /// Returns a `TransportError` if the shutdown signal fails or the process cannot be killed.
235    async fn shut_down(&self) -> TransportResult<()> {
236        let lock = self.shutdown_tx.write().await;
237        if let Some(tx) = lock.as_ref() {
238            tx.send(true).map_err(GenericWatchSendError::new)?;
239            let mut lock = self.is_shut_down.lock().await;
240            *lock = true
241        }
242        Ok(())
243    }
244}