tower_duplex/
lib.rs

1//! A [`tower::Service`] that implements a server and a client simultaneously over a
2//! bi-directional channel. As a server it is able to process RPC calls from a remote client,
3//! and as a client it is capable of making RPC calls into a remote server. It is very
4//! convinient in a system that requires asynchronous communication in both directions.
5use std::collections::VecDeque;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::atomic::AtomicUsize;
9use std::sync::atomic::Ordering::SeqCst;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use futures::future::pending;
14use futures::stream::{FuturesOrdered, FuturesUnordered};
15use futures::{Sink, SinkExt, StreamExt, TryStream, TryStreamExt};
16use serde::{Deserialize, Serialize};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tokio::sync::{mpsc, oneshot};
19use tower::Service;
20
21mod codec;
22mod serialize;
23#[cfg(test)]
24mod test;
25
26/// A wrapper for an RPC request or response. The wrapper includes a tag to demultiplex responses
27/// and match them to the correct requests.
28#[derive(Debug)]
29pub enum DuplexValue<Request, Response> {
30    Request(u8, Request),
31    Response(u8, Response),
32}
33
34/// A [`tower::Service`] that implements a server and a client simultaneously over a bi-directional
35/// channel. As a server it is able to process RPC calls from a remote client, and as a client it is
36/// capable of making RPC calls into a remote server. It is very convinient in a system that
37/// requires asynchronous communication in both directions.
38pub struct DuplexService<Request, Response, S: Service<ServiceRequest>, ServiceRequest> {
39    calls: mpsc::UnboundedReceiver<(Request, oneshot::Sender<Response>)>,
40    service: S,
41    load: Arc<AtomicUsize>,
42    _p: PhantomData<ServiceRequest>,
43}
44
45/// A client side handle to [`DuplexService`], used for RPC calls to the remote server.
46#[derive(Clone)]
47pub struct DuplexClient<Request, Response> {
48    sender: mpsc::UnboundedSender<(Request, oneshot::Sender<Response>)>,
49    load: Arc<AtomicUsize>,
50}
51
52pub enum DuplexError<E1, E2> {
53    RemoteHangUp,
54    RemoteError(E1),
55    SendError(E2),
56}
57
58/// Tagger generates unique tag values for RPC calls, required to match requests to responses when
59/// multiplexing requests. The used tags are backed by a 128 bitarray, and therefore the maximal
60/// amount of tags generated is 128. Which is also the maximum we set for concurrent RPC calls in
61/// flight.
62struct Tagger {
63    bitmask: u128,
64}
65
66impl Tagger {
67    fn new() -> Self {
68        Tagger { bitmask: 0 }
69    }
70
71    /// Get the smallest available tag value
72    fn get_tag(&mut self) -> Option<u8> {
73        // Simply look for the next unset bit
74        let r = self.bitmask.trailing_ones() as u8;
75        self.bitmask |= 1u128.checked_shl(r as _)?;
76        Some(r)
77    }
78
79    /// Return the tag after we finished using it. It is very impotant to release tags, otherwise
80    /// the system will run out of tags very quickly.
81    fn release_tag(&mut self, tag: u8) {
82        self.bitmask &= !(1 << tag);
83    }
84
85    /// Check if tagger can't allocate any more tags
86    fn full(&self) -> bool {
87        self.bitmask == u128::MAX
88    }
89}
90
91/// Feed the queued packets into our sink, then flush. This is implemented as a standalone function
92/// so we can use it in a select loop with FuturesOrdered (that implement StreamExt and therefore
93/// cancellation safe). Once done the function returns the owned sink, so more writes can be
94/// scheduled on it.
95async fn do_send<I, S: Sink<I> + Unpin>(items: Vec<I>, sender: Option<S>) -> Result<S, S::Error> {
96    let mut sender = match sender {
97        Some(sender) => sender,
98        None => pending().await,
99    };
100
101    for item in items {
102        sender.feed(item).await?;
103    }
104
105    sender.flush().await?;
106
107    Ok(sender)
108}
109
110impl<Request, Response, S: Service<ServiceRequest>, ServiceRequest>
111    DuplexService<Request, Response, S, ServiceRequest>
112{
113    // This is a nice workaround for initiating an array with a non Copy type
114    // https://github.com/rust-lang/rust/issues/44796#issuecomment-967747810
115    const INIT_ARR: Option<oneshot::Sender<Response>> = None;
116
117    /// Create a new server instance, with an associated client handle. The server stops if all the
118    /// client handles are dropped. To start the server use the [`run`] or [`run_with`] methods.
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// use core::task::{Context, Poll};
124    ///
125    /// use tower_duplex::DuplexService;
126    ///
127    /// /// A Service that converts requests to lower or upper case
128    /// enum ChangeCase {
129    ///     ToLower,
130    ///     ToUpper,
131    /// }
132    ///
133    /// impl tower::Service<String> for ChangeCase {
134    ///     type Response = String;
135    ///     type Error = ();
136    ///     type Future = std::pin::Pin<
137    ///         Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
138    ///     >;
139    ///
140    ///     fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
141    ///         Poll::Ready(Ok(()))
142    ///     }
143    ///
144    ///     fn call(&mut self, req: String) -> Self::Future {
145    ///         let to_upper = matches!(self, ChangeCase::ToUpper);
146    ///         Box::pin(async move {
147    ///             if to_upper {
148    ///                 Ok(req.to_uppercase())
149    ///             } else {
150    ///                 Ok(req.to_lowercase())
151    ///             }
152    ///         })
153    ///     }
154    /// }
155    ///
156    /// let (server, client): (DuplexService<String, String, _, _>, _) =
157    ///     DuplexService::new_pair(ChangeCase::ToUpper);
158    /// ```
159    pub fn new_pair(service: S) -> (Self, DuplexClient<Request, Response>) {
160        let load = Arc::new(AtomicUsize::new(0));
161        let (calls_sender, calls) = mpsc::unbounded_channel();
162        (
163            DuplexService {
164                calls,
165                service,
166                load: load.clone(),
167                _p: Default::default(),
168            },
169            DuplexClient {
170                sender: calls_sender,
171                load,
172            },
173        )
174    }
175
176    /// Run the server loop with the provided [`TryStream`] and [`Sink`]. The server loop
177    /// serves remote RPC calls, and handles local RPC calls from client handles.
178    pub async fn run_with<Rcv, Snd, RcvErr, SndErr>(
179        mut self,
180        mut receiver: Rcv,
181        sender: Snd,
182    ) -> Result<(), DuplexError<RcvErr, SndErr>>
183    where
184        Rcv: TryStream<Ok = DuplexValue<ServiceRequest, Response>, Error = RcvErr> + Unpin,
185        Snd: Sink<
186                DuplexValue<Request, <<S as Service<ServiceRequest>>::Future as Future>::Output>,
187                Error = SndErr,
188            > + Unpin,
189    {
190        // A list of pending calls to the inner service
191        let mut pending_calls = FuturesUnordered::new();
192        // A list of pening RPC return channels
193        let mut pending_rpcs: [Option<oneshot::Sender<Response>>; 128] = [Self::INIT_ARR; 128];
194
195        let mut tagger = Tagger::new();
196
197        let mut sender = Some(sender);
198
199        // Send items that will be sent in the next send op
200        let mut pending_send = Vec::new();
201        // Send items that didn't get a tag and will be sent once there is room in the send queue
202        let mut sending_queue = VecDeque::new();
203        let mut send_fut = FuturesOrdered::new();
204
205        loop {
206            while !sending_queue.is_empty() && !tagger.full() {
207                // Move items to the from queue to pending
208                let tag = tagger.get_tag().expect("Tagger not full");
209                let (request, result_sender) = sending_queue.pop_front().expect("Queue not empty");
210                tracing::trace!("New RPC call {tag}");
211                pending_send.push(DuplexValue::Request(tag, request));
212                pending_rpcs[tag as usize] = Some(result_sender);
213            }
214
215            if sender.is_some() && !pending_send.is_empty() {
216                // We got something to send out, and we are not already sending anything
217                tracing::trace!("Flushing send buffer");
218                let to_send = pending_send.split_off(0);
219                send_fut.push(do_send(to_send, sender.take()));
220            }
221
222            let DuplexService {
223                service,
224                calls,
225                load,
226                ..
227            } = &mut self;
228            tokio::select! {
229                response = receiver.try_next() => {
230                    match response {
231                        Ok(Some(DuplexValue::Request(tag, req))) => {
232                            tracing::trace!("New request {tag}");
233                            // This is a request from a remote client (or server/client) to perform an RPC.
234                            // Generate the future for the RPC and push it along with the tag to the list of executing calls.
235                            let fut = service.call(req);
236                            pending_calls.push(async move { (fut.await, tag) });
237                        }
238                        Ok(Some(DuplexValue::Response(tag, res))) => {
239                            tracing::trace!("Response for {tag}");
240                            // This is a response from the remote server (or server/client) to an RPC we initiated.
241                            // Match the tag of the response to a channel on which to send the response.
242                            match pending_rpcs.get_mut(tag as usize).and_then(Option::take) {
243                                None => tracing::error!("No matching request for response {tag}"),
244                                Some(chan) => if chan.send(res).is_err() {
245                                    tracing::debug!("Channel for response {tag} went away");
246                                },
247                            }
248                            tagger.release_tag(tag);
249                            load.fetch_sub(1, SeqCst);
250                        }
251                        Err(err) => return Err(DuplexError::RemoteError(err)),
252                        Ok(None) => return Err(DuplexError::RemoteHangUp)
253                    }
254                }
255                Some((request, result_sender)) = calls.recv() => {
256                    // This is a request from our own client handle to perform an RPC call.
257                    if let Some(tag) = tagger.get_tag() {
258                        tracing::trace!("New RPC call {tag}");
259                        pending_send.push(DuplexValue::Request(tag, request));
260                        pending_rpcs[tag as usize] = Some(result_sender);
261                    } else {
262                        tracing::trace!("Queued RPC call");
263                        sending_queue.push_back((request, result_sender));
264                    }
265                    load.fetch_add(1, SeqCst);
266                }
267                Some((result, tag)) = pending_calls.next() => {
268                    // One of the executed calls is finished, send it back to the remote client (or server/client).
269                    tracing::trace!("Call {tag} finished");
270                    pending_send.push(DuplexValue::Response(tag, result));
271                }
272                Some(send_result) = send_fut.next() => {
273                    match send_result {
274                        Ok(enc) => {
275                            // Return the codec to its place, so we can issue the next send operation
276                            sender.replace(enc);
277                        },
278                        Err(err) => return Err(DuplexError::SendError(err)),
279                    }
280                }
281            }
282        }
283    }
284}
285
286impl<Request, Response, S: Service<ServiceRequest>, ServiceRequest>
287    DuplexService<Request, Response, S, ServiceRequest>
288where
289    for<'de> ServiceRequest: Serialize + Deserialize<'de>,
290    for<'de> <<S as Service<ServiceRequest>>::Future as Future>::Output:
291        Serialize + Deserialize<'de>,
292    for<'de> Request: Serialize + Deserialize<'de>,
293    for<'de> Response: Serialize + Deserialize<'de>,
294{
295    /// Run the server loop with the provided [`AsyncRead`] and [`AsyncWrite`]. The server loop
296    /// serves remote RPC calls, and handles local RPC calls from client handles.
297    ///
298    /// # Example
299    ///
300    /// ```
301    /// use core::task::{Context, Poll};
302    ///
303    /// use tokio::sync::mpsc;
304    /// use tower::Service;
305    /// use tower_duplex::DuplexService;
306    ///
307    /// /// A Service that converts requests to lower or upper case
308    /// enum ChangeCase {
309    ///     ToLower,
310    ///     ToUpper,
311    /// }
312    ///
313    /// impl tower::Service<String> for ChangeCase {
314    ///     type Response = String;
315    ///     type Error = ();
316    ///     type Future = std::pin::Pin<
317    ///         Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
318    ///     >;
319    ///
320    ///     fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
321    ///         Poll::Ready(Ok(()))
322    ///     }
323    ///
324    ///     fn call(&mut self, req: String) -> Self::Future {
325    ///         let to_upper = matches!(self, ChangeCase::ToUpper);
326    ///         Box::pin(async move {
327    ///             if to_upper {
328    ///                 Ok(req.to_uppercase())
329    ///             } else {
330    ///                 Ok(req.to_lowercase())
331    ///             }
332    ///         })
333    ///     }
334    /// }
335    ///
336    /// #[tokio::main]
337    /// async fn main() {
338    ///     // `server1` handles serves requests from `server2` and converts strings to upper case.
339    ///     // It also forwards requests from `client1` to `server2`.
340    ///     let (server1, mut client1): (DuplexService<String, Result<String, ()>, _, _>, _) =
341    ///         DuplexService::new_pair(ChangeCase::ToUpper);
342    ///     // `server2` handles serves requests from `server1` and converts strings to lower case.
343    ///     // It also forwards requests from `client2` to `server1`.
344    ///     let (server2, mut client2): (DuplexService<String, Result<String, ()>, _, _>, _) =
345    ///         DuplexService::new_pair(ChangeCase::ToLower);
346    ///
347    ///     let ((r1, w1), (r2, w2)) = tokio::net::UnixStream::pair()
348    ///         .map(|(a, b)| (a.into_split(), b.into_split()))
349    ///         .unwrap();
350    ///
351    ///     tokio::spawn(server1.run(r1, w1));
352    ///     tokio::spawn(server2.run(r2, w2));
353    ///
354    ///     assert_eq!(
355    ///         client2.call("String".to_string()).await.unwrap().unwrap(),
356    ///         "STRING"
357    ///     );
358    ///
359    ///     assert_eq!(
360    ///         client1.call("String".to_string()).await.unwrap().unwrap(),
361    ///         "string"
362    ///     );
363    /// }
364    /// ```
365    pub async fn run<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
366        self,
367        reader: R,
368        writer: W,
369    ) -> Result<(), DuplexError<std::io::Error, std::io::Error>> {
370        let decoder: codec::FrameCodec<DuplexValue<ServiceRequest, Response>> = Default::default();
371        let encoder: codec::FrameCodec<
372            DuplexValue<Request, <<S as Service<ServiceRequest>>::Future as Future>::Output>,
373        > = Default::default();
374
375        let frame_reader = tokio_util::codec::FramedRead::new(reader, decoder);
376        let frame_writer = tokio_util::codec::FramedWrite::new(writer, encoder);
377
378        self.run_with(frame_reader, frame_writer).await
379    }
380}
381
382impl<Request, Response> Service<Request> for DuplexClient<Request, Response> {
383    type Response = Response;
384
385    type Error = oneshot::error::RecvError;
386
387    type Future = oneshot::Receiver<Response>;
388
389    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
390        Poll::Ready(Ok(()))
391    }
392
393    fn call(&mut self, req: Request) -> Self::Future {
394        let (reseponse_send, response_recv) = oneshot::channel();
395        // We ignore the send error here, because if send fails it just means the service has
396        // stopped, in which case our oneshot will immediately get dropped and an error returned
397        // from the future
398        let _ = self.sender.send((req, reseponse_send));
399        response_recv
400    }
401}
402
403impl<Request, Response> tower::load::Load for DuplexClient<Request, Response> {
404    type Metric = usize;
405
406    fn load(&self) -> Self::Metric {
407        self.load.load(SeqCst)
408    }
409}
410
411impl<Request, Response> DuplexClient<Request, Response> {
412    /// A call that does not require a mutable reference
413    pub fn do_call(
414        &self,
415        req: Request,
416    ) -> impl Future<Output = Result<Response, oneshot::error::RecvError>> {
417        let (reseponse_send, response_recv) = oneshot::channel();
418        // We ignore the send error here, because if send fails it just means the service has
419        // stopped, in which case our oneshot will immediately get dropped and an error returned
420        // from the future
421        let _ = self.sender.send((req, reseponse_send));
422        response_recv
423    }
424}