use bytes::Bytes;
use std::collections::HashSet;
#[derive(Debug, thiserror::Error)]
pub enum Error<'a> {
#[error("unsupported compression: {0}")]
Unsupported(&'a str),
#[error(transparent)]
Io(#[from] std::io::Error),
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum Compression {
None,
#[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
Bzip2,
#[cfg(feature = "liblzma")]
Xz,
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct DecompressionOptions {
pub limit: usize,
}
impl DecompressionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
impl Compression {
pub fn decompress(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
Ok(self.decompress_opt(&data)?.unwrap_or(data))
}
pub fn decompress_with(
&self,
data: Bytes,
opts: &DecompressionOptions,
) -> Result<Bytes, std::io::Error> {
Ok(self.decompress_opt_with(&data, opts)?.unwrap_or(data))
}
pub fn decompress_opt(&self, data: &[u8]) -> Result<Option<Bytes>, std::io::Error> {
self.decompress_opt_with(data, &Default::default())
}
pub fn decompress_opt_with(
&self,
data: &[u8],
opts: &DecompressionOptions,
) -> Result<Option<Bytes>, std::io::Error> {
match self {
#[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
Compression::Bzip2 => super::decompress_bzip2_with(data, opts).map(Some),
#[cfg(feature = "liblzma")]
Compression::Xz => super::decompress_xz_with(data, opts).map(Some),
Compression::None => Ok(None),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Detector<'a> {
pub file_name: Option<&'a str>,
pub disable_magic: bool,
pub ignore_file_extensions: HashSet<&'a str>,
pub fail_unknown_file_extension: bool,
}
impl<'a> Detector<'a> {
pub fn decompress(&'a self, data: Bytes) -> Result<Bytes, Error<'a>> {
self.decompress_with(data, &Default::default())
}
pub fn decompress_with(
&'a self,
data: Bytes,
opts: &DecompressionOptions,
) -> Result<Bytes, Error<'a>> {
let compression = self.detect(&data)?;
Ok(compression.decompress_with(data, opts)?)
}
pub fn detect(&'a self, #[allow(unused)] data: &[u8]) -> Result<Compression, Error<'a>> {
if let Some(file_name) = self.file_name {
#[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
if file_name.ends_with(".bz2") {
return Ok(Compression::Bzip2);
}
#[cfg(feature = "liblzma")]
if file_name.ends_with(".xz") {
return Ok(Compression::Xz);
}
if self.fail_unknown_file_extension {
if let Some((_, ext)) = file_name.rsplit_once('.') {
if !self.ignore_file_extensions.contains(ext) {
return Err(Error::Unsupported(ext));
}
}
}
}
if !self.disable_magic {
#[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
if data.starts_with(b"BZh") {
return Ok(Compression::Bzip2);
}
#[cfg(feature = "liblzma")]
if data.starts_with(&[0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00]) {
return Ok(Compression::Xz);
}
}
Ok(Compression::None)
}
}