techne_client/
stdio.rs

1use crate::mcp::client::Message;
2use crate::mcp::server;
3use crate::mcp::{self, Bytes};
4use crate::transport::{Channel, Transport};
5
6use futures::channel::mpsc;
7use futures::future::{self, BoxFuture};
8use futures::{FutureExt, SinkExt, StreamExt};
9use tokio::io::{self, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
10use tokio::process;
11use tokio::task;
12
13use std::ffi::OsStr;
14
15pub struct Stdio {
16    _process: process::Child,
17    runner: mpsc::Sender<Action>,
18}
19
20impl Stdio {
21    pub fn run(
22        command: impl AsRef<OsStr>,
23        args: impl IntoIterator<Item = impl AsRef<OsStr>>,
24    ) -> io::Result<Self> {
25        let mut process = process::Command::new(command)
26            .args(args)
27            .stdin(std::process::Stdio::piped())
28            .stderr(std::process::Stdio::null()) // TODO
29            .stdout(std::process::Stdio::piped())
30            .kill_on_drop(true) // TODO: Graceful quitting
31            .spawn()?;
32
33        let input = process.stdin.take().expect("process must have stdin");
34        let output = process.stdout.take().expect("process must have stdout");
35
36        let (sender, receiver) = mpsc::channel(10);
37        drop(task::spawn(run(input, output, receiver)));
38
39        Ok(Self {
40            _process: process,
41            runner: sender,
42        })
43    }
44}
45
46impl Transport for Stdio {
47    fn listen(&self) -> BoxFuture<'static, io::Result<Channel>> {
48        let mut runner = self.runner.clone();
49
50        async move {
51            let (sender, receiver) = mpsc::channel(1);
52            let _ = runner.send(Action::Listen(sender)).await;
53
54            Ok(receiver)
55        }
56        .boxed()
57    }
58
59    fn send(&self, bytes: Bytes) -> BoxFuture<'static, io::Result<Channel>> {
60        let mut runner = self.runner.clone();
61
62        async move {
63            let (sender, receiver) = mpsc::channel(1);
64            let _ = runner.send(Action::Send(bytes, sender)).await;
65
66            Ok(receiver)
67        }
68        .boxed()
69    }
70}
71
72type Sender = mpsc::Sender<Bytes>;
73
74enum Action {
75    Listen(Sender),
76    Send(Bytes, Sender),
77}
78
79async fn run(
80    mut input: impl AsyncWrite + Unpin,
81    output: impl AsyncRead + Unpin,
82    mut receiver: mpsc::Receiver<Action>,
83) -> io::Result<()> {
84    use future::Either;
85
86    let mut output = BufReader::new(output);
87    let mut listeners = Vec::new();
88    let mut buffer = Vec::new();
89
90    loop {
91        let event = {
92            let next_line = Box::pin(output.read_until(0xA, &mut buffer));
93            let next_action = receiver.next().fuse();
94            let next_event = future::select(next_line, next_action);
95
96            match next_event.await {
97                Either::Left((line, _)) => Either::Left(line),
98                Either::Right((Some(action), _)) => Either::Right(action),
99                _ => return Ok(()),
100            }
101        };
102
103        match event {
104            Either::Right(Action::Listen(sender)) => {
105                listeners.push(sender);
106            }
107            Either::Left(Ok(n)) => {
108                if n == 0 {
109                    return Ok(());
110                }
111
112                let bytes = Bytes::from_owner(std::mem::take(&mut buffer));
113
114                for listener in &mut listeners {
115                    let _ = listener.send(bytes.clone()).await;
116                }
117            }
118            Either::Right(Action::Send(bytes, mut sender)) => {
119                write(&mut input, &bytes).await?;
120
121                let Ok(Message::Request(_)) = Message::<mcp::Ignored>::deserialize(&bytes) else {
122                    continue;
123                };
124
125                while let Ok(n) = output.read_until(0xA, &mut buffer).await {
126                    if n == 0 {
127                        return Ok(());
128                    }
129
130                    let bytes = Bytes::from_owner(std::mem::take(&mut buffer));
131                    let _ = sender.send(bytes.clone()).await;
132
133                    if let Ok(server::Message::Response(_)) =
134                        server::Message::<mcp::Ignored>::deserialize(&bytes)
135                    {
136                        break;
137                    }
138                }
139            }
140            _ => {
141                break;
142            }
143        }
144    }
145
146    Ok(())
147}
148
149async fn write(input: &mut (impl AsyncWrite + Unpin), data: &[u8]) -> io::Result<()> {
150    input.write_all(data).await?;
151    input.write_u8(0xA).await?;
152    input.flush().await
153}