1use super::*;
2use futures::channel::oneshot;
3use futures::future::Shared;
4use futures::lock::Mutex;
5use futures::{FutureExt, SinkExt};
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
9
10static REQUEST_ID: AtomicU64 = AtomicU64::new(0);
11
12struct ClientInner<Out> {
13 ids: u64,
14 response_callbacks: HashMap<u64, futures::channel::oneshot::Sender<proto::Response>>,
15 output: Out,
16}
17
18struct Client<Out> {
19 inner: Mutex<ClientInner<Out>>,
20 pid: u32,
21 kill_cmd: std::sync::Mutex<Option<oneshot::Sender<()>>>,
22 status: Shared<oneshot::Receiver<i32>>,
23}
24
25impl<Out: Sink<proto::Request> + Unpin> Client<Out>
26where
27 Out::Error: Debug,
28{
29 fn new(
30 output: Out,
31 pid: u32,
32 kill_cmd: oneshot::Sender<()>,
33 status: Shared<oneshot::Receiver<i32>>,
34 ) -> Self {
35 let kill_cmd = std::sync::Mutex::new(Some(kill_cmd));
36 let inner = Mutex::new(ClientInner {
37 ids: 1,
38 response_callbacks: Default::default(),
39 output,
40 });
41 Client {
42 inner,
43 pid,
44 kill_cmd,
45 status,
46 }
47 }
48
49 async fn call(&self, mut param: proto::Request) -> proto::Response {
50 let (tx, rx) = futures::channel::oneshot::channel();
51 {
52 let mut inner = self.inner.lock().await;
53 inner.ids += 1;
54 let id = inner.ids;
55 param.id = id;
56 let _ = inner.response_callbacks.insert(id, tx);
57 log::debug!("sending request: {:?}", param);
58 if let Err(e) = SinkExt::send(&mut inner.output, param).await {
59 log::error!("Runtime client write error: {:?}", e);
60 }
61 }
62 log::debug!("waiting for response");
63 let response = rx.await.unwrap();
64 log::debug!("got response: {:?}", response);
65 response
66 }
67
68 async fn handle_response(&self, resp: proto::Response) {
69 if resp.event {
70 todo!()
71 }
72
73 if let Some(callback) = {
74 let mut inner = self.inner.lock().await;
75 inner.response_callbacks.remove(&resp.id)
76 } {
77 let _ = callback.send(resp);
78 }
79 }
80}
81
82impl<Out: Sink<proto::Request> + Unpin> RuntimeControl for Arc<Client<Out>> {
83 fn id(&self) -> u32 {
84 self.pid
85 }
86
87 fn stop(&self) {
88 if let Some(s) = self.kill_cmd.lock().unwrap().take() {
89 let _ = s.send(());
90 }
91 }
92
93 fn stopped(&self) -> BoxFuture<'_, i32> {
94 Box::pin(self.status.clone().then(|r| async move { r.unwrap_or(1) }))
95 }
96}
97
98impl<Out: Sink<proto::Request> + Unpin> RuntimeService for Arc<Client<Out>>
99where
100 Out::Error: Debug,
101{
102 fn hello(&self, version: &str) -> AsyncResponse<'_, String> {
103 let id = REQUEST_ID.fetch_add(1, Relaxed);
104 let request = proto::Request {
105 id,
106 command: Some(proto::request::Command::Hello(proto::request::Hello {
107 version: version.to_owned(),
108 })),
109 };
110 let fut = self.call(request);
111 async move {
112 match fut.await.command {
113 Some(proto::response::Command::Hello(hello)) => Ok(hello.version),
114 Some(proto::response::Command::Error(error)) => Err(error),
115 _ => panic!("invalid response"),
116 }
117 }
118 .boxed_local()
119 }
120
121 fn run_process(&self, run: RunProcess) -> AsyncResponse<RunProcessResp> {
122 let id = REQUEST_ID.fetch_add(1, Relaxed);
123 let request = proto::Request {
124 id,
125 command: Some(proto::request::Command::Run(run)),
126 };
127 let fut = self.call(request);
128 async move {
129 match fut.await.command {
130 Some(proto::response::Command::Run(run)) => Ok(run),
131 Some(proto::response::Command::Error(error)) => Err(error),
132 _ => panic!("invalid response"),
133 }
134 }
135 .boxed_local()
136 }
137
138 fn kill_process(&self, kill: KillProcess) -> AsyncResponse<()> {
139 let id = REQUEST_ID.fetch_add(1, Relaxed);
140 let request = proto::Request {
141 id,
142 command: Some(proto::request::Command::Kill(kill)),
143 };
144 let fut = self.call(request);
145 async move {
146 match fut.await.command {
147 Some(proto::response::Command::Kill(_kill)) => Ok(()),
148 Some(proto::response::Command::Error(error)) => Err(error),
149 _ => panic!("invalid response"),
150 }
151 }
152 .boxed_local()
153 }
154
155 fn create_network(&self, network: CreateNetwork) -> AsyncResponse<CreateNetworkResp> {
156 let id = REQUEST_ID.fetch_add(1, Relaxed);
157 let request = proto::Request {
158 id,
159 command: Some(proto::request::Command::Network(network)),
160 };
161 let fut = self.call(request);
162 async move {
163 match fut.await.command {
164 Some(proto::response::Command::Network(res)) => Ok(res),
165 Some(proto::response::Command::Error(error)) => Err(error),
166 _ => panic!("invalid response"),
167 }
168 }
169 .boxed_local()
170 }
171
172 fn shutdown(&self) -> AsyncResponse<'_, ()> {
173 let shutdown = proto::request::Shutdown::default();
174 let id = REQUEST_ID.fetch_add(1, Relaxed);
175
176 let request = proto::Request {
177 id,
178 command: Some(proto::request::Command::Shutdown(shutdown)),
179 };
180 let fut = self.call(request);
181 async move {
182 match fut.await.command {
183 Some(proto::response::Command::Shutdown(_shutdown)) => Ok(()),
184 Some(proto::response::Command::Error(error)) => Err(error),
185 _ => panic!("invalid response"),
186 }
187 }
188 .boxed_local()
189 }
190}
191
192pub async fn spawn(
194 mut command: process::Command,
195 event_handler: impl RuntimeHandler + Send + Sync + 'static,
196) -> Result<impl RuntimeService + RuntimeControl + Clone, anyhow::Error> {
197 command.stdin(Stdio::piped()).stdout(Stdio::piped());
198 command.kill_on_drop(true);
199 let mut child: process::Child = command.spawn()?;
200 let pid = child
201 .id()
202 .ok_or_else(|| anyhow::anyhow!("Missing child process PID"))?;
203 let stdin =
204 tokio_util::codec::FramedWrite::new(child.stdin.take().unwrap(), codec::Codec::default());
205 let stdout = child.stdout.take().unwrap();
206 let (kill_tx, kill_rx) = oneshot::channel();
207 let (status_tx, status_rx) = oneshot::channel();
208
209 let client = Arc::new(Client::new(stdin, pid, kill_tx, status_rx.shared()));
210 {
211 let client = client.clone();
212 let mut stdout =
213 tokio_util::codec::FramedRead::new(stdout, codec::Codec::<proto::Response>::default());
214 let pump = async move {
215 while let Some(Ok(it)) = stdout.next().await {
216 if it.event {
217 handle_event(it, &event_handler).await;
218 } else {
219 client.handle_response(it).await;
220 }
221 }
222 };
223 let _ = tokio::task::spawn(async move {
224 let code = tokio::select! {
225 r = child.wait() => map_return_code(r, pid),
226 _ = pump => {
227 let _ = child.start_kill();
228 map_return_code(child.wait().await, pid)
229 }
230 _ = kill_rx => {
231 let _ = child.start_kill();
232 map_return_code(child.wait().await, pid)
233 }
234 };
235 if status_tx.send(code).is_err() {
236 log::warn!("Unable to update process {} status: receiver is gone", pid);
237 }
238 });
239 }
240
241 async fn handle_event(response: proto::Response, handler: &impl RuntimeHandler) {
242 use proto::response::Command;
243 match response.command {
244 Some(Command::Status(status)) => {
245 handler.on_process_status(status).await;
246 }
247 Some(Command::RtStatus(status)) => {
248 handler.on_runtime_status(status).await;
249 }
250 cmd => log::warn!("invalid event: {:?}", cmd),
251 }
252 }
253
254 Ok(client)
255}
256
257fn map_return_code(result: std::io::Result<ExitStatus>, pid: u32) -> i32 {
258 result
259 .map(|e| match e.code() {
260 Some(code) => code,
261 None => {
262 log::warn!("Unable to kill process {}: {}", pid, e);
263 1
264 }
265 })
266 .unwrap_or_else(|e| {
267 log::warn!("Child process {} error: {}", pid, e);
268 1
269 })
270}