1use crate::{
10 cancellations::{cancellations, CanceledRequests, RequestCancellation},
11 context::{self, SpanExt},
12 trace,
13 util::TimeUntil,
14 ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport,
15};
16use ::tokio::sync::mpsc;
17use futures::{
18 future::{AbortRegistration, Abortable},
19 prelude::*,
20 ready,
21 stream::Fuse,
22 task::*,
23};
24use in_flight_requests::{AlreadyExistsError, InFlightRequests};
25use pin_project::pin_project;
26use std::{
27 convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime,
28};
29use tracing::{info_span, instrument::Instrument, Span};
30
31mod in_flight_requests;
32pub mod request_hook;
33#[cfg(test)]
34mod testing;
35
36pub mod limits;
38
39pub mod incoming;
41
42#[derive(Clone, Debug)]
44pub struct Config {
45 pub pending_response_buffer: usize,
49}
50
51impl Default for Config {
52 fn default() -> Self {
53 Config {
54 pending_response_buffer: 100,
55 }
56 }
57}
58
59impl Config {
60 pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
62 where
63 T: Transport<Response<Resp>, ClientMessage<Req>>,
64 {
65 BaseChannel::new(self, transport)
66 }
67}
68
69#[allow(async_fn_in_trait)]
71pub trait Serve {
72 type Req: RequestName;
74
75 type Resp;
77
78 async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
80}
81
82#[derive(Debug)]
84pub struct ServeFn<Req, Resp, F> {
85 f: F,
86 data: PhantomData<fn(Req) -> Resp>,
87}
88
89impl<Req, Resp, F> Clone for ServeFn<Req, Resp, F>
90where
91 F: Clone,
92{
93 fn clone(&self) -> Self {
94 Self {
95 f: self.f.clone(),
96 data: PhantomData,
97 }
98 }
99}
100
101impl<Req, Resp, F> Copy for ServeFn<Req, Resp, F> where F: Copy {}
102
103pub fn serve<Req, Resp, Fut, F>(f: F) -> ServeFn<Req, Resp, F>
106where
107 F: FnOnce(context::Context, Req) -> Fut,
108 Fut: Future<Output = Result<Resp, ServerError>>,
109{
110 ServeFn {
111 f,
112 data: PhantomData,
113 }
114}
115
116impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
117where
118 Req: RequestName,
119 F: FnOnce(context::Context, Req) -> Fut,
120 Fut: Future<Output = Result<Resp, ServerError>>,
121{
122 type Req = Req;
123 type Resp = Resp;
124
125 async fn serve(self, ctx: context::Context, req: Req) -> Result<Resp, ServerError> {
126 (self.f)(ctx, req).await
127 }
128}
129
130#[pin_project]
141pub struct BaseChannel<Req, Resp, T> {
142 config: Config,
143 #[pin]
145 transport: Fuse<T>,
146 #[pin]
148 canceled_requests: CanceledRequests,
149 request_cancellation: RequestCancellation,
151 in_flight_requests: InFlightRequests,
153 ghost: PhantomData<(fn() -> Req, fn(Resp))>,
155}
156
157impl<Req, Resp, T> BaseChannel<Req, Resp, T>
158where
159 T: Transport<Response<Resp>, ClientMessage<Req>>,
160{
161 pub fn new(config: Config, transport: T) -> Self {
163 let (request_cancellation, canceled_requests) = cancellations();
164 BaseChannel {
165 config,
166 transport: transport.fuse(),
167 canceled_requests,
168 request_cancellation,
169 in_flight_requests: InFlightRequests::default(),
170 ghost: PhantomData,
171 }
172 }
173
174 pub fn with_defaults(transport: T) -> Self {
176 Self::new(Config::default(), transport)
177 }
178
179 pub fn get_ref(&self) -> &T {
181 self.transport.get_ref()
182 }
183
184 pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
186 self.project().transport.get_pin_mut()
187 }
188
189 fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
190 self.as_mut().project().in_flight_requests
191 }
192
193 fn canceled_requests_pin_mut<'a>(
194 self: &'a mut Pin<&mut Self>,
195 ) -> Pin<&'a mut CanceledRequests> {
196 self.as_mut().project().canceled_requests
197 }
198
199 fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
200 self.as_mut().project().transport
201 }
202
203 fn start_request(
204 mut self: Pin<&mut Self>,
205 mut request: Request<Req>,
206 ) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
207 let span = info_span!(
208 "RPC",
209 rpc.trace_id = %request.context.trace_id(),
210 rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()),
211 otel.kind = "server",
212 otel.name = tracing::field::Empty,
213 );
214 span.set_context(&request.context);
215 request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
216 tracing::trace!(
217 "OpenTelemetry subscriber not installed; making unsampled \
218 child context."
219 );
220 request.context.trace_context.new_child()
221 });
222 let entered = span.enter();
223 tracing::info!("ReceiveRequest");
224 let start = self.in_flight_requests_mut().start_request(
225 request.id,
226 request.context.deadline,
227 span.clone(),
228 );
229 match start {
230 Ok(abort_registration) => {
231 drop(entered);
232 Ok(TrackedRequest {
233 abort_registration,
234 span,
235 response_guard: ResponseGuard {
236 request_id: request.id,
237 request_cancellation: self.request_cancellation.clone(),
238 cancel: false,
239 },
240 request,
241 })
242 }
243 Err(AlreadyExistsError) => {
244 tracing::trace!("DuplicateRequest");
245 Err(AlreadyExistsError)
246 }
247 }
248 }
249}
250
251impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 write!(f, "BaseChannel")
254 }
255}
256
257#[derive(Debug)]
259pub struct TrackedRequest<Req> {
260 pub request: Request<Req>,
262 pub abort_registration: AbortRegistration,
265 pub span: Span,
267 pub response_guard: ResponseGuard,
269}
270
271pub trait Channel
297where
298 Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
299{
300 type Req;
302
303 type Resp;
305
306 type Transport;
308
309 fn config(&self) -> &Config;
311
312 fn in_flight_requests(&self) -> usize;
314
315 fn transport(&self) -> &Self::Transport;
317
318 fn max_concurrent_requests(
326 self,
327 limit: usize,
328 ) -> limits::requests_per_channel::MaxRequests<Self>
329 where
330 Self: Sized,
331 {
332 limits::requests_per_channel::MaxRequests::new(self, limit)
333 }
334
335 fn requests(self) -> Requests<Self>
370 where
371 Self: Sized,
372 {
373 let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
374
375 Requests {
376 channel: self,
377 pending_responses: responses,
378 responses_tx,
379 }
380 }
381
382 fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
412 where
413 Self: Sized,
414 Self::Req: RequestName,
415 S: Serve<Req = Self::Req, Resp = Self::Resp> + Clone,
416 {
417 self.requests().execute(serve)
418 }
419}
420
421impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
422where
423 T: Transport<Response<Resp>, ClientMessage<Req>>,
424{
425 type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
426
427 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
428 #[derive(Clone, Copy, Debug)]
429 enum ReceiverStatus {
430 Ready,
431 Pending,
432 Closed,
433 }
434
435 impl ReceiverStatus {
436 fn combine(self, other: Self) -> Self {
437 use ReceiverStatus::*;
438 match (self, other) {
439 (Ready, _) | (_, Ready) => Ready,
440 (Closed, Closed) => Closed,
441 (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
442 }
443 }
444 }
445
446 use ReceiverStatus::*;
447
448 loop {
449 let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
450 Poll::Ready(Some(request_id)) => {
451 if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
452 let _entered = span.enter();
453 tracing::info!("ResponseCancelled");
454 }
455 Ready
456 }
457 Poll::Pending | Poll::Ready(None) => Closed,
464 };
465
466 let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) {
467 Poll::Ready(Some(_)) => Ready,
470 Poll::Ready(None) => Closed,
471 Poll::Pending => Pending,
472 };
473
474 let request_status = match self
475 .transport_pin_mut()
476 .poll_next(cx)
477 .map_err(|e| ChannelError::Read(Arc::new(e)))?
478 {
479 Poll::Ready(Some(message)) => match message {
480 ClientMessage::Request(request) => {
481 match self.as_mut().start_request(request) {
482 Ok(request) => return Poll::Ready(Some(Ok(request))),
483 Err(AlreadyExistsError) => {
484 continue;
489 }
490 }
491 }
492 ClientMessage::Cancel {
493 trace_context,
494 request_id,
495 } => {
496 if !self.in_flight_requests_mut().cancel_request(request_id) {
497 tracing::trace!(
498 rpc.trace_id = %trace_context.trace_id,
499 "Received cancellation, but response handler is already complete.",
500 );
501 }
502 Ready
503 }
504 },
505 Poll::Ready(None) => Closed,
506 Poll::Pending => Pending,
507 };
508
509 let status = cancellation_status
510 .combine(expiration_status)
511 .combine(request_status);
512
513 tracing::trace!(
514 "Cancellations: {cancellation_status:?}, \
515 Expired requests: {expiration_status:?}, \
516 Inbound: {request_status:?}, \
517 Overall: {status:?}",
518 );
519 match status {
520 Ready => continue,
521 Closed => return Poll::Ready(None),
522 Pending => return Poll::Pending,
523 }
524 }
525 }
526}
527
528impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
529where
530 T: Transport<Response<Resp>, ClientMessage<Req>>,
531 T::Error: Error,
532{
533 type Error = ChannelError<T::Error>;
534
535 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
536 self.project()
537 .transport
538 .poll_ready(cx)
539 .map_err(|e| ChannelError::Ready(Arc::new(e)))
540 }
541
542 fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
543 if let Some(span) = self
544 .in_flight_requests_mut()
545 .remove_request(response.request_id)
546 {
547 let _entered = span.enter();
548 tracing::info!("SendResponse");
549 self.project()
550 .transport
551 .start_send(response)
552 .map_err(|e| ChannelError::Write(Arc::new(e)))
553 } else {
554 Ok(())
556 }
557 }
558
559 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
560 tracing::trace!("poll_flush");
561 self.project()
562 .transport
563 .poll_flush(cx)
564 .map_err(|e| ChannelError::Flush(Arc::new(e)))
565 }
566
567 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
568 self.project()
569 .transport
570 .poll_close(cx)
571 .map_err(|e| ChannelError::Close(Arc::new(e)))
572 }
573}
574
575impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
576 fn as_ref(&self) -> &T {
577 self.transport.get_ref()
578 }
579}
580
581impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
582where
583 T: Transport<Response<Resp>, ClientMessage<Req>>,
584{
585 type Req = Req;
586 type Resp = Resp;
587 type Transport = T;
588
589 fn config(&self) -> &Config {
590 &self.config
591 }
592
593 fn in_flight_requests(&self) -> usize {
594 self.in_flight_requests.len()
595 }
596
597 fn transport(&self) -> &Self::Transport {
598 self.get_ref()
599 }
600}
601
602#[pin_project]
605pub struct Requests<C>
606where
607 C: Channel,
608{
609 #[pin]
610 channel: C,
611 pending_responses: mpsc::Receiver<Response<C::Resp>>,
613 responses_tx: mpsc::Sender<Response<C::Resp>>,
615}
616
617impl<C> Requests<C>
618where
619 C: Channel,
620{
621 pub fn channel(&self) -> &C {
623 &self.channel
624 }
625
626 pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
628 self.as_mut().project().channel
629 }
630
631 pub fn pending_responses_mut<'a>(
633 self: &'a mut Pin<&mut Self>,
634 ) -> &'a mut mpsc::Receiver<Response<C::Resp>> {
635 self.as_mut().project().pending_responses
636 }
637
638 fn pump_read(
639 mut self: Pin<&mut Self>,
640 cx: &mut Context<'_>,
641 ) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
642 self.channel_pin_mut().poll_next(cx).map_ok(
643 |TrackedRequest {
644 request,
645 abort_registration,
646 span,
647 mut response_guard,
648 }| {
649 response_guard.cancel = true;
651 {
652 let _entered = span.enter();
653 tracing::info!("BeginRequest");
654 }
655 InFlightRequest {
656 request,
657 abort_registration,
658 span,
659 response_guard,
660 response_tx: self.responses_tx.clone(),
661 }
662 },
663 )
664 }
665
666 fn pump_write(
667 mut self: Pin<&mut Self>,
668 cx: &mut Context<'_>,
669 read_half_closed: bool,
670 ) -> Poll<Option<Result<(), C::Error>>> {
671 match self.as_mut().poll_next_response(cx)? {
672 Poll::Ready(Some(response)) => {
673 self.channel_pin_mut().start_send(response)?;
676 Poll::Ready(Some(Ok(())))
677 }
678 Poll::Ready(None) => {
679 ready!(self.channel_pin_mut().poll_flush(cx)?);
681 Poll::Ready(None)
682 }
683 Poll::Pending => {
684 ready!(self.channel_pin_mut().poll_flush(cx)?);
686
687 if read_half_closed && self.channel.in_flight_requests() == 0 {
691 Poll::Ready(None)
692 } else {
693 Poll::Pending
694 }
695 }
696 }
697 }
698
699 fn poll_next_response(
704 mut self: Pin<&mut Self>,
705 cx: &mut Context<'_>,
706 ) -> Poll<Option<Result<Response<C::Resp>, C::Error>>> {
707 ready!(self.ensure_writeable(cx)?);
708
709 match ready!(self.pending_responses_mut().poll_recv(cx)) {
710 Some(response) => Poll::Ready(Some(Ok(response))),
711 None => {
712 Poll::Ready(None)
714 }
715 }
716 }
717
718 fn ensure_writeable<'a>(
721 self: &'a mut Pin<&mut Self>,
722 cx: &mut Context<'_>,
723 ) -> Poll<Option<Result<(), C::Error>>> {
724 while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
725 ready!(self.channel_pin_mut().poll_flush(cx)?);
726 }
727 Poll::Ready(Some(Ok(())))
728 }
729
730 pub fn execute<S>(self, serve: S) -> impl Stream<Item = impl Future<Output = ()>>
759 where
760 C::Req: RequestName,
761 S: Serve<Req = C::Req, Resp = C::Resp> + Clone,
762 {
763 self.take_while(|result| {
764 if let Err(e) = result {
765 tracing::warn!("Requests stream errored out: {}", e);
766 }
767 futures::future::ready(result.is_ok())
768 })
769 .filter_map(|result| async move { result.ok() })
770 .map(move |request| {
771 let serve = serve.clone();
772 request.execute(serve)
773 })
774 }
775}
776
777impl<C> fmt::Debug for Requests<C>
778where
779 C: Channel,
780{
781 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
782 write!(fmt, "Requests")
783 }
784}
785
786#[derive(Debug)]
789pub struct ResponseGuard {
790 request_cancellation: RequestCancellation,
791 request_id: u64,
792 cancel: bool,
793}
794
795impl Drop for ResponseGuard {
796 fn drop(&mut self) {
797 if self.cancel {
798 self.request_cancellation.cancel(self.request_id);
799 }
800 }
801}
802
803#[derive(Debug)]
808pub struct InFlightRequest<Req, Res> {
809 request: Request<Req>,
810 abort_registration: AbortRegistration,
811 response_guard: ResponseGuard,
812 span: Span,
813 response_tx: mpsc::Sender<Response<Res>>,
814}
815
816impl<Req, Res> InFlightRequest<Req, Res> {
817 pub fn get(&self) -> &Request<Req> {
819 &self.request
820 }
821
822 pub async fn execute<S>(self, serve: S)
867 where
868 Req: RequestName,
869 S: Serve<Req = Req, Resp = Res>,
870 {
871 let Self {
872 response_tx,
873 mut response_guard,
874 abort_registration,
875 span,
876 request:
877 Request {
878 context,
879 message,
880 id: request_id,
881 },
882 } = self;
883 span.record("otel.name", message.name());
884 let _ = Abortable::new(
885 async move {
886 let message = serve.serve(context, message).await;
887 tracing::info!("CompleteRequest");
888 let response = Response {
889 request_id,
890 message,
891 };
892 let _ = response_tx.send(response).await;
893 tracing::info!("BufferResponse");
894 },
895 abort_registration,
896 )
897 .instrument(span)
898 .await;
899 response_guard.cancel = false;
903 }
904}
905
906fn print_err(e: &(dyn Error + 'static)) -> String {
907 anyhow::Chain::new(e)
908 .map(|e| e.to_string())
909 .collect::<Vec<_>>()
910 .join(": ")
911}
912
913impl<C> Stream for Requests<C>
914where
915 C: Channel,
916{
917 type Item = Result<InFlightRequest<C::Req, C::Resp>, C::Error>;
918
919 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
920 loop {
921 let read = self.as_mut().pump_read(cx).map_err(|e| {
922 tracing::trace!("read: {}", print_err(&e));
923 e
924 })?;
925 let read_closed = matches!(read, Poll::Ready(None));
926 let write = self.as_mut().pump_write(cx, read_closed).map_err(|e| {
927 tracing::trace!("write: {}", print_err(&e));
928 e
929 })?;
930 match (read, write) {
931 (Poll::Ready(None), Poll::Ready(None)) => {
932 tracing::trace!("read: Poll::Ready(None), write: Poll::Ready(None)");
933 return Poll::Ready(None);
934 }
935 (Poll::Ready(Some(request_handler)), _) => {
936 tracing::trace!("read: Poll::Ready(Some), write: _");
937 return Poll::Ready(Some(Ok(request_handler)));
938 }
939 (_, Poll::Ready(Some(()))) => {
940 tracing::trace!("read: _, write: Poll::Ready(Some)");
941 }
942 (read @ Poll::Pending, write) | (read, write @ Poll::Pending) => {
943 tracing::trace!(
944 "read pending: {}, write pending: {}",
945 read.is_pending(),
946 write.is_pending()
947 );
948 return Poll::Pending;
949 }
950 }
951 }
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::{
958 in_flight_requests::AlreadyExistsError,
959 request_hook::{AfterRequest, BeforeRequest, RequestHook},
960 serve, BaseChannel, Channel, Config, Requests, Serve,
961 };
962 use crate::{
963 context, trace,
964 transport::channel::{self, UnboundedChannel},
965 ClientMessage, Request, Response, ServerError,
966 };
967 use assert_matches::assert_matches;
968 use futures::{
969 future::{pending, AbortRegistration, Abortable, Aborted},
970 prelude::*,
971 Future,
972 };
973 use futures_test::task::noop_context;
974 use std::{
975 io,
976 pin::Pin,
977 task::Poll,
978 time::{Duration, Instant},
979 };
980
981 fn test_channel<Req, Resp>() -> (
982 Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
983 UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
984 ) {
985 let (tx, rx) = crate::transport::channel::unbounded();
986 (Box::pin(BaseChannel::new(Config::default(), rx)), tx)
987 }
988
989 fn test_requests<Req, Resp>() -> (
990 Pin<
991 Box<
992 Requests<
993 BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>,
994 >,
995 >,
996 >,
997 UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
998 ) {
999 let (tx, rx) = crate::transport::channel::unbounded();
1000 (
1001 Box::pin(BaseChannel::new(Config::default(), rx).requests()),
1002 tx,
1003 )
1004 }
1005
1006 fn test_bounded_requests<Req, Resp>(
1007 capacity: usize,
1008 ) -> (
1009 Pin<
1010 Box<
1011 Requests<
1012 BaseChannel<Req, Resp, channel::Channel<ClientMessage<Req>, Response<Resp>>>,
1013 >,
1014 >,
1015 >,
1016 channel::Channel<Response<Resp>, ClientMessage<Req>>,
1017 ) {
1018 let (tx, rx) = crate::transport::channel::bounded(capacity);
1019 let config = Config {
1021 pending_response_buffer: capacity + 1,
1022 };
1023 (Box::pin(BaseChannel::new(config, rx).requests()), tx)
1024 }
1025
1026 fn fake_request<Req>(req: Req) -> ClientMessage<Req> {
1027 ClientMessage::Request(Request {
1028 context: context::current(),
1029 id: 0,
1030 message: req,
1031 })
1032 }
1033
1034 fn test_abortable(
1035 abort_registration: AbortRegistration,
1036 ) -> impl Future<Output = Result<(), Aborted>> {
1037 Abortable::new(pending(), abort_registration)
1038 }
1039
1040 #[tokio::test]
1041 async fn test_serve() {
1042 let serve = serve(|_, i| async move { Ok(i) });
1043 assert_matches!(serve.serve(context::current(), 7).await, Ok(7));
1044 }
1045
1046 #[tokio::test]
1047 async fn serve_before_mutates_context() -> anyhow::Result<()> {
1048 struct SetDeadline(Instant);
1049 impl<Req> BeforeRequest<Req> for SetDeadline {
1050 async fn before(
1051 &mut self,
1052 ctx: &mut context::Context,
1053 _: &Req,
1054 ) -> Result<(), ServerError> {
1055 ctx.deadline = self.0;
1056 Ok(())
1057 }
1058 }
1059
1060 let some_time = Instant::now() + Duration::from_secs(37);
1061 let some_other_time = Instant::now() + Duration::from_secs(83);
1062
1063 let serve = serve(move |ctx: context::Context, i| async move {
1064 assert_eq!(ctx.deadline, some_time);
1065 Ok(i)
1066 });
1067 let deadline_hook = serve.before(SetDeadline(some_time));
1068 let mut ctx = context::current();
1069 ctx.deadline = some_other_time;
1070 deadline_hook.serve(ctx, 7).await?;
1071 Ok(())
1072 }
1073
1074 #[tokio::test]
1075 async fn serve_before_and_after() -> anyhow::Result<()> {
1076 let _ = tracing_subscriber::fmt::try_init();
1077
1078 struct PrintLatency {
1079 start: Instant,
1080 }
1081 impl PrintLatency {
1082 fn new() -> Self {
1083 Self {
1084 start: Instant::now(),
1085 }
1086 }
1087 }
1088 impl<Req> BeforeRequest<Req> for PrintLatency {
1089 async fn before(
1090 &mut self,
1091 _: &mut context::Context,
1092 _: &Req,
1093 ) -> Result<(), ServerError> {
1094 self.start = Instant::now();
1095 Ok(())
1096 }
1097 }
1098 impl<Resp> AfterRequest<Resp> for PrintLatency {
1099 async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
1100 tracing::info!("Elapsed: {:?}", self.start.elapsed());
1101 }
1102 }
1103
1104 let serve = serve(move |_: context::Context, i| async move { Ok(i) });
1105 serve
1106 .before_and_after(PrintLatency::new())
1107 .serve(context::current(), 7)
1108 .await?;
1109 Ok(())
1110 }
1111
1112 #[tokio::test]
1113 async fn serve_before_error_aborts_request() -> anyhow::Result<()> {
1114 let serve = serve(|_, _| async { panic!("Shouldn't get here") });
1115 let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async {
1116 Err(ServerError::new(io::ErrorKind::Other, "oops".into()))
1117 });
1118 let resp: Result<i32, _> = deadline_hook.serve(context::current(), 7).await;
1119 assert_matches!(resp, Err(_));
1120 Ok(())
1121 }
1122
1123 #[tokio::test]
1124 async fn base_channel_start_send_duplicate_request_returns_error() {
1125 let (mut channel, _tx) = test_channel::<(), ()>();
1126
1127 channel
1128 .as_mut()
1129 .start_request(Request {
1130 id: 0,
1131 context: context::current(),
1132 message: (),
1133 })
1134 .unwrap();
1135 assert_matches!(
1136 channel.as_mut().start_request(Request {
1137 id: 0,
1138 context: context::current(),
1139 message: ()
1140 }),
1141 Err(AlreadyExistsError)
1142 );
1143 }
1144
1145 #[tokio::test]
1146 async fn base_channel_poll_next_aborts_multiple_requests() {
1147 let (mut channel, _tx) = test_channel::<(), ()>();
1148
1149 tokio::time::pause();
1150 let req0 = channel
1151 .as_mut()
1152 .start_request(Request {
1153 id: 0,
1154 context: context::current(),
1155 message: (),
1156 })
1157 .unwrap();
1158 let req1 = channel
1159 .as_mut()
1160 .start_request(Request {
1161 id: 1,
1162 context: context::current(),
1163 message: (),
1164 })
1165 .unwrap();
1166 tokio::time::advance(std::time::Duration::from_secs(1000)).await;
1167
1168 assert_matches!(
1169 channel.as_mut().poll_next(&mut noop_context()),
1170 Poll::Pending
1171 );
1172 assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
1173 assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
1174 }
1175
1176 #[tokio::test]
1177 async fn base_channel_poll_next_aborts_canceled_request() {
1178 let (mut channel, mut tx) = test_channel::<(), ()>();
1179
1180 tokio::time::pause();
1181 let req = channel
1182 .as_mut()
1183 .start_request(Request {
1184 id: 0,
1185 context: context::current(),
1186 message: (),
1187 })
1188 .unwrap();
1189
1190 tx.send(ClientMessage::Cancel {
1191 trace_context: trace::Context::default(),
1192 request_id: 0,
1193 })
1194 .await
1195 .unwrap();
1196
1197 assert_matches!(
1198 channel.as_mut().poll_next(&mut noop_context()),
1199 Poll::Pending
1200 );
1201
1202 assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
1203 }
1204
1205 #[tokio::test]
1206 async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() {
1207 let (mut channel, tx) = test_channel::<(), ()>();
1208
1209 tokio::time::pause();
1210 let _abort_registration = channel
1211 .as_mut()
1212 .start_request(Request {
1213 id: 0,
1214 context: context::current(),
1215 message: (),
1216 })
1217 .unwrap();
1218
1219 drop(tx);
1220 assert_matches!(
1221 channel.as_mut().poll_next(&mut noop_context()),
1222 Poll::Pending
1223 );
1224 }
1225
1226 #[tokio::test]
1227 async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() {
1228 let (mut channel, tx) = test_channel::<(), ()>();
1229 drop(tx);
1230 assert_matches!(
1231 channel.as_mut().poll_next(&mut noop_context()),
1232 Poll::Ready(None)
1233 );
1234 }
1235
1236 #[tokio::test]
1237 async fn base_channel_poll_next_yields_request() {
1238 let (mut channel, mut tx) = test_channel::<(), ()>();
1239 tx.send(fake_request(())).await.unwrap();
1240
1241 assert_matches!(
1242 channel.as_mut().poll_next(&mut noop_context()),
1243 Poll::Ready(Some(Ok(_)))
1244 );
1245 }
1246
1247 #[tokio::test]
1248 async fn base_channel_poll_next_aborts_request_and_yields_request() {
1249 let (mut channel, mut tx) = test_channel::<(), ()>();
1250
1251 tokio::time::pause();
1252 let req = channel
1253 .as_mut()
1254 .start_request(Request {
1255 id: 0,
1256 context: context::current(),
1257 message: (),
1258 })
1259 .unwrap();
1260 tokio::time::advance(std::time::Duration::from_secs(1000)).await;
1261
1262 tx.send(fake_request(())).await.unwrap();
1263
1264 assert_matches!(
1265 channel.as_mut().poll_next(&mut noop_context()),
1266 Poll::Ready(Some(Ok(_)))
1267 );
1268 assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
1269 }
1270
1271 #[tokio::test]
1272 async fn base_channel_start_send_removes_in_flight_request() {
1273 let (mut channel, _tx) = test_channel::<(), ()>();
1274
1275 channel
1276 .as_mut()
1277 .start_request(Request {
1278 id: 0,
1279 context: context::current(),
1280 message: (),
1281 })
1282 .unwrap();
1283 assert_eq!(channel.in_flight_requests(), 1);
1284 channel
1285 .as_mut()
1286 .start_send(Response {
1287 request_id: 0,
1288 message: Ok(()),
1289 })
1290 .unwrap();
1291 assert_eq!(channel.in_flight_requests(), 0);
1292 }
1293
1294 #[tokio::test]
1295 async fn in_flight_request_drop_cancels_request() {
1296 let (mut requests, mut tx) = test_requests::<(), ()>();
1297 tx.send(fake_request(())).await.unwrap();
1298
1299 let request = match requests.as_mut().poll_next(&mut noop_context()) {
1300 Poll::Ready(Some(Ok(request))) => request,
1301 result => panic!("Unexpected result: {result:?}"),
1302 };
1303 drop(request);
1304
1305 let poll = requests
1306 .as_mut()
1307 .channel_pin_mut()
1308 .poll_next(&mut noop_context());
1309 assert!(poll.is_pending());
1310 let in_flight_requests = requests.channel().in_flight_requests();
1311 assert_eq!(in_flight_requests, 0);
1312 }
1313
1314 #[tokio::test]
1315 async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
1316 let (mut requests, mut tx) = test_requests::<(), ()>();
1317 tx.send(fake_request(())).await.unwrap();
1318
1319 let request = match requests.as_mut().poll_next(&mut noop_context()) {
1320 Poll::Ready(Some(Ok(request))) => request,
1321 result => panic!("Unexpected result: {result:?}"),
1322 };
1323 request.execute(serve(|_, _| async { Ok(()) })).await;
1324 assert!(requests
1325 .as_mut()
1326 .channel_pin_mut()
1327 .canceled_requests
1328 .poll_recv(&mut noop_context())
1329 .is_pending());
1330 }
1331
1332 #[tokio::test]
1333 async fn requests_poll_next_response_returns_pending_when_buffer_full() {
1334 let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1335
1336 requests
1338 .as_mut()
1339 .channel_pin_mut()
1340 .start_request(Request {
1341 id: 0,
1342 context: context::current(),
1343 message: (),
1344 })
1345 .unwrap();
1346 requests
1347 .as_mut()
1348 .channel_pin_mut()
1349 .start_send(Response {
1350 request_id: 0,
1351 message: Ok(()),
1352 })
1353 .unwrap();
1354
1355 requests
1357 .as_mut()
1358 .project()
1359 .responses_tx
1360 .send(Response {
1361 request_id: 1,
1362 message: Ok(()),
1363 })
1364 .await
1365 .unwrap();
1366
1367 requests
1368 .as_mut()
1369 .channel_pin_mut()
1370 .start_request(Request {
1371 id: 1,
1372 context: context::current(),
1373 message: (),
1374 })
1375 .unwrap();
1376
1377 assert_matches!(
1378 requests.as_mut().poll_next_response(&mut noop_context()),
1379 Poll::Pending
1380 );
1381 }
1382
1383 #[tokio::test]
1384 async fn requests_pump_write_returns_pending_when_buffer_full() {
1385 let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1386
1387 requests
1389 .as_mut()
1390 .channel_pin_mut()
1391 .start_request(Request {
1392 id: 0,
1393 context: context::current(),
1394 message: (),
1395 })
1396 .unwrap();
1397 requests
1398 .as_mut()
1399 .channel_pin_mut()
1400 .start_send(Response {
1401 request_id: 0,
1402 message: Ok(()),
1403 })
1404 .unwrap();
1405
1406 requests
1408 .as_mut()
1409 .channel_pin_mut()
1410 .start_request(Request {
1411 id: 1,
1412 context: context::current(),
1413 message: (),
1414 })
1415 .unwrap();
1416 requests
1417 .as_mut()
1418 .project()
1419 .responses_tx
1420 .send(Response {
1421 request_id: 1,
1422 message: Ok(()),
1423 })
1424 .await
1425 .unwrap();
1426
1427 assert_matches!(
1428 requests.as_mut().pump_write(&mut noop_context(), true),
1429 Poll::Pending
1430 );
1431 assert_matches!(
1433 requests.as_mut().pending_responses_mut().recv().await,
1434 Some(_)
1435 );
1436 }
1437
1438 #[tokio::test]
1439 async fn requests_pump_read() {
1440 let (mut requests, mut tx) = test_requests::<(), ()>();
1441
1442 tx.send(fake_request(())).await.unwrap();
1444
1445 assert_matches!(
1446 requests.as_mut().pump_read(&mut noop_context()),
1447 Poll::Ready(Some(Ok(_)))
1448 );
1449 assert_eq!(requests.channel.in_flight_requests(), 1);
1450 }
1451}