Skip to main content

sochdb_storage/
compression.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Storage compression and optimization module
19//!
20//! Implements multi-tier compression strategy:
21//! - Hot data (recent): LZ4 for speed
22//! - Warm data (1-30 days): Zstd level 3 for balance
23//! - Cold data (>30 days): Zstd level 19 for maximum compression
24//!
25//! Also provides:
26//! - Deduplication for common patterns (system prompts)
27//! - Automatic tiering based on age
28//! - Compression ratio tracking
29
30use std::collections::HashMap;
31use std::io;
32use std::time::{SystemTime, UNIX_EPOCH};
33
34/// Compression type identifier
35#[repr(u8)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum CompressionType {
38    None = 0,
39    Lz4 = 1,
40    ZstdFast = 2, // Level 3
41    ZstdMax = 3,  // Level 19
42}
43
44impl CompressionType {
45    pub fn from_u8(value: u8) -> Self {
46        match value {
47            1 => CompressionType::Lz4,
48            2 => CompressionType::ZstdFast,
49            3 => CompressionType::ZstdMax,
50            _ => CompressionType::None,
51        }
52    }
53}
54
55/// Storage tier based on data age
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum StorageTier {
58    Hot,  // < 24 hours
59    Warm, // 1-30 days
60    Cold, // > 30 days
61}
62
63impl StorageTier {
64    /// Determine tier based on age
65    pub fn from_age(timestamp_us: u64) -> Self {
66        let now = SystemTime::now()
67            .duration_since(UNIX_EPOCH)
68            .unwrap()
69            .as_micros() as u64;
70
71        let age_us = now.saturating_sub(timestamp_us);
72        let age_hours = age_us / 3_600_000_000;
73
74        if age_hours < 24 {
75            StorageTier::Hot
76        } else if age_hours < 720 {
77            // 30 days
78            StorageTier::Warm
79        } else {
80            StorageTier::Cold
81        }
82    }
83
84    /// Get recommended compression for this tier
85    pub fn compression_type(&self) -> CompressionType {
86        match self {
87            StorageTier::Hot => CompressionType::Lz4, // Fast compression
88            StorageTier::Warm => CompressionType::ZstdFast, // Balanced
89            StorageTier::Cold => CompressionType::ZstdMax, // Maximum compression
90        }
91    }
92}
93
94/// Compression engine
95pub struct CompressionEngine {
96    /// Deduplication cache (hash -> compressed data)
97    dedup_cache: HashMap<u64, Vec<u8>>,
98    /// Compression statistics
99    stats: CompressionStats,
100}
101
102#[derive(Debug, Default, Clone)]
103pub struct CompressionStats {
104    pub total_uncompressed: u64,
105    pub total_compressed: u64,
106    pub lz4_count: u64,
107    pub zstd_fast_count: u64,
108    pub zstd_max_count: u64,
109    pub dedup_hits: u64,
110}
111
112impl CompressionStats {
113    pub fn compression_ratio(&self) -> f64 {
114        if self.total_uncompressed == 0 {
115            return 1.0;
116        }
117        self.total_compressed as f64 / self.total_uncompressed as f64
118    }
119
120    pub fn space_saved_bytes(&self) -> u64 {
121        self.total_uncompressed
122            .saturating_sub(self.total_compressed)
123    }
124}
125
126impl CompressionEngine {
127    pub fn new() -> Self {
128        Self {
129            dedup_cache: HashMap::new(),
130            stats: CompressionStats::default(),
131        }
132    }
133
134    /// Compress data using specified algorithm
135    pub fn compress(
136        &mut self,
137        data: &[u8],
138        compression: CompressionType,
139    ) -> Result<Vec<u8>, std::io::Error> {
140        self.stats.total_uncompressed += data.len() as u64;
141
142        let compressed = match compression {
143            CompressionType::None => data.to_vec(),
144            CompressionType::Lz4 => self.compress_lz4(data)?,
145            CompressionType::ZstdFast => self.compress_zstd(data, 3)?,
146            CompressionType::ZstdMax => self.compress_zstd(data, 19)?,
147        };
148
149        self.stats.total_compressed += compressed.len() as u64;
150
151        match compression {
152            CompressionType::Lz4 => self.stats.lz4_count += 1,
153            CompressionType::ZstdFast => self.stats.zstd_fast_count += 1,
154            CompressionType::ZstdMax => self.stats.zstd_max_count += 1,
155            _ => {}
156        }
157
158        Ok(compressed)
159    }
160
161    /// Decompress data
162    pub fn decompress(
163        &self,
164        data: &[u8],
165        compression: CompressionType,
166    ) -> Result<Vec<u8>, std::io::Error> {
167        match compression {
168            CompressionType::None => Ok(data.to_vec()),
169            CompressionType::Lz4 => self.decompress_lz4(data),
170            CompressionType::ZstdFast | CompressionType::ZstdMax => self.decompress_zstd(data),
171        }
172    }
173
174    /// Compress with deduplication
175    pub fn compress_with_dedup(
176        &mut self,
177        data: &[u8],
178        compression: CompressionType,
179    ) -> Result<Vec<u8>, std::io::Error> {
180        // Use xxHash3 for dedup hashing — 5× faster than SipHash, non-adversarial context
181        let hash = twox_hash::xxh3::hash64(data);
182
183        // Check dedup cache
184        if let Some(cached) = self.dedup_cache.get(&hash) {
185            self.stats.dedup_hits += 1;
186            return Ok(cached.clone());
187        }
188
189        // Compress and cache
190        let compressed = self.compress(data, compression)?;
191
192        // Only cache if it's worth it (data > 1KB and compression ratio > 2:1)
193        if data.len() > 1024 && compressed.len() > 0 && (data.len() / compressed.len()) >= 2 {
194            self.dedup_cache.insert(hash, compressed.clone());
195        }
196
197        Ok(compressed)
198    }
199
200    /// LZ4 compression using lz4_flex (block mode, ~3 GB/s throughput)
201    ///
202    /// Wire format: [original_len: u32 LE] [lz4_compressed_payload...]
203    /// If compressed output >= original size, falls back to uncompressed with len=0 sentinel.
204    fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
205        let compressed = lz4_flex::compress_prepend_size(data);
206        // Fallback: if compressed is larger than original + 4-byte header, store raw
207        if compressed.len() >= data.len() + 4 {
208            let mut output = Vec::with_capacity(data.len() + 4);
209            output.extend_from_slice(&0u32.to_le_bytes()); // 0 = uncompressed sentinel
210            output.extend_from_slice(data);
211            Ok(output)
212        } else {
213            Ok(compressed)
214        }
215    }
216
217    /// LZ4 decompression
218    fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
219        if data.len() < 4 {
220            return Err(io::Error::new(
221                io::ErrorKind::InvalidData,
222                "LZ4 data too short (< 4 bytes)",
223            ));
224        }
225        let original_len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
226        if original_len == 0 {
227            // Uncompressed fallback: sentinel 0 means raw payload follows
228            return Ok(data[4..].to_vec());
229        }
230        lz4_flex::decompress_size_prepended(data).map_err(|e| {
231            io::Error::new(
232                io::ErrorKind::InvalidData,
233                format!("LZ4 decompression failed: {}", e),
234            )
235        })
236    }
237
238    /// Zstd compression at the specified level
239    ///
240    /// Level 3: ~500 MB/s, ~3× ratio (warm tier)
241    /// Level 19: ~40 MB/s, ~4.5× ratio (cold tier — use from background compaction only)
242    ///
243    /// Wire format: raw zstd frame (self-describing, includes original size)
244    /// If compressed output >= original, falls back with a 4-byte sentinel header.
245    fn compress_zstd(&self, data: &[u8], level: i32) -> Result<Vec<u8>, std::io::Error> {
246        let compressed = zstd::encode_all(std::io::Cursor::new(data), level)?;
247        // Fallback: if compression didn't help, store raw with sentinel
248        if compressed.len() >= data.len() {
249            let mut output = Vec::with_capacity(data.len() + 4);
250            output.extend_from_slice(b"\x00\x00\x00\x00"); // 4 zero bytes = uncompressed sentinel
251            output.extend_from_slice(data);
252            Ok(output)
253        } else {
254            Ok(compressed)
255        }
256    }
257
258    /// Zstd decompression
259    fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
260        if data.len() < 4 {
261            return Err(io::Error::new(
262                io::ErrorKind::InvalidData,
263                "Zstd data too short (< 4 bytes)",
264            ));
265        }
266        // Check for uncompressed sentinel (4 zero bytes and NOT a valid zstd magic)
267        if &data[0..4] == b"\x00\x00\x00\x00" {
268            return Ok(data[4..].to_vec());
269        }
270        zstd::decode_all(std::io::Cursor::new(data))
271    }
272
273    /// Get compression statistics
274    pub fn stats(&self) -> &CompressionStats {
275        &self.stats
276    }
277
278    /// Clear deduplication cache
279    pub fn clear_cache(&mut self) {
280        self.dedup_cache.clear();
281    }
282
283    /// Get cache size in bytes
284    pub fn cache_size(&self) -> usize {
285        self.dedup_cache.values().map(|v| v.len()).sum()
286    }
287}
288
289impl Default for CompressionEngine {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295/// Helper: Determine optimal compression for payload
296pub fn choose_compression(size: usize, age_us: u64) -> CompressionType {
297    // Small payloads: don't compress (overhead not worth it)
298    if size < 512 {
299        return CompressionType::None;
300    }
301
302    // Use tier-based compression
303    let tier = StorageTier::from_age(age_us);
304    tier.compression_type()
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_storage_tier() {
313        let now = SystemTime::now()
314            .duration_since(UNIX_EPOCH)
315            .unwrap()
316            .as_micros() as u64;
317
318        // Recent data -> Hot
319        let tier = StorageTier::from_age(now - 3_600_000_000); // 1 hour ago
320        assert_eq!(tier, StorageTier::Hot);
321
322        // Week old -> Warm
323        let tier = StorageTier::from_age(now - 604_800_000_000); // 7 days ago
324        assert_eq!(tier, StorageTier::Warm);
325
326        // Very old -> Cold
327        let tier = StorageTier::from_age(now - 3_000_000_000_000); // ~35 days ago
328        assert_eq!(tier, StorageTier::Cold);
329    }
330
331    #[test]
332    fn test_lz4_roundtrip() {
333        let mut engine = CompressionEngine::new();
334        let data = b"Hello, World! This is test data for LZ4 compression roundtrip.";
335
336        let compressed = engine.compress(data, CompressionType::Lz4).unwrap();
337        let decompressed = engine
338            .decompress(&compressed, CompressionType::Lz4)
339            .unwrap();
340
341        assert_eq!(data.as_slice(), decompressed.as_slice());
342    }
343
344    #[test]
345    fn test_zstd_fast_roundtrip() {
346        let mut engine = CompressionEngine::new();
347        let data = b"Hello, World! This is test data for Zstd level-3 compression roundtrip.";
348
349        let compressed = engine.compress(data, CompressionType::ZstdFast).unwrap();
350        let decompressed = engine
351            .decompress(&compressed, CompressionType::ZstdFast)
352            .unwrap();
353
354        assert_eq!(data.as_slice(), decompressed.as_slice());
355    }
356
357    #[test]
358    fn test_zstd_max_roundtrip() {
359        let mut engine = CompressionEngine::new();
360        let data =
361            b"Hello, World! This is test data for Zstd level-19 maximum compression roundtrip.";
362
363        let compressed = engine.compress(data, CompressionType::ZstdMax).unwrap();
364        let decompressed = engine
365            .decompress(&compressed, CompressionType::ZstdMax)
366            .unwrap();
367
368        assert_eq!(data.as_slice(), decompressed.as_slice());
369    }
370
371    #[test]
372    fn test_real_compression_ratio() {
373        let mut engine = CompressionEngine::new();
374        // Highly compressible data: repeated pattern
375        let data: Vec<u8> = "The quick brown fox jumps over the lazy dog. "
376            .repeat(100)
377            .into_bytes();
378
379        let lz4 = engine.compress(&data, CompressionType::Lz4).unwrap();
380        assert!(
381            lz4.len() < data.len(),
382            "LZ4 should compress repetitive data: {} -> {}",
383            data.len(),
384            lz4.len()
385        );
386
387        let mut engine2 = CompressionEngine::new();
388        let zstd_fast = engine2.compress(&data, CompressionType::ZstdFast).unwrap();
389        assert!(
390            zstd_fast.len() < data.len(),
391            "ZstdFast should compress repetitive data: {} -> {}",
392            data.len(),
393            zstd_fast.len()
394        );
395
396        let mut engine3 = CompressionEngine::new();
397        let zstd_max = engine3.compress(&data, CompressionType::ZstdMax).unwrap();
398        assert!(
399            zstd_max.len() < data.len(),
400            "ZstdMax should compress repetitive data: {} -> {}",
401            data.len(),
402            zstd_max.len()
403        );
404
405        // ZstdMax should compress at least as well as ZstdFast
406        assert!(
407            zstd_max.len() <= zstd_fast.len(),
408            "ZstdMax ({}) should be <= ZstdFast ({})",
409            zstd_max.len(),
410            zstd_fast.len()
411        );
412    }
413
414    #[test]
415    fn test_compression_stats() {
416        let mut engine = CompressionEngine::new();
417        let data: Vec<u8> = "Test data for compression statistics. "
418            .repeat(50)
419            .into_bytes();
420
421        engine.compress(&data, CompressionType::Lz4).unwrap();
422
423        let stats = engine.stats();
424        assert!(stats.total_uncompressed > 0);
425        assert!(stats.total_compressed > 0);
426        assert_eq!(stats.lz4_count, 1);
427        // Real compression should actually save space on repetitive data
428        assert!(
429            stats.space_saved_bytes() > 0,
430            "Should save space on compressible data"
431        );
432        assert!(
433            stats.compression_ratio() < 1.0,
434            "Ratio should be < 1.0 (compressed smaller than original)"
435        );
436    }
437
438    #[test]
439    fn test_deduplication() {
440        let mut engine = CompressionEngine::new();
441        // Data must be > 1024 bytes AND achieve 2:1 compression for caching
442        let data: Vec<u8> = "Repeated system prompt for deduplication testing. "
443            .repeat(100)
444            .into_bytes();
445        assert!(data.len() > 1024);
446
447        // First call: compresses and caches
448        let first = engine
449            .compress_with_dedup(&data, CompressionType::Lz4)
450            .unwrap();
451        assert_eq!(engine.stats().dedup_hits, 0);
452
453        // Second call: should hit dedup cache
454        let second = engine
455            .compress_with_dedup(&data, CompressionType::Lz4)
456            .unwrap();
457        assert_eq!(engine.stats().dedup_hits, 1);
458        assert_eq!(first, second);
459    }
460
461    #[test]
462    fn test_small_data_fallback() {
463        // Data too small to compress effectively — should still roundtrip correctly
464        let mut engine = CompressionEngine::new();
465        let data = b"tiny";
466
467        let lz4 = engine.compress(data, CompressionType::Lz4).unwrap();
468        let rt = engine.decompress(&lz4, CompressionType::Lz4).unwrap();
469        assert_eq!(data.as_slice(), rt.as_slice());
470
471        let mut engine2 = CompressionEngine::new();
472        let zstd = engine2.compress(data, CompressionType::ZstdFast).unwrap();
473        let rt2 = engine2
474            .decompress(&zstd, CompressionType::ZstdFast)
475            .unwrap();
476        assert_eq!(data.as_slice(), rt2.as_slice());
477    }
478
479    #[test]
480    fn test_choose_compression() {
481        let now = SystemTime::now()
482            .duration_since(UNIX_EPOCH)
483            .unwrap()
484            .as_micros() as u64;
485
486        // Small payload -> None
487        assert_eq!(choose_compression(100, now), CompressionType::None);
488
489        // Recent large payload -> LZ4
490        assert_eq!(choose_compression(10000, now), CompressionType::Lz4);
491
492        // Old large payload -> ZstdMax
493        let old = now - 4_000_000_000_000; // ~46 days ago
494        assert_eq!(choose_compression(10000, old), CompressionType::ZstdMax);
495    }
496
497    #[test]
498    fn test_none_compression_passthrough() {
499        let mut engine = CompressionEngine::new();
500        let data = b"no compression applied";
501
502        let compressed = engine.compress(data, CompressionType::None).unwrap();
503        assert_eq!(data.as_slice(), compressed.as_slice());
504
505        let decompressed = engine
506            .decompress(&compressed, CompressionType::None)
507            .unwrap();
508        assert_eq!(data.as_slice(), decompressed.as_slice());
509    }
510
511    #[test]
512    fn test_large_payload_compression() {
513        let mut engine = CompressionEngine::new();
514        // Simulate a large LLM conversation context (JSON-like)
515        let data: Vec<u8> = (0..10000)
516            .map(|i| format!("{{\"role\":\"user\",\"content\":\"message {}\"}},", i))
517            .collect::<String>()
518            .into_bytes();
519
520        let compressed = engine.compress(&data, CompressionType::ZstdFast).unwrap();
521        let ratio = compressed.len() as f64 / data.len() as f64;
522        assert!(
523            ratio < 0.5,
524            "Large repetitive JSON should compress to < 50%: ratio={:.3}",
525            ratio
526        );
527
528        let decompressed = engine
529            .decompress(&compressed, CompressionType::ZstdFast)
530            .unwrap();
531        assert_eq!(data, decompressed);
532    }
533}