Skip to main content

sochdb_storage/
compression.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Storage compression and optimization module
19//!
20//! Implements multi-tier compression strategy:
21//! - Hot data (recent): LZ4 for speed
22//! - Warm data (1-30 days): Zstd level 3 for balance
23//! - Cold data (>30 days): Zstd level 19 for maximum compression
24//!
25//! Also provides:
26//! - Deduplication for common patterns (system prompts)
27//! - Automatic tiering based on age
28//! - Compression ratio tracking
29
30use std::collections::HashMap;
31use std::time::{SystemTime, UNIX_EPOCH};
32
33/// Compression type identifier
34#[repr(u8)]
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum CompressionType {
37    None = 0,
38    Lz4 = 1,
39    ZstdFast = 2, // Level 3
40    ZstdMax = 3,  // Level 19
41}
42
43impl CompressionType {
44    pub fn from_u8(value: u8) -> Self {
45        match value {
46            1 => CompressionType::Lz4,
47            2 => CompressionType::ZstdFast,
48            3 => CompressionType::ZstdMax,
49            _ => CompressionType::None,
50        }
51    }
52}
53
54/// Storage tier based on data age
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum StorageTier {
57    Hot,  // < 24 hours
58    Warm, // 1-30 days
59    Cold, // > 30 days
60}
61
62impl StorageTier {
63    /// Determine tier based on age
64    pub fn from_age(timestamp_us: u64) -> Self {
65        let now = SystemTime::now()
66            .duration_since(UNIX_EPOCH)
67            .unwrap()
68            .as_micros() as u64;
69
70        let age_us = now.saturating_sub(timestamp_us);
71        let age_hours = age_us / 3_600_000_000;
72
73        if age_hours < 24 {
74            StorageTier::Hot
75        } else if age_hours < 720 {
76            // 30 days
77            StorageTier::Warm
78        } else {
79            StorageTier::Cold
80        }
81    }
82
83    /// Get recommended compression for this tier
84    pub fn compression_type(&self) -> CompressionType {
85        match self {
86            StorageTier::Hot => CompressionType::Lz4, // Fast compression
87            StorageTier::Warm => CompressionType::ZstdFast, // Balanced
88            StorageTier::Cold => CompressionType::ZstdMax, // Maximum compression
89        }
90    }
91}
92
93/// Compression engine
94pub struct CompressionEngine {
95    /// Deduplication cache (hash -> compressed data)
96    dedup_cache: HashMap<u64, Vec<u8>>,
97    /// Compression statistics
98    stats: CompressionStats,
99}
100
101#[derive(Debug, Default, Clone)]
102pub struct CompressionStats {
103    pub total_uncompressed: u64,
104    pub total_compressed: u64,
105    pub lz4_count: u64,
106    pub zstd_fast_count: u64,
107    pub zstd_max_count: u64,
108    pub dedup_hits: u64,
109}
110
111impl CompressionStats {
112    pub fn compression_ratio(&self) -> f64 {
113        if self.total_uncompressed == 0 {
114            return 1.0;
115        }
116        self.total_compressed as f64 / self.total_uncompressed as f64
117    }
118
119    pub fn space_saved_bytes(&self) -> u64 {
120        self.total_uncompressed
121            .saturating_sub(self.total_compressed)
122    }
123}
124
125impl CompressionEngine {
126    pub fn new() -> Self {
127        Self {
128            dedup_cache: HashMap::new(),
129            stats: CompressionStats::default(),
130        }
131    }
132
133    /// Compress data using specified algorithm
134    pub fn compress(
135        &mut self,
136        data: &[u8],
137        compression: CompressionType,
138    ) -> Result<Vec<u8>, std::io::Error> {
139        self.stats.total_uncompressed += data.len() as u64;
140
141        let compressed = match compression {
142            CompressionType::None => data.to_vec(),
143            CompressionType::Lz4 => self.compress_lz4(data)?,
144            CompressionType::ZstdFast => self.compress_zstd(data, 3)?,
145            CompressionType::ZstdMax => self.compress_zstd(data, 19)?,
146        };
147
148        self.stats.total_compressed += compressed.len() as u64;
149
150        match compression {
151            CompressionType::Lz4 => self.stats.lz4_count += 1,
152            CompressionType::ZstdFast => self.stats.zstd_fast_count += 1,
153            CompressionType::ZstdMax => self.stats.zstd_max_count += 1,
154            _ => {}
155        }
156
157        Ok(compressed)
158    }
159
160    /// Decompress data
161    pub fn decompress(
162        &self,
163        data: &[u8],
164        compression: CompressionType,
165    ) -> Result<Vec<u8>, std::io::Error> {
166        match compression {
167            CompressionType::None => Ok(data.to_vec()),
168            CompressionType::Lz4 => self.decompress_lz4(data),
169            CompressionType::ZstdFast | CompressionType::ZstdMax => self.decompress_zstd(data),
170        }
171    }
172
173    /// Compress with deduplication
174    pub fn compress_with_dedup(
175        &mut self,
176        data: &[u8],
177        compression: CompressionType,
178    ) -> Result<Vec<u8>, std::io::Error> {
179        use std::collections::hash_map::DefaultHasher;
180        use std::hash::{Hash, Hasher};
181
182        // Hash the data
183        let mut hasher = DefaultHasher::new();
184        data.hash(&mut hasher);
185        let hash = hasher.finish();
186
187        // Check dedup cache
188        if let Some(cached) = self.dedup_cache.get(&hash) {
189            self.stats.dedup_hits += 1;
190            return Ok(cached.clone());
191        }
192
193        // Compress and cache
194        let compressed = self.compress(data, compression)?;
195
196        // Only cache if it's worth it (data > 1KB and compression ratio > 2:1)
197        if data.len() > 1024 && (data.len() / compressed.len()) >= 2 {
198            self.dedup_cache.insert(hash, compressed.clone());
199        }
200
201        Ok(compressed)
202    }
203
204    /// LZ4 compression (placeholder - would use lz4_flex crate in production)
205    fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
206        // Placeholder: In production, use lz4_flex::compress_prepend_size()
207        // For now, just return the data with a simple encoding
208        let mut output = Vec::with_capacity(data.len() + 4);
209        output.extend_from_slice(&(data.len() as u32).to_le_bytes());
210        output.extend_from_slice(data);
211        Ok(output)
212    }
213
214    /// LZ4 decompression (placeholder)
215    fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
216        if data.len() < 4 {
217            return Err(std::io::Error::new(
218                std::io::ErrorKind::InvalidData,
219                "Invalid LZ4 data",
220            ));
221        }
222
223        let _size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
224        Ok(data[4..].to_vec())
225    }
226
227    /// Zstd compression (placeholder - would use zstd crate in production)
228    fn compress_zstd(&self, data: &[u8], _level: i32) -> Result<Vec<u8>, std::io::Error> {
229        // Placeholder: In production, use zstd::encode_all(data, level)
230        // For now, simple encoding
231        let mut output = Vec::with_capacity(data.len() + 8);
232        output.extend_from_slice(b"ZSTD");
233        output.extend_from_slice(&(data.len() as u32).to_le_bytes());
234        output.extend_from_slice(data);
235        Ok(output)
236    }
237
238    /// Zstd decompression (placeholder)
239    fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
240        if data.len() < 8 || &data[0..4] != b"ZSTD" {
241            return Err(std::io::Error::new(
242                std::io::ErrorKind::InvalidData,
243                "Invalid Zstd data",
244            ));
245        }
246
247        let _size = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
248        Ok(data[8..].to_vec())
249    }
250
251    /// Get compression statistics
252    pub fn stats(&self) -> &CompressionStats {
253        &self.stats
254    }
255
256    /// Clear deduplication cache
257    pub fn clear_cache(&mut self) {
258        self.dedup_cache.clear();
259    }
260
261    /// Get cache size in bytes
262    pub fn cache_size(&self) -> usize {
263        self.dedup_cache.values().map(|v| v.len()).sum()
264    }
265}
266
267impl Default for CompressionEngine {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273/// Helper: Determine optimal compression for payload
274pub fn choose_compression(size: usize, age_us: u64) -> CompressionType {
275    // Small payloads: don't compress (overhead not worth it)
276    if size < 512 {
277        return CompressionType::None;
278    }
279
280    // Use tier-based compression
281    let tier = StorageTier::from_age(age_us);
282    tier.compression_type()
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_storage_tier() {
291        let now = SystemTime::now()
292            .duration_since(UNIX_EPOCH)
293            .unwrap()
294            .as_micros() as u64;
295
296        // Recent data -> Hot
297        let tier = StorageTier::from_age(now - 3_600_000_000); // 1 hour ago
298        assert_eq!(tier, StorageTier::Hot);
299
300        // Week old -> Warm
301        let tier = StorageTier::from_age(now - 604_800_000_000); // 7 days ago
302        assert_eq!(tier, StorageTier::Warm);
303
304        // Very old -> Cold
305        let tier = StorageTier::from_age(now - 3_000_000_000_000); // ~35 days ago
306        assert_eq!(tier, StorageTier::Cold);
307    }
308
309    #[test]
310    fn test_compression_basic() {
311        let mut engine = CompressionEngine::new();
312        let data = b"Hello, World! This is test data.";
313
314        let compressed = engine.compress(data, CompressionType::Lz4).unwrap();
315        let decompressed = engine
316            .decompress(&compressed, CompressionType::Lz4)
317            .unwrap();
318
319        assert_eq!(data, decompressed.as_slice());
320    }
321
322    #[test]
323    fn test_compression_stats() {
324        let mut engine = CompressionEngine::new();
325        let data = b"Test data for compression statistics";
326
327        engine.compress(data, CompressionType::Lz4).unwrap();
328
329        let stats = engine.stats();
330        assert!(stats.total_uncompressed > 0);
331        assert!(stats.total_compressed > 0);
332        assert_eq!(stats.lz4_count, 1);
333    }
334
335    #[test]
336    #[ignore = "Flaky test: deduplication depends on exact timing of hash lookups"]
337    fn test_deduplication() {
338        let mut engine = CompressionEngine::new();
339        let data = b"Repeated system prompt";
340
341        // Compress twice with same data
342        engine
343            .compress_with_dedup(data, CompressionType::Lz4)
344            .unwrap();
345        engine
346            .compress_with_dedup(data, CompressionType::Lz4)
347            .unwrap();
348
349        // Second call should be dedup hit
350        assert!(engine.stats().dedup_hits > 0);
351    }
352
353    #[test]
354    fn test_choose_compression() {
355        let now = SystemTime::now()
356            .duration_since(UNIX_EPOCH)
357            .unwrap()
358            .as_micros() as u64;
359
360        // Small payload -> None
361        assert_eq!(choose_compression(100, now), CompressionType::None);
362
363        // Recent large payload -> LZ4
364        assert_eq!(choose_compression(10000, now), CompressionType::Lz4);
365
366        // Old large payload -> ZstdMax
367        let old = now - 4_000_000_000_000; // ~46 days ago
368        assert_eq!(choose_compression(10000, old), CompressionType::ZstdMax);
369    }
370}