Skip to main content

sh_layer4/mcp_bridge/
transport.rs

1//! MCP 传输层
2//!
3//! 支持 stdio、TCP 和 Unix socket 三种传输方式。
4
5use async_trait::async_trait;
6use std::sync::Arc;
7use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
8#[cfg(unix)]
9use tokio::net::UnixStream;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::Mutex;
12
13use super::protocol::McpMessage;
14use anyhow::{anyhow, Result};
15
16/// 传输类型
17#[derive(Debug, Clone)]
18pub enum McpTransportType {
19    /// 标准输入输出
20    Stdio {
21        /// 命令
22        command: String,
23        /// 参数
24        args: Vec<String>,
25    },
26    /// TCP Socket
27    Tcp {
28        /// 地址
29        addr: String,
30    },
31    /// Unix Socket (仅 Unix 系统)
32    #[cfg(unix)]
33    Unix {
34        /// 路径
35        path: String,
36    },
37}
38
39/// MCP 传输 trait
40#[async_trait]
41pub trait McpTransport: Send + Sync {
42    /// 发送消息
43    async fn send(&self, message: &McpMessage) -> Result<()>;
44
45    /// 接收消息
46    async fn receive(&self) -> Result<Option<McpMessage>>;
47
48    /// 关闭传输
49    async fn close(&self) -> Result<()>;
50}
51
52/// Stdio 传输实现
53pub struct StdioTransport {
54    /// 子进程
55    process: Arc<Mutex<Option<tokio::process::Child>>>,
56    /// 标准输入
57    stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>,
58    /// 标准输出读取器
59    stdout: Arc<Mutex<Option<tokio::io::BufReader<tokio::process::ChildStdout>>>>,
60}
61
62impl StdioTransport {
63    /// 创建新的 Stdio 传输
64    pub fn new(_command: &str, _args: &[String]) -> Result<Self> {
65        Ok(Self {
66            process: Arc::new(Mutex::new(None)),
67            stdin: Arc::new(Mutex::new(None)),
68            stdout: Arc::new(Mutex::new(None)),
69        })
70    }
71
72    /// 启动子进程
73    pub async fn start(&self, command: &str, args: &[String]) -> Result<()> {
74        use std::process::Stdio;
75
76        let mut cmd = tokio::process::Command::new(command);
77        cmd.args(args)
78            .stdin(Stdio::piped())
79            .stdout(Stdio::piped())
80            .stderr(Stdio::null());
81
82        let mut child = cmd.spawn()?;
83
84        let stdin = child
85            .stdin
86            .take()
87            .ok_or_else(|| anyhow!("Failed to open stdin"))?;
88        let stdout = child
89            .stdout
90            .take()
91            .ok_or_else(|| anyhow!("Failed to open stdout"))?;
92
93        *self.stdin.lock().await = Some(stdin);
94        *self.stdout.lock().await = Some(BufReader::new(stdout));
95        *self.process.lock().await = Some(child);
96
97        Ok(())
98    }
99}
100
101#[async_trait]
102impl McpTransport for StdioTransport {
103    async fn send(&self, message: &McpMessage) -> Result<()> {
104        let mut stdin_guard = self.stdin.lock().await;
105        let stdin = stdin_guard
106            .as_mut()
107            .ok_or_else(|| anyhow!("Transport not started"))?;
108
109        let json = serde_json::to_string(message)?;
110        let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
111        stdin.write_all(frame.as_bytes()).await?;
112        stdin.flush().await?;
113        Ok(())
114    }
115
116    async fn receive(&self) -> Result<Option<McpMessage>> {
117        let mut stdout_guard = self.stdout.lock().await;
118        let stdout = stdout_guard
119            .as_mut()
120            .ok_or_else(|| anyhow!("Transport not started"))?;
121
122        // 读取 Content-Length 头
123        let mut header_buf = vec![0u8; 1024];
124        let mut total_read = 0;
125
126        loop {
127            let n = stdout.read(&mut header_buf[total_read..]).await?;
128            if n == 0 {
129                return Ok(None); // 连接关闭
130            }
131            total_read += n;
132
133            // 查找 \r\n\r\n 分隔符
134            if let Some(pos) = find_header_end(&header_buf[..total_read]) {
135                let header = String::from_utf8_lossy(&header_buf[..pos]);
136                let content_length = parse_content_length(&header)?;
137
138                // 读取消息体
139                let header_size = pos + 4; // 包含 \r\n\r\n
140                let body_size = content_length;
141                let mut body_buf = vec![0u8; body_size];
142
143                // 处理已经读取的部分
144                let already_read = total_read - header_size;
145                if already_read > 0 {
146                    body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
147                }
148
149                // 读取剩余部分
150                if already_read < body_size {
151                    stdout.read_exact(&mut body_buf[already_read..]).await?;
152                }
153
154                let message: McpMessage = serde_json::from_slice(&body_buf)?;
155                return Ok(Some(message));
156            }
157
158            if total_read >= header_buf.len() {
159                return Err(anyhow!("Header too large"));
160            }
161        }
162    }
163
164    async fn close(&self) -> Result<()> {
165        let mut process_guard = self.process.lock().await;
166        if let Some(mut process) = process_guard.take() {
167            process.kill().await?;
168        }
169        Ok(())
170    }
171}
172
173/// TCP 传输实现
174pub struct TcpTransport {
175    /// 连接流
176    stream: Arc<Mutex<Option<TcpStream>>>,
177    /// 服务器监听器 (服务端模式)
178    listener: Arc<Mutex<Option<TcpListener>>>,
179}
180
181impl TcpTransport {
182    /// 创建客户端连接
183    pub async fn connect(addr: &str) -> Result<Self> {
184        let stream = TcpStream::connect(addr).await?;
185        Ok(Self {
186            stream: Arc::new(Mutex::new(Some(stream))),
187            listener: Arc::new(Mutex::new(None)),
188        })
189    }
190
191    /// 创建服务器监听
192    pub async fn bind(addr: &str) -> Result<Self> {
193        let listener = TcpListener::bind(addr).await?;
194        Ok(Self {
195            stream: Arc::new(Mutex::new(None)),
196            listener: Arc::new(Mutex::new(Some(listener))),
197        })
198    }
199
200    /// 接受客户端连接 (服务端模式)
201    pub async fn accept(&self) -> Result<()> {
202        let mut listener_guard = self.listener.lock().await;
203        let listener = listener_guard
204            .as_mut()
205            .ok_or_else(|| anyhow!("Not in server mode"))?;
206
207        let (stream, _) = listener.accept().await?;
208        *self.stream.lock().await = Some(stream);
209        Ok(())
210    }
211}
212
213#[async_trait]
214impl McpTransport for TcpTransport {
215    async fn send(&self, message: &McpMessage) -> Result<()> {
216        let mut stream_guard = self.stream.lock().await;
217        let stream = stream_guard
218            .as_mut()
219            .ok_or_else(|| anyhow!("Not connected"))?;
220
221        let json = serde_json::to_string(message)?;
222        let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
223        stream.write_all(frame.as_bytes()).await?;
224        stream.flush().await?;
225        Ok(())
226    }
227
228    async fn receive(&self) -> Result<Option<McpMessage>> {
229        let mut stream_guard = self.stream.lock().await;
230        let stream = stream_guard
231            .as_mut()
232            .ok_or_else(|| anyhow!("Not connected"))?;
233
234        // 读取 Content-Length 头
235        let mut header_buf = vec![0u8; 1024];
236        let mut total_read = 0;
237
238        loop {
239            let n = stream.read(&mut header_buf[total_read..]).await?;
240            if n == 0 {
241                return Ok(None); // 连接关闭
242            }
243            total_read += n;
244
245            // 查找 \r\n\r\n 分隔符
246            if let Some(pos) = find_header_end(&header_buf[..total_read]) {
247                let header = String::from_utf8_lossy(&header_buf[..pos]);
248                let content_length = parse_content_length(&header)?;
249
250                // 读取消息体
251                let header_size = pos + 4; // 包含 \r\n\r\n
252                let body_size = content_length;
253                let mut body_buf = vec![0u8; body_size];
254
255                // 处理已经读取的部分
256                let already_read = total_read - header_size;
257                if already_read > 0 {
258                    body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
259                }
260
261                // 读取剩余部分
262                if already_read < body_size {
263                    stream.read_exact(&mut body_buf[already_read..]).await?;
264                }
265
266                let message: McpMessage = serde_json::from_slice(&body_buf)?;
267                return Ok(Some(message));
268            }
269
270            if total_read >= header_buf.len() {
271                return Err(anyhow!("Header too large"));
272            }
273        }
274    }
275
276    async fn close(&self) -> Result<()> {
277        let mut stream_guard = self.stream.lock().await;
278        stream_guard.take();
279        Ok(())
280    }
281}
282
283/// 查找 HTTP 风格的头部结束位置
284fn find_header_end(buf: &[u8]) -> Option<usize> {
285    for i in 0..buf.len().saturating_sub(3) {
286        if &buf[i..i + 4] == b"\r\n\r\n" {
287            return Some(i);
288        }
289    }
290    None
291}
292
293/// 解析 Content-Length 头
294fn parse_content_length(header: &str) -> Result<usize> {
295    for line in header.lines() {
296        if let Some(value) = line.strip_prefix("Content-Length:") {
297            return Ok(value.trim().parse()?);
298        }
299    }
300    Err(anyhow!("Content-Length header not found"))
301}
302
303/// Unix Socket 传输实现 (仅 Unix 系统)
304#[cfg(unix)]
305pub struct UnixSocketTransport {
306    /// Unix socket 流
307    stream: Arc<Mutex<Option<UnixStream>>>,
308}
309
310#[cfg(unix)]
311impl UnixSocketTransport {
312    /// 连接到 Unix socket
313    pub async fn connect(path: &str) -> Result<Self> {
314        let stream = UnixStream::connect(path).await?;
315        Ok(Self {
316            stream: Arc::new(Mutex::new(Some(stream))),
317        })
318    }
319}
320
321#[cfg(unix)]
322#[async_trait]
323impl McpTransport for UnixSocketTransport {
324    async fn send(&self, message: &McpMessage) -> Result<()> {
325        let mut stream_guard = self.stream.lock().await;
326        let stream = stream_guard
327            .as_mut()
328            .ok_or_else(|| anyhow!("Not connected"))?;
329
330        let json = serde_json::to_string(message)?;
331        let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
332        stream.write_all(frame.as_bytes()).await?;
333        stream.flush().await?;
334        Ok(())
335    }
336
337    async fn receive(&self) -> Result<Option<McpMessage>> {
338        let mut stream_guard = self.stream.lock().await;
339        let stream = stream_guard
340            .as_mut()
341            .ok_or_else(|| anyhow!("Not connected"))?;
342
343        // 读取 Content-Length 头
344        let mut header_buf = vec![0u8; 1024];
345        let mut total_read = 0;
346
347        loop {
348            let n = stream.read(&mut header_buf[total_read..]).await?;
349            if n == 0 {
350                return Ok(None); // 连接关闭
351            }
352            total_read += n;
353
354            // 查找 \r\n\r\n 分隔符
355            if let Some(pos) = find_header_end(&header_buf[..total_read]) {
356                let header = String::from_utf8_lossy(&header_buf[..pos]);
357                let content_length = parse_content_length(&header)?;
358
359                // 读取消息体
360                let header_size = pos + 4; // 包含 \r\n\r\n
361                let body_size = content_length;
362                let mut body_buf = vec![0u8; body_size];
363
364                // 处理已经读取的部分
365                let already_read = total_read - header_size;
366                if already_read > 0 {
367                    body_buf[..already_read].copy_from_slice(&header_buf[header_size..total_read]);
368                }
369
370                // 读取剩余部分
371                if already_read < body_size {
372                    stream.read_exact(&mut body_buf[already_read..]).await?;
373                }
374
375                let message: McpMessage = serde_json::from_slice(&body_buf)?;
376                return Ok(Some(message));
377            }
378
379            if total_read >= header_buf.len() {
380                return Err(anyhow!("Header too large"));
381            }
382        }
383    }
384
385    async fn close(&self) -> Result<()> {
386        let mut stream_guard = self.stream.lock().await;
387        stream_guard.take();
388        Ok(())
389    }
390}
391
392/// 内存传输 (用于测试)
393pub struct MemoryTransport {
394    messages: Arc<Mutex<Vec<McpMessage>>>,
395    position: Arc<Mutex<usize>>,
396}
397
398impl Default for MemoryTransport {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404impl MemoryTransport {
405    pub fn new() -> Self {
406        Self {
407            messages: Arc::new(Mutex::new(Vec::new())),
408            position: Arc::new(Mutex::new(0)),
409        }
410    }
411
412    pub async fn push(&self, message: McpMessage) {
413        self.messages.lock().await.push(message);
414    }
415}
416
417#[async_trait]
418impl McpTransport for MemoryTransport {
419    async fn send(&self, message: &McpMessage) -> Result<()> {
420        self.messages.lock().await.push(message.clone());
421        Ok(())
422    }
423
424    async fn receive(&self) -> Result<Option<McpMessage>> {
425        let messages = self.messages.lock().await;
426        let mut pos = self.position.lock().await;
427
428        if *pos < messages.len() {
429            let message = messages[*pos].clone();
430            *pos += 1;
431            Ok(Some(message))
432        } else {
433            Ok(None)
434        }
435    }
436
437    async fn close(&self) -> Result<()> {
438        Ok(())
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::super::protocol::{McpRequest, RequestId};
445    use super::*;
446
447    #[test]
448    fn test_parse_content_length() {
449        let header = "Content-Length: 42\r\n";
450        let len = parse_content_length(header).unwrap();
451        assert_eq!(len, 42);
452    }
453
454    #[test]
455    fn test_find_header_end() {
456        let buf = b"Content-Length: 10\r\n\r\n";
457        let pos = find_header_end(buf).unwrap();
458        assert_eq!(pos, 18); // "\r\n\r\n" starts at position 18 (after "Content-Length: 10")
459    }
460
461    #[tokio::test]
462    async fn test_memory_transport() {
463        let transport = MemoryTransport::new();
464
465        let msg = McpMessage::Request(McpRequest {
466            id: RequestId::Number(1),
467            method: "test".to_string(),
468            params: None,
469        });
470
471        transport.send(&msg).await.unwrap();
472        let received = transport.receive().await.unwrap();
473        assert!(received.is_some());
474    }
475}