sh_layer4/mcp_bridge/
transport.rs1use async_trait::async_trait;
6use std::sync::Arc;
7use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
8#[cfg(unix)]
9use tokio::net::UnixStream;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::Mutex;
12
13use super::protocol::McpMessage;
14use anyhow::{anyhow, Result};
15
16#[derive(Debug, Clone)]
18pub enum McpTransportType {
19 Stdio {
21 command: String,
23 args: Vec<String>,
25 },
26 Tcp {
28 addr: String,
30 },
31 #[cfg(unix)]
33 Unix {
34 path: String,
36 },
37}
38
39#[async_trait]
41pub trait McpTransport: Send + Sync {
42 async fn send(&self, message: &McpMessage) -> Result<()>;
44
45 async fn receive(&self) -> Result<Option<McpMessage>>;
47
48 async fn close(&self) -> Result<()>;
50}
51
52pub struct StdioTransport {
54 process: Arc<Mutex<Option<tokio::process::Child>>>,
56 stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>,
58 stdout: Arc<Mutex<Option<tokio::io::BufReader<tokio::process::ChildStdout>>>>,
60}
61
62impl StdioTransport {
63 pub fn new(_command: &str, _args: &[String]) -> Result<Self> {
65 Ok(Self {
66 process: Arc::new(Mutex::new(None)),
67 stdin: Arc::new(Mutex::new(None)),
68 stdout: Arc::new(Mutex::new(None)),
69 })
70 }
71
72 pub async fn start(&self, command: &str, args: &[String]) -> Result<()> {
74 use std::process::Stdio;
75
76 let mut cmd = tokio::process::Command::new(command);
77 cmd.args(args)
78 .stdin(Stdio::piped())
79 .stdout(Stdio::piped())
80 .stderr(Stdio::null());
81
82 let mut child = cmd.spawn()?;
83
84 let stdin = child
85 .stdin
86 .take()
87 .ok_or_else(|| anyhow!("Failed to open stdin"))?;
88 let stdout = child
89 .stdout
90 .take()
91 .ok_or_else(|| anyhow!("Failed to open stdout"))?;
92
93 *self.stdin.lock().await = Some(stdin);
94 *self.stdout.lock().await = Some(BufReader::new(stdout));
95 *self.process.lock().await = Some(child);
96
97 Ok(())
98 }
99}
100
101#[async_trait]
102impl McpTransport for StdioTransport {
103 async fn send(&self, message: &McpMessage) -> Result<()> {
104 let mut stdin_guard = self.stdin.lock().await;
105 let stdin = stdin_guard
106 .as_mut()
107 .ok_or_else(|| anyhow!("Transport not started"))?;
108
109 let json = serde_json::to_string(message)?;
110 let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
111 stdin.write_all(frame.as_bytes()).await?;
112 stdin.flush().await?;
113 Ok(())
114 }
115
116 async fn receive(&self) -> Result<Option<McpMessage>> {
117 let mut stdout_guard = self.stdout.lock().await;
118 let stdout = stdout_guard
119 .as_mut()
120 .ok_or_else(|| anyhow!("Transport not started"))?;
121
122 let mut header_buf = vec![0u8; 1024];
124 let mut total_read = 0;
125
126 loop {
127 let n = stdout.read(&mut header_buf[total_read..]).await?;
128 if n == 0 {
129 return Ok(None); }
131 total_read += n;
132
133 if let Some(pos) = find_header_end(&header_buf[..total_read]) {
135 let header = String::from_utf8_lossy(&header_buf[..pos]);
136 let content_length = parse_content_length(&header)?;
137
138 let header_size = pos + 4; let body_size = content_length;
141 let mut body_buf = vec![0u8; body_size];
142
143 let already_read = total_read - header_size;
145 if already_read > 0 {
146 body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
147 }
148
149 if already_read < body_size {
151 stdout.read_exact(&mut body_buf[already_read..]).await?;
152 }
153
154 let message: McpMessage = serde_json::from_slice(&body_buf)?;
155 return Ok(Some(message));
156 }
157
158 if total_read >= header_buf.len() {
159 return Err(anyhow!("Header too large"));
160 }
161 }
162 }
163
164 async fn close(&self) -> Result<()> {
165 let mut process_guard = self.process.lock().await;
166 if let Some(mut process) = process_guard.take() {
167 process.kill().await?;
168 }
169 Ok(())
170 }
171}
172
173pub struct TcpTransport {
175 stream: Arc<Mutex<Option<TcpStream>>>,
177 listener: Arc<Mutex<Option<TcpListener>>>,
179}
180
181impl TcpTransport {
182 pub async fn connect(addr: &str) -> Result<Self> {
184 let stream = TcpStream::connect(addr).await?;
185 Ok(Self {
186 stream: Arc::new(Mutex::new(Some(stream))),
187 listener: Arc::new(Mutex::new(None)),
188 })
189 }
190
191 pub async fn bind(addr: &str) -> Result<Self> {
193 let listener = TcpListener::bind(addr).await?;
194 Ok(Self {
195 stream: Arc::new(Mutex::new(None)),
196 listener: Arc::new(Mutex::new(Some(listener))),
197 })
198 }
199
200 pub async fn accept(&self) -> Result<()> {
202 let mut listener_guard = self.listener.lock().await;
203 let listener = listener_guard
204 .as_mut()
205 .ok_or_else(|| anyhow!("Not in server mode"))?;
206
207 let (stream, _) = listener.accept().await?;
208 *self.stream.lock().await = Some(stream);
209 Ok(())
210 }
211}
212
213#[async_trait]
214impl McpTransport for TcpTransport {
215 async fn send(&self, message: &McpMessage) -> Result<()> {
216 let mut stream_guard = self.stream.lock().await;
217 let stream = stream_guard
218 .as_mut()
219 .ok_or_else(|| anyhow!("Not connected"))?;
220
221 let json = serde_json::to_string(message)?;
222 let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
223 stream.write_all(frame.as_bytes()).await?;
224 stream.flush().await?;
225 Ok(())
226 }
227
228 async fn receive(&self) -> Result<Option<McpMessage>> {
229 let mut stream_guard = self.stream.lock().await;
230 let stream = stream_guard
231 .as_mut()
232 .ok_or_else(|| anyhow!("Not connected"))?;
233
234 let mut header_buf = vec![0u8; 1024];
236 let mut total_read = 0;
237
238 loop {
239 let n = stream.read(&mut header_buf[total_read..]).await?;
240 if n == 0 {
241 return Ok(None); }
243 total_read += n;
244
245 if let Some(pos) = find_header_end(&header_buf[..total_read]) {
247 let header = String::from_utf8_lossy(&header_buf[..pos]);
248 let content_length = parse_content_length(&header)?;
249
250 let header_size = pos + 4; let body_size = content_length;
253 let mut body_buf = vec![0u8; body_size];
254
255 let already_read = total_read - header_size;
257 if already_read > 0 {
258 body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
259 }
260
261 if already_read < body_size {
263 stream.read_exact(&mut body_buf[already_read..]).await?;
264 }
265
266 let message: McpMessage = serde_json::from_slice(&body_buf)?;
267 return Ok(Some(message));
268 }
269
270 if total_read >= header_buf.len() {
271 return Err(anyhow!("Header too large"));
272 }
273 }
274 }
275
276 async fn close(&self) -> Result<()> {
277 let mut stream_guard = self.stream.lock().await;
278 stream_guard.take();
279 Ok(())
280 }
281}
282
283fn find_header_end(buf: &[u8]) -> Option<usize> {
285 for i in 0..buf.len().saturating_sub(3) {
286 if &buf[i..i + 4] == b"\r\n\r\n" {
287 return Some(i);
288 }
289 }
290 None
291}
292
293fn parse_content_length(header: &str) -> Result<usize> {
295 for line in header.lines() {
296 if let Some(value) = line.strip_prefix("Content-Length:") {
297 return Ok(value.trim().parse()?);
298 }
299 }
300 Err(anyhow!("Content-Length header not found"))
301}
302
303#[cfg(unix)]
305pub struct UnixSocketTransport {
306 stream: Arc<Mutex<Option<UnixStream>>>,
308}
309
310#[cfg(unix)]
311impl UnixSocketTransport {
312 pub async fn connect(path: &str) -> Result<Self> {
314 let stream = UnixStream::connect(path).await?;
315 Ok(Self {
316 stream: Arc::new(Mutex::new(Some(stream))),
317 })
318 }
319}
320
321#[cfg(unix)]
322#[async_trait]
323impl McpTransport for UnixSocketTransport {
324 async fn send(&self, message: &McpMessage) -> Result<()> {
325 let mut stream_guard = self.stream.lock().await;
326 let stream = stream_guard
327 .as_mut()
328 .ok_or_else(|| anyhow!("Not connected"))?;
329
330 let json = serde_json::to_string(message)?;
331 let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
332 stream.write_all(frame.as_bytes()).await?;
333 stream.flush().await?;
334 Ok(())
335 }
336
337 async fn receive(&self) -> Result<Option<McpMessage>> {
338 let mut stream_guard = self.stream.lock().await;
339 let stream = stream_guard
340 .as_mut()
341 .ok_or_else(|| anyhow!("Not connected"))?;
342
343 let mut header_buf = vec![0u8; 1024];
345 let mut total_read = 0;
346
347 loop {
348 let n = stream.read(&mut header_buf[total_read..]).await?;
349 if n == 0 {
350 return Ok(None); }
352 total_read += n;
353
354 if let Some(pos) = find_header_end(&header_buf[..total_read]) {
356 let header = String::from_utf8_lossy(&header_buf[..pos]);
357 let content_length = parse_content_length(&header)?;
358
359 let header_size = pos + 4; let body_size = content_length;
362 let mut body_buf = vec![0u8; body_size];
363
364 let already_read = total_read - header_size;
366 if already_read > 0 {
367 body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
368 }
369
370 if already_read < body_size {
372 stream.read_exact(&mut body_buf[already_read..]).await?;
373 }
374
375 let message: McpMessage = serde_json::from_slice(&body_buf)?;
376 return Ok(Some(message));
377 }
378
379 if total_read >= header_buf.len() {
380 return Err(anyhow!("Header too large"));
381 }
382 }
383 }
384
385 async fn close(&self) -> Result<()> {
386 let mut stream_guard = self.stream.lock().await;
387 stream_guard.take();
388 Ok(())
389 }
390}
391
392pub struct MemoryTransport {
394 messages: Arc<Mutex<Vec<McpMessage>>>,
395 position: Arc<Mutex<usize>>,
396}
397
398impl Default for MemoryTransport {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404impl MemoryTransport {
405 pub fn new() -> Self {
406 Self {
407 messages: Arc::new(Mutex::new(Vec::new())),
408 position: Arc::new(Mutex::new(0)),
409 }
410 }
411
412 pub async fn push(&self, message: McpMessage) {
413 self.messages.lock().await.push(message);
414 }
415}
416
417#[async_trait]
418impl McpTransport for MemoryTransport {
419 async fn send(&self, message: &McpMessage) -> Result<()> {
420 self.messages.lock().await.push(message.clone());
421 Ok(())
422 }
423
424 async fn receive(&self) -> Result<Option<McpMessage>> {
425 let messages = self.messages.lock().await;
426 let mut pos = self.position.lock().await;
427
428 if *pos < messages.len() {
429 let message = messages[*pos].clone();
430 *pos += 1;
431 Ok(Some(message))
432 } else {
433 Ok(None)
434 }
435 }
436
437 async fn close(&self) -> Result<()> {
438 Ok(())
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::super::protocol::{McpRequest, RequestId};
445 use super::*;
446
447 #[test]
448 fn test_parse_content_length() {
449 let header = "Content-Length: 42\r\n";
450 let len = parse_content_length(header).unwrap();
451 assert_eq!(len, 42);
452 }
453
454 #[test]
455 fn test_find_header_end() {
456 let buf = b"Content-Length: 10\r\n\r\n";
457 let pos = find_header_end(buf).unwrap();
458 assert_eq!(pos, 18); }
460
461 #[tokio::test]
462 async fn test_memory_transport() {
463 let transport = MemoryTransport::new();
464
465 let msg = McpMessage::Request(McpRequest {
466 id: RequestId::Number(1),
467 method: "test".to_string(),
468 params: None,
469 });
470
471 transport.send(&msg).await.unwrap();
472 let received = transport.receive().await.unwrap();
473 assert!(received.is_some());
474 }
475}