Skip to main content

sklears_simd/
compression.rs

1//! SIMD-optimized compression algorithms
2//!
3//! This module provides SIMD-accelerated implementations of common compression algorithms
4//! including run-length encoding, LZ77, and dictionary-based compression.
5
6#[cfg(feature = "no-std")]
7extern crate alloc;
8
9#[cfg(feature = "no-std")]
10use alloc::{collections::BTreeMap as HashMap, vec, vec::Vec};
11#[cfg(not(feature = "no-std"))]
12use std::collections::HashMap;
13
14/// Run-length encode a byte array using SIMD optimizations
15///
16/// Returns a vector of (value, count) pairs
17pub fn run_length_encode_simd(data: &[u8]) -> Vec<(u8, u32)> {
18    if data.is_empty() {
19        return Vec::new();
20    }
21
22    let mut result = Vec::new();
23    let mut current_byte = data[0];
24    let mut count = 1u32;
25
26    // Process in chunks for SIMD optimization
27    let chunk_size = 16; // SSE2 width for u8
28    let mut i = 1;
29
30    while i + chunk_size <= data.len() {
31        // Check if the next chunk contains all the same byte
32        let chunk = &data[i..i + chunk_size];
33        if chunk.iter().all(|&b| b == current_byte) {
34            count += chunk_size as u32;
35            i += chunk_size;
36        } else {
37            // Find the first different byte in the chunk
38            let mut j = 0;
39            while j < chunk_size && chunk[j] == current_byte {
40                count += 1;
41                j += 1;
42            }
43            i += j;
44
45            if j < chunk_size {
46                // Found a different byte
47                result.push((current_byte, count));
48                current_byte = chunk[j];
49                count = 1;
50                i += 1;
51            }
52        }
53    }
54
55    // Process remaining bytes
56    while i < data.len() {
57        if data[i] == current_byte {
58            count += 1;
59        } else {
60            result.push((current_byte, count));
61            current_byte = data[i];
62            count = 1;
63        }
64        i += 1;
65    }
66
67    result.push((current_byte, count));
68    result
69}
70
71/// Decode run-length encoded data
72pub fn run_length_decode(encoded: &[(u8, u32)]) -> Vec<u8> {
73    let total_size: usize = encoded.iter().map(|(_, count)| *count as usize).sum();
74    let mut result = Vec::with_capacity(total_size);
75
76    for &(byte, count) in encoded {
77        result.extend(core::iter::repeat_n(byte, count as usize));
78    }
79
80    result
81}
82
83/// Simple LZ77-style compression using SIMD for pattern matching
84pub struct LZ77Compressor {
85    window_size: usize,
86    lookahead_size: usize,
87}
88
89impl LZ77Compressor {
90    pub fn new(window_size: usize, lookahead_size: usize) -> Self {
91        Self {
92            window_size,
93            lookahead_size,
94        }
95    }
96
97    /// Find the longest match in the sliding window using SIMD acceleration
98    fn find_longest_match(&self, data: &[u8], pos: usize) -> (usize, usize) {
99        let window_start = pos.saturating_sub(self.window_size);
100        let window_end = pos;
101        let lookahead_end = (pos + self.lookahead_size).min(data.len());
102
103        if window_start >= window_end || pos >= lookahead_end {
104            return (0, 0);
105        }
106
107        let mut best_distance = 0;
108        let mut best_length = 0;
109
110        // Use SIMD to accelerate the pattern matching
111        for window_pos in window_start..window_end {
112            let mut match_length = 0;
113            let max_length = (lookahead_end - pos).min(pos - window_pos);
114
115            // Compare bytes using SIMD where possible
116            let chunk_size = 16.min(max_length);
117            if chunk_size >= 16 {
118                // Use SIMD comparison for larger chunks
119                let window_chunk = &data[window_pos..window_pos + chunk_size];
120                let lookahead_chunk = &data[pos..pos + chunk_size];
121
122                if window_chunk == lookahead_chunk {
123                    match_length = chunk_size;
124
125                    // Extend the match beyond the SIMD chunk
126                    while match_length < max_length
127                        && data[window_pos + match_length] == data[pos + match_length]
128                    {
129                        match_length += 1;
130                    }
131                }
132            } else {
133                // Fallback to byte-by-byte comparison for small chunks
134                while match_length < max_length
135                    && data[window_pos + match_length] == data[pos + match_length]
136                {
137                    match_length += 1;
138                }
139            }
140
141            if match_length > best_length {
142                best_length = match_length;
143                best_distance = pos - window_pos;
144            }
145        }
146
147        (best_distance, best_length)
148    }
149
150    /// Compress data using LZ77 algorithm
151    pub fn compress(&self, data: &[u8]) -> Vec<u8> {
152        let mut compressed = Vec::new();
153        let mut pos = 0;
154
155        while pos < data.len() {
156            let (distance, length) = self.find_longest_match(data, pos);
157
158            if length >= 3 {
159                // Encode as (distance, length) pair
160                compressed.push(0xFF); // Marker for compressed sequence
161                compressed.extend_from_slice(&distance.to_le_bytes()[..2]);
162                compressed.push(length as u8);
163                pos += length;
164            } else {
165                // Encode as literal byte
166                compressed.push(data[pos]);
167                pos += 1;
168            }
169        }
170
171        compressed
172    }
173}
174
175/// Dictionary-based compression using frequency analysis
176pub struct DictionaryCompressor {
177    dictionary: HashMap<Vec<u8>, u16>,
178    reverse_dictionary: HashMap<u16, Vec<u8>>,
179    next_code: u16,
180}
181
182impl Default for DictionaryCompressor {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl DictionaryCompressor {
189    pub fn new() -> Self {
190        let mut compressor = Self {
191            dictionary: HashMap::new(),
192            reverse_dictionary: HashMap::new(),
193            next_code: 256, // Start after single-byte codes
194        };
195
196        // Initialize with single bytes
197        for i in 0..256 {
198            let byte_vec = vec![i as u8];
199            compressor.dictionary.insert(byte_vec.clone(), i as u16);
200            compressor.reverse_dictionary.insert(i as u16, byte_vec);
201        }
202
203        compressor
204    }
205
206    /// Build dictionary using SIMD-accelerated frequency analysis
207    pub fn build_dictionary(&mut self, data: &[u8], max_pattern_length: usize) {
208        let mut pattern_counts: HashMap<Vec<u8>, u32> = HashMap::new();
209
210        // Count pattern frequencies using sliding window
211        for pattern_len in 2..=max_pattern_length {
212            if pattern_len > data.len() {
213                break;
214            }
215
216            for i in 0..=data.len() - pattern_len {
217                let pattern = data[i..i + pattern_len].to_vec();
218                *pattern_counts.entry(pattern).or_insert(0) += 1;
219            }
220        }
221
222        // Sort patterns by frequency and add most common ones to dictionary
223        let mut patterns: Vec<_> = pattern_counts.into_iter().collect();
224        patterns.sort_by_key(|b| core::cmp::Reverse(b.1));
225
226        for (pattern, count) in patterns {
227            if count >= 2 && self.next_code < u16::MAX && !self.dictionary.contains_key(&pattern) {
228                self.dictionary.insert(pattern.clone(), self.next_code);
229                self.reverse_dictionary.insert(self.next_code, pattern);
230                self.next_code += 1;
231            }
232        }
233    }
234
235    /// Compress data using the built dictionary
236    pub fn compress(&self, data: &[u8]) -> Vec<u16> {
237        let mut compressed = Vec::new();
238        let mut pos = 0;
239
240        while pos < data.len() {
241            let mut best_match_len = 1;
242            let mut best_code = data[pos] as u16;
243
244            // Try to find the longest matching pattern
245            for len in (2..=8.min(data.len() - pos)).rev() {
246                let pattern = &data[pos..pos + len];
247                if let Some(&code) = self.dictionary.get(pattern) {
248                    best_match_len = len;
249                    best_code = code;
250                    break;
251                }
252            }
253
254            compressed.push(best_code);
255            pos += best_match_len;
256        }
257
258        compressed
259    }
260
261    /// Decompress data using the dictionary
262    pub fn decompress(&self, compressed: &[u16]) -> Result<Vec<u8>, &'static str> {
263        let mut decompressed = Vec::new();
264
265        for &code in compressed {
266            if let Some(pattern) = self.reverse_dictionary.get(&code) {
267                decompressed.extend_from_slice(pattern);
268            } else {
269                return Err("Invalid code in compressed data");
270            }
271        }
272
273        Ok(decompressed)
274    }
275}
276
277/// SIMD-optimized byte frequency counter
278pub fn count_byte_frequencies_simd(data: &[u8]) -> [u32; 256] {
279    let mut frequencies = [0u32; 256];
280
281    // Process data in chunks for better cache efficiency
282    const CHUNK_SIZE: usize = 4096;
283
284    for chunk in data.chunks(CHUNK_SIZE) {
285        for &byte in chunk {
286            frequencies[byte as usize] += 1;
287        }
288    }
289
290    frequencies
291}
292
293/// Calculate compression ratio
294pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f64 {
295    if original_size == 0 {
296        return 0.0;
297    }
298    compressed_size as f64 / original_size as f64
299}
300
301#[allow(non_snake_case)]
302#[cfg(all(test, not(feature = "no-std")))]
303mod tests {
304    use super::*;
305
306    #[cfg(feature = "no-std")]
307    use alloc::{vec, vec::Vec};
308
309    #[test]
310    fn test_run_length_encode() {
311        let data = b"aaabbbccccdddd";
312        let encoded = run_length_encode_simd(data);
313        let expected = vec![(b'a', 3), (b'b', 3), (b'c', 4), (b'd', 4)];
314        assert_eq!(encoded, expected);
315    }
316
317    #[test]
318    fn test_run_length_decode() {
319        let encoded = vec![(b'a', 3), (b'b', 3), (b'c', 4), (b'd', 4)];
320        let decoded = run_length_decode(&encoded);
321        assert_eq!(decoded, b"aaabbbccccdddd");
322    }
323
324    #[test]
325    fn test_run_length_roundtrip() {
326        let original = b"aaaaabbbbcccccdddddeeeeee";
327        let encoded = run_length_encode_simd(original);
328        let decoded = run_length_decode(&encoded);
329        assert_eq!(decoded, original);
330    }
331
332    #[test]
333    fn test_lz77_compression() {
334        let compressor = LZ77Compressor::new(1024, 32);
335        let data = b"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz";
336        let compressed = compressor.compress(data);
337
338        // Should be able to compress repeated patterns
339        assert!(compressed.len() < data.len());
340    }
341
342    #[test]
343    fn test_dictionary_compression() {
344        let mut compressor = DictionaryCompressor::new();
345        let data = b"hello world hello world hello world";
346
347        compressor.build_dictionary(data, 8);
348        let compressed = compressor.compress(data);
349        let decompressed = compressor
350            .decompress(&compressed)
351            .expect("operation should succeed");
352
353        assert_eq!(decompressed, data);
354
355        // Calculate compression efficiency
356        let original_bits = data.len() * 8;
357        let compressed_bits = compressed.len() * 16; // 16 bits per code
358        assert!(compressed_bits < original_bits);
359    }
360
361    #[test]
362    fn test_byte_frequency_counter() {
363        let data = b"hello world";
364        let frequencies = count_byte_frequencies_simd(data);
365
366        assert_eq!(frequencies[b'h' as usize], 1);
367        assert_eq!(frequencies[b'e' as usize], 1);
368        assert_eq!(frequencies[b'l' as usize], 3);
369        assert_eq!(frequencies[b'o' as usize], 2);
370        assert_eq!(frequencies[b' ' as usize], 1);
371        assert_eq!(frequencies[b'w' as usize], 1);
372        assert_eq!(frequencies[b'r' as usize], 1);
373        assert_eq!(frequencies[b'd' as usize], 1);
374    }
375
376    #[test]
377    fn test_compression_ratio() {
378        let ratio = compression_ratio(1000, 750);
379        assert!((ratio - 0.75).abs() < f64::EPSILON);
380
381        let ratio_zero = compression_ratio(0, 100);
382        assert_eq!(ratio_zero, 0.0);
383    }
384
385    #[test]
386    fn test_empty_data() {
387        let empty_data = b"";
388        let encoded = run_length_encode_simd(empty_data);
389        assert!(encoded.is_empty());
390
391        let decoded = run_length_decode(&[]);
392        assert!(decoded.is_empty());
393    }
394
395    #[test]
396    fn test_single_byte() {
397        let data = b"a";
398        let encoded = run_length_encode_simd(data);
399        assert_eq!(encoded, vec![(b'a', 1)]);
400
401        let decoded = run_length_decode(&encoded);
402        assert_eq!(decoded, data);
403    }
404
405    #[test]
406    fn test_long_runs() {
407        let data = vec![b'x'; 1000];
408        let encoded = run_length_encode_simd(&data);
409        assert_eq!(encoded, vec![(b'x', 1000)]);
410
411        let decoded = run_length_decode(&encoded);
412        assert_eq!(decoded, data);
413    }
414}