Skip to main content

sochdb_storage/
compression.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Storage compression and optimization module
19//!
20//! Implements multi-tier compression strategy:
21//! - Hot data (recent): LZ4 for speed
22//! - Warm data (1-30 days): Zstd level 3 for balance
23//! - Cold data (>30 days): Zstd level 19 for maximum compression
24//!
25//! Also provides:
26//! - Deduplication for common patterns (system prompts)
27//! - Automatic tiering based on age
28//! - Compression ratio tracking
29
30use std::collections::HashMap;
31use std::io;
32use std::time::{SystemTime, UNIX_EPOCH};
33
34/// Compression type identifier
35#[repr(u8)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum CompressionType {
38    None = 0,
39    Lz4 = 1,
40    ZstdFast = 2, // Level 3
41    ZstdMax = 3,  // Level 19
42}
43
44impl CompressionType {
45    pub fn from_u8(value: u8) -> Self {
46        match value {
47            1 => CompressionType::Lz4,
48            2 => CompressionType::ZstdFast,
49            3 => CompressionType::ZstdMax,
50            _ => CompressionType::None,
51        }
52    }
53}
54
55/// Storage tier based on data age
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum StorageTier {
58    Hot,  // < 24 hours
59    Warm, // 1-30 days
60    Cold, // > 30 days
61}
62
63impl StorageTier {
64    /// Determine tier based on age
65    pub fn from_age(timestamp_us: u64) -> Self {
66        let now = SystemTime::now()
67            .duration_since(UNIX_EPOCH)
68            .unwrap()
69            .as_micros() as u64;
70
71        let age_us = now.saturating_sub(timestamp_us);
72        let age_hours = age_us / 3_600_000_000;
73
74        if age_hours < 24 {
75            StorageTier::Hot
76        } else if age_hours < 720 {
77            // 30 days
78            StorageTier::Warm
79        } else {
80            StorageTier::Cold
81        }
82    }
83
84    /// Get recommended compression for this tier
85    pub fn compression_type(&self) -> CompressionType {
86        match self {
87            StorageTier::Hot => CompressionType::Lz4, // Fast compression
88            StorageTier::Warm => CompressionType::ZstdFast, // Balanced
89            StorageTier::Cold => CompressionType::ZstdMax, // Maximum compression
90        }
91    }
92}
93
94/// Compression engine
95pub struct CompressionEngine {
96    /// Deduplication cache (hash -> compressed data)
97    dedup_cache: HashMap<u64, Vec<u8>>,
98    /// Compression statistics
99    stats: CompressionStats,
100}
101
102#[derive(Debug, Default, Clone)]
103pub struct CompressionStats {
104    pub total_uncompressed: u64,
105    pub total_compressed: u64,
106    pub lz4_count: u64,
107    pub zstd_fast_count: u64,
108    pub zstd_max_count: u64,
109    pub dedup_hits: u64,
110}
111
112impl CompressionStats {
113    pub fn compression_ratio(&self) -> f64 {
114        if self.total_uncompressed == 0 {
115            return 1.0;
116        }
117        self.total_compressed as f64 / self.total_uncompressed as f64
118    }
119
120    pub fn space_saved_bytes(&self) -> u64 {
121        self.total_uncompressed
122            .saturating_sub(self.total_compressed)
123    }
124}
125
126impl CompressionEngine {
127    pub fn new() -> Self {
128        Self {
129            dedup_cache: HashMap::new(),
130            stats: CompressionStats::default(),
131        }
132    }
133
134    /// Compress data using specified algorithm
135    pub fn compress(
136        &mut self,
137        data: &[u8],
138        compression: CompressionType,
139    ) -> Result<Vec<u8>, std::io::Error> {
140        self.stats.total_uncompressed += data.len() as u64;
141
142        let compressed = match compression {
143            CompressionType::None => data.to_vec(),
144            CompressionType::Lz4 => self.compress_lz4(data)?,
145            CompressionType::ZstdFast => self.compress_zstd(data, 3)?,
146            CompressionType::ZstdMax => self.compress_zstd(data, 19)?,
147        };
148
149        self.stats.total_compressed += compressed.len() as u64;
150
151        match compression {
152            CompressionType::Lz4 => self.stats.lz4_count += 1,
153            CompressionType::ZstdFast => self.stats.zstd_fast_count += 1,
154            CompressionType::ZstdMax => self.stats.zstd_max_count += 1,
155            _ => {}
156        }
157
158        Ok(compressed)
159    }
160
161    /// Decompress data
162    pub fn decompress(
163        &self,
164        data: &[u8],
165        compression: CompressionType,
166    ) -> Result<Vec<u8>, std::io::Error> {
167        match compression {
168            CompressionType::None => Ok(data.to_vec()),
169            CompressionType::Lz4 => self.decompress_lz4(data),
170            CompressionType::ZstdFast | CompressionType::ZstdMax => self.decompress_zstd(data),
171        }
172    }
173
174    /// Compress with deduplication
175    pub fn compress_with_dedup(
176        &mut self,
177        data: &[u8],
178        compression: CompressionType,
179    ) -> Result<Vec<u8>, std::io::Error> {
180        // Use xxHash3 for dedup hashing — 5× faster than SipHash, non-adversarial context
181        let hash = twox_hash::xxh3::hash64(data);
182
183        // Check dedup cache
184        if let Some(cached) = self.dedup_cache.get(&hash) {
185            self.stats.dedup_hits += 1;
186            return Ok(cached.clone());
187        }
188
189        // Compress and cache
190        let compressed = self.compress(data, compression)?;
191
192        // Only cache if it's worth it (data > 1KB and compression ratio > 2:1)
193        if data.len() > 1024 && compressed.len() > 0 && (data.len() / compressed.len()) >= 2 {
194            self.dedup_cache.insert(hash, compressed.clone());
195        }
196
197        Ok(compressed)
198    }
199
200    /// LZ4 compression using lz4_flex (block mode, ~3 GB/s throughput)
201    ///
202    /// Wire format: [original_len: u32 LE] [lz4_compressed_payload...]
203    /// If compressed output >= original size, falls back to uncompressed with len=0 sentinel.
204    fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
205        let compressed = lz4_flex::compress_prepend_size(data);
206        // Fallback: if compressed is larger than original + 4-byte header, store raw
207        if compressed.len() >= data.len() + 4 {
208            let mut output = Vec::with_capacity(data.len() + 4);
209            output.extend_from_slice(&0u32.to_le_bytes()); // 0 = uncompressed sentinel
210            output.extend_from_slice(data);
211            Ok(output)
212        } else {
213            Ok(compressed)
214        }
215    }
216
217    /// LZ4 decompression
218    fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
219        if data.len() < 4 {
220            return Err(io::Error::new(
221                io::ErrorKind::InvalidData,
222                "LZ4 data too short (< 4 bytes)",
223            ));
224        }
225        let original_len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
226        if original_len == 0 {
227            // Uncompressed fallback: sentinel 0 means raw payload follows
228            return Ok(data[4..].to_vec());
229        }
230        lz4_flex::decompress_size_prepended(data).map_err(|e| {
231            io::Error::new(
232                io::ErrorKind::InvalidData,
233                format!("LZ4 decompression failed: {}", e),
234            )
235        })
236    }
237
238    /// Zstd compression at the specified level
239    ///
240    /// Level 3: ~500 MB/s, ~3× ratio (warm tier)
241    /// Level 19: ~40 MB/s, ~4.5× ratio (cold tier — use from background compaction only)
242    ///
243    /// Wire format: raw zstd frame (self-describing, includes original size)
244    /// If compressed output >= original, falls back with a 4-byte sentinel header.
245    fn compress_zstd(&self, data: &[u8], level: i32) -> Result<Vec<u8>, std::io::Error> {
246        let compressed = zstd::encode_all(std::io::Cursor::new(data), level)?;
247        // Fallback: if compression didn't help, store raw with sentinel
248        if compressed.len() >= data.len() {
249            let mut output = Vec::with_capacity(data.len() + 4);
250            output.extend_from_slice(b"\x00\x00\x00\x00"); // 4 zero bytes = uncompressed sentinel
251            output.extend_from_slice(data);
252            Ok(output)
253        } else {
254            Ok(compressed)
255        }
256    }
257
258    /// Zstd decompression
259    fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
260        if data.len() < 4 {
261            return Err(io::Error::new(
262                io::ErrorKind::InvalidData,
263                "Zstd data too short (< 4 bytes)",
264            ));
265        }
266        // Check for uncompressed sentinel (4 zero bytes and NOT a valid zstd magic)
267        if &data[0..4] == b"\x00\x00\x00\x00" {
268            return Ok(data[4..].to_vec());
269        }
270        zstd::decode_all(std::io::Cursor::new(data))
271    }
272
273    /// Get compression statistics
274    pub fn stats(&self) -> &CompressionStats {
275        &self.stats
276    }
277
278    /// Clear deduplication cache
279    pub fn clear_cache(&mut self) {
280        self.dedup_cache.clear();
281    }
282
283    /// Get cache size in bytes
284    pub fn cache_size(&self) -> usize {
285        self.dedup_cache.values().map(|v| v.len()).sum()
286    }
287}
288
289impl Default for CompressionEngine {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295/// Helper: Determine optimal compression for payload
296pub fn choose_compression(size: usize, age_us: u64) -> CompressionType {
297    // Small payloads: don't compress (overhead not worth it)
298    if size < 512 {
299        return CompressionType::None;
300    }
301
302    // Use tier-based compression
303    let tier = StorageTier::from_age(age_us);
304    tier.compression_type()
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_storage_tier() {
313        let now = SystemTime::now()
314            .duration_since(UNIX_EPOCH)
315            .unwrap()
316            .as_micros() as u64;
317
318        // Recent data -> Hot
319        let tier = StorageTier::from_age(now - 3_600_000_000); // 1 hour ago
320        assert_eq!(tier, StorageTier::Hot);
321
322        // Week old -> Warm
323        let tier = StorageTier::from_age(now - 604_800_000_000); // 7 days ago
324        assert_eq!(tier, StorageTier::Warm);
325
326        // Very old -> Cold
327        let tier = StorageTier::from_age(now - 3_000_000_000_000); // ~35 days ago
328        assert_eq!(tier, StorageTier::Cold);
329    }
330
331    #[test]
332    fn test_lz4_roundtrip() {
333        let mut engine = CompressionEngine::new();
334        let data = b"Hello, World! This is test data for LZ4 compression roundtrip.";
335
336        let compressed = engine.compress(data, CompressionType::Lz4).unwrap();
337        let decompressed = engine
338            .decompress(&compressed, CompressionType::Lz4)
339            .unwrap();
340
341        assert_eq!(data.as_slice(), decompressed.as_slice());
342    }
343
344    #[test]
345    fn test_zstd_fast_roundtrip() {
346        let mut engine = CompressionEngine::new();
347        let data = b"Hello, World! This is test data for Zstd level-3 compression roundtrip.";
348
349        let compressed = engine.compress(data, CompressionType::ZstdFast).unwrap();
350        let decompressed = engine
351            .decompress(&compressed, CompressionType::ZstdFast)
352            .unwrap();
353
354        assert_eq!(data.as_slice(), decompressed.as_slice());
355    }
356
357    #[test]
358    fn test_zstd_max_roundtrip() {
359        let mut engine = CompressionEngine::new();
360        let data = b"Hello, World! This is test data for Zstd level-19 maximum compression roundtrip.";
361
362        let compressed = engine.compress(data, CompressionType::ZstdMax).unwrap();
363        let decompressed = engine
364            .decompress(&compressed, CompressionType::ZstdMax)
365            .unwrap();
366
367        assert_eq!(data.as_slice(), decompressed.as_slice());
368    }
369
370    #[test]
371    fn test_real_compression_ratio() {
372        let mut engine = CompressionEngine::new();
373        // Highly compressible data: repeated pattern
374        let data: Vec<u8> = "The quick brown fox jumps over the lazy dog. "
375            .repeat(100)
376            .into_bytes();
377
378        let lz4 = engine.compress(&data, CompressionType::Lz4).unwrap();
379        assert!(
380            lz4.len() < data.len(),
381            "LZ4 should compress repetitive data: {} -> {}",
382            data.len(),
383            lz4.len()
384        );
385
386        let mut engine2 = CompressionEngine::new();
387        let zstd_fast = engine2.compress(&data, CompressionType::ZstdFast).unwrap();
388        assert!(
389            zstd_fast.len() < data.len(),
390            "ZstdFast should compress repetitive data: {} -> {}",
391            data.len(),
392            zstd_fast.len()
393        );
394
395        let mut engine3 = CompressionEngine::new();
396        let zstd_max = engine3.compress(&data, CompressionType::ZstdMax).unwrap();
397        assert!(
398            zstd_max.len() < data.len(),
399            "ZstdMax should compress repetitive data: {} -> {}",
400            data.len(),
401            zstd_max.len()
402        );
403
404        // ZstdMax should compress at least as well as ZstdFast
405        assert!(
406            zstd_max.len() <= zstd_fast.len(),
407            "ZstdMax ({}) should be <= ZstdFast ({})",
408            zstd_max.len(),
409            zstd_fast.len()
410        );
411    }
412
413    #[test]
414    fn test_compression_stats() {
415        let mut engine = CompressionEngine::new();
416        let data: Vec<u8> = "Test data for compression statistics. ".repeat(50).into_bytes();
417
418        engine.compress(&data, CompressionType::Lz4).unwrap();
419
420        let stats = engine.stats();
421        assert!(stats.total_uncompressed > 0);
422        assert!(stats.total_compressed > 0);
423        assert_eq!(stats.lz4_count, 1);
424        // Real compression should actually save space on repetitive data
425        assert!(
426            stats.space_saved_bytes() > 0,
427            "Should save space on compressible data"
428        );
429        assert!(
430            stats.compression_ratio() < 1.0,
431            "Ratio should be < 1.0 (compressed smaller than original)"
432        );
433    }
434
435    #[test]
436    fn test_deduplication() {
437        let mut engine = CompressionEngine::new();
438        // Data must be > 1024 bytes AND achieve 2:1 compression for caching
439        let data: Vec<u8> = "Repeated system prompt for deduplication testing. "
440            .repeat(100)
441            .into_bytes();
442        assert!(data.len() > 1024);
443
444        // First call: compresses and caches
445        let first = engine
446            .compress_with_dedup(&data, CompressionType::Lz4)
447            .unwrap();
448        assert_eq!(engine.stats().dedup_hits, 0);
449
450        // Second call: should hit dedup cache
451        let second = engine
452            .compress_with_dedup(&data, CompressionType::Lz4)
453            .unwrap();
454        assert_eq!(engine.stats().dedup_hits, 1);
455        assert_eq!(first, second);
456    }
457
458    #[test]
459    fn test_small_data_fallback() {
460        // Data too small to compress effectively — should still roundtrip correctly
461        let mut engine = CompressionEngine::new();
462        let data = b"tiny";
463
464        let lz4 = engine.compress(data, CompressionType::Lz4).unwrap();
465        let rt = engine.decompress(&lz4, CompressionType::Lz4).unwrap();
466        assert_eq!(data.as_slice(), rt.as_slice());
467
468        let mut engine2 = CompressionEngine::new();
469        let zstd = engine2.compress(data, CompressionType::ZstdFast).unwrap();
470        let rt2 = engine2.decompress(&zstd, CompressionType::ZstdFast).unwrap();
471        assert_eq!(data.as_slice(), rt2.as_slice());
472    }
473
474    #[test]
475    fn test_choose_compression() {
476        let now = SystemTime::now()
477            .duration_since(UNIX_EPOCH)
478            .unwrap()
479            .as_micros() as u64;
480
481        // Small payload -> None
482        assert_eq!(choose_compression(100, now), CompressionType::None);
483
484        // Recent large payload -> LZ4
485        assert_eq!(choose_compression(10000, now), CompressionType::Lz4);
486
487        // Old large payload -> ZstdMax
488        let old = now - 4_000_000_000_000; // ~46 days ago
489        assert_eq!(choose_compression(10000, old), CompressionType::ZstdMax);
490    }
491
492    #[test]
493    fn test_none_compression_passthrough() {
494        let mut engine = CompressionEngine::new();
495        let data = b"no compression applied";
496
497        let compressed = engine.compress(data, CompressionType::None).unwrap();
498        assert_eq!(data.as_slice(), compressed.as_slice());
499
500        let decompressed = engine
501            .decompress(&compressed, CompressionType::None)
502            .unwrap();
503        assert_eq!(data.as_slice(), decompressed.as_slice());
504    }
505
506    #[test]
507    fn test_large_payload_compression() {
508        let mut engine = CompressionEngine::new();
509        // Simulate a large LLM conversation context (JSON-like)
510        let data: Vec<u8> = (0..10000)
511            .map(|i| format!("{{\"role\":\"user\",\"content\":\"message {}\"}},", i))
512            .collect::<String>()
513            .into_bytes();
514
515        let compressed = engine.compress(&data, CompressionType::ZstdFast).unwrap();
516        let ratio = compressed.len() as f64 / data.len() as f64;
517        assert!(
518            ratio < 0.5,
519            "Large repetitive JSON should compress to < 50%: ratio={:.3}",
520            ratio
521        );
522
523        let decompressed = engine
524            .decompress(&compressed, CompressionType::ZstdFast)
525            .unwrap();
526        assert_eq!(data, decompressed);
527    }
528}