1use bytes::Bytes;
6use http::{Extensions, Method, Uri, Version};
7use http_body::Body;
8use http_body_util::BodyExt;
9use pin_project::pin_project;
10use std::{
11 fmt,
12 future::Future,
13 mem,
14 pin::Pin,
15 task::{Context, Poll},
16};
17use tonic::metadata::MetadataMap;
18use tonic::{body::BoxBody, Request, Status};
19use tower_layer::Layer;
20use tower_service::Service;
21
22pub type Error = Box<dyn std::error::Error + Send + Sync>;
23
24pub trait AsyncInterceptor {
54 type Future: Future<Output = Result<Request<()>, Status>>;
56 fn call(&mut self, request: Request<()>) -> Self::Future;
58}
59
60impl<F, U> AsyncInterceptor for F
61where
62 F: FnMut(Request<()>) -> U,
63 U: Future<Output = Result<Request<()>, Status>>,
64{
65 type Future = U;
66
67 fn call(&mut self, request: Request<()>) -> Self::Future {
68 self(request)
69 }
70}
71
72pub fn async_interceptor<F>(f: F) -> AsyncInterceptorLayer<F>
76where
77 F: AsyncInterceptor,
78{
79 AsyncInterceptorLayer { f }
80}
81
82#[derive(Debug, Clone, Copy)]
87pub struct AsyncInterceptorLayer<F> {
88 f: F,
89}
90
91impl<S, F> Layer<S> for AsyncInterceptorLayer<F>
92where
93 S: Clone,
94 F: AsyncInterceptor + Clone,
95{
96 type Service = AsyncInterceptedService<S, F>;
97
98 fn layer(&self, service: S) -> Self::Service {
99 AsyncInterceptedService::new(service, self.f.clone())
100 }
101}
102
103fn boxed<B>(body: B) -> BoxBody
105where
106 B: Body<Data = Bytes> + Send + 'static,
107 B::Error: Into<Error>,
108{
109 body.map_err(|e| Status::from_error(e.into()))
110 .boxed_unsync()
111}
112
113#[derive(Debug)]
115struct DecomposedRequest<ReqBody> {
116 uri: Uri,
117 method: Method,
118 http_version: Version,
119 msg: ReqBody,
120}
121
122fn request_into_parts<Msg>(mut req: Request<Msg>) -> (MetadataMap, Extensions, Msg) {
124 let metadata = mem::take(req.metadata_mut());
126 let extensions = mem::take(req.extensions_mut());
127 (metadata, extensions, req.into_inner())
128}
129
130fn request_from_parts<Msg>(
132 msg: Msg,
133 metadata: MetadataMap,
134 extensions: Extensions,
135) -> Request<Msg> {
136 let mut req = Request::new(msg);
137 *req.metadata_mut() = metadata;
138 *req.extensions_mut() = extensions;
139 req
140}
141
142fn request_into_http<Msg>(
144 msg: Msg,
145 uri: http::Uri,
146 method: http::Method,
147 version: http::Version,
148 metadata: MetadataMap,
149 extensions: Extensions,
150) -> http::Request<Msg> {
151 let mut request = http::Request::new(msg);
152 *request.version_mut() = version;
153 *request.method_mut() = method;
154 *request.uri_mut() = uri;
155 *request.headers_mut() = metadata.into_headers();
156 *request.extensions_mut() = extensions;
157
158 request
159}
160
161fn decompose<ReqBody>(req: http::Request<ReqBody>) -> (DecomposedRequest<ReqBody>, Request<()>) {
170 let uri = req.uri().clone();
171 let method = req.method().clone();
172 let http_version = req.version();
173 let req = Request::from_http(req);
174 let (metadata, extensions, msg) = request_into_parts(req);
175
176 let dreq = DecomposedRequest {
177 uri,
178 method,
179 http_version,
180 msg,
181 };
182 let req_without_body = request_from_parts((), metadata, extensions);
183
184 (dreq, req_without_body)
185}
186
187fn recompose<ReqBody>(
189 dreq: DecomposedRequest<ReqBody>,
190 modified_req: Request<()>,
191) -> http::Request<ReqBody> {
192 let (metadata, extensions, _) = request_into_parts(modified_req);
193
194 request_into_http(
195 dreq.msg,
196 dreq.uri,
197 dreq.method,
198 dreq.http_version,
199 metadata,
200 extensions,
201 )
202}
203
204#[derive(Clone, Copy)]
208pub struct AsyncInterceptedService<S, F> {
209 inner: S,
210 f: F,
211}
212
213impl<S, F> AsyncInterceptedService<S, F> {
214 pub fn new(service: S, f: F) -> Self {
217 Self { inner: service, f }
218 }
219}
220
221impl<S, F> fmt::Debug for AsyncInterceptedService<S, F>
222where
223 S: fmt::Debug,
224{
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 f.debug_struct("AsyncInterceptedService")
227 .field("inner", &self.inner)
228 .field("f", &format_args!("{}", std::any::type_name::<F>()))
229 .finish()
230 }
231}
232
233impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for AsyncInterceptedService<S, F>
234where
235 F: AsyncInterceptor + Clone,
236 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone,
237 S::Error: Into<Error>,
238 ReqBody: Default,
239 ResBody: Default + Body<Data = Bytes> + Send + 'static,
240 ResBody::Error: Into<Error>,
241{
242 type Response = http::Response<BoxBody>;
243 type Error = S::Error;
244 type Future = AsyncResponseFuture<S, F::Future, ReqBody>;
245
246 #[inline]
247 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248 self.inner.poll_ready(cx)
249 }
250
251 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
252 let clone = self.inner.clone();
256 let inner = std::mem::replace(&mut self.inner, clone);
257
258 AsyncResponseFuture::new(req, &mut self.f, inner)
259 }
260}
261
262impl<S, F> tonic::server::NamedService for AsyncInterceptedService<S, F>
264where
265 S: tonic::server::NamedService,
266{
267 const NAME: &'static str = S::NAME;
268}
269
270#[pin_project]
272#[derive(Debug)]
273pub struct ResponseFuture<F> {
274 #[pin]
275 kind: Kind<F>,
276}
277
278impl<F> ResponseFuture<F> {
279 fn future(future: F) -> Self {
280 Self {
281 kind: Kind::Future(future),
282 }
283 }
284
285 fn status(status: Status) -> Self {
286 Self {
287 kind: Kind::Status(Some(status)),
288 }
289 }
290}
291
292#[pin_project(project = KindProj)]
293#[derive(Debug)]
294enum Kind<F> {
295 Future(#[pin] F),
296 Status(Option<Status>),
297}
298
299impl<F, E, B> Future for ResponseFuture<F>
300where
301 F: Future<Output = Result<http::Response<B>, E>>,
302 E: Into<Error>,
303 B: Default + Body<Data = Bytes> + Send + 'static,
304 B::Error: Into<Error>,
305{
306 type Output = Result<http::Response<BoxBody>, E>;
307
308 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309 match self.project().kind.project() {
310 KindProj::Future(future) => future
311 .poll(cx)
312 .map(|result| result.map(|res| res.map(boxed))),
313 KindProj::Status(status) => {
314 let response = status
315 .take()
316 .unwrap()
317 .into_http()
318 .map(|_| B::default())
319 .map(boxed);
320 Poll::Ready(Ok(response))
321 }
322 }
323 }
324}
325
326#[pin_project(project = PinnedOptionProj)]
327#[derive(Debug)]
328enum PinnedOption<F> {
329 Some(#[pin] F),
330 None,
331}
332
333#[pin_project(project = AsyncResponseFutureProj)]
338#[derive(Debug)]
339pub struct AsyncResponseFuture<S, I, ReqBody>
340where
341 S: Service<http::Request<ReqBody>>,
342 S::Error: Into<Error>,
343 I: Future<Output = Result<Request<()>, Status>>,
344{
345 #[pin]
346 interceptor_fut: PinnedOption<I>,
347 #[pin]
348 inner_fut: PinnedOption<ResponseFuture<S::Future>>,
349 inner: S,
350 dreq: DecomposedRequest<ReqBody>,
351}
352
353impl<S, I, ReqBody> AsyncResponseFuture<S, I, ReqBody>
354where
355 S: Service<http::Request<ReqBody>>,
356 S::Error: Into<Error>,
357 I: Future<Output = Result<Request<()>, Status>>,
358 ReqBody: Default,
359{
360 fn new<A: AsyncInterceptor<Future = I>>(
361 req: http::Request<ReqBody>,
362 interceptor: &mut A,
363 inner: S,
364 ) -> Self {
365 let (dreq, req_without_body) = decompose(req);
366 let interceptor_fut = interceptor.call(req_without_body);
367
368 AsyncResponseFuture {
369 interceptor_fut: PinnedOption::Some(interceptor_fut),
370 inner_fut: PinnedOption::None,
371 inner,
372 dreq,
373 }
374 }
375
376 fn create_inner_fut(
379 this: &mut AsyncResponseFutureProj<'_, S, I, ReqBody>,
380 intercepted_req: Result<Request<()>, Status>,
381 ) -> ResponseFuture<S::Future> {
382 match intercepted_req {
383 Ok(req) => {
384 let msg = mem::take(&mut this.dreq.msg);
388 let movable_dreq = DecomposedRequest {
389 uri: this.dreq.uri.clone(),
390 method: this.dreq.method.clone(),
391 http_version: this.dreq.http_version,
392 msg,
393 };
394 let modified_req_with_body = recompose(movable_dreq, req);
395
396 ResponseFuture::future(this.inner.call(modified_req_with_body))
397 }
398 Err(status) => ResponseFuture::status(status),
399 }
400 }
401}
402
403impl<S, I, ReqBody, ResBody> Future for AsyncResponseFuture<S, I, ReqBody>
404where
405 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
406 I: Future<Output = Result<Request<()>, Status>>,
407 S::Error: Into<Error>,
408 ReqBody: Default,
409 ResBody: Default + Body<Data = Bytes> + Send + 'static,
410 ResBody::Error: Into<Error>,
411{
412 type Output = Result<http::Response<BoxBody>, S::Error>;
413
414 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
415 let mut this = self.project();
416
417 if let PinnedOptionProj::Some(f) = this.interceptor_fut.as_mut().project() {
419 match f.poll(cx) {
420 Poll::Ready(intercepted_req) => {
421 let inner_fut = AsyncResponseFuture::<S, I, ReqBody>::create_inner_fut(
422 &mut this,
423 intercepted_req,
424 );
425 this.inner_fut.set(PinnedOption::Some(inner_fut));
427 this.interceptor_fut.set(PinnedOption::None);
428 }
429 Poll::Pending => return Poll::Pending,
430 }
431 }
432 let inner_fut = match this.inner_fut.project() {
434 PinnedOptionProj::None => panic!(),
435 PinnedOptionProj::Some(f) => f,
436 };
437
438 inner_fut.poll(cx)
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use http::StatusCode;
446 use http_body_util::Empty;
447 use std::future;
448 use tower::ServiceExt;
449
450 #[tokio::test]
451 async fn propagates_added_extensions() {
452 #[derive(Clone)]
453 struct TestExtension {
454 data: String,
455 }
456 let test_extension_data = "abc";
457
458 let layer = async_interceptor(|mut req: Request<()>| {
459 req.extensions_mut().insert(TestExtension {
460 data: test_extension_data.to_owned(),
461 });
462
463 future::ready(Ok(req))
464 });
465
466 let svc = layer.layer(tower::service_fn(
467 |http_req: http::Request<Empty<Bytes>>| async {
468 let req = Request::from_http(http_req);
469 let maybe_extension = req.extensions().get::<TestExtension>();
470 assert!(maybe_extension.is_some());
471 assert_eq!(maybe_extension.unwrap().data, test_extension_data);
472
473 Ok::<_, Status>(http::Response::new(Empty::new()))
474 },
475 ));
476
477 let request = http::Request::builder().body(Empty::new()).unwrap();
478 let http_response = svc.oneshot(request).await.unwrap();
479
480 assert_eq!(http_response.status(), StatusCode::OK);
481 }
482
483 #[tokio::test]
484 async fn propagates_added_metadata() {
485 let test_metadata_key = "test_key";
486 let test_metadata_val = "abc";
487
488 let layer = async_interceptor(|mut req: Request<()>| {
489 req.metadata_mut()
490 .insert(test_metadata_key, test_metadata_val.parse().unwrap());
491
492 future::ready(Ok(req))
493 });
494
495 let svc = layer.layer(tower::service_fn(
496 |http_req: http::Request<Empty<Bytes>>| async {
497 let req = Request::from_http(http_req);
498 let maybe_metadata = req.metadata().get(test_metadata_key);
499 assert!(maybe_metadata.is_some());
500 assert_eq!(maybe_metadata.unwrap(), test_metadata_val);
501
502 Ok::<_, Status>(http::Response::new(Empty::new()))
503 },
504 ));
505
506 let request = http::Request::builder().body(Empty::new()).unwrap();
507 let http_response = svc.oneshot(request).await.unwrap();
508
509 assert_eq!(http_response.status(), StatusCode::OK);
510 }
511
512 #[tokio::test]
513 async fn doesnt_remove_headers_from_request() {
514 let layer = async_interceptor(|request: Request<()>| {
515 assert_eq!(
516 request
517 .metadata()
518 .get("user-agent")
519 .expect("missing in interceptor"),
520 "test-tonic"
521 );
522 future::ready(Ok(request))
523 });
524
525 let svc = layer.layer(tower::service_fn(
526 |request: http::Request<Empty<Bytes>>| async move {
527 assert_eq!(
528 request
529 .headers()
530 .get("user-agent")
531 .expect("missing in leaf service"),
532 "test-tonic"
533 );
534
535 Ok::<_, Status>(http::Response::new(Empty::new()))
536 },
537 ));
538
539 let request = http::Request::builder()
540 .header("user-agent", "test-tonic")
541 .body(Empty::new())
542 .unwrap();
543
544 svc.oneshot(request).await.unwrap();
545 }
546
547 #[tokio::test]
548 async fn handles_intercepted_status_as_response() {
549 let message = "Blocked by the interceptor";
550 let expected = Status::permission_denied(message).into_http();
551
552 let layer = async_interceptor(|_: Request<()>| {
553 future::ready(Err(Status::permission_denied(message)))
554 });
555
556 let svc = layer.layer(tower::service_fn(|_: http::Request<Empty<Bytes>>| async {
557 Ok::<_, Status>(http::Response::new(Empty::new()))
558 }));
559
560 let request = http::Request::builder().body(Empty::new()).unwrap();
561 let response = svc.oneshot(request).await.unwrap();
562
563 assert_eq!(expected.status(), response.status());
564 assert_eq!(expected.version(), response.version());
565 assert_eq!(expected.headers(), response.headers());
566 }
567
568 #[tokio::test]
569 async fn doesnt_change_http_method() {
570 let layer = async_interceptor(|request: Request<()>| future::ready(Ok(request)));
571
572 let svc = layer.layer(tower::service_fn(
573 |request: http::Request<Empty<Bytes>>| async move {
574 assert_eq!(request.method(), http::Method::OPTIONS);
575
576 Ok::<_, Status>(http::Response::new(Empty::new()))
577 },
578 ));
579
580 let request = http::Request::builder()
581 .method(http::Method::OPTIONS)
582 .body(Empty::new())
583 .unwrap();
584
585 svc.oneshot(request).await.unwrap();
586 }
587}