tarpc_lib/server/
mod.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7//! Provides a server that concurrently handles many connections sending multiplexed requests.
8
9use crate::{
10    context, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response, ServerError,
11    Transport,
12};
13use fnv::FnvHashMap;
14use futures::{
15    channel::mpsc,
16    future::{AbortHandle, AbortRegistration, Abortable},
17    prelude::*,
18    ready,
19    stream::Fuse,
20    task::{Context, Poll},
21};
22use humantime::format_rfc3339;
23use log::{debug, trace};
24use pin_utils::{unsafe_pinned, unsafe_unpinned};
25use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
26use tokio_timer::{timeout, Timeout};
27
28mod filter;
29#[cfg(test)]
30mod testing;
31mod throttle;
32
33pub use self::{
34    filter::ChannelFilter,
35    throttle::{Throttler, ThrottlerStream},
36};
37
38/// Manages clients, serving multiplexed requests over each connection.
39#[derive(Debug)]
40pub struct Server<Req, Resp> {
41    config: Config,
42    ghost: PhantomData<(Req, Resp)>,
43}
44
45impl<Req, Resp> Default for Server<Req, Resp> {
46    fn default() -> Self {
47        new(Config::default())
48    }
49}
50
51/// Settings that control the behavior of the server.
52#[non_exhaustive]
53#[derive(Clone, Debug)]
54pub struct Config {
55    /// The number of responses per client that can be buffered server-side before being sent.
56    /// `pending_response_buffer` controls the buffer size of the channel that a server's
57    /// response tasks use to send responses to the client handler task.
58    pub pending_response_buffer: usize,
59}
60
61impl Default for Config {
62    fn default() -> Self {
63        Config {
64            pending_response_buffer: 100,
65        }
66    }
67}
68
69impl Config {
70    /// Returns a channel backed by `transport` and configured with `self`.
71    pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
72    where
73        T: Transport<Response<Resp>, ClientMessage<Req>>,
74    {
75        BaseChannel::new(self, transport)
76    }
77}
78
79/// Returns a new server with configuration specified `config`.
80pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
81    Server {
82        config,
83        ghost: PhantomData,
84    }
85}
86
87impl<Req, Resp> Server<Req, Resp> {
88    /// Returns the config for this server.
89    pub fn config(&self) -> &Config {
90        &self.config
91    }
92
93    /// Returns a stream of server channels.
94    pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
95    where
96        S: Stream<Item = T>,
97        T: Transport<Response<Resp>, ClientMessage<Req>>,
98    {
99        listener.map(move |t| BaseChannel::new(self.config.clone(), t))
100    }
101}
102
103/// Basically a Fn(Req) -> impl Future<Output = Resp>;
104pub trait Serve<Req>: Sized + Clone {
105    /// Type of response.
106    type Resp;
107
108    /// Type of response future.
109    type Fut: Future<Output = Self::Resp>;
110
111    /// Responds to a single request.
112    fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
113}
114
115impl<Req, Resp, Fut, F> Serve<Req> for F
116where
117    F: FnOnce(context::Context, Req) -> Fut + Clone,
118    Fut: Future<Output = Resp>,
119{
120    type Resp = Resp;
121    type Fut = Fut;
122
123    fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
124        self(ctx, req)
125    }
126}
127
128/// A utility trait enabling a stream to fluently chain a request handler.
129pub trait Handler<C>
130where
131    Self: Sized + Stream<Item = C>,
132    C: Channel,
133{
134    /// Enforces channel per-key limits.
135    fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
136    where
137        K: fmt::Display + Eq + Hash + Clone + Unpin,
138        KF: Fn(&C) -> K,
139    {
140        ChannelFilter::new(self, n, keymaker)
141    }
142
143    /// Caps the number of concurrent requests per channel.
144    fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
145        ThrottlerStream::new(self, n)
146    }
147
148    /// Responds to all requests with `server`.
149    #[cfg(feature = "tokio1")]
150    fn respond_with<S>(self, server: S) -> Running<Self, S>
151    where
152        S: Serve<C::Req, Resp = C::Resp>,
153    {
154        Running {
155            incoming: self,
156            server,
157        }
158    }
159}
160
161impl<S, C> Handler<C> for S
162where
163    S: Sized + Stream<Item = C>,
164    C: Channel,
165{
166}
167
168/// BaseChannel lifts a Transport to a Channel by tracking in-flight requests.
169#[derive(Debug)]
170pub struct BaseChannel<Req, Resp, T> {
171    config: Config,
172    /// Writes responses to the wire and reads requests off the wire.
173    transport: Fuse<T>,
174    /// Number of requests currently being responded to.
175    in_flight_requests: FnvHashMap<u64, AbortHandle>,
176    /// Types the request and response.
177    ghost: PhantomData<(Req, Resp)>,
178}
179
180impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
181    unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
182}
183
184impl<Req, Resp, T> BaseChannel<Req, Resp, T>
185where
186    T: Transport<Response<Resp>, ClientMessage<Req>>,
187{
188    /// Creates a new channel backed by `transport` and configured with `config`.
189    pub fn new(config: Config, transport: T) -> Self {
190        BaseChannel {
191            config,
192            transport: transport.fuse(),
193            in_flight_requests: FnvHashMap::default(),
194            ghost: PhantomData,
195        }
196    }
197
198    /// Creates a new channel backed by `transport` and configured with the defaults.
199    pub fn with_defaults(transport: T) -> Self {
200        Self::new(Config::default(), transport)
201    }
202
203    /// Returns the inner transport.
204    pub fn get_ref(&self) -> &T {
205        self.transport.get_ref()
206    }
207
208    /// Returns the pinned inner transport.
209    pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
210        unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
211    }
212
213    fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
214        // It's possible the request was already completed, so it's fine
215        // if this is None.
216        if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
217            self.as_mut().in_flight_requests().compact(0.1);
218
219            cancel_handle.abort();
220            let remaining = self.as_mut().in_flight_requests().len();
221            trace!(
222                "[{}] Request canceled. In-flight requests = {}",
223                trace_context.trace_id,
224                remaining,
225            );
226        } else {
227            trace!(
228                "[{}] Received cancellation, but response handler \
229                 is already complete.",
230                trace_context.trace_id,
231            );
232        }
233    }
234}
235
236/// The server end of an open connection with a client, streaming in requests from, and sinking
237/// responses to, the client.
238///
239/// Channels are free to somewhat rely on the assumption that all in-flight requests are eventually
240/// either [cancelled](BaseChannel::cancel_request) or [responded to](Sink::start_send). Safety cannot
241/// rely on this assumption, but it is best for `Channel` users to always account for all outstanding
242/// requests.
243pub trait Channel
244where
245    Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
246{
247    /// Type of request item.
248    type Req;
249
250    /// Type of response sink item.
251    type Resp;
252
253    /// Configuration of the channel.
254    fn config(&self) -> &Config;
255
256    /// Returns the number of in-flight requests over this channel.
257    fn in_flight_requests(self: Pin<&mut Self>) -> usize;
258
259    /// Caps the number of concurrent requests.
260    fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
261    where
262        Self: Sized,
263    {
264        Throttler::new(self, n)
265    }
266
267    /// Tells the Channel that request with ID `request_id` is being handled.
268    /// The request will be tracked until a response with the same ID is sent
269    /// to the Channel.
270    fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
271
272    /// Respond to requests coming over the channel with `f`. Returns a future that drives the
273    /// responses and resolves when the connection is closed.
274    fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
275    where
276        S: Serve<Self::Req, Resp = Self::Resp>,
277        Self: Sized,
278    {
279        let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
280        let responses = responses.fuse();
281
282        ClientHandler {
283            channel: self,
284            server,
285            pending_responses: responses,
286            responses_tx,
287        }
288    }
289}
290
291impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
292where
293    T: Transport<Response<Resp>, ClientMessage<Req>>,
294{
295    type Item = io::Result<Request<Req>>;
296
297    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
298        loop {
299            match ready!(self.as_mut().transport().poll_next(cx)?) {
300                Some(message) => match message {
301                    ClientMessage::Request(request) => {
302                        return Poll::Ready(Some(Ok(request)));
303                    }
304                    ClientMessage::Cancel {
305                        trace_context,
306                        request_id,
307                    } => {
308                        self.as_mut().cancel_request(&trace_context, request_id);
309                    }
310                },
311                None => return Poll::Ready(None),
312            }
313        }
314    }
315}
316
317impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
318where
319    T: Transport<Response<Resp>, ClientMessage<Req>>,
320{
321    type Error = io::Error;
322
323    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
324        self.transport().poll_ready(cx)
325    }
326
327    fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
328        if self
329            .as_mut()
330            .in_flight_requests()
331            .remove(&response.request_id)
332            .is_some()
333        {
334            self.as_mut().in_flight_requests().compact(0.1);
335        }
336
337        self.transport().start_send(response)
338    }
339
340    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
341        self.transport().poll_flush(cx)
342    }
343
344    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
345        self.transport().poll_close(cx)
346    }
347}
348
349impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
350    fn as_ref(&self) -> &T {
351        self.transport.get_ref()
352    }
353}
354
355impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
356where
357    T: Transport<Response<Resp>, ClientMessage<Req>>,
358{
359    type Req = Req;
360    type Resp = Resp;
361
362    fn config(&self) -> &Config {
363        &self.config
364    }
365
366    fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
367        self.as_mut().in_flight_requests().len()
368    }
369
370    fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
371        let (abort_handle, abort_registration) = AbortHandle::new_pair();
372        assert!(self
373            .in_flight_requests()
374            .insert(request_id, abort_handle)
375            .is_none());
376        abort_registration
377    }
378}
379
380/// A running handler serving all requests coming over a channel.
381#[derive(Debug)]
382pub struct ClientHandler<C, S>
383where
384    C: Channel,
385{
386    channel: C,
387    /// Responses waiting to be written to the wire.
388    pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
389    /// Handed out to request handlers to fan in responses.
390    responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
391    /// Server
392    server: S,
393}
394
395impl<C, S> ClientHandler<C, S>
396where
397    C: Channel,
398{
399    unsafe_pinned!(channel: C);
400    unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
401    unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
402    // For this to be safe, field f must be private, and code in this module must never
403    // construct PinMut<S>.
404    unsafe_unpinned!(server: S);
405}
406
407impl<C, S> ClientHandler<C, S>
408where
409    C: Channel,
410    S: Serve<C::Req, Resp = C::Resp>,
411{
412    fn pump_read(
413        mut self: Pin<&mut Self>,
414        cx: &mut Context<'_>,
415    ) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
416        match ready!(self.as_mut().channel().poll_next(cx)?) {
417            Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
418            None => Poll::Ready(None),
419        }
420    }
421
422    fn pump_write(
423        mut self: Pin<&mut Self>,
424        cx: &mut Context<'_>,
425        read_half_closed: bool,
426    ) -> PollIo<()> {
427        match self.as_mut().poll_next_response(cx)? {
428            Poll::Ready(Some((ctx, response))) => {
429                trace!(
430                    "[{}] Staging response. In-flight requests = {}.",
431                    ctx.trace_id(),
432                    self.as_mut().channel().in_flight_requests(),
433                );
434                self.as_mut().channel().start_send(response)?;
435                Poll::Ready(Some(Ok(())))
436            }
437            Poll::Ready(None) => {
438                // Shutdown can't be done before we finish pumping out remaining responses.
439                ready!(self.as_mut().channel().poll_flush(cx)?);
440                Poll::Ready(None)
441            }
442            Poll::Pending => {
443                // No more requests to process, so flush any requests buffered in the transport.
444                ready!(self.as_mut().channel().poll_flush(cx)?);
445
446                // Being here means there are no staged requests and all written responses are
447                // fully flushed. So, if the read half is closed and there are no in-flight
448                // requests, then we can close the write half.
449                if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
450                    Poll::Ready(None)
451                } else {
452                    Poll::Pending
453                }
454            }
455        }
456    }
457
458    fn poll_next_response(
459        mut self: Pin<&mut Self>,
460        cx: &mut Context<'_>,
461    ) -> PollIo<(context::Context, Response<C::Resp>)> {
462        // Ensure there's room to write a response.
463        while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
464            ready!(self.as_mut().channel().poll_flush(cx)?);
465        }
466
467        match ready!(self.as_mut().pending_responses().poll_next(cx)) {
468            Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
469            None => {
470                // This branch likely won't happen, since the ClientHandler is holding a Sender.
471                Poll::Ready(None)
472            }
473        }
474    }
475
476    fn handle_request(
477        mut self: Pin<&mut Self>,
478        request: Request<C::Req>,
479    ) -> RequestHandler<S::Fut, C::Resp> {
480        let request_id = request.id;
481        let deadline = request.context.deadline;
482        let timeout = deadline.time_until();
483        trace!(
484            "[{}] Received request with deadline {} (timeout {:?}).",
485            request.context.trace_id(),
486            format_rfc3339(deadline),
487            timeout,
488        );
489        let ctx = request.context;
490        let request = request.message;
491
492        let response = self.as_mut().server().clone().serve(ctx, request);
493        let response = Resp {
494            state: RespState::PollResp,
495            request_id,
496            ctx,
497            deadline,
498            f: Timeout::new(response, timeout),
499            response: None,
500            response_tx: self.as_mut().responses_tx().clone(),
501        };
502        let abort_registration = self.as_mut().channel().start_request(request_id);
503        RequestHandler {
504            resp: Abortable::new(response, abort_registration),
505        }
506    }
507}
508
509/// A future fulfilling a single client request.
510#[derive(Debug)]
511pub struct RequestHandler<F, R> {
512    resp: Abortable<Resp<F, R>>,
513}
514
515impl<F, R> RequestHandler<F, R> {
516    unsafe_pinned!(resp: Abortable<Resp<F, R>>);
517}
518
519impl<F, R> Future for RequestHandler<F, R>
520where
521    F: Future<Output = R>,
522{
523    type Output = ();
524
525    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
526        let _ = ready!(self.resp().poll(cx));
527        Poll::Ready(())
528    }
529}
530
531#[derive(Debug)]
532struct Resp<F, R> {
533    state: RespState,
534    request_id: u64,
535    ctx: context::Context,
536    deadline: SystemTime,
537    f: Timeout<F>,
538    response: Option<Response<R>>,
539    response_tx: mpsc::Sender<(context::Context, Response<R>)>,
540}
541
542#[derive(Debug)]
543enum RespState {
544    PollResp,
545    PollReady,
546    PollFlush,
547}
548
549impl<F, R> Resp<F, R> {
550    unsafe_pinned!(f: Timeout<F>);
551    unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
552    unsafe_unpinned!(response: Option<Response<R>>);
553    unsafe_unpinned!(state: RespState);
554}
555
556impl<F, R> Future for Resp<F, R>
557where
558    F: Future<Output = R>,
559{
560    type Output = ();
561
562    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
563        loop {
564            match self.as_mut().state() {
565                RespState::PollResp => {
566                    let result = ready!(self.as_mut().f().poll(cx));
567                    *self.as_mut().response() = Some(Response {
568                        request_id: self.request_id,
569                        message: match result {
570                            Ok(message) => Ok(message),
571                            Err(timeout::Elapsed { .. }) => {
572                                debug!(
573                                    "[{}] Response did not complete before deadline of {}s.",
574                                    self.ctx.trace_id(),
575                                    format_rfc3339(self.deadline)
576                                );
577                                // No point in responding, since the client will have dropped the
578                                // request.
579                                Err(ServerError {
580                                    kind: io::ErrorKind::TimedOut,
581                                    detail: Some(format!(
582                                        "Response did not complete before deadline of {}s.",
583                                        format_rfc3339(self.deadline)
584                                    )),
585                                })
586                            }
587                        },
588                    });
589                    *self.as_mut().state() = RespState::PollReady;
590                }
591                RespState::PollReady => {
592                    let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
593                    if ready.is_err() {
594                        return Poll::Ready(());
595                    }
596                    let resp = (self.ctx, self.as_mut().response().take().unwrap());
597                    if self.as_mut().response_tx().start_send(resp).is_err() {
598                        return Poll::Ready(());
599                    }
600                    *self.as_mut().state() = RespState::PollFlush;
601                }
602                RespState::PollFlush => {
603                    let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
604                    if ready.is_err() {
605                        return Poll::Ready(());
606                    }
607                    return Poll::Ready(());
608                }
609            }
610        }
611    }
612}
613
614impl<C, S> Stream for ClientHandler<C, S>
615where
616    C: Channel,
617    S: Serve<C::Req, Resp = C::Resp>,
618{
619    type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
620
621    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
622        loop {
623            let read = self.as_mut().pump_read(cx)?;
624            let read_closed = if let Poll::Ready(None) = read {
625                true
626            } else {
627                false
628            };
629            match (read, self.as_mut().pump_write(cx, read_closed)?) {
630                (Poll::Ready(None), Poll::Ready(None)) => {
631                    return Poll::Ready(None);
632                }
633                (Poll::Ready(Some(request_handler)), _) => {
634                    return Poll::Ready(Some(Ok(request_handler)));
635                }
636                (_, Poll::Ready(Some(()))) => {}
637                _ => {
638                    return Poll::Pending;
639                }
640            }
641        }
642    }
643}
644
645// Send + 'static execution helper methods.
646
647impl<C, S> ClientHandler<C, S>
648where
649    C: Channel + 'static,
650    C::Req: Send + 'static,
651    C::Resp: Send + 'static,
652    S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
653    S::Fut: Send + 'static,
654{
655    /// Runs the client handler until completion by spawning each
656    /// request handler onto the default executor.
657    #[cfg(feature = "tokio1")]
658    pub fn execute(self) -> impl Future<Output = ()> {
659        use log::info;
660
661        self.try_for_each(|request_handler| {
662            async {
663                tokio::spawn(request_handler);
664                Ok(())
665            }
666        })
667        .unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
668    }
669}
670
671/// A future that drives the server by spawning channels and request handlers on the default
672/// executor.
673#[derive(Debug)]
674#[cfg(feature = "tokio1")]
675pub struct Running<St, Se> {
676    incoming: St,
677    server: Se,
678}
679
680#[cfg(feature = "tokio1")]
681impl<St, Se> Running<St, Se> {
682    unsafe_pinned!(incoming: St);
683    unsafe_unpinned!(server: Se);
684}
685
686#[cfg(feature = "tokio1")]
687impl<St, C, Se> Future for Running<St, Se>
688where
689    St: Sized + Stream<Item = C>,
690    C: Channel + Send + 'static,
691    C::Req: Send + 'static,
692    C::Resp: Send + 'static,
693    Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
694    Se::Fut: Send + 'static,
695{
696    type Output = ();
697
698    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
699        use log::info;
700
701        while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
702            tokio::spawn(
703                channel
704                    .respond_with(self.as_mut().server().clone())
705                    .execute(),
706            );
707        }
708        info!("Server shutting down.");
709        Poll::Ready(())
710    }
711}