tower_duplex/lib.rs
1//! A [`tower::Service`] that implements a server and a client simultaneously over a
2//! bi-directional channel. As a server it is able to process RPC calls from a remote client,
3//! and as a client it is capable of making RPC calls into a remote server. It is very
4//! convinient in a system that requires asynchronous communication in both directions.
5use std::collections::VecDeque;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::atomic::AtomicUsize;
9use std::sync::atomic::Ordering::SeqCst;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use futures::future::pending;
14use futures::stream::{FuturesOrdered, FuturesUnordered};
15use futures::{Sink, SinkExt, StreamExt, TryStream, TryStreamExt};
16use serde::{Deserialize, Serialize};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio::sync::{mpsc, oneshot};
19use tower::Service;
20
21mod codec;
22mod serialize;
23#[cfg(test)]
24mod test;
25
26/// A wrapper for an RPC request or response. The wrapper includes a tag to demultiplex responses
27/// and match them to the correct requests.
28#[derive(Debug)]
29pub enum DuplexValue<Request, Response> {
30 Request(u8, Request),
31 Response(u8, Response),
32}
33
34/// A [`tower::Service`] that implements a server and a client simultaneously over a bi-directional
35/// channel. As a server it is able to process RPC calls from a remote client, and as a client it is
36/// capable of making RPC calls into a remote server. It is very convinient in a system that
37/// requires asynchronous communication in both directions.
38pub struct DuplexService<Request, Response, S: Service<ServiceRequest>, ServiceRequest> {
39 calls: mpsc::UnboundedReceiver<(Request, oneshot::Sender<Response>)>,
40 service: S,
41 load: Arc<AtomicUsize>,
42 _p: PhantomData<ServiceRequest>,
43}
44
45/// A client side handle to [`DuplexService`], used for RPC calls to the remote server.
46#[derive(Clone)]
47pub struct DuplexClient<Request, Response> {
48 sender: mpsc::UnboundedSender<(Request, oneshot::Sender<Response>)>,
49 load: Arc<AtomicUsize>,
50}
51
52pub enum DuplexError<E1, E2> {
53 RemoteHangUp,
54 RemoteError(E1),
55 SendError(E2),
56}
57
58/// Tagger generates unique tag values for RPC calls, required to match requests to responses when
59/// multiplexing requests. The used tags are backed by a 128 bitarray, and therefore the maximal
60/// amount of tags generated is 128. Which is also the maximum we set for concurrent RPC calls in
61/// flight.
62struct Tagger {
63 bitmask: u128,
64}
65
66impl Tagger {
67 fn new() -> Self {
68 Tagger { bitmask: 0 }
69 }
70
71 /// Get the smallest available tag value
72 fn get_tag(&mut self) -> Option<u8> {
73 // Simply look for the next unset bit
74 let r = self.bitmask.trailing_ones() as u8;
75 self.bitmask |= 1u128.checked_shl(r as _)?;
76 Some(r)
77 }
78
79 /// Return the tag after we finished using it. It is very impotant to release tags, otherwise
80 /// the system will run out of tags very quickly.
81 fn release_tag(&mut self, tag: u8) {
82 self.bitmask &= !(1 << tag);
83 }
84
85 /// Check if tagger can't allocate any more tags
86 fn full(&self) -> bool {
87 self.bitmask == u128::MAX
88 }
89}
90
91/// Feed the queued packets into our sink, then flush. This is implemented as a standalone function
92/// so we can use it in a select loop with FuturesOrdered (that implement StreamExt and therefore
93/// cancellation safe). Once done the function returns the owned sink, so more writes can be
94/// scheduled on it.
95async fn do_send<I, S: Sink<I> + Unpin>(items: Vec<I>, sender: Option<S>) -> Result<S, S::Error> {
96 let mut sender = match sender {
97 Some(sender) => sender,
98 None => pending().await,
99 };
100
101 for item in items {
102 sender.feed(item).await?;
103 }
104
105 sender.flush().await?;
106
107 Ok(sender)
108}
109
110impl<Request, Response, S: Service<ServiceRequest>, ServiceRequest>
111 DuplexService<Request, Response, S, ServiceRequest>
112{
113 // This is a nice workaround for initiating an array with a non Copy type
114 // https://github.com/rust-lang/rust/issues/44796#issuecomment-967747810
115 const INIT_ARR: Option<oneshot::Sender<Response>> = None;
116
117 /// Create a new server instance, with an associated client handle. The server stops if all the
118 /// client handles are dropped. To start the server use the [`run`] or [`run_with`] methods.
119 ///
120 /// # Example
121 ///
122 /// ```
123 /// use core::task::{Context, Poll};
124 ///
125 /// use tower_duplex::DuplexService;
126 ///
127 /// /// A Service that converts requests to lower or upper case
128 /// enum ChangeCase {
129 /// ToLower,
130 /// ToUpper,
131 /// }
132 ///
133 /// impl tower::Service<String> for ChangeCase {
134 /// type Response = String;
135 /// type Error = ();
136 /// type Future = std::pin::Pin<
137 /// Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
138 /// >;
139 ///
140 /// fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
141 /// Poll::Ready(Ok(()))
142 /// }
143 ///
144 /// fn call(&mut self, req: String) -> Self::Future {
145 /// let to_upper = matches!(self, ChangeCase::ToUpper);
146 /// Box::pin(async move {
147 /// if to_upper {
148 /// Ok(req.to_uppercase())
149 /// } else {
150 /// Ok(req.to_lowercase())
151 /// }
152 /// })
153 /// }
154 /// }
155 ///
156 /// let (server, client): (DuplexService<String, String, _, _>, _) =
157 /// DuplexService::new_pair(ChangeCase::ToUpper);
158 /// ```
159 pub fn new_pair(service: S) -> (Self, DuplexClient<Request, Response>) {
160 let load = Arc::new(AtomicUsize::new(0));
161 let (calls_sender, calls) = mpsc::unbounded_channel();
162 (
163 DuplexService {
164 calls,
165 service,
166 load: load.clone(),
167 _p: Default::default(),
168 },
169 DuplexClient {
170 sender: calls_sender,
171 load,
172 },
173 )
174 }
175
176 /// Run the server loop with the provided [`TryStream`] and [`Sink`]. The server loop
177 /// serves remote RPC calls, and handles local RPC calls from client handles.
178 pub async fn run_with<Rcv, Snd, RcvErr, SndErr>(
179 mut self,
180 mut receiver: Rcv,
181 sender: Snd,
182 ) -> Result<(), DuplexError<RcvErr, SndErr>>
183 where
184 Rcv: TryStream<Ok = DuplexValue<ServiceRequest, Response>, Error = RcvErr> + Unpin,
185 Snd: Sink<
186 DuplexValue<Request, <<S as Service<ServiceRequest>>::Future as Future>::Output>,
187 Error = SndErr,
188 > + Unpin,
189 {
190 // A list of pending calls to the inner service
191 let mut pending_calls = FuturesUnordered::new();
192 // A list of pening RPC return channels
193 let mut pending_rpcs: [Option<oneshot::Sender<Response>>; 128] = [Self::INIT_ARR; 128];
194
195 let mut tagger = Tagger::new();
196
197 let mut sender = Some(sender);
198
199 // Send items that will be sent in the next send op
200 let mut pending_send = Vec::new();
201 // Send items that didn't get a tag and will be sent once there is room in the send queue
202 let mut sending_queue = VecDeque::new();
203 let mut send_fut = FuturesOrdered::new();
204
205 loop {
206 while !sending_queue.is_empty() && !tagger.full() {
207 // Move items to the from queue to pending
208 let tag = tagger.get_tag().expect("Tagger not full");
209 let (request, result_sender) = sending_queue.pop_front().expect("Queue not empty");
210 tracing::trace!("New RPC call {tag}");
211 pending_send.push(DuplexValue::Request(tag, request));
212 pending_rpcs[tag as usize] = Some(result_sender);
213 }
214
215 if sender.is_some() && !pending_send.is_empty() {
216 // We got something to send out, and we are not already sending anything
217 tracing::trace!("Flushing send buffer");
218 let to_send = pending_send.split_off(0);
219 send_fut.push(do_send(to_send, sender.take()));
220 }
221
222 let DuplexService {
223 service,
224 calls,
225 load,
226 ..
227 } = &mut self;
228 tokio::select! {
229 response = receiver.try_next() => {
230 match response {
231 Ok(Some(DuplexValue::Request(tag, req))) => {
232 tracing::trace!("New request {tag}");
233 // This is a request from a remote client (or server/client) to perform an RPC.
234 // Generate the future for the RPC and push it along with the tag to the list of executing calls.
235 let fut = service.call(req);
236 pending_calls.push(async move { (fut.await, tag) });
237 }
238 Ok(Some(DuplexValue::Response(tag, res))) => {
239 tracing::trace!("Response for {tag}");
240 // This is a response from the remote server (or server/client) to an RPC we initiated.
241 // Match the tag of the response to a channel on which to send the response.
242 match pending_rpcs.get_mut(tag as usize).and_then(Option::take) {
243 None => tracing::error!("No matching request for response {tag}"),
244 Some(chan) => if chan.send(res).is_err() {
245 tracing::debug!("Channel for response {tag} went away");
246 },
247 }
248 tagger.release_tag(tag);
249 load.fetch_sub(1, SeqCst);
250 }
251 Err(err) => return Err(DuplexError::RemoteError(err)),
252 Ok(None) => return Err(DuplexError::RemoteHangUp)
253 }
254 }
255 Some((request, result_sender)) = calls.recv() => {
256 // This is a request from our own client handle to perform an RPC call.
257 if let Some(tag) = tagger.get_tag() {
258 tracing::trace!("New RPC call {tag}");
259 pending_send.push(DuplexValue::Request(tag, request));
260 pending_rpcs[tag as usize] = Some(result_sender);
261 } else {
262 tracing::trace!("Queued RPC call");
263 sending_queue.push_back((request, result_sender));
264 }
265 load.fetch_add(1, SeqCst);
266 }
267 Some((result, tag)) = pending_calls.next() => {
268 // One of the executed calls is finished, send it back to the remote client (or server/client).
269 tracing::trace!("Call {tag} finished");
270 pending_send.push(DuplexValue::Response(tag, result));
271 }
272 Some(send_result) = send_fut.next() => {
273 match send_result {
274 Ok(enc) => {
275 // Return the codec to its place, so we can issue the next send operation
276 sender.replace(enc);
277 },
278 Err(err) => return Err(DuplexError::SendError(err)),
279 }
280 }
281 }
282 }
283 }
284}
285
286impl<Request, Response, S: Service<ServiceRequest>, ServiceRequest>
287 DuplexService<Request, Response, S, ServiceRequest>
288where
289 for<'de> ServiceRequest: Serialize + Deserialize<'de>,
290 for<'de> <<S as Service<ServiceRequest>>::Future as Future>::Output:
291 Serialize + Deserialize<'de>,
292 for<'de> Request: Serialize + Deserialize<'de>,
293 for<'de> Response: Serialize + Deserialize<'de>,
294{
295 /// Run the server loop with the provided [`AsyncRead`] and [`AsyncWrite`]. The server loop
296 /// serves remote RPC calls, and handles local RPC calls from client handles.
297 ///
298 /// # Example
299 ///
300 /// ```
301 /// use core::task::{Context, Poll};
302 ///
303 /// use tokio::sync::mpsc;
304 /// use tower::Service;
305 /// use tower_duplex::DuplexService;
306 ///
307 /// /// A Service that converts requests to lower or upper case
308 /// enum ChangeCase {
309 /// ToLower,
310 /// ToUpper,
311 /// }
312 ///
313 /// impl tower::Service<String> for ChangeCase {
314 /// type Response = String;
315 /// type Error = ();
316 /// type Future = std::pin::Pin<
317 /// Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
318 /// >;
319 ///
320 /// fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
321 /// Poll::Ready(Ok(()))
322 /// }
323 ///
324 /// fn call(&mut self, req: String) -> Self::Future {
325 /// let to_upper = matches!(self, ChangeCase::ToUpper);
326 /// Box::pin(async move {
327 /// if to_upper {
328 /// Ok(req.to_uppercase())
329 /// } else {
330 /// Ok(req.to_lowercase())
331 /// }
332 /// })
333 /// }
334 /// }
335 ///
336 /// #[tokio::main]
337 /// async fn main() {
338 /// // `server1` handles serves requests from `server2` and converts strings to upper case.
339 /// // It also forwards requests from `client1` to `server2`.
340 /// let (server1, mut client1): (DuplexService<String, Result<String, ()>, _, _>, _) =
341 /// DuplexService::new_pair(ChangeCase::ToUpper);
342 /// // `server2` handles serves requests from `server1` and converts strings to lower case.
343 /// // It also forwards requests from `client2` to `server1`.
344 /// let (server2, mut client2): (DuplexService<String, Result<String, ()>, _, _>, _) =
345 /// DuplexService::new_pair(ChangeCase::ToLower);
346 ///
347 /// let ((r1, w1), (r2, w2)) = tokio::net::UnixStream::pair()
348 /// .map(|(a, b)| (a.into_split(), b.into_split()))
349 /// .unwrap();
350 ///
351 /// tokio::spawn(server1.run(r1, w1));
352 /// tokio::spawn(server2.run(r2, w2));
353 ///
354 /// assert_eq!(
355 /// client2.call("String".to_string()).await.unwrap().unwrap(),
356 /// "STRING"
357 /// );
358 ///
359 /// assert_eq!(
360 /// client1.call("String".to_string()).await.unwrap().unwrap(),
361 /// "string"
362 /// );
363 /// }
364 /// ```
365 pub async fn run<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
366 self,
367 reader: R,
368 writer: W,
369 ) -> Result<(), DuplexError<std::io::Error, std::io::Error>> {
370 let decoder: codec::FrameCodec<DuplexValue<ServiceRequest, Response>> = Default::default();
371 let encoder: codec::FrameCodec<
372 DuplexValue<Request, <<S as Service<ServiceRequest>>::Future as Future>::Output>,
373 > = Default::default();
374
375 let frame_reader = tokio_util::codec::FramedRead::new(reader, decoder);
376 let frame_writer = tokio_util::codec::FramedWrite::new(writer, encoder);
377
378 self.run_with(frame_reader, frame_writer).await
379 }
380}
381
382impl<Request, Response> Service<Request> for DuplexClient<Request, Response> {
383 type Response = Response;
384
385 type Error = oneshot::error::RecvError;
386
387 type Future = oneshot::Receiver<Response>;
388
389 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
390 Poll::Ready(Ok(()))
391 }
392
393 fn call(&mut self, req: Request) -> Self::Future {
394 let (reseponse_send, response_recv) = oneshot::channel();
395 // We ignore the send error here, because if send fails it just means the service has
396 // stopped, in which case our oneshot will immediately get dropped and an error returned
397 // from the future
398 let _ = self.sender.send((req, reseponse_send));
399 response_recv
400 }
401}
402
403impl<Request, Response> tower::load::Load for DuplexClient<Request, Response> {
404 type Metric = usize;
405
406 fn load(&self) -> Self::Metric {
407 self.load.load(SeqCst)
408 }
409}
410
411impl<Request, Response> DuplexClient<Request, Response> {
412 /// A call that does not require a mutable reference
413 pub fn do_call(
414 &self,
415 req: Request,
416 ) -> impl Future<Output = Result<Response, oneshot::error::RecvError>> {
417 let (reseponse_send, response_recv) = oneshot::channel();
418 // We ignore the send error here, because if send fails it just means the service has
419 // stopped, in which case our oneshot will immediately get dropped and an error returned
420 // from the future
421 let _ = self.sender.send((req, reseponse_send));
422 response_recv
423 }
424}