rust_mcp_transport/
stdio.rs1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
4use rust_mcp_schema::RequestId;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::process::Command;
9use tokio::sync::watch::Sender;
10use tokio::sync::{watch, Mutex};
11
12use crate::error::{GenericWatchSendError, TransportError, TransportResult};
13use crate::mcp_stream::MCPStream;
14use crate::message_dispatcher::MessageDispatcher;
15use crate::transport::Transport;
16use crate::{IoStream, McpDispatch, TransportOptions};
17
18pub struct StdioTransport {
26 command: Option<String>,
27 args: Option<Vec<String>>,
28 env: Option<HashMap<String, String>>,
29 options: TransportOptions,
30 shutdown_tx: tokio::sync::RwLock<Option<Sender<bool>>>,
31 is_shut_down: Mutex<bool>,
32}
33
34impl StdioTransport {
35 pub fn new(options: TransportOptions) -> TransportResult<Self> {
48 Ok(Self {
49 args: None,
51 command: None,
52 env: None,
53 options,
54 shutdown_tx: tokio::sync::RwLock::new(None),
55 is_shut_down: Mutex::new(false),
56 })
57 }
58
59 pub fn create_with_server_launch<C: Into<String>>(
74 command: C,
75 args: Vec<String>,
76 env: Option<HashMap<String, String>>,
77 options: TransportOptions,
78 ) -> TransportResult<Self> {
79 Ok(Self {
80 args: Some(args),
82 command: Some(command.into()),
83 env,
84 options,
85 shutdown_tx: tokio::sync::RwLock::new(None),
86 is_shut_down: Mutex::new(false),
87 })
88 }
89
90 fn launch_commands(&self) -> (String, Vec<std::string::String>) {
97 #[cfg(windows)]
98 {
99 let command = "cmd.exe".to_string();
100 let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
101 command_args.extend(self.args.clone().unwrap_or_default());
102 (command, command_args)
103 }
104
105 #[cfg(unix)]
106 {
107 let command = self.command.clone().unwrap_or_default();
108 let command_args = self.args.clone().unwrap_or_default();
109 (command, command_args)
110 }
111 }
112}
113
114#[async_trait]
115impl<R, S> Transport<R, S> for StdioTransport
116where
117 R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
118 S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
119{
120 async fn start(
134 &self,
135 ) -> TransportResult<(
136 Pin<Box<dyn Stream<Item = R> + Send>>,
137 MessageDispatcher<R>,
138 IoStream,
139 )>
140 where
141 MessageDispatcher<R>: McpDispatch<R, S>,
142 {
143 let (shutdown_tx, shutdown_rx) = watch::channel(false);
144
145 let mut lock = self.shutdown_tx.write().await;
146 *lock = Some(shutdown_tx);
147
148 if self.command.is_some() {
149 let (command_name, command_args) = self.launch_commands();
150
151 let mut command = Command::new(command_name);
152 command
153 .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
154 .args(&command_args)
155 .stdout(std::process::Stdio::piped())
156 .stdin(std::process::Stdio::piped())
157 .stderr(std::process::Stdio::piped())
158 .kill_on_drop(true);
159
160 #[cfg(windows)]
161 command.creation_flags(0x08000000); #[cfg(unix)]
164 command.process_group(0);
165
166 let mut process = command.spawn().map_err(TransportError::StdioError)?;
167
168 let stdin = process
169 .stdin
170 .take()
171 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
172
173 let stdout = process
174 .stdout
175 .take()
176 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
177
178 let stderr = process
179 .stderr
180 .take()
181 .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
182
183 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
184 Arc::new(Mutex::new(HashMap::new()));
185 let pending_requests_clone = Arc::clone(&pending_requests);
186
187 tokio::spawn(async move {
188 let _ = process.wait().await;
189 let mut pending_requests = pending_requests.lock().await;
191 pending_requests.clear();
192 });
193
194 let (stream, sender, error_stream) = MCPStream::create(
195 Box::pin(stdout),
196 Mutex::new(Box::pin(stdin)),
197 IoStream::Readable(Box::pin(stderr)),
198 pending_requests_clone,
199 self.options.timeout,
200 shutdown_rx,
201 );
202
203 Ok((stream, sender, error_stream))
204 } else {
205 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
206 Arc::new(Mutex::new(HashMap::new()));
207 let (stream, sender, error_stream) = MCPStream::create(
208 Box::pin(tokio::io::stdin()),
209 Mutex::new(Box::pin(tokio::io::stdout())),
210 IoStream::Writable(Box::pin(tokio::io::stderr())),
211 pending_requests,
212 self.options.timeout,
213 shutdown_rx,
214 );
215
216 Ok((stream, sender, error_stream))
217 }
218 }
219
220 async fn is_shut_down(&self) -> bool {
222 let result = self.is_shut_down.lock().await;
223 *result
224 }
225
226 async fn shut_down(&self) -> TransportResult<()> {
236 let lock = self.shutdown_tx.write().await;
237 if let Some(tx) = lock.as_ref() {
238 tx.send(true).map_err(GenericWatchSendError::new)?;
239 let mut lock = self.is_shut_down.lock().await;
240 *lock = true
241 }
242 Ok(())
243 }
244}