1use crate::schema::schema_utils::{
2 ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage,
3 ServerMessages,
4};
5use crate::schema::RequestId;
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::process::Command;
13use tokio::sync::oneshot::Sender;
14use tokio::sync::{oneshot, Mutex};
15use tokio::task::JoinHandle;
16
17use crate::error::{TransportError, TransportResult};
18use crate::mcp_stream::MCPStream;
19use crate::message_dispatcher::MessageDispatcher;
20use crate::transport::Transport;
21use crate::utils::CancellationTokenSource;
22use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
23
24pub struct StdioTransport<R>
32where
33 R: Clone + Send + Sync + DeserializeOwned + 'static,
34{
35 command: Option<String>,
36 args: Option<Vec<String>>,
37 env: Option<HashMap<String, String>>,
38 options: TransportOptions,
39 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
40 is_shut_down: Mutex<bool>,
41 message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
42 error_stream: tokio::sync::RwLock<Option<IoStream>>,
43 pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
44}
45
46impl<R> StdioTransport<R>
47where
48 R: Clone + Send + Sync + DeserializeOwned + 'static,
49{
50 pub fn new(options: TransportOptions) -> TransportResult<Self> {
63 Ok(Self {
64 args: None,
66 command: None,
67 env: None,
68 options,
69 shutdown_source: tokio::sync::RwLock::new(None),
70 is_shut_down: Mutex::new(false),
71 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
72 error_stream: tokio::sync::RwLock::new(None),
73 pending_requests: Arc::new(Mutex::new(HashMap::new())),
74 })
75 }
76
77 pub fn create_with_server_launch<C: Into<String>>(
92 command: C,
93 args: Vec<String>,
94 env: Option<HashMap<String, String>>,
95 options: TransportOptions,
96 ) -> TransportResult<Self> {
97 Ok(Self {
98 args: Some(args),
100 command: Some(command.into()),
101 env,
102 options,
103 shutdown_source: tokio::sync::RwLock::new(None),
104 is_shut_down: Mutex::new(false),
105 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
106 error_stream: tokio::sync::RwLock::new(None),
107 pending_requests: Arc::new(Mutex::new(HashMap::new())),
108 })
109 }
110
111 fn launch_commands(&self) -> (String, Vec<std::string::String>) {
118 #[cfg(windows)]
119 {
120 let command = "cmd.exe".to_string();
121 let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
122 command_args.extend(self.args.clone().unwrap_or_default());
123 (command, command_args)
124 }
125
126 #[cfg(unix)]
127 {
128 let command = self.command.clone().unwrap_or_default();
129 let command_args = self.args.clone().unwrap_or_default();
130 (command, command_args)
131 }
132 }
133
134 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
135 let mut lock = self.message_sender.write().await;
136 *lock = Some(sender);
137 }
138
139 pub(crate) async fn set_error_stream(
140 &self,
141 error_stream: Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>,
142 ) {
143 let mut lock = self.error_stream.write().await;
144 *lock = Some(IoStream::Writable(error_stream));
145 }
146}
147
148#[async_trait]
149impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for StdioTransport<M>
150where
151 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
152 S: Clone + Send + Sync + serde::Serialize + 'static,
153 M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
154 OR: Clone + Send + Sync + serde::Serialize + 'static,
155 OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
156{
157 async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
171 where
172 MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
173 {
174 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
176 let mut lock = self.shutdown_source.write().await;
177 *lock = Some(cancellation_source);
178
179 if self.command.is_some() {
180 let (command_name, command_args) = self.launch_commands();
181
182 let mut command = Command::new(command_name);
183 command
184 .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
185 .args(&command_args)
186 .stdout(std::process::Stdio::piped())
187 .stdin(std::process::Stdio::piped())
188 .stderr(std::process::Stdio::piped())
189 .kill_on_drop(true);
190
191 #[cfg(windows)]
192 command.creation_flags(0x08000000); #[cfg(unix)]
195 command.process_group(0);
196
197 let mut process = command.spawn().map_err(TransportError::Io)?;
198
199 let stdin = process
200 .stdin
201 .take()
202 .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?;
203
204 let stdout = process
205 .stdout
206 .take()
207 .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?;
208
209 let stderr = process
210 .stderr
211 .take()
212 .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?;
213
214 let pending_requests_clone = self.pending_requests.clone();
215
216 tokio::spawn(async move {
217 let _ = process.wait().await;
218 let mut pending_requests = pending_requests_clone.lock().await;
220 pending_requests.clear();
221 });
222
223 let (stream, sender, error_stream) = MCPStream::create(
224 Box::pin(stdout),
225 Mutex::new(Box::pin(stdin)),
226 IoStream::Readable(Box::pin(stderr)),
227 self.pending_requests.clone(),
228 self.options.timeout,
229 cancellation_token,
230 );
231
232 self.set_message_sender(sender).await;
233
234 if let IoStream::Writable(error_stream) = error_stream {
235 self.set_error_stream(error_stream).await;
236 }
237
238 Ok(stream)
239 } else {
240 let (stream, sender, error_stream) = MCPStream::create(
241 Box::pin(tokio::io::stdin()),
242 Mutex::new(Box::pin(tokio::io::stdout())),
243 IoStream::Writable(Box::pin(tokio::io::stderr())),
244 self.pending_requests.clone(),
245 self.options.timeout,
246 cancellation_token,
247 );
248
249 self.set_message_sender(sender).await;
250
251 if let IoStream::Writable(error_stream) = error_stream {
252 self.set_error_stream(error_stream).await;
253 }
254 Ok(stream)
255 }
256 }
257
258 async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
259 let mut pending_requests = self.pending_requests.lock().await;
260 pending_requests.remove(request_id)
261 }
262
263 async fn is_shut_down(&self) -> bool {
265 let result = self.is_shut_down.lock().await;
266 *result
267 }
268
269 fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
270 self.message_sender.clone() as _
271 }
272
273 fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
274 &self.error_stream as _
275 }
276
277 async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
278 Err(TransportError::Internal(
279 "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(),
280 ))
281 }
282
283 async fn keep_alive(
284 &self,
285 _interval: Duration,
286 _disconnect_tx: oneshot::Sender<()>,
287 ) -> TransportResult<JoinHandle<()>> {
288 Err(TransportError::Internal(
289 "Invalid invocation of keep_alive() function for StdioTransport".to_string(),
290 ))
291 }
292
293 async fn shut_down(&self) -> TransportResult<()> {
303 let mut cancellation_lock = self.shutdown_source.write().await;
305 if let Some(source) = cancellation_lock.as_ref() {
306 source.cancel()?;
307 }
308 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
312 *is_shut_down_lock = true;
313 Ok(())
314 }
315}
316
317#[async_trait]
318impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
319 for StdioTransport<ClientMessage>
320{
321 async fn send_message(
322 &self,
323 message: ServerMessages,
324 request_timeout: Option<Duration>,
325 ) -> TransportResult<Option<ClientMessages>> {
326 let sender = self.message_sender.read().await;
327 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
328 sender.send_message(message, request_timeout).await
329 }
330
331 async fn send(
332 &self,
333 message: ServerMessage,
334 request_timeout: Option<Duration>,
335 ) -> TransportResult<Option<ClientMessage>> {
336 let sender = self.message_sender.read().await;
337 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
338 sender.send(message, request_timeout).await
339 }
340
341 async fn send_batch(
342 &self,
343 message: Vec<ServerMessage>,
344 request_timeout: Option<Duration>,
345 ) -> TransportResult<Option<Vec<ClientMessage>>> {
346 let sender = self.message_sender.read().await;
347 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
348 sender.send_batch(message, request_timeout).await
349 }
350
351 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
352 let sender = self.message_sender.read().await;
353 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
354 sender.write_str(payload, skip_store).await
355 }
356}
357
358impl
359 TransportDispatcher<
360 ClientMessages,
361 MessageFromServer,
362 ClientMessage,
363 ServerMessages,
364 ServerMessage,
365 > for StdioTransport<ClientMessage>
366{
367}
368
369#[async_trait]
370impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
371 for StdioTransport<ServerMessage>
372{
373 async fn send_message(
374 &self,
375 message: ClientMessages,
376 request_timeout: Option<Duration>,
377 ) -> TransportResult<Option<ServerMessages>> {
378 let sender = self.message_sender.read().await;
379 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
380 sender.send_message(message, request_timeout).await
381 }
382
383 async fn send(
384 &self,
385 message: ClientMessage,
386 request_timeout: Option<Duration>,
387 ) -> TransportResult<Option<ServerMessage>> {
388 let sender = self.message_sender.read().await;
389 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
390 sender.send(message, request_timeout).await
391 }
392
393 async fn send_batch(
394 &self,
395 message: Vec<ClientMessage>,
396 request_timeout: Option<Duration>,
397 ) -> TransportResult<Option<Vec<ServerMessage>>> {
398 let sender = self.message_sender.read().await;
399 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
400 sender.send_batch(message, request_timeout).await
401 }
402
403 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
404 let sender = self.message_sender.read().await;
405 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
406 sender.write_str(payload, skip_store).await
407 }
408}
409
410impl
411 TransportDispatcher<
412 ServerMessages,
413 MessageFromClient,
414 ServerMessage,
415 ClientMessages,
416 ClientMessage,
417 > for StdioTransport<ServerMessage>
418{
419}