Skip to main content

titan_api_codec/transform/
zstd.rs

1//! Defines transforms that utilize [zstd] to compress and decompress data.
2//!
3//! [zstd]: https://github.com/facebook/zstd
4
5use super::common::BinaryTransform;
6
7use bytes::{Buf, Bytes};
8use zstd::bulk::{compress, decompress, Compressor, Decompressor};
9use zstd::stream::decode_all;
10
11/// Transform that applies zstd compression to input.
12#[derive(Default)]
13pub struct ZstdCompressor {
14    level: i32,
15    inner: Compressor<'static>,
16}
17
18impl ZstdCompressor {
19    /// Creates a new compressor with the given compression level.
20    pub fn new(level: i32) -> std::io::Result<Self> {
21        let mut inner = Compressor::default();
22        if level != 0 {
23            inner.set_dictionary(level, &[])?;
24        }
25
26        Ok(Self { level, inner })
27    }
28}
29
30impl BinaryTransform for ZstdCompressor {
31    fn transform(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
32        compress(&data, self.level).map(Bytes::from)
33    }
34
35    fn transform_mut(&mut self, data: Bytes) -> Result<Bytes, std::io::Error> {
36        self.inner.compress(&data).map(Bytes::from)
37    }
38}
39
40/// Transform that transforms zstd-compressed data back to its original content.
41#[derive(Default)]
42pub struct ZstdDecompressor {
43    inner: Decompressor<'static>,
44}
45
46impl ZstdCompressor {}
47
48impl BinaryTransform for ZstdDecompressor {
49    fn transform(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
50        if let Ok(bound) = zstd_safe::decompress_bound(&data) {
51            decompress(&data, bound as usize).map(Bytes::from)
52        } else {
53            // Unable to determine size, fallback to stream decoding.
54            decode_all(data.reader()).map(Bytes::from)
55        }
56    }
57
58    fn transform_mut(&mut self, data: Bytes) -> Result<Bytes, std::io::Error> {
59        if let Ok(bound) = zstd_safe::decompress_bound(&data) {
60            self.inner
61                .decompress(&data, bound as usize)
62                .map(Bytes::from)
63        } else {
64            // Unable to determine size, fallback to stream decoding.
65            decode_all(data.reader()).map(Bytes::from)
66        }
67    }
68}
69
70#[cfg(test)]
71mod test {
72    use super::{ZstdCompressor, ZstdDecompressor};
73    use crate::transform::BinaryTransform;
74    use bytes::Bytes;
75    use lipsum::lipsum;
76
77    #[test]
78    fn test_roundtrip_default() {
79        let compressor = ZstdCompressor::default();
80        let decompressor = ZstdDecompressor::default();
81
82        let data = Bytes::from(lipsum(1000));
83
84        let compressed = compressor
85            .transform(data.clone())
86            .expect("should compress via zstd");
87        let uncompressed = decompressor
88            .transform(compressed)
89            .expect("should decompress from zstd");
90
91        assert_eq!(data, uncompressed);
92    }
93
94    #[test]
95    fn test_roundtrip_mut_default() {
96        let mut compressor = ZstdCompressor::default();
97        let mut decompressor = ZstdDecompressor::default();
98
99        let data = Bytes::from(lipsum(1000));
100
101        let compressed = compressor
102            .transform_mut(data.clone())
103            .expect("should compress via zstd");
104        let uncompressed = decompressor
105            .transform_mut(compressed)
106            .expect("should decompress from zstd");
107
108        assert_eq!(data, uncompressed);
109    }
110
111    #[test]
112    fn test_compressor_adds_size() {
113        let mut compressor = ZstdCompressor::default();
114
115        let data = Bytes::from(lipsum(1000));
116
117        let compressed = compressor
118            .transform_mut(data.clone())
119            .expect("should compress via zstd");
120
121        let size_bound = zstd_safe::decompress_bound(&compressed)
122            .expect("should be able to determine decompress bound");
123        assert_eq!(size_bound as usize, data.len());
124    }
125}