tonic_async_interceptor/
lib.rs

1//! gRPC interceptors which are a kind of middleware.
2//!
3//! See [`Interceptor`] for more details.
4
5use bytes::Bytes;
6use http::{Extensions, Method, Uri, Version};
7use http_body::Body;
8use pin_project::pin_project;
9use std::{
10    fmt,
11    future::Future,
12    mem,
13    pin::Pin,
14    task::{Context, Poll},
15};
16use tonic::body::Body as TonicBody;
17use tonic::metadata::MetadataMap;
18use tonic::{Request, Status};
19use tower_layer::Layer;
20use tower_service::Service;
21
22pub type Error = Box<dyn std::error::Error + Send + Sync>;
23
24/// An async gRPC interceptor.
25///
26/// This interceptor is an `async` variant of Tonic's built-in [`Interceptor`].
27///
28/// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
29/// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
30/// request. Two, cancel a request with a `Status`.
31///
32/// Any function that satisfies the bound `async FnMut(Request<()>) -> Result<Request<()>, Status>`
33/// can be used as an `AsyncInterceptor`.
34///
35/// An interceptor can be used on both the server and client side through the `tonic-build` crate's
36/// generated structs.
37///
38/// See the [interceptor example][example] for more details.
39///
40/// If you need more powerful middleware, [tower] is the recommended approach. You can find
41/// examples of how to use tower with tonic [here][tower-example].
42///
43/// Additionally, interceptors is not the recommended way to add logging to your service. For that
44/// a [tower] middleware is more appropriate since it can also act on the response. For example
45/// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
46/// middleware supports gRPC out of the box.
47///
48/// [tower]: https://crates.io/crates/tower
49/// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
50/// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
51///
52/// Async version of `Interceptor`.
53pub trait AsyncInterceptor {
54    /// The Future returned by the interceptor.
55    type Future: Future<Output = Result<Request<()>, Status>>;
56    /// Intercept a request before it is sent, optionally cancelling it.
57    fn call(&mut self, request: Request<()>) -> Self::Future;
58}
59
60impl<F, U> AsyncInterceptor for F
61where
62    F: FnMut(Request<()>) -> U,
63    U: Future<Output = Result<Request<()>, Status>>,
64{
65    type Future = U;
66
67    fn call(&mut self, request: Request<()>) -> Self::Future {
68        self(request)
69    }
70}
71
72/// Create a new async interceptor layer.
73///
74/// See [`AsyncInterceptor`] and [`Interceptor`] for more details.
75pub fn async_interceptor<F>(f: F) -> AsyncInterceptorLayer<F>
76where
77    F: AsyncInterceptor,
78{
79    AsyncInterceptorLayer { f }
80}
81
82/// A gRPC async interceptor that can be used as a [`Layer`],
83/// created by calling [`async_interceptor`].
84///
85/// See [`AsyncInterceptor`] for more details.
86#[derive(Debug, Clone, Copy)]
87pub struct AsyncInterceptorLayer<F> {
88    f: F,
89}
90
91impl<S, F> Layer<S> for AsyncInterceptorLayer<F>
92where
93    S: Clone,
94    F: AsyncInterceptor + Clone,
95{
96    type Service = AsyncInterceptedService<S, F>;
97
98    fn layer(&self, service: S) -> Self::Service {
99        AsyncInterceptedService::new(service, self.f.clone())
100    }
101}
102
103// Components and attributes of a request, without metadata or extensions.
104#[derive(Debug)]
105struct DecomposedRequest<ReqBody> {
106    uri: Uri,
107    method: Method,
108    http_version: Version,
109    msg: ReqBody,
110}
111
112// Note that tonic::Request::into_parts is not public, so we do it this way.
113fn request_into_parts<Msg>(mut req: Request<Msg>) -> (MetadataMap, Extensions, Msg) {
114    // We use mem::take because Tonic doesn't not provide public access to these fields.
115    let metadata = mem::take(req.metadata_mut());
116    let extensions = mem::take(req.extensions_mut());
117    (metadata, extensions, req.into_inner())
118}
119
120// Note that tonic::Request::from_parts is not public, so we do it this way.
121fn request_from_parts<Msg>(
122    msg: Msg,
123    metadata: MetadataMap,
124    extensions: Extensions,
125) -> Request<Msg> {
126    let mut req = Request::new(msg);
127    *req.metadata_mut() = metadata;
128    *req.extensions_mut() = extensions;
129    req
130}
131
132// Note that tonic::Request::into_http is not public, so we do it this way.
133fn request_into_http<Msg>(
134    msg: Msg,
135    uri: http::Uri,
136    method: http::Method,
137    version: http::Version,
138    metadata: MetadataMap,
139    extensions: Extensions,
140) -> http::Request<Msg> {
141    let mut request = http::Request::new(msg);
142    *request.version_mut() = version;
143    *request.method_mut() = method;
144    *request.uri_mut() = uri;
145    *request.headers_mut() = metadata.into_headers();
146    *request.extensions_mut() = extensions;
147
148    request
149}
150
151/// Decompose the request into its contents and properties, and create a new request without a body.
152///
153/// It is bad practice to modify the body (i.e. Message) of the request via an interceptor.
154/// To avoid exposing the body of the request to the interceptor function, we first remove it
155/// here, allow the interceptor to modify the metadata and extensions, and then recreate the
156/// HTTP request with the original message body with the `recompose` function. Also note that Tonic
157/// requests do not preserve the URI, HTTP version, and HTTP method of the HTTP request, so we
158/// extract them here and then add them back in `recompose`.
159fn decompose<ReqBody>(req: http::Request<ReqBody>) -> (DecomposedRequest<ReqBody>, Request<()>) {
160    let uri = req.uri().clone();
161    let method = req.method().clone();
162    let http_version = req.version();
163    let req = Request::from_http(req);
164    let (metadata, extensions, msg) = request_into_parts(req);
165
166    let dreq = DecomposedRequest {
167        uri,
168        method,
169        http_version,
170        msg,
171    };
172    let req_without_body = request_from_parts((), metadata, extensions);
173
174    (dreq, req_without_body)
175}
176
177/// Combine the modified metadata and extensions with the original message body and attributes.
178fn recompose<ReqBody>(
179    dreq: DecomposedRequest<ReqBody>,
180    modified_req: Request<()>,
181) -> http::Request<ReqBody> {
182    let (metadata, extensions, _) = request_into_parts(modified_req);
183
184    request_into_http(
185        dreq.msg,
186        dreq.uri,
187        dreq.method,
188        dreq.http_version,
189        metadata,
190        extensions,
191    )
192}
193
194/// A service wrapped in an async interceptor middleware.
195///
196/// See [`AsyncInterceptor`] for more details.
197#[derive(Clone, Copy)]
198pub struct AsyncInterceptedService<S, F> {
199    inner: S,
200    f: F,
201}
202
203impl<S, F> AsyncInterceptedService<S, F> {
204    /// Create a new `AsyncInterceptedService` that wraps `S` and intercepts each request with the
205    /// function `F`.
206    pub fn new(service: S, f: F) -> Self {
207        Self { inner: service, f }
208    }
209}
210
211impl<S, F> fmt::Debug for AsyncInterceptedService<S, F>
212where
213    S: fmt::Debug,
214{
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        f.debug_struct("AsyncInterceptedService")
217            .field("inner", &self.inner)
218            .field("f", &format_args!("{}", std::any::type_name::<F>()))
219            .finish()
220    }
221}
222
223impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for AsyncInterceptedService<S, F>
224where
225    F: AsyncInterceptor + Clone,
226    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>> + Clone,
227    S::Error: Into<Error>,
228    ReqBody: Default,
229    ResBody: Default + Body<Data = Bytes> + Send + 'static,
230    ResBody::Error: Into<Error>,
231{
232    type Response = http::Response<TonicBody>;
233    type Error = S::Error;
234    type Future = AsyncResponseFuture<S, F::Future, ReqBody>;
235
236    #[inline]
237    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
238        self.inner.poll_ready(cx)
239    }
240
241    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
242        // This is necessary because tonic internally uses `tower::buffer::Buffer`.
243        // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149
244        // for details on why this is necessary
245        let clone = self.inner.clone();
246        let inner = std::mem::replace(&mut self.inner, clone);
247
248        AsyncResponseFuture::new(req, &mut self.f, inner)
249    }
250}
251
252// required to use `AsyncInterceptedService` with `Router`
253impl<S, F> tonic::server::NamedService for AsyncInterceptedService<S, F>
254where
255    S: tonic::server::NamedService,
256{
257    const NAME: &'static str = S::NAME;
258}
259
260/// Response future for [`InterceptedService`].
261#[pin_project]
262#[derive(Debug)]
263pub struct ResponseFuture<F> {
264    #[pin]
265    kind: Kind<F>,
266}
267
268impl<F> ResponseFuture<F> {
269    fn future(future: F) -> Self {
270        Self {
271            kind: Kind::Future(future),
272        }
273    }
274
275    fn status(status: Status) -> Self {
276        Self {
277            kind: Kind::Status(Some(status)),
278        }
279    }
280}
281
282#[pin_project(project = KindProj)]
283#[derive(Debug)]
284enum Kind<F> {
285    Future(#[pin] F),
286    Status(Option<Status>),
287}
288
289impl<F, E, B> Future for ResponseFuture<F>
290where
291    F: Future<Output = Result<http::Response<B>, E>>,
292    E: Into<Error>,
293    B: Default + Body<Data = Bytes> + Send + 'static,
294    B::Error: Into<Error>,
295{
296    type Output = Result<http::Response<TonicBody>, E>;
297
298    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299        match self.project().kind.project() {
300            KindProj::Future(future) => future
301                .poll(cx)
302                .map(|result| result.map(|resp| resp.map(TonicBody::new))),
303            KindProj::Status(status) => {
304                let response = status.take().unwrap().into_http();
305                Poll::Ready(Ok(response))
306            }
307        }
308    }
309}
310
311#[pin_project(project = PinnedOptionProj)]
312#[derive(Debug)]
313enum PinnedOption<F> {
314    Some(#[pin] F),
315    None,
316}
317
318/// Response future for [`AsyncInterceptedService`].
319///
320/// Handles the call to the async interceptor, then calls the inner service and wraps the result in
321/// [`ResponseFuture`].
322#[pin_project(project = AsyncResponseFutureProj)]
323#[derive(Debug)]
324pub struct AsyncResponseFuture<S, I, ReqBody>
325where
326    S: Service<http::Request<ReqBody>>,
327    S::Error: Into<Error>,
328    I: Future<Output = Result<Request<()>, Status>>,
329{
330    #[pin]
331    interceptor_fut: PinnedOption<I>,
332    #[pin]
333    inner_fut: PinnedOption<ResponseFuture<S::Future>>,
334    inner: S,
335    dreq: DecomposedRequest<ReqBody>,
336}
337
338impl<S, I, ReqBody> AsyncResponseFuture<S, I, ReqBody>
339where
340    S: Service<http::Request<ReqBody>>,
341    S::Error: Into<Error>,
342    I: Future<Output = Result<Request<()>, Status>>,
343    ReqBody: Default,
344{
345    fn new<A: AsyncInterceptor<Future = I>>(
346        req: http::Request<ReqBody>,
347        interceptor: &mut A,
348        inner: S,
349    ) -> Self {
350        let (dreq, req_without_body) = decompose(req);
351        let interceptor_fut = interceptor.call(req_without_body);
352
353        AsyncResponseFuture {
354            interceptor_fut: PinnedOption::Some(interceptor_fut),
355            inner_fut: PinnedOption::None,
356            inner,
357            dreq,
358        }
359    }
360
361    /// Calls the inner service with the intercepted request (which has been modified by the
362    /// async interceptor func).
363    fn create_inner_fut(
364        this: &mut AsyncResponseFutureProj<'_, S, I, ReqBody>,
365        intercepted_req: Result<Request<()>, Status>,
366    ) -> ResponseFuture<S::Future> {
367        match intercepted_req {
368            Ok(req) => {
369                // We can't move the message body out of the pin projection. So, to
370                // avoid copying it, we swap its memory with an empty body and then can
371                // move it into the recomposed request.
372                let msg = mem::take(&mut this.dreq.msg);
373                let movable_dreq = DecomposedRequest {
374                    uri: this.dreq.uri.clone(),
375                    method: this.dreq.method.clone(),
376                    http_version: this.dreq.http_version,
377                    msg,
378                };
379                let modified_req_with_body = recompose(movable_dreq, req);
380
381                ResponseFuture::future(this.inner.call(modified_req_with_body))
382            }
383            Err(status) => ResponseFuture::status(status),
384        }
385    }
386}
387
388impl<S, I, ReqBody, ResBody> Future for AsyncResponseFuture<S, I, ReqBody>
389where
390    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
391    I: Future<Output = Result<Request<()>, Status>>,
392    S::Error: Into<Error>,
393    ReqBody: Default,
394    ResBody: Default + Body<Data = Bytes> + Send + 'static,
395    ResBody::Error: Into<Error>,
396{
397    type Output = Result<http::Response<TonicBody>, S::Error>;
398
399    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
400        let mut this = self.project();
401
402        // The struct was initialized (via `new`) with interceptor func future, which we poll here.
403        if let PinnedOptionProj::Some(f) = this.interceptor_fut.as_mut().project() {
404            match f.poll(cx) {
405                Poll::Ready(intercepted_req) => {
406                    let inner_fut = AsyncResponseFuture::<S, I, ReqBody>::create_inner_fut(
407                        &mut this,
408                        intercepted_req,
409                    );
410                    // Set the inner service future and clear the interceptor future.
411                    this.inner_fut.set(PinnedOption::Some(inner_fut));
412                    this.interceptor_fut.set(PinnedOption::None);
413                }
414                Poll::Pending => return Poll::Pending,
415            }
416        }
417        // At this point, inner_fut should always be Some.
418        let inner_fut = match this.inner_fut.project() {
419            PinnedOptionProj::None => panic!(),
420            PinnedOptionProj::Some(f) => f,
421        };
422
423        inner_fut.poll(cx)
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use http::StatusCode;
431    use http_body_util::Empty;
432    use std::future;
433    use tower::ServiceExt;
434
435    #[tokio::test]
436    async fn propagates_added_extensions() {
437        #[derive(Clone)]
438        struct TestExtension {
439            data: String,
440        }
441        let test_extension_data = "abc";
442
443        let layer = async_interceptor(|mut req: Request<()>| {
444            req.extensions_mut().insert(TestExtension {
445                data: test_extension_data.to_owned(),
446            });
447
448            future::ready(Ok(req))
449        });
450
451        let svc = layer.layer(tower::service_fn(
452            |http_req: http::Request<Empty<Bytes>>| async {
453                let req = Request::from_http(http_req);
454                let maybe_extension = req.extensions().get::<TestExtension>();
455                assert!(maybe_extension.is_some());
456                assert_eq!(maybe_extension.unwrap().data, test_extension_data);
457
458                Ok::<_, Status>(http::Response::new(Empty::new()))
459            },
460        ));
461
462        let request = http::Request::builder().body(Empty::new()).unwrap();
463        let http_response = svc.oneshot(request).await.unwrap();
464
465        assert_eq!(http_response.status(), StatusCode::OK);
466    }
467
468    #[tokio::test]
469    async fn propagates_added_metadata() {
470        let test_metadata_key = "test_key";
471        let test_metadata_val = "abc";
472
473        let layer = async_interceptor(|mut req: Request<()>| {
474            req.metadata_mut()
475                .insert(test_metadata_key, test_metadata_val.parse().unwrap());
476
477            future::ready(Ok(req))
478        });
479
480        let svc = layer.layer(tower::service_fn(
481            |http_req: http::Request<Empty<Bytes>>| async {
482                let req = Request::from_http(http_req);
483                let maybe_metadata = req.metadata().get(test_metadata_key);
484                assert!(maybe_metadata.is_some());
485                assert_eq!(maybe_metadata.unwrap(), test_metadata_val);
486
487                Ok::<_, Status>(http::Response::new(Empty::new()))
488            },
489        ));
490
491        let request = http::Request::builder().body(Empty::new()).unwrap();
492        let http_response = svc.oneshot(request).await.unwrap();
493
494        assert_eq!(http_response.status(), StatusCode::OK);
495    }
496
497    #[tokio::test]
498    async fn doesnt_remove_headers_from_request() {
499        let layer = async_interceptor(|request: Request<()>| {
500            assert_eq!(
501                request
502                    .metadata()
503                    .get("user-agent")
504                    .expect("missing in interceptor"),
505                "test-tonic"
506            );
507            future::ready(Ok(request))
508        });
509
510        let svc = layer.layer(tower::service_fn(
511            |request: http::Request<Empty<Bytes>>| async move {
512                assert_eq!(
513                    request
514                        .headers()
515                        .get("user-agent")
516                        .expect("missing in leaf service"),
517                    "test-tonic"
518                );
519
520                Ok::<_, Status>(http::Response::new(Empty::new()))
521            },
522        ));
523
524        let request = http::Request::builder()
525            .header("user-agent", "test-tonic")
526            .body(Empty::new())
527            .unwrap();
528
529        svc.oneshot(request).await.unwrap();
530    }
531
532    #[tokio::test]
533    async fn handles_intercepted_status_as_response() {
534        let message = "Blocked by the interceptor";
535        let expected = Status::permission_denied(message).into_http::<TonicBody>();
536
537        let layer = async_interceptor(|_: Request<()>| {
538            future::ready(Err(Status::permission_denied(message)))
539        });
540
541        let svc = layer.layer(tower::service_fn(|_: http::Request<Empty<Bytes>>| async {
542            Ok::<_, Status>(http::Response::new(Empty::new()))
543        }));
544
545        let request = http::Request::builder().body(Empty::new()).unwrap();
546        let response = svc.oneshot(request).await.unwrap();
547
548        assert_eq!(expected.status(), response.status());
549        assert_eq!(expected.version(), response.version());
550        assert_eq!(expected.headers(), response.headers());
551    }
552
553    #[tokio::test]
554    async fn doesnt_change_http_method() {
555        let layer = async_interceptor(|request: Request<()>| future::ready(Ok(request)));
556
557        let svc = layer.layer(tower::service_fn(
558            |request: http::Request<Empty<Bytes>>| async move {
559                assert_eq!(request.method(), http::Method::OPTIONS);
560
561                Ok::<_, Status>(http::Response::new(Empty::new()))
562            },
563        ));
564
565        let request = http::Request::builder()
566            .method(http::Method::OPTIONS)
567            .body(Empty::new())
568            .unwrap();
569
570        svc.oneshot(request).await.unwrap();
571    }
572}