1use crate::{Error, Result, MAGIC_COMPRESSED};
4use alloc::vec::Vec;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[repr(u8)]
10pub enum CompressionAlgorithm {
11 None = 0x00,
13 Gzip = 0x01,
15 Brotli = 0x02,
17}
18
19impl CompressionAlgorithm {
20 pub fn from_byte(byte: u8) -> Result<Self> {
22 match byte {
23 0x00 => Ok(CompressionAlgorithm::None),
24 0x01 => Ok(CompressionAlgorithm::Gzip),
25 0x02 => Ok(CompressionAlgorithm::Brotli),
26 _ => Err(Error::UnsupportedAlgorithm(byte)),
27 }
28 }
29
30 pub fn name(&self) -> &'static str {
32 match self {
33 CompressionAlgorithm::None => "none",
34 CompressionAlgorithm::Gzip => "gzip",
35 CompressionAlgorithm::Brotli => "brotli",
36 }
37 }
38
39 pub fn from_name(name: &str) -> Result<Self> {
41 match name.to_lowercase().as_str() {
42 "none" => Ok(CompressionAlgorithm::None),
43 "gzip" => Ok(CompressionAlgorithm::Gzip),
44 "brotli" => Ok(CompressionAlgorithm::Brotli),
45 _ => Err(Error::UnsupportedAlgorithm(0)),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct CompressionResult {
53 pub compressed: Vec<u8>,
55 pub algorithm: CompressionAlgorithm,
57 pub original_size: usize,
59 pub compressed_size: usize,
61 pub compression_ratio: f64,
63}
64
65#[derive(Debug, Clone)]
67pub struct CompressionOptions {
68 pub algorithm: CompressionAlgorithm,
70 pub min_size_threshold: usize,
72 pub level: u32,
74}
75
76impl Default for CompressionOptions {
77 fn default() -> Self {
78 Self {
79 algorithm: CompressionAlgorithm::Brotli,
80 min_size_threshold: 100,
81 level: 6,
82 }
83 }
84}
85
86pub fn compress(data: &[u8], options: Option<CompressionOptions>) -> Result<CompressionResult> {
88 let opts = options.unwrap_or_default();
89 let original_size = data.len();
90
91 if original_size < opts.min_size_threshold {
93 return Ok(CompressionResult {
94 compressed: data.to_vec(),
95 algorithm: CompressionAlgorithm::None,
96 original_size,
97 compressed_size: original_size,
98 compression_ratio: 1.0,
99 });
100 }
101
102 if opts.algorithm == CompressionAlgorithm::None {
104 return Ok(CompressionResult {
105 compressed: data.to_vec(),
106 algorithm: CompressionAlgorithm::None,
107 original_size,
108 compressed_size: original_size,
109 compression_ratio: 1.0,
110 });
111 }
112
113 let (compressed, algorithm) = match opts.algorithm {
114 CompressionAlgorithm::Brotli => {
115 compress_brotli(data, opts.level)?
116 }
117 CompressionAlgorithm::Gzip => {
118 compress_gzip(data, opts.level)?
119 }
120 CompressionAlgorithm::None => {
121 (data.to_vec(), CompressionAlgorithm::None)
122 }
123 };
124
125 let compressed_size = compressed.len();
126 let compression_ratio = compressed_size as f64 / original_size as f64;
127
128 if compression_ratio < 0.9 {
130 Ok(CompressionResult {
131 compressed,
132 algorithm,
133 original_size,
134 compressed_size,
135 compression_ratio,
136 })
137 } else {
138 Ok(CompressionResult {
139 compressed: data.to_vec(),
140 algorithm: CompressionAlgorithm::None,
141 original_size,
142 compressed_size: original_size,
143 compression_ratio: 1.0,
144 })
145 }
146}
147
148pub fn decompress(data: &[u8], algorithm: CompressionAlgorithm) -> Result<Vec<u8>> {
150 match algorithm {
151 CompressionAlgorithm::None => Ok(data.to_vec()),
152 CompressionAlgorithm::Gzip => decompress_gzip(data),
153 CompressionAlgorithm::Brotli => decompress_brotli(data),
154 }
155}
156
157fn compress_brotli(data: &[u8], level: u32) -> Result<(Vec<u8>, CompressionAlgorithm)> {
159 use brotli::enc::BrotliEncoderParams;
160
161 let mut output = Vec::new();
162 let mut params = BrotliEncoderParams::default();
163 params.quality = level as i32;
164
165 brotli::BrotliCompress(
166 &mut std::io::Cursor::new(data),
167 &mut output,
168 ¶ms,
169 ).map_err(|e| Error::CompressionFailed(e.to_string()))?;
170
171 Ok((output, CompressionAlgorithm::Brotli))
172}
173
174fn decompress_brotli(data: &[u8]) -> Result<Vec<u8>> {
176 let mut output = Vec::new();
177
178 brotli::BrotliDecompress(
179 &mut std::io::Cursor::new(data),
180 &mut output,
181 ).map_err(|e| Error::DecompressionFailed(e.to_string()))?;
182
183 Ok(output)
184}
185
186fn compress_gzip(data: &[u8], level: u32) -> Result<(Vec<u8>, CompressionAlgorithm)> {
188 use flate2::write::GzEncoder;
189 use flate2::Compression;
190 use std::io::Write;
191
192 let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
193 encoder.write_all(data)
194 .map_err(|e| Error::CompressionFailed(e.to_string()))?;
195
196 let output = encoder.finish()
197 .map_err(|e| Error::CompressionFailed(e.to_string()))?;
198
199 Ok((output, CompressionAlgorithm::Gzip))
200}
201
202fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>> {
204 use flate2::read::GzDecoder;
205 use std::io::Read;
206
207 let mut decoder = GzDecoder::new(data);
208 let mut output = Vec::new();
209
210 decoder.read_to_end(&mut output)
211 .map_err(|e| Error::DecompressionFailed(e.to_string()))?;
212
213 Ok(output)
214}
215
216pub fn serialize_with_header(result: &CompressionResult) -> Vec<u8> {
218 let original_size = result.original_size as u32;
219 let mut output = Vec::with_capacity(7 + result.compressed.len());
220
221 output.extend_from_slice(MAGIC_COMPRESSED);
223 output.push(result.algorithm as u8);
225 output.extend_from_slice(&original_size.to_be_bytes());
227 output.extend_from_slice(&result.compressed);
229
230 output
231}
232
233pub fn deserialize_with_header(data: &[u8]) -> Result<(Vec<u8>, CompressionAlgorithm, usize)> {
235 if data.len() < 7 {
236 return Err(Error::TruncatedPayload {
237 expected: 7,
238 actual: data.len(),
239 });
240 }
241
242 if &data[0..2] != MAGIC_COMPRESSED {
244 return Err(Error::InvalidFormat);
245 }
246
247 let algorithm = CompressionAlgorithm::from_byte(data[2])?;
249
250 let original_size = u32::from_be_bytes([data[3], data[4], data[5], data[6]]) as usize;
252
253 let compressed = data[7..].to_vec();
255
256 Ok((compressed, algorithm, original_size))
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_gzip_roundtrip() {
265 let data = b"Hello, World! This is a test message that should be compressed.";
266
267 let result = compress(data, Some(CompressionOptions {
268 algorithm: CompressionAlgorithm::Gzip,
269 min_size_threshold: 10,
270 level: 6,
271 })).unwrap();
272
273 let decompressed = decompress(&result.compressed, result.algorithm).unwrap();
274 assert_eq!(data, &decompressed[..]);
275 }
276
277 #[test]
278 fn test_brotli_roundtrip() {
279 let data = b"Hello, World! This is a test message that should be compressed with Brotli.";
280
281 let result = compress(data, Some(CompressionOptions {
282 algorithm: CompressionAlgorithm::Brotli,
283 min_size_threshold: 10,
284 level: 6,
285 })).unwrap();
286
287 let decompressed = decompress(&result.compressed, result.algorithm).unwrap();
288 assert_eq!(data, &decompressed[..]);
289 }
290
291 #[test]
292 fn test_skip_small_data() {
293 let data = b"tiny";
294
295 let result = compress(data, Some(CompressionOptions {
296 algorithm: CompressionAlgorithm::Brotli,
297 min_size_threshold: 100, level: 6,
299 })).unwrap();
300
301 assert_eq!(result.algorithm, CompressionAlgorithm::None);
302 assert_eq!(result.compressed, data);
303 }
304
305 #[test]
306 fn test_header_serialization() {
307 let data = b"Test data for header serialization test with enough content.";
308
309 let result = compress(data, Some(CompressionOptions {
310 algorithm: CompressionAlgorithm::Gzip,
311 min_size_threshold: 10,
312 level: 6,
313 })).unwrap();
314
315 let serialized = serialize_with_header(&result);
316 let (compressed, algorithm, original_size) = deserialize_with_header(&serialized).unwrap();
317
318 assert_eq!(algorithm, result.algorithm);
319 assert_eq!(original_size, result.original_size);
320 assert_eq!(compressed, result.compressed);
321 }
322}
323