walker_common/compression/
detecting.rs1use bytes::Bytes;
2use std::collections::HashSet;
3
4#[derive(Debug, thiserror::Error)]
5pub enum Error<'a> {
6 #[error("unsupported compression: {0}")]
7 Unsupported(&'a str),
8 #[error(transparent)]
9 Io(#[from] std::io::Error),
10}
11
12#[derive(Copy, Clone, Eq, PartialEq, Debug)]
13#[non_exhaustive]
14pub enum Compression {
15 None,
16 #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
17 Bzip2,
18 #[cfg(feature = "lzma")]
19 Xz,
20 #[cfg(feature = "flate2")]
21 Gzip,
22}
23
24#[non_exhaustive]
25#[derive(Clone, Debug, PartialEq, Eq, Default)]
26pub struct DecompressionOptions {
27 pub limit: usize,
32}
33
34impl DecompressionOptions {
35 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn limit(mut self, limit: usize) -> Self {
41 self.limit = limit;
42 self
43 }
44}
45
46impl Compression {
47 pub fn decompress(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
51 Ok(self.decompress_opt(&data)?.unwrap_or(data))
52 }
53
54 pub fn decompress_with(
58 &self,
59 data: Bytes,
60 opts: &DecompressionOptions,
61 ) -> Result<Bytes, std::io::Error> {
62 Ok(self.decompress_opt_with(&data, opts)?.unwrap_or(data))
63 }
64
65 pub fn decompress_opt(&self, data: &[u8]) -> Result<Option<Bytes>, std::io::Error> {
69 self.decompress_opt_with(data, &Default::default())
70 }
71
72 pub fn decompress_opt_with(
76 &self,
77 #[allow(unused_variables)] data: &[u8],
78 #[allow(unused_variables)] opts: &DecompressionOptions,
79 ) -> Result<Option<Bytes>, std::io::Error> {
80 match self {
81 #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
82 Compression::Bzip2 => super::decompress_bzip2_with(data, opts).map(Some),
83 #[cfg(feature = "lzma")]
84 Compression::Xz => super::decompress_xz_with(data, opts).map(Some),
85 #[cfg(feature = "flate2")]
86 Compression::Gzip => super::decompress_gzip_with(data, opts).map(Some),
87 Compression::None => Ok(None),
88 }
89 }
90}
91
92#[derive(Clone, Debug, Default)]
93pub struct Detector<'a> {
94 pub file_name: Option<&'a str>,
96
97 pub disable_magic: bool,
99
100 pub ignore_file_extensions: HashSet<&'a str>,
102 pub fail_unknown_file_extension: bool,
104}
105
106impl<'a> Detector<'a> {
107 pub fn decompress(&self, data: Bytes) -> Result<Bytes, Error<'a>> {
109 self.decompress_with(data, &Default::default())
110 }
111
112 pub fn decompress_with(
114 &self,
115 data: Bytes,
116 opts: &DecompressionOptions,
117 ) -> Result<Bytes, Error<'a>> {
118 let compression = self.detect(&data)?;
119 Ok(compression.decompress_with(data, opts)?)
120 }
121
122 pub fn detect(&self, #[allow(unused)] data: &[u8]) -> Result<Compression, Error<'a>> {
123 if let Some(file_name) = self.file_name {
126 #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
127 if file_name.ends_with(".bz2") {
128 return Ok(Compression::Bzip2);
129 }
130 #[cfg(feature = "lzma")]
131 if file_name.ends_with(".xz") {
132 return Ok(Compression::Xz);
133 }
134 #[cfg(feature = "flate2")]
135 if file_name.ends_with(".gz") {
136 return Ok(Compression::Gzip);
137 }
138 if self.fail_unknown_file_extension
139 && let Some((_, ext)) = file_name.rsplit_once('.')
140 && !self.ignore_file_extensions.contains(ext)
141 {
142 return Err(Error::Unsupported(ext));
143 }
144 }
145
146 if !self.disable_magic {
149 #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
150 if data.starts_with(b"BZh") {
151 return Ok(Compression::Bzip2);
152 }
153 #[cfg(feature = "lzma")]
154 if data.starts_with(&[0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00]) {
155 return Ok(Compression::Xz);
156 }
157 #[cfg(feature = "flate2")]
158 if data.starts_with(&[0x1F, 0x8B, 0x08]) {
159 return Ok(Compression::Gzip);
163 }
164 }
165
166 Ok(Compression::None)
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175
176 fn detect(name: &str) -> Compression {
177 Detector {
178 file_name: Some(name),
179 disable_magic: true,
180 ..Default::default()
181 }
182 .detect(&[])
183 .unwrap()
184 }
185
186 #[test]
187 fn by_name_none() {
188 assert_eq!(detect("foo.bar.json"), Compression::None);
189 }
190
191 #[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
192 #[test]
193 fn by_name_bzip2() {
194 assert_eq!(detect("foo.bar.bz2"), Compression::Bzip2);
195 }
196
197 #[cfg(feature = "lzma")]
198 #[test]
199 fn by_name_xz() {
200 assert_eq!(detect("foo.bar.xz"), Compression::Xz);
201 }
202
203 #[cfg(feature = "flate2")]
204 #[test]
205 fn by_name_gzip() {
206 assert_eq!(detect("foo.bar.gz"), Compression::Gzip);
207 }
208
209 #[test]
210 fn default() {
211 let _result = Detector::default().decompress(Bytes::from_static(b"foo"));
213
214 let detector = Detector::default();
215 let _result = detector.decompress(Bytes::from_static(b"foo"));
216 }
217}