1use crate::schema::schema_utils::{
2 ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage,
3 ServerMessages,
4};
5use crate::schema::RequestId;
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::process::Command;
12use tokio::sync::oneshot::Sender;
13use tokio::sync::{oneshot, Mutex};
14use tokio::task::JoinHandle;
15
16use crate::error::{TransportError, TransportResult};
17use crate::mcp_stream::MCPStream;
18use crate::message_dispatcher::MessageDispatcher;
19use crate::transport::Transport;
20use crate::utils::CancellationTokenSource;
21use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions};
22
23pub struct StdioTransport<R>
31where
32 R: Clone + Send + Sync + DeserializeOwned + 'static,
33{
34 command: Option<String>,
35 args: Option<Vec<String>>,
36 env: Option<HashMap<String, String>>,
37 options: TransportOptions,
38 shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
39 is_shut_down: Mutex<bool>,
40 message_sender: Arc<tokio::sync::RwLock<Option<MessageDispatcher<R>>>>,
41 error_stream: tokio::sync::RwLock<Option<IoStream>>,
42 pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>>,
43}
44
45impl<R> StdioTransport<R>
46where
47 R: Clone + Send + Sync + DeserializeOwned + 'static,
48{
49 pub fn new(options: TransportOptions) -> TransportResult<Self> {
62 Ok(Self {
63 args: None,
65 command: None,
66 env: None,
67 options,
68 shutdown_source: tokio::sync::RwLock::new(None),
69 is_shut_down: Mutex::new(false),
70 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
71 error_stream: tokio::sync::RwLock::new(None),
72 pending_requests: Arc::new(Mutex::new(HashMap::new())),
73 })
74 }
75
76 pub fn create_with_server_launch<C: Into<String>>(
91 command: C,
92 args: Vec<String>,
93 env: Option<HashMap<String, String>>,
94 options: TransportOptions,
95 ) -> TransportResult<Self> {
96 Ok(Self {
97 args: Some(args),
99 command: Some(command.into()),
100 env,
101 options,
102 shutdown_source: tokio::sync::RwLock::new(None),
103 is_shut_down: Mutex::new(false),
104 message_sender: Arc::new(tokio::sync::RwLock::new(None)),
105 error_stream: tokio::sync::RwLock::new(None),
106 pending_requests: Arc::new(Mutex::new(HashMap::new())),
107 })
108 }
109
110 fn launch_commands(&self) -> (String, Vec<std::string::String>) {
117 #[cfg(windows)]
118 {
119 let command = "cmd.exe".to_string();
120 let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
121 command_args.extend(self.args.clone().unwrap_or_default());
122 (command, command_args)
123 }
124
125 #[cfg(unix)]
126 {
127 let command = self.command.clone().unwrap_or_default();
128 let command_args = self.args.clone().unwrap_or_default();
129 (command, command_args)
130 }
131 }
132
133 pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher<R>) {
134 let mut lock = self.message_sender.write().await;
135 *lock = Some(sender);
136 }
137
138 pub(crate) async fn set_error_stream(&self, error_stream: IoStream) {
139 let mut lock = self.error_stream.write().await;
140 *lock = Some(error_stream);
141 }
142}
143
144#[async_trait]
145impl<R, S, M, OR, OM> Transport<R, S, M, OR, OM> for StdioTransport<M>
146where
147 R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
148 S: Clone + Send + Sync + serde::Serialize + 'static,
149 M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
150 OR: Clone + Send + Sync + serde::Serialize + 'static,
151 OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
152{
153 async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
167 where
168 MessageDispatcher<M>: McpDispatch<R, OR, M, OM>,
169 {
170 let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
172 let mut lock = self.shutdown_source.write().await;
173 *lock = Some(cancellation_source);
174
175 if self.command.is_some() {
176 let (command_name, command_args) = self.launch_commands();
177
178 let mut command = Command::new(command_name);
179 command
180 .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
181 .args(&command_args)
182 .stdout(std::process::Stdio::piped())
183 .stdin(std::process::Stdio::piped())
184 .stderr(std::process::Stdio::piped())
185 .kill_on_drop(true);
186
187 #[cfg(windows)]
188 command.creation_flags(0x08000000); #[cfg(unix)]
191 command.process_group(0);
192
193 let mut process = command.spawn().map_err(TransportError::Io)?;
194
195 let stdin = process
196 .stdin
197 .take()
198 .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?;
199
200 let stdout = process
201 .stdout
202 .take()
203 .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?;
204
205 let stderr = process
206 .stderr
207 .take()
208 .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?;
209
210 let pending_requests_clone = self.pending_requests.clone();
211
212 tokio::spawn(async move {
213 let _ = process.wait().await;
214 let mut pending_requests = pending_requests_clone.lock().await;
216 pending_requests.clear();
217 });
218
219 let (stream, sender, error_stream) = MCPStream::create(
220 Box::pin(stdout),
221 Mutex::new(Box::pin(stdin)),
222 IoStream::Readable(Box::pin(stderr)),
223 self.pending_requests.clone(),
224 self.options.timeout,
225 cancellation_token,
226 );
227
228 self.set_message_sender(sender).await;
229 self.set_error_stream(error_stream).await;
230
231 Ok(stream)
232 } else {
233 let (stream, sender, error_stream) = MCPStream::create(
234 Box::pin(tokio::io::stdin()),
235 Mutex::new(Box::pin(tokio::io::stdout())),
236 IoStream::Writable(Box::pin(tokio::io::stderr())),
237 self.pending_requests.clone(),
238 self.options.timeout,
239 cancellation_token,
240 );
241
242 self.set_message_sender(sender).await;
243 self.set_error_stream(error_stream).await;
244 Ok(stream)
245 }
246 }
247
248 async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>> {
249 let mut pending_requests = self.pending_requests.lock().await;
250 pending_requests.remove(request_id)
251 }
252
253 async fn is_shut_down(&self) -> bool {
255 let result = self.is_shut_down.lock().await;
256 *result
257 }
258
259 fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>> {
260 self.message_sender.clone() as _
261 }
262
263 fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>> {
264 &self.error_stream as _
265 }
266
267 async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> {
268 Err(TransportError::Internal(
269 "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(),
270 ))
271 }
272
273 async fn keep_alive(
274 &self,
275 _interval: Duration,
276 _disconnect_tx: oneshot::Sender<()>,
277 ) -> TransportResult<JoinHandle<()>> {
278 Err(TransportError::Internal(
279 "Invalid invocation of keep_alive() function for StdioTransport".to_string(),
280 ))
281 }
282
283 async fn shut_down(&self) -> TransportResult<()> {
293 let mut cancellation_lock = self.shutdown_source.write().await;
295 if let Some(source) = cancellation_lock.as_ref() {
296 source.cancel()?;
297 }
298 *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
302 *is_shut_down_lock = true;
303 Ok(())
304 }
305}
306
307#[async_trait]
308impl McpDispatch<ClientMessages, ServerMessages, ClientMessage, ServerMessage>
309 for StdioTransport<ClientMessage>
310{
311 async fn send_message(
312 &self,
313 message: ServerMessages,
314 request_timeout: Option<Duration>,
315 ) -> TransportResult<Option<ClientMessages>> {
316 let sender = self.message_sender.read().await;
317 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
318 sender.send_message(message, request_timeout).await
319 }
320
321 async fn send(
322 &self,
323 message: ServerMessage,
324 request_timeout: Option<Duration>,
325 ) -> TransportResult<Option<ClientMessage>> {
326 let sender = self.message_sender.read().await;
327 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
328 sender.send(message, request_timeout).await
329 }
330
331 async fn send_batch(
332 &self,
333 message: Vec<ServerMessage>,
334 request_timeout: Option<Duration>,
335 ) -> TransportResult<Option<Vec<ClientMessage>>> {
336 let sender = self.message_sender.read().await;
337 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
338 sender.send_batch(message, request_timeout).await
339 }
340
341 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
342 let sender = self.message_sender.read().await;
343 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
344 sender.write_str(payload, skip_store).await
345 }
346}
347
348impl
349 TransportDispatcher<
350 ClientMessages,
351 MessageFromServer,
352 ClientMessage,
353 ServerMessages,
354 ServerMessage,
355 > for StdioTransport<ClientMessage>
356{
357}
358
359#[async_trait]
360impl McpDispatch<ServerMessages, ClientMessages, ServerMessage, ClientMessage>
361 for StdioTransport<ServerMessage>
362{
363 async fn send_message(
364 &self,
365 message: ClientMessages,
366 request_timeout: Option<Duration>,
367 ) -> TransportResult<Option<ServerMessages>> {
368 let sender = self.message_sender.read().await;
369 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
370 sender.send_message(message, request_timeout).await
371 }
372
373 async fn send(
374 &self,
375 message: ClientMessage,
376 request_timeout: Option<Duration>,
377 ) -> TransportResult<Option<ServerMessage>> {
378 let sender = self.message_sender.read().await;
379 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
380 sender.send(message, request_timeout).await
381 }
382
383 async fn send_batch(
384 &self,
385 message: Vec<ClientMessage>,
386 request_timeout: Option<Duration>,
387 ) -> TransportResult<Option<Vec<ServerMessage>>> {
388 let sender = self.message_sender.read().await;
389 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
390 sender.send_batch(message, request_timeout).await
391 }
392
393 async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> {
394 let sender = self.message_sender.read().await;
395 let sender = sender.as_ref().ok_or(SdkError::connection_closed())?;
396 sender.write_str(payload, skip_store).await
397 }
398}
399
400impl
401 TransportDispatcher<
402 ServerMessages,
403 MessageFromClient,
404 ServerMessage,
405 ClientMessages,
406 ClientMessage,
407 > for StdioTransport<ServerMessage>
408{
409}