1use crate::mcp::client::Message;
2use crate::mcp::server;
3use crate::mcp::{self, Bytes};
4use crate::transport::{Channel, Transport};
5
6use futures::channel::mpsc;
7use futures::future::{self, BoxFuture};
8use futures::{FutureExt, SinkExt, StreamExt};
9use tokio::io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::process;
11use tokio::task;
12
13use std::ffi::OsStr;
14
15pub struct Stdio {
16 _process: process::Child,
17 runner: mpsc::Sender<Action>,
18}
19
20impl Stdio {
21 pub fn run(
22 command: impl AsRef<OsStr>,
23 args: impl IntoIterator<Item = impl AsRef<OsStr>>,
24 ) -> io::Result<Self> {
25 let mut process = process::Command::new(command)
26 .args(args)
27 .stdin(std::process::Stdio::piped())
28 .stderr(std::process::Stdio::null()) .stdout(std::process::Stdio::piped())
30 .kill_on_drop(true) .spawn()?;
32
33 let input = process.stdin.take().expect("process must have stdin");
34 let output = process.stdout.take().expect("process must have stdout");
35
36 let (sender, receiver) = mpsc::channel(10);
37 drop(task::spawn(run(input, output, receiver)));
38
39 Ok(Self {
40 _process: process,
41 runner: sender,
42 })
43 }
44}
45
46impl Transport for Stdio {
47 fn listen(&self) -> BoxFuture<'static, io::Result<Channel>> {
48 let mut runner = self.runner.clone();
49
50 async move {
51 let (sender, receiver) = mpsc::channel(1);
52 let _ = runner.send(Action::Listen(sender)).await;
53
54 Ok(receiver)
55 }
56 .boxed()
57 }
58
59 fn send(&self, bytes: Bytes) -> BoxFuture<'static, io::Result<Channel>> {
60 let mut runner = self.runner.clone();
61
62 async move {
63 let (sender, receiver) = mpsc::channel(1);
64 let _ = runner.send(Action::Send(bytes, sender)).await;
65
66 Ok(receiver)
67 }
68 .boxed()
69 }
70}
71
72type Sender = mpsc::Sender<Bytes>;
73
74enum Action {
75 Listen(Sender),
76 Send(Bytes, Sender),
77}
78
79async fn run(
80 mut input: impl AsyncWrite + Unpin,
81 output: impl AsyncRead + Unpin,
82 mut receiver: mpsc::Receiver<Action>,
83) -> io::Result<()> {
84 use future::Either;
85
86 let mut output = BufReader::new(output);
87 let mut listeners = Vec::new();
88 let mut buffer = Vec::new();
89
90 loop {
91 let event = {
92 let next_line = Box::pin(output.read_until(0xA, &mut buffer));
93 let next_action = receiver.next().fuse();
94 let next_event = future::select(next_line, next_action);
95
96 match next_event.await {
97 Either::Left((line, _)) => Either::Left(line),
98 Either::Right((Some(action), _)) => Either::Right(action),
99 _ => return Ok(()),
100 }
101 };
102
103 match event {
104 Either::Right(Action::Listen(sender)) => {
105 listeners.push(sender);
106 }
107 Either::Left(Ok(n)) => {
108 if n == 0 {
109 return Ok(());
110 }
111
112 let bytes = Bytes::from_owner(std::mem::take(&mut buffer));
113
114 for listener in &mut listeners {
115 let _ = listener.send(bytes.clone()).await;
116 }
117 }
118 Either::Right(Action::Send(bytes, mut sender)) => {
119 write(&mut input, &bytes).await?;
120
121 let Ok(Message::Request(_)) = Message::<mcp::Ignored>::deserialize(&bytes) else {
122 continue;
123 };
124
125 while let Ok(n) = output.read_until(0xA, &mut buffer).await {
126 if n == 0 {
127 return Ok(());
128 }
129
130 let bytes = Bytes::from_owner(std::mem::take(&mut buffer));
131 let _ = sender.send(bytes.clone()).await;
132
133 if let Ok(server::Message::Response(_)) =
134 server::Message::<mcp::Ignored>::deserialize(&bytes)
135 {
136 break;
137 }
138 }
139 }
140 _ => {
141 break;
142 }
143 }
144 }
145
146 Ok(())
147}
148
149async fn write(input: &mut (impl AsyncWrite + Unpin), data: &[u8]) -> io::Result<()> {
150 input.write_all(data).await?;
151 input.write_u8(0xA).await?;
152 input.flush().await
153}