1use crate::Status;
23
24pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum Encoding {
32 Identity,
34 #[cfg(feature = "gzip")]
36 Gzip,
37 #[cfg(feature = "deflate")]
39 Deflate,
40 #[cfg(feature = "zstd")]
42 Zstd,
43}
44
45impl Encoding {
46 pub const ALL: &'static [Self] = &[
49 Self::Identity,
50 #[cfg(feature = "gzip")]
51 Self::Gzip,
52 #[cfg(feature = "deflate")]
53 Self::Deflate,
54 #[cfg(feature = "zstd")]
55 Self::Zstd,
56 ];
57
58 pub fn from_grpc_encoding(s: &str) -> Option<Self> {
61 match s {
62 "identity" => Some(Self::Identity),
63 #[cfg(feature = "gzip")]
64 "gzip" => Some(Self::Gzip),
65 #[cfg(feature = "deflate")]
66 "deflate" => Some(Self::Deflate),
67 #[cfg(feature = "zstd")]
68 "zstd" => Some(Self::Zstd),
69 _ => None,
70 }
71 }
72
73 pub fn as_grpc_encoding(&self) -> &'static str {
75 match self {
76 Self::Identity => "identity",
77 #[cfg(feature = "gzip")]
78 Self::Gzip => "gzip",
79 #[cfg(feature = "deflate")]
80 Self::Deflate => "deflate",
81 #[cfg(feature = "zstd")]
82 Self::Zstd => "zstd",
83 }
84 }
85
86 pub fn accepted_encodings() -> &'static str {
90 static LIST: std::sync::OnceLock<String> = std::sync::OnceLock::new();
91 LIST.get_or_init(|| {
92 Self::ALL
93 .iter()
94 .map(|e| e.as_grpc_encoding())
95 .collect::<Vec<_>>()
96 .join(",")
97 })
98 }
99
100 pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>, Status> {
102 match self {
103 Self::Identity => Ok(data.to_vec()),
104 #[cfg(feature = "gzip")]
105 Self::Gzip => gzip_compress(data),
106 #[cfg(feature = "deflate")]
107 Self::Deflate => deflate_compress(data),
108 #[cfg(feature = "zstd")]
109 Self::Zstd => zstd_compress(data),
110 }
111 }
112
113 pub fn decompress(&self, data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
117 match self {
118 Self::Identity => {
119 if data.len() > max_size {
120 return Err(oversize(max_size));
121 }
122 Ok(data.to_vec())
123 }
124 #[cfg(feature = "gzip")]
125 Self::Gzip => gzip_decompress(data, max_size),
126 #[cfg(feature = "deflate")]
127 Self::Deflate => deflate_decompress(data, max_size),
128 #[cfg(feature = "zstd")]
129 Self::Zstd => zstd_decompress(data, max_size),
130 }
131 }
132}
133
134fn oversize(max_size: usize) -> Status {
135 Status::resource_exhausted(format!(
136 "decompressed message exceeds limit of {max_size} bytes"
137 ))
138}
139
140#[cfg(feature = "gzip")]
141fn gzip_compress(data: &[u8]) -> Result<Vec<u8>, Status> {
142 use flate2::{Compression, write::GzEncoder};
143 use std::io::Write;
144 let mut enc = GzEncoder::new(Vec::with_capacity(data.len()), Compression::default());
145 enc.write_all(data)
146 .map_err(|e| Status::internal(format!("gzip compress: {e}")))?;
147 enc.finish()
148 .map_err(|e| Status::internal(format!("gzip compress: {e}")))
149}
150
151#[cfg(feature = "gzip")]
152fn gzip_decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
153 use flate2::read::GzDecoder;
154 read_capped(GzDecoder::new(data), max_size, "gzip decompress")
155}
156
157#[cfg(feature = "deflate")]
158fn deflate_compress(data: &[u8]) -> Result<Vec<u8>, Status> {
159 use flate2::{Compression, write::DeflateEncoder};
160 use std::io::Write;
161 let mut enc = DeflateEncoder::new(Vec::with_capacity(data.len()), Compression::default());
162 enc.write_all(data)
163 .map_err(|e| Status::internal(format!("deflate compress: {e}")))?;
164 enc.finish()
165 .map_err(|e| Status::internal(format!("deflate compress: {e}")))
166}
167
168#[cfg(feature = "deflate")]
169fn deflate_decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
170 use flate2::read::DeflateDecoder;
171 read_capped(DeflateDecoder::new(data), max_size, "deflate decompress")
172}
173
174#[cfg(feature = "zstd")]
175fn zstd_compress(data: &[u8]) -> Result<Vec<u8>, Status> {
176 zstd::stream::encode_all(data, 0).map_err(|e| Status::internal(format!("zstd compress: {e}")))
177}
178
179#[cfg(feature = "zstd")]
180fn zstd_decompress(data: &[u8], max_size: usize) -> Result<Vec<u8>, Status> {
181 let dec = zstd::stream::Decoder::new(data)
182 .map_err(|e| Status::internal(format!("zstd decompress: {e}")))?;
183 read_capped(dec, max_size, "zstd decompress")
184}
185
186#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
189fn read_capped<R: std::io::Read>(r: R, max_size: usize, ctx: &str) -> Result<Vec<u8>, Status> {
190 use std::io::Read;
191 let mut out = Vec::new();
192 r.take(max_size as u64 + 1)
193 .read_to_end(&mut out)
194 .map_err(|e| Status::internal(format!("{ctx}: {e}")))?;
195 if out.len() > max_size {
196 return Err(oversize(max_size));
197 }
198 Ok(out)
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn identity_roundtrip() {
207 let data = b"hello world";
208 let compressed = Encoding::Identity.compress(data).unwrap();
209 let decompressed = Encoding::Identity
210 .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
211 .unwrap();
212 assert_eq!(decompressed, data);
213 }
214
215 #[test]
216 fn identity_decompress_respects_max_size() {
217 let data = vec![0u8; 100];
218 let err = Encoding::Identity.decompress(&data, 50).unwrap_err();
219 assert_eq!(err.code, crate::Code::ResourceExhausted);
220 }
221
222 #[test]
223 fn from_grpc_encoding_identity_always_recognized() {
224 assert_eq!(
225 Encoding::from_grpc_encoding("identity"),
226 Some(Encoding::Identity)
227 );
228 }
229
230 #[test]
231 fn from_grpc_encoding_unknown_returns_none() {
232 assert!(Encoding::from_grpc_encoding("snappy").is_none());
233 assert!(Encoding::from_grpc_encoding("").is_none());
234 assert!(Encoding::from_grpc_encoding("GZIP").is_none()); }
236
237 #[test]
238 fn accepted_encodings_starts_with_identity() {
239 assert!(Encoding::accepted_encodings().starts_with("identity"));
240 }
241
242 #[cfg(feature = "gzip")]
243 mod gzip {
244 use super::*;
245
246 #[test]
247 fn parse_and_serialize() {
248 assert_eq!(Encoding::from_grpc_encoding("gzip"), Some(Encoding::Gzip));
249 assert_eq!(Encoding::Gzip.as_grpc_encoding(), "gzip");
250 assert!(Encoding::accepted_encodings().contains("gzip"));
251 }
252
253 #[test]
254 fn roundtrip() {
255 let data = b"hello, gzip-compressed world! ".repeat(100);
256 let compressed = Encoding::Gzip.compress(&data).unwrap();
257 assert!(compressed.len() < data.len(), "compression had effect");
258 let decompressed = Encoding::Gzip
259 .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
260 .unwrap();
261 assert_eq!(decompressed, data);
262 }
263
264 #[test]
265 fn decompress_respects_max_size() {
266 let data = vec![b'a'; 100 * 1024];
268 let compressed = Encoding::Gzip.compress(&data).unwrap();
269 let err = Encoding::Gzip.decompress(&compressed, 1024).unwrap_err();
270 assert_eq!(err.code, crate::Code::ResourceExhausted);
271 }
272 }
273
274 #[cfg(feature = "deflate")]
275 mod deflate {
276 use super::*;
277
278 #[test]
279 fn parse_and_serialize() {
280 assert_eq!(
281 Encoding::from_grpc_encoding("deflate"),
282 Some(Encoding::Deflate)
283 );
284 assert_eq!(Encoding::Deflate.as_grpc_encoding(), "deflate");
285 }
286
287 #[test]
288 fn roundtrip() {
289 let data = b"hello, deflate-compressed world! ".repeat(100);
290 let compressed = Encoding::Deflate.compress(&data).unwrap();
291 let decompressed = Encoding::Deflate
292 .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
293 .unwrap();
294 assert_eq!(decompressed, data);
295 }
296 }
297
298 #[cfg(feature = "zstd")]
299 mod zstd {
300 use super::*;
301
302 #[test]
303 fn parse_and_serialize() {
304 assert_eq!(Encoding::from_grpc_encoding("zstd"), Some(Encoding::Zstd));
305 assert_eq!(Encoding::Zstd.as_grpc_encoding(), "zstd");
306 }
307
308 #[test]
309 fn roundtrip() {
310 let data = b"hello, zstd-compressed world! ".repeat(100);
311 let compressed = Encoding::Zstd.compress(&data).unwrap();
312 let decompressed = Encoding::Zstd
313 .decompress(&compressed, DEFAULT_MAX_MESSAGE_SIZE)
314 .unwrap();
315 assert_eq!(decompressed, data);
316 }
317 }
318}