1use std::fmt;
2use std::pin::Pin;
3use std::task::{ready, Context, Poll};
4
5use base64::Engine as _;
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use http::{header, HeaderMap, HeaderName, HeaderValue};
8use http_body::{Body, Frame, SizeHint};
9use pin_project::pin_project;
10use tokio_stream::Stream;
11use tonic::Status;
12
13use self::content_types::*;
14
15const GRPC_HEADER_SIZE: usize = 1 + 4;
17
18pub(crate) mod content_types {
19 use http::{header::CONTENT_TYPE, HeaderMap};
20
21 pub(crate) const GRPC_WEB: &str = "application/grpc-web";
22 pub(crate) const GRPC_WEB_PROTO: &str = "application/grpc-web+proto";
23 pub(crate) const GRPC_WEB_TEXT: &str = "application/grpc-web-text";
24 pub(crate) const GRPC_WEB_TEXT_PROTO: &str = "application/grpc-web-text+proto";
25
26 pub(crate) fn is_grpc_web(headers: &HeaderMap) -> bool {
27 matches!(
28 content_type(headers),
29 Some(GRPC_WEB) | Some(GRPC_WEB_PROTO) | Some(GRPC_WEB_TEXT) | Some(GRPC_WEB_TEXT_PROTO)
30 )
31 }
32
33 fn content_type(headers: &HeaderMap) -> Option<&str> {
34 headers.get(CONTENT_TYPE).and_then(|val| val.to_str().ok())
35 }
36}
37
38const BUFFER_SIZE: usize = 8 * 1024;
39
40const FRAME_HEADER_SIZE: usize = 5;
41
42const GRPC_WEB_TRAILERS_BIT: u8 = 0b10000000;
45
46#[derive(Copy, Clone, PartialEq, Debug)]
47enum Direction {
48 Decode,
49 Encode,
50 Empty,
51}
52
53#[derive(Copy, Clone, PartialEq, Debug)]
54pub(crate) enum Encoding {
55 Base64,
56 None,
57}
58
59#[derive(Debug)]
61#[pin_project]
62pub struct GrpcWebCall<B> {
63 #[pin]
64 inner: B,
65 buf: BytesMut,
66 decoded: BytesMut,
67 direction: Direction,
68 encoding: Encoding,
69 client: bool,
70 trailers: Option<HeaderMap>,
71}
72
73impl<B: Default> Default for GrpcWebCall<B> {
74 fn default() -> Self {
75 Self {
76 inner: Default::default(),
77 buf: Default::default(),
78 decoded: Default::default(),
79 direction: Direction::Empty,
80 encoding: Encoding::None,
81 client: Default::default(),
82 trailers: Default::default(),
83 }
84 }
85}
86
87impl<B> GrpcWebCall<B> {
88 pub(crate) fn request(inner: B, encoding: Encoding) -> Self {
89 Self::new(inner, Direction::Decode, encoding)
90 }
91
92 pub(crate) fn response(inner: B, encoding: Encoding) -> Self {
93 Self::new(inner, Direction::Encode, encoding)
94 }
95
96 pub(crate) fn client_request(inner: B) -> Self {
97 Self::new_client(inner, Direction::Encode, Encoding::None)
98 }
99
100 pub(crate) fn client_response(inner: B) -> Self {
101 Self::new_client(inner, Direction::Decode, Encoding::None)
102 }
103
104 fn new_client(inner: B, direction: Direction, encoding: Encoding) -> Self {
105 GrpcWebCall {
106 inner,
107 buf: BytesMut::with_capacity(match (direction, encoding) {
108 (Direction::Encode, Encoding::Base64) => BUFFER_SIZE,
109 _ => 0,
110 }),
111 decoded: BytesMut::with_capacity(match direction {
112 Direction::Decode => BUFFER_SIZE,
113 _ => 0,
114 }),
115 direction,
116 encoding,
117 client: true,
118 trailers: None,
119 }
120 }
121
122 fn new(inner: B, direction: Direction, encoding: Encoding) -> Self {
123 GrpcWebCall {
124 inner,
125 buf: BytesMut::with_capacity(match (direction, encoding) {
126 (Direction::Encode, Encoding::Base64) => BUFFER_SIZE,
127 _ => 0,
128 }),
129 decoded: BytesMut::with_capacity(0),
130 direction,
131 encoding,
132 client: false,
133 trailers: None,
134 }
135 }
136
137 #[inline]
140 fn max_decodable(&self) -> usize {
141 (self.buf.len() / 4) * 4
142 }
143
144 fn decode_chunk(mut self: Pin<&mut Self>) -> Result<Option<Bytes>, Status> {
145 if self.buf.is_empty() || self.buf.len() < 4 {
147 return Ok(None);
148 }
149
150 let index = self.max_decodable();
153
154 crate::util::base64::STANDARD
155 .decode(self.as_mut().project().buf.split_to(index))
156 .map(|decoded| Some(Bytes::from(decoded)))
157 .map_err(internal_error)
158 }
159}
160
161impl<B> GrpcWebCall<B>
162where
163 B: Body,
164 B::Data: Buf,
165 B::Error: fmt::Display,
166{
167 fn poll_decode(
171 mut self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 ) -> Poll<Option<Result<Frame<Bytes>, Status>>> {
174 match self.encoding {
175 Encoding::Base64 => loop {
176 if let Some(bytes) = self.as_mut().decode_chunk()? {
177 return Poll::Ready(Some(Ok(Frame::data(bytes))));
178 }
179
180 let this = self.as_mut().project();
181
182 match ready!(this.inner.poll_frame(cx)) {
183 Some(Ok(frame)) if frame.is_data() => this
184 .buf
185 .put(frame.into_data().unwrap_or_else(|_| unreachable!())),
186 Some(Ok(frame)) if frame.is_trailers() => {
187 return Poll::Ready(Some(Err(internal_error(
188 "malformed base64 request has unencoded trailers",
189 ))))
190 }
191 Some(Ok(_)) => {
192 return Poll::Ready(Some(Err(internal_error("unexpected frame type"))))
193 }
194 Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))),
195 None => {
196 return if this.buf.has_remaining() {
197 Poll::Ready(Some(Err(internal_error("malformed base64 request"))))
198 } else if let Some(trailers) = this.trailers.take() {
199 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
200 } else {
201 Poll::Ready(None)
202 }
203 }
204 }
205 },
206
207 Encoding::None => self
208 .project()
209 .inner
210 .poll_frame(cx)
211 .map_ok(|f| f.map_data(|mut d| d.copy_to_bytes(d.remaining())))
212 .map_err(internal_error),
213 }
214 }
215
216 fn poll_encode(
217 mut self: Pin<&mut Self>,
218 cx: &mut Context<'_>,
219 ) -> Poll<Option<Result<Frame<Bytes>, Status>>> {
220 let this = self.as_mut().project();
221
222 match ready!(this.inner.poll_frame(cx)) {
223 Some(Ok(frame)) if frame.is_data() => {
224 let mut data = frame.into_data().unwrap_or_else(|_| unreachable!());
225 let mut res = data.copy_to_bytes(data.remaining());
226
227 if *this.encoding == Encoding::Base64 {
228 res = crate::util::base64::STANDARD.encode(res).into();
229 }
230
231 Poll::Ready(Some(Ok(Frame::data(res))))
232 }
233 Some(Ok(frame)) if frame.is_trailers() => {
234 let trailers = frame.into_trailers().unwrap_or_else(|_| unreachable!());
235 let mut res = make_trailers_frame(trailers);
236
237 if *this.encoding == Encoding::Base64 {
238 res = crate::util::base64::STANDARD.encode(res).into();
239 }
240
241 Poll::Ready(Some(Ok(Frame::data(res))))
242 }
243 Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexpected frame type")))),
244 Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))),
245 None => Poll::Ready(None),
246 }
247 }
248}
249
250impl<B> Body for GrpcWebCall<B>
251where
252 B: Body,
253 B::Error: fmt::Display,
254{
255 type Data = Bytes;
256 type Error = Status;
257
258 fn poll_frame(
259 mut self: Pin<&mut Self>,
260 cx: &mut Context<'_>,
261 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
262 if self.client && self.direction == Direction::Decode {
263 let mut me = self.as_mut();
264
265 loop {
266 match ready!(me.as_mut().poll_decode(cx)) {
267 Some(Ok(incoming_buf)) if incoming_buf.is_data() => {
268 me.as_mut()
269 .project()
270 .decoded
271 .put(incoming_buf.into_data().unwrap());
272 }
273 Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => {
274 let trailers = incoming_buf.into_trailers().unwrap();
275 match me.as_mut().project().trailers {
276 Some(current_trailers) => {
277 current_trailers.extend(trailers);
278 }
279 None => {
280 me.as_mut().project().trailers.replace(trailers);
281 }
282 }
283 continue;
284 }
285 Some(Ok(_)) => unreachable!("unexpected frame type"),
286 None => {} Some(Err(e)) => return Poll::Ready(Some(Err(e))),
288 };
289
290 let buf = me.as_mut().project().decoded;
293
294 return match find_trailers(&buf[..])? {
295 FindTrailers::Trailer(len) => {
296 let msg_buf = buf.copy_to_bytes(len);
298 match decode_trailers_frame(buf.split().freeze()) {
299 Ok(Some(trailers)) => {
300 me.as_mut().project().trailers.replace(trailers);
301 }
302 Err(e) => return Poll::Ready(Some(Err(e))),
303 _ => {}
304 }
305
306 if msg_buf.has_remaining() {
307 Poll::Ready(Some(Ok(Frame::data(msg_buf))))
308 } else if let Some(trailers) = me.as_mut().project().trailers.take() {
309 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
310 } else {
311 Poll::Ready(None)
312 }
313 }
314 FindTrailers::IncompleteBuf => continue,
315 FindTrailers::Done(len) => Poll::Ready(match len {
316 0 => None,
317 _ => Some(Ok(Frame::data(buf.split_to(len).freeze()))),
318 }),
319 };
320 }
321 }
322
323 match self.direction {
324 Direction::Decode => self.poll_decode(cx),
325 Direction::Encode => self.poll_encode(cx),
326 Direction::Empty => Poll::Ready(None),
327 }
328 }
329
330 fn is_end_stream(&self) -> bool {
331 self.inner.is_end_stream()
332 }
333
334 fn size_hint(&self) -> SizeHint {
335 self.inner.size_hint()
336 }
337}
338
339impl<B> Stream for GrpcWebCall<B>
340where
341 B: Body,
342 B::Error: fmt::Display,
343{
344 type Item = Result<Frame<Bytes>, Status>;
345
346 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
347 self.poll_frame(cx)
348 }
349}
350
351impl Encoding {
352 pub(crate) fn from_content_type(headers: &HeaderMap) -> Encoding {
353 Self::from_header(headers.get(header::CONTENT_TYPE))
354 }
355
356 pub(crate) fn from_accept(headers: &HeaderMap) -> Encoding {
357 Self::from_header(headers.get(header::ACCEPT))
358 }
359
360 pub(crate) fn to_content_type(self) -> &'static str {
361 match self {
362 Encoding::Base64 => GRPC_WEB_TEXT_PROTO,
363 Encoding::None => GRPC_WEB_PROTO,
364 }
365 }
366
367 fn from_header(value: Option<&HeaderValue>) -> Encoding {
368 match value.and_then(|val| val.to_str().ok()) {
369 Some(GRPC_WEB_TEXT_PROTO) | Some(GRPC_WEB_TEXT) => Encoding::Base64,
370 _ => Encoding::None,
371 }
372 }
373}
374
375fn internal_error(e: impl std::fmt::Display) -> Status {
376 Status::internal(format!("tonic-web: {e}"))
377}
378
379fn encode_trailers(trailers: HeaderMap) -> Vec<u8> {
381 trailers.iter().fold(Vec::new(), |mut acc, (key, value)| {
382 acc.put_slice(key.as_ref());
383 acc.push(b':');
384 acc.put_slice(value.as_bytes());
385 acc.put_slice(b"\r\n");
386 acc
387 })
388}
389
390fn decode_trailers_frame(mut buf: Bytes) -> Result<Option<HeaderMap>, Status> {
391 if buf.remaining() < GRPC_HEADER_SIZE {
392 return Ok(None);
393 }
394
395 buf.get_u8();
396 buf.get_u32();
397
398 let mut map = HeaderMap::new();
399 let mut temp_buf = buf.clone();
400
401 let mut trailers = Vec::new();
402 let mut cursor_pos = 0;
403
404 for (i, b) in buf.iter().enumerate() {
405 if b == &b'\r' && buf.get(i + 1) == Some(&b'\n') {
407 let trailer = temp_buf.copy_to_bytes(i - cursor_pos);
409 cursor_pos = i + 2;
411 trailers.push(trailer);
412 if temp_buf.has_remaining() {
413 temp_buf.get_u8();
415 temp_buf.get_u8();
416 }
417 }
418 }
419
420 for trailer in trailers {
421 let mut s = trailer.split(|b| b == &b':');
422 let key = s
423 .next()
424 .ok_or_else(|| Status::internal("trailers couldn't parse key"))?;
425 let value = s
426 .next()
427 .ok_or_else(|| Status::internal("trailers couldn't parse value"))?;
428
429 let value = value
430 .split(|b| b == &b'\r')
431 .next()
432 .ok_or_else(|| Status::internal("trailers was not escaped"))?
433 .strip_prefix(b" ")
434 .unwrap_or(value);
435
436 let header_key = HeaderName::try_from(key)
437 .map_err(|e| Status::internal(format!("Unable to parse HeaderName: {e}")))?;
438 let header_value = HeaderValue::try_from(value)
439 .map_err(|e| Status::internal(format!("Unable to parse HeaderValue: {e}")))?;
440 map.insert(header_key, header_value);
441 }
442
443 Ok(Some(map))
444}
445
446fn make_trailers_frame(trailers: HeaderMap) -> Bytes {
447 let trailers = encode_trailers(trailers);
448 let len = trailers.len();
449 assert!(len <= u32::MAX as usize);
450
451 let mut frame = BytesMut::with_capacity(len + FRAME_HEADER_SIZE);
452 frame.put_u8(GRPC_WEB_TRAILERS_BIT);
453 frame.put_u32(len as u32);
454 frame.put_slice(&trailers);
455
456 frame.freeze()
457}
458
459fn find_trailers(buf: &[u8]) -> Result<FindTrailers, Status> {
464 let mut len = 0;
465 let mut temp_buf = buf;
466
467 loop {
468 if temp_buf.is_empty() || temp_buf.len() < GRPC_HEADER_SIZE {
471 return Ok(FindTrailers::Done(len));
472 }
473
474 let header = temp_buf.get_u8();
475
476 if header == GRPC_WEB_TRAILERS_BIT {
477 return Ok(FindTrailers::Trailer(len));
478 }
479
480 if !(header == 0 || header == 1) {
481 return Err(Status::internal(format!(
482 "Invalid header bit {header} expected 0 or 1"
483 )));
484 }
485
486 let msg_len = temp_buf.get_u32();
487
488 len += msg_len as usize + 4 + 1;
489
490 if len > buf.len() {
493 return Ok(FindTrailers::IncompleteBuf);
494 }
495
496 temp_buf = &buf[len..];
497 }
498}
499
500#[derive(Debug, PartialEq, Eq)]
501enum FindTrailers {
502 Trailer(usize),
503 IncompleteBuf,
504 Done(usize),
505}
506
507#[cfg(test)]
508mod tests {
509 use tonic::Code;
510
511 use super::*;
512
513 #[test]
514 fn encoding_constructors() {
515 let cases = &[
516 (GRPC_WEB, Encoding::None),
517 (GRPC_WEB_PROTO, Encoding::None),
518 (GRPC_WEB_TEXT, Encoding::Base64),
519 (GRPC_WEB_TEXT_PROTO, Encoding::Base64),
520 ("foo", Encoding::None),
521 ];
522
523 let mut headers = HeaderMap::new();
524
525 for case in cases {
526 headers.insert(header::CONTENT_TYPE, case.0.parse().unwrap());
527 headers.insert(header::ACCEPT, case.0.parse().unwrap());
528
529 assert_eq!(Encoding::from_content_type(&headers), case.1, "{}", case.0);
530 assert_eq!(Encoding::from_accept(&headers), case.1, "{}", case.0);
531 }
532 }
533
534 #[test]
535 fn decode_trailers() {
536 let mut headers = HeaderMap::new();
537 headers.insert(Status::GRPC_STATUS, 0.into());
538 headers.insert(
539 Status::GRPC_MESSAGE,
540 "this is a message".try_into().unwrap(),
541 );
542
543 let trailers = make_trailers_frame(headers.clone());
544
545 let map = decode_trailers_frame(trailers).unwrap().unwrap();
546
547 assert_eq!(headers, map);
548 }
549
550 #[test]
551 fn find_trailers_non_buffered() {
552 let buf = [
555 128, 0, 0, 0, 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10,
556 ];
557
558 let out = find_trailers(&buf[..]).unwrap();
559
560 assert_eq!(out, FindTrailers::Trailer(0));
561 }
562
563 #[test]
564 fn find_trailers_buffered() {
565 let buf = [
568 0, 0, 0, 0, 76, 10, 36, 57, 55, 53, 55, 51, 56, 97, 102, 45, 49, 97, 49, 55, 45, 52,
569 97, 101, 97, 45, 98, 56, 56, 55, 45, 101, 100, 48, 98, 98, 99, 101, 100, 54, 48, 57,
570 51, 26, 36, 100, 97, 54, 48, 57, 101, 57, 98, 45, 102, 52, 55, 48, 45, 52, 99, 99, 48,
571 45, 97, 54, 57, 49, 45, 51, 102, 100, 54, 97, 48, 48, 53, 97, 52, 51, 54, 128, 0, 0, 0,
572 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10,
573 ];
574
575 let out = find_trailers(&buf[..]).unwrap();
576
577 assert_eq!(out, FindTrailers::Trailer(81));
578
579 let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[81..]))
580 .unwrap()
581 .unwrap();
582 let status = trailers.get(Status::GRPC_STATUS).unwrap();
583 assert_eq!(status.to_str().unwrap(), "0")
584 }
585
586 #[test]
587 fn find_trailers_buffered_incomplete_message() {
588 let buf = vec![
589 0, 0, 0, 9, 238, 10, 233, 19, 18, 230, 19, 10, 9, 10, 1, 120, 26, 4, 84, 69, 88, 84,
590 18, 60, 10, 58, 10, 56, 3, 0, 0, 0, 44, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32,
591 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32,
592 118, 105, 97, 32, 119, 114, 105, 116, 101, 32, 100, 101, 108, 101, 103, 97, 116, 105,
593 111, 110, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104,
594 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116,
595 101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114,
596 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0,
597 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114,
598 105, 116, 116, 101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100,
599 101, 100, 32, 114, 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0,
600 46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97,
601 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98,
602 101, 100, 100, 101, 100, 32, 114, 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10,
603 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117,
604 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98, 121, 32, 97, 110,
605 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114, 101, 112, 108, 105, 99, 97, 33, 18,
606 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118,
607 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98,
608 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114, 101, 112, 108,
609 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104,
610 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116,
611 101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114,
612 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0,
613 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114,
614 105, 116, 116, 101, 110, 32, 98, 121, 32,
615 ];
616
617 let out = find_trailers(&buf[..]).unwrap();
618
619 assert_eq!(out, FindTrailers::IncompleteBuf);
620 }
621
622 #[test]
623 #[ignore]
624 fn find_trailers_buffered_incomplete_buf_bug() {
625 let buf = std::fs::read("tests/incomplete-buf-bug.bin").unwrap();
626 let out = find_trailers(&buf[..]).unwrap_err();
627
628 assert_eq!(out.code(), Code::Internal);
629 }
630
631 #[test]
632 fn decode_multiple_trailers() {
633 let buf = b"\x80\0\0\0\x0fgrpc-status:0\r\ngrpc-message:\r\na:1\r\nb:2\r\n";
634
635 let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[..]))
636 .unwrap()
637 .unwrap();
638
639 let mut expected = HeaderMap::new();
640 expected.insert(Status::GRPC_STATUS, "0".parse().unwrap());
641 expected.insert(Status::GRPC_MESSAGE, "".parse().unwrap());
642 expected.insert("a", "1".parse().unwrap());
643 expected.insert("b", "2".parse().unwrap());
644
645 assert_eq!(trailers, expected);
646 }
647
648 #[test]
649 fn decode_trailers_with_space_after_colon() {
650 let buf = b"\x80\0\0\0\x0fgrpc-status: 0\r\ngrpc-message: \r\n";
651
652 let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[..]))
653 .unwrap()
654 .unwrap();
655
656 let mut expected = HeaderMap::new();
657 expected.insert(Status::GRPC_STATUS, "0".parse().unwrap());
658 expected.insert(Status::GRPC_MESSAGE, "".parse().unwrap());
659
660 assert_eq!(trailers, expected);
661 }
662}