tonic_arcanyx_fork/codec/
compression.rs

1use super::encode::BUFFER_SIZE;
2use crate::{metadata::MetadataValue, Status};
3use bytes::{Buf, BytesMut};
4#[cfg(feature = "gzip")]
5use flate2::read::{GzDecoder, GzEncoder};
6use std::fmt;
7
8pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
9pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
10
11/// Struct used to configure which encodings are enabled on a server or channel.
12#[derive(Debug, Default, Clone, Copy)]
13pub struct EnabledCompressionEncodings {
14    #[cfg(feature = "gzip")]
15    pub(crate) gzip: bool,
16}
17
18impl EnabledCompressionEncodings {
19    /// Check if a [`CompressionEncoding`] is enabled.
20    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
21        match encoding {
22            #[cfg(feature = "gzip")]
23            CompressionEncoding::Gzip => self.gzip,
24        }
25    }
26
27    /// Enable a [`CompressionEncoding`].
28    pub fn enable(&mut self, encoding: CompressionEncoding) {
29        match encoding {
30            #[cfg(feature = "gzip")]
31            CompressionEncoding::Gzip => self.gzip = true,
32        }
33    }
34
35    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
36        if self.is_gzip_enabled() {
37            Some(http::HeaderValue::from_static("gzip,identity"))
38        } else {
39            None
40        }
41    }
42
43    #[cfg(feature = "gzip")]
44    const fn is_gzip_enabled(&self) -> bool {
45        self.gzip
46    }
47
48    #[cfg(not(feature = "gzip"))]
49    const fn is_gzip_enabled(&self) -> bool {
50        false
51    }
52}
53
54/// The compression encodings Tonic supports.
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
56#[non_exhaustive]
57pub enum CompressionEncoding {
58    #[allow(missing_docs)]
59    #[cfg(feature = "gzip")]
60    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
61    Gzip,
62}
63
64impl CompressionEncoding {
65    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
66    pub(crate) fn from_accept_encoding_header(
67        map: &http::HeaderMap,
68        enabled_encodings: EnabledCompressionEncodings,
69    ) -> Option<Self> {
70        if !enabled_encodings.is_gzip_enabled() {
71            return None;
72        }
73
74        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
75        let header_value_str = header_value.to_str().ok()?;
76
77        split_by_comma(header_value_str).find_map(|value| match value {
78            #[cfg(feature = "gzip")]
79            "gzip" => Some(CompressionEncoding::Gzip),
80            _ => None,
81        })
82    }
83
84    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
85    pub(crate) fn from_encoding_header(
86        map: &http::HeaderMap,
87        enabled_encodings: EnabledCompressionEncodings,
88    ) -> Result<Option<Self>, Status> {
89        let header_value = if let Some(value) = map.get(ENCODING_HEADER) {
90            value
91        } else {
92            return Ok(None);
93        };
94
95        let header_value_str = if let Ok(value) = header_value.to_str() {
96            value
97        } else {
98            return Ok(None);
99        };
100
101        match header_value_str {
102            #[cfg(feature = "gzip")]
103            "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
104                Ok(Some(CompressionEncoding::Gzip))
105            }
106            "identity" => Ok(None),
107            other => {
108                let mut status = Status::unimplemented(format!(
109                    "Content is compressed with `{}` which isn't supported",
110                    other
111                ));
112
113                let header_value = enabled_encodings
114                    .into_accept_encoding_header_value()
115                    .map(MetadataValue::unchecked_from_header_value)
116                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
117                status
118                    .metadata_mut()
119                    .insert(ACCEPT_ENCODING_HEADER, header_value);
120
121                Err(status)
122            }
123        }
124    }
125
126    pub(crate) fn into_header_value(self) -> http::HeaderValue {
127        match self {
128            #[cfg(feature = "gzip")]
129            CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"),
130        }
131    }
132
133    pub(crate) fn encodings() -> &'static [Self] {
134        &[
135            #[cfg(feature = "gzip")]
136            CompressionEncoding::Gzip,
137        ]
138    }
139}
140
141impl fmt::Display for CompressionEncoding {
142    #[allow(unused_variables)]
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        match *self {
145            #[cfg(feature = "gzip")]
146            CompressionEncoding::Gzip => write!(f, "gzip"),
147        }
148    }
149}
150
151fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
152    s.trim().split(',').map(|s| s.trim())
153}
154
155/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
156#[allow(unused_variables, unreachable_code)]
157pub(crate) fn compress(
158    encoding: CompressionEncoding,
159    decompressed_buf: &mut BytesMut,
160    out_buf: &mut BytesMut,
161    len: usize,
162) -> Result<(), std::io::Error> {
163    let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
164    out_buf.reserve(capacity);
165
166    match encoding {
167        #[cfg(feature = "gzip")]
168        CompressionEncoding::Gzip => {
169            let mut gzip_encoder = GzEncoder::new(
170                &decompressed_buf[0..len],
171                // FIXME: support customizing the compression level
172                flate2::Compression::new(6),
173            );
174            let mut out_writer = bytes::BufMut::writer(out_buf);
175
176            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
177        }
178    }
179
180    decompressed_buf.advance(len);
181
182    Ok(())
183}
184
185/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
186#[allow(unused_variables, unreachable_code)]
187pub(crate) fn decompress(
188    encoding: CompressionEncoding,
189    compressed_buf: &mut BytesMut,
190    out_buf: &mut BytesMut,
191    len: usize,
192) -> Result<(), std::io::Error> {
193    let estimate_decompressed_len = len * 2;
194    let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
195    out_buf.reserve(capacity);
196
197    match encoding {
198        #[cfg(feature = "gzip")]
199        CompressionEncoding::Gzip => {
200            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
201            let mut out_writer = bytes::BufMut::writer(out_buf);
202
203            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
204        }
205    }
206
207    compressed_buf.advance(len);
208
209    Ok(())
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub(crate) enum SingleMessageCompressionOverride {
214    /// Inherit whatever compression is already configured. If the stream is compressed this
215    /// message will also be configured.
216    ///
217    /// This is the default.
218    Inherit,
219    /// Don't compress this message, even if compression is enabled on the stream.
220    Disable,
221}
222
223impl Default for SingleMessageCompressionOverride {
224    fn default() -> Self {
225        Self::Inherit
226    }
227}