twilight_gateway/
inflater.rs

1//! Efficiently decompress Discord gateway messages.
2//!
3//! The [`Inflater`] decompresses messages sent over the gateway by reusing a
4//! common buffer to minimize the amount of allocations in the hot path.
5//!
6//! A compressed message buffer is used to store incomplete messages and gets,
7//! if used, shrank every minute to the size of the most recent completed
8//! message.
9
10use flate2::{Decompress, FlushDecompress};
11use std::{
12    error::Error,
13    fmt::{Display, Formatter, Result as FmtResult},
14    time::Instant,
15};
16
17/// An operation relating to compression failed.
18#[derive(Debug)]
19pub struct CompressionError {
20    /// Type of error.
21    kind: CompressionErrorType,
22    /// Source error if available.
23    source: Option<Box<dyn Error + Send + Sync>>,
24}
25
26impl CompressionError {
27    /// Immutable reference to the type of error that occurred.
28    #[must_use = "retrieving the type has no effect if left unused"]
29    pub const fn kind(&self) -> &CompressionErrorType {
30        &self.kind
31    }
32
33    /// Consume the error, returning the source error if there is any.
34    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
35    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
36        self.source
37    }
38
39    /// Consume the error, returning the owned error type and the source error.
40    #[must_use = "consuming the error into its parts has no effect if left unused"]
41    pub fn into_parts(self) -> (CompressionErrorType, Option<Box<dyn Error + Send + Sync>>) {
42        (self.kind, None)
43    }
44}
45
46impl Display for CompressionError {
47    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
48        match self.kind {
49            CompressionErrorType::Decompressing => f.write_str("message could not be decompressed"),
50            CompressionErrorType::NotUtf8 => f.write_str("decompressed message is not UTF-8"),
51        }
52    }
53}
54
55impl Error for CompressionError {
56    fn source(&self) -> Option<&(dyn Error + 'static)> {
57        self.source
58            .as_ref()
59            .map(|source| &**source as &(dyn Error + 'static))
60    }
61}
62
63/// Type of [`CompressionError`] that occurred.
64#[derive(Debug)]
65#[non_exhaustive]
66pub enum CompressionErrorType {
67    /// Decompressing a frame failed.
68    Decompressing,
69    /// Decompressed message is not UTF-8.
70    NotUtf8,
71}
72
73/// Whether the message is incomplete.
74fn is_incomplete_message(message: &[u8]) -> bool {
75    /// The "magic number" deciding if a message is done or if another
76    /// message needs to be read.
77    ///
78    /// The suffix is documented in the [Discord docs].
79    ///
80    /// [Discord docs]: https://discord.com/developers/docs/topics/gateway#transport-compression-transport-compression-example
81    const ZLIB_SUFFIX: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
82
83    message.len() < 4 || message[(message.len() - 4)..] != ZLIB_SUFFIX
84}
85
86/// Gateway event decompressor.
87///
88/// Each received compressed event gets inflated into a [`String`] who's input
89/// and output size is recorded.
90///
91/// # Example
92///
93/// Calculate the percentage bytes saved:
94/// ```
95/// # use twilight_gateway::{Intents, Shard, ShardId};
96/// # #[tokio::main] async fn main() {
97/// # let shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
98/// let inflater = shard.inflater();
99/// let total_percentage_compressed =
100///     inflater.processed() as f64 * 100.0 / inflater.produced() as f64;
101/// let total_percentage_saved = 100.0 - total_percentage_compressed;
102/// # }
103/// ```
104#[derive(Debug)]
105pub struct Inflater {
106    /// Common decompressed message buffer.
107    buffer: Box<[u8]>,
108    /// Per event compressed message buffer.
109    compressed: Vec<u8>,
110    /// Zlib decompressor with a dictionary of past data.
111    decompress: Decompress,
112    /// When the compression buffer last shrank.
113    last_shrank: Instant,
114}
115
116impl Inflater {
117    /// [`Self::buffer`]'s size.
118    const BUFFER_SIZE: usize = 32 * 1024;
119
120    /// Create a new inflator for a shard.
121    pub(crate) fn new() -> Self {
122        Self {
123            buffer: vec![0; Self::BUFFER_SIZE].into_boxed_slice(),
124            compressed: Vec::new(),
125            decompress: Decompress::new(true),
126            last_shrank: Instant::now(),
127        }
128    }
129
130    /// Clear the compressed buffer and periodically shrink its capacity.
131    fn clear(&mut self) {
132        if self.compressed.capacity() != 0 && self.last_shrank.elapsed().as_secs() > 60 {
133            self.compressed.shrink_to_fit();
134
135            tracing::trace!(
136                compressed.capacity = self.compressed.capacity(),
137                "shrank capacity to the size of the last message"
138            );
139
140            self.last_shrank = Instant::now();
141        }
142
143        self.compressed.clear();
144    }
145
146    /// Decompress message.
147    ///
148    /// Returns `None` if the message is incomplete, saving its content to be
149    /// combined with the next one.
150    ///
151    /// # Errors
152    ///
153    /// Returns a [`CompressionErrorType::Decompressing`] error type if the
154    /// message could not be decompressed.
155    ///
156    /// Returns a [`CompressionErrorType::NotUtf8`] error type if the
157    /// decompressed message is not UTF-8.
158    pub(crate) fn inflate(&mut self, message: &[u8]) -> Result<Option<String>, CompressionError> {
159        // Complete message. Tries to bypass the `self.compressed` buffer if the
160        // message is incomplete.
161        let message = if self.compressed.is_empty() {
162            if is_incomplete_message(message) {
163                tracing::trace!("received incomplete message");
164                self.compressed.extend_from_slice(message);
165                return Ok(None);
166            }
167            message
168        } else {
169            self.compressed.extend_from_slice(message);
170            if is_incomplete_message(&self.compressed) {
171                tracing::trace!("received incomplete message");
172                return Ok(None);
173            }
174            &self.compressed
175        };
176
177        let processed_pre = self.processed();
178
179        let mut processed = 0;
180
181        // Decompressed message. `Vec::extend_from_slice` efficiently allocates
182        // only what's necessary.
183        let mut decompressed = Vec::new();
184
185        loop {
186            let produced_pre = self.produced();
187
188            // Use Sync to ensure data is flushed to the buffer.
189            self.decompress
190                .decompress(
191                    &message[processed..],
192                    &mut self.buffer,
193                    FlushDecompress::Sync,
194                )
195                .map_err(|source| CompressionError {
196                    kind: CompressionErrorType::Decompressing,
197                    source: Some(Box::new(source)),
198                })?;
199
200            processed = (self.processed() - processed_pre).try_into().unwrap();
201            let produced = (self.produced() - produced_pre).try_into().unwrap();
202
203            decompressed.extend_from_slice(&self.buffer[..produced]);
204
205            // Break when message has been fully decompressed.
206            if processed == message.len() {
207                break;
208            }
209
210            tracing::trace!(bytes.compressed.remaining = message.len() - processed);
211        }
212
213        {
214            #[allow(clippy::cast_precision_loss)]
215            let total_percentage_compressed =
216                self.processed() as f64 * 100.0 / self.produced() as f64;
217            let total_percentage_saved = 100.0 - total_percentage_compressed;
218            let total_kib_saved = (self.produced() - self.processed()) / 1024;
219
220            tracing::trace!(
221                bytes.compressed = message.len(),
222                bytes.decompressed = decompressed.len(),
223                total_percentage_saved,
224                "{total_kib_saved} KiB saved in total",
225            );
226        }
227
228        self.clear();
229
230        String::from_utf8(decompressed)
231            .map(Some)
232            .map_err(|source| CompressionError {
233                kind: CompressionErrorType::NotUtf8,
234                source: Some(Box::new(source)),
235            })
236    }
237
238    /// Reset the inflater's state.
239    pub(crate) fn reset(&mut self) {
240        self.compressed = Vec::new();
241        self.decompress.reset(true);
242    }
243
244    /// Total number of bytes processed.
245    pub fn processed(&self) -> u64 {
246        self.decompress.total_in()
247    }
248
249    /// Total number of bytes produced.
250    pub fn produced(&self) -> u64 {
251        self.decompress.total_out()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::Inflater;
258
259    const MESSAGE: &[u8] = &[
260        120, 156, 52, 201, 65, 10, 131, 48, 16, 5, 208, 187, 252, 117, 82, 98, 169, 32, 115, 21,
261        35, 50, 53, 67, 27, 136, 81, 226, 216, 82, 66, 238, 222, 110, 186, 123, 240, 42, 20, 148,
262        207, 148, 12, 142, 63, 182, 29, 212, 57, 131, 0, 170, 120, 10, 23, 189, 11, 235, 28, 179,
263        74, 121, 113, 2, 221, 186, 107, 255, 251, 89, 11, 47, 2, 26, 49, 122, 60, 88, 229, 205, 31,
264        187, 151, 96, 87, 142, 217, 14, 253, 16, 60, 76, 245, 88, 227, 82, 182, 195, 131, 220, 197,
265        181, 9, 83, 107, 95, 0, 0, 0, 255, 255,
266    ];
267    const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-main-858d\",{\"micros\":0.0}]"]}}"#;
268
269    #[test]
270    fn decompress_single_segment() {
271        let mut inflator = Inflater::new();
272        assert!(inflator.compressed.is_empty());
273        assert_eq!(inflator.inflate(MESSAGE).unwrap(), Some(OUTPUT.to_owned()));
274
275        assert!(inflator.compressed.is_empty());
276    }
277
278    #[test]
279    fn decompress_split_message() {
280        let mut inflator = Inflater::new();
281        assert!(inflator.compressed.is_empty());
282        assert_eq!(
283            inflator.inflate(&MESSAGE[0..MESSAGE.len() / 2]).unwrap(),
284            None
285        );
286        assert!(!inflator.compressed.is_empty());
287
288        assert_eq!(
289            inflator.inflate(&MESSAGE[MESSAGE.len() / 2..]).unwrap(),
290            Some(OUTPUT.to_owned()),
291        );
292        assert!(inflator.compressed.is_empty());
293    }
294
295    #[test]
296    fn invalid_is_none() {
297        let mut inflator = Inflater::new();
298        assert_eq!(inflator.inflate(&[]).unwrap(), None);
299
300        assert_eq!(
301            inflator.inflate(&MESSAGE[..MESSAGE.len() - 2]).unwrap(),
302            None
303        );
304    }
305
306    #[test]
307    fn reset() {
308        let mut inflator = Inflater::new();
309        assert_eq!(
310            inflator.inflate(&MESSAGE[..MESSAGE.len() - 2]).unwrap(),
311            None
312        );
313
314        inflator.reset();
315        assert_eq!(inflator.inflate(MESSAGE).unwrap(), Some(OUTPUT.to_owned()));
316    }
317}