tonic_web/
service.rs

1use core::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5
6use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
7use pin_project::pin_project;
8use tonic::metadata::GRPC_CONTENT_TYPE;
9use tonic::{body::Body, server::NamedService};
10use tower_service::Service;
11use tracing::{debug, trace};
12
13use crate::call::content_types::is_grpc_web;
14use crate::call::{Encoding, GrpcWebCall};
15
16/// Service implementing the grpc-web protocol.
17#[derive(Debug, Clone)]
18pub struct GrpcWebService<S> {
19    inner: S,
20}
21
22#[derive(Debug, PartialEq)]
23enum RequestKind<'a> {
24    // The request is considered a grpc-web request if its `content-type`
25    // header is exactly one of:
26    //
27    //  - "application/grpc-web"
28    //  - "application/grpc-web+proto"
29    //  - "application/grpc-web-text"
30    //  - "application/grpc-web-text+proto"
31    GrpcWeb {
32        method: &'a Method,
33        encoding: Encoding,
34        accept: Encoding,
35    },
36    // All other requests, including `application/grpc`
37    Other(http::Version),
38}
39
40impl<S> GrpcWebService<S> {
41    pub(crate) fn new(inner: S) -> Self {
42        GrpcWebService { inner }
43    }
44}
45
46impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for GrpcWebService<S>
47where
48    S: Service<Request<Body>, Response = Response<ResBody>>,
49    ReqBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
50    ReqBody::Error: Into<crate::BoxError> + fmt::Display,
51    ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
52    ResBody::Error: Into<crate::BoxError> + fmt::Display,
53{
54    type Response = Response<Body>;
55    type Error = S::Error;
56    type Future = ResponseFuture<S::Future>;
57
58    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59        self.inner.poll_ready(cx)
60    }
61
62    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
63        match RequestKind::new(req.headers(), req.method(), req.version()) {
64            // A valid grpc-web request, regardless of HTTP version.
65            //
66            // If the request includes an `origin` header, we verify it is allowed
67            // to access the resource, an HTTP 403 response is returned otherwise.
68            //
69            // If the origin is allowed to access the resource or there is no
70            // `origin` header present, translate the request into a grpc request,
71            // call the inner service, and translate the response back to
72            // grpc-web.
73            RequestKind::GrpcWeb {
74                method: &Method::POST,
75                encoding,
76                accept,
77            } => {
78                trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
79
80                ResponseFuture {
81                    case: Case::GrpcWeb {
82                        future: self.inner.call(coerce_request(req, encoding)),
83                        accept,
84                    },
85                }
86            }
87
88            // The request's content-type matches one of the 4 supported grpc-web
89            // content-types, but the request method is not `POST`.
90            // This is not a valid grpc-web request, return HTTP 405.
91            RequestKind::GrpcWeb { .. } => {
92                debug!(kind = "simple", error="method not allowed", method = ?req.method());
93
94                ResponseFuture {
95                    case: Case::immediate(StatusCode::METHOD_NOT_ALLOWED),
96                }
97            }
98
99            // All http/2 requests that are not grpc-web are passed through to the inner service,
100            // whatever they are.
101            RequestKind::Other(Version::HTTP_2) => {
102                debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
103                ResponseFuture {
104                    case: Case::Other {
105                        future: self.inner.call(req.map(Body::new)),
106                    },
107                }
108            }
109
110            // Return HTTP 400 for all other requests.
111            RequestKind::Other(_) => {
112                debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE));
113
114                ResponseFuture {
115                    case: Case::immediate(StatusCode::BAD_REQUEST),
116                }
117            }
118        }
119    }
120}
121
122/// Response future for the [`GrpcWebService`].
123#[pin_project]
124#[must_use = "futures do nothing unless polled"]
125pub struct ResponseFuture<F> {
126    #[pin]
127    case: Case<F>,
128}
129
130#[pin_project(project = CaseProj)]
131enum Case<F> {
132    GrpcWeb {
133        #[pin]
134        future: F,
135        accept: Encoding,
136    },
137    Other {
138        #[pin]
139        future: F,
140    },
141    ImmediateResponse {
142        res: Option<http::response::Parts>,
143    },
144}
145
146impl<F> Case<F> {
147    fn immediate(status: StatusCode) -> Self {
148        let (res, ()) = Response::builder()
149            .status(status)
150            .body(())
151            .unwrap()
152            .into_parts();
153        Self::ImmediateResponse { res: Some(res) }
154    }
155}
156
157impl<F, B, E> Future for ResponseFuture<F>
158where
159    F: Future<Output = Result<Response<B>, E>>,
160    B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
161    B::Error: Into<crate::BoxError> + fmt::Display,
162{
163    type Output = Result<Response<Body>, E>;
164
165    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166        let this = self.project();
167
168        match this.case.project() {
169            CaseProj::GrpcWeb { future, accept } => {
170                let res = ready!(future.poll(cx))?;
171
172                Poll::Ready(Ok(coerce_response(res, *accept)))
173            }
174            CaseProj::Other { future } => future.poll(cx).map_ok(|res| res.map(Body::new)),
175            CaseProj::ImmediateResponse { res } => {
176                let res = Response::from_parts(res.take().unwrap(), Body::empty());
177                Poll::Ready(Ok(res))
178            }
179        }
180    }
181}
182
183impl<S: NamedService> NamedService for GrpcWebService<S> {
184    const NAME: &'static str = S::NAME;
185}
186
187impl<F> fmt::Debug for ResponseFuture<F> {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        f.debug_struct("ResponseFuture").finish()
190    }
191}
192
193impl<'a> RequestKind<'a> {
194    fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self {
195        if is_grpc_web(headers) {
196            return RequestKind::GrpcWeb {
197                method,
198                encoding: Encoding::from_content_type(headers),
199                accept: Encoding::from_accept(headers),
200            };
201        }
202
203        RequestKind::Other(version)
204    }
205}
206
207// Mutating request headers to conform to a gRPC request is not really
208// necessary for us at this point. We could remove most of these except
209// maybe for inserting `header::TE`, which tonic should check?
210fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
211where
212    B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
213    B::Error: Into<crate::BoxError> + fmt::Display,
214{
215    req.headers_mut().remove(header::CONTENT_LENGTH);
216
217    req.headers_mut()
218        .insert(header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
219
220    req.headers_mut()
221        .insert(header::TE, HeaderValue::from_static("trailers"));
222
223    req.headers_mut().insert(
224        header::ACCEPT_ENCODING,
225        HeaderValue::from_static("identity,deflate,gzip"),
226    );
227
228    req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
229}
230
231fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
232where
233    B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
234    B::Error: Into<crate::BoxError> + fmt::Display,
235{
236    let mut res = res
237        .map(|b| GrpcWebCall::response(b, encoding))
238        .map(Body::new);
239
240    res.headers_mut().insert(
241        header::CONTENT_TYPE,
242        HeaderValue::from_static(encoding.to_content_type()),
243    );
244
245    res
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::call::content_types::*;
252    use http::header::{
253        ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
254    };
255    use tower_layer::Layer as _;
256
257    type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
258
259    #[derive(Debug, Clone)]
260    struct Svc;
261
262    impl<B> tower_service::Service<Request<B>> for Svc {
263        type Response = Response<Body>;
264        type Error = std::convert::Infallible;
265        type Future = BoxFuture<Self::Response, Self::Error>;
266
267        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
268            Poll::Ready(Ok(()))
269        }
270
271        fn call(&mut self, _: Request<B>) -> Self::Future {
272            Box::pin(async { Ok(Response::new(Body::default())) })
273        }
274    }
275
276    impl NamedService for Svc {
277        const NAME: &'static str = "test";
278    }
279
280    fn enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>>
281    where
282        S: Service<http::Request<Body>, Response = http::Response<Body>>,
283    {
284        tower_layer::Stack::new(
285            crate::GrpcWebLayer::new(),
286            tower_http::cors::CorsLayer::new(),
287        )
288        .layer(service)
289    }
290
291    mod grpc_web {
292        use super::*;
293        use tower_layer::Layer;
294
295        fn request() -> Request<Body> {
296            Request::builder()
297                .method(Method::POST)
298                .header(CONTENT_TYPE, GRPC_WEB)
299                .header(ORIGIN, "http://example.com")
300                .body(Body::default())
301                .unwrap()
302        }
303
304        #[tokio::test]
305        async fn default_cors_config() {
306            let mut svc = enable(Svc);
307            let res = svc.call(request()).await.unwrap();
308
309            assert_eq!(res.status(), StatusCode::OK);
310        }
311
312        #[tokio::test]
313        async fn web_layer() {
314            let mut svc = crate::GrpcWebLayer::new().layer(Svc);
315            let res = svc.call(request()).await.unwrap();
316
317            assert_eq!(res.status(), StatusCode::OK);
318        }
319
320        #[tokio::test]
321        async fn web_layer_with_axum() {
322            let mut svc = axum::routing::Router::new()
323                .route("/", axum::routing::post_service(Svc))
324                .layer(crate::GrpcWebLayer::new());
325
326            let res = svc.call(request()).await.unwrap();
327
328            assert_eq!(res.status(), StatusCode::OK);
329        }
330
331        #[tokio::test]
332        async fn without_origin() {
333            let mut svc = enable(Svc);
334
335            let mut req = request();
336            req.headers_mut().remove(ORIGIN);
337
338            let res = svc.call(req).await.unwrap();
339
340            assert_eq!(res.status(), StatusCode::OK);
341        }
342
343        #[tokio::test]
344        async fn only_post_and_options_allowed() {
345            let mut svc = enable(Svc);
346
347            for method in &[
348                Method::GET,
349                Method::PUT,
350                Method::DELETE,
351                Method::HEAD,
352                Method::PATCH,
353            ] {
354                let mut req = request();
355                *req.method_mut() = method.clone();
356
357                let res = svc.call(req).await.unwrap();
358
359                assert_eq!(
360                    res.status(),
361                    StatusCode::METHOD_NOT_ALLOWED,
362                    "{method} should not be allowed"
363                );
364            }
365        }
366
367        #[tokio::test]
368        async fn grpc_web_content_types() {
369            let mut svc = enable(Svc);
370
371            for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_TEXT_PROTO, GRPC_WEB] {
372                let mut req = request();
373                req.headers_mut()
374                    .insert(CONTENT_TYPE, HeaderValue::from_static(ct));
375
376                let res = svc.call(req).await.unwrap();
377
378                assert_eq!(res.status(), StatusCode::OK);
379            }
380        }
381    }
382
383    mod options {
384        use super::*;
385
386        fn request() -> Request<Body> {
387            Request::builder()
388                .method(Method::OPTIONS)
389                .header(ORIGIN, "http://example.com")
390                .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web")
391                .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
392                .body(Body::default())
393                .unwrap()
394        }
395
396        #[tokio::test]
397        async fn valid_grpc_web_preflight() {
398            let mut svc = enable(Svc);
399            let res = svc.call(request()).await.unwrap();
400
401            assert_eq!(res.status(), StatusCode::OK);
402        }
403    }
404
405    mod grpc {
406        use super::*;
407
408        fn request() -> Request<Body> {
409            Request::builder()
410                .version(Version::HTTP_2)
411                .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
412                .body(Body::default())
413                .unwrap()
414        }
415
416        #[tokio::test]
417        async fn h2_is_ok() {
418            let mut svc = enable(Svc);
419
420            let req = request();
421            let res = svc.call(req).await.unwrap();
422
423            assert_eq!(res.status(), StatusCode::OK)
424        }
425
426        #[tokio::test]
427        async fn h1_is_err() {
428            let mut svc = enable(Svc);
429
430            let req = Request::builder()
431                .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
432                .body(Body::default())
433                .unwrap();
434
435            let res = svc.call(req).await.unwrap();
436            assert_eq!(res.status(), StatusCode::BAD_REQUEST)
437        }
438
439        #[tokio::test]
440        async fn content_type_variants() {
441            let mut svc = enable(Svc);
442
443            for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] {
444                let mut req = request();
445                req.headers_mut().insert(
446                    CONTENT_TYPE,
447                    HeaderValue::from_maybe_shared(format!("application/{variant}")).unwrap(),
448                );
449
450                let res = svc.call(req).await.unwrap();
451
452                assert_eq!(res.status(), StatusCode::OK)
453            }
454        }
455    }
456
457    mod other {
458        use super::*;
459
460        fn request() -> Request<Body> {
461            Request::builder()
462                .header(CONTENT_TYPE, "application/text")
463                .body(Body::default())
464                .unwrap()
465        }
466
467        #[tokio::test]
468        async fn h1_is_err() {
469            let mut svc = enable(Svc);
470            let res = svc.call(request()).await.unwrap();
471
472            assert_eq!(res.status(), StatusCode::BAD_REQUEST)
473        }
474
475        #[tokio::test]
476        async fn h2_is_ok() {
477            let mut svc = enable(Svc);
478            let mut req = request();
479            *req.version_mut() = Version::HTTP_2;
480
481            let res = svc.call(req).await.unwrap();
482            assert_eq!(res.status(), StatusCode::OK)
483        }
484    }
485}