ya_runtime_api/server/
client.rs

1use super::*;
2use futures::channel::oneshot;
3use futures::future::Shared;
4use futures::lock::Mutex;
5use futures::{FutureExt, SinkExt};
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
9
10static REQUEST_ID: AtomicU64 = AtomicU64::new(0);
11
12struct ClientInner<Out> {
13    ids: u64,
14    response_callbacks: HashMap<u64, futures::channel::oneshot::Sender<proto::Response>>,
15    output: Out,
16}
17
18struct Client<Out> {
19    inner: Mutex<ClientInner<Out>>,
20    pid: u32,
21    kill_cmd: std::sync::Mutex<Option<oneshot::Sender<()>>>,
22    status: Shared<oneshot::Receiver<i32>>,
23}
24
25impl<Out: Sink<proto::Request> + Unpin> Client<Out>
26where
27    Out::Error: Debug,
28{
29    fn new(
30        output: Out,
31        pid: u32,
32        kill_cmd: oneshot::Sender<()>,
33        status: Shared<oneshot::Receiver<i32>>,
34    ) -> Self {
35        let kill_cmd = std::sync::Mutex::new(Some(kill_cmd));
36        let inner = Mutex::new(ClientInner {
37            ids: 1,
38            response_callbacks: Default::default(),
39            output,
40        });
41        Client {
42            inner,
43            pid,
44            kill_cmd,
45            status,
46        }
47    }
48
49    async fn call(&self, mut param: proto::Request) -> proto::Response {
50        let (tx, rx) = futures::channel::oneshot::channel();
51        {
52            let mut inner = self.inner.lock().await;
53            inner.ids += 1;
54            let id = inner.ids;
55            param.id = id;
56            let _ = inner.response_callbacks.insert(id, tx);
57            log::debug!("sending request: {:?}", param);
58            if let Err(e) = SinkExt::send(&mut inner.output, param).await {
59                log::error!("Runtime client write error: {:?}", e);
60            }
61        }
62        log::debug!("waiting for response");
63        let response = rx.await.unwrap();
64        log::debug!("got response: {:?}", response);
65        response
66    }
67
68    async fn handle_response(&self, resp: proto::Response) {
69        if resp.event {
70            todo!()
71        }
72
73        if let Some(callback) = {
74            let mut inner = self.inner.lock().await;
75            inner.response_callbacks.remove(&resp.id)
76        } {
77            let _ = callback.send(resp);
78        }
79    }
80}
81
82impl<Out: Sink<proto::Request> + Unpin> RuntimeControl for Arc<Client<Out>> {
83    fn id(&self) -> u32 {
84        self.pid
85    }
86
87    fn stop(&self) {
88        if let Some(s) = self.kill_cmd.lock().unwrap().take() {
89            let _ = s.send(());
90        }
91    }
92
93    fn stopped(&self) -> BoxFuture<'_, i32> {
94        Box::pin(self.status.clone().then(|r| async move { r.unwrap_or(1) }))
95    }
96}
97
98impl<Out: Sink<proto::Request> + Unpin> RuntimeService for Arc<Client<Out>>
99where
100    Out::Error: Debug,
101{
102    fn hello(&self, version: &str) -> AsyncResponse<'_, String> {
103        let id = REQUEST_ID.fetch_add(1, Relaxed);
104        let request = proto::Request {
105            id,
106            command: Some(proto::request::Command::Hello(proto::request::Hello {
107                version: version.to_owned(),
108            })),
109        };
110        let fut = self.call(request);
111        async move {
112            match fut.await.command {
113                Some(proto::response::Command::Hello(hello)) => Ok(hello.version),
114                Some(proto::response::Command::Error(error)) => Err(error),
115                _ => panic!("invalid response"),
116            }
117        }
118        .boxed_local()
119    }
120
121    fn run_process(&self, run: RunProcess) -> AsyncResponse<RunProcessResp> {
122        let id = REQUEST_ID.fetch_add(1, Relaxed);
123        let request = proto::Request {
124            id,
125            command: Some(proto::request::Command::Run(run)),
126        };
127        let fut = self.call(request);
128        async move {
129            match fut.await.command {
130                Some(proto::response::Command::Run(run)) => Ok(run),
131                Some(proto::response::Command::Error(error)) => Err(error),
132                _ => panic!("invalid response"),
133            }
134        }
135        .boxed_local()
136    }
137
138    fn kill_process(&self, kill: KillProcess) -> AsyncResponse<()> {
139        let id = REQUEST_ID.fetch_add(1, Relaxed);
140        let request = proto::Request {
141            id,
142            command: Some(proto::request::Command::Kill(kill)),
143        };
144        let fut = self.call(request);
145        async move {
146            match fut.await.command {
147                Some(proto::response::Command::Kill(_kill)) => Ok(()),
148                Some(proto::response::Command::Error(error)) => Err(error),
149                _ => panic!("invalid response"),
150            }
151        }
152        .boxed_local()
153    }
154
155    fn create_network(&self, network: CreateNetwork) -> AsyncResponse<CreateNetworkResp> {
156        let id = REQUEST_ID.fetch_add(1, Relaxed);
157        let request = proto::Request {
158            id,
159            command: Some(proto::request::Command::Network(network)),
160        };
161        let fut = self.call(request);
162        async move {
163            match fut.await.command {
164                Some(proto::response::Command::Network(res)) => Ok(res),
165                Some(proto::response::Command::Error(error)) => Err(error),
166                _ => panic!("invalid response"),
167            }
168        }
169        .boxed_local()
170    }
171
172    fn shutdown(&self) -> AsyncResponse<'_, ()> {
173        let shutdown = proto::request::Shutdown::default();
174        let id = REQUEST_ID.fetch_add(1, Relaxed);
175
176        let request = proto::Request {
177            id,
178            command: Some(proto::request::Command::Shutdown(shutdown)),
179        };
180        let fut = self.call(request);
181        async move {
182            match fut.await.command {
183                Some(proto::response::Command::Shutdown(_shutdown)) => Ok(()),
184                Some(proto::response::Command::Error(error)) => Err(error),
185                _ => panic!("invalid response"),
186            }
187        }
188        .boxed_local()
189    }
190}
191
192// sends Request, recv Response
193pub async fn spawn(
194    mut command: process::Command,
195    event_handler: impl RuntimeHandler + Send + Sync + 'static,
196) -> Result<impl RuntimeService + RuntimeControl + Clone, anyhow::Error> {
197    command.stdin(Stdio::piped()).stdout(Stdio::piped());
198    command.kill_on_drop(true);
199    let mut child: process::Child = command.spawn()?;
200    let pid = child
201        .id()
202        .ok_or_else(|| anyhow::anyhow!("Missing child process PID"))?;
203    let stdin =
204        tokio_util::codec::FramedWrite::new(child.stdin.take().unwrap(), codec::Codec::default());
205    let stdout = child.stdout.take().unwrap();
206    let (kill_tx, kill_rx) = oneshot::channel();
207    let (status_tx, status_rx) = oneshot::channel();
208
209    let client = Arc::new(Client::new(stdin, pid, kill_tx, status_rx.shared()));
210    {
211        let client = client.clone();
212        let mut stdout =
213            tokio_util::codec::FramedRead::new(stdout, codec::Codec::<proto::Response>::default());
214        let pump = async move {
215            while let Some(Ok(it)) = stdout.next().await {
216                if it.event {
217                    handle_event(it, &event_handler).await;
218                } else {
219                    client.handle_response(it).await;
220                }
221            }
222        };
223        let _ = tokio::task::spawn(async move {
224            let code = tokio::select! {
225                r = child.wait() => map_return_code(r, pid),
226                _ = pump => {
227                    let _ = child.start_kill();
228                    map_return_code(child.wait().await, pid)
229                }
230                _ = kill_rx => {
231                    let _ = child.start_kill();
232                    map_return_code(child.wait().await, pid)
233                }
234            };
235            if status_tx.send(code).is_err() {
236                log::warn!("Unable to update process {} status: receiver is gone", pid);
237            }
238        });
239    }
240
241    async fn handle_event(response: proto::Response, handler: &impl RuntimeHandler) {
242        use proto::response::Command;
243        match response.command {
244            Some(Command::Status(status)) => {
245                handler.on_process_status(status).await;
246            }
247            Some(Command::RtStatus(status)) => {
248                handler.on_runtime_status(status).await;
249            }
250            cmd => log::warn!("invalid event: {:?}", cmd),
251        }
252    }
253
254    Ok(client)
255}
256
257fn map_return_code(result: std::io::Result<ExitStatus>, pid: u32) -> i32 {
258    result
259        .map(|e| match e.code() {
260            Some(code) => code,
261            None => {
262                log::warn!("Unable to kill process {}: {}", pid, e);
263                1
264            }
265        })
266        .unwrap_or_else(|e| {
267            log::warn!("Child process {} error: {}", pid, e);
268            1
269        })
270}