tonic_async_interceptor/
lib.rs

1//! gRPC interceptors which are a kind of middleware.
2//!
3//! See [`Interceptor`] for more details.
4
5use bytes::Bytes;
6use http::{Extensions, Method, Uri, Version};
7use http_body::Body;
8use http_body_util::BodyExt;
9use pin_project::pin_project;
10use std::{
11    fmt,
12    future::Future,
13    mem,
14    pin::Pin,
15    task::{Context, Poll},
16};
17use tonic::metadata::MetadataMap;
18use tonic::{body::BoxBody, Request, Status};
19use tower_layer::Layer;
20use tower_service::Service;
21
22pub type Error = Box<dyn std::error::Error + Send + Sync>;
23
24/// An async gRPC interceptor.
25///
26/// This interceptor is an `async` variant of Tonic's built-in [`Interceptor`].
27///
28/// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
29/// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
30/// request. Two, cancel a request with a `Status`.
31///
32/// Any function that satisfies the bound `async FnMut(Request<()>) -> Result<Request<()>, Status>`
33/// can be used as an `AsyncInterceptor`.
34///
35/// An interceptor can be used on both the server and client side through the `tonic-build` crate's
36/// generated structs.
37///
38/// See the [interceptor example][example] for more details.
39///
40/// If you need more powerful middleware, [tower] is the recommended approach. You can find
41/// examples of how to use tower with tonic [here][tower-example].
42///
43/// Additionally, interceptors is not the recommended way to add logging to your service. For that
44/// a [tower] middleware is more appropriate since it can also act on the response. For example
45/// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
46/// middleware supports gRPC out of the box.
47///
48/// [tower]: https://crates.io/crates/tower
49/// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
50/// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
51///
52/// Async version of `Interceptor`.
53pub trait AsyncInterceptor {
54    /// The Future returned by the interceptor.
55    type Future: Future<Output = Result<Request<()>, Status>>;
56    /// Intercept a request before it is sent, optionally cancelling it.
57    fn call(&mut self, request: Request<()>) -> Self::Future;
58}
59
60impl<F, U> AsyncInterceptor for F
61where
62    F: FnMut(Request<()>) -> U,
63    U: Future<Output = Result<Request<()>, Status>>,
64{
65    type Future = U;
66
67    fn call(&mut self, request: Request<()>) -> Self::Future {
68        self(request)
69    }
70}
71
72/// Create a new async interceptor layer.
73///
74/// See [`AsyncInterceptor`] and [`Interceptor`] for more details.
75pub fn async_interceptor<F>(f: F) -> AsyncInterceptorLayer<F>
76where
77    F: AsyncInterceptor,
78{
79    AsyncInterceptorLayer { f }
80}
81
82/// A gRPC async interceptor that can be used as a [`Layer`],
83/// created by calling [`async_interceptor`].
84///
85/// See [`AsyncInterceptor`] for more details.
86#[derive(Debug, Clone, Copy)]
87pub struct AsyncInterceptorLayer<F> {
88    f: F,
89}
90
91impl<S, F> Layer<S> for AsyncInterceptorLayer<F>
92where
93    S: Clone,
94    F: AsyncInterceptor + Clone,
95{
96    type Service = AsyncInterceptedService<S, F>;
97
98    fn layer(&self, service: S) -> Self::Service {
99        AsyncInterceptedService::new(service, self.f.clone())
100    }
101}
102
103/// Convert a [`http_body::Body`] into a [`BoxBody`].
104fn boxed<B>(body: B) -> BoxBody
105where
106    B: Body<Data = Bytes> + Send + 'static,
107    B::Error: Into<Error>,
108{
109    body.map_err(|e| Status::from_error(e.into()))
110        .boxed_unsync()
111}
112
113// Components and attributes of a request, without metadata or extensions.
114#[derive(Debug)]
115struct DecomposedRequest<ReqBody> {
116    uri: Uri,
117    method: Method,
118    http_version: Version,
119    msg: ReqBody,
120}
121
122// Note that tonic::Request::into_parts is not public, so we do it this way.
123fn request_into_parts<Msg>(mut req: Request<Msg>) -> (MetadataMap, Extensions, Msg) {
124    // We use mem::take because Tonic doesn't not provide public access to these fields.
125    let metadata = mem::take(req.metadata_mut());
126    let extensions = mem::take(req.extensions_mut());
127    (metadata, extensions, req.into_inner())
128}
129
130// Note that tonic::Request::from_parts is not public, so we do it this way.
131fn request_from_parts<Msg>(
132    msg: Msg,
133    metadata: MetadataMap,
134    extensions: Extensions,
135) -> Request<Msg> {
136    let mut req = Request::new(msg);
137    *req.metadata_mut() = metadata;
138    *req.extensions_mut() = extensions;
139    req
140}
141
142// Note that tonic::Request::into_http is not public, so we do it this way.
143fn request_into_http<Msg>(
144    msg: Msg,
145    uri: http::Uri,
146    method: http::Method,
147    version: http::Version,
148    metadata: MetadataMap,
149    extensions: Extensions,
150) -> http::Request<Msg> {
151    let mut request = http::Request::new(msg);
152    *request.version_mut() = version;
153    *request.method_mut() = method;
154    *request.uri_mut() = uri;
155    *request.headers_mut() = metadata.into_headers();
156    *request.extensions_mut() = extensions;
157
158    request
159}
160
161/// Decompose the request into its contents and properties, and create a new request without a body.
162///
163/// It is bad practice to modify the body (i.e. Message) of the request via an interceptor.
164/// To avoid exposing the body of the request to the interceptor function, we first remove it
165/// here, allow the interceptor to modify the metadata and extensions, and then recreate the
166/// HTTP request with the original message body with the `recompose` function. Also note that Tonic
167/// requests do not preserve the URI, HTTP version, and HTTP method of the HTTP request, so we
168/// extract them here and then add them back in `recompose`.
169fn decompose<ReqBody>(req: http::Request<ReqBody>) -> (DecomposedRequest<ReqBody>, Request<()>) {
170    let uri = req.uri().clone();
171    let method = req.method().clone();
172    let http_version = req.version();
173    let req = Request::from_http(req);
174    let (metadata, extensions, msg) = request_into_parts(req);
175
176    let dreq = DecomposedRequest {
177        uri,
178        method,
179        http_version,
180        msg,
181    };
182    let req_without_body = request_from_parts((), metadata, extensions);
183
184    (dreq, req_without_body)
185}
186
187/// Combine the modified metadata and extensions with the original message body and attributes.
188fn recompose<ReqBody>(
189    dreq: DecomposedRequest<ReqBody>,
190    modified_req: Request<()>,
191) -> http::Request<ReqBody> {
192    let (metadata, extensions, _) = request_into_parts(modified_req);
193
194    request_into_http(
195        dreq.msg,
196        dreq.uri,
197        dreq.method,
198        dreq.http_version,
199        metadata,
200        extensions,
201    )
202}
203
204/// A service wrapped in an async interceptor middleware.
205///
206/// See [`AsyncInterceptor`] for more details.
207#[derive(Clone, Copy)]
208pub struct AsyncInterceptedService<S, F> {
209    inner: S,
210    f: F,
211}
212
213impl<S, F> AsyncInterceptedService<S, F> {
214    /// Create a new `AsyncInterceptedService` that wraps `S` and intercepts each request with the
215    /// function `F`.
216    pub fn new(service: S, f: F) -> Self {
217        Self { inner: service, f }
218    }
219}
220
221impl<S, F> fmt::Debug for AsyncInterceptedService<S, F>
222where
223    S: fmt::Debug,
224{
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        f.debug_struct("AsyncInterceptedService")
227            .field("inner", &self.inner)
228            .field("f", &format_args!("{}", std::any::type_name::<F>()))
229            .finish()
230    }
231}
232
233impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for AsyncInterceptedService<S, F>
234where
235    F: AsyncInterceptor + Clone,
236    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone,
237    S::Error: Into<Error>,
238    ReqBody: Default,
239    ResBody: Default + Body<Data = Bytes> + Send + 'static,
240    ResBody::Error: Into<Error>,
241{
242    type Response = http::Response<BoxBody>;
243    type Error = S::Error;
244    type Future = AsyncResponseFuture<S, F::Future, ReqBody>;
245
246    #[inline]
247    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        self.inner.poll_ready(cx)
249    }
250
251    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
252        // This is necessary because tonic internally uses `tower::buffer::Buffer`.
253        // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149
254        // for details on why this is necessary
255        let clone = self.inner.clone();
256        let inner = std::mem::replace(&mut self.inner, clone);
257
258        AsyncResponseFuture::new(req, &mut self.f, inner)
259    }
260}
261
262// required to use `AsyncInterceptedService` with `Router`
263impl<S, F> tonic::server::NamedService for AsyncInterceptedService<S, F>
264where
265    S: tonic::server::NamedService,
266{
267    const NAME: &'static str = S::NAME;
268}
269
270/// Response future for [`InterceptedService`].
271#[pin_project]
272#[derive(Debug)]
273pub struct ResponseFuture<F> {
274    #[pin]
275    kind: Kind<F>,
276}
277
278impl<F> ResponseFuture<F> {
279    fn future(future: F) -> Self {
280        Self {
281            kind: Kind::Future(future),
282        }
283    }
284
285    fn status(status: Status) -> Self {
286        Self {
287            kind: Kind::Status(Some(status)),
288        }
289    }
290}
291
292#[pin_project(project = KindProj)]
293#[derive(Debug)]
294enum Kind<F> {
295    Future(#[pin] F),
296    Status(Option<Status>),
297}
298
299impl<F, E, B> Future for ResponseFuture<F>
300where
301    F: Future<Output = Result<http::Response<B>, E>>,
302    E: Into<Error>,
303    B: Default + Body<Data = Bytes> + Send + 'static,
304    B::Error: Into<Error>,
305{
306    type Output = Result<http::Response<BoxBody>, E>;
307
308    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309        match self.project().kind.project() {
310            KindProj::Future(future) => future
311                .poll(cx)
312                .map(|result| result.map(|res| res.map(boxed))),
313            KindProj::Status(status) => {
314                let response = status
315                    .take()
316                    .unwrap()
317                    .into_http()
318                    .map(|_| B::default())
319                    .map(boxed);
320                Poll::Ready(Ok(response))
321            }
322        }
323    }
324}
325
326#[pin_project(project = PinnedOptionProj)]
327#[derive(Debug)]
328enum PinnedOption<F> {
329    Some(#[pin] F),
330    None,
331}
332
333/// Response future for [`AsyncInterceptedService`].
334///
335/// Handles the call to the async interceptor, then calls the inner service and wraps the result in
336/// [`ResponseFuture`].
337#[pin_project(project = AsyncResponseFutureProj)]
338#[derive(Debug)]
339pub struct AsyncResponseFuture<S, I, ReqBody>
340where
341    S: Service<http::Request<ReqBody>>,
342    S::Error: Into<Error>,
343    I: Future<Output = Result<Request<()>, Status>>,
344{
345    #[pin]
346    interceptor_fut: PinnedOption<I>,
347    #[pin]
348    inner_fut: PinnedOption<ResponseFuture<S::Future>>,
349    inner: S,
350    dreq: DecomposedRequest<ReqBody>,
351}
352
353impl<S, I, ReqBody> AsyncResponseFuture<S, I, ReqBody>
354where
355    S: Service<http::Request<ReqBody>>,
356    S::Error: Into<Error>,
357    I: Future<Output = Result<Request<()>, Status>>,
358    ReqBody: Default,
359{
360    fn new<A: AsyncInterceptor<Future = I>>(
361        req: http::Request<ReqBody>,
362        interceptor: &mut A,
363        inner: S,
364    ) -> Self {
365        let (dreq, req_without_body) = decompose(req);
366        let interceptor_fut = interceptor.call(req_without_body);
367
368        AsyncResponseFuture {
369            interceptor_fut: PinnedOption::Some(interceptor_fut),
370            inner_fut: PinnedOption::None,
371            inner,
372            dreq,
373        }
374    }
375
376    /// Calls the inner service with the intercepted request (which has been modified by the
377    /// async interceptor func).
378    fn create_inner_fut(
379        this: &mut AsyncResponseFutureProj<'_, S, I, ReqBody>,
380        intercepted_req: Result<Request<()>, Status>,
381    ) -> ResponseFuture<S::Future> {
382        match intercepted_req {
383            Ok(req) => {
384                // We can't move the message body out of the pin projection. So, to
385                // avoid copying it, we swap its memory with an empty body and then can
386                // move it into the recomposed request.
387                let msg = mem::take(&mut this.dreq.msg);
388                let movable_dreq = DecomposedRequest {
389                    uri: this.dreq.uri.clone(),
390                    method: this.dreq.method.clone(),
391                    http_version: this.dreq.http_version,
392                    msg,
393                };
394                let modified_req_with_body = recompose(movable_dreq, req);
395
396                ResponseFuture::future(this.inner.call(modified_req_with_body))
397            }
398            Err(status) => ResponseFuture::status(status),
399        }
400    }
401}
402
403impl<S, I, ReqBody, ResBody> Future for AsyncResponseFuture<S, I, ReqBody>
404where
405    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
406    I: Future<Output = Result<Request<()>, Status>>,
407    S::Error: Into<Error>,
408    ReqBody: Default,
409    ResBody: Default + Body<Data = Bytes> + Send + 'static,
410    ResBody::Error: Into<Error>,
411{
412    type Output = Result<http::Response<BoxBody>, S::Error>;
413
414    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
415        let mut this = self.project();
416
417        // The struct was initialized (via `new`) with interceptor func future, which we poll here.
418        if let PinnedOptionProj::Some(f) = this.interceptor_fut.as_mut().project() {
419            match f.poll(cx) {
420                Poll::Ready(intercepted_req) => {
421                    let inner_fut = AsyncResponseFuture::<S, I, ReqBody>::create_inner_fut(
422                        &mut this,
423                        intercepted_req,
424                    );
425                    // Set the inner service future and clear the interceptor future.
426                    this.inner_fut.set(PinnedOption::Some(inner_fut));
427                    this.interceptor_fut.set(PinnedOption::None);
428                }
429                Poll::Pending => return Poll::Pending,
430            }
431        }
432        // At this point, inner_fut should always be Some.
433        let inner_fut = match this.inner_fut.project() {
434            PinnedOptionProj::None => panic!(),
435            PinnedOptionProj::Some(f) => f,
436        };
437
438        inner_fut.poll(cx)
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use http::StatusCode;
446    use http_body_util::Empty;
447    use std::future;
448    use tower::ServiceExt;
449
450    #[tokio::test]
451    async fn propagates_added_extensions() {
452        #[derive(Clone)]
453        struct TestExtension {
454            data: String,
455        }
456        let test_extension_data = "abc";
457
458        let layer = async_interceptor(|mut req: Request<()>| {
459            req.extensions_mut().insert(TestExtension {
460                data: test_extension_data.to_owned(),
461            });
462
463            future::ready(Ok(req))
464        });
465
466        let svc = layer.layer(tower::service_fn(
467            |http_req: http::Request<Empty<Bytes>>| async {
468                let req = Request::from_http(http_req);
469                let maybe_extension = req.extensions().get::<TestExtension>();
470                assert!(maybe_extension.is_some());
471                assert_eq!(maybe_extension.unwrap().data, test_extension_data);
472
473                Ok::<_, Status>(http::Response::new(Empty::new()))
474            },
475        ));
476
477        let request = http::Request::builder().body(Empty::new()).unwrap();
478        let http_response = svc.oneshot(request).await.unwrap();
479
480        assert_eq!(http_response.status(), StatusCode::OK);
481    }
482
483    #[tokio::test]
484    async fn propagates_added_metadata() {
485        let test_metadata_key = "test_key";
486        let test_metadata_val = "abc";
487
488        let layer = async_interceptor(|mut req: Request<()>| {
489            req.metadata_mut()
490                .insert(test_metadata_key, test_metadata_val.parse().unwrap());
491
492            future::ready(Ok(req))
493        });
494
495        let svc = layer.layer(tower::service_fn(
496            |http_req: http::Request<Empty<Bytes>>| async {
497                let req = Request::from_http(http_req);
498                let maybe_metadata = req.metadata().get(test_metadata_key);
499                assert!(maybe_metadata.is_some());
500                assert_eq!(maybe_metadata.unwrap(), test_metadata_val);
501
502                Ok::<_, Status>(http::Response::new(Empty::new()))
503            },
504        ));
505
506        let request = http::Request::builder().body(Empty::new()).unwrap();
507        let http_response = svc.oneshot(request).await.unwrap();
508
509        assert_eq!(http_response.status(), StatusCode::OK);
510    }
511
512    #[tokio::test]
513    async fn doesnt_remove_headers_from_request() {
514        let layer = async_interceptor(|request: Request<()>| {
515            assert_eq!(
516                request
517                    .metadata()
518                    .get("user-agent")
519                    .expect("missing in interceptor"),
520                "test-tonic"
521            );
522            future::ready(Ok(request))
523        });
524
525        let svc = layer.layer(tower::service_fn(
526            |request: http::Request<Empty<Bytes>>| async move {
527                assert_eq!(
528                    request
529                        .headers()
530                        .get("user-agent")
531                        .expect("missing in leaf service"),
532                    "test-tonic"
533                );
534
535                Ok::<_, Status>(http::Response::new(Empty::new()))
536            },
537        ));
538
539        let request = http::Request::builder()
540            .header("user-agent", "test-tonic")
541            .body(Empty::new())
542            .unwrap();
543
544        svc.oneshot(request).await.unwrap();
545    }
546
547    #[tokio::test]
548    async fn handles_intercepted_status_as_response() {
549        let message = "Blocked by the interceptor";
550        let expected = Status::permission_denied(message).into_http();
551
552        let layer = async_interceptor(|_: Request<()>| {
553            future::ready(Err(Status::permission_denied(message)))
554        });
555
556        let svc = layer.layer(tower::service_fn(|_: http::Request<Empty<Bytes>>| async {
557            Ok::<_, Status>(http::Response::new(Empty::new()))
558        }));
559
560        let request = http::Request::builder().body(Empty::new()).unwrap();
561        let response = svc.oneshot(request).await.unwrap();
562
563        assert_eq!(expected.status(), response.status());
564        assert_eq!(expected.version(), response.version());
565        assert_eq!(expected.headers(), response.headers());
566    }
567
568    #[tokio::test]
569    async fn doesnt_change_http_method() {
570        let layer = async_interceptor(|request: Request<()>| future::ready(Ok(request)));
571
572        let svc = layer.layer(tower::service_fn(
573            |request: http::Request<Empty<Bytes>>| async move {
574                assert_eq!(request.method(), http::Method::OPTIONS);
575
576                Ok::<_, Status>(http::Response::new(Empty::new()))
577            },
578        ));
579
580        let request = http::Request::builder()
581            .method(http::Method::OPTIONS)
582            .body(Empty::new())
583            .unwrap();
584
585        svc.oneshot(request).await.unwrap();
586    }
587}