rust_mcp_transport/
stdio.rs1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{MCPMessage, RPCMessage};
4use std::collections::HashMap;
5use std::pin::Pin;
6use tokio::process::{Child, Command};
7use tokio::sync::watch::Sender;
8use tokio::sync::{watch, Mutex};
9
10use crate::error::{GenericWatchSendError, TransportError, TransportResult};
11use crate::mcp_stream::MCPStream;
12use crate::message_dispatcher::MessageDispatcher;
13use crate::transport::Transport;
14use crate::{IoStream, McpDispatch, TransportOptions};
15
16pub struct StdioTransport {
24 command: Option<String>,
25 args: Option<Vec<String>>,
26 env: Option<HashMap<String, String>>,
27 process: Mutex<Option<Child>>,
28 options: TransportOptions,
29 shutdown_tx: tokio::sync::RwLock<Option<Sender<bool>>>,
30 is_shut_down: Mutex<bool>,
31}
32
33impl StdioTransport {
34 pub fn new(options: TransportOptions) -> TransportResult<Self> {
47 Ok(Self {
48 args: None,
50 command: None,
51 env: None,
52 process: Mutex::new(None),
53 options,
54 shutdown_tx: tokio::sync::RwLock::new(None),
55 is_shut_down: Mutex::new(false),
56 })
57 }
58
59 pub fn create_with_server_launch<C: Into<String>>(
74 command: C,
75 args: Vec<String>,
76 env: Option<HashMap<String, String>>,
77 options: TransportOptions,
78 ) -> TransportResult<Self> {
79 Ok(Self {
80 args: Some(args),
82 command: Some(command.into()),
83 env,
84 process: Mutex::new(None),
85 options,
86 shutdown_tx: tokio::sync::RwLock::new(None),
87 is_shut_down: Mutex::new(false),
88 })
89 }
90
91 async fn set_process(&self, value: Child) -> TransportResult<()> {
93 let mut process = self.process.lock().await;
94 *process = Some(value);
95 Ok(())
96 }
97
98 fn launch_commands(&self) -> (String, Vec<std::string::String>) {
105 #[cfg(windows)]
106 {
107 let command = "cmd.exe".to_string();
108 let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
109 command_args.extend(self.args.clone().unwrap_or_default());
110 (command, command_args)
111 }
112
113 #[cfg(unix)]
114 {
115 let command = self.command.clone().unwrap_or_default();
116 let command_args = self.args.clone().unwrap_or_default();
117 (command, command_args)
118 }
119 }
120}
121
122#[async_trait]
123impl<R, S> Transport<R, S> for StdioTransport
124where
125 R: RPCMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
126 S: MCPMessage + Clone + Send + Sync + serde::Serialize + 'static,
127{
128 async fn start(
142 &self,
143 ) -> TransportResult<(
144 Pin<Box<dyn Stream<Item = R> + Send>>,
145 MessageDispatcher<R>,
146 IoStream,
147 )>
148 where
149 MessageDispatcher<R>: McpDispatch<R, S>,
150 {
151 let (shutdown_tx, shutdown_rx) = watch::channel(false);
152
153 let mut lock = self.shutdown_tx.write().await;
154 *lock = Some(shutdown_tx);
155
156 if self.command.is_some() {
157 let (command_name, command_args) = self.launch_commands();
158
159 let mut command = Command::new(command_name);
160 command
161 .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
162 .args(&command_args)
163 .stdout(std::process::Stdio::piped())
164 .stdin(std::process::Stdio::piped())
165 .stderr(std::process::Stdio::piped())
166 .kill_on_drop(true);
167
168 #[cfg(windows)]
169 command.creation_flags(0x08000000); #[cfg(unix)]
172 command.process_group(0);
173
174 let mut process = command.spawn().map_err(TransportError::StdioError)?;
175
176 let stdin = process
177 .stdin
178 .take()
179 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
180
181 let stdout = process
182 .stdout
183 .take()
184 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
185
186 let stderr = process
187 .stderr
188 .take()
189 .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
190
191 self.set_process(process).await.unwrap();
192
193 let (stream, sender, error_stream) = MCPStream::create(
194 Box::pin(stdout),
195 Mutex::new(Box::pin(stdin)),
196 IoStream::Readable(Box::pin(stderr)),
197 self.options.timeout,
198 shutdown_rx,
199 );
200
201 Ok((stream, sender, error_stream))
202 } else {
203 let (stream, sender, error_stream) = MCPStream::create(
204 Box::pin(tokio::io::stdin()),
205 Mutex::new(Box::pin(tokio::io::stdout())),
206 IoStream::Writable(Box::pin(tokio::io::stderr())),
207 self.options.timeout,
208 shutdown_rx,
209 );
210
211 Ok((stream, sender, error_stream))
212 }
213 }
214
215 async fn is_shut_down(&self) -> bool {
217 let result = self.is_shut_down.lock().await;
218 *result
219 }
220
221 async fn shut_down(&self) -> TransportResult<()> {
231 let lock = self.shutdown_tx.write().await;
232 if let Some(tx) = lock.as_ref() {
233 tx.send(true).map_err(GenericWatchSendError::new)?;
234 let mut lock = self.is_shut_down.lock().await;
235 *lock = true
236 }
237
238 let mut process = self.process.lock().await;
239 if let Some(p) = process.as_mut() {
240 p.kill().await?;
241 p.wait().await?;
242 }
243 Ok(())
244 }
245}