Skip to main content

trillium_grpc/
encoding.rs

1//! Per-message compression negotiation and codecs.
2//!
3//! The wire surface is two HTTP/2 headers plus the per-message
4//! Compressed-Flag byte:
5//!
6//! - `grpc-encoding`: encoding the sender used for *its* messages
7//!   (request → request body messages; response → response body messages).
8//! - `grpc-accept-encoding`: comma-separated list the sender will accept
9//!   on the *peer's* messages.
10//!
11//! [`Encoding`] enumerates the codecs trillium-grpc was built with. The
12//! `Identity` variant is always present; `Gzip`, `Deflate`, and `Zstd` are
13//! cfg-gated on their respective Cargo features. `gzip` is on by default
14//! because it's the de-facto baseline for gRPC compression in the wild.
15//!
16//! Compression is one-shot (bytes → bytes) on per-message buffers, so we
17//! use the synchronous codec crates directly (`flate2` for gzip+deflate,
18//! `zstd` for zstd). The async-compression wrappers would only add
19//! AsyncRead-shaped ceremony for what is fundamentally a small in-memory
20//! transformation.
21
22use crate::Status;
23
24/// Default cap on a single decompressed message, matching grpc-go's default
25/// and the per-frame `max_message_size` in [`crate::frame::reader`].
26pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024;
27
28/// A per-message compression codec. Which variants exist depends on the
29/// enabled Cargo features; `Identity` (no compression) is always present.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum Encoding {
32    /// No compression.
33    Identity,
34    /// gzip (`flate2`), enabled by the `gzip` feature.
35    #[cfg(feature = "gzip")]
36    Gzip,
37    /// raw DEFLATE (`flate2`), enabled by the `deflate` feature.
38    #[cfg(feature = "deflate")]
39    Deflate,
40    /// Zstandard (`zstd`), enabled by the `zstd` feature.
41    #[cfg(feature = "zstd")]
42    Zstd,
43}
44
45impl Encoding {
46    /// Every codec compiled into this build, including `Identity`. Order
47    /// is the order presented in `grpc-accept-encoding`.
48    pub const ALL: &'static [Self] = &[
49        Self::Identity,
50        #[cfg(feature = "gzip")]
51        Self::Gzip,
52        #[cfg(feature = "deflate")]
53        Self::Deflate,
54        #[cfg(feature = "zstd")]
55        Self::Zstd,
56    ];
57
58    /// Parse a single `grpc-encoding` token. Returns `None` for codecs not
59    /// compiled in or values outside the spec set.
60    pub fn from_grpc_encoding(s: &str) -> Option<Self> {
61        match s {
62            "identity" => Some(Self::Identity),
63            #[cfg(feature = "gzip")]
64            "gzip" => Some(Self::Gzip),
65            #[cfg(feature = "deflate")]
66            "deflate" => Some(Self::Deflate),
67            #[cfg(feature = "zstd")]
68            "zstd" => Some(Self::Zstd),
69            _ => None,
70        }
71    }
72
73    /// The `grpc-encoding` token for this codec (`"identity"`, `"gzip"`, …).
74    pub fn as_grpc_encoding(&self) -> &'static str {
75        match self {
76            Self::Identity => "identity",
77            #[cfg(feature = "gzip")]
78            Self::Gzip => "gzip",
79            #[cfg(feature = "deflate")]
80            Self::Deflate => "deflate",
81            #[cfg(feature = "zstd")]
82            Self::Zstd => "zstd",
83        }
84    }
85
86    /// Comma-separated list of every codec in this build, suitable for the
87    /// `grpc-accept-encoding` response header. Memoized — the value is
88    /// constant for a given build.
89    pub fn accepted_encodings() -> &'static str {
90        static LIST: std::sync::OnceLock<String> = std::sync::OnceLock::new();
91        LIST.get_or_init(|| {
92            Self::ALL
93                .iter()
94                .map(|e| e.as_grpc_encoding())
95                .collect::<Vec<_>>()
96                .join(",")
97        })
98    }
99
100    /// Compress `data` with this codec. `Identity` returns a copy.
101    pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Status> {
102        match self {
103            Self::Identity => Ok(data.to_vec()),
104            #[cfg(feature = "gzip")]
105            Self::Gzip => gzip_compress(data),
106            #[cfg(feature = "deflate")]
107            Self::Deflate => deflate_compress(data),
108            #[cfg(feature = "zstd")]
109            Self::Zstd => zstd_compress(data),
110        }
111    }
112
113    /// Decompress `data` with this codec, capping the decompressed size at
114    /// `max_size` bytes (zip-bomb defense). `Identity` returns a copy and
115    /// errors if `data.len() > max_size`.
116    pub fn decompress(&self, data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
117        match self {
118            Self::Identity => {
119                if data.len() > max_size {
120                    return Err(oversize(max_size));
121                }
122                Ok(data.to_vec())
123            }
124            #[cfg(feature = "gzip")]
125            Self::Gzip => gzip_decompress(data, max_size),
126            #[cfg(feature = "deflate")]
127            Self::Deflate => deflate_decompress(data, max_size),
128            #[cfg(feature = "zstd")]
129            Self::Zstd => zstd_decompress(data, max_size),
130        }
131    }
132}
133
134fn oversize(max_size: usize) -> Status {
135    Status::resource_exhausted(format!(
136        "decompressed message exceeds limit of {max_size} bytes"
137    ))
138}
139
140#[cfg(feature = "gzip")]
141fn gzip_compress(data: &[u8]) -> Result<Vec<u8>, Status> {
142    use flate2::{Compression, write::GzEncoder};
143    use std::io::Write;
144    let mut enc = GzEncoder::new(Vec::with_capacity(data.len()), Compression::default());
145    enc.write_all(data)
146        .map_err(|e| Status::internal(format!("gzip compress: {e}")))?;
147    enc.finish()
148        .map_err(|e| Status::internal(format!("gzip compress: {e}")))
149}
150
151#[cfg(feature = "gzip")]
152fn gzip_decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
153    use flate2::read::GzDecoder;
154    read_capped(GzDecoder::new(data), max_size, "gzip decompress")
155}
156
157#[cfg(feature = "deflate")]
158fn deflate_compress(data: &[u8]) -> Result<Vec<u8>, Status> {
159    use flate2::{Compression, write::DeflateEncoder};
160    use std::io::Write;
161    let mut enc = DeflateEncoder::new(Vec::with_capacity(data.len()), Compression::default());
162    enc.write_all(data)
163        .map_err(|e| Status::internal(format!("deflate compress: {e}")))?;
164    enc.finish()
165        .map_err(|e| Status::internal(format!("deflate compress: {e}")))
166}
167
168#[cfg(feature = "deflate")]
169fn deflate_decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
170    use flate2::read::DeflateDecoder;
171    read_capped(DeflateDecoder::new(data), max_size, "deflate decompress")
172}
173
174#[cfg(feature = "zstd")]
175fn zstd_compress(data: &[u8]) -> Result<Vec<u8>, Status> {
176    zstd::stream::encode_all(data, 0).map_err(|e| Status::internal(format!("zstd compress: {e}")))
177}
178
179#[cfg(feature = "zstd")]
180fn zstd_decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
181    let dec = zstd::stream::Decoder::new(data)
182        .map_err(|e| Status::internal(format!("zstd decompress: {e}")))?;
183    read_capped(dec, max_size, "zstd decompress")
184}
185
186/// Read at most `max_size + 1` bytes from `r` into a fresh `Vec`. If we
187/// hit the +1 byte, the message blew the cap → `ResourceExhausted`.
188#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
189fn read_capped<R: std::io::Read>(r: R, max_size: usize, ctx: &str) -> Result<Vec<u8>, Status> {
190    use std::io::Read;
191    let mut out = Vec::new();
192    r.take(max_size as u64 + 1)
193        .read_to_end(&mut out)
194        .map_err(|e| Status::internal(format!("{ctx}: {e}")))?;
195    if out.len() > max_size {
196        return Err(oversize(max_size));
197    }
198    Ok(out)
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn identity_roundtrip() {
207        let data = b"hello world";
208        let compressed = Encoding::Identity.compress(data).unwrap();
209        let decompressed = Encoding::Identity
210            .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
211            .unwrap();
212        assert_eq!(decompressed, data);
213    }
214
215    #[test]
216    fn identity_decompress_respects_max_size() {
217        let data = vec![0u8; 100];
218        let err = Encoding::Identity.decompress(&data, 50).unwrap_err();
219        assert_eq!(err.code, crate::Code::ResourceExhausted);
220    }
221
222    #[test]
223    fn from_grpc_encoding_identity_always_recognized() {
224        assert_eq!(
225            Encoding::from_grpc_encoding("identity"),
226            Some(Encoding::Identity)
227        );
228    }
229
230    #[test]
231    fn from_grpc_encoding_unknown_returns_none() {
232        assert!(Encoding::from_grpc_encoding("snappy").is_none());
233        assert!(Encoding::from_grpc_encoding("").is_none());
234        assert!(Encoding::from_grpc_encoding("GZIP").is_none()); // case-sensitive per spec
235    }
236
237    #[test]
238    fn accepted_encodings_starts_with_identity() {
239        assert!(Encoding::accepted_encodings().starts_with("identity"));
240    }
241
242    #[cfg(feature = "gzip")]
243    mod gzip {
244        use super::*;
245
246        #[test]
247        fn parse_and_serialize() {
248            assert_eq!(Encoding::from_grpc_encoding("gzip"), Some(Encoding::Gzip));
249            assert_eq!(Encoding::Gzip.as_grpc_encoding(), "gzip");
250            assert!(Encoding::accepted_encodings().contains("gzip"));
251        }
252
253        #[test]
254        fn roundtrip() {
255            let data = b"hello, gzip-compressed world! ".repeat(100);
256            let compressed = Encoding::Gzip.compress(&data).unwrap();
257            assert!(compressed.len() < data.len(), "compression had effect");
258            let decompressed = Encoding::Gzip
259                .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
260                .unwrap();
261            assert_eq!(decompressed, data);
262        }
263
264        #[test]
265        fn decompress_respects_max_size() {
266            // 100 KB of 'a' compresses well; decompressed size > 1 KB cap.
267            let data = vec![b'a'; 100 * 1024];
268            let compressed = Encoding::Gzip.compress(&data).unwrap();
269            let err = Encoding::Gzip.decompress(&compressed, 1024).unwrap_err();
270            assert_eq!(err.code, crate::Code::ResourceExhausted);
271        }
272    }
273
274    #[cfg(feature = "deflate")]
275    mod deflate {
276        use super::*;
277
278        #[test]
279        fn parse_and_serialize() {
280            assert_eq!(
281                Encoding::from_grpc_encoding("deflate"),
282                Some(Encoding::Deflate)
283            );
284            assert_eq!(Encoding::Deflate.as_grpc_encoding(), "deflate");
285        }
286
287        #[test]
288        fn roundtrip() {
289            let data = b"hello, deflate-compressed world! ".repeat(100);
290            let compressed = Encoding::Deflate.compress(&data).unwrap();
291            let decompressed = Encoding::Deflate
292                .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
293                .unwrap();
294            assert_eq!(decompressed, data);
295        }
296    }
297
298    #[cfg(feature = "zstd")]
299    mod zstd {
300        use super::*;
301
302        #[test]
303        fn parse_and_serialize() {
304            assert_eq!(Encoding::from_grpc_encoding("zstd"), Some(Encoding::Zstd));
305            assert_eq!(Encoding::Zstd.as_grpc_encoding(), "zstd");
306        }
307
308        #[test]
309        fn roundtrip() {
310            let data = b"hello, zstd-compressed world! ".repeat(100);
311            let compressed = Encoding::Zstd.compress(&data).unwrap();
312            let decompressed = Encoding::Zstd
313                .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
314                .unwrap();
315            assert_eq!(decompressed, data);
316        }
317    }
318}