tarpc_lib/client/
channel.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7use crate::{
8    context,
9    util::{Compact, TimeUntil},
10    ClientMessage, PollIo, Request, Response, Transport,
11};
12use fnv::FnvHashMap;
13use futures::{
14    channel::{mpsc, oneshot},
15    prelude::*,
16    ready,
17    stream::Fuse,
18    task::Context,
19    Poll,
20};
21use log::{debug, info, trace};
22use pin_utils::{unsafe_pinned, unsafe_unpinned};
23use std::{
24    io,
25    marker::Unpin,
26    pin::Pin,
27    sync::{
28        atomic::{AtomicU64, Ordering},
29        Arc,
30    },
31};
32use tokio_timer::{timeout, Timeout};
33use trace::SpanId;
34
35use super::{Config, NewClient};
36
37/// Handles communication from the client to request dispatch.
38#[derive(Debug)]
39pub struct Channel<Req, Resp> {
40    to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
41    /// Channel to send a cancel message to the dispatcher.
42    cancellation: RequestCancellation,
43    /// The ID to use for the next request to stage.
44    next_request_id: Arc<AtomicU64>,
45}
46
47impl<Req, Resp> Clone for Channel<Req, Resp> {
48    fn clone(&self) -> Self {
49        Self {
50            to_dispatch: self.to_dispatch.clone(),
51            cancellation: self.cancellation.clone(),
52            next_request_id: self.next_request_id.clone(),
53        }
54    }
55}
56
57/// A future returned by [`Channel::send`] that resolves to a server response.
58#[derive(Debug)]
59#[must_use = "futures do nothing unless polled"]
60struct Send<'a, Req, Resp> {
61    fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
62}
63
64type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
65    futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
66>;
67
68impl<'a, Req, Resp> Send<'a, Req, Resp> {
69    unsafe_pinned!(
70        fut: MapOkDispatchResponse<
71            MapErrConnectionReset<
72                futures::sink::Send<
73                    'a,
74                    mpsc::Sender<DispatchRequest<Req, Resp>>,
75                    DispatchRequest<Req, Resp>,
76                >,
77            >,
78            Resp,
79        >
80    );
81}
82
83impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
84    type Output = io::Result<DispatchResponse<Resp>>;
85
86    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
87        self.as_mut().fut().poll(cx)
88    }
89}
90
91/// A future returned by [`Channel::call`] that resolves to a server response.
92#[derive(Debug)]
93#[must_use = "futures do nothing unless polled"]
94pub struct Call<'a, Req, Resp> {
95    fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
96}
97
98impl<'a, Req, Resp> Call<'a, Req, Resp> {
99    unsafe_pinned!(fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>);
100}
101
102impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
103    type Output = io::Result<Resp>;
104
105    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
106        self.as_mut().fut().poll(cx)
107    }
108}
109
110impl<Req, Resp> Channel<Req, Resp> {
111    /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
112    /// resolves when the request is sent (not when the response is received).
113    fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
114        // Convert the context to the call context.
115        ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
116        ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
117
118        let timeout = ctx.deadline.time_until();
119        trace!(
120            "[{}] Queuing request with timeout {:?}.",
121            ctx.trace_id(),
122            timeout,
123        );
124
125        let (response_completion, response) = oneshot::channel();
126        let cancellation = self.cancellation.clone();
127        let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
128        Send {
129            fut: MapOkDispatchResponse::new(
130                MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
131                    ctx,
132                    request_id,
133                    request,
134                    response_completion,
135                })),
136                DispatchResponse {
137                    response: Timeout::new(response, timeout),
138                    complete: false,
139                    request_id,
140                    cancellation,
141                    ctx,
142                },
143            ),
144        }
145    }
146
147    /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
148    /// resolves to the response.
149    pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
150        Call {
151            fut: AndThenIdent::new(self.send(context, request)),
152        }
153    }
154}
155
156/// A server response that is completed by request dispatch when the corresponding response
157/// arrives off the wire.
158#[derive(Debug)]
159struct DispatchResponse<Resp> {
160    response: Timeout<oneshot::Receiver<Response<Resp>>>,
161    ctx: context::Context,
162    complete: bool,
163    cancellation: RequestCancellation,
164    request_id: u64,
165}
166
167impl<Resp> DispatchResponse<Resp> {
168    unsafe_pinned!(ctx: context::Context);
169}
170
171impl<Resp> Future for DispatchResponse<Resp> {
172    type Output = io::Result<Resp>;
173
174    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
175        let resp = ready!(self.response.poll_unpin(cx));
176
177        Poll::Ready(match resp {
178            Ok(resp) => {
179                self.complete = true;
180                match resp {
181                    Ok(resp) => Ok(resp.message?),
182                    Err(oneshot::Canceled) => {
183                        // The oneshot is Canceled when the dispatch task ends. In that case,
184                        // there's nothing listening on the other side, so there's no point in
185                        // propagating cancellation.
186                        Err(io::Error::from(io::ErrorKind::ConnectionReset))
187                    }
188                }
189            }
190            Err(timeout::Elapsed { .. }) => Err(io::Error::new(
191                io::ErrorKind::TimedOut,
192                "Client dropped expired request.".to_string(),
193            )),
194        })
195    }
196}
197
198// Cancels the request when dropped, if not already complete.
199impl<Resp> Drop for DispatchResponse<Resp> {
200    fn drop(&mut self) {
201        if !self.complete {
202            // The receiver needs to be closed to handle the edge case that the request has not
203            // yet been received by the dispatch task. It is possible for the cancel message to
204            // arrive before the request itself, in which case the request could get stuck in the
205            // dispatch map forever if the server never responds (e.g. if the server dies while
206            // responding). Even if the server does respond, it will have unnecessarily done work
207            // for a client no longer waiting for a response. To avoid this, the dispatch task
208            // checks if the receiver is closed before inserting the request in the map. By
209            // closing the receiver before sending the cancel message, it is guaranteed that if the
210            // dispatch task misses an early-arriving cancellation message, then it will see the
211            // receiver as closed.
212            self.response.get_mut().close();
213            self.cancellation.cancel(self.request_id);
214        }
215    }
216}
217
218/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
219/// channel.
220pub fn new<Req, Resp, C>(
221    config: Config,
222    transport: C,
223) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
224where
225    C: Transport<ClientMessage<Req>, Response<Resp>>,
226{
227    let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
228    let (cancellation, canceled_requests) = cancellations();
229    let canceled_requests = canceled_requests.fuse();
230
231    NewClient {
232        client: Channel {
233            to_dispatch,
234            cancellation,
235            next_request_id: Arc::new(AtomicU64::new(0)),
236        },
237        dispatch: RequestDispatch {
238            config,
239            canceled_requests,
240            transport: transport.fuse(),
241            in_flight_requests: FnvHashMap::default(),
242            pending_requests: pending_requests.fuse(),
243        },
244    }
245}
246
247/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
248/// and dispatching responses to the appropriate channel.
249#[derive(Debug)]
250pub struct RequestDispatch<Req, Resp, C> {
251    /// Writes requests to the wire and reads responses off the wire.
252    transport: Fuse<C>,
253    /// Requests waiting to be written to the wire.
254    pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
255    /// Requests that were dropped.
256    canceled_requests: Fuse<CanceledRequests>,
257    /// Requests already written to the wire that haven't yet received responses.
258    in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
259    /// Configures limits to prevent unlimited resource usage.
260    config: Config,
261}
262
263impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
264where
265    C: Transport<ClientMessage<Req>, Response<Resp>>,
266{
267    unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
268    unsafe_pinned!(canceled_requests: Fuse<CanceledRequests>);
269    unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
270    unsafe_pinned!(transport: Fuse<C>);
271
272    fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
273        Poll::Ready(match ready!(self.as_mut().transport().poll_next(cx)?) {
274            Some(response) => {
275                self.complete(response);
276                Some(Ok(()))
277            }
278            None => None,
279        })
280    }
281
282    fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
283        enum ReceiverStatus {
284            NotReady,
285            Closed,
286        }
287
288        let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
289            Poll::Ready(Some(dispatch_request)) => {
290                self.as_mut().write_request(dispatch_request)?;
291                return Poll::Ready(Some(Ok(())));
292            }
293            Poll::Ready(None) => ReceiverStatus::Closed,
294            Poll::Pending => ReceiverStatus::NotReady,
295        };
296
297        let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
298            Poll::Ready(Some((context, request_id))) => {
299                self.as_mut().write_cancel(context, request_id)?;
300                return Poll::Ready(Some(Ok(())));
301            }
302            Poll::Ready(None) => ReceiverStatus::Closed,
303            Poll::Pending => ReceiverStatus::NotReady,
304        };
305
306        match (pending_requests_status, canceled_requests_status) {
307            (ReceiverStatus::Closed, ReceiverStatus::Closed) => {
308                ready!(self.as_mut().transport().poll_flush(cx)?);
309                Poll::Ready(None)
310            }
311            (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
312                // No more messages to process, so flush any messages buffered in the transport.
313                ready!(self.as_mut().transport().poll_flush(cx)?);
314
315                // Even if we fully-flush, we return Pending, because we have no more requests
316                // or cancellations right now.
317                Poll::Pending
318            }
319        }
320    }
321
322    /// Yields the next pending request, if one is ready to be sent.
323    fn poll_next_request(
324        mut self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326    ) -> PollIo<DispatchRequest<Req, Resp>> {
327        if self.as_mut().in_flight_requests().len() >= self.config.max_in_flight_requests {
328            info!(
329                "At in-flight request capacity ({}/{}).",
330                self.as_mut().in_flight_requests().len(),
331                self.config.max_in_flight_requests
332            );
333
334            // No need to schedule a wakeup, because timers and responses are responsible
335            // for clearing out in-flight requests.
336            return Poll::Pending;
337        }
338
339        while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
340            // We can't yield a request-to-be-sent before the transport is capable of buffering it.
341            ready!(self.as_mut().transport().poll_flush(cx)?);
342        }
343
344        loop {
345            match ready!(self.as_mut().pending_requests().poll_next_unpin(cx)) {
346                Some(request) => {
347                    if request.response_completion.is_canceled() {
348                        trace!(
349                            "[{}] Request canceled before being sent.",
350                            request.ctx.trace_id()
351                        );
352                        continue;
353                    }
354
355                    return Poll::Ready(Some(Ok(request)));
356                }
357                None => return Poll::Ready(None),
358            }
359        }
360    }
361
362    /// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
363    fn poll_next_cancellation(
364        mut self: Pin<&mut Self>,
365        cx: &mut Context<'_>,
366    ) -> PollIo<(context::Context, u64)> {
367        while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
368            ready!(self.as_mut().transport().poll_flush(cx)?);
369        }
370
371        loop {
372            let cancellation = self.as_mut().canceled_requests().poll_next_unpin(cx);
373            match ready!(cancellation) {
374                Some(request_id) => {
375                    if let Some(in_flight_data) =
376                        self.as_mut().in_flight_requests().remove(&request_id)
377                    {
378                        self.as_mut().in_flight_requests().compact(0.1);
379                        debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
380                        return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
381                    }
382                }
383                None => return Poll::Ready(None),
384            }
385        }
386    }
387
388    fn write_request(
389        mut self: Pin<&mut Self>,
390        dispatch_request: DispatchRequest<Req, Resp>,
391    ) -> io::Result<()> {
392        let request_id = dispatch_request.request_id;
393        let request = ClientMessage::Request(Request {
394            id: request_id,
395            message: dispatch_request.request,
396            context: context::Context {
397                deadline: dispatch_request.ctx.deadline,
398                trace_context: dispatch_request.ctx.trace_context,
399            },
400        });
401        self.as_mut().transport().start_send(request)?;
402        self.as_mut().in_flight_requests().insert(
403            request_id,
404            InFlightData {
405                ctx: dispatch_request.ctx,
406                response_completion: dispatch_request.response_completion,
407            },
408        );
409        Ok(())
410    }
411
412    fn write_cancel(
413        mut self: Pin<&mut Self>,
414        context: context::Context,
415        request_id: u64,
416    ) -> io::Result<()> {
417        let trace_id = *context.trace_id();
418        let cancel = ClientMessage::Cancel {
419            trace_context: context.trace_context,
420            request_id,
421        };
422        self.as_mut().transport().start_send(cancel)?;
423        trace!("[{}] Cancel message sent.", trace_id);
424        Ok(())
425    }
426
427    /// Sends a server response to the client task that initiated the associated request.
428    fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
429        if let Some(in_flight_data) = self
430            .as_mut()
431            .in_flight_requests()
432            .remove(&response.request_id)
433        {
434            self.as_mut().in_flight_requests().compact(0.1);
435
436            trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
437            let _ = in_flight_data.response_completion.send(response);
438            return true;
439        }
440
441        debug!(
442            "No in-flight request found for request_id = {}.",
443            response.request_id
444        );
445
446        // If the response completion was absent, then the request was already canceled.
447        false
448    }
449}
450
451impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
452where
453    C: Transport<ClientMessage<Req>, Response<Resp>>,
454{
455    type Output = io::Result<()>;
456
457    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
458        loop {
459            match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
460                (read, Poll::Ready(None)) => {
461                    if self.as_mut().in_flight_requests().is_empty() {
462                        info!("Shutdown: write half closed, and no requests in flight.");
463                        return Poll::Ready(Ok(()));
464                    }
465                    info!(
466                        "Shutdown: write half closed, and {} requests in flight.",
467                        self.as_mut().in_flight_requests().len()
468                    );
469                    match read {
470                        Poll::Ready(Some(())) => continue,
471                        _ => return Poll::Pending,
472                    }
473                }
474                (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
475                _ => return Poll::Pending,
476            }
477        }
478    }
479}
480
481/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
482/// the lifecycle of the request.
483#[derive(Debug)]
484struct DispatchRequest<Req, Resp> {
485    ctx: context::Context,
486    request_id: u64,
487    request: Req,
488    response_completion: oneshot::Sender<Response<Resp>>,
489}
490
491#[derive(Debug)]
492struct InFlightData<Resp> {
493    ctx: context::Context,
494    response_completion: oneshot::Sender<Response<Resp>>,
495}
496
497/// Sends request cancellation signals.
498#[derive(Debug, Clone)]
499struct RequestCancellation(mpsc::UnboundedSender<u64>);
500
501/// A stream of IDs of requests that have been canceled.
502#[derive(Debug)]
503struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
504
505/// Returns a channel to send request cancellation messages.
506fn cancellations() -> (RequestCancellation, CanceledRequests) {
507    // Unbounded because messages are sent in the drop fn. This is fine, because it's still
508    // bounded by the number of in-flight requests. Additionally, each request has a clone
509    // of the sender, so the bounded channel would have the same behavior,
510    // since it guarantees a slot.
511    let (tx, rx) = mpsc::unbounded();
512    (RequestCancellation(tx), CanceledRequests(rx))
513}
514
515impl RequestCancellation {
516    /// Cancels the request with ID `request_id`.
517    fn cancel(&mut self, request_id: u64) {
518        let _ = self.0.unbounded_send(request_id);
519    }
520}
521
522impl Stream for CanceledRequests {
523    type Item = u64;
524
525    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
526        self.0.poll_next_unpin(cx)
527    }
528}
529
530#[derive(Debug)]
531#[must_use = "futures do nothing unless polled"]
532struct MapErrConnectionReset<Fut> {
533    future: Fut,
534    finished: Option<()>,
535}
536
537impl<Fut> MapErrConnectionReset<Fut> {
538    unsafe_pinned!(future: Fut);
539    unsafe_unpinned!(finished: Option<()>);
540
541    fn new(future: Fut) -> MapErrConnectionReset<Fut> {
542        MapErrConnectionReset {
543            future,
544            finished: Some(()),
545        }
546    }
547}
548
549impl<Fut: Unpin> Unpin for MapErrConnectionReset<Fut> {}
550
551impl<Fut> Future for MapErrConnectionReset<Fut>
552where
553    Fut: TryFuture,
554{
555    type Output = io::Result<Fut::Ok>;
556
557    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
558        match self.as_mut().future().try_poll(cx) {
559            Poll::Pending => Poll::Pending,
560            Poll::Ready(result) => {
561                self.finished().take().expect(
562                    "MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
563                );
564                Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
565            }
566        }
567    }
568}
569
570#[derive(Debug)]
571#[must_use = "futures do nothing unless polled"]
572struct MapOkDispatchResponse<Fut, Resp> {
573    future: Fut,
574    response: Option<DispatchResponse<Resp>>,
575}
576
577impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
578    unsafe_pinned!(future: Fut);
579    unsafe_unpinned!(response: Option<DispatchResponse<Resp>>);
580
581    fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
582        MapOkDispatchResponse {
583            future,
584            response: Some(response),
585        }
586    }
587}
588
589impl<Fut: Unpin, Resp> Unpin for MapOkDispatchResponse<Fut, Resp> {}
590
591impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
592where
593    Fut: TryFuture,
594{
595    type Output = Result<DispatchResponse<Resp>, Fut::Error>;
596
597    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
598        match self.as_mut().future().try_poll(cx) {
599            Poll::Pending => Poll::Pending,
600            Poll::Ready(result) => {
601                let response = self
602                    .as_mut()
603                    .response()
604                    .take()
605                    .expect("MapOk must not be polled after it returned `Poll::Ready`");
606                Poll::Ready(result.map(|_| response))
607            }
608        }
609    }
610}
611
612#[derive(Debug)]
613#[must_use = "futures do nothing unless polled"]
614struct AndThenIdent<Fut1, Fut2> {
615    try_chain: TryChain<Fut1, Fut2>,
616}
617
618impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
619where
620    Fut1: TryFuture<Ok = Fut2>,
621    Fut2: TryFuture,
622{
623    unsafe_pinned!(try_chain: TryChain<Fut1, Fut2>);
624
625    /// Creates a new `Then`.
626    fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
627        AndThenIdent {
628            try_chain: TryChain::new(future),
629        }
630    }
631}
632
633impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
634where
635    Fut1: TryFuture<Ok = Fut2>,
636    Fut2: TryFuture<Error = Fut1::Error>,
637{
638    type Output = Result<Fut2::Ok, Fut2::Error>;
639
640    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
641        self.try_chain().poll(cx, |result| match result {
642            Ok(ok) => TryChainAction::Future(ok),
643            Err(err) => TryChainAction::Output(Err(err)),
644        })
645    }
646}
647
648#[must_use = "futures do nothing unless polled"]
649#[derive(Debug)]
650enum TryChain<Fut1, Fut2> {
651    First(Fut1),
652    Second(Fut2),
653    Empty,
654}
655
656enum TryChainAction<Fut2>
657where
658    Fut2: TryFuture,
659{
660    Future(Fut2),
661    Output(Result<Fut2::Ok, Fut2::Error>),
662}
663
664impl<Fut1, Fut2> TryChain<Fut1, Fut2>
665where
666    Fut1: TryFuture<Ok = Fut2>,
667    Fut2: TryFuture,
668{
669    fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
670        TryChain::First(fut1)
671    }
672
673    fn poll<F>(
674        self: Pin<&mut Self>,
675        cx: &mut Context<'_>,
676        f: F,
677    ) -> Poll<Result<Fut2::Ok, Fut2::Error>>
678    where
679        F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
680    {
681        let mut f = Some(f);
682
683        // Safe to call `get_unchecked_mut` because we won't move the futures.
684        let this = unsafe { Pin::get_unchecked_mut(self) };
685
686        loop {
687            let output = match this {
688                TryChain::First(fut1) => {
689                    // Poll the first future
690                    match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
691                        Poll::Pending => return Poll::Pending,
692                        Poll::Ready(output) => output,
693                    }
694                }
695                TryChain::Second(fut2) => {
696                    // Poll the second future
697                    return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
698                }
699                TryChain::Empty => {
700                    panic!("future must not be polled after it returned `Poll::Ready`");
701                }
702            };
703
704            *this = TryChain::Empty; // Drop fut1
705            let f = f.take().unwrap();
706            match f(output) {
707                TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
708                TryChainAction::Output(output) => return Poll::Ready(output),
709            }
710        }
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::{
717        cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
718        RequestDispatch,
719    };
720    use crate::{
721        client::Config,
722        context,
723        transport::{self, channel::UnboundedChannel},
724        ClientMessage, Response,
725    };
726    use fnv::FnvHashMap;
727    use futures::{
728        channel::{mpsc, oneshot},
729        prelude::*,
730        task::Context,
731        Poll,
732    };
733    use futures_test::task::noop_waker_ref;
734    use std::time::Duration;
735    use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
736    use tokio::runtime::current_thread;
737    use tokio_timer::Timeout;
738
739    #[test]
740    fn dispatch_response_cancels_on_timeout() {
741        let (_response_completion, response) = oneshot::channel();
742        let (cancellation, mut canceled_requests) = cancellations();
743        let resp = DispatchResponse::<u64> {
744            // Timeout in the past should cause resp to error out when polled.
745            response: Timeout::new(response, Duration::from_secs(0)),
746            complete: false,
747            request_id: 3,
748            cancellation,
749            ctx: context::current(),
750        };
751        {
752            pin_utils::pin_mut!(resp);
753            let timer = tokio_timer::Timer::default();
754            let handle = timer.handle();
755            let _guard = tokio_timer::set_default(&handle);
756
757            let _ = resp
758                .as_mut()
759                .poll(&mut Context::from_waker(&noop_waker_ref()));
760            // End of block should cause resp.drop() to run, which should send a cancel message.
761        }
762        assert!(canceled_requests.0.try_next().unwrap() == Some(3));
763    }
764
765    #[test]
766    fn stage_request() {
767        let (mut dispatch, mut channel, _server_channel) = set_up();
768        let dispatch = Pin::new(&mut dispatch);
769        let cx = &mut Context::from_waker(&noop_waker_ref());
770
771        let _resp = send_request(&mut channel, "hi");
772
773        let req = dispatch.poll_next_request(cx).ready();
774        assert!(req.is_some());
775
776        let req = req.unwrap();
777        assert_eq!(req.request_id, 0);
778        assert_eq!(req.request, "hi".to_string());
779    }
780
781    fn block_on<F: Future>(f: F) -> F::Output {
782        current_thread::Runtime::new().unwrap().block_on(f)
783    }
784
785    // Regression test for  https://github.com/google/tarpc/issues/220
786    #[test]
787    fn stage_request_channel_dropped_doesnt_panic() {
788        let (mut dispatch, mut channel, mut server_channel) = set_up();
789        let mut dispatch = Pin::new(&mut dispatch);
790        let cx = &mut Context::from_waker(&noop_waker_ref());
791
792        let _ = send_request(&mut channel, "hi");
793        drop(channel);
794
795        assert!(dispatch.as_mut().poll(cx).is_ready());
796        send_response(
797            &mut server_channel,
798            Response {
799                request_id: 0,
800                message: Ok("hello".into()),
801            },
802        );
803        block_on(dispatch).unwrap();
804    }
805
806    #[test]
807    fn stage_request_response_future_dropped_is_canceled_before_sending() {
808        let (mut dispatch, mut channel, _server_channel) = set_up();
809        let dispatch = Pin::new(&mut dispatch);
810        let cx = &mut Context::from_waker(&noop_waker_ref());
811
812        let _ = send_request(&mut channel, "hi");
813
814        // Drop the channel so polling returns none if no requests are currently ready.
815        drop(channel);
816        // Test that a request future dropped before it's processed by dispatch will cause the request
817        // to not be added to the in-flight request map.
818        assert!(dispatch.poll_next_request(cx).ready().is_none());
819    }
820
821    #[test]
822    fn stage_request_response_future_dropped_is_canceled_after_sending() {
823        let (mut dispatch, mut channel, _server_channel) = set_up();
824        let cx = &mut Context::from_waker(&noop_waker_ref());
825        let mut dispatch = Pin::new(&mut dispatch);
826
827        let req = send_request(&mut channel, "hi");
828
829        assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
830        assert!(!dispatch.as_mut().in_flight_requests().is_empty());
831
832        // Test that a request future dropped after it's processed by dispatch will cause the request
833        // to be removed from the in-flight request map.
834        drop(req);
835        if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
836            // ok
837        } else {
838            panic!("Expected request to be cancelled")
839        };
840        assert!(dispatch.in_flight_requests().is_empty());
841    }
842
843    #[test]
844    fn stage_request_response_closed_skipped() {
845        let (mut dispatch, mut channel, _server_channel) = set_up();
846        let dispatch = Pin::new(&mut dispatch);
847        let cx = &mut Context::from_waker(&noop_waker_ref());
848
849        // Test that a request future that's closed its receiver but not yet canceled its request --
850        // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
851        // map.
852        let mut resp = send_request(&mut channel, "hi");
853        resp.response.get_mut().close();
854
855        assert!(dispatch.poll_next_request(cx).is_pending());
856    }
857
858    fn set_up() -> (
859        RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
860        Channel<String, String>,
861        UnboundedChannel<ClientMessage<String>, Response<String>>,
862    ) {
863        let _ = env_logger::try_init();
864
865        let (to_dispatch, pending_requests) = mpsc::channel(1);
866        let (cancel_tx, canceled_requests) = mpsc::unbounded();
867        let (client_channel, server_channel) = transport::channel::unbounded();
868
869        let dispatch = RequestDispatch::<String, String, _> {
870            transport: client_channel.fuse(),
871            pending_requests: pending_requests.fuse(),
872            canceled_requests: CanceledRequests(canceled_requests).fuse(),
873            in_flight_requests: FnvHashMap::default(),
874            config: Config::default(),
875        };
876
877        let cancellation = RequestCancellation(cancel_tx);
878        let channel = Channel {
879            to_dispatch,
880            cancellation,
881            next_request_id: Arc::new(AtomicU64::new(0)),
882        };
883
884        (dispatch, channel, server_channel)
885    }
886
887    fn send_request(
888        channel: &mut Channel<String, String>,
889        request: &str,
890    ) -> DispatchResponse<String> {
891        block_on(channel.send(context::current(), request.to_string())).unwrap()
892    }
893
894    fn send_response(
895        channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
896        response: Response<String>,
897    ) {
898        block_on(channel.send(response)).unwrap();
899    }
900
901    trait PollTest {
902        type T;
903        fn unwrap(self) -> Poll<Self::T>;
904        fn ready(self) -> Self::T;
905    }
906
907    impl<T, E> PollTest for Poll<Option<Result<T, E>>>
908    where
909        E: ::std::fmt::Display,
910    {
911        type T = Option<T>;
912
913        fn unwrap(self) -> Poll<Option<T>> {
914            match self {
915                Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
916                Poll::Ready(None) => Poll::Ready(None),
917                Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
918                Poll::Pending => Poll::Pending,
919            }
920        }
921
922        fn ready(self) -> Option<T> {
923            match self {
924                Poll::Ready(Some(Ok(t))) => Some(t),
925                Poll::Ready(None) => None,
926                Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
927                Poll::Pending => panic!("Pending"),
928            }
929        }
930    }
931}