1use core::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5
6use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
7use pin_project::pin_project;
8use tonic::metadata::GRPC_CONTENT_TYPE;
9use tonic::{body::Body, server::NamedService};
10use tower_service::Service;
11use tracing::{debug, trace};
12
13use crate::call::content_types::is_grpc_web;
14use crate::call::{Encoding, GrpcWebCall};
15
16#[derive(Debug, Clone)]
18pub struct GrpcWebService<S> {
19 inner: S,
20}
21
22#[derive(Debug, PartialEq)]
23enum RequestKind<'a> {
24 GrpcWeb {
32 method: &'a Method,
33 encoding: Encoding,
34 accept: Encoding,
35 },
36 Other(http::Version),
38}
39
40impl<S> GrpcWebService<S> {
41 pub(crate) fn new(inner: S) -> Self {
42 GrpcWebService { inner }
43 }
44}
45
46impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for GrpcWebService<S>
47where
48 S: Service<Request<Body>, Response = Response<ResBody>>,
49 ReqBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
50 ReqBody::Error: Into<crate::BoxError> + fmt::Display,
51 ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
52 ResBody::Error: Into<crate::BoxError> + fmt::Display,
53{
54 type Response = Response<Body>;
55 type Error = S::Error;
56 type Future = ResponseFuture<S::Future>;
57
58 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59 self.inner.poll_ready(cx)
60 }
61
62 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
63 match RequestKind::new(req.headers(), req.method(), req.version()) {
64 RequestKind::GrpcWeb {
74 method: &Method::POST,
75 encoding,
76 accept,
77 } => {
78 trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
79
80 ResponseFuture {
81 case: Case::GrpcWeb {
82 future: self.inner.call(coerce_request(req, encoding)),
83 accept,
84 },
85 }
86 }
87
88 RequestKind::GrpcWeb { .. } => {
92 debug!(kind = "simple", error="method not allowed", method = ?req.method());
93
94 ResponseFuture {
95 case: Case::immediate(StatusCode::METHOD_NOT_ALLOWED),
96 }
97 }
98
99 RequestKind::Other(Version::HTTP_2) => {
102 debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
103 ResponseFuture {
104 case: Case::Other {
105 future: self.inner.call(req.map(Body::new)),
106 },
107 }
108 }
109
110 RequestKind::Other(_) => {
112 debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE));
113
114 ResponseFuture {
115 case: Case::immediate(StatusCode::BAD_REQUEST),
116 }
117 }
118 }
119 }
120}
121
122#[pin_project]
124#[must_use = "futures do nothing unless polled"]
125pub struct ResponseFuture<F> {
126 #[pin]
127 case: Case<F>,
128}
129
130#[pin_project(project = CaseProj)]
131enum Case<F> {
132 GrpcWeb {
133 #[pin]
134 future: F,
135 accept: Encoding,
136 },
137 Other {
138 #[pin]
139 future: F,
140 },
141 ImmediateResponse {
142 res: Option<http::response::Parts>,
143 },
144}
145
146impl<F> Case<F> {
147 fn immediate(status: StatusCode) -> Self {
148 let (res, ()) = Response::builder()
149 .status(status)
150 .body(())
151 .unwrap()
152 .into_parts();
153 Self::ImmediateResponse { res: Some(res) }
154 }
155}
156
157impl<F, B, E> Future for ResponseFuture<F>
158where
159 F: Future<Output = Result<Response<B>, E>>,
160 B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
161 B::Error: Into<crate::BoxError> + fmt::Display,
162{
163 type Output = Result<Response<Body>, E>;
164
165 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 let this = self.project();
167
168 match this.case.project() {
169 CaseProj::GrpcWeb { future, accept } => {
170 let res = ready!(future.poll(cx))?;
171
172 Poll::Ready(Ok(coerce_response(res, *accept)))
173 }
174 CaseProj::Other { future } => future.poll(cx).map_ok(|res| res.map(Body::new)),
175 CaseProj::ImmediateResponse { res } => {
176 let res = Response::from_parts(res.take().unwrap(), Body::empty());
177 Poll::Ready(Ok(res))
178 }
179 }
180 }
181}
182
183impl<S: NamedService> NamedService for GrpcWebService<S> {
184 const NAME: &'static str = S::NAME;
185}
186
187impl<F> fmt::Debug for ResponseFuture<F> {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 f.debug_struct("ResponseFuture").finish()
190 }
191}
192
193impl<'a> RequestKind<'a> {
194 fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self {
195 if is_grpc_web(headers) {
196 return RequestKind::GrpcWeb {
197 method,
198 encoding: Encoding::from_content_type(headers),
199 accept: Encoding::from_accept(headers),
200 };
201 }
202
203 RequestKind::Other(version)
204 }
205}
206
207fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
211where
212 B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
213 B::Error: Into<crate::BoxError> + fmt::Display,
214{
215 req.headers_mut().remove(header::CONTENT_LENGTH);
216
217 req.headers_mut()
218 .insert(header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
219
220 req.headers_mut()
221 .insert(header::TE, HeaderValue::from_static("trailers"));
222
223 req.headers_mut().insert(
224 header::ACCEPT_ENCODING,
225 HeaderValue::from_static("identity,deflate,gzip"),
226 );
227
228 req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
229}
230
231fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
232where
233 B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
234 B::Error: Into<crate::BoxError> + fmt::Display,
235{
236 let mut res = res
237 .map(|b| GrpcWebCall::response(b, encoding))
238 .map(Body::new);
239
240 res.headers_mut().insert(
241 header::CONTENT_TYPE,
242 HeaderValue::from_static(encoding.to_content_type()),
243 );
244
245 res
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::call::content_types::*;
252 use http::header::{
253 ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
254 };
255 use tower_layer::Layer as _;
256
257 type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
258
259 #[derive(Debug, Clone)]
260 struct Svc;
261
262 impl<B> tower_service::Service<Request<B>> for Svc {
263 type Response = Response<Body>;
264 type Error = std::convert::Infallible;
265 type Future = BoxFuture<Self::Response, Self::Error>;
266
267 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
268 Poll::Ready(Ok(()))
269 }
270
271 fn call(&mut self, _: Request<B>) -> Self::Future {
272 Box::pin(async { Ok(Response::new(Body::default())) })
273 }
274 }
275
276 impl NamedService for Svc {
277 const NAME: &'static str = "test";
278 }
279
280 fn enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>>
281 where
282 S: Service<http::Request<Body>, Response = http::Response<Body>>,
283 {
284 tower_layer::Stack::new(
285 crate::GrpcWebLayer::new(),
286 tower_http::cors::CorsLayer::new(),
287 )
288 .layer(service)
289 }
290
291 mod grpc_web {
292 use super::*;
293 use tower_layer::Layer;
294
295 fn request() -> Request<Body> {
296 Request::builder()
297 .method(Method::POST)
298 .header(CONTENT_TYPE, GRPC_WEB)
299 .header(ORIGIN, "http://example.com")
300 .body(Body::default())
301 .unwrap()
302 }
303
304 #[tokio::test]
305 async fn default_cors_config() {
306 let mut svc = enable(Svc);
307 let res = svc.call(request()).await.unwrap();
308
309 assert_eq!(res.status(), StatusCode::OK);
310 }
311
312 #[tokio::test]
313 async fn web_layer() {
314 let mut svc = crate::GrpcWebLayer::new().layer(Svc);
315 let res = svc.call(request()).await.unwrap();
316
317 assert_eq!(res.status(), StatusCode::OK);
318 }
319
320 #[tokio::test]
321 async fn web_layer_with_axum() {
322 let mut svc = axum::routing::Router::new()
323 .route("/", axum::routing::post_service(Svc))
324 .layer(crate::GrpcWebLayer::new());
325
326 let res = svc.call(request()).await.unwrap();
327
328 assert_eq!(res.status(), StatusCode::OK);
329 }
330
331 #[tokio::test]
332 async fn without_origin() {
333 let mut svc = enable(Svc);
334
335 let mut req = request();
336 req.headers_mut().remove(ORIGIN);
337
338 let res = svc.call(req).await.unwrap();
339
340 assert_eq!(res.status(), StatusCode::OK);
341 }
342
343 #[tokio::test]
344 async fn only_post_and_options_allowed() {
345 let mut svc = enable(Svc);
346
347 for method in &[
348 Method::GET,
349 Method::PUT,
350 Method::DELETE,
351 Method::HEAD,
352 Method::PATCH,
353 ] {
354 let mut req = request();
355 *req.method_mut() = method.clone();
356
357 let res = svc.call(req).await.unwrap();
358
359 assert_eq!(
360 res.status(),
361 StatusCode::METHOD_NOT_ALLOWED,
362 "{method} should not be allowed"
363 );
364 }
365 }
366
367 #[tokio::test]
368 async fn grpc_web_content_types() {
369 let mut svc = enable(Svc);
370
371 for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_TEXT_PROTO, GRPC_WEB] {
372 let mut req = request();
373 req.headers_mut()
374 .insert(CONTENT_TYPE, HeaderValue::from_static(ct));
375
376 let res = svc.call(req).await.unwrap();
377
378 assert_eq!(res.status(), StatusCode::OK);
379 }
380 }
381 }
382
383 mod options {
384 use super::*;
385
386 fn request() -> Request<Body> {
387 Request::builder()
388 .method(Method::OPTIONS)
389 .header(ORIGIN, "http://example.com")
390 .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web")
391 .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
392 .body(Body::default())
393 .unwrap()
394 }
395
396 #[tokio::test]
397 async fn valid_grpc_web_preflight() {
398 let mut svc = enable(Svc);
399 let res = svc.call(request()).await.unwrap();
400
401 assert_eq!(res.status(), StatusCode::OK);
402 }
403 }
404
405 mod grpc {
406 use super::*;
407
408 fn request() -> Request<Body> {
409 Request::builder()
410 .version(Version::HTTP_2)
411 .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
412 .body(Body::default())
413 .unwrap()
414 }
415
416 #[tokio::test]
417 async fn h2_is_ok() {
418 let mut svc = enable(Svc);
419
420 let req = request();
421 let res = svc.call(req).await.unwrap();
422
423 assert_eq!(res.status(), StatusCode::OK)
424 }
425
426 #[tokio::test]
427 async fn h1_is_err() {
428 let mut svc = enable(Svc);
429
430 let req = Request::builder()
431 .header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
432 .body(Body::default())
433 .unwrap();
434
435 let res = svc.call(req).await.unwrap();
436 assert_eq!(res.status(), StatusCode::BAD_REQUEST)
437 }
438
439 #[tokio::test]
440 async fn content_type_variants() {
441 let mut svc = enable(Svc);
442
443 for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] {
444 let mut req = request();
445 req.headers_mut().insert(
446 CONTENT_TYPE,
447 HeaderValue::from_maybe_shared(format!("application/{variant}")).unwrap(),
448 );
449
450 let res = svc.call(req).await.unwrap();
451
452 assert_eq!(res.status(), StatusCode::OK)
453 }
454 }
455 }
456
457 mod other {
458 use super::*;
459
460 fn request() -> Request<Body> {
461 Request::builder()
462 .header(CONTENT_TYPE, "application/text")
463 .body(Body::default())
464 .unwrap()
465 }
466
467 #[tokio::test]
468 async fn h1_is_err() {
469 let mut svc = enable(Svc);
470 let res = svc.call(request()).await.unwrap();
471
472 assert_eq!(res.status(), StatusCode::BAD_REQUEST)
473 }
474
475 #[tokio::test]
476 async fn h2_is_ok() {
477 let mut svc = enable(Svc);
478 let mut req = request();
479 *req.version_mut() = Version::HTTP_2;
480
481 let res = svc.call(req).await.unwrap();
482 assert_eq!(res.status(), StatusCode::OK)
483 }
484 }
485}