rust_mcp_transport/
stdio.rs1use async_trait::async_trait;
2use futures::Stream;
3use rust_mcp_schema::schema_utils::{McpMessage, RpcMessage};
4use rust_mcp_schema::RequestId;
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::process::Command;
9use tokio::sync::Mutex;
10
11use crate::error::{TransportError, TransportResult};
12use crate::mcp_stream::MCPStream;
13use crate::message_dispatcher::MessageDispatcher;
14use crate::transport::Transport;
15use crate::utils::CancellationTokenSource;
16use crate::{IoStream, McpDispatch, TransportOptions};
17
18pub struct StdioTransport {
26    command: Option<String>,
27    args: Option<Vec<String>>,
28    env: Option<HashMap<String, String>>,
29    options: TransportOptions,
30    shutdown_source: tokio::sync::RwLock<Option<CancellationTokenSource>>,
31    is_shut_down: Mutex<bool>,
32}
33
34impl StdioTransport {
35    pub fn new(options: TransportOptions) -> TransportResult<Self> {
48        Ok(Self {
49            args: None,
51            command: None,
52            env: None,
53            options,
54            shutdown_source: tokio::sync::RwLock::new(None),
55            is_shut_down: Mutex::new(false),
56        })
57    }
58
59    pub fn create_with_server_launch<C: Into<String>>(
74        command: C,
75        args: Vec<String>,
76        env: Option<HashMap<String, String>>,
77        options: TransportOptions,
78    ) -> TransportResult<Self> {
79        Ok(Self {
80            args: Some(args),
82            command: Some(command.into()),
83            env,
84            options,
85            shutdown_source: tokio::sync::RwLock::new(None),
86            is_shut_down: Mutex::new(false),
87        })
88    }
89
90    fn launch_commands(&self) -> (String, Vec<std::string::String>) {
97        #[cfg(windows)]
98        {
99            let command = "cmd.exe".to_string();
100            let mut command_args = vec!["/c".to_string(), self.command.clone().unwrap_or_default()];
101            command_args.extend(self.args.clone().unwrap_or_default());
102            (command, command_args)
103        }
104
105        #[cfg(unix)]
106        {
107            let command = self.command.clone().unwrap_or_default();
108            let command_args = self.args.clone().unwrap_or_default();
109            (command, command_args)
110        }
111    }
112}
113
114#[async_trait]
115impl<R, S> Transport<R, S> for StdioTransport
116where
117    R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
118    S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static,
119{
120    async fn start(
134        &self,
135    ) -> TransportResult<(
136        Pin<Box<dyn Stream<Item = R> + Send>>,
137        MessageDispatcher<R>,
138        IoStream,
139    )>
140    where
141        MessageDispatcher<R>: McpDispatch<R, S>,
142    {
143        let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
145        let mut lock = self.shutdown_source.write().await;
146        *lock = Some(cancellation_source);
147
148        if self.command.is_some() {
149            let (command_name, command_args) = self.launch_commands();
150
151            let mut command = Command::new(command_name);
152            command
153                .envs(self.env.as_ref().unwrap_or(&HashMap::new()))
154                .args(&command_args)
155                .stdout(std::process::Stdio::piped())
156                .stdin(std::process::Stdio::piped())
157                .stderr(std::process::Stdio::piped())
158                .kill_on_drop(true);
159
160            #[cfg(windows)]
161            command.creation_flags(0x08000000); #[cfg(unix)]
164            command.process_group(0);
165
166            let mut process = command.spawn().map_err(TransportError::StdioError)?;
167
168            let stdin = process
169                .stdin
170                .take()
171                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?;
172
173            let stdout = process
174                .stdout
175                .take()
176                .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?;
177
178            let stderr = process
179                .stderr
180                .take()
181                .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?;
182
183            let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
184                Arc::new(Mutex::new(HashMap::new()));
185            let pending_requests_clone = Arc::clone(&pending_requests);
186
187            tokio::spawn(async move {
188                let _ = process.wait().await;
189                let mut pending_requests = pending_requests.lock().await;
191                pending_requests.clear();
192            });
193
194            let (stream, sender, error_stream) = MCPStream::create(
195                Box::pin(stdout),
196                Mutex::new(Box::pin(stdin)),
197                IoStream::Readable(Box::pin(stderr)),
198                pending_requests_clone,
199                self.options.timeout,
200                cancellation_token,
201            );
202
203            Ok((stream, sender, error_stream))
204        } else {
205            let pending_requests: Arc<Mutex<HashMap<RequestId, tokio::sync::oneshot::Sender<R>>>> =
206                Arc::new(Mutex::new(HashMap::new()));
207            let (stream, sender, error_stream) = MCPStream::create(
208                Box::pin(tokio::io::stdin()),
209                Mutex::new(Box::pin(tokio::io::stdout())),
210                IoStream::Writable(Box::pin(tokio::io::stderr())),
211                pending_requests,
212                self.options.timeout,
213                cancellation_token,
214            );
215
216            Ok((stream, sender, error_stream))
217        }
218    }
219
220    async fn is_shut_down(&self) -> bool {
222        let result = self.is_shut_down.lock().await;
223        *result
224    }
225
226    async fn shut_down(&self) -> TransportResult<()> {
236        let mut cancellation_lock = self.shutdown_source.write().await;
238        if let Some(source) = cancellation_lock.as_ref() {
239            source.cancel()?;
240        }
241        *cancellation_lock = None; let mut is_shut_down_lock = self.is_shut_down.lock().await;
245        *is_shut_down_lock = true;
246        Ok(())
247    }
248}