walker_common/compression/
detecting.rs

1use bytes::Bytes;
2use std::collections::HashSet;
3
4#[derive(Debug, thiserror::Error)]
5pub enum Error<'a> {
6    #[error("unsupported compression: {0}")]
7    Unsupported(&'a str),
8    #[error(transparent)]
9    Io(#[from] std::io::Error),
10}
11
12#[derive(Copy, Clone, Eq, PartialEq, Debug)]
13pub enum Compression {
14    None,
15    #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
16    Bzip2,
17    #[cfg(feature = "liblzma")]
18    Xz,
19}
20
21#[non_exhaustive]
22#[derive(Clone, Debug, PartialEq, Eq, Default)]
23pub struct DecompressionOptions {
24    /// The maximum decompressed payload size.
25    ///
26    /// If the size of the uncompressed payload exceeds this limit, and error would be returned
27    /// instead. Zero means, unlimited.
28    pub limit: usize,
29}
30
31impl DecompressionOptions {
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Set the limit of the maximum uncompressed payload size.
37    pub fn limit(mut self, limit: usize) -> Self {
38        self.limit = limit;
39        self
40    }
41}
42
43impl Compression {
44    /// Perform decompression.
45    ///
46    /// Returns the original data for [`Compression::None`].
47    pub fn decompress(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
48        Ok(self.decompress_opt(&data)?.unwrap_or(data))
49    }
50
51    /// Perform decompression.
52    ///
53    /// Returns the original data for [`Compression::None`].
54    pub fn decompress_with(
55        &self,
56        data: Bytes,
57        opts: &DecompressionOptions,
58    ) -> Result<Bytes, std::io::Error> {
59        Ok(self.decompress_opt_with(&data, opts)?.unwrap_or(data))
60    }
61
62    /// Perform decompression.
63    ///
64    /// Returns `None` for [`Compression::None`]
65    pub fn decompress_opt(&self, data: &[u8]) -> Result<Option<Bytes>, std::io::Error> {
66        self.decompress_opt_with(data, &Default::default())
67    }
68
69    /// Perform decompression.
70    ///
71    /// Returns `None` for [`Compression::None`]
72    pub fn decompress_opt_with(
73        &self,
74        #[allow(unused_variables)] data: &[u8],
75        #[allow(unused_variables)] opts: &DecompressionOptions,
76    ) -> Result<Option<Bytes>, std::io::Error> {
77        match self {
78            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
79            Compression::Bzip2 =>
80            {
81                #[allow(deprecated)]
82                super::decompress_bzip2_with(data, opts).map(Some)
83            }
84            #[cfg(feature = "liblzma")]
85            Compression::Xz =>
86            {
87                #[allow(deprecated)]
88                super::decompress_xz_with(data, opts).map(Some)
89            }
90            Compression::None => Ok(None),
91        }
92    }
93}
94
95#[derive(Clone, Debug, Default)]
96pub struct Detector<'a> {
97    /// File name
98    pub file_name: Option<&'a str>,
99
100    /// Disable detection by magic bytes
101    pub disable_magic: bool,
102
103    /// File name extensions to ignore.
104    pub ignore_file_extensions: HashSet<&'a str>,
105    /// If a file name is present, but the extension is unknown, report as an error
106    pub fail_unknown_file_extension: bool,
107}
108
109impl<'a> Detector<'a> {
110    /// Detect and decompress in a single step.
111    pub fn decompress(&'a self, data: Bytes) -> Result<Bytes, Error<'a>> {
112        self.decompress_with(data, &Default::default())
113    }
114
115    /// Detect and decompress in a single step.
116    pub fn decompress_with(
117        &'a self,
118        data: Bytes,
119        opts: &DecompressionOptions,
120    ) -> Result<Bytes, Error<'a>> {
121        let compression = self.detect(&data)?;
122        Ok(compression.decompress_with(data, opts)?)
123    }
124
125    pub fn detect(&'a self, #[allow(unused)] data: &[u8]) -> Result<Compression, Error<'a>> {
126        // detect by file name extension
127
128        if let Some(file_name) = self.file_name {
129            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
130            if file_name.ends_with(".bz2") {
131                return Ok(Compression::Bzip2);
132            }
133            #[cfg(feature = "liblzma")]
134            if file_name.ends_with(".xz") {
135                return Ok(Compression::Xz);
136            }
137            if self.fail_unknown_file_extension {
138                if let Some((_, ext)) = file_name.rsplit_once('.') {
139                    if !self.ignore_file_extensions.contains(ext) {
140                        return Err(Error::Unsupported(ext));
141                    }
142                }
143            }
144        }
145
146        // magic bytes
147
148        if !self.disable_magic {
149            #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
150            if data.starts_with(b"BZh") {
151                return Ok(Compression::Bzip2);
152            }
153            #[cfg(feature = "liblzma")]
154            if data.starts_with(&[0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00]) {
155                return Ok(Compression::Xz);
156            }
157        }
158
159        // done
160
161        Ok(Compression::None)
162    }
163}