rust_mcp_transport/
stdio.rs1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
4use rust_mcp_schema::RequestId;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::process::Command;
9use tokio::sync::Mutex;
10
11use crate::error::{TransportError, TransportResult};
12use crate::mcp_stream::MCPStream;
13use crate::message_dispatcher::MessageDispatcher;
14use crate::transport::Transport;
15use crate::utils::CancellationTokenSource;
16use crate::{IoStream, McpDispatch, TransportOptions};
17
18pub struct StdioTransport {
26 command: Option<String>,
27 args: Option<Vec<String>>,
28 env: Option<HashMap<String, String>>,
29 options: TransportOptions,
30 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
31 is_shut_down: Mutex<bool>,
32}
33
34impl StdioTransport {
35 pub fn new(options: TransportOptions) -> TransportResult<Self> {
48 Ok(Self {
49 args: None,
51 command: None,
52 env: None,
53 options,
54 shutdown_source: tokio::sync::RwLock::new(None),
55 is_shut_down: Mutex::new(false),
56 })
57 }
58
59 pub fn create_with_server_launch<C: Into<String>>(
74 command: C,
75 args: Vec<String>,
76 env: Option<HashMap<String, String>>,
77 options: TransportOptions,
78 ) -> TransportResult<Self> {
79 Ok(Self {
80 args: Some(args),
82 command: Some(command.into()),
83 env,
84 options,
85 shutdown_source: tokio::sync::RwLock::new(None),
86 is_shut_down: Mutex::new(false),
87 })
88 }
89
90 fn launch_commands(&self) -> (String, Vec<std::string::String>) {
97 #[cfg(windows)]
98 {
99 let command = "cmd.exe".to_string();
100 let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
101 command_args.extend(self.args.clone().unwrap_or_default());
102 (command, command_args)
103 }
104
105 #[cfg(unix)]
106 {
107 let command = self.command.clone().unwrap_or_default();
108 let command_args = self.args.clone().unwrap_or_default();
109 (command, command_args)
110 }
111 }
112}
113
114#[async_trait]
115impl<R, S> Transport<R, S> for StdioTransport
116where
117 R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
118 S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
119{
120 async fn start(
134 &self,
135 ) -> TransportResult<(
136 Pin<Box<dyn Stream<Item = R> + Send>>,
137 MessageDispatcher<R>,
138 IoStream,
139 )>
140 where
141 MessageDispatcher<R>: McpDispatch<R, S>,
142 {
143 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
145 let mut lock = self.shutdown_source.write().await;
146 *lock = Some(cancellation_source);
147
148 if self.command.is_some() {
149 let (command_name, command_args) = self.launch_commands();
150
151 let mut command = Command::new(command_name);
152 command
153 .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
154 .args(&command_args)
155 .stdout(std::process::Stdio::piped())
156 .stdin(std::process::Stdio::piped())
157 .stderr(std::process::Stdio::piped())
158 .kill_on_drop(true);
159
160 #[cfg(windows)]
161 command.creation_flags(0x08000000); #[cfg(unix)]
164 command.process_group(0);
165
166 let mut process = command.spawn().map_err(TransportError::StdioError)?;
167
168 let stdin = process
169 .stdin
170 .take()
171 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
172
173 let stdout = process
174 .stdout
175 .take()
176 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
177
178 let stderr = process
179 .stderr
180 .take()
181 .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
182
183 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
184 Arc::new(Mutex::new(HashMap::new()));
185 let pending_requests_clone = Arc::clone(&pending_requests);
186
187 tokio::spawn(async move {
188 let _ = process.wait().await;
189 let mut pending_requests = pending_requests.lock().await;
191 pending_requests.clear();
192 });
193
194 let (stream, sender, error_stream) = MCPStream::create(
195 Box::pin(stdout),
196 Mutex::new(Box::pin(stdin)),
197 IoStream::Readable(Box::pin(stderr)),
198 pending_requests_clone,
199 self.options.timeout,
200 cancellation_token,
201 );
202
203 Ok((stream, sender, error_stream))
204 } else {
205 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
206 Arc::new(Mutex::new(HashMap::new()));
207 let (stream, sender, error_stream) = MCPStream::create(
208 Box::pin(tokio::io::stdin()),
209 Mutex::new(Box::pin(tokio::io::stdout())),
210 IoStream::Writable(Box::pin(tokio::io::stderr())),
211 pending_requests,
212 self.options.timeout,
213 cancellation_token,
214 );
215
216 Ok((stream, sender, error_stream))
217 }
218 }
219
220 async fn is_shut_down(&self) -> bool {
222 let result = self.is_shut_down.lock().await;
223 *result
224 }
225
226 async fn shut_down(&self) -> TransportResult<()> {
236 let mut cancellation_lock = self.shutdown_source.write().await;
238 if let Some(source) = cancellation_lock.as_ref() {
239 source.cancel()?;
240 }
241 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
245 *is_shut_down_lock = true;
246 Ok(())
247 }
248}