Skip to main content

ugi/protocol/
grpc.rs

1use std::fmt;
2use std::future::IntoFuture;
3use std::io::{Read, Write};
4use std::pin::Pin;
5use std::time::Duration;
6
7use async_channel::{Receiver, Sender};
8use bytes::{Bytes, BytesMut};
9use flate2::Compression;
10use flate2::read::GzDecoder;
11use flate2::write::GzEncoder;
12use futures_lite::Stream;
13use futures_lite::StreamExt;
14use futures_lite::stream;
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17
18use crate::BodyStream;
19use crate::error::{Error, ErrorKind, Result};
20use crate::header::HeaderMap;
21use crate::request::{ProtocolPolicy, RequestBuilder};
22use crate::response::{Response, TrailerState, Version};
23use crate::tls::{TlsBackend, TlsConfig};
24
25#[derive(Clone, Copy, Debug, Eq, PartialEq)]
26pub enum GrpcCodec {
27    Protobuf,
28    Json,
29}
30
31type GrpcMessageStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'static>>;
32
33#[derive(Clone, Copy, Debug, Eq, PartialEq)]
34enum GrpcCompression {
35    Gzip,
36}
37
38impl GrpcCompression {
39    fn as_str(self) -> &'static str {
40        match self {
41            Self::Gzip => "gzip",
42        }
43    }
44}
45
46#[derive(Clone, Copy, Debug, Eq, PartialEq)]
47struct GrpcRequestConfig {
48    content_type: &'static str,
49    compression: Option<GrpcCompression>,
50}
51
52enum GrpcRequestPayload {
53    Empty,
54    JsonMessage(Bytes),
55    JsonStream(GrpcMessageStream),
56    RawMessage(Bytes),
57    RawStream(GrpcMessageStream),
58}
59
60pub struct GrpcRequestBuilder {
61    request: RequestBuilder,
62    codec: GrpcCodec,
63    compression: Option<String>,
64    timeout: Option<Duration>,
65    payload: GrpcRequestPayload,
66}
67
68impl GrpcRequestBuilder {
69    pub(crate) fn from_request_builder(request: RequestBuilder) -> Self {
70        Self {
71            request: request.http2_only(),
72            codec: GrpcCodec::Json,
73            compression: None,
74            timeout: None,
75            payload: GrpcRequestPayload::Empty,
76        }
77    }
78
79    pub fn metadata(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
80        self.request = self.request.headers([(name.as_ref(), value.as_ref())])?;
81        Ok(self)
82    }
83
84    pub fn metadata_map<I, K, V>(mut self, values: I) -> Result<Self>
85    where
86        I: IntoIterator<Item = (K, V)>,
87        K: AsRef<str>,
88        V: AsRef<str>,
89    {
90        for (name, value) in values {
91            self = self.metadata(name, value)?;
92        }
93        Ok(self)
94    }
95
96    pub fn metadata_bin(mut self, name: impl AsRef<str>, value: impl AsRef<[u8]>) -> Result<Self> {
97        let name = normalize_grpc_binary_metadata_name(name.as_ref())?;
98        let value = encode_grpc_binary_header(value.as_ref());
99        self.request = self.request.header(name, value)?;
100        Ok(self)
101    }
102
103    pub fn message<T: Serialize>(mut self, value: &T) -> Result<Self> {
104        self.payload = GrpcRequestPayload::JsonMessage(encode_json_message(value)?);
105        Ok(self)
106    }
107
108    pub fn messages<S, T>(mut self, stream: S) -> Result<Self>
109    where
110        S: Stream<Item = Result<T>> + Send + 'static,
111        T: Serialize + Send,
112    {
113        let stream = async_stream::stream! {
114            let mut stream = Box::pin(stream);
115            while let Some(item) = stream.next().await {
116                let item = item?;
117                yield Ok(encode_json_message(&item)?);
118            }
119        };
120        self.payload = GrpcRequestPayload::JsonStream(Box::pin(stream));
121        Ok(self)
122    }
123
124    pub fn message_bytes(mut self, value: impl Into<Bytes>) -> Result<Self> {
125        self.payload = GrpcRequestPayload::RawMessage(value.into());
126        Ok(self)
127    }
128
129    pub fn messages_bytes<S, B>(mut self, stream: S) -> Result<Self>
130    where
131        S: Stream<Item = Result<B>> + Send + 'static,
132        B: Into<Bytes> + Send + 'static,
133    {
134        let stream = async_stream::stream! {
135            let mut stream = Box::pin(stream);
136            while let Some(item) = stream.next().await {
137                let item = item?;
138                yield Ok(item.into());
139            }
140        };
141        self.payload = GrpcRequestPayload::RawStream(Box::pin(stream));
142        Ok(self)
143    }
144
145    pub fn codec(mut self, codec: GrpcCodec) -> Self {
146        self.codec = codec;
147        self
148    }
149
150    pub fn compression(mut self, algo: impl AsRef<str>) -> Self {
151        let algo = algo.as_ref().trim();
152        self.compression = if algo.is_empty() {
153            None
154        } else {
155            Some(algo.to_owned())
156        };
157        self
158    }
159
160    pub fn protocol_policy(mut self, policy: ProtocolPolicy) -> Self {
161        self.request = self.request.protocol_policy(policy);
162        self
163    }
164
165    pub fn prefer_http3(mut self) -> Self {
166        self.request = self.request.prefer_http3();
167        self
168    }
169
170    pub fn prefer_http2(mut self) -> Self {
171        self.request = self.request.prefer_http2();
172        self
173    }
174
175    pub fn http2_only(mut self) -> Self {
176        self.request = self.request.http2_only();
177        self
178    }
179
180    pub fn http3_only(mut self) -> Self {
181        self.request = self.request.http3_only();
182        self
183    }
184
185    pub fn prior_knowledge_h2c(mut self, enabled: bool) -> Self {
186        self.request = self.request.prior_knowledge_h2c(enabled);
187        self
188    }
189
190    pub fn timeout(mut self, duration: Duration) -> Self {
191        self.timeout = Some(duration);
192        self.request = self.request.timeout(duration);
193        self
194    }
195
196    pub fn connect_timeout(mut self, duration: Duration) -> Self {
197        self.request = self.request.connect_timeout(duration);
198        self
199    }
200
201    pub fn read_timeout(mut self, duration: Duration) -> Self {
202        self.request = self.request.read_timeout(duration);
203        self
204    }
205
206    pub fn write_timeout(mut self, duration: Duration) -> Self {
207        self.request = self.request.write_timeout(duration);
208        self
209    }
210
211    pub fn tls_config(mut self, tls_config: TlsConfig) -> Self {
212        self.request = self.request.tls_config(tls_config);
213        self
214    }
215
216    pub fn danger_accept_invalid_certs(mut self, enabled: bool) -> Self {
217        self.request = self.request.danger_accept_invalid_certs(enabled);
218        self
219    }
220
221    pub fn tls_backend(mut self, backend: TlsBackend) -> Self {
222        self.request = self.request.tls_backend(backend);
223        self
224    }
225
226    pub async fn send_streaming(self) -> Result<GrpcStreamingResponse> {
227        let codec = self.codec;
228        let config = self.validate_configuration()?;
229        let request = self.build_request(config)?;
230        let response = request.await?;
231        GrpcStreamingResponse::from_http_response(response, codec)
232    }
233
234    pub async fn open_duplex(self) -> Result<GrpcDuplexCall> {
235        if !matches!(self.payload, GrpcRequestPayload::Empty) {
236            return Err(duplex_payload_conflict_error());
237        }
238
239        let codec = self.codec;
240        let config = self.validate_configuration()?;
241        let mut request = self
242            .request
243            .header("te", "trailers")?
244            .header("content-type", config.content_type)?
245            .header("grpc-accept-encoding", GrpcCompression::Gzip.as_str())?;
246
247        if let Some(timeout) = self.timeout {
248            request = request.header("grpc-timeout", &encode_grpc_timeout(timeout))?;
249        }
250
251        if let Some(compression) = config.compression {
252            request = request.header("grpc-encoding", compression.as_str())?;
253        }
254
255        let (request_tx, request_rx) = async_channel::unbounded::<Bytes>();
256        let (response_tx, response_rx) = async_channel::bounded(1);
257        std::thread::Builder::new()
258            .name("request-grpc-duplex".to_owned())
259            .spawn(move || {
260                async_io::block_on(async move {
261                    let body_stream: BodyStream = Box::pin(async_stream::stream! {
262                        while let Ok(item) = request_rx.recv().await {
263                            yield Ok(item);
264                        }
265                    });
266                    let response = request.body_stream(body_stream).await.and_then(|response| {
267                        GrpcStreamingResponse::from_http_response(response, codec)
268                    });
269                    let _ = response_tx.send(response).await;
270                });
271            })
272            .map_err(|err| {
273                Error::with_source(
274                    ErrorKind::Transport,
275                    "failed to spawn grpc duplex task",
276                    err,
277                )
278            })?;
279
280        Ok(GrpcDuplexCall {
281            codec,
282            compression: config.compression,
283            request_tx,
284            response_rx,
285            response: None,
286        })
287    }
288
289    fn validate_configuration(&self) -> Result<GrpcRequestConfig> {
290        let compression = parse_grpc_compression(self.compression.as_deref())?;
291        let content_type = match self.codec {
292            GrpcCodec::Json => "application/grpc+json",
293            GrpcCodec::Protobuf => "application/grpc+proto",
294        };
295        Ok(GrpcRequestConfig {
296            content_type,
297            compression,
298        })
299    }
300
301    fn build_request(self, config: GrpcRequestConfig) -> Result<RequestBuilder> {
302        let mut request = self
303            .request
304            .header("te", "trailers")?
305            .header("content-type", config.content_type)?
306            .header("grpc-accept-encoding", GrpcCompression::Gzip.as_str())?;
307
308        if let Some(timeout) = self.timeout {
309            request = request.header("grpc-timeout", &encode_grpc_timeout(timeout))?;
310        }
311
312        if let Some(compression) = config.compression {
313            request = request.header("grpc-encoding", compression.as_str())?;
314        }
315
316        match self.payload {
317            GrpcRequestPayload::Empty => Err(Error::new(
318                ErrorKind::Transport,
319                "grpc request requires at least one message",
320            )),
321            GrpcRequestPayload::JsonMessage(message) => {
322                if self.codec != GrpcCodec::Json {
323                    return Err(typed_request_codec_error());
324                }
325                Ok(request.body(encode_grpc_frame(&message, config.compression)?))
326            }
327            GrpcRequestPayload::JsonStream(stream) => {
328                if self.codec != GrpcCodec::Json {
329                    return Err(typed_request_codec_error());
330                }
331                Ok(request.body_stream(frame_message_stream(stream, config.compression)))
332            }
333            GrpcRequestPayload::RawMessage(message) => {
334                Ok(request.body(encode_grpc_frame(&message, config.compression)?))
335            }
336            GrpcRequestPayload::RawStream(stream) => {
337                Ok(request.body_stream(frame_message_stream(stream, config.compression)))
338            }
339        }
340    }
341}
342
343impl IntoFuture for GrpcRequestBuilder {
344    type Output = Result<GrpcResponse>;
345    type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'static>>;
346
347    fn into_future(self) -> Self::IntoFuture {
348        Box::pin(async move { self.send_streaming().await?.into_buffered_response().await })
349    }
350}
351
352#[derive(Debug)]
353pub struct GrpcResponse {
354    codec: GrpcCodec,
355    headers: HeaderMap,
356    messages: Vec<Bytes>,
357    next_message: usize,
358    trailers: Option<HeaderMap>,
359    status: GrpcStatus,
360}
361
362pub struct GrpcDuplexCall {
363    codec: GrpcCodec,
364    compression: Option<GrpcCompression>,
365    request_tx: Sender<Bytes>,
366    response_rx: Receiver<Result<GrpcStreamingResponse>>,
367    response: Option<GrpcStreamingResponse>,
368}
369
370impl fmt::Debug for GrpcDuplexCall {
371    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372        f.debug_struct("GrpcDuplexCall")
373            .field("codec", &self.codec)
374            .field("compression", &self.compression)
375            .field("response_ready", &self.response.is_some())
376            .finish()
377    }
378}
379
380impl GrpcDuplexCall {
381    pub async fn send_message<T: Serialize>(&self, value: &T) -> Result<()> {
382        if self.codec != GrpcCodec::Json {
383            return Err(typed_request_codec_error());
384        }
385        let message = encode_json_message(value)?;
386        self.send_message_bytes(message).await
387    }
388
389    pub async fn send_message_bytes(&self, value: impl Into<Bytes>) -> Result<()> {
390        let frame = encode_grpc_frame(&value.into(), self.compression)?;
391        self.request_tx.send(frame).await.map_err(|_| {
392            Error::new(
393                ErrorKind::Transport,
394                "grpc duplex request stream is already closed",
395            )
396        })
397    }
398
399    pub fn finish_sending(&self) {
400        self.request_tx.close();
401    }
402
403    pub async fn next_message<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
404        self.ensure_response_ready().await?;
405        self.response
406            .as_mut()
407            .expect("response ready")
408            .next_message()
409            .await
410    }
411
412    pub async fn next_message_bytes(&mut self) -> Result<Option<Bytes>> {
413        self.ensure_response_ready().await?;
414        self.response
415            .as_mut()
416            .expect("response ready")
417            .next_message_bytes()
418            .await
419    }
420
421    pub fn messages<'a, T: DeserializeOwned + 'a>(
422        &'a mut self,
423    ) -> Result<Pin<Box<dyn Stream<Item = Result<T>> + 'a>>> {
424        if self.codec != GrpcCodec::Json {
425            return Err(typed_response_codec_error());
426        }
427        Ok(Box::pin(async_stream::stream! {
428            while let Some(item) = self.next_message_bytes().await? {
429                let value = serde_json::from_slice(&item).map_err(|err| {
430                    Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
431                })?;
432                yield Ok(value);
433            }
434        }))
435    }
436
437    pub fn messages_bytes<'a>(
438        &'a mut self,
439    ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + 'a>>> {
440        Ok(Box::pin(async_stream::stream! {
441            while let Some(item) = self.next_message_bytes().await? {
442                yield Ok(item);
443            }
444        }))
445    }
446
447    pub async fn finish(&mut self) -> Result<GrpcStatus> {
448        self.finish_sending();
449        self.ensure_response_ready().await?;
450        self.response
451            .as_mut()
452            .expect("response ready")
453            .finish()
454            .await
455    }
456
457    pub fn trailers(&self) -> Result<Option<HeaderMap>> {
458        match self.response.as_ref() {
459            Some(response) => response.trailers(),
460            None => Ok(None),
461        }
462    }
463
464    pub fn metadata(&self, name: &str) -> Vec<String> {
465        self.response
466            .as_ref()
467            .map(|response| response.metadata(name))
468            .unwrap_or_default()
469    }
470
471    pub fn metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
472        match self.response.as_ref() {
473            Some(response) => response.metadata_bin(name),
474            None => Ok(Vec::new()),
475        }
476    }
477
478    pub fn trailer_metadata(&self, name: &str) -> Vec<String> {
479        self.response
480            .as_ref()
481            .map(|response| response.trailer_metadata(name))
482            .unwrap_or_default()
483    }
484
485    pub fn trailer_metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
486        match self.response.as_ref() {
487            Some(response) => response.trailer_metadata_bin(name),
488            None => Ok(Vec::new()),
489        }
490    }
491
492    pub fn status(&self) -> Result<Option<GrpcStatus>> {
493        match self.response.as_ref() {
494            Some(response) => response.status(),
495            None => Ok(None),
496        }
497    }
498
499    pub fn is_complete(&self) -> bool {
500        self.response
501            .as_ref()
502            .map(GrpcStreamingResponse::is_complete)
503            .unwrap_or(false)
504    }
505
506    async fn ensure_response_ready(&mut self) -> Result<()> {
507        if self.response.is_none() {
508            let response = self.response_rx.recv().await.map_err(|_| {
509                Error::new(
510                    ErrorKind::Transport,
511                    "grpc duplex task stopped before response headers arrived",
512                )
513            })??;
514            self.response = Some(response);
515        }
516        Ok(())
517    }
518}
519
520impl GrpcResponse {
521    pub fn metadata(&self, name: &str) -> Vec<String> {
522        grpc_metadata_values(&self.headers, name)
523    }
524
525    pub fn metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
526        grpc_binary_metadata_values(&self.headers, name)
527    }
528
529    pub async fn message<T: DeserializeOwned>(&mut self) -> Result<T> {
530        if self.codec != GrpcCodec::Json {
531            return Err(typed_response_codec_error());
532        }
533        let message = self.message_bytes().await?;
534        serde_json::from_slice(&message).map_err(|err| {
535            Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
536        })
537    }
538
539    pub fn messages<'a, T: DeserializeOwned + 'a>(
540        &'a mut self,
541    ) -> Result<Pin<Box<dyn Stream<Item = Result<T>> + 'a>>> {
542        if self.codec != GrpcCodec::Json {
543            return Err(typed_response_codec_error());
544        }
545        let stream = self.messages_bytes()?;
546        Ok(Box::pin(async_stream::stream! {
547            let mut stream = stream;
548            while let Some(item) = stream.next().await {
549                let item = item?;
550                let value = serde_json::from_slice(&item).map_err(|err| {
551                    Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
552                })?;
553                yield Ok(value);
554            }
555        }))
556    }
557
558    pub async fn message_bytes(&mut self) -> Result<Bytes> {
559        self.ensure_ok_status()?;
560        let remaining = self.messages.len().saturating_sub(self.next_message);
561        if remaining != 1 {
562            return Err(Error::new(
563                ErrorKind::Transport,
564                format!("expected exactly one grpc response message, found {remaining}"),
565            ));
566        }
567        let message = self.messages[self.next_message].clone();
568        self.next_message += 1;
569        Ok(message)
570    }
571
572    pub fn messages_bytes<'a>(
573        &'a mut self,
574    ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + 'a>>> {
575        self.ensure_ok_status()?;
576        let messages = self.messages[self.next_message..].to_vec();
577        self.next_message = self.messages.len();
578        Ok(Box::pin(stream::iter(messages.into_iter().map(Ok))))
579    }
580
581    pub fn trailers(&self) -> Result<Option<HeaderMap>> {
582        Ok(self.trailers.clone())
583    }
584
585    pub fn trailer_metadata(&self, name: &str) -> Vec<String> {
586        self.trailers
587            .as_ref()
588            .map(|trailers| grpc_metadata_values(trailers, name))
589            .unwrap_or_default()
590    }
591
592    pub fn trailer_metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
593        match self.trailers.as_ref() {
594            Some(trailers) => grpc_binary_metadata_values(trailers, name),
595            None => Ok(Vec::new()),
596        }
597    }
598
599    pub fn status(&self) -> Result<GrpcStatus> {
600        Ok(self.status.clone())
601    }
602
603    fn ensure_ok_status(&self) -> Result<()> {
604        if self.status.code != 0 {
605            return Err(grpc_status_error(&self.status));
606        }
607        Ok(())
608    }
609}
610
611pub struct GrpcStreamingResponse {
612    codec: GrpcCodec,
613    decoder: GrpcResponseDecoder,
614}
615
616impl fmt::Debug for GrpcStreamingResponse {
617    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
618        f.debug_struct("GrpcStreamingResponse")
619            .field("codec", &self.codec)
620            .field("complete", &self.decoder.complete)
621            .field("status", &self.decoder.final_status)
622            .finish()
623    }
624}
625
626impl GrpcStreamingResponse {
627    fn from_http_response(response: Response, codec: GrpcCodec) -> Result<Self> {
628        validate_grpc_http_response(&response)?;
629        let headers = response.headers().clone();
630        let compression = parse_grpc_compression(headers.get("grpc-encoding"))?;
631        let (body, trailers) = response.into_body_stream_and_trailer_state();
632        Ok(Self {
633            codec,
634            decoder: GrpcResponseDecoder {
635                body,
636                headers,
637                trailers,
638                compression,
639                frame_decoder: GrpcFrameDecoder::default(),
640                final_trailers: None,
641                final_status: None,
642                complete: false,
643            },
644        })
645    }
646
647    pub async fn next_message<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
648        if self.codec != GrpcCodec::Json {
649            return Err(typed_response_codec_error());
650        }
651        let Some(message) = self.next_message_bytes().await? else {
652            return Ok(None);
653        };
654        let value = serde_json::from_slice(&message).map_err(|err| {
655            Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
656        })?;
657        Ok(Some(value))
658    }
659
660    pub async fn next_message_bytes(&mut self) -> Result<Option<Bytes>> {
661        let next = self.decoder.next_message().await?;
662        if next.is_none() {
663            self.ensure_ok_status()?;
664        }
665        Ok(next)
666    }
667
668    pub fn messages<'a, T: DeserializeOwned + 'a>(
669        &'a mut self,
670    ) -> Result<Pin<Box<dyn Stream<Item = Result<T>> + 'a>>> {
671        if self.codec != GrpcCodec::Json {
672            return Err(typed_response_codec_error());
673        }
674        Ok(Box::pin(async_stream::stream! {
675            while let Some(item) = self.next_message_bytes().await? {
676                let value = serde_json::from_slice(&item).map_err(|err| {
677                    Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
678                })?;
679                yield Ok(value);
680            }
681        }))
682    }
683
684    pub fn messages_bytes<'a>(
685        &'a mut self,
686    ) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + 'a>>> {
687        Ok(Box::pin(async_stream::stream! {
688            while let Some(item) = self.next_message_bytes().await? {
689                yield Ok(item);
690            }
691        }))
692    }
693
694    pub async fn finish(&mut self) -> Result<GrpcStatus> {
695        while self.decoder.next_message().await?.is_some() {}
696        self.decoder
697            .final_status
698            .clone()
699            .ok_or_else(|| Error::new(ErrorKind::Transport, "grpc response did not complete"))
700    }
701
702    pub fn metadata(&self, name: &str) -> Vec<String> {
703        grpc_metadata_values(&self.decoder.headers, name)
704    }
705
706    pub fn metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
707        grpc_binary_metadata_values(&self.decoder.headers, name)
708    }
709
710    pub fn trailer_metadata(&self, name: &str) -> Vec<String> {
711        self.decoder
712            .final_trailers
713            .as_ref()
714            .map(|trailers| grpc_metadata_values(trailers, name))
715            .unwrap_or_default()
716    }
717
718    pub fn trailer_metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
719        match self.decoder.final_trailers.as_ref() {
720            Some(trailers) => grpc_binary_metadata_values(trailers, name),
721            None => Ok(Vec::new()),
722        }
723    }
724
725    pub fn trailers(&self) -> Result<Option<HeaderMap>> {
726        Ok(self.decoder.final_trailers.clone())
727    }
728
729    pub fn status(&self) -> Result<Option<GrpcStatus>> {
730        Ok(self.decoder.final_status.clone())
731    }
732
733    pub fn is_complete(&self) -> bool {
734        self.decoder.complete
735    }
736
737    async fn into_buffered_response(mut self) -> Result<GrpcResponse> {
738        let mut messages = Vec::new();
739        while let Some(message) = self.decoder.next_message().await? {
740            messages.push(message);
741        }
742        let status =
743            self.decoder.final_status.clone().ok_or_else(|| {
744                Error::new(ErrorKind::Transport, "grpc response did not complete")
745            })?;
746        Ok(GrpcResponse {
747            codec: self.codec,
748            headers: self.decoder.headers.clone(),
749            messages,
750            next_message: 0,
751            trailers: self.decoder.final_trailers.clone(),
752            status,
753        })
754    }
755
756    fn ensure_ok_status(&self) -> Result<()> {
757        if let Some(status) = self.decoder.final_status.as_ref() {
758            if status.code != 0 {
759                return Err(grpc_status_error(status));
760            }
761        }
762        Ok(())
763    }
764}
765
766struct GrpcResponseDecoder {
767    body: BodyStream,
768    headers: HeaderMap,
769    trailers: TrailerState,
770    compression: Option<GrpcCompression>,
771    frame_decoder: GrpcFrameDecoder,
772    final_trailers: Option<HeaderMap>,
773    final_status: Option<GrpcStatus>,
774    complete: bool,
775}
776
777impl GrpcResponseDecoder {
778    async fn next_message(&mut self) -> Result<Option<Bytes>> {
779        loop {
780            if let Some(message) = self.frame_decoder.next_message(self.compression)? {
781                return Ok(Some(message));
782            }
783
784            if self.complete {
785                if self.frame_decoder.has_buffered_data() {
786                    return Err(self.frame_decoder.incomplete_frame_error());
787                }
788                return Ok(None);
789            }
790
791            match self.body.next().await {
792                Some(chunk) => self.frame_decoder.push(chunk?),
793                None => self.finish_stream()?,
794            }
795        }
796    }
797
798    fn finish_stream(&mut self) -> Result<()> {
799        if self.complete {
800            return Ok(());
801        }
802        self.complete = true;
803        let trailer_state = std::mem::replace(&mut self.trailers, TrailerState::Ready(None));
804        let trailers = match trailer_state.take() {
805            Some(trailers) => Some(trailers),
806            None if self.headers.get("grpc-status").is_some()
807                || self.headers.get("grpc-message").is_some() =>
808            {
809                Some(self.headers.clone())
810            }
811            None => None,
812        };
813        self.final_status = Some(parse_grpc_status(trailers.as_ref(), &self.headers)?);
814        self.final_trailers = trailers;
815        Ok(())
816    }
817}
818
819#[derive(Default)]
820struct GrpcFrameDecoder {
821    buffer: BytesMut,
822}
823
824impl GrpcFrameDecoder {
825    fn push(&mut self, chunk: Bytes) {
826        self.buffer.extend_from_slice(&chunk);
827    }
828
829    fn next_message(&mut self, compression: Option<GrpcCompression>) -> Result<Option<Bytes>> {
830        if self.buffer.len() < 5 {
831            return Ok(None);
832        }
833
834        let compressed_flag = self.buffer[0];
835        if compressed_flag > 1 {
836            return Err(Error::new(
837                ErrorKind::Transport,
838                format!("invalid grpc compression flag: {compressed_flag}"),
839            ));
840        }
841
842        let message_len = u32::from_be_bytes([
843            self.buffer[1],
844            self.buffer[2],
845            self.buffer[3],
846            self.buffer[4],
847        ]) as usize;
848        if self.buffer.len() < 5 + message_len {
849            return Ok(None);
850        }
851
852        let frame = self.buffer.split_to(5 + message_len);
853        let payload = &frame[5..];
854        let message = if compressed_flag == 1 {
855            let compression = compression.ok_or_else(|| {
856                Error::new(
857                    ErrorKind::Transport,
858                    "grpc frame is compressed but grpc-encoding is missing",
859                )
860            })?;
861            decompress_grpc_message(payload, compression)?
862        } else {
863            Bytes::copy_from_slice(payload)
864        };
865        Ok(Some(message))
866    }
867
868    fn has_buffered_data(&self) -> bool {
869        !self.buffer.is_empty()
870    }
871
872    fn incomplete_frame_error(&self) -> Error {
873        if self.buffer.len() < 5 {
874            Error::new(ErrorKind::Transport, "incomplete grpc frame header")
875        } else {
876            Error::new(
877                ErrorKind::Transport,
878                "grpc frame length exceeds remaining response body",
879            )
880        }
881    }
882}
883
884#[derive(Clone, Debug, Eq, PartialEq)]
885pub struct GrpcStatus {
886    code: i32,
887    message: String,
888    details_bin: Option<Bytes>,
889}
890
891impl GrpcStatus {
892    pub fn new(code: i32, message: impl Into<String>) -> Self {
893        Self {
894            code,
895            message: message.into(),
896            details_bin: None,
897        }
898    }
899
900    fn with_details(code: i32, message: impl Into<String>, details_bin: Option<Bytes>) -> Self {
901        Self {
902            code,
903            message: message.into(),
904            details_bin,
905        }
906    }
907
908    pub fn code(&self) -> i32 {
909        self.code
910    }
911
912    pub fn message(&self) -> &str {
913        &self.message
914    }
915
916    pub fn details_bin(&self) -> Option<&[u8]> {
917        self.details_bin.as_deref()
918    }
919}
920
921fn validate_grpc_http_response(response: &Response) -> Result<()> {
922    if !matches!(response.version(), Version::Http2 | Version::Http3) {
923        return Err(Error::new(
924            ErrorKind::Transport,
925            "grpc requires an HTTP/2 or HTTP/3 response",
926        ));
927    }
928    if response.status().as_u16() != 200 {
929        return Err(Error::new(
930            ErrorKind::Transport,
931            format!(
932                "unexpected grpc http status: {}",
933                response.status().as_u16()
934            ),
935        ));
936    }
937
938    let Some(content_type) = response.headers().get("content-type") else {
939        return Err(Error::new(
940            ErrorKind::Transport,
941            "grpc response did not include content-type",
942        ));
943    };
944    if !content_type
945        .to_ascii_lowercase()
946        .starts_with("application/grpc")
947    {
948        return Err(Error::new(
949            ErrorKind::Transport,
950            format!("unexpected grpc content-type: {content_type}"),
951        ));
952    }
953    Ok(())
954}
955
956fn encode_grpc_timeout(duration: Duration) -> String {
957    if duration.as_nanos() == 0 {
958        return "1n".to_owned();
959    }
960
961    const UNITS: &[(u128, char)] = &[
962        (3_600_000_000_000u128, 'H'),
963        (60_000_000_000u128, 'M'),
964        (1_000_000_000u128, 'S'),
965        (1_000_000u128, 'm'),
966        (1_000u128, 'u'),
967        (1u128, 'n'),
968    ];
969
970    let nanos = duration.as_nanos();
971    for (unit_nanos, suffix) in UNITS {
972        let value = nanos / unit_nanos;
973        if value > 0 && value <= 99_999_999 && nanos % unit_nanos == 0 {
974            return format!("{value}{suffix}");
975        }
976    }
977
978    for (unit_nanos, suffix) in UNITS.iter().rev() {
979        let value = nanos.div_ceil(*unit_nanos);
980        if value <= 99_999_999 {
981            return format!("{value}{suffix}");
982        }
983    }
984
985    "99999999H".to_owned()
986}
987
988fn encode_json_message<T: Serialize>(value: &T) -> Result<Bytes> {
989    let payload = serde_json::to_vec(value).map_err(|err| {
990        Error::with_source(ErrorKind::Decode, "failed to encode grpc json message", err)
991    })?;
992    Ok(Bytes::from(payload))
993}
994
995fn frame_message_stream(
996    stream: GrpcMessageStream,
997    compression: Option<GrpcCompression>,
998) -> BodyStream {
999    Box::pin(async_stream::stream! {
1000        let mut stream = stream;
1001        while let Some(item) = stream.next().await {
1002            let item = item?;
1003            yield Ok(encode_grpc_frame(&item, compression)?);
1004        }
1005    })
1006}
1007
1008fn encode_grpc_frame(payload: &[u8], compression: Option<GrpcCompression>) -> Result<Bytes> {
1009    let (compressed_flag, payload) = match compression {
1010        Some(GrpcCompression::Gzip) => (1_u8, gzip_compress(payload)?),
1011        None => (0_u8, payload.to_vec()),
1012    };
1013
1014    let mut framed = Vec::with_capacity(5 + payload.len());
1015    framed.push(compressed_flag);
1016    framed.extend_from_slice(&(payload.len() as u32).to_be_bytes());
1017    framed.extend_from_slice(&payload);
1018    Ok(Bytes::from(framed))
1019}
1020
1021#[cfg(test)]
1022fn decode_grpc_frames(body: Bytes, compression: Option<GrpcCompression>) -> Result<Vec<Bytes>> {
1023    let mut decoder = GrpcFrameDecoder::default();
1024    decoder.push(body);
1025    let mut messages = Vec::new();
1026    while let Some(message) = decoder.next_message(compression)? {
1027        messages.push(message);
1028    }
1029    if decoder.has_buffered_data() {
1030        return Err(decoder.incomplete_frame_error());
1031    }
1032    Ok(messages)
1033}
1034
1035fn parse_grpc_status(trailers: Option<&HeaderMap>, headers: &HeaderMap) -> Result<GrpcStatus> {
1036    let code = trailers
1037        .and_then(|map| map.get("grpc-status"))
1038        .or_else(|| headers.get("grpc-status"))
1039        .ok_or_else(|| {
1040            Error::new(
1041                ErrorKind::Transport,
1042                "grpc response did not include grpc-status",
1043            )
1044        })?;
1045    let code = code.parse::<i32>().map_err(|err| {
1046        Error::with_source(ErrorKind::Transport, "invalid grpc-status value", err)
1047    })?;
1048    let message = trailers
1049        .and_then(|map| map.get("grpc-message"))
1050        .or_else(|| headers.get("grpc-message"))
1051        .map(decode_grpc_message_header)
1052        .transpose()?
1053        .unwrap_or_default();
1054    let details_bin = trailers
1055        .and_then(|map| map.get("grpc-status-details-bin"))
1056        .or_else(|| headers.get("grpc-status-details-bin"))
1057        .map(decode_grpc_binary_header)
1058        .transpose()?;
1059    Ok(GrpcStatus::with_details(code, message, details_bin))
1060}
1061
1062fn grpc_metadata_values(headers: &HeaderMap, name: &str) -> Vec<String> {
1063    headers
1064        .get_all(name)
1065        .into_iter()
1066        .map(str::to_owned)
1067        .collect()
1068}
1069
1070fn grpc_binary_metadata_values(headers: &HeaderMap, name: &str) -> Result<Vec<Bytes>> {
1071    let name = normalize_grpc_binary_metadata_name(name)?;
1072    headers
1073        .get_all(&name)
1074        .into_iter()
1075        .map(decode_grpc_binary_header)
1076        .collect()
1077}
1078
1079fn normalize_grpc_binary_metadata_name(name: &str) -> Result<String> {
1080    let name = name.trim();
1081    if name.is_empty() {
1082        return Err(Error::new(
1083            ErrorKind::InvalidHeaderName,
1084            "grpc binary metadata name cannot be empty",
1085        ));
1086    }
1087    if name.ends_with("-bin") {
1088        return Ok(name.to_ascii_lowercase());
1089    }
1090    Ok(format!("{}-bin", name.to_ascii_lowercase()))
1091}
1092
1093fn encode_grpc_binary_header(bytes: &[u8]) -> String {
1094    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1095    let mut output = String::with_capacity(bytes.len().div_ceil(3) * 4);
1096    for chunk in bytes.chunks(3) {
1097        let b0 = chunk[0];
1098        let b1 = *chunk.get(1).unwrap_or(&0);
1099        let b2 = *chunk.get(2).unwrap_or(&0);
1100        let n = ((b0 as u32) << 16) | ((b1 as u32) << 8) | b2 as u32;
1101        output.push(TABLE[((n >> 18) & 0x3F) as usize] as char);
1102        output.push(TABLE[((n >> 12) & 0x3F) as usize] as char);
1103        if chunk.len() > 1 {
1104            output.push(TABLE[((n >> 6) & 0x3F) as usize] as char);
1105        } else {
1106            output.push('=');
1107        }
1108        if chunk.len() > 2 {
1109            output.push(TABLE[(n & 0x3F) as usize] as char);
1110        } else {
1111            output.push('=');
1112        }
1113    }
1114    output
1115}
1116
1117fn decode_grpc_binary_header(value: &str) -> Result<Bytes> {
1118    let bytes = value.trim().as_bytes();
1119    if bytes.is_empty() {
1120        return Ok(Bytes::new());
1121    }
1122
1123    let mut output = Vec::with_capacity((bytes.len() / 4) * 3);
1124    let mut chunk = [0_u8; 4];
1125    let mut chunk_len = 0usize;
1126    let mut padding = 0usize;
1127
1128    for &byte in bytes {
1129        if byte == b'=' {
1130            chunk[chunk_len] = 0;
1131            chunk_len += 1;
1132            padding += 1;
1133        } else {
1134            chunk[chunk_len] = decode_base64_value(byte)?;
1135            chunk_len += 1;
1136        }
1137
1138        if chunk_len == 4 {
1139            output.push((chunk[0] << 2) | (chunk[1] >> 4));
1140            if padding < 2 {
1141                output.push((chunk[1] << 4) | (chunk[2] >> 2));
1142            }
1143            if padding == 0 {
1144                output.push((chunk[2] << 6) | chunk[3]);
1145            }
1146            chunk_len = 0;
1147            padding = 0;
1148        }
1149    }
1150
1151    // Handle remaining partial groups.  The gRPC spec (and RFC 4648 §3.2)
1152    // explicitly permits base64 without trailing `=` padding, so 2 or 3
1153    // remaining characters are valid; 1 character is not decodable.
1154    match chunk_len {
1155        0 => {}
1156        2 => {
1157            // Two base64 chars → one output byte
1158            output.push((chunk[0] << 2) | (chunk[1] >> 4));
1159        }
1160        3 => {
1161            // Three base64 chars → two output bytes
1162            output.push((chunk[0] << 2) | (chunk[1] >> 4));
1163            output.push((chunk[1] << 4) | (chunk[2] >> 2));
1164        }
1165        _ => {
1166            return Err(Error::new(
1167                ErrorKind::Transport,
1168                "invalid grpc binary metadata encoding",
1169            ));
1170        }
1171    }
1172
1173    Ok(Bytes::from(output))
1174}
1175
1176fn decode_base64_value(byte: u8) -> Result<u8> {
1177    match byte {
1178        b'A'..=b'Z' => Ok(byte - b'A'),
1179        b'a'..=b'z' => Ok(byte - b'a' + 26),
1180        b'0'..=b'9' => Ok(byte - b'0' + 52),
1181        b'+' => Ok(62),
1182        b'/' => Ok(63),
1183        _ => Err(Error::new(
1184            ErrorKind::Transport,
1185            "invalid grpc binary metadata encoding",
1186        )),
1187    }
1188}
1189
1190fn parse_grpc_compression(value: Option<&str>) -> Result<Option<GrpcCompression>> {
1191    match value.map(str::trim) {
1192        None | Some("") => Ok(None),
1193        Some(value) if value.eq_ignore_ascii_case("identity") => Ok(None),
1194        Some(value) if value.eq_ignore_ascii_case("gzip") => Ok(Some(GrpcCompression::Gzip)),
1195        Some(value) => Err(Error::new(
1196            ErrorKind::Transport,
1197            format!("unsupported grpc compression algorithm: {value}"),
1198        )),
1199    }
1200}
1201
1202fn gzip_compress(payload: &[u8]) -> Result<Vec<u8>> {
1203    let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
1204    encoder.write_all(payload).map_err(|err| {
1205        Error::with_source(ErrorKind::Transport, "failed to gzip grpc message", err)
1206    })?;
1207    encoder.finish().map_err(|err| {
1208        Error::with_source(
1209            ErrorKind::Transport,
1210            "failed to finish grpc gzip encoding",
1211            err,
1212        )
1213    })
1214}
1215
1216fn decompress_grpc_message(payload: &[u8], compression: GrpcCompression) -> Result<Bytes> {
1217    match compression {
1218        GrpcCompression::Gzip => {
1219            let mut decoder = GzDecoder::new(payload);
1220            let mut decoded = Vec::new();
1221            decoder.read_to_end(&mut decoded).map_err(|err| {
1222                Error::with_source(ErrorKind::Transport, "failed to gunzip grpc message", err)
1223            })?;
1224            Ok(Bytes::from(decoded))
1225        }
1226    }
1227}
1228
1229fn decode_grpc_message_header(value: &str) -> Result<String> {
1230    let bytes = value.as_bytes();
1231    let mut decoded = Vec::with_capacity(bytes.len());
1232    let mut index = 0usize;
1233    while index < bytes.len() {
1234        if bytes[index] == b'%' {
1235            if index + 2 >= bytes.len() {
1236                return Err(Error::new(
1237                    ErrorKind::Transport,
1238                    "invalid grpc-message percent encoding",
1239                ));
1240            }
1241            let hex = std::str::from_utf8(&bytes[index + 1..index + 3]).map_err(|err| {
1242                Error::with_source(
1243                    ErrorKind::Transport,
1244                    "invalid grpc-message percent encoding",
1245                    err,
1246                )
1247            })?;
1248            let byte = u8::from_str_radix(hex, 16).map_err(|err| {
1249                Error::with_source(
1250                    ErrorKind::Transport,
1251                    "invalid grpc-message percent encoding",
1252                    err,
1253                )
1254            })?;
1255            decoded.push(byte);
1256            index += 3;
1257            continue;
1258        }
1259        decoded.push(bytes[index]);
1260        index += 1;
1261    }
1262    String::from_utf8(decoded).map_err(|err| {
1263        Error::with_source(ErrorKind::Transport, "grpc-message is not valid utf-8", err)
1264    })
1265}
1266
1267fn grpc_status_error(status: &GrpcStatus) -> Error {
1268    let message = if status.message.is_empty() {
1269        format!("grpc request failed with status {}", status.code)
1270    } else {
1271        format!(
1272            "grpc request failed with status {}: {}",
1273            status.code, status.message
1274        )
1275    };
1276    Error::new(ErrorKind::Transport, message)
1277}
1278
1279fn duplex_payload_conflict_error() -> Error {
1280    Error::new(
1281        ErrorKind::Transport,
1282        "grpc duplex call manages request messages itself; do not combine open_duplex() with message/messages/message_bytes/messages_bytes",
1283    )
1284}
1285
1286fn typed_request_codec_error() -> Error {
1287    Error::new(
1288        ErrorKind::Transport,
1289        "typed grpc request APIs only support GrpcCodec::Json; use message_bytes/messages_bytes for protobuf payloads",
1290    )
1291}
1292
1293fn typed_response_codec_error() -> Error {
1294    Error::new(
1295        ErrorKind::Transport,
1296        "typed grpc response APIs only support GrpcCodec::Json; use message_bytes/messages_bytes for protobuf payloads",
1297    )
1298}
1299
1300#[cfg(test)]
1301mod tests {
1302    use super::{
1303        GrpcCodec, GrpcCompression, GrpcFrameDecoder, decode_grpc_binary_header,
1304        decode_grpc_frames, decode_grpc_message_header, encode_grpc_binary_header,
1305        encode_grpc_frame, encode_grpc_timeout, normalize_grpc_binary_metadata_name,
1306        parse_grpc_status, validate_grpc_http_response,
1307    };
1308    use bytes::Bytes;
1309    use std::time::Duration;
1310
1311    use crate::{Body, HeaderMap, StatusCode, Url, Version};
1312
1313    #[test]
1314    fn grpc_frame_round_trip_keeps_message_boundaries() {
1315        let first = encode_grpc_frame(br#"{"name":"one"}"#, None).unwrap();
1316        let second = encode_grpc_frame(br#"{"name":"two"}"#, None).unwrap();
1317        let combined = Bytes::from([first.as_ref(), second.as_ref()].concat());
1318
1319        let frames = decode_grpc_frames(combined, None).unwrap();
1320        assert_eq!(frames.len(), 2);
1321        assert_eq!(frames[0], Bytes::from_static(br#"{"name":"one"}"#));
1322        assert_eq!(frames[1], Bytes::from_static(br#"{"name":"two"}"#));
1323    }
1324
1325    #[test]
1326    fn grpc_gzip_frame_round_trip_restores_original_payload() {
1327        let frame = encode_grpc_frame(br#"{"name":"gzip"}"#, Some(GrpcCompression::Gzip)).unwrap();
1328        let frames = decode_grpc_frames(frame, Some(GrpcCompression::Gzip)).unwrap();
1329        assert_eq!(frames, vec![Bytes::from_static(br#"{"name":"gzip"}"#)]);
1330    }
1331
1332    #[test]
1333    fn grpc_frame_decoder_handles_split_frame_boundaries() {
1334        let frame = encode_grpc_frame(br#"{"name":"split"}"#, None).unwrap();
1335        let mut decoder = GrpcFrameDecoder::default();
1336        decoder.push(frame.slice(..2));
1337        assert!(decoder.next_message(None).unwrap().is_none());
1338        decoder.push(frame.slice(2..7));
1339        assert!(decoder.next_message(None).unwrap().is_none());
1340        decoder.push(frame.slice(7..));
1341        assert_eq!(
1342            decoder.next_message(None).unwrap(),
1343            Some(Bytes::from_static(br#"{"name":"split"}"#))
1344        );
1345    }
1346
1347    #[test]
1348    fn grpc_message_header_percent_decodes() {
1349        assert_eq!(
1350            decode_grpc_message_header("user%20not%20found").unwrap(),
1351            "user not found"
1352        );
1353    }
1354
1355    #[test]
1356    fn open_duplex_rejects_preconfigured_payload_before_network() {
1357        let err = futures_lite::future::block_on(async {
1358            crate::grpc("https://example.com/chat.Service/Talk")
1359                .message_bytes(Bytes::from_static(b"hello"))?
1360                .open_duplex()
1361                .await
1362        })
1363        .unwrap_err();
1364        assert_eq!(err.kind(), &crate::ErrorKind::Transport);
1365        assert!(
1366            err.to_string()
1367                .contains("grpc duplex call manages request messages itself")
1368        );
1369    }
1370
1371    #[test]
1372    fn protobuf_typed_request_api_returns_explicit_error_before_network() {
1373        let err = futures_lite::future::block_on(async {
1374            crate::grpc("https://example.com/greeter.SayHello/Call")
1375                .codec(GrpcCodec::Protobuf)
1376                .message(&serde_json::json!({ "name": "Ada" }))?
1377                .await
1378        })
1379        .unwrap_err();
1380        assert_eq!(err.kind(), &crate::ErrorKind::Transport);
1381        assert!(
1382            err.to_string()
1383                .contains("typed grpc request APIs only support GrpcCodec::Json")
1384        );
1385    }
1386
1387    #[test]
1388    fn grpc_timeout_header_uses_smallest_fitting_unit() {
1389        assert_eq!(encode_grpc_timeout(Duration::from_millis(1500)), "1500m");
1390        assert_eq!(encode_grpc_timeout(Duration::from_secs(12)), "12S");
1391        assert_eq!(encode_grpc_timeout(Duration::from_nanos(1)), "1n");
1392    }
1393
1394    #[test]
1395    fn grpc_response_requires_content_type() {
1396        let response = crate::Response::new(
1397            StatusCode::OK,
1398            Version::Http2,
1399            Url::parse("https://example.com/grpc").unwrap(),
1400            HeaderMap::new(),
1401            Some({
1402                let mut trailers = HeaderMap::new();
1403                trailers.insert("grpc-status", "0").unwrap();
1404                trailers
1405            }),
1406            Body::default(),
1407        );
1408
1409        let err = validate_grpc_http_response(&response).unwrap_err();
1410        assert_eq!(err.kind(), &crate::ErrorKind::Transport);
1411        assert!(err.to_string().contains("did not include content-type"));
1412    }
1413
1414    #[test]
1415    fn grpc_response_rejects_non_grpc_content_type() {
1416        let mut headers = HeaderMap::new();
1417        headers.insert("content-type", "application/json").unwrap();
1418        let response = crate::Response::new(
1419            StatusCode::OK,
1420            Version::Http2,
1421            Url::parse("https://example.com/grpc").unwrap(),
1422            headers,
1423            Some({
1424                let mut trailers = HeaderMap::new();
1425                trailers.insert("grpc-status", "0").unwrap();
1426                trailers
1427            }),
1428            Body::default(),
1429        );
1430
1431        let err = validate_grpc_http_response(&response).unwrap_err();
1432        assert_eq!(err.kind(), &crate::ErrorKind::Transport);
1433        assert!(err.to_string().contains("unexpected grpc content-type"));
1434    }
1435
1436    #[test]
1437    fn grpc_binary_metadata_name_normalizes_suffix() {
1438        assert_eq!(
1439            normalize_grpc_binary_metadata_name("trace").unwrap(),
1440            "trace-bin"
1441        );
1442        assert_eq!(
1443            normalize_grpc_binary_metadata_name("trace-bin").unwrap(),
1444            "trace-bin"
1445        );
1446    }
1447
1448    #[test]
1449    fn grpc_binary_metadata_round_trips() {
1450        let encoded = encode_grpc_binary_header(b"\x00\x01grpc");
1451        let decoded = decode_grpc_binary_header(&encoded).unwrap();
1452        assert_eq!(decoded, Bytes::from_static(b"\x00\x01grpc"));
1453    }
1454
1455    // RFC 4648 §3.2 and the gRPC spec permit base64 without trailing `=` padding.
1456    // Servers frequently omit padding; the decoder must handle all partial groups.
1457    #[test]
1458    fn grpc_binary_header_decodes_unpadded_base64() {
1459        // 1-byte value: encode produces "AA==" but servers may send "AA"
1460        let decoded = decode_grpc_binary_header("AA").unwrap();
1461        assert_eq!(&decoded[..], b"\x00");
1462
1463        // 2-byte value: encode produces "AAA=" but servers may send "AAA"
1464        let decoded = decode_grpc_binary_header("AAA").unwrap();
1465        assert_eq!(&decoded[..], b"\x00\x00");
1466
1467        // 3-byte value: no padding needed, always 4 chars
1468        let decoded = decode_grpc_binary_header("AAAA").unwrap();
1469        assert_eq!(&decoded[..], b"\x00\x00\x00");
1470
1471        // Round-trip of a realistic value without padding
1472        let raw = b"\xde\xad\xbe\xef";
1473        let padded = encode_grpc_binary_header(raw);
1474        // Strip trailing `=` to simulate a server that omits padding
1475        let unpadded = padded.trim_end_matches('=');
1476        let decoded = decode_grpc_binary_header(unpadded).unwrap();
1477        assert_eq!(&decoded[..], raw);
1478    }
1479
1480    #[test]
1481    fn grpc_binary_header_rejects_invalid_chunk_length_of_one() {
1482        // A single base64 character cannot decode to any byte (need at least 2).
1483        let err = decode_grpc_binary_header("A").unwrap_err();
1484        assert_eq!(err.kind(), &crate::ErrorKind::Transport);
1485    }
1486
1487    #[test]
1488    fn grpc_status_exposes_status_details_bin() {
1489        let mut trailers = HeaderMap::new();
1490        trailers.insert("grpc-status", "7").unwrap();
1491        trailers.insert("grpc-message", "denied").unwrap();
1492        trailers
1493            .insert(
1494                "grpc-status-details-bin",
1495                encode_grpc_binary_header(b"details").as_str(),
1496            )
1497            .unwrap();
1498
1499        let status = parse_grpc_status(Some(&trailers), &HeaderMap::new()).unwrap();
1500        assert_eq!(status.code(), 7);
1501        assert_eq!(status.message(), "denied");
1502        assert_eq!(status.details_bin(), Some(&b"details"[..]));
1503    }
1504}