tarpc/
server.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7//! Provides a server that concurrently handles many connections sending multiplexed requests.
8
9use crate::{
10    cancellations::{cancellations, CanceledRequests, RequestCancellation},
11    context::{self, SpanExt},
12    trace,
13    util::TimeUntil,
14    ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport,
15};
16use ::tokio::sync::mpsc;
17use futures::{
18    future::{AbortRegistration, Abortable},
19    prelude::*,
20    ready,
21    stream::Fuse,
22    task::*,
23};
24use in_flight_requests::{AlreadyExistsError, InFlightRequests};
25use pin_project::pin_project;
26use std::{
27    convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime,
28};
29use tracing::{info_span, instrument::Instrument, Span};
30
31mod in_flight_requests;
32pub mod request_hook;
33#[cfg(test)]
34mod testing;
35
36/// Provides functionality to apply server limits.
37pub mod limits;
38
39/// Provides helper methods for streams of Channels.
40pub mod incoming;
41
42/// Settings that control the behavior of [channels](Channel).
43#[derive(Clone, Debug)]
44pub struct Config {
45    /// Controls the buffer size of the in-process channel over which a server's handlers send
46    /// responses to the [`Channel`]. In other words, this is the number of responses that can sit
47    /// in the outbound queue before request handlers begin blocking.
48    pub pending_response_buffer: usize,
49}
50
51impl Default for Config {
52    fn default() -> Self {
53        Config {
54            pending_response_buffer: 100,
55        }
56    }
57}
58
59impl Config {
60    /// Returns a channel backed by `transport` and configured with `self`.
61    pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
62    where
63        T: Transport<Response<Resp>, ClientMessage<Req>>,
64    {
65        BaseChannel::new(self, transport)
66    }
67}
68
69/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
70#[allow(async_fn_in_trait)]
71pub trait Serve {
72    /// Type of request.
73    type Req: RequestName;
74
75    /// Type of response.
76    type Resp;
77
78    /// Responds to a single request.
79    async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
80}
81
82/// A Serve wrapper around a Fn.
83#[derive(Debug)]
84pub struct ServeFn<Req, Resp, F> {
85    f: F,
86    data: PhantomData<fn(Req) -> Resp>,
87}
88
89impl<Req, Resp, F> Clone for ServeFn<Req, Resp, F>
90where
91    F: Clone,
92{
93    fn clone(&self) -> Self {
94        Self {
95            f: self.f.clone(),
96            data: PhantomData,
97        }
98    }
99}
100
101impl<Req, Resp, F> Copy for ServeFn<Req, Resp, F> where F: Copy {}
102
103/// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future<Output =
104/// Result<Resp, ServerError>>`.
105pub fn serve<Req, Resp, Fut, F>(f: F) -> ServeFn<Req, Resp, F>
106where
107    F: FnOnce(context::Context, Req) -> Fut,
108    Fut: Future<Output = Result<Resp, ServerError>>,
109{
110    ServeFn {
111        f,
112        data: PhantomData,
113    }
114}
115
116impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
117where
118    Req: RequestName,
119    F: FnOnce(context::Context, Req) -> Fut,
120    Fut: Future<Output = Result<Resp, ServerError>>,
121{
122    type Req = Req;
123    type Resp = Resp;
124
125    async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
126        (self.f)(ctx, req).await
127    }
128}
129
130/// BaseChannel is the standard implementation of a [`Channel`].
131///
132/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and
133/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for
134/// how to use channels.
135///
136/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation
137/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
138/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
139/// the corresponding in-flight requests and aborting their handlers).
140#[pin_project]
141pub struct BaseChannel<Req, Resp, T> {
142    config: Config,
143    /// Writes responses to the wire and reads requests off the wire.
144    #[pin]
145    transport: Fuse<T>,
146    /// In-flight requests that were dropped by the server before completion.
147    #[pin]
148    canceled_requests: CanceledRequests,
149    /// Notifies `canceled_requests` when a request is canceled.
150    request_cancellation: RequestCancellation,
151    /// Holds data necessary to clean up in-flight requests.
152    in_flight_requests: InFlightRequests,
153    /// Types the request and response.
154    ghost: PhantomData<(fn() -> Req, fn(Resp))>,
155}
156
157impl<Req, Resp, T> BaseChannel<Req, Resp, T>
158where
159    T: Transport<Response<Resp>, ClientMessage<Req>>,
160{
161    /// Creates a new channel backed by `transport` and configured with `config`.
162    pub fn new(config: Config, transport: T) -> Self {
163        let (request_cancellation, canceled_requests) = cancellations();
164        BaseChannel {
165            config,
166            transport: transport.fuse(),
167            canceled_requests,
168            request_cancellation,
169            in_flight_requests: InFlightRequests::default(),
170            ghost: PhantomData,
171        }
172    }
173
174    /// Creates a new channel backed by `transport` and configured with the defaults.
175    pub fn with_defaults(transport: T) -> Self {
176        Self::new(Config::default(), transport)
177    }
178
179    /// Returns the inner transport over which messages are sent and received.
180    pub fn get_ref(&self) -> &T {
181        self.transport.get_ref()
182    }
183
184    /// Returns the inner transport over which messages are sent and received.
185    pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
186        self.project().transport.get_pin_mut()
187    }
188
189    fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
190        self.as_mut().project().in_flight_requests
191    }
192
193    fn canceled_requests_pin_mut<'a>(
194        self: &'a mut Pin<&mut Self>,
195    ) -> Pin<&'a mut CanceledRequests> {
196        self.as_mut().project().canceled_requests
197    }
198
199    fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
200        self.as_mut().project().transport
201    }
202
203    fn start_request(
204        mut self: Pin<&mut Self>,
205        mut request: Request<Req>,
206    ) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
207        let span = info_span!(
208            "RPC",
209            rpc.trace_id = %request.context.trace_id(),
210            rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()),
211            otel.kind = "server",
212            otel.name = tracing::field::Empty,
213        );
214        span.set_context(&request.context);
215        request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
216            tracing::trace!(
217                "OpenTelemetry subscriber not installed; making unsampled \
218                            child context."
219            );
220            request.context.trace_context.new_child()
221        });
222        let entered = span.enter();
223        tracing::info!("ReceiveRequest");
224        let start = self.in_flight_requests_mut().start_request(
225            request.id,
226            request.context.deadline,
227            span.clone(),
228        );
229        match start {
230            Ok(abort_registration) => {
231                drop(entered);
232                Ok(TrackedRequest {
233                    abort_registration,
234                    span,
235                    response_guard: ResponseGuard {
236                        request_id: request.id,
237                        request_cancellation: self.request_cancellation.clone(),
238                        cancel: false,
239                    },
240                    request,
241                })
242            }
243            Err(AlreadyExistsError) => {
244                tracing::trace!("DuplicateRequest");
245                Err(AlreadyExistsError)
246            }
247        }
248    }
249}
250
251impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        write!(f, "BaseChannel")
254    }
255}
256
257/// A request tracked by a [`Channel`].
258#[derive(Debug)]
259pub struct TrackedRequest<Req> {
260    /// The request sent by the client.
261    pub request: Request<Req>,
262    /// A registration to abort a future when the [`Channel`] that produced this request stops
263    /// tracking it.
264    pub abort_registration: AbortRegistration,
265    /// A span representing the server processing of this request.
266    pub span: Span,
267    /// An inert response guard. Becomes active in an InFlightRequest.
268    pub response_guard: ResponseGuard,
269}
270
271/// The server end of an open connection with a client, receiving requests from, and sending
272/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management.
273///
274/// The ways to use a Channel, in order of simplest to most complex, is:
275/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
276///    do not have specific scheduling needs and whose services are `Send + 'static`.
277/// 2. [`Channel::requests`] - This method is best for those who need direct access to individual
278///    requests, or are not using `tokio`, or want control over [futures](Future) scheduling.
279///    [`Requests`] is a stream of [`InFlightRequests`](InFlightRequest), each which has an
280///    [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will
281///    automatically cease when either the request deadline is reached or when a corresponding
282///    cancellation message is received by the Channel.
283/// 3. [`Stream::next`](futures::stream::StreamExt::next) /
284///    [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
285///    from, and send responses into, a Channel in lieu of the previous methods. Channels stream
286///    [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
287///    server [`Span`], request lifetime [`AbortRegistration`], and an inert [`ResponseGuard`].
288///    Wrapping response logic in an [`Abortable`] future using the abort registration will ensure
289///    that the response does not execute longer than the request deadline. The `Channel` itself
290///    will clean up request state once either the deadline expires, or the response guard is
291///    dropped, or a response is sent.
292///
293/// Channels must be implemented using the decorator pattern: the only way to create a
294/// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are
295/// created by [`BaseChannel`].
296pub trait Channel
297where
298    Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
299{
300    /// Type of request item.
301    type Req;
302
303    /// Type of response sink item.
304    type Resp;
305
306    /// The wrapped transport.
307    type Transport;
308
309    /// Configuration of the channel.
310    fn config(&self) -> &Config;
311
312    /// Returns the number of in-flight requests over this channel.
313    fn in_flight_requests(&self) -> usize;
314
315    /// Returns the transport underlying the channel.
316    fn transport(&self) -> &Self::Transport;
317
318    /// Caps the number of concurrent requests to `limit`. An error will be returned for requests
319    /// over the concurrency limit.
320    ///
321    /// Note that this is a very
322    /// simplistic throttling heuristic. It is easy to set a number that is too low for the
323    /// resources available to the server. For production use cases, a more advanced throttler is
324    /// likely needed.
325    fn max_concurrent_requests(
326        self,
327        limit: usize,
328    ) -> limits::requests_per_channel::MaxRequests<Self>
329    where
330        Self: Sized,
331    {
332        limits::requests_per_channel::MaxRequests::new(self, limit)
333    }
334
335    /// Returns a stream of requests that automatically handle request cancellation and response
336    /// routing.
337    ///
338    /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved,
339    /// and the only way to complete requests is via [`Requests::execute`] or
340    /// [`InFlightRequest::execute`].
341    ///
342    /// # Example
343    ///
344    /// ```rust
345    /// use tarpc::{
346    ///     context,
347    ///     client::{self, NewClient},
348    ///     server::{self, BaseChannel, Channel, serve},
349    ///     transport,
350    /// };
351    /// use futures::prelude::*;
352    ///
353    /// #[tokio::main]
354    /// async fn main() {
355    ///     let (tx, rx) = transport::channel::unbounded();
356    ///     let server = BaseChannel::new(server::Config::default(), rx);
357    ///     let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
358    ///     tokio::spawn(dispatch);
359    ///
360    ///     let mut requests = server.requests();
361    ///     tokio::spawn(async move {
362    ///         while let Some(Ok(request)) = requests.next().await {
363    ///             tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) })));
364    ///         }
365    ///     });
366    ///     assert_eq!(client.call(context::current(), 1).await.unwrap(), 2);
367    /// }
368    /// ```
369    fn requests(self) -> Requests<Self>
370    where
371        Self: Sized,
372    {
373        let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
374
375        Requests {
376            channel: self,
377            pending_responses: responses,
378            responses_tx,
379        }
380    }
381
382    /// Returns a stream of request execution futures. Each future represents an in-flight request
383    /// being responded to by the server. The futures must be awaited or spawned to complete their
384    /// requests.
385    ///
386    /// # Example
387    ///
388    /// ```rust
389    /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
390    /// use futures::prelude::*;
391    /// use tracing_subscriber::prelude::*;
392    ///
393    /// # #[cfg(not(feature = "tokio1"))]
394    /// # fn main() {}
395    /// # #[cfg(feature = "tokio1")]
396    /// #[tokio::main]
397    /// async fn main() {
398    ///     let (tx, rx) = transport::channel::unbounded();
399    ///     let client = client::new(client::Config::default(), tx).spawn();
400    ///     let channel = BaseChannel::with_defaults(rx);
401    ///     tokio::spawn(
402    ///         channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }))
403    ///            .for_each(|response| async move {
404    ///                tokio::spawn(response);
405    ///            }));
406    ///     assert_eq!(
407    ///         client.call(context::current(), 1).await.unwrap(),
408    ///         2);
409    /// }
410    /// ```
411    fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
412    where
413        Self: Sized,
414        Self::Req: RequestName,
415        S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
416    {
417        self.requests().execute(serve)
418    }
419}
420
421impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
422where
423    T: Transport<Response<Resp>, ClientMessage<Req>>,
424{
425    type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
426
427    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
428        #[derive(Clone, Copy, Debug)]
429        enum ReceiverStatus {
430            Ready,
431            Pending,
432            Closed,
433        }
434
435        impl ReceiverStatus {
436            fn combine(self, other: Self) -> Self {
437                use ReceiverStatus::*;
438                match (self, other) {
439                    (Ready, _) | (_, Ready) => Ready,
440                    (Closed, Closed) => Closed,
441                    (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
442                }
443            }
444        }
445
446        use ReceiverStatus::*;
447
448        loop {
449            let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
450                Poll::Ready(Some(request_id)) => {
451                    if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
452                        let _entered = span.enter();
453                        tracing::info!("ResponseCancelled");
454                    }
455                    Ready
456                }
457                // Pending cancellations don't block Channel closure, because all they do is ensure
458                // the Channel's internal state is cleaned up. But Channel closure also cleans up
459                // the Channel state, so there's no reason to wait on a cancellation before
460                // closing.
461                //
462                // Ready(None) can't happen, since `self` holds a Cancellation.
463                Poll::Pending | Poll::Ready(None) => Closed,
464            };
465
466            let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) {
467                // No need to send a response, since the client wouldn't be waiting for one
468                // anymore.
469                Poll::Ready(Some(_)) => Ready,
470                Poll::Ready(None) => Closed,
471                Poll::Pending => Pending,
472            };
473
474            let request_status = match self
475                .transport_pin_mut()
476                .poll_next(cx)
477                .map_err(|e| ChannelError::Read(Arc::new(e)))?
478            {
479                Poll::Ready(Some(message)) => match message {
480                    ClientMessage::Request(request) => {
481                        match self.as_mut().start_request(request) {
482                            Ok(request) => return Poll::Ready(Some(Ok(request))),
483                            Err(AlreadyExistsError) => {
484                                // Instead of closing the channel if a duplicate request is sent,
485                                // just ignore it, since it's already being processed. Note that we
486                                // cannot return Poll::Pending here, since nothing has scheduled a
487                                // wakeup yet.
488                                continue;
489                            }
490                        }
491                    }
492                    ClientMessage::Cancel {
493                        trace_context,
494                        request_id,
495                    } => {
496                        if !self.in_flight_requests_mut().cancel_request(request_id) {
497                            tracing::trace!(
498                                rpc.trace_id = %trace_context.trace_id,
499                                "Received cancellation, but response handler is already complete.",
500                            );
501                        }
502                        Ready
503                    }
504                },
505                Poll::Ready(None) => Closed,
506                Poll::Pending => Pending,
507            };
508
509            let status = cancellation_status
510                .combine(expiration_status)
511                .combine(request_status);
512
513            tracing::trace!(
514                "Cancellations: {cancellation_status:?}, \
515                Expired requests: {expiration_status:?}, \
516                Inbound: {request_status:?}, \
517                Overall: {status:?}",
518            );
519            match status {
520                Ready => continue,
521                Closed => return Poll::Ready(None),
522                Pending => return Poll::Pending,
523            }
524        }
525    }
526}
527
528impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
529where
530    T: Transport<Response<Resp>, ClientMessage<Req>>,
531    T::Error: Error,
532{
533    type Error = ChannelError<T::Error>;
534
535    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
536        self.project()
537            .transport
538            .poll_ready(cx)
539            .map_err(|e| ChannelError::Ready(Arc::new(e)))
540    }
541
542    fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
543        if let Some(span) = self
544            .in_flight_requests_mut()
545            .remove_request(response.request_id)
546        {
547            let _entered = span.enter();
548            tracing::info!("SendResponse");
549            self.project()
550                .transport
551                .start_send(response)
552                .map_err(|e| ChannelError::Write(Arc::new(e)))
553        } else {
554            // If the request isn't tracked anymore, there's no need to send the response.
555            Ok(())
556        }
557    }
558
559    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
560        tracing::trace!("poll_flush");
561        self.project()
562            .transport
563            .poll_flush(cx)
564            .map_err(|e| ChannelError::Flush(Arc::new(e)))
565    }
566
567    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
568        self.project()
569            .transport
570            .poll_close(cx)
571            .map_err(|e| ChannelError::Close(Arc::new(e)))
572    }
573}
574
575impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
576    fn as_ref(&self) -> &T {
577        self.transport.get_ref()
578    }
579}
580
581impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
582where
583    T: Transport<Response<Resp>, ClientMessage<Req>>,
584{
585    type Req = Req;
586    type Resp = Resp;
587    type Transport = T;
588
589    fn config(&self) -> &Config {
590        &self.config
591    }
592
593    fn in_flight_requests(&self) -> usize {
594        self.in_flight_requests.len()
595    }
596
597    fn transport(&self) -> &Self::Transport {
598        self.get_ref()
599    }
600}
601
602/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
603/// it must be continually polled to ensure progress.
604#[pin_project]
605pub struct Requests<C>
606where
607    C: Channel,
608{
609    #[pin]
610    channel: C,
611    /// Responses waiting to be written to the wire.
612    pending_responses: mpsc::Receiver<Response<C::Resp>>,
613    /// Handed out to request handlers to fan in responses.
614    responses_tx: mpsc::Sender<Response<C::Resp>>,
615}
616
617impl<C> Requests<C>
618where
619    C: Channel,
620{
621    /// Returns a reference to the inner channel over which messages are sent and received.
622    pub fn channel(&self) -> &C {
623        &self.channel
624    }
625
626    /// Returns the inner channel over which messages are sent and received.
627    pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
628        self.as_mut().project().channel
629    }
630
631    /// Returns the inner channel over which messages are sent and received.
632    pub fn pending_responses_mut<'a>(
633        self: &'a mut Pin<&mut Self>,
634    ) -> &'a mut mpsc::Receiver<Response<C::Resp>> {
635        self.as_mut().project().pending_responses
636    }
637
638    fn pump_read(
639        mut self: Pin<&mut Self>,
640        cx: &mut Context<'_>,
641    ) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
642        self.channel_pin_mut().poll_next(cx).map_ok(
643            |TrackedRequest {
644                 request,
645                 abort_registration,
646                 span,
647                 mut response_guard,
648             }| {
649                // The response guard becomes active once in an InFlightRequest.
650                response_guard.cancel = true;
651                {
652                    let _entered = span.enter();
653                    tracing::info!("BeginRequest");
654                }
655                InFlightRequest {
656                    request,
657                    abort_registration,
658                    span,
659                    response_guard,
660                    response_tx: self.responses_tx.clone(),
661                }
662            },
663        )
664    }
665
666    fn pump_write(
667        mut self: Pin<&mut Self>,
668        cx: &mut Context<'_>,
669        read_half_closed: bool,
670    ) -> Poll<Option<Result<(), C::Error>>> {
671        match self.as_mut().poll_next_response(cx)? {
672            Poll::Ready(Some(response)) => {
673                // A Ready result from poll_next_response means the Channel is ready to be written
674                // to. Therefore, we can call start_send without worry of a full buffer.
675                self.channel_pin_mut().start_send(response)?;
676                Poll::Ready(Some(Ok(())))
677            }
678            Poll::Ready(None) => {
679                // Shutdown can't be done before we finish pumping out remaining responses.
680                ready!(self.channel_pin_mut().poll_flush(cx)?);
681                Poll::Ready(None)
682            }
683            Poll::Pending => {
684                // No more requests to process, so flush any requests buffered in the transport.
685                ready!(self.channel_pin_mut().poll_flush(cx)?);
686
687                // Being here means there are no staged requests and all written responses are
688                // fully flushed. So, if the read half is closed and there are no in-flight
689                // requests, then we can close the write half.
690                if read_half_closed && self.channel.in_flight_requests() == 0 {
691                    Poll::Ready(None)
692                } else {
693                    Poll::Pending
694                }
695            }
696        }
697    }
698
699    /// Yields a response ready to be written to the Channel sink.
700    ///
701    /// Note that a response will only be yielded if the Channel is *ready* to be written to (i.e.
702    /// start_send would succeed).
703    fn poll_next_response(
704        mut self: Pin<&mut Self>,
705        cx: &mut Context<'_>,
706    ) -> Poll<Option<Result<Response<C::Resp>, C::Error>>> {
707        ready!(self.ensure_writeable(cx)?);
708
709        match ready!(self.pending_responses_mut().poll_recv(cx)) {
710            Some(response) => Poll::Ready(Some(Ok(response))),
711            None => {
712                // This branch likely won't happen, since the Requests stream is holding a Sender.
713                Poll::Ready(None)
714            }
715        }
716    }
717
718    /// Returns Ready if writing a message to the Channel would not fail due to a full buffer. If
719    /// the Channel is not ready to be written to, flushes it until it is ready.
720    fn ensure_writeable<'a>(
721        self: &'a mut Pin<&mut Self>,
722        cx: &mut Context<'_>,
723    ) -> Poll<Option<Result<(), C::Error>>> {
724        while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
725            ready!(self.channel_pin_mut().poll_flush(cx)?);
726        }
727        Poll::Ready(Some(Ok(())))
728    }
729
730    /// Returns a stream of request execution futures. Each future represents an in-flight request
731    /// being responded to by the server. The futures must be awaited or spawned to complete their
732    /// requests.
733    ///
734    /// If the channel encounters an error, the stream is terminated and the error is logged.
735    ///
736    /// # Example
737    ///
738    /// ```rust
739    /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport};
740    /// use futures::prelude::*;
741    ///
742    /// # #[cfg(not(feature = "tokio1"))]
743    /// # fn main() {}
744    /// # #[cfg(feature = "tokio1")]
745    /// #[tokio::main]
746    /// async fn main() {
747    ///     let (tx, rx) = transport::channel::unbounded();
748    ///     let requests = BaseChannel::new(server::Config::default(), rx).requests();
749    ///     let client = client::new(client::Config::default(), tx).spawn();
750    ///     tokio::spawn(
751    ///         requests.execute(serve(|_, i| async move { Ok(i + 1) }))
752    ///            .for_each(|response| async move {
753    ///                tokio::spawn(response);
754    ///            }));
755    ///     assert_eq!(client.call(context::current(), 1).await.unwrap(), 2);
756    /// }
757    /// ```
758    pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
759    where
760        C::Req: RequestName,
761        S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
762    {
763        self.take_while(|result| {
764            if let Err(e) = result {
765                tracing::warn!("Requests stream errored out: {}", e);
766            }
767            futures::future::ready(result.is_ok())
768        })
769        .filter_map(|result| async move { result.ok() })
770        .map(move |request| {
771            let serve = serve.clone();
772            request.execute(serve)
773        })
774    }
775}
776
777impl<C> fmt::Debug for Requests<C>
778where
779    C: Channel,
780{
781    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
782        write!(fmt, "Requests")
783    }
784}
785
786/// A fail-safe to ensure requests are properly canceled if request processing is aborted before
787/// completing.
788#[derive(Debug)]
789pub struct ResponseGuard {
790    request_cancellation: RequestCancellation,
791    request_id: u64,
792    cancel: bool,
793}
794
795impl Drop for ResponseGuard {
796    fn drop(&mut self) {
797        if self.cancel {
798            self.request_cancellation.cancel(self.request_id);
799        }
800    }
801}
802
803/// A request produced by [Channel::requests].
804///
805/// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will
806/// be sent to the Channel to clean up associated request state.
807#[derive(Debug)]
808pub struct InFlightRequest<Req, Res> {
809    request: Request<Req>,
810    abort_registration: AbortRegistration,
811    response_guard: ResponseGuard,
812    span: Span,
813    response_tx: mpsc::Sender<Response<Res>>,
814}
815
816impl<Req, Res> InFlightRequest<Req, Res> {
817    /// Returns a reference to the request.
818    pub fn get(&self) -> &Request<Req> {
819        &self.request
820    }
821
822    /// Returns a [future](Future) that executes the request using the given [service
823    /// function](Serve). The service function's output is automatically sent back to the [Channel]
824    /// that yielded this request. The request will be executed in the scope of this request's
825    /// context.
826    ///
827    /// The returned future will stop executing when the first of the following conditions is met:
828    ///
829    /// 1. The channel that yielded this request receives a [cancellation
830    ///    message](ClientMessage::Cancel) for this request.
831    /// 2. The request [deadline](crate::context::Context::deadline) is reached.
832    /// 3. The service function completes.
833    ///
834    /// If the returned Future is dropped before completion, a cancellation message will be sent to
835    /// the Channel to clean up associated request state.
836    ///
837    /// # Example
838    ///
839    /// ```rust
840    /// use tarpc::{
841    ///     context,
842    ///     client::{self, NewClient},
843    ///     server::{self, BaseChannel, Channel, serve},
844    ///     transport,
845    /// };
846    /// use futures::prelude::*;
847    ///
848    /// #[tokio::main]
849    /// async fn main() {
850    ///     let (tx, rx) = transport::channel::unbounded();
851    ///     let server = BaseChannel::new(server::Config::default(), rx);
852    ///     let NewClient { client, dispatch } = client::new(client::Config::default(), tx);
853    ///     tokio::spawn(dispatch);
854    ///
855    ///     tokio::spawn(async move {
856    ///         let mut requests = server.requests();
857    ///         while let Some(Ok(in_flight_request)) = requests.next().await {
858    ///             in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await;
859    ///         }
860    ///
861    ///     });
862    ///     assert_eq!(client.call(context::current(), 1).await.unwrap(), 2);
863    /// }
864    /// ```
865    ///
866    pub async fn execute<S>(self, serve: S)
867    where
868        Req: RequestName,
869        S: Serve<Req = Req, Resp = Res>,
870    {
871        let Self {
872            response_tx,
873            mut response_guard,
874            abort_registration,
875            span,
876            request:
877                Request {
878                    context,
879                    message,
880                    id: request_id,
881                },
882        } = self;
883        span.record("otel.name", message.name());
884        let _ = Abortable::new(
885            async move {
886                let message = serve.serve(context, message).await;
887                tracing::info!("CompleteRequest");
888                let response = Response {
889                    request_id,
890                    message,
891                };
892                let _ = response_tx.send(response).await;
893                tracing::info!("BufferResponse");
894            },
895            abort_registration,
896        )
897        .instrument(span)
898        .await;
899        // Request processing has completed, meaning either the channel canceled the request or
900        // a request was sent back to the channel. Either way, the channel will clean up the
901        // request data, so the request does not need to be canceled.
902        response_guard.cancel = false;
903    }
904}
905
906fn print_err(e: &(dyn Error + 'static)) -> String {
907    anyhow::Chain::new(e)
908        .map(|e| e.to_string())
909        .collect::<Vec<_>>()
910        .join(": ")
911}
912
913impl<C> Stream for Requests<C>
914where
915    C: Channel,
916{
917    type Item = Result<InFlightRequest<C::Req, C::Resp>, C::Error>;
918
919    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
920        loop {
921            let read = self.as_mut().pump_read(cx).map_err(|e| {
922                tracing::trace!("read: {}", print_err(&e));
923                e
924            })?;
925            let read_closed = matches!(read, Poll::Ready(None));
926            let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
927                tracing::trace!("write: {}", print_err(&e));
928                e
929            })?;
930            match (read, write) {
931                (Poll::Ready(None), Poll::Ready(None)) => {
932                    tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
933                    return Poll::Ready(None);
934                }
935                (Poll::Ready(Some(request_handler)), _) => {
936                    tracing::trace!("read: Poll::Ready(Some), write: _");
937                    return Poll::Ready(Some(Ok(request_handler)));
938                }
939                (_, Poll::Ready(Some(()))) => {
940                    tracing::trace!("read: _, write: Poll::Ready(Some)");
941                }
942                (read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
943                    tracing::trace!(
944                        "read pending: {}, write pending: {}",
945                        read.is_pending(),
946                        write.is_pending()
947                    );
948                    return Poll::Pending;
949                }
950            }
951        }
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::{
958        in_flight_requests::AlreadyExistsError,
959        request_hook::{AfterRequest, BeforeRequest, RequestHook},
960        serve, BaseChannel, Channel, Config, Requests, Serve,
961    };
962    use crate::{
963        context, trace,
964        transport::channel::{self, UnboundedChannel},
965        ClientMessage, Request, Response, ServerError,
966    };
967    use assert_matches::assert_matches;
968    use futures::{
969        future::{pending, AbortRegistration, Abortable, Aborted},
970        prelude::*,
971        Future,
972    };
973    use futures_test::task::noop_context;
974    use std::{
975        io,
976        pin::Pin,
977        task::Poll,
978        time::{Duration, Instant},
979    };
980
981    fn test_channel<Req, Resp>() -> (
982        Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
983        UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
984    ) {
985        let (tx, rx) = crate::transport::channel::unbounded();
986        (Box::pin(BaseChannel::new(Config::default(), rx)), tx)
987    }
988
989    fn test_requests<Req, Resp>() -> (
990        Pin<
991            Box<
992                Requests<
993                    BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>,
994                >,
995            >,
996        >,
997        UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
998    ) {
999        let (tx, rx) = crate::transport::channel::unbounded();
1000        (
1001            Box::pin(BaseChannel::new(Config::default(), rx).requests()),
1002            tx,
1003        )
1004    }
1005
1006    fn test_bounded_requests<Req, Resp>(
1007        capacity: usize,
1008    ) -> (
1009        Pin<
1010            Box<
1011                Requests<
1012                    BaseChannel<Req, Resp, channel::Channel<ClientMessage<Req>, Response<Resp>>>,
1013                >,
1014            >,
1015        >,
1016        channel::Channel<Response<Resp>, ClientMessage<Req>>,
1017    ) {
1018        let (tx, rx) = crate::transport::channel::bounded(capacity);
1019        // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded).
1020        let config = Config {
1021            pending_response_buffer: capacity + 1,
1022        };
1023        (Box::pin(BaseChannel::new(config, rx).requests()), tx)
1024    }
1025
1026    fn fake_request<Req>(req: Req) -> ClientMessage<Req> {
1027        ClientMessage::Request(Request {
1028            context: context::current(),
1029            id: 0,
1030            message: req,
1031        })
1032    }
1033
1034    fn test_abortable(
1035        abort_registration: AbortRegistration,
1036    ) -> impl Future<Output = Result<(), Aborted>> {
1037        Abortable::new(pending(), abort_registration)
1038    }
1039
1040    #[tokio::test]
1041    async fn test_serve() {
1042        let serve = serve(|_, i| async move { Ok(i) });
1043        assert_matches!(serve.serve(context::current(), 7).await, Ok(7));
1044    }
1045
1046    #[tokio::test]
1047    async fn serve_before_mutates_context() -> anyhow::Result<()> {
1048        struct SetDeadline(Instant);
1049        impl<Req> BeforeRequest<Req> for SetDeadline {
1050            async fn before(
1051                &mut self,
1052                ctx: &mut context::Context,
1053                _: &Req,
1054            ) -> Result<(), ServerError> {
1055                ctx.deadline = self.0;
1056                Ok(())
1057            }
1058        }
1059
1060        let some_time = Instant::now() + Duration::from_secs(37);
1061        let some_other_time = Instant::now() + Duration::from_secs(83);
1062
1063        let serve = serve(move |ctx: context::Context, i| async move {
1064            assert_eq!(ctx.deadline, some_time);
1065            Ok(i)
1066        });
1067        let deadline_hook = serve.before(SetDeadline(some_time));
1068        let mut ctx = context::current();
1069        ctx.deadline = some_other_time;
1070        deadline_hook.serve(ctx, 7).await?;
1071        Ok(())
1072    }
1073
1074    #[tokio::test]
1075    async fn serve_before_and_after() -> anyhow::Result<()> {
1076        let _ = tracing_subscriber::fmt::try_init();
1077
1078        struct PrintLatency {
1079            start: Instant,
1080        }
1081        impl PrintLatency {
1082            fn new() -> Self {
1083                Self {
1084                    start: Instant::now(),
1085                }
1086            }
1087        }
1088        impl<Req> BeforeRequest<Req> for PrintLatency {
1089            async fn before(
1090                &mut self,
1091                _: &mut context::Context,
1092                _: &Req,
1093            ) -> Result<(), ServerError> {
1094                self.start = Instant::now();
1095                Ok(())
1096            }
1097        }
1098        impl<Resp> AfterRequest<Resp> for PrintLatency {
1099            async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
1100                tracing::info!("Elapsed: {:?}", self.start.elapsed());
1101            }
1102        }
1103
1104        let serve = serve(move |_: context::Context, i| async move { Ok(i) });
1105        serve
1106            .before_and_after(PrintLatency::new())
1107            .serve(context::current(), 7)
1108            .await?;
1109        Ok(())
1110    }
1111
1112    #[tokio::test]
1113    async fn serve_before_error_aborts_request() -> anyhow::Result<()> {
1114        let serve = serve(|_, _| async { panic!("Shouldn't get here") });
1115        let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async {
1116            Err(ServerError::new(io::ErrorKind::Other, "oops".into()))
1117        });
1118        let resp: Result<i32, _> = deadline_hook.serve(context::current(), 7).await;
1119        assert_matches!(resp, Err(_));
1120        Ok(())
1121    }
1122
1123    #[tokio::test]
1124    async fn base_channel_start_send_duplicate_request_returns_error() {
1125        let (mut channel, _tx) = test_channel::<(), ()>();
1126
1127        channel
1128            .as_mut()
1129            .start_request(Request {
1130                id: 0,
1131                context: context::current(),
1132                message: (),
1133            })
1134            .unwrap();
1135        assert_matches!(
1136            channel.as_mut().start_request(Request {
1137                id: 0,
1138                context: context::current(),
1139                message: ()
1140            }),
1141            Err(AlreadyExistsError)
1142        );
1143    }
1144
1145    #[tokio::test]
1146    async fn base_channel_poll_next_aborts_multiple_requests() {
1147        let (mut channel, _tx) = test_channel::<(), ()>();
1148
1149        tokio::time::pause();
1150        let req0 = channel
1151            .as_mut()
1152            .start_request(Request {
1153                id: 0,
1154                context: context::current(),
1155                message: (),
1156            })
1157            .unwrap();
1158        let req1 = channel
1159            .as_mut()
1160            .start_request(Request {
1161                id: 1,
1162                context: context::current(),
1163                message: (),
1164            })
1165            .unwrap();
1166        tokio::time::advance(std::time::Duration::from_secs(1000)).await;
1167
1168        assert_matches!(
1169            channel.as_mut().poll_next(&mut noop_context()),
1170            Poll::Pending
1171        );
1172        assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
1173        assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
1174    }
1175
1176    #[tokio::test]
1177    async fn base_channel_poll_next_aborts_canceled_request() {
1178        let (mut channel, mut tx) = test_channel::<(), ()>();
1179
1180        tokio::time::pause();
1181        let req = channel
1182            .as_mut()
1183            .start_request(Request {
1184                id: 0,
1185                context: context::current(),
1186                message: (),
1187            })
1188            .unwrap();
1189
1190        tx.send(ClientMessage::Cancel {
1191            trace_context: trace::Context::default(),
1192            request_id: 0,
1193        })
1194        .await
1195        .unwrap();
1196
1197        assert_matches!(
1198            channel.as_mut().poll_next(&mut noop_context()),
1199            Poll::Pending
1200        );
1201
1202        assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
1203    }
1204
1205    #[tokio::test]
1206    async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() {
1207        let (mut channel, tx) = test_channel::<(), ()>();
1208
1209        tokio::time::pause();
1210        let _abort_registration = channel
1211            .as_mut()
1212            .start_request(Request {
1213                id: 0,
1214                context: context::current(),
1215                message: (),
1216            })
1217            .unwrap();
1218
1219        drop(tx);
1220        assert_matches!(
1221            channel.as_mut().poll_next(&mut noop_context()),
1222            Poll::Pending
1223        );
1224    }
1225
1226    #[tokio::test]
1227    async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() {
1228        let (mut channel, tx) = test_channel::<(), ()>();
1229        drop(tx);
1230        assert_matches!(
1231            channel.as_mut().poll_next(&mut noop_context()),
1232            Poll::Ready(None)
1233        );
1234    }
1235
1236    #[tokio::test]
1237    async fn base_channel_poll_next_yields_request() {
1238        let (mut channel, mut tx) = test_channel::<(), ()>();
1239        tx.send(fake_request(())).await.unwrap();
1240
1241        assert_matches!(
1242            channel.as_mut().poll_next(&mut noop_context()),
1243            Poll::Ready(Some(Ok(_)))
1244        );
1245    }
1246
1247    #[tokio::test]
1248    async fn base_channel_poll_next_aborts_request_and_yields_request() {
1249        let (mut channel, mut tx) = test_channel::<(), ()>();
1250
1251        tokio::time::pause();
1252        let req = channel
1253            .as_mut()
1254            .start_request(Request {
1255                id: 0,
1256                context: context::current(),
1257                message: (),
1258            })
1259            .unwrap();
1260        tokio::time::advance(std::time::Duration::from_secs(1000)).await;
1261
1262        tx.send(fake_request(())).await.unwrap();
1263
1264        assert_matches!(
1265            channel.as_mut().poll_next(&mut noop_context()),
1266            Poll::Ready(Some(Ok(_)))
1267        );
1268        assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
1269    }
1270
1271    #[tokio::test]
1272    async fn base_channel_start_send_removes_in_flight_request() {
1273        let (mut channel, _tx) = test_channel::<(), ()>();
1274
1275        channel
1276            .as_mut()
1277            .start_request(Request {
1278                id: 0,
1279                context: context::current(),
1280                message: (),
1281            })
1282            .unwrap();
1283        assert_eq!(channel.in_flight_requests(), 1);
1284        channel
1285            .as_mut()
1286            .start_send(Response {
1287                request_id: 0,
1288                message: Ok(()),
1289            })
1290            .unwrap();
1291        assert_eq!(channel.in_flight_requests(), 0);
1292    }
1293
1294    #[tokio::test]
1295    async fn in_flight_request_drop_cancels_request() {
1296        let (mut requests, mut tx) = test_requests::<(), ()>();
1297        tx.send(fake_request(())).await.unwrap();
1298
1299        let request = match requests.as_mut().poll_next(&mut noop_context()) {
1300            Poll::Ready(Some(Ok(request))) => request,
1301            result => panic!("Unexpected result: {result:?}"),
1302        };
1303        drop(request);
1304
1305        let poll = requests
1306            .as_mut()
1307            .channel_pin_mut()
1308            .poll_next(&mut noop_context());
1309        assert!(poll.is_pending());
1310        let in_flight_requests = requests.channel().in_flight_requests();
1311        assert_eq!(in_flight_requests, 0);
1312    }
1313
1314    #[tokio::test]
1315    async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
1316        let (mut requests, mut tx) = test_requests::<(), ()>();
1317        tx.send(fake_request(())).await.unwrap();
1318
1319        let request = match requests.as_mut().poll_next(&mut noop_context()) {
1320            Poll::Ready(Some(Ok(request))) => request,
1321            result => panic!("Unexpected result: {result:?}"),
1322        };
1323        request.execute(serve(|_, _| async { Ok(()) })).await;
1324        assert!(requests
1325            .as_mut()
1326            .channel_pin_mut()
1327            .canceled_requests
1328            .poll_recv(&mut noop_context())
1329            .is_pending());
1330    }
1331
1332    #[tokio::test]
1333    async fn requests_poll_next_response_returns_pending_when_buffer_full() {
1334        let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1335
1336        // Response written to the transport.
1337        requests
1338            .as_mut()
1339            .channel_pin_mut()
1340            .start_request(Request {
1341                id: 0,
1342                context: context::current(),
1343                message: (),
1344            })
1345            .unwrap();
1346        requests
1347            .as_mut()
1348            .channel_pin_mut()
1349            .start_send(Response {
1350                request_id: 0,
1351                message: Ok(()),
1352            })
1353            .unwrap();
1354
1355        // Response waiting to be written.
1356        requests
1357            .as_mut()
1358            .project()
1359            .responses_tx
1360            .send(Response {
1361                request_id: 1,
1362                message: Ok(()),
1363            })
1364            .await
1365            .unwrap();
1366
1367        requests
1368            .as_mut()
1369            .channel_pin_mut()
1370            .start_request(Request {
1371                id: 1,
1372                context: context::current(),
1373                message: (),
1374            })
1375            .unwrap();
1376
1377        assert_matches!(
1378            requests.as_mut().poll_next_response(&mut noop_context()),
1379            Poll::Pending
1380        );
1381    }
1382
1383    #[tokio::test]
1384    async fn requests_pump_write_returns_pending_when_buffer_full() {
1385        let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1386
1387        // Response written to the transport.
1388        requests
1389            .as_mut()
1390            .channel_pin_mut()
1391            .start_request(Request {
1392                id: 0,
1393                context: context::current(),
1394                message: (),
1395            })
1396            .unwrap();
1397        requests
1398            .as_mut()
1399            .channel_pin_mut()
1400            .start_send(Response {
1401                request_id: 0,
1402                message: Ok(()),
1403            })
1404            .unwrap();
1405
1406        // Response waiting to be written.
1407        requests
1408            .as_mut()
1409            .channel_pin_mut()
1410            .start_request(Request {
1411                id: 1,
1412                context: context::current(),
1413                message: (),
1414            })
1415            .unwrap();
1416        requests
1417            .as_mut()
1418            .project()
1419            .responses_tx
1420            .send(Response {
1421                request_id: 1,
1422                message: Ok(()),
1423            })
1424            .await
1425            .unwrap();
1426
1427        assert_matches!(
1428            requests.as_mut().pump_write(&mut noop_context(), true),
1429            Poll::Pending
1430        );
1431        // Assert that the pending response was not polled while the channel was blocked.
1432        assert_matches!(
1433            requests.as_mut().pending_responses_mut().recv().await,
1434            Some(_)
1435        );
1436    }
1437
1438    #[tokio::test]
1439    async fn requests_pump_read() {
1440        let (mut requests, mut tx) = test_requests::<(), ()>();
1441
1442        // Response written to the transport.
1443        tx.send(fake_request(())).await.unwrap();
1444
1445        assert_matches!(
1446            requests.as_mut().pump_read(&mut noop_context()),
1447            Poll::Ready(Some(Ok(_)))
1448        );
1449        assert_eq!(requests.channel.in_flight_requests(), 1);
1450    }
1451}