1use crate::schema::schema_utils::{
2 ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages,
3};
4use crate::schema::RequestId;
5use async_trait::async_trait;
6use serde::de::DeserializeOwned;
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::process::Command;
12use tokio::sync::oneshot::Sender;
13use tokio::sync::{oneshot, Mutex};
14use tokio::task::JoinHandle;
15
16use crate::error::{TransportError, TransportResult};
17use crate::mcp_stream::MCPStream;
18use crate::message_dispatcher::MessageDispatcher;
19use crate::transport::Transport;
20use crate::utils::CancellationTokenSource;
21use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
22
23pub struct StdioTransport<R>
31where
32 R: Clone + Send + Sync + DeserializeOwned + 'static,
33{
34 command: Option<String>,
35 args: Option<Vec<String>>,
36 env: Option<HashMap<String, String>>,
37 options: TransportOptions,
38 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
39 is_shut_down: Mutex<bool>,
40 message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
41 error_stream: tokio::sync::RwLock<Option<IoStream>>,
42 pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
43}
44
45impl<R> StdioTransport<R>
46where
47 R: Clone + Send + Sync + DeserializeOwned + 'static,
48{
49 pub fn new(options: TransportOptions) -> TransportResult<Self> {
62 Ok(Self {
63 args: None,
65 command: None,
66 env: None,
67 options,
68 shutdown_source: tokio::sync::RwLock::new(None),
69 is_shut_down: Mutex::new(false),
70 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
71 error_stream: tokio::sync::RwLock::new(None),
72 pending_requests: Arc::new(Mutex::new(HashMap::new())),
73 })
74 }
75
76 pub fn create_with_server_launch<C: Into<String>>(
91 command: C,
92 args: Vec<String>,
93 env: Option<HashMap<String, String>>,
94 options: TransportOptions,
95 ) -> TransportResult<Self> {
96 Ok(Self {
97 args: Some(args),
99 command: Some(command.into()),
100 env,
101 options,
102 shutdown_source: tokio::sync::RwLock::new(None),
103 is_shut_down: Mutex::new(false),
104 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
105 error_stream: tokio::sync::RwLock::new(None),
106 pending_requests: Arc::new(Mutex::new(HashMap::new())),
107 })
108 }
109
110 fn launch_commands(&self) -> (String, Vec<std::string::String>) {
117 #[cfg(windows)]
118 {
119 let command = "cmd.exe".to_string();
120 let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
121 command_args.extend(self.args.clone().unwrap_or_default());
122 (command, command_args)
123 }
124
125 #[cfg(unix)]
126 {
127 let command = self.command.clone().unwrap_or_default();
128 let command_args = self.args.clone().unwrap_or_default();
129 (command, command_args)
130 }
131 }
132
133 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
134 let mut lock = self.message_sender.write().await;
135 *lock = Some(sender);
136 }
137
138 pub(crate) async fn set_error_stream(
139 &self,
140 error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
141 ) {
142 let mut lock = self.error_stream.write().await;
143 *lock = Some(IoStream::Writable(error_stream));
144 }
145}
146
147#[async_trait]
148impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for StdioTransport<M>
149where
150 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
151 S: Clone + Send + Sync + serde::Serialize + 'static,
152 M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
153 OR: Clone + Send + Sync + serde::Serialize + 'static,
154 OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
155{
156 async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
170 where
171 MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
172 {
173 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
175 let mut lock = self.shutdown_source.write().await;
176 *lock = Some(cancellation_source);
177
178 if self.command.is_some() {
179 let (command_name, command_args) = self.launch_commands();
180
181 let mut command = Command::new(command_name);
182 command
183 .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
184 .args(&command_args)
185 .stdout(std::process::Stdio::piped())
186 .stdin(std::process::Stdio::piped())
187 .stderr(std::process::Stdio::piped())
188 .kill_on_drop(true);
189
190 #[cfg(windows)]
191 command.creation_flags(0x08000000); #[cfg(unix)]
194 command.process_group(0);
195
196 let mut process = command.spawn().map_err(TransportError::StdioError)?;
197
198 let stdin = process
199 .stdin
200 .take()
201 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
202
203 let stdout = process
204 .stdout
205 .take()
206 .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
207
208 let stderr = process
209 .stderr
210 .take()
211 .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
212
213 let pending_requests_clone1 = self.pending_requests.clone();
214 let pending_requests_clone2 = self.pending_requests.clone();
215
216 tokio::spawn(async move {
217 let _ = process.wait().await;
218 let mut pending_requests = pending_requests_clone1.lock().await;
220 pending_requests.clear();
221 });
222
223 let (stream, sender, error_stream) = MCPStream::create(
224 Box::pin(stdout),
225 Mutex::new(Box::pin(stdin)),
226 IoStream::Readable(Box::pin(stderr)),
227 pending_requests_clone2,
228 self.options.timeout,
229 cancellation_token,
230 );
231
232 self.set_message_sender(sender).await;
233
234 if let IoStream::Writable(error_stream) = error_stream {
235 self.set_error_stream(error_stream).await;
236 }
237
238 Ok(stream)
239 } else {
240 let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<M>>>> =
241 Arc::new(Mutex::new(HashMap::new()));
242 let (stream, sender, error_stream) = MCPStream::create(
243 Box::pin(tokio::io::stdin()),
244 Mutex::new(Box::pin(tokio::io::stdout())),
245 IoStream::Writable(Box::pin(tokio::io::stderr())),
246 pending_requests,
247 self.options.timeout,
248 cancellation_token,
249 );
250
251 self.set_message_sender(sender).await;
252
253 if let IoStream::Writable(error_stream) = error_stream {
254 self.set_error_stream(error_stream).await;
255 }
256 Ok(stream)
257 }
258 }
259
260 async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
261 let mut pending_requests = self.pending_requests.lock().await;
262 pending_requests.remove(request_id)
263 }
264
265 async fn is_shut_down(&self) -> bool {
267 let result = self.is_shut_down.lock().await;
268 *result
269 }
270
271 fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
272 self.message_sender.clone() as _
273 }
274
275 fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
276 &self.error_stream as _
277 }
278
279 async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
280 Err(TransportError::FromString(
281 "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(),
282 ))
283 }
284
285 async fn keep_alive(
286 &self,
287 _interval: Duration,
288 _disconnect_tx: oneshot::Sender<()>,
289 ) -> TransportResult<JoinHandle<()>> {
290 Err(TransportError::FromString(
291 "Invalid invocation of keep_alive() function for StdioTransport".to_string(),
292 ))
293 }
294
295 async fn shut_down(&self) -> TransportResult<()> {
305 let mut cancellation_lock = self.shutdown_source.write().await;
307 if let Some(source) = cancellation_lock.as_ref() {
308 source.cancel()?;
309 }
310 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
314 *is_shut_down_lock = true;
315 Ok(())
316 }
317}
318
319#[async_trait]
320impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
321 for StdioTransport<ClientMessage>
322{
323 async fn send_message(
324 &self,
325 message: ServerMessages,
326 request_timeout: Option<Duration>,
327 ) -> TransportResult<Option<ClientMessages>> {
328 let sender = self.message_sender.read().await;
329 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
330 sender.send_message(message, request_timeout).await
331 }
332
333 async fn send(
334 &self,
335 message: ServerMessage,
336 request_timeout: Option<Duration>,
337 ) -> TransportResult<Option<ClientMessage>> {
338 let sender = self.message_sender.read().await;
339 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
340 sender.send(message, request_timeout).await
341 }
342
343 async fn send_batch(
344 &self,
345 message: Vec<ServerMessage>,
346 request_timeout: Option<Duration>,
347 ) -> TransportResult<Option<Vec<ClientMessage>>> {
348 let sender = self.message_sender.read().await;
349 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
350 sender.send_batch(message, request_timeout).await
351 }
352
353 async fn write_str(&self, payload: &str) -> TransportResult<()> {
354 let sender = self.message_sender.read().await;
355 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
356 sender.write_str(payload).await
357 }
358}
359
360impl
361 TransportDispatcher<
362 ClientMessages,
363 MessageFromServer,
364 ClientMessage,
365 ServerMessages,
366 ServerMessage,
367 > for StdioTransport<ClientMessage>
368{
369}