1use crate::{
10 context, util::Compact, util::TimeUntil, ClientMessage, PollIo, Request, Response, ServerError,
11 Transport,
12};
13use fnv::FnvHashMap;
14use futures::{
15 channel::mpsc,
16 future::{AbortHandle, AbortRegistration, Abortable},
17 prelude::*,
18 ready,
19 stream::Fuse,
20 task::{Context, Poll},
21};
22use humantime::format_rfc3339;
23use log::{debug, trace};
24use pin_utils::{unsafe_pinned, unsafe_unpinned};
25use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime};
26use tokio_timer::{timeout, Timeout};
27
28mod filter;
29#[cfg(test)]
30mod testing;
31mod throttle;
32
33pub use self::{
34 filter::ChannelFilter,
35 throttle::{Throttler, ThrottlerStream},
36};
37
38#[derive(Debug)]
40pub struct Server<Req, Resp> {
41 config: Config,
42 ghost: PhantomData<(Req, Resp)>,
43}
44
45impl<Req, Resp> Default for Server<Req, Resp> {
46 fn default() -> Self {
47 new(Config::default())
48 }
49}
50
51#[non_exhaustive]
53#[derive(Clone, Debug)]
54pub struct Config {
55 pub pending_response_buffer: usize,
59}
60
61impl Default for Config {
62 fn default() -> Self {
63 Config {
64 pending_response_buffer: 100,
65 }
66 }
67}
68
69impl Config {
70 pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
72 where
73 T: Transport<Response<Resp>, ClientMessage<Req>>,
74 {
75 BaseChannel::new(self, transport)
76 }
77}
78
79pub fn new<Req, Resp>(config: Config) -> Server<Req, Resp> {
81 Server {
82 config,
83 ghost: PhantomData,
84 }
85}
86
87impl<Req, Resp> Server<Req, Resp> {
88 pub fn config(&self) -> &Config {
90 &self.config
91 }
92
93 pub fn incoming<S, T>(self, listener: S) -> impl Stream<Item = BaseChannel<Req, Resp, T>>
95 where
96 S: Stream<Item = T>,
97 T: Transport<Response<Resp>, ClientMessage<Req>>,
98 {
99 listener.map(move |t| BaseChannel::new(self.config.clone(), t))
100 }
101}
102
103pub trait Serve<Req>: Sized + Clone {
105 type Resp;
107
108 type Fut: Future<Output = Self::Resp>;
110
111 fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
113}
114
115impl<Req, Resp, Fut, F> Serve<Req> for F
116where
117 F: FnOnce(context::Context, Req) -> Fut + Clone,
118 Fut: Future<Output = Resp>,
119{
120 type Resp = Resp;
121 type Fut = Fut;
122
123 fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
124 self(ctx, req)
125 }
126}
127
128pub trait Handler<C>
130where
131 Self: Sized + Stream<Item = C>,
132 C: Channel,
133{
134 fn max_channels_per_key<K, KF>(self, n: u32, keymaker: KF) -> filter::ChannelFilter<Self, K, KF>
136 where
137 K: fmt::Display + Eq + Hash + Clone + Unpin,
138 KF: Fn(&C) -> K,
139 {
140 ChannelFilter::new(self, n, keymaker)
141 }
142
143 fn max_concurrent_requests_per_channel(self, n: usize) -> ThrottlerStream<Self> {
145 ThrottlerStream::new(self, n)
146 }
147
148 #[cfg(feature = "tokio1")]
150 fn respond_with<S>(self, server: S) -> Running<Self, S>
151 where
152 S: Serve<C::Req, Resp = C::Resp>,
153 {
154 Running {
155 incoming: self,
156 server,
157 }
158 }
159}
160
161impl<S, C> Handler<C> for S
162where
163 S: Sized + Stream<Item = C>,
164 C: Channel,
165{
166}
167
168#[derive(Debug)]
170pub struct BaseChannel<Req, Resp, T> {
171 config: Config,
172 transport: Fuse<T>,
174 in_flight_requests: FnvHashMap<u64, AbortHandle>,
176 ghost: PhantomData<(Req, Resp)>,
178}
179
180impl<Req, Resp, T> BaseChannel<Req, Resp, T> {
181 unsafe_unpinned!(in_flight_requests: FnvHashMap<u64, AbortHandle>);
182}
183
184impl<Req, Resp, T> BaseChannel<Req, Resp, T>
185where
186 T: Transport<Response<Resp>, ClientMessage<Req>>,
187{
188 pub fn new(config: Config, transport: T) -> Self {
190 BaseChannel {
191 config,
192 transport: transport.fuse(),
193 in_flight_requests: FnvHashMap::default(),
194 ghost: PhantomData,
195 }
196 }
197
198 pub fn with_defaults(transport: T) -> Self {
200 Self::new(Config::default(), transport)
201 }
202
203 pub fn get_ref(&self) -> &T {
205 self.transport.get_ref()
206 }
207
208 pub fn transport<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> {
210 unsafe { self.map_unchecked_mut(|me| me.transport.get_mut()) }
211 }
212
213 fn cancel_request(mut self: Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) {
214 if let Some(cancel_handle) = self.as_mut().in_flight_requests().remove(&request_id) {
217 self.as_mut().in_flight_requests().compact(0.1);
218
219 cancel_handle.abort();
220 let remaining = self.as_mut().in_flight_requests().len();
221 trace!(
222 "[{}] Request canceled. In-flight requests = {}",
223 trace_context.trace_id,
224 remaining,
225 );
226 } else {
227 trace!(
228 "[{}] Received cancellation, but response handler \
229 is already complete.",
230 trace_context.trace_id,
231 );
232 }
233 }
234}
235
236pub trait Channel
244where
245 Self: Transport<Response<<Self as Channel>::Resp>, Request<<Self as Channel>::Req>>,
246{
247 type Req;
249
250 type Resp;
252
253 fn config(&self) -> &Config;
255
256 fn in_flight_requests(self: Pin<&mut Self>) -> usize;
258
259 fn max_concurrent_requests(self, n: usize) -> Throttler<Self>
261 where
262 Self: Sized,
263 {
264 Throttler::new(self, n)
265 }
266
267 fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration;
271
272 fn respond_with<S>(self, server: S) -> ClientHandler<Self, S>
275 where
276 S: Serve<Self::Req, Resp = Self::Resp>,
277 Self: Sized,
278 {
279 let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
280 let responses = responses.fuse();
281
282 ClientHandler {
283 channel: self,
284 server,
285 pending_responses: responses,
286 responses_tx,
287 }
288 }
289}
290
291impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
292where
293 T: Transport<Response<Resp>, ClientMessage<Req>>,
294{
295 type Item = io::Result<Request<Req>>;
296
297 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
298 loop {
299 match ready!(self.as_mut().transport().poll_next(cx)?) {
300 Some(message) => match message {
301 ClientMessage::Request(request) => {
302 return Poll::Ready(Some(Ok(request)));
303 }
304 ClientMessage::Cancel {
305 trace_context,
306 request_id,
307 } => {
308 self.as_mut().cancel_request(&trace_context, request_id);
309 }
310 },
311 None => return Poll::Ready(None),
312 }
313 }
314 }
315}
316
317impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
318where
319 T: Transport<Response<Resp>, ClientMessage<Req>>,
320{
321 type Error = io::Error;
322
323 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
324 self.transport().poll_ready(cx)
325 }
326
327 fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
328 if self
329 .as_mut()
330 .in_flight_requests()
331 .remove(&response.request_id)
332 .is_some()
333 {
334 self.as_mut().in_flight_requests().compact(0.1);
335 }
336
337 self.transport().start_send(response)
338 }
339
340 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
341 self.transport().poll_flush(cx)
342 }
343
344 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
345 self.transport().poll_close(cx)
346 }
347}
348
349impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
350 fn as_ref(&self) -> &T {
351 self.transport.get_ref()
352 }
353}
354
355impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
356where
357 T: Transport<Response<Resp>, ClientMessage<Req>>,
358{
359 type Req = Req;
360 type Resp = Resp;
361
362 fn config(&self) -> &Config {
363 &self.config
364 }
365
366 fn in_flight_requests(mut self: Pin<&mut Self>) -> usize {
367 self.as_mut().in_flight_requests().len()
368 }
369
370 fn start_request(self: Pin<&mut Self>, request_id: u64) -> AbortRegistration {
371 let (abort_handle, abort_registration) = AbortHandle::new_pair();
372 assert!(self
373 .in_flight_requests()
374 .insert(request_id, abort_handle)
375 .is_none());
376 abort_registration
377 }
378}
379
380#[derive(Debug)]
382pub struct ClientHandler<C, S>
383where
384 C: Channel,
385{
386 channel: C,
387 pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>,
389 responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
391 server: S,
393}
394
395impl<C, S> ClientHandler<C, S>
396where
397 C: Channel,
398{
399 unsafe_pinned!(channel: C);
400 unsafe_pinned!(pending_responses: Fuse<mpsc::Receiver<(context::Context, Response<C::Resp>)>>);
401 unsafe_pinned!(responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>);
402 unsafe_unpinned!(server: S);
405}
406
407impl<C, S> ClientHandler<C, S>
408where
409 C: Channel,
410 S: Serve<C::Req, Resp = C::Resp>,
411{
412 fn pump_read(
413 mut self: Pin<&mut Self>,
414 cx: &mut Context<'_>,
415 ) -> PollIo<RequestHandler<S::Fut, C::Resp>> {
416 match ready!(self.as_mut().channel().poll_next(cx)?) {
417 Some(request) => Poll::Ready(Some(Ok(self.handle_request(request)))),
418 None => Poll::Ready(None),
419 }
420 }
421
422 fn pump_write(
423 mut self: Pin<&mut Self>,
424 cx: &mut Context<'_>,
425 read_half_closed: bool,
426 ) -> PollIo<()> {
427 match self.as_mut().poll_next_response(cx)? {
428 Poll::Ready(Some((ctx, response))) => {
429 trace!(
430 "[{}] Staging response. In-flight requests = {}.",
431 ctx.trace_id(),
432 self.as_mut().channel().in_flight_requests(),
433 );
434 self.as_mut().channel().start_send(response)?;
435 Poll::Ready(Some(Ok(())))
436 }
437 Poll::Ready(None) => {
438 ready!(self.as_mut().channel().poll_flush(cx)?);
440 Poll::Ready(None)
441 }
442 Poll::Pending => {
443 ready!(self.as_mut().channel().poll_flush(cx)?);
445
446 if read_half_closed && self.as_mut().channel().in_flight_requests() == 0 {
450 Poll::Ready(None)
451 } else {
452 Poll::Pending
453 }
454 }
455 }
456 }
457
458 fn poll_next_response(
459 mut self: Pin<&mut Self>,
460 cx: &mut Context<'_>,
461 ) -> PollIo<(context::Context, Response<C::Resp>)> {
462 while let Poll::Pending = self.as_mut().channel().poll_ready(cx)? {
464 ready!(self.as_mut().channel().poll_flush(cx)?);
465 }
466
467 match ready!(self.as_mut().pending_responses().poll_next(cx)) {
468 Some((ctx, response)) => Poll::Ready(Some(Ok((ctx, response)))),
469 None => {
470 Poll::Ready(None)
472 }
473 }
474 }
475
476 fn handle_request(
477 mut self: Pin<&mut Self>,
478 request: Request<C::Req>,
479 ) -> RequestHandler<S::Fut, C::Resp> {
480 let request_id = request.id;
481 let deadline = request.context.deadline;
482 let timeout = deadline.time_until();
483 trace!(
484 "[{}] Received request with deadline {} (timeout {:?}).",
485 request.context.trace_id(),
486 format_rfc3339(deadline),
487 timeout,
488 );
489 let ctx = request.context;
490 let request = request.message;
491
492 let response = self.as_mut().server().clone().serve(ctx, request);
493 let response = Resp {
494 state: RespState::PollResp,
495 request_id,
496 ctx,
497 deadline,
498 f: Timeout::new(response, timeout),
499 response: None,
500 response_tx: self.as_mut().responses_tx().clone(),
501 };
502 let abort_registration = self.as_mut().channel().start_request(request_id);
503 RequestHandler {
504 resp: Abortable::new(response, abort_registration),
505 }
506 }
507}
508
509#[derive(Debug)]
511pub struct RequestHandler<F, R> {
512 resp: Abortable<Resp<F, R>>,
513}
514
515impl<F, R> RequestHandler<F, R> {
516 unsafe_pinned!(resp: Abortable<Resp<F, R>>);
517}
518
519impl<F, R> Future for RequestHandler<F, R>
520where
521 F: Future<Output = R>,
522{
523 type Output = ();
524
525 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
526 let _ = ready!(self.resp().poll(cx));
527 Poll::Ready(())
528 }
529}
530
531#[derive(Debug)]
532struct Resp<F, R> {
533 state: RespState,
534 request_id: u64,
535 ctx: context::Context,
536 deadline: SystemTime,
537 f: Timeout<F>,
538 response: Option<Response<R>>,
539 response_tx: mpsc::Sender<(context::Context, Response<R>)>,
540}
541
542#[derive(Debug)]
543enum RespState {
544 PollResp,
545 PollReady,
546 PollFlush,
547}
548
549impl<F, R> Resp<F, R> {
550 unsafe_pinned!(f: Timeout<F>);
551 unsafe_pinned!(response_tx: mpsc::Sender<(context::Context, Response<R>)>);
552 unsafe_unpinned!(response: Option<Response<R>>);
553 unsafe_unpinned!(state: RespState);
554}
555
556impl<F, R> Future for Resp<F, R>
557where
558 F: Future<Output = R>,
559{
560 type Output = ();
561
562 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
563 loop {
564 match self.as_mut().state() {
565 RespState::PollResp => {
566 let result = ready!(self.as_mut().f().poll(cx));
567 *self.as_mut().response() = Some(Response {
568 request_id: self.request_id,
569 message: match result {
570 Ok(message) => Ok(message),
571 Err(timeout::Elapsed { .. }) => {
572 debug!(
573 "[{}] Response did not complete before deadline of {}s.",
574 self.ctx.trace_id(),
575 format_rfc3339(self.deadline)
576 );
577 Err(ServerError {
580 kind: io::ErrorKind::TimedOut,
581 detail: Some(format!(
582 "Response did not complete before deadline of {}s.",
583 format_rfc3339(self.deadline)
584 )),
585 })
586 }
587 },
588 });
589 *self.as_mut().state() = RespState::PollReady;
590 }
591 RespState::PollReady => {
592 let ready = ready!(self.as_mut().response_tx().poll_ready(cx));
593 if ready.is_err() {
594 return Poll::Ready(());
595 }
596 let resp = (self.ctx, self.as_mut().response().take().unwrap());
597 if self.as_mut().response_tx().start_send(resp).is_err() {
598 return Poll::Ready(());
599 }
600 *self.as_mut().state() = RespState::PollFlush;
601 }
602 RespState::PollFlush => {
603 let ready = ready!(self.as_mut().response_tx().poll_flush(cx));
604 if ready.is_err() {
605 return Poll::Ready(());
606 }
607 return Poll::Ready(());
608 }
609 }
610 }
611 }
612}
613
614impl<C, S> Stream for ClientHandler<C, S>
615where
616 C: Channel,
617 S: Serve<C::Req, Resp = C::Resp>,
618{
619 type Item = io::Result<RequestHandler<S::Fut, C::Resp>>;
620
621 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
622 loop {
623 let read = self.as_mut().pump_read(cx)?;
624 let read_closed = if let Poll::Ready(None) = read {
625 true
626 } else {
627 false
628 };
629 match (read, self.as_mut().pump_write(cx, read_closed)?) {
630 (Poll::Ready(None), Poll::Ready(None)) => {
631 return Poll::Ready(None);
632 }
633 (Poll::Ready(Some(request_handler)), _) => {
634 return Poll::Ready(Some(Ok(request_handler)));
635 }
636 (_, Poll::Ready(Some(()))) => {}
637 _ => {
638 return Poll::Pending;
639 }
640 }
641 }
642 }
643}
644
645impl<C, S> ClientHandler<C, S>
648where
649 C: Channel + 'static,
650 C::Req: Send + 'static,
651 C::Resp: Send + 'static,
652 S: Serve<C::Req, Resp = C::Resp> + Send + 'static,
653 S::Fut: Send + 'static,
654{
655 #[cfg(feature = "tokio1")]
658 pub fn execute(self) -> impl Future<Output = ()> {
659 use log::info;
660
661 self.try_for_each(|request_handler| {
662 async {
663 tokio::spawn(request_handler);
664 Ok(())
665 }
666 })
667 .unwrap_or_else(|e| info!("ClientHandler errored out: {}", e))
668 }
669}
670
671#[derive(Debug)]
674#[cfg(feature = "tokio1")]
675pub struct Running<St, Se> {
676 incoming: St,
677 server: Se,
678}
679
680#[cfg(feature = "tokio1")]
681impl<St, Se> Running<St, Se> {
682 unsafe_pinned!(incoming: St);
683 unsafe_unpinned!(server: Se);
684}
685
686#[cfg(feature = "tokio1")]
687impl<St, C, Se> Future for Running<St, Se>
688where
689 St: Sized + Stream<Item = C>,
690 C: Channel + Send + 'static,
691 C::Req: Send + 'static,
692 C::Resp: Send + 'static,
693 Se: Serve<C::Req, Resp = C::Resp> + Send + 'static + Clone,
694 Se::Fut: Send + 'static,
695{
696 type Output = ();
697
698 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
699 use log::info;
700
701 while let Some(channel) = ready!(self.as_mut().incoming().poll_next(cx)) {
702 tokio::spawn(
703 channel
704 .respond_with(self.as_mut().server().clone())
705 .execute(),
706 );
707 }
708 info!("Server shutting down.");
709 Poll::Ready(())
710 }
711}