ruvector_gnn/
compress.rs

1//! Tensor compression with adaptive level selection
2//!
3//! This module provides multi-level tensor compression based on access frequency:
4//! - Hot data (f > 0.8): Full precision
5//! - Warm data (f > 0.4): Half precision
6//! - Cool data (f > 0.1): 8-bit product quantization
7//! - Cold data (f > 0.01): 4-bit product quantization
8//! - Archive (f <= 0.01): Binary quantization
9
10use crate::error::{GnnError, Result};
11use serde::{Deserialize, Serialize};
12
13/// Compression level with associated parameters
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub enum CompressionLevel {
16    /// Full precision - no compression
17    None,
18
19    /// Half precision with scale factor
20    Half { scale: f32 },
21
22    /// Product quantization with 8-bit codes
23    PQ8 { subvectors: u8, centroids: u8 },
24
25    /// Product quantization with 4-bit codes and outlier handling
26    PQ4 {
27        subvectors: u8,
28        outlier_threshold: f32,
29    },
30
31    /// Binary quantization with threshold
32    Binary { threshold: f32 },
33}
34
35/// Compressed tensor data
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum CompressedTensor {
38    /// Uncompressed full precision data
39    Full { data: Vec<f32> },
40
41    /// Half precision data
42    Half {
43        data: Vec<u16>,
44        scale: f32,
45        dim: usize,
46    },
47
48    /// 8-bit product quantization
49    PQ8 {
50        codes: Vec<u8>,
51        codebooks: Vec<Vec<f32>>,
52        subvector_dim: usize,
53        dim: usize,
54    },
55
56    /// 4-bit product quantization with outliers
57    PQ4 {
58        codes: Vec<u8>, // Packed 4-bit codes
59        codebooks: Vec<Vec<f32>>,
60        outliers: Vec<(usize, f32)>, // (index, value) pairs
61        subvector_dim: usize,
62        dim: usize,
63    },
64
65    /// Binary quantization
66    Binary {
67        bits: Vec<u8>,
68        threshold: f32,
69        dim: usize,
70    },
71}
72
73/// Tensor compressor with adaptive level selection
74#[derive(Debug, Clone)]
75pub struct TensorCompress {
76    /// Default compression parameters
77    default_level: CompressionLevel,
78}
79
80impl Default for TensorCompress {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl TensorCompress {
87    /// Create a new tensor compressor with default settings
88    pub fn new() -> Self {
89        Self {
90            default_level: CompressionLevel::None,
91        }
92    }
93
94    /// Compress an embedding based on access frequency
95    ///
96    /// # Arguments
97    /// * `embedding` - The input embedding vector
98    /// * `access_freq` - Access frequency in range [0.0, 1.0]
99    ///
100    /// # Returns
101    /// Compressed tensor using adaptive compression level
102    pub fn compress(&self, embedding: &[f32], access_freq: f32) -> Result<CompressedTensor> {
103        if embedding.is_empty() {
104            return Err(GnnError::InvalidInput("Empty embedding vector".to_string()));
105        }
106
107        let level = self.select_level(access_freq);
108        self.compress_with_level(embedding, &level)
109    }
110
111    /// Compress with explicit compression level
112    pub fn compress_with_level(
113        &self,
114        embedding: &[f32],
115        level: &CompressionLevel,
116    ) -> Result<CompressedTensor> {
117        match level {
118            CompressionLevel::None => self.compress_none(embedding),
119            CompressionLevel::Half { scale } => self.compress_half(embedding, *scale),
120            CompressionLevel::PQ8 {
121                subvectors,
122                centroids,
123            } => self.compress_pq8(embedding, *subvectors, *centroids),
124            CompressionLevel::PQ4 {
125                subvectors,
126                outlier_threshold,
127            } => self.compress_pq4(embedding, *subvectors, *outlier_threshold),
128            CompressionLevel::Binary { threshold } => self.compress_binary(embedding, *threshold),
129        }
130    }
131
132    /// Decompress a compressed tensor
133    pub fn decompress(&self, compressed: &CompressedTensor) -> Result<Vec<f32>> {
134        match compressed {
135            CompressedTensor::Full { data } => Ok(data.clone()),
136            CompressedTensor::Half { data, scale, dim } => self.decompress_half(data, *scale, *dim),
137            CompressedTensor::PQ8 {
138                codes,
139                codebooks,
140                subvector_dim,
141                dim,
142            } => self.decompress_pq8(codes, codebooks, *subvector_dim, *dim),
143            CompressedTensor::PQ4 {
144                codes,
145                codebooks,
146                outliers,
147                subvector_dim,
148                dim,
149            } => self.decompress_pq4(codes, codebooks, outliers, *subvector_dim, *dim),
150            CompressedTensor::Binary {
151                bits,
152                threshold,
153                dim,
154            } => self.decompress_binary(bits, *threshold, *dim),
155        }
156    }
157
158    /// Select compression level based on access frequency
159    ///
160    /// Thresholds:
161    /// - f > 0.8: None (hot data)
162    /// - f > 0.4: Half (warm data)
163    /// - f > 0.1: PQ8 (cool data)
164    /// - f > 0.01: PQ4 (cold data)
165    /// - f <= 0.01: Binary (archive)
166    fn select_level(&self, access_freq: f32) -> CompressionLevel {
167        if access_freq > 0.8 {
168            CompressionLevel::None
169        } else if access_freq > 0.4 {
170            CompressionLevel::Half { scale: 1.0 }
171        } else if access_freq > 0.1 {
172            CompressionLevel::PQ8 {
173                subvectors: 8,
174                centroids: 16,
175            }
176        } else if access_freq > 0.01 {
177            CompressionLevel::PQ4 {
178                subvectors: 8,
179                outlier_threshold: 3.0,
180            }
181        } else {
182            CompressionLevel::Binary { threshold: 0.0 }
183        }
184    }
185
186    // === Compression implementations ===
187
188    fn compress_none(&self, embedding: &[f32]) -> Result<CompressedTensor> {
189        Ok(CompressedTensor::Full {
190            data: embedding.to_vec(),
191        })
192    }
193
194    fn compress_half(&self, embedding: &[f32], scale: f32) -> Result<CompressedTensor> {
195        // Simple half precision: scale and convert to 16-bit
196        let data: Vec<u16> = embedding
197            .iter()
198            .map(|&x| {
199                let scaled = x * scale;
200                let clamped = scaled.clamp(-65504.0, 65504.0);
201                // Convert to half precision representation
202                f32_to_f16_bits(clamped)
203            })
204            .collect();
205
206        Ok(CompressedTensor::Half {
207            data,
208            scale,
209            dim: embedding.len(),
210        })
211    }
212
213    fn compress_pq8(
214        &self,
215        embedding: &[f32],
216        subvectors: u8,
217        centroids: u8,
218    ) -> Result<CompressedTensor> {
219        let dim = embedding.len();
220        let subvectors = subvectors as usize;
221
222        if dim % subvectors != 0 {
223            return Err(GnnError::InvalidInput(format!(
224                "Dimension {} not divisible by subvectors {}",
225                dim, subvectors
226            )));
227        }
228
229        let subvector_dim = dim / subvectors;
230        let mut codes = Vec::with_capacity(subvectors);
231        let mut codebooks = Vec::with_capacity(subvectors);
232
233        // For each subvector, create a codebook and quantize
234        for i in 0..subvectors {
235            let start = i * subvector_dim;
236            let end = start + subvector_dim;
237            let subvector = &embedding[start..end];
238
239            // Simple k-means clustering (k=centroids)
240            let (codebook, code) = self.quantize_subvector(subvector, centroids as usize);
241            codes.push(code);
242            codebooks.push(codebook);
243        }
244
245        Ok(CompressedTensor::PQ8 {
246            codes,
247            codebooks,
248            subvector_dim,
249            dim,
250        })
251    }
252
253    fn compress_pq4(
254        &self,
255        embedding: &[f32],
256        subvectors: u8,
257        outlier_threshold: f32,
258    ) -> Result<CompressedTensor> {
259        let dim = embedding.len();
260        let subvectors = subvectors as usize;
261
262        if dim % subvectors != 0 {
263            return Err(GnnError::InvalidInput(format!(
264                "Dimension {} not divisible by subvectors {}",
265                dim, subvectors
266            )));
267        }
268
269        let subvector_dim = dim / subvectors;
270        let mut codes = Vec::with_capacity(subvectors);
271        let mut codebooks = Vec::with_capacity(subvectors);
272        let mut outliers = Vec::new();
273
274        // Detect outliers based on magnitude
275        let mean = embedding.iter().sum::<f32>() / dim as f32;
276        let std_dev =
277            (embedding.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / dim as f32).sqrt();
278
279        // For each subvector
280        for i in 0..subvectors {
281            let start = i * subvector_dim;
282            let end = start + subvector_dim;
283            let subvector = &embedding[start..end];
284
285            // Extract outliers
286            let mut cleaned_subvector = subvector.to_vec();
287            for (j, &val) in subvector.iter().enumerate() {
288                if (val - mean).abs() > outlier_threshold * std_dev {
289                    outliers.push((start + j, val));
290                    cleaned_subvector[j] = mean; // Replace with mean
291                }
292            }
293
294            // Quantize to 4-bit (16 centroids)
295            let (codebook, code) = self.quantize_subvector(&cleaned_subvector, 16);
296            codes.push(code);
297            codebooks.push(codebook);
298        }
299
300        Ok(CompressedTensor::PQ4 {
301            codes,
302            codebooks,
303            outliers,
304            subvector_dim,
305            dim,
306        })
307    }
308
309    fn compress_binary(&self, embedding: &[f32], threshold: f32) -> Result<CompressedTensor> {
310        let dim = embedding.len();
311        let num_bytes = (dim + 7) / 8;
312        let mut bits = vec![0u8; num_bytes];
313
314        for (i, &val) in embedding.iter().enumerate() {
315            if val > threshold {
316                let byte_idx = i / 8;
317                let bit_idx = i % 8;
318                bits[byte_idx] |= 1 << bit_idx;
319            }
320        }
321
322        Ok(CompressedTensor::Binary {
323            bits,
324            threshold,
325            dim,
326        })
327    }
328
329    // === Decompression implementations ===
330
331    fn decompress_half(&self, data: &[u16], scale: f32, dim: usize) -> Result<Vec<f32>> {
332        if data.len() != dim {
333            return Err(GnnError::InvalidInput(format!(
334                "Dimension mismatch: expected {}, got {}",
335                dim,
336                data.len()
337            )));
338        }
339
340        Ok(data
341            .iter()
342            .map(|&bits| f16_bits_to_f32(bits) / scale)
343            .collect())
344    }
345
346    fn decompress_pq8(
347        &self,
348        codes: &[u8],
349        codebooks: &[Vec<f32>],
350        subvector_dim: usize,
351        dim: usize,
352    ) -> Result<Vec<f32>> {
353        let subvectors = codes.len();
354        let expected_dim = subvectors * subvector_dim;
355
356        if expected_dim != dim {
357            return Err(GnnError::InvalidInput(format!(
358                "Dimension mismatch: expected {}, got {}",
359                dim, expected_dim
360            )));
361        }
362
363        let mut result = Vec::with_capacity(dim);
364
365        for (code, codebook) in codes.iter().zip(codebooks.iter()) {
366            let centroid_idx = *code as usize;
367            if centroid_idx >= codebook.len() / subvector_dim {
368                return Err(GnnError::InvalidInput(format!(
369                    "Invalid centroid index: {}",
370                    centroid_idx
371                )));
372            }
373
374            let start = centroid_idx * subvector_dim;
375            let end = start + subvector_dim;
376            result.extend_from_slice(&codebook[start..end]);
377        }
378
379        Ok(result)
380    }
381
382    fn decompress_pq4(
383        &self,
384        codes: &[u8],
385        codebooks: &[Vec<f32>],
386        outliers: &[(usize, f32)],
387        subvector_dim: usize,
388        dim: usize,
389    ) -> Result<Vec<f32>> {
390        // First decompress using PQ8 logic
391        let mut result = self.decompress_pq8(codes, codebooks, subvector_dim, dim)?;
392
393        // Restore outliers
394        for &(idx, val) in outliers {
395            if idx < result.len() {
396                result[idx] = val;
397            }
398        }
399
400        Ok(result)
401    }
402
403    fn decompress_binary(&self, bits: &[u8], _threshold: f32, dim: usize) -> Result<Vec<f32>> {
404        let expected_bytes = (dim + 7) / 8;
405        if bits.len() != expected_bytes {
406            return Err(GnnError::InvalidInput(format!(
407                "Dimension mismatch: expected {} bytes, got {}",
408                expected_bytes,
409                bits.len()
410            )));
411        }
412
413        let mut result = Vec::with_capacity(dim);
414
415        for i in 0..dim {
416            let byte_idx = i / 8;
417            let bit_idx = i % 8;
418            let is_set = (bits[byte_idx] & (1 << bit_idx)) != 0;
419            result.push(if is_set { 1.0 } else { -1.0 });
420        }
421
422        Ok(result)
423    }
424
425    // === Helper methods ===
426
427    /// Simple quantization using k-means-like approach
428    fn quantize_subvector(&self, subvector: &[f32], k: usize) -> (Vec<f32>, u8) {
429        let dim = subvector.len();
430
431        // Initialize centroids using simple range-based approach
432        let min_val = subvector.iter().cloned().fold(f32::INFINITY, f32::min);
433        let max_val = subvector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
434        let range = max_val - min_val;
435
436        if range < 1e-6 {
437            // All values are essentially the same
438            let codebook = vec![min_val; dim * k];
439            return (codebook, 0);
440        }
441
442        // Create k centroids evenly spaced across the range
443        let centroids: Vec<Vec<f32>> = (0..k)
444            .map(|i| {
445                let offset = min_val + (i as f32 / k as f32) * range;
446                vec![offset; dim]
447            })
448            .collect();
449
450        // Find nearest centroid for this subvector
451        let code = self.nearest_centroid(subvector, &centroids);
452
453        // Flatten codebook
454        let codebook: Vec<f32> = centroids.into_iter().flatten().collect();
455
456        (codebook, code as u8)
457    }
458
459    fn nearest_centroid(&self, subvector: &[f32], centroids: &[Vec<f32>]) -> usize {
460        centroids
461            .iter()
462            .enumerate()
463            .map(|(i, centroid)| {
464                let dist: f32 = subvector
465                    .iter()
466                    .zip(centroid.iter())
467                    .map(|(a, b)| (a - b).powi(2))
468                    .sum();
469                (i, dist)
470            })
471            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
472            .map(|(i, _)| i)
473            .unwrap_or(0)
474    }
475}
476
477// === Half precision conversion helpers ===
478
479/// Convert f32 to f16 bits (simplified implementation)
480fn f32_to_f16_bits(value: f32) -> u16 {
481    // Simple conversion: scale to 16-bit range
482    // This is a simplified version, not IEEE 754 half precision
483    let scaled = (value * 1000.0).clamp(-32768.0, 32767.0);
484    ((scaled as i32) + 32768) as u16
485}
486
487/// Convert f16 bits to f32 (simplified implementation)
488fn f16_bits_to_f32(bits: u16) -> f32 {
489    // Reverse of f32_to_f16_bits
490    let value = bits as i32 - 32768;
491    value as f32 / 1000.0
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_compress_none() {
500        let compressor = TensorCompress::new();
501        let embedding = vec![1.0, 2.0, 3.0, 4.0];
502
503        let compressed = compressor.compress(&embedding, 1.0).unwrap();
504        let decompressed = compressor.decompress(&compressed).unwrap();
505
506        assert_eq!(embedding, decompressed);
507    }
508
509    #[test]
510    fn test_compress_half() {
511        let compressor = TensorCompress::new();
512        let embedding = vec![1.0, 2.0, 3.0, 4.0];
513
514        let compressed = compressor.compress(&embedding, 0.5).unwrap();
515        let decompressed = compressor.decompress(&compressed).unwrap();
516
517        // Half precision should be close but not exact
518        for (a, b) in embedding.iter().zip(decompressed.iter()) {
519            assert!((a - b).abs() < 0.01, "Expected {}, got {}", a, b);
520        }
521    }
522
523    #[test]
524    fn test_compress_binary() {
525        let compressor = TensorCompress::new();
526        let embedding = vec![1.0, -1.0, 0.5, -0.5];
527
528        let compressed = compressor.compress(&embedding, 0.005).unwrap();
529        let decompressed = compressor.decompress(&compressed).unwrap();
530
531        // Binary should be +1 or -1
532        assert_eq!(decompressed.len(), embedding.len());
533        for val in &decompressed {
534            assert!(*val == 1.0 || *val == -1.0);
535        }
536    }
537
538    #[test]
539    fn test_select_level() {
540        let compressor = TensorCompress::new();
541
542        // Hot data
543        assert!(matches!(
544            compressor.select_level(0.9),
545            CompressionLevel::None
546        ));
547
548        // Warm data
549        assert!(matches!(
550            compressor.select_level(0.5),
551            CompressionLevel::Half { .. }
552        ));
553
554        // Cool data
555        assert!(matches!(
556            compressor.select_level(0.2),
557            CompressionLevel::PQ8 { .. }
558        ));
559
560        // Cold data
561        assert!(matches!(
562            compressor.select_level(0.05),
563            CompressionLevel::PQ4 { .. }
564        ));
565
566        // Archive
567        assert!(matches!(
568            compressor.select_level(0.001),
569            CompressionLevel::Binary { .. }
570        ));
571    }
572
573    #[test]
574    fn test_empty_embedding() {
575        let compressor = TensorCompress::new();
576        let result = compressor.compress(&[], 0.5);
577        assert!(result.is_err());
578    }
579
580    #[test]
581    fn test_pq8_compression() {
582        let compressor = TensorCompress::new();
583        let embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
584
585        let compressed = compressor.compress_pq8(&embedding, 8, 16).unwrap();
586        let decompressed = compressor.decompress(&compressed).unwrap();
587
588        assert_eq!(decompressed.len(), embedding.len());
589    }
590
591    #[test]
592    fn test_round_trip_all_levels() {
593        let compressor = TensorCompress::new();
594        let embedding: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
595
596        let access_frequencies = vec![0.9, 0.5, 0.2, 0.05, 0.001];
597
598        for freq in access_frequencies {
599            let compressed = compressor.compress(&embedding, freq).unwrap();
600            let decompressed = compressor.decompress(&compressed).unwrap();
601            assert_eq!(decompressed.len(), embedding.len());
602        }
603    }
604
605    #[test]
606    fn test_half_precision_roundtrip() {
607        let compressor = TensorCompress::new();
608        // Use values within the supported range (-32.768 to 32.767)
609        let values = vec![-30.0, -1.0, 0.0, 1.0, 30.0];
610
611        for val in values {
612            let embedding = vec![val; 4];
613            let compressed = compressor
614                .compress_with_level(&embedding, &CompressionLevel::Half { scale: 1.0 })
615                .unwrap();
616            let decompressed = compressor.decompress(&compressed).unwrap();
617
618            for (a, b) in embedding.iter().zip(decompressed.iter()) {
619                let diff = (a - b).abs();
620                assert!(
621                    diff < 0.1,
622                    "Value {} decompressed to {}, diff: {}",
623                    a,
624                    b,
625                    diff
626                );
627            }
628        }
629    }
630
631    #[test]
632    fn test_binary_threshold() {
633        let compressor = TensorCompress::new();
634        let embedding = vec![0.5, -0.5, 1.5, -1.5];
635
636        let compressed = compressor
637            .compress_with_level(&embedding, &CompressionLevel::Binary { threshold: 0.0 })
638            .unwrap();
639        let decompressed = compressor.decompress(&compressed).unwrap();
640
641        // Values > 0 should be 1.0, values <= 0 should be -1.0
642        assert_eq!(decompressed, vec![1.0, -1.0, 1.0, -1.0]);
643    }
644
645    #[test]
646    fn test_pq4_with_outliers() {
647        let compressor = TensorCompress::new();
648        // Create embedding with some outliers
649        let mut embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
650        embedding[10] = 100.0; // Outlier
651        embedding[30] = -100.0; // Outlier
652
653        let compressed = compressor
654            .compress_with_level(
655                &embedding,
656                &CompressionLevel::PQ4 {
657                    subvectors: 8,
658                    outlier_threshold: 2.0,
659                },
660            )
661            .unwrap();
662        let decompressed = compressor.decompress(&compressed).unwrap();
663
664        assert_eq!(decompressed.len(), embedding.len());
665        // Outliers should be preserved
666        assert_eq!(decompressed[10], 100.0);
667        assert_eq!(decompressed[30], -100.0);
668    }
669
670    #[test]
671    fn test_dimension_validation() {
672        let compressor = TensorCompress::new();
673        let embedding = vec![1.0; 10]; // Not divisible by 8
674
675        let result = compressor.compress_pq8(&embedding, 8, 16);
676        assert!(result.is_err());
677    }
678}