use flate2::{Decompress, FlushDecompress};
use std::{
error::Error,
fmt::{Display, Formatter, Result as FmtResult},
time::Instant,
};
#[derive(Debug)]
pub struct CompressionError {
kind: CompressionErrorType,
source: Option<Box<dyn Error + Send + Sync>>,
}
impl CompressionError {
#[must_use = "retrieving the type has no effect if left unused"]
pub const fn kind(&self) -> &CompressionErrorType {
&self.kind
}
#[must_use = "consuming the error and retrieving the source has no effect if left unused"]
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
self.source
}
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (CompressionErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, None)
}
}
impl Display for CompressionError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self.kind {
CompressionErrorType::Decompressing => f.write_str("message could not be decompressed"),
CompressionErrorType::NotUtf8 => f.write_str("decompressed message is not UTF-8"),
}
}
}
impl Error for CompressionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum CompressionErrorType {
Decompressing,
NotUtf8,
}
fn is_incomplete_message(message: &[u8]) -> bool {
const ZLIB_SUFFIX: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
message.len() < 4 || message[(message.len() - 4)..] != ZLIB_SUFFIX
}
#[derive(Debug)]
pub struct Inflater {
buffer: Box<[u8]>,
compressed: Vec<u8>,
decompress: Decompress,
last_shrank: Instant,
}
impl Inflater {
const BUFFER_SIZE: usize = 32 * 1024;
pub(crate) fn new() -> Self {
Self {
buffer: vec![0; Self::BUFFER_SIZE].into_boxed_slice(),
compressed: Vec::new(),
decompress: Decompress::new(true),
last_shrank: Instant::now(),
}
}
fn clear(&mut self) {
if self.compressed.capacity() != 0 && self.last_shrank.elapsed().as_secs() > 60 {
self.compressed.shrink_to_fit();
tracing::trace!(
compressed.capacity = self.compressed.capacity(),
"shrank capacity to the size of the last message"
);
self.last_shrank = Instant::now();
}
self.compressed.clear();
}
pub(crate) fn inflate(&mut self, message: &[u8]) -> Result<Option<String>, CompressionError> {
let message = if self.compressed.is_empty() {
if is_incomplete_message(message) {
tracing::trace!("received incomplete message");
self.compressed.extend_from_slice(message);
return Ok(None);
}
message
} else {
self.compressed.extend_from_slice(message);
if is_incomplete_message(&self.compressed) {
tracing::trace!("received incomplete message");
return Ok(None);
}
&self.compressed
};
let processed_pre = self.processed();
let mut processed = 0;
let mut decompressed = Vec::new();
loop {
let produced_pre = self.produced();
self.decompress
.decompress(
&message[processed..],
&mut self.buffer,
FlushDecompress::Sync,
)
.map_err(|source| CompressionError {
kind: CompressionErrorType::Decompressing,
source: Some(Box::new(source)),
})?;
processed = (self.processed() - processed_pre).try_into().unwrap();
let produced = (self.produced() - produced_pre).try_into().unwrap();
decompressed.extend_from_slice(&self.buffer[..produced]);
if processed == message.len() {
break;
}
tracing::trace!(bytes.compressed.remaining = message.len() - processed);
}
{
#[allow(clippy::cast_precision_loss)]
let total_percentage_compressed =
self.processed() as f64 * 100.0 / self.produced() as f64;
let total_percentage_saved = 100.0 - total_percentage_compressed;
let total_kib_saved = (self.produced() - self.processed()) / 1024;
tracing::trace!(
bytes.compressed = message.len(),
bytes.decompressed = decompressed.len(),
total_percentage_saved,
"{total_kib_saved} KiB saved in total",
);
}
self.clear();
String::from_utf8(decompressed)
.map(Some)
.map_err(|source| CompressionError {
kind: CompressionErrorType::NotUtf8,
source: Some(Box::new(source)),
})
}
pub(crate) fn reset(&mut self) {
self.compressed = Vec::new();
self.decompress.reset(true);
}
pub fn processed(&self) -> u64 {
self.decompress.total_in()
}
pub fn produced(&self) -> u64 {
self.decompress.total_out()
}
}
#[cfg(test)]
mod tests {
use super::Inflater;
const MESSAGE: &[u8] = &[
120, 156, 52, 201, 65, 10, 131, 48, 16, 5, 208, 187, 252, 117, 82, 98, 169, 32, 115, 21,
35, 50, 53, 67, 27, 136, 81, 226, 216, 82, 66, 238, 222, 110, 186, 123, 240, 42, 20, 148,
207, 148, 12, 142, 63, 182, 29, 212, 57, 131, 0, 170, 120, 10, 23, 189, 11, 235, 28, 179,
74, 121, 113, 2, 221, 186, 107, 255, 251, 89, 11, 47, 2, 26, 49, 122, 60, 88, 229, 205, 31,
187, 151, 96, 87, 142, 217, 14, 253, 16, 60, 76, 245, 88, 227, 82, 182, 195, 131, 220, 197,
181, 9, 83, 107, 95, 0, 0, 0, 255, 255,
];
const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-main-858d\",{\"micros\":0.0}]"]}}"#;
#[test]
fn decompress_single_segment() {
let mut inflator = Inflater::new();
assert!(inflator.compressed.is_empty());
assert_eq!(inflator.inflate(MESSAGE).unwrap(), Some(OUTPUT.to_owned()));
assert!(inflator.compressed.is_empty());
}
#[test]
fn decompress_split_message() {
let mut inflator = Inflater::new();
assert!(inflator.compressed.is_empty());
assert_eq!(
inflator.inflate(&MESSAGE[0..MESSAGE.len() / 2]).unwrap(),
None
);
assert!(!inflator.compressed.is_empty());
assert_eq!(
inflator.inflate(&MESSAGE[MESSAGE.len() / 2..]).unwrap(),
Some(OUTPUT.to_owned()),
);
assert!(inflator.compressed.is_empty());
}
#[test]
fn invalid_is_none() {
let mut inflator = Inflater::new();
assert_eq!(inflator.inflate(&[]).unwrap(), None);
assert_eq!(
inflator.inflate(&MESSAGE[..MESSAGE.len() - 2]).unwrap(),
None
);
}
#[test]
fn reset() {
let mut inflator = Inflater::new();
assert_eq!(
inflator.inflate(&MESSAGE[..MESSAGE.len() - 2]).unwrap(),
None
);
inflator.reset();
assert_eq!(inflator.inflate(MESSAGE).unwrap(), Some(OUTPUT.to_owned()));
}
}