1use flate2::{Decompress, FlushDecompress};
11use std::{
12 error::Error,
13 fmt::{Display, Formatter, Result as FmtResult},
14 time::Instant,
15};
16
17#[derive(Debug)]
19pub struct CompressionError {
20 kind: CompressionErrorType,
22 source: Option<Box<dyn Error + Send + Sync>>,
24}
25
26impl CompressionError {
27 #[must_use = "retrieving the type has no effect if left unused"]
29 pub const fn kind(&self) -> &CompressionErrorType {
30 &self.kind
31 }
32
33 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
35 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
36 self.source
37 }
38
39 #[must_use = "consuming the error into its parts has no effect if left unused"]
41 pub fn into_parts(self) -> (CompressionErrorType, Option<Box<dyn Error + Send + Sync>>) {
42 (self.kind, None)
43 }
44}
45
46impl Display for CompressionError {
47 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
48 match self.kind {
49 CompressionErrorType::Decompressing => f.write_str("message could not be decompressed"),
50 CompressionErrorType::NotUtf8 => f.write_str("decompressed message is not UTF-8"),
51 }
52 }
53}
54
55impl Error for CompressionError {
56 fn source(&self) -> Option<&(dyn Error + 'static)> {
57 self.source
58 .as_ref()
59 .map(|source| &**source as &(dyn Error + 'static))
60 }
61}
62
63#[derive(Debug)]
65#[non_exhaustive]
66pub enum CompressionErrorType {
67 Decompressing,
69 NotUtf8,
71}
72
73fn is_incomplete_message(message: &[u8]) -> bool {
75 const ZLIB_SUFFIX: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
82
83 message.len() < 4 || message[(message.len() - 4)..] != ZLIB_SUFFIX
84}
85
86#[derive(Debug)]
105pub struct Inflater {
106 buffer: Box<[u8]>,
108 compressed: Vec<u8>,
110 decompress: Decompress,
112 last_shrank: Instant,
114}
115
116impl Inflater {
117 const BUFFER_SIZE: usize = 32 * 1024;
119
120 pub(crate) fn new() -> Self {
122 Self {
123 buffer: vec![0; Self::BUFFER_SIZE].into_boxed_slice(),
124 compressed: Vec::new(),
125 decompress: Decompress::new(true),
126 last_shrank: Instant::now(),
127 }
128 }
129
130 fn clear(&mut self) {
132 if self.compressed.capacity() != 0 && self.last_shrank.elapsed().as_secs() > 60 {
133 self.compressed.shrink_to_fit();
134
135 tracing::trace!(
136 compressed.capacity = self.compressed.capacity(),
137 "shrank capacity to the size of the last message"
138 );
139
140 self.last_shrank = Instant::now();
141 }
142
143 self.compressed.clear();
144 }
145
146 pub(crate) fn inflate(&mut self, message: &[u8]) -> Result<Option<String>, CompressionError> {
159 let message = if self.compressed.is_empty() {
162 if is_incomplete_message(message) {
163 tracing::trace!("received incomplete message");
164 self.compressed.extend_from_slice(message);
165 return Ok(None);
166 }
167 message
168 } else {
169 self.compressed.extend_from_slice(message);
170 if is_incomplete_message(&self.compressed) {
171 tracing::trace!("received incomplete message");
172 return Ok(None);
173 }
174 &self.compressed
175 };
176
177 let processed_pre = self.processed();
178
179 let mut processed = 0;
180
181 let mut decompressed = Vec::new();
184
185 loop {
186 let produced_pre = self.produced();
187
188 self.decompress
190 .decompress(
191 &message[processed..],
192 &mut self.buffer,
193 FlushDecompress::Sync,
194 )
195 .map_err(|source| CompressionError {
196 kind: CompressionErrorType::Decompressing,
197 source: Some(Box::new(source)),
198 })?;
199
200 processed = (self.processed() - processed_pre).try_into().unwrap();
201 let produced = (self.produced() - produced_pre).try_into().unwrap();
202
203 decompressed.extend_from_slice(&self.buffer[..produced]);
204
205 if processed == message.len() {
207 break;
208 }
209
210 tracing::trace!(bytes.compressed.remaining = message.len() - processed);
211 }
212
213 {
214 #[allow(clippy::cast_precision_loss)]
215 let total_percentage_compressed =
216 self.processed() as f64 * 100.0 / self.produced() as f64;
217 let total_percentage_saved = 100.0 - total_percentage_compressed;
218 let total_kib_saved = (self.produced() - self.processed()) / 1024;
219
220 tracing::trace!(
221 bytes.compressed = message.len(),
222 bytes.decompressed = decompressed.len(),
223 total_percentage_saved,
224 "{total_kib_saved} KiB saved in total",
225 );
226 }
227
228 self.clear();
229
230 String::from_utf8(decompressed)
231 .map(Some)
232 .map_err(|source| CompressionError {
233 kind: CompressionErrorType::NotUtf8,
234 source: Some(Box::new(source)),
235 })
236 }
237
238 pub(crate) fn reset(&mut self) {
240 self.compressed = Vec::new();
241 self.decompress.reset(true);
242 }
243
244 pub fn processed(&self) -> u64 {
246 self.decompress.total_in()
247 }
248
249 pub fn produced(&self) -> u64 {
251 self.decompress.total_out()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::Inflater;
258
259 const MESSAGE: &[u8] = &[
260 120, 156, 52, 201, 65, 10, 131, 48, 16, 5, 208, 187, 252, 117, 82, 98, 169, 32, 115, 21,
261 35, 50, 53, 67, 27, 136, 81, 226, 216, 82, 66, 238, 222, 110, 186, 123, 240, 42, 20, 148,
262 207, 148, 12, 142, 63, 182, 29, 212, 57, 131, 0, 170, 120, 10, 23, 189, 11, 235, 28, 179,
263 74, 121, 113, 2, 221, 186, 107, 255, 251, 89, 11, 47, 2, 26, 49, 122, 60, 88, 229, 205, 31,
264 187, 151, 96, 87, 142, 217, 14, 253, 16, 60, 76, 245, 88, 227, 82, 182, 195, 131, 220, 197,
265 181, 9, 83, 107, 95, 0, 0, 0, 255, 255,
266 ];
267 const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-main-858d\",{\"micros\":0.0}]"]}}"#;
268
269 #[test]
270 fn decompress_single_segment() {
271 let mut inflator = Inflater::new();
272 assert!(inflator.compressed.is_empty());
273 assert_eq!(inflator.inflate(MESSAGE).unwrap(), Some(OUTPUT.to_owned()));
274
275 assert!(inflator.compressed.is_empty());
276 }
277
278 #[test]
279 fn decompress_split_message() {
280 let mut inflator = Inflater::new();
281 assert!(inflator.compressed.is_empty());
282 assert_eq!(
283 inflator.inflate(&MESSAGE[0..MESSAGE.len() / 2]).unwrap(),
284 None
285 );
286 assert!(!inflator.compressed.is_empty());
287
288 assert_eq!(
289 inflator.inflate(&MESSAGE[MESSAGE.len() / 2..]).unwrap(),
290 Some(OUTPUT.to_owned()),
291 );
292 assert!(inflator.compressed.is_empty());
293 }
294
295 #[test]
296 fn invalid_is_none() {
297 let mut inflator = Inflater::new();
298 assert_eq!(inflator.inflate(&[]).unwrap(), None);
299
300 assert_eq!(
301 inflator.inflate(&MESSAGE[..MESSAGE.len() - 2]).unwrap(),
302 None
303 );
304 }
305
306 #[test]
307 fn reset() {
308 let mut inflator = Inflater::new();
309 assert_eq!(
310 inflator.inflate(&MESSAGE[..MESSAGE.len() - 2]).unwrap(),
311 None
312 );
313
314 inflator.reset();
315 assert_eq!(inflator.inflate(MESSAGE).unwrap(), Some(OUTPUT.to_owned()));
316 }
317}