1use std::io;
4
5#[cfg(feature = "compress")]
6use bytes::BufMut;
7use bytes::{Buf, BytesMut};
8#[cfg(feature = "compress")]
9pub use flate2::Compression as Level;
10#[cfg(feature = "gzip")]
11use flate2::bufread::{GzDecoder, GzEncoder};
12#[cfg(feature = "zlib")]
13use flate2::bufread::{ZlibDecoder, ZlibEncoder};
14use http::HeaderValue;
15
16use super::BUFFER_SIZE;
17#[cfg(feature = "compress")]
18use crate::Status;
19
20pub const ENCODING_HEADER: &str = "grpc-encoding";
21pub const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
22#[cfg(feature = "compress")]
23const DEFAULT_LEVEL: Level = Level::new(6);
24
25#[derive(Clone, Copy, Debug)]
27pub enum CompressionEncoding {
28 Identity,
29 #[cfg(feature = "gzip")]
30 Gzip(Option<GzipConfig>),
31 #[cfg(feature = "zlib")]
32 Zlib(Option<ZlibConfig>),
33 #[cfg(feature = "zstd")]
34 Zstd(Option<ZstdConfig>),
35}
36
37impl PartialEq for CompressionEncoding {
38 fn eq(&self, other: &Self) -> bool {
39 match (self, other) {
40 #[cfg(feature = "gzip")]
41 (Self::Gzip(_), Self::Gzip(_)) => true,
42 #[cfg(feature = "zlib")]
43 (Self::Zlib(_), Self::Zlib(_)) => true,
44 (Self::Identity, Self::Identity) => true,
45 #[cfg(feature = "zstd")]
46 (Self::Zstd(_), Self::Zstd(_)) => true,
47 #[cfg(feature = "compress")]
48 _ => false,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
54#[cfg(feature = "gzip")]
55pub struct GzipConfig {
56 pub level: Level,
57}
58
59#[cfg(feature = "gzip")]
60impl Default for GzipConfig {
61 fn default() -> Self {
62 Self {
63 level: DEFAULT_LEVEL,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy)]
69#[cfg(feature = "zlib")]
70pub struct ZlibConfig {
71 pub level: Level,
72}
73
74#[cfg(feature = "zlib")]
75impl Default for ZlibConfig {
76 fn default() -> Self {
77 Self {
78 level: DEFAULT_LEVEL,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy)]
84#[cfg(feature = "zstd")]
85pub struct ZstdConfig {
86 pub level: Level,
87}
88
89#[cfg(feature = "zstd")]
90impl Default for ZstdConfig {
91 fn default() -> Self {
92 Self {
93 level: DEFAULT_LEVEL,
94 }
95 }
96}
97
98pub fn compose_encodings(encodings: &[CompressionEncoding]) -> HeaderValue {
100 let encodings = encodings
101 .iter()
102 .map(|item| match item {
103 #[cfg(feature = "gzip")]
105 CompressionEncoding::Gzip(_) => "gzip",
106 #[cfg(feature = "zlib")]
107 CompressionEncoding::Zlib(_) => "zlib",
108 #[cfg(feature = "zstd")]
109 CompressionEncoding::Zstd(_) => "zstd",
110 CompressionEncoding::Identity => "identity",
111 })
112 .collect::<Vec<&'static str>>();
113 HeaderValue::from_str(encodings.join(",").as_str()).unwrap()
116}
117
118#[cfg(feature = "compress")]
119fn is_enabled(encoding: CompressionEncoding, encodings: &[CompressionEncoding]) -> bool {
120 encodings.contains(&encoding)
121}
122
123impl CompressionEncoding {
124 pub fn into_header_value(self) -> HeaderValue {
126 match self {
127 #[cfg(feature = "gzip")]
128 CompressionEncoding::Gzip(_) => HeaderValue::from_static("gzip"),
129 #[cfg(feature = "zlib")]
130 CompressionEncoding::Zlib(_) => HeaderValue::from_static("zlib"),
131 #[cfg(feature = "zstd")]
132 CompressionEncoding::Zstd(_) => HeaderValue::from_static("zstd"),
133 CompressionEncoding::Identity => HeaderValue::from_static("identity"),
134 }
135 }
136
137 pub fn into_accept_encoding_header_value(
140 self,
141 encodings: &[CompressionEncoding],
142 ) -> Option<HeaderValue> {
143 if self.is_enabled() {
144 Some(compose_encodings(encodings))
145 } else {
146 None
147 }
148 }
149
150 #[cfg(feature = "compress")]
152 pub fn from_accept_encoding_header(
153 headers: &http::HeaderMap,
154 config: &Option<Vec<Self>>,
155 ) -> Option<Self> {
156 if let Some(available_encodings) = config {
157 let header_value = headers.get(ACCEPT_ENCODING_HEADER)?;
158 let header_value_str = header_value.to_str().ok()?;
159
160 header_value_str
161 .split(',')
162 .map(|s| s.trim())
163 .find_map(|encoding| match encoding {
164 #[cfg(feature = "gzip")]
165 "gzip" => available_encodings.iter().find_map(|item| {
166 if item.is_gzip_enabled() {
167 Some(*item)
168 } else {
169 None
170 }
171 }),
172 #[cfg(feature = "zlib")]
173 "zlib" => available_encodings.iter().find_map(|item| {
174 if item.is_zlib_enabled() {
175 Some(*item)
176 } else {
177 None
178 }
179 }),
180 #[cfg(feature = "zstd")]
181 "zstd" => available_encodings.iter().find_map(|item| {
182 if item.is_zstd_enabled() {
183 Some(*item)
184 } else {
185 None
186 }
187 }),
188 _ => None,
189 })
190 } else {
191 None
192 }
193 }
194
195 #[allow(clippy::result_large_err)]
197 #[cfg(feature = "compress")]
198 pub fn from_encoding_header(
199 headers: &http::HeaderMap,
200 config: &Option<Vec<Self>>,
201 ) -> Result<Option<Self>, Status> {
202 if let Some(encodings) = config {
203 let header_value = if let Some(header_value) = headers.get(ENCODING_HEADER) {
204 header_value
205 } else {
206 return Ok(None);
207 };
208
209 match header_value.to_str()? {
210 #[cfg(feature = "gzip")]
211 "gzip" if is_enabled(Self::Gzip(None), encodings) => Ok(Some(Self::Gzip(None))),
212 #[cfg(feature = "zlib")]
213 "zlib" if is_enabled(Self::Zlib(None), encodings) => Ok(Some(Self::Zlib(None))),
214 #[cfg(feature = "zstd")]
215 "zstd" if is_enabled(Self::Zstd(None), encodings) => Ok(Some(Self::Zstd(None))),
216 "identity" => Ok(None),
217 other => {
218 let status = Status::unimplemented(format!(
219 "Content is compressed with `{other}` which isn't supported"
220 ));
221 Err(status)
222 }
223 }
224 } else {
225 Ok(None)
226 }
227 }
228
229 #[cfg(feature = "compress")]
232 pub fn level(self) -> Level {
233 match self {
234 #[cfg(feature = "gzip")]
235 CompressionEncoding::Gzip(Some(config)) => config.level,
236 #[cfg(feature = "zlib")]
237 CompressionEncoding::Zlib(Some(config)) => config.level,
238 #[cfg(feature = "zstd")]
239 CompressionEncoding::Zstd(Some(config)) => config.level,
240 _ => DEFAULT_LEVEL,
241 }
242 }
243
244 #[cfg(feature = "gzip")]
245 const fn is_gzip_enabled(&self) -> bool {
246 matches!(self, CompressionEncoding::Gzip(_))
247 }
248
249 #[cfg(feature = "zlib")]
250 const fn is_zlib_enabled(&self) -> bool {
251 matches!(self, CompressionEncoding::Zlib(_))
252 }
253
254 #[cfg(feature = "zstd")]
255 const fn is_zstd_enabled(&self) -> bool {
256 matches!(self, CompressionEncoding::Zstd(_))
257 }
258
259 const fn is_enabled(&self) -> bool {
260 #[allow(unreachable_patterns)]
261 match self {
262 #[cfg(feature = "gzip")]
263 CompressionEncoding::Gzip(_) => true,
264 #[cfg(feature = "zlib")]
265 CompressionEncoding::Zlib(_) => true,
266 #[cfg(feature = "zstd")]
267 CompressionEncoding::Zstd(_) => true,
268 _ => false,
269 }
270 }
271}
272
273pub(crate) fn compress(
275 encoding: CompressionEncoding,
276 src_buf: &mut BytesMut,
277 dest_buf: &mut BytesMut,
278) -> Result<(), io::Error> {
279 let len = src_buf.len();
280 let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
281
282 dest_buf.reserve(capacity);
283
284 match encoding {
285 #[cfg(feature = "gzip")]
286 CompressionEncoding::Gzip(Some(config)) => {
287 let mut gz_encoder = GzEncoder::new(&src_buf[0..len], config.level);
288 io::copy(&mut gz_encoder, &mut dest_buf.writer())?;
289 }
290 #[cfg(feature = "zlib")]
291 CompressionEncoding::Zlib(Some(config)) => {
292 let mut zlib_encoder = ZlibEncoder::new(&src_buf[0..len], config.level);
293 io::copy(&mut zlib_encoder, &mut dest_buf.writer())?;
294 }
295 #[cfg(feature = "zstd")]
296 CompressionEncoding::Zstd(Some(config)) => {
297 let level = config.level.level();
298 let zstd_level = if level == 0 {
299 zstd::DEFAULT_COMPRESSION_LEVEL
300 } else {
301 level as i32
302 };
303 let mut zstd_encoder = zstd::Encoder::new(dest_buf.writer(), zstd_level)?;
304 io::copy(&mut &src_buf[0..len], &mut zstd_encoder)?;
305 zstd_encoder.finish()?;
306 }
307 _ => {}
308 };
309
310 src_buf.advance(len);
311 Ok(())
312}
313
314pub(crate) fn decompress(
316 encoding: CompressionEncoding,
317 src_buf: &mut BytesMut,
318 dest_buf: &mut BytesMut,
319) -> Result<(), io::Error> {
320 let len = src_buf.len();
321 let estimate_decompressed_len = len * 2;
322 let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
323
324 dest_buf.reserve(capacity);
325
326 match encoding {
327 #[cfg(feature = "gzip")]
328 CompressionEncoding::Gzip(_) => {
329 let mut gz_decoder = GzDecoder::new(&src_buf[0..len]);
330 io::copy(&mut gz_decoder, &mut dest_buf.writer())?;
331 }
332 #[cfg(feature = "zlib")]
333 CompressionEncoding::Zlib(_) => {
334 let mut zlib_decoder = ZlibDecoder::new(&src_buf[0..len]);
335 io::copy(&mut zlib_decoder, &mut dest_buf.writer())?;
336 }
337 #[cfg(feature = "zstd")]
338 CompressionEncoding::Zstd(_) => {
339 let mut zstd_decoder = zstd::Decoder::new(&src_buf[0..len])?;
340 io::copy(&mut zstd_decoder, &mut dest_buf.writer())?;
341 }
342 _ => {}
343 };
344
345 src_buf.advance(len);
346 Ok(())
347}
348
349#[cfg(test)]
350mod tests {
351 use bytes::BytesMut;
352
353 #[cfg(feature = "gzip")]
354 use crate::codec::compression::GzipConfig;
355 #[cfg(feature = "compress")]
356 use crate::codec::compression::Level;
357 #[cfg(feature = "zlib")]
358 use crate::codec::compression::ZlibConfig;
359 #[cfg(feature = "zstd")]
360 use crate::codec::compression::ZstdConfig;
361 use crate::codec::{
362 BUFFER_SIZE,
363 compression::{CompressionEncoding, compress, decompress},
364 };
365
366 #[test]
367 fn test_consistency_for_compression() {
368 let mut src = BytesMut::with_capacity(BUFFER_SIZE);
369 let mut compress_buf = BytesMut::new();
370 let mut de_data = BytesMut::with_capacity(BUFFER_SIZE);
371 let test_data = &b"test compression"[..];
372 src.extend_from_slice(test_data);
373
374 let encodings = [
375 #[cfg(feature = "gzip")]
376 CompressionEncoding::Gzip(Some(GzipConfig {
377 level: Level::fast(),
378 })),
379 #[cfg(feature = "zlib")]
380 CompressionEncoding::Zlib(Some(ZlibConfig {
381 level: Level::fast(),
382 })),
383 #[cfg(feature = "zstd")]
384 CompressionEncoding::Zstd(Some(ZstdConfig {
385 level: Level::new(3),
386 })),
387 CompressionEncoding::Identity,
388 ];
389
390 for encoding in encodings {
391 compress_buf.clear();
392 compress(encoding, &mut src, &mut compress_buf).expect("compress failed:");
393 decompress(encoding, &mut compress_buf, &mut de_data).expect("decompress failed:");
394 assert_eq!(test_data, de_data.as_ref());
395 }
396 }
397}