Skip to main content

tonic_web/
call.rs

1use std::fmt;
2use std::pin::Pin;
3use std::task::{ready, Context, Poll};
4
5use base64::Engine as _;
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use http::{header, HeaderMap, HeaderName, HeaderValue};
8use http_body::{Body, Frame, SizeHint};
9use pin_project::pin_project;
10use tokio_stream::Stream;
11use tonic::Status;
12
13use self::content_types::*;
14
15// A grpc header is u8 (flag) + u32 (msg len)
16const GRPC_HEADER_SIZE: usize = 1 + 4;
17
18pub(crate) mod content_types {
19    use http::{header::CONTENT_TYPE, HeaderMap};
20
21    pub(crate) const GRPC_WEB: &str = "application/grpc-web";
22    pub(crate) const GRPC_WEB_PROTO: &str = "application/grpc-web+proto";
23    pub(crate) const GRPC_WEB_TEXT: &str = "application/grpc-web-text";
24    pub(crate) const GRPC_WEB_TEXT_PROTO: &str = "application/grpc-web-text+proto";
25
26    pub(crate) fn is_grpc_web(headers: &HeaderMap) -> bool {
27        matches!(
28            content_type(headers),
29            Some(GRPC_WEB) | Some(GRPC_WEB_PROTO) | Some(GRPC_WEB_TEXT) | Some(GRPC_WEB_TEXT_PROTO)
30        )
31    }
32
33    fn content_type(headers: &HeaderMap) -> Option<&str> {
34        headers.get(CONTENT_TYPE).and_then(|val| val.to_str().ok())
35    }
36}
37
38const BUFFER_SIZE: usize = 8 * 1024;
39
40const FRAME_HEADER_SIZE: usize = 5;
41
42// 8th (MSB) bit of the 1st gRPC frame byte
43// denotes an uncompressed trailer (as part of the body)
44const GRPC_WEB_TRAILERS_BIT: u8 = 0b10000000;
45
46#[derive(Copy, Clone, PartialEq, Debug)]
47enum Direction {
48    Decode,
49    Encode,
50    Empty,
51}
52
53#[derive(Copy, Clone, PartialEq, Debug)]
54pub(crate) enum Encoding {
55    Base64,
56    None,
57}
58
59/// HttpBody adapter for the grpc web based services.
60#[derive(Debug)]
61#[pin_project]
62pub struct GrpcWebCall<B> {
63    #[pin]
64    inner: B,
65    buf: BytesMut,
66    decoded: BytesMut,
67    direction: Direction,
68    encoding: Encoding,
69    client: bool,
70    trailers: Option<HeaderMap>,
71}
72
73impl<B: Default> Default for GrpcWebCall<B> {
74    fn default() -> Self {
75        Self {
76            inner: Default::default(),
77            buf: Default::default(),
78            decoded: Default::default(),
79            direction: Direction::Empty,
80            encoding: Encoding::None,
81            client: Default::default(),
82            trailers: Default::default(),
83        }
84    }
85}
86
87impl<B> GrpcWebCall<B> {
88    pub(crate) fn request(inner: B, encoding: Encoding) -> Self {
89        Self::new(inner, Direction::Decode, encoding)
90    }
91
92    pub(crate) fn response(inner: B, encoding: Encoding) -> Self {
93        Self::new(inner, Direction::Encode, encoding)
94    }
95
96    pub(crate) fn client_request(inner: B) -> Self {
97        Self::new_client(inner, Direction::Encode, Encoding::None)
98    }
99
100    pub(crate) fn client_response(inner: B) -> Self {
101        Self::new_client(inner, Direction::Decode, Encoding::None)
102    }
103
104    fn new_client(inner: B, direction: Direction, encoding: Encoding) -> Self {
105        GrpcWebCall {
106            inner,
107            buf: BytesMut::with_capacity(match (direction, encoding) {
108                (Direction::Encode, Encoding::Base64) => BUFFER_SIZE,
109                _ => 0,
110            }),
111            decoded: BytesMut::with_capacity(match direction {
112                Direction::Decode => BUFFER_SIZE,
113                _ => 0,
114            }),
115            direction,
116            encoding,
117            client: true,
118            trailers: None,
119        }
120    }
121
122    fn new(inner: B, direction: Direction, encoding: Encoding) -> Self {
123        GrpcWebCall {
124            inner,
125            buf: BytesMut::with_capacity(match (direction, encoding) {
126                (Direction::Encode, Encoding::Base64) => BUFFER_SIZE,
127                _ => 0,
128            }),
129            decoded: BytesMut::with_capacity(0),
130            direction,
131            encoding,
132            client: false,
133            trailers: None,
134        }
135    }
136
137    // This is to avoid passing a slice of bytes with a length that the base64
138    // decoder would consider invalid.
139    #[inline]
140    fn max_decodable(&self) -> usize {
141        (self.buf.len() / 4) * 4
142    }
143
144    fn decode_chunk(mut self: Pin<&mut Self>) -> Result<Option<Bytes>, Status> {
145        // not enough bytes to decode
146        if self.buf.is_empty() || self.buf.len() < 4 {
147            return Ok(None);
148        }
149
150        // Split `buf` at the largest index that is multiple of 4. Decode the
151        // returned `Bytes`, keeping the rest for the next attempt to decode.
152        let index = self.max_decodable();
153
154        crate::util::base64::STANDARD
155            .decode(self.as_mut().project().buf.split_to(index))
156            .map(|decoded| Some(Bytes::from(decoded)))
157            .map_err(internal_error)
158    }
159}
160
161impl<B> GrpcWebCall<B>
162where
163    B: Body,
164    B::Data: Buf,
165    B::Error: fmt::Display,
166{
167    // Poll body for data, decoding (e.g. via Base64 if necessary) and returning frames
168    // to the caller. If the caller is a client, it should look for trailers before
169    // returning these frames.
170    fn poll_decode(
171        mut self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173    ) -> Poll<Option<Result<Frame<Bytes>, Status>>> {
174        match self.encoding {
175            Encoding::Base64 => loop {
176                if let Some(bytes) = self.as_mut().decode_chunk()? {
177                    return Poll::Ready(Some(Ok(Frame::data(bytes))));
178                }
179
180                let this = self.as_mut().project();
181
182                match ready!(this.inner.poll_frame(cx)) {
183                    Some(Ok(frame)) if frame.is_data() => this
184                        .buf
185                        .put(frame.into_data().unwrap_or_else(|_| unreachable!())),
186                    Some(Ok(frame)) if frame.is_trailers() => {
187                        return Poll::Ready(Some(Err(internal_error(
188                            "malformed base64 request has unencoded trailers",
189                        ))))
190                    }
191                    Some(Ok(_)) => {
192                        return Poll::Ready(Some(Err(internal_error("unexpected frame type"))))
193                    }
194                    Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))),
195                    None => {
196                        return if this.buf.has_remaining() {
197                            Poll::Ready(Some(Err(internal_error("malformed base64 request"))))
198                        } else if let Some(trailers) = this.trailers.take() {
199                            Poll::Ready(Some(Ok(Frame::trailers(trailers))))
200                        } else {
201                            Poll::Ready(None)
202                        }
203                    }
204                }
205            },
206
207            Encoding::None => self
208                .project()
209                .inner
210                .poll_frame(cx)
211                .map_ok(|f| f.map_data(|mut d| d.copy_to_bytes(d.remaining())))
212                .map_err(internal_error),
213        }
214    }
215
216    fn poll_encode(
217        mut self: Pin<&mut Self>,
218        cx: &mut Context<'_>,
219    ) -> Poll<Option<Result<Frame<Bytes>, Status>>> {
220        let this = self.as_mut().project();
221
222        match ready!(this.inner.poll_frame(cx)) {
223            Some(Ok(frame)) if frame.is_data() => {
224                let mut data = frame.into_data().unwrap_or_else(|_| unreachable!());
225                let mut res = data.copy_to_bytes(data.remaining());
226
227                if *this.encoding == Encoding::Base64 {
228                    res = crate::util::base64::STANDARD.encode(res).into();
229                }
230
231                Poll::Ready(Some(Ok(Frame::data(res))))
232            }
233            Some(Ok(frame)) if frame.is_trailers() => {
234                let trailers = frame.into_trailers().unwrap_or_else(|_| unreachable!());
235                let mut res = make_trailers_frame(trailers);
236
237                if *this.encoding == Encoding::Base64 {
238                    res = crate::util::base64::STANDARD.encode(res).into();
239                }
240
241                Poll::Ready(Some(Ok(Frame::data(res))))
242            }
243            Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexpected frame type")))),
244            Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))),
245            None => Poll::Ready(None),
246        }
247    }
248}
249
250impl<B> Body for GrpcWebCall<B>
251where
252    B: Body,
253    B::Error: fmt::Display,
254{
255    type Data = Bytes;
256    type Error = Status;
257
258    fn poll_frame(
259        mut self: Pin<&mut Self>,
260        cx: &mut Context<'_>,
261    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
262        if self.client && self.direction == Direction::Decode {
263            let mut me = self.as_mut();
264
265            loop {
266                match ready!(me.as_mut().poll_decode(cx)) {
267                    Some(Ok(incoming_buf)) if incoming_buf.is_data() => {
268                        me.as_mut()
269                            .project()
270                            .decoded
271                            .put(incoming_buf.into_data().unwrap());
272                    }
273                    Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => {
274                        let trailers = incoming_buf.into_trailers().unwrap();
275                        match me.as_mut().project().trailers {
276                            Some(current_trailers) => {
277                                current_trailers.extend(trailers);
278                            }
279                            None => {
280                                me.as_mut().project().trailers.replace(trailers);
281                            }
282                        }
283                        continue;
284                    }
285                    Some(Ok(_)) => unreachable!("unexpected frame type"),
286                    None => {} // No more data to decode, time to look for trailers
287                    Some(Err(e)) => return Poll::Ready(Some(Err(e))),
288                };
289
290                // Hold the incoming, decoded data until we have a full message
291                // or trailers to return.
292                let buf = me.as_mut().project().decoded;
293
294                return match find_trailers(&buf[..])? {
295                    FindTrailers::Trailer(len) => {
296                        // Extract up to len of where the trailers are at
297                        let msg_buf = buf.copy_to_bytes(len);
298                        match decode_trailers_frame(buf.split().freeze()) {
299                            Ok(Some(trailers)) => {
300                                me.as_mut().project().trailers.replace(trailers);
301                            }
302                            Err(e) => return Poll::Ready(Some(Err(e))),
303                            _ => {}
304                        }
305
306                        if msg_buf.has_remaining() {
307                            Poll::Ready(Some(Ok(Frame::data(msg_buf))))
308                        } else if let Some(trailers) = me.as_mut().project().trailers.take() {
309                            Poll::Ready(Some(Ok(Frame::trailers(trailers))))
310                        } else {
311                            Poll::Ready(None)
312                        }
313                    }
314                    FindTrailers::IncompleteBuf => continue,
315                    FindTrailers::Done(len) => Poll::Ready(match len {
316                        0 => None,
317                        _ => Some(Ok(Frame::data(buf.split_to(len).freeze()))),
318                    }),
319                };
320            }
321        }
322
323        match self.direction {
324            Direction::Decode => self.poll_decode(cx),
325            Direction::Encode => self.poll_encode(cx),
326            Direction::Empty => Poll::Ready(None),
327        }
328    }
329
330    fn is_end_stream(&self) -> bool {
331        self.inner.is_end_stream()
332    }
333
334    fn size_hint(&self) -> SizeHint {
335        self.inner.size_hint()
336    }
337}
338
339impl<B> Stream for GrpcWebCall<B>
340where
341    B: Body,
342    B::Error: fmt::Display,
343{
344    type Item = Result<Frame<Bytes>, Status>;
345
346    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
347        self.poll_frame(cx)
348    }
349}
350
351impl Encoding {
352    pub(crate) fn from_content_type(headers: &HeaderMap) -> Encoding {
353        Self::from_header(headers.get(header::CONTENT_TYPE))
354    }
355
356    pub(crate) fn from_accept(headers: &HeaderMap) -> Encoding {
357        Self::from_header(headers.get(header::ACCEPT))
358    }
359
360    pub(crate) fn to_content_type(self) -> &'static str {
361        match self {
362            Encoding::Base64 => GRPC_WEB_TEXT_PROTO,
363            Encoding::None => GRPC_WEB_PROTO,
364        }
365    }
366
367    fn from_header(value: Option<&HeaderValue>) -> Encoding {
368        match value.and_then(|val| val.to_str().ok()) {
369            Some(GRPC_WEB_TEXT_PROTO) | Some(GRPC_WEB_TEXT) => Encoding::Base64,
370            _ => Encoding::None,
371        }
372    }
373}
374
375fn internal_error(e: impl std::fmt::Display) -> Status {
376    Status::internal(format!("tonic-web: {e}"))
377}
378
379// Key-value pairs encoded as a HTTP/1 headers block (without the terminating newline)
380fn encode_trailers(trailers: HeaderMap) -> Vec<u8> {
381    trailers.iter().fold(Vec::new(), |mut acc, (key, value)| {
382        acc.put_slice(key.as_ref());
383        acc.push(b':');
384        acc.put_slice(value.as_bytes());
385        acc.put_slice(b"\r\n");
386        acc
387    })
388}
389
390fn decode_trailers_frame(mut buf: Bytes) -> Result<Option<HeaderMap>, Status> {
391    if buf.remaining() < GRPC_HEADER_SIZE {
392        return Ok(None);
393    }
394
395    buf.get_u8();
396    buf.get_u32();
397
398    let mut map = HeaderMap::new();
399    let mut temp_buf = buf.clone();
400
401    let mut trailers = Vec::new();
402    let mut cursor_pos = 0;
403
404    for (i, b) in buf.iter().enumerate() {
405        // if we are at a trailer delimiter (\r\n)
406        if b == &b'\r' && buf.get(i + 1) == Some(&b'\n') {
407            // read the bytes of the trailer passed so far
408            let trailer = temp_buf.copy_to_bytes(i - cursor_pos);
409            // increment cursor beyond the delimiter
410            cursor_pos = i + 2;
411            trailers.push(trailer);
412            if temp_buf.has_remaining() {
413                // advance buf beyond the delimiters
414                temp_buf.get_u8();
415                temp_buf.get_u8();
416            }
417        }
418    }
419
420    for trailer in trailers {
421        let Some((key, value)) = trailer
422            .iter()
423            .position(|b| *b == b':')
424            .map(|pos| trailer.split_at(pos))
425        else {
426            return Err(Status::internal("trailers couldn't parse key/value"));
427        };
428
429        // Skip the ':' separator and trim leading OWS (spaces/tabs) from the value
430        let value = &value[1..]; // skip ':'
431        let value = trim_ascii_start(value);
432
433        let header_key = HeaderName::try_from(key)
434            .map_err(|e| Status::internal(format!("Unable to parse HeaderName: {e}")))?;
435        let header_value = HeaderValue::try_from(value)
436            .map_err(|e| Status::internal(format!("Unable to parse HeaderValue: {e}")))?;
437        map.insert(header_key, header_value);
438    }
439
440    Ok(Some(map))
441}
442
443fn trim_ascii_start(bytes: &[u8]) -> &[u8] {
444    let start = bytes
445        .iter()
446        .position(|b| !b.is_ascii_whitespace())
447        .unwrap_or(bytes.len());
448    &bytes[start..]
449}
450
451fn make_trailers_frame(trailers: HeaderMap) -> Bytes {
452    let trailers = encode_trailers(trailers);
453    let len = trailers.len();
454    assert!(len <= u32::MAX as usize);
455
456    let mut frame = BytesMut::with_capacity(len + FRAME_HEADER_SIZE);
457    frame.put_u8(GRPC_WEB_TRAILERS_BIT);
458    frame.put_u32(len as u32);
459    frame.put_slice(&trailers);
460
461    frame.freeze()
462}
463
464/// Search some buffer for grpc-web trailers headers and return
465/// its location in the original buf. If `None` is returned we did
466/// not find a trailers in this buffer either because its incomplete
467/// or the buffer just contained grpc message frames.
468fn find_trailers(buf: &[u8]) -> Result<FindTrailers, Status> {
469    let mut len = 0;
470    let mut temp_buf = buf;
471
472    loop {
473        // To check each frame, there must be at least GRPC_HEADER_SIZE
474        // amount of bytes available otherwise the buffer is incomplete.
475        if temp_buf.is_empty() || temp_buf.len() < GRPC_HEADER_SIZE {
476            return Ok(FindTrailers::Done(len));
477        }
478
479        let header = temp_buf.get_u8();
480
481        if header == GRPC_WEB_TRAILERS_BIT {
482            return Ok(FindTrailers::Trailer(len));
483        }
484
485        if !(header == 0 || header == 1) {
486            return Err(Status::internal(format!(
487                "Invalid header bit {header} expected 0 or 1"
488            )));
489        }
490
491        let msg_len = temp_buf.get_u32();
492
493        len += msg_len as usize + 4 + 1;
494
495        // If the msg len of a non-grpc-web trailer frame is larger than
496        // the overall buffer we know within that buffer there are no trailers.
497        if len > buf.len() {
498            return Ok(FindTrailers::IncompleteBuf);
499        }
500
501        temp_buf = &buf[len..];
502    }
503}
504
505#[derive(Debug, PartialEq, Eq)]
506enum FindTrailers {
507    Trailer(usize),
508    IncompleteBuf,
509    Done(usize),
510}
511
512#[cfg(test)]
513mod tests {
514    use tonic::Code;
515
516    use super::*;
517
518    #[test]
519    fn encoding_constructors() {
520        let cases = &[
521            (GRPC_WEB, Encoding::None),
522            (GRPC_WEB_PROTO, Encoding::None),
523            (GRPC_WEB_TEXT, Encoding::Base64),
524            (GRPC_WEB_TEXT_PROTO, Encoding::Base64),
525            ("foo", Encoding::None),
526        ];
527
528        let mut headers = HeaderMap::new();
529
530        for case in cases {
531            headers.insert(header::CONTENT_TYPE, case.0.parse().unwrap());
532            headers.insert(header::ACCEPT, case.0.parse().unwrap());
533
534            assert_eq!(Encoding::from_content_type(&headers), case.1, "{}", case.0);
535            assert_eq!(Encoding::from_accept(&headers), case.1, "{}", case.0);
536        }
537    }
538
539    #[test]
540    fn decode_trailers() {
541        let mut headers = HeaderMap::new();
542        headers.insert(Status::GRPC_STATUS, 0.into());
543        headers.insert(
544            Status::GRPC_MESSAGE,
545            "this is a message".try_into().unwrap(),
546        );
547
548        let trailers = make_trailers_frame(headers.clone());
549
550        let map = decode_trailers_frame(trailers).unwrap().unwrap();
551
552        assert_eq!(headers, map);
553    }
554
555    #[test]
556    fn find_trailers_non_buffered() {
557        // Byte version of this:
558        // b"\x80\0\0\0\x0fgrpc-status:0\r\n"
559        let buf = [
560            128, 0, 0, 0, 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10,
561        ];
562
563        let out = find_trailers(&buf[..]).unwrap();
564
565        assert_eq!(out, FindTrailers::Trailer(0));
566    }
567
568    #[test]
569    fn find_trailers_buffered() {
570        // Byte version of this:
571        // b"\0\0\0\0L\n$975738af-1a17-4aea-b887-ed0bbced6093\x1a$da609e9b-f470-4cc0-a691-3fd6a005a436\x80\0\0\0\x0fgrpc-status:0\r\n"
572        let buf = [
573            0, 0, 0, 0, 76, 10, 36, 57, 55, 53, 55, 51, 56, 97, 102, 45, 49, 97, 49, 55, 45, 52,
574            97, 101, 97, 45, 98, 56, 56, 55, 45, 101, 100, 48, 98, 98, 99, 101, 100, 54, 48, 57,
575            51, 26, 36, 100, 97, 54, 48, 57, 101, 57, 98, 45, 102, 52, 55, 48, 45, 52, 99, 99, 48,
576            45, 97, 54, 57, 49, 45, 51, 102, 100, 54, 97, 48, 48, 53, 97, 52, 51, 54, 128, 0, 0, 0,
577            15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10,
578        ];
579
580        let out = find_trailers(&buf[..]).unwrap();
581
582        assert_eq!(out, FindTrailers::Trailer(81));
583
584        let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[81..]))
585            .unwrap()
586            .unwrap();
587        let status = trailers.get(Status::GRPC_STATUS).unwrap();
588        assert_eq!(status.to_str().unwrap(), "0")
589    }
590
591    #[test]
592    fn find_trailers_buffered_incomplete_message() {
593        let buf = vec![
594            0, 0, 0, 9, 238, 10, 233, 19, 18, 230, 19, 10, 9, 10, 1, 120, 26, 4, 84, 69, 88, 84,
595            18, 60, 10, 58, 10, 56, 3, 0, 0, 0, 44, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32,
596            118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32,
597            118, 105, 97, 32, 119, 114, 105, 116, 101, 32, 100, 101, 108, 101, 103, 97, 116, 105,
598            111, 110, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104,
599            105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116,
600            101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114,
601            101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0,
602            0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114,
603            105, 116, 116, 101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100,
604            101, 100, 32, 114, 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0,
605            46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97,
606            115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98,
607            101, 100, 100, 101, 100, 32, 114, 101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10,
608            58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117,
609            101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98, 121, 32, 97, 110,
610            32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114, 101, 112, 108, 105, 99, 97, 33, 18,
611            62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104, 105, 115, 32, 118,
612            97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116, 101, 110, 32, 98,
613            121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114, 101, 112, 108,
614            105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 116, 104,
615            105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114, 105, 116, 116,
616            101, 110, 32, 98, 121, 32, 97, 110, 32, 101, 109, 98, 101, 100, 100, 101, 100, 32, 114,
617            101, 112, 108, 105, 99, 97, 33, 18, 62, 10, 60, 10, 58, 3, 0, 0, 0, 46, 0, 0, 0, 0, 0,
618            0, 0, 116, 104, 105, 115, 32, 118, 97, 108, 117, 101, 32, 119, 97, 115, 32, 119, 114,
619            105, 116, 116, 101, 110, 32, 98, 121, 32,
620        ];
621
622        let out = find_trailers(&buf[..]).unwrap();
623
624        assert_eq!(out, FindTrailers::IncompleteBuf);
625    }
626
627    #[test]
628    #[ignore]
629    fn find_trailers_buffered_incomplete_buf_bug() {
630        let buf = std::fs::read("tests/incomplete-buf-bug.bin").unwrap();
631        let out = find_trailers(&buf[..]).unwrap_err();
632
633        assert_eq!(out.code(), Code::Internal);
634    }
635
636    #[test]
637    fn decode_multiple_trailers() {
638        let buf = b"\x80\0\0\0\x0fgrpc-status:0\r\ngrpc-message:\r\na:1\r\nb:2\r\n";
639
640        let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[..]))
641            .unwrap()
642            .unwrap();
643
644        let mut expected = HeaderMap::new();
645        expected.insert(Status::GRPC_STATUS, "0".parse().unwrap());
646        expected.insert(Status::GRPC_MESSAGE, "".parse().unwrap());
647        expected.insert("a", "1".parse().unwrap());
648        expected.insert("b", "2".parse().unwrap());
649
650        assert_eq!(trailers, expected);
651    }
652
653    #[test]
654    fn decode_trailers_with_space_after_colon() {
655        let buf = b"\x80\0\0\0\x0fgrpc-status: 0\r\ngrpc-message: \r\n";
656
657        let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[..]))
658            .unwrap()
659            .unwrap();
660
661        let mut expected = HeaderMap::new();
662        expected.insert(Status::GRPC_STATUS, "0".parse().unwrap());
663        expected.insert(Status::GRPC_MESSAGE, "".parse().unwrap());
664
665        assert_eq!(trailers, expected);
666    }
667
668    #[test]
669    fn decode_trailers_space_after_colon() {
670        // connect-rpc and standard HTTP use "key: value" (space after colon)
671        let trailers_bytes = b"grpc-status: 0\r\ngrpc-message: this is a message\r\n";
672        let len = trailers_bytes.len();
673
674        let mut frame = BytesMut::new();
675        frame.put_u8(GRPC_WEB_TRAILERS_BIT);
676        frame.put_u32(len as u32);
677        frame.put_slice(&trailers_bytes[..]);
678
679        let map = decode_trailers_frame(frame.freeze()).unwrap().unwrap();
680
681        let mut expected = HeaderMap::new();
682        expected.insert(Status::GRPC_STATUS, HeaderValue::from_static("0"));
683        expected.insert(
684            Status::GRPC_MESSAGE,
685            HeaderValue::from_static("this is a message"),
686        );
687
688        assert_eq!(map, expected);
689    }
690
691    #[test]
692    fn decode_trailers_value_with_colons() {
693        let trailers_bytes = b"grpc-status: 0\r\ngrpc-message: error: something: went wrong\r\n";
694        let len = trailers_bytes.len();
695
696        let mut frame = BytesMut::new();
697        frame.put_u8(GRPC_WEB_TRAILERS_BIT);
698        frame.put_u32(len as u32);
699        frame.put_slice(&trailers_bytes[..]);
700
701        let map = decode_trailers_frame(frame.freeze()).unwrap().unwrap();
702
703        assert_eq!(map.get("grpc-status").unwrap(), "0");
704        assert_eq!(
705            map.get("grpc-message").unwrap(),
706            "error: something: went wrong"
707        );
708    }
709}