walker_common/compression/
detecting.rs

1use bytes::Bytes;
2use std::collections::HashSet;
3
4#[derive(Debug, thiserror::Error)]
5pub enum Error<'a> {
6    #[error("unsupported compression: {0}")]
7    Unsupported(&'a str),
8    #[error(transparent)]
9    Io(#[from] std::io::Error),
10}
11
12#[derive(Copy, Clone, Eq, PartialEq, Debug)]
13pub enum Compression {
14    None,
15    #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
16    Bzip2,
17    #[cfg(feature = "liblzma")]
18    Xz,
19    #[cfg(feature = "flate2")]
20    Gzip,
21}
22
23#[non_exhaustive]
24#[derive(Clone, Debug, PartialEq, Eq, Default)]
25pub struct DecompressionOptions {
26    /// The maximum decompressed payload size.
27    ///
28    /// If the size of the uncompressed payload exceeds this limit, and error would be returned
29    /// instead. Zero means, unlimited.
30    pub limit: usize,
31}
32
33impl DecompressionOptions {
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Set the limit of the maximum uncompressed payload size.
39    pub fn limit(mut self, limit: usize) -> Self {
40        self.limit = limit;
41        self
42    }
43}
44
45impl Compression {
46    /// Perform decompression.
47    ///
48    /// Returns the original data for [`Compression::None`].
49    pub fn decompress(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
50        Ok(self.decompress_opt(&data)?.unwrap_or(data))
51    }
52
53    /// Perform decompression.
54    ///
55    /// Returns the original data for [`Compression::None`].
56    pub fn decompress_with(
57        &self,
58        data: Bytes,
59        opts: &DecompressionOptions,
60    ) -> Result<Bytes, std::io::Error> {
61        Ok(self.decompress_opt_with(&data, opts)?.unwrap_or(data))
62    }
63
64    /// Perform decompression.
65    ///
66    /// Returns `None` for [`Compression::None`]
67    pub fn decompress_opt(&self, data: &[u8]) -> Result<Option<Bytes>, std::io::Error> {
68        self.decompress_opt_with(data, &Default::default())
69    }
70
71    /// Perform decompression.
72    ///
73    /// Returns `None` for [`Compression::None`]
74    pub fn decompress_opt_with(
75        &self,
76        #[allow(unused_variables)] data: &[u8],
77        #[allow(unused_variables)] opts: &DecompressionOptions,
78    ) -> Result<Option<Bytes>, std::io::Error> {
79        match self {
80            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
81            Compression::Bzip2 =>
82            {
83                #[allow(deprecated)]
84                super::decompress_bzip2_with(data, opts).map(Some)
85            }
86            #[cfg(feature = "liblzma")]
87            Compression::Xz =>
88            {
89                #[allow(deprecated)]
90                super::decompress_xz_with(data, opts).map(Some)
91            }
92            #[cfg(feature = "flate2")]
93            Compression::Gzip =>
94            {
95                #[allow(deprecated)]
96                super::decompress_gzip_with(data, opts).map(Some)
97            }
98            Compression::None => Ok(None),
99        }
100    }
101}
102
103#[derive(Clone, Debug, Default)]
104pub struct Detector<'a> {
105    /// File name
106    pub file_name: Option<&'a str>,
107
108    /// Disable detection by magic bytes
109    pub disable_magic: bool,
110
111    /// File name extensions to ignore.
112    pub ignore_file_extensions: HashSet<&'a str>,
113    /// If a file name is present, but the extension is unknown, report as an error
114    pub fail_unknown_file_extension: bool,
115}
116
117impl<'a> Detector<'a> {
118    /// Detect and decompress in a single step.
119    pub fn decompress(&'a self, data: Bytes) -> Result<Bytes, Error<'a>> {
120        self.decompress_with(data, &Default::default())
121    }
122
123    /// Detect and decompress in a single step.
124    pub fn decompress_with(
125        &'a self,
126        data: Bytes,
127        opts: &DecompressionOptions,
128    ) -> Result<Bytes, Error<'a>> {
129        let compression = self.detect(&data)?;
130        Ok(compression.decompress_with(data, opts)?)
131    }
132
133    pub fn detect(&'a self, #[allow(unused)] data: &[u8]) -> Result<Compression, Error<'a>> {
134        // detect by file name extension
135
136        if let Some(file_name) = self.file_name {
137            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
138            if file_name.ends_with(".bz2") {
139                return Ok(Compression::Bzip2);
140            }
141            #[cfg(feature = "liblzma")]
142            if file_name.ends_with(".xz") {
143                return Ok(Compression::Xz);
144            }
145            #[cfg(feature = "flate2")]
146            if file_name.ends_with(".gz") {
147                return Ok(Compression::Gzip);
148            }
149            if self.fail_unknown_file_extension {
150                if let Some((_, ext)) = file_name.rsplit_once('.') {
151                    if !self.ignore_file_extensions.contains(ext) {
152                        return Err(Error::Unsupported(ext));
153                    }
154                }
155            }
156        }
157
158        // magic bytes
159
160        if !self.disable_magic {
161            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
162            if data.starts_with(b"BZh") {
163                return Ok(Compression::Bzip2);
164            }
165            #[cfg(feature = "liblzma")]
166            if data.starts_with(&[0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00]) {
167                return Ok(Compression::Xz);
168            }
169            #[cfg(feature = "flate2")]
170            if data.starts_with(&[0x1F, 0x8B, 0x08]) {
171                // NOTE: Byte #3 (0x08) is the compression format, which means "deflate" and is the
172                // only one supported right now. Having additional compression formats, we'd need
173                // to extend this check, or drop the 3rd byte.
174                return Ok(Compression::Gzip);
175            }
176        }
177
178        // done
179
180        Ok(Compression::None)
181    }
182}
183
184#[cfg(test)]
185mod test {
186    use super::*;
187
188    fn detect(name: &str) -> Compression {
189        Detector {
190            file_name: Some(name),
191            disable_magic: true,
192            ..Default::default()
193        }
194        .detect(&[])
195        .unwrap()
196    }
197
198    #[test]
199    fn by_name_none() {
200        assert_eq!(detect("foo.bar.json"), Compression::None);
201    }
202
203    #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
204    #[test]
205    fn by_name_bzip2() {
206        assert_eq!(detect("foo.bar.bz2"), Compression::Bzip2);
207    }
208
209    #[cfg(feature = "liblzma")]
210    #[test]
211    fn by_name_xz() {
212        assert_eq!(detect("foo.bar.xz"), Compression::Xz);
213    }
214
215    #[cfg(feature = "flate2")]
216    #[test]
217    fn by_name_gzip() {
218        assert_eq!(detect("foo.bar.gz"), Compression::Gzip);
219    }
220}