Skip to main content

ruvector_gnn/
compress.rs

1//! Tensor compression with adaptive level selection
2//!
3//! This module provides multi-level tensor compression based on access frequency:
4//! - Hot data (f > 0.8): Full precision
5//! - Warm data (f > 0.4): Half precision
6//! - Cool data (f > 0.1): 8-bit product quantization
7//! - Cold data (f > 0.01): 4-bit product quantization
8//! - Archive (f <= 0.01): Binary quantization
9
10use crate::error::{GnnError, Result};
11use serde::{Deserialize, Serialize};
12
13/// Compression level with associated parameters
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub enum CompressionLevel {
16    /// Full precision - no compression
17    None,
18
19    /// Half precision with scale factor
20    Half { scale: f32 },
21
22    /// Product quantization with 8-bit codes
23    PQ8 { subvectors: u8, centroids: u8 },
24
25    /// Product quantization with 4-bit codes and outlier handling
26    PQ4 {
27        subvectors: u8,
28        outlier_threshold: f32,
29    },
30
31    /// Binary quantization with threshold
32    Binary { threshold: f32 },
33}
34
35/// Compressed tensor data
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum CompressedTensor {
38    /// Uncompressed full precision data
39    Full { data: Vec<f32> },
40
41    /// Half precision data
42    Half {
43        data: Vec<u16>,
44        scale: f32,
45        dim: usize,
46    },
47
48    /// 8-bit product quantization
49    PQ8 {
50        codes: Vec<u8>,
51        codebooks: Vec<Vec<f32>>,
52        subvector_dim: usize,
53        dim: usize,
54    },
55
56    /// 4-bit product quantization with outliers
57    PQ4 {
58        codes: Vec<u8>, // Packed 4-bit codes
59        codebooks: Vec<Vec<f32>>,
60        outliers: Vec<(usize, f32)>, // (index, value) pairs
61        subvector_dim: usize,
62        dim: usize,
63    },
64
65    /// Binary quantization
66    Binary {
67        bits: Vec<u8>,
68        threshold: f32,
69        dim: usize,
70    },
71}
72
73/// Tensor compressor with adaptive level selection
74#[derive(Debug, Clone)]
75pub struct TensorCompress {
76    /// Default compression parameters
77    #[allow(dead_code)]
78    default_level: CompressionLevel,
79}
80
81impl Default for TensorCompress {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl TensorCompress {
88    /// Create a new tensor compressor with default settings
89    pub fn new() -> Self {
90        Self {
91            default_level: CompressionLevel::None,
92        }
93    }
94
95    /// Compress an embedding based on access frequency
96    ///
97    /// # Arguments
98    /// * `embedding` - The input embedding vector
99    /// * `access_freq` - Access frequency in range [0.0, 1.0]
100    ///
101    /// # Returns
102    /// Compressed tensor using adaptive compression level
103    pub fn compress(&self, embedding: &[f32], access_freq: f32) -> Result<CompressedTensor> {
104        if embedding.is_empty() {
105            return Err(GnnError::InvalidInput("Empty embedding vector".to_string()));
106        }
107
108        let level = self.select_level(access_freq);
109        self.compress_with_level(embedding, &level)
110    }
111
112    /// Compress with explicit compression level
113    pub fn compress_with_level(
114        &self,
115        embedding: &[f32],
116        level: &CompressionLevel,
117    ) -> Result<CompressedTensor> {
118        match level {
119            CompressionLevel::None => self.compress_none(embedding),
120            CompressionLevel::Half { scale } => self.compress_half(embedding, *scale),
121            CompressionLevel::PQ8 {
122                subvectors,
123                centroids,
124            } => self.compress_pq8(embedding, *subvectors, *centroids),
125            CompressionLevel::PQ4 {
126                subvectors,
127                outlier_threshold,
128            } => self.compress_pq4(embedding, *subvectors, *outlier_threshold),
129            CompressionLevel::Binary { threshold } => self.compress_binary(embedding, *threshold),
130        }
131    }
132
133    /// Decompress a compressed tensor
134    pub fn decompress(&self, compressed: &CompressedTensor) -> Result<Vec<f32>> {
135        match compressed {
136            CompressedTensor::Full { data } => Ok(data.clone()),
137            CompressedTensor::Half { data, scale, dim } => self.decompress_half(data, *scale, *dim),
138            CompressedTensor::PQ8 {
139                codes,
140                codebooks,
141                subvector_dim,
142                dim,
143            } => self.decompress_pq8(codes, codebooks, *subvector_dim, *dim),
144            CompressedTensor::PQ4 {
145                codes,
146                codebooks,
147                outliers,
148                subvector_dim,
149                dim,
150            } => self.decompress_pq4(codes, codebooks, outliers, *subvector_dim, *dim),
151            CompressedTensor::Binary {
152                bits,
153                threshold,
154                dim,
155            } => self.decompress_binary(bits, *threshold, *dim),
156        }
157    }
158
159    /// Select compression level based on access frequency
160    ///
161    /// Thresholds:
162    /// - f > 0.8: None (hot data)
163    /// - f > 0.4: Half (warm data)
164    /// - f > 0.1: PQ8 (cool data)
165    /// - f > 0.01: PQ4 (cold data)
166    /// - f <= 0.01: Binary (archive)
167    fn select_level(&self, access_freq: f32) -> CompressionLevel {
168        if access_freq > 0.8 {
169            CompressionLevel::None
170        } else if access_freq > 0.4 {
171            CompressionLevel::Half { scale: 1.0 }
172        } else if access_freq > 0.1 {
173            CompressionLevel::PQ8 {
174                subvectors: 8,
175                centroids: 16,
176            }
177        } else if access_freq > 0.01 {
178            CompressionLevel::PQ4 {
179                subvectors: 8,
180                outlier_threshold: 3.0,
181            }
182        } else {
183            CompressionLevel::Binary { threshold: 0.0 }
184        }
185    }
186
187    // === Compression implementations ===
188
189    fn compress_none(&self, embedding: &[f32]) -> Result<CompressedTensor> {
190        Ok(CompressedTensor::Full {
191            data: embedding.to_vec(),
192        })
193    }
194
195    fn compress_half(&self, embedding: &[f32], scale: f32) -> Result<CompressedTensor> {
196        // Simple half precision: scale and convert to 16-bit
197        let data: Vec<u16> = embedding
198            .iter()
199            .map(|&x| {
200                let scaled = x * scale;
201                let clamped = scaled.clamp(-65504.0, 65504.0);
202                // Convert to half precision representation
203                f32_to_f16_bits(clamped)
204            })
205            .collect();
206
207        Ok(CompressedTensor::Half {
208            data,
209            scale,
210            dim: embedding.len(),
211        })
212    }
213
214    fn compress_pq8(
215        &self,
216        embedding: &[f32],
217        subvectors: u8,
218        centroids: u8,
219    ) -> Result<CompressedTensor> {
220        let dim = embedding.len();
221        let subvectors = subvectors as usize;
222
223        if dim % subvectors != 0 {
224            return Err(GnnError::InvalidInput(format!(
225                "Dimension {} not divisible by subvectors {}",
226                dim, subvectors
227            )));
228        }
229
230        let subvector_dim = dim / subvectors;
231        let mut codes = Vec::with_capacity(subvectors);
232        let mut codebooks = Vec::with_capacity(subvectors);
233
234        // For each subvector, create a codebook and quantize
235        for i in 0..subvectors {
236            let start = i * subvector_dim;
237            let end = start + subvector_dim;
238            let subvector = &embedding[start..end];
239
240            // Simple k-means clustering (k=centroids)
241            let (codebook, code) = self.quantize_subvector(subvector, centroids as usize);
242            codes.push(code);
243            codebooks.push(codebook);
244        }
245
246        Ok(CompressedTensor::PQ8 {
247            codes,
248            codebooks,
249            subvector_dim,
250            dim,
251        })
252    }
253
254    fn compress_pq4(
255        &self,
256        embedding: &[f32],
257        subvectors: u8,
258        outlier_threshold: f32,
259    ) -> Result<CompressedTensor> {
260        let dim = embedding.len();
261        let subvectors = subvectors as usize;
262
263        if dim % subvectors != 0 {
264            return Err(GnnError::InvalidInput(format!(
265                "Dimension {} not divisible by subvectors {}",
266                dim, subvectors
267            )));
268        }
269
270        let subvector_dim = dim / subvectors;
271        let mut codes = Vec::with_capacity(subvectors);
272        let mut codebooks = Vec::with_capacity(subvectors);
273        let mut outliers = Vec::new();
274
275        // Detect outliers based on magnitude
276        let mean = embedding.iter().sum::<f32>() / dim as f32;
277        let std_dev =
278            (embedding.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / dim as f32).sqrt();
279
280        // For each subvector
281        for i in 0..subvectors {
282            let start = i * subvector_dim;
283            let end = start + subvector_dim;
284            let subvector = &embedding[start..end];
285
286            // Extract outliers
287            let mut cleaned_subvector = subvector.to_vec();
288            for (j, &val) in subvector.iter().enumerate() {
289                if (val - mean).abs() > outlier_threshold * std_dev {
290                    outliers.push((start + j, val));
291                    cleaned_subvector[j] = mean; // Replace with mean
292                }
293            }
294
295            // Quantize to 4-bit (16 centroids)
296            let (codebook, code) = self.quantize_subvector(&cleaned_subvector, 16);
297            codes.push(code);
298            codebooks.push(codebook);
299        }
300
301        Ok(CompressedTensor::PQ4 {
302            codes,
303            codebooks,
304            outliers,
305            subvector_dim,
306            dim,
307        })
308    }
309
310    fn compress_binary(&self, embedding: &[f32], threshold: f32) -> Result<CompressedTensor> {
311        let dim = embedding.len();
312        let num_bytes = dim.div_ceil(8);
313        let mut bits = vec![0u8; num_bytes];
314
315        for (i, &val) in embedding.iter().enumerate() {
316            if val > threshold {
317                let byte_idx = i / 8;
318                let bit_idx = i % 8;
319                bits[byte_idx] |= 1 << bit_idx;
320            }
321        }
322
323        Ok(CompressedTensor::Binary {
324            bits,
325            threshold,
326            dim,
327        })
328    }
329
330    // === Decompression implementations ===
331
332    fn decompress_half(&self, data: &[u16], scale: f32, dim: usize) -> Result<Vec<f32>> {
333        if data.len() != dim {
334            return Err(GnnError::InvalidInput(format!(
335                "Dimension mismatch: expected {}, got {}",
336                dim,
337                data.len()
338            )));
339        }
340
341        Ok(data
342            .iter()
343            .map(|&bits| f16_bits_to_f32(bits) / scale)
344            .collect())
345    }
346
347    fn decompress_pq8(
348        &self,
349        codes: &[u8],
350        codebooks: &[Vec<f32>],
351        subvector_dim: usize,
352        dim: usize,
353    ) -> Result<Vec<f32>> {
354        let subvectors = codes.len();
355        let expected_dim = subvectors * subvector_dim;
356
357        if expected_dim != dim {
358            return Err(GnnError::InvalidInput(format!(
359                "Dimension mismatch: expected {}, got {}",
360                dim, expected_dim
361            )));
362        }
363
364        let mut result = Vec::with_capacity(dim);
365
366        for (code, codebook) in codes.iter().zip(codebooks.iter()) {
367            let centroid_idx = *code as usize;
368            if centroid_idx >= codebook.len() / subvector_dim {
369                return Err(GnnError::InvalidInput(format!(
370                    "Invalid centroid index: {}",
371                    centroid_idx
372                )));
373            }
374
375            let start = centroid_idx * subvector_dim;
376            let end = start + subvector_dim;
377            result.extend_from_slice(&codebook[start..end]);
378        }
379
380        Ok(result)
381    }
382
383    fn decompress_pq4(
384        &self,
385        codes: &[u8],
386        codebooks: &[Vec<f32>],
387        outliers: &[(usize, f32)],
388        subvector_dim: usize,
389        dim: usize,
390    ) -> Result<Vec<f32>> {
391        // First decompress using PQ8 logic
392        let mut result = self.decompress_pq8(codes, codebooks, subvector_dim, dim)?;
393
394        // Restore outliers
395        for &(idx, val) in outliers {
396            if idx < result.len() {
397                result[idx] = val;
398            }
399        }
400
401        Ok(result)
402    }
403
404    fn decompress_binary(&self, bits: &[u8], _threshold: f32, dim: usize) -> Result<Vec<f32>> {
405        let expected_bytes = dim.div_ceil(8);
406        if bits.len() != expected_bytes {
407            return Err(GnnError::InvalidInput(format!(
408                "Dimension mismatch: expected {} bytes, got {}",
409                expected_bytes,
410                bits.len()
411            )));
412        }
413
414        let mut result = Vec::with_capacity(dim);
415
416        for i in 0..dim {
417            let byte_idx = i / 8;
418            let bit_idx = i % 8;
419            let is_set = (bits[byte_idx] & (1 << bit_idx)) != 0;
420            result.push(if is_set { 1.0 } else { -1.0 });
421        }
422
423        Ok(result)
424    }
425
426    // === Helper methods ===
427
428    /// Simple quantization using k-means-like approach
429    fn quantize_subvector(&self, subvector: &[f32], k: usize) -> (Vec<f32>, u8) {
430        let dim = subvector.len();
431
432        // Initialize centroids using simple range-based approach
433        let min_val = subvector.iter().cloned().fold(f32::INFINITY, f32::min);
434        let max_val = subvector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
435        let range = max_val - min_val;
436
437        if range < 1e-6 {
438            // All values are essentially the same
439            let codebook = vec![min_val; dim * k];
440            return (codebook, 0);
441        }
442
443        // Create k centroids evenly spaced across the range
444        let centroids: Vec<Vec<f32>> = (0..k)
445            .map(|i| {
446                let offset = min_val + (i as f32 / k as f32) * range;
447                vec![offset; dim]
448            })
449            .collect();
450
451        // Find nearest centroid for this subvector
452        let code = self.nearest_centroid(subvector, &centroids);
453
454        // Flatten codebook
455        let codebook: Vec<f32> = centroids.into_iter().flatten().collect();
456
457        (codebook, code as u8)
458    }
459
460    fn nearest_centroid(&self, subvector: &[f32], centroids: &[Vec<f32>]) -> usize {
461        centroids
462            .iter()
463            .enumerate()
464            .map(|(i, centroid)| {
465                let dist: f32 = subvector
466                    .iter()
467                    .zip(centroid.iter())
468                    .map(|(a, b)| (a - b).powi(2))
469                    .sum();
470                (i, dist)
471            })
472            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
473            .map(|(i, _)| i)
474            .unwrap_or(0)
475    }
476}
477
478// === Half precision conversion helpers ===
479
480/// Convert f32 to f16 bits (simplified implementation)
481fn f32_to_f16_bits(value: f32) -> u16 {
482    // Simple conversion: scale to 16-bit range
483    // This is a simplified version, not IEEE 754 half precision
484    let scaled = (value * 1000.0).clamp(-32768.0, 32767.0);
485    ((scaled as i32) + 32768) as u16
486}
487
488/// Convert f16 bits to f32 (simplified implementation)
489fn f16_bits_to_f32(bits: u16) -> f32 {
490    // Reverse of f32_to_f16_bits
491    let value = bits as i32 - 32768;
492    value as f32 / 1000.0
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_compress_none() {
501        let compressor = TensorCompress::new();
502        let embedding = vec![1.0, 2.0, 3.0, 4.0];
503
504        let compressed = compressor.compress(&embedding, 1.0).unwrap();
505        let decompressed = compressor.decompress(&compressed).unwrap();
506
507        assert_eq!(embedding, decompressed);
508    }
509
510    #[test]
511    fn test_compress_half() {
512        let compressor = TensorCompress::new();
513        let embedding = vec![1.0, 2.0, 3.0, 4.0];
514
515        let compressed = compressor.compress(&embedding, 0.5).unwrap();
516        let decompressed = compressor.decompress(&compressed).unwrap();
517
518        // Half precision should be close but not exact
519        for (a, b) in embedding.iter().zip(decompressed.iter()) {
520            assert!((a - b).abs() < 0.01, "Expected {}, got {}", a, b);
521        }
522    }
523
524    #[test]
525    fn test_compress_binary() {
526        let compressor = TensorCompress::new();
527        let embedding = vec![1.0, -1.0, 0.5, -0.5];
528
529        let compressed = compressor.compress(&embedding, 0.005).unwrap();
530        let decompressed = compressor.decompress(&compressed).unwrap();
531
532        // Binary should be +1 or -1
533        assert_eq!(decompressed.len(), embedding.len());
534        for val in &decompressed {
535            assert!(*val == 1.0 || *val == -1.0);
536        }
537    }
538
539    #[test]
540    fn test_select_level() {
541        let compressor = TensorCompress::new();
542
543        // Hot data
544        assert!(matches!(
545            compressor.select_level(0.9),
546            CompressionLevel::None
547        ));
548
549        // Warm data
550        assert!(matches!(
551            compressor.select_level(0.5),
552            CompressionLevel::Half { .. }
553        ));
554
555        // Cool data
556        assert!(matches!(
557            compressor.select_level(0.2),
558            CompressionLevel::PQ8 { .. }
559        ));
560
561        // Cold data
562        assert!(matches!(
563            compressor.select_level(0.05),
564            CompressionLevel::PQ4 { .. }
565        ));
566
567        // Archive
568        assert!(matches!(
569            compressor.select_level(0.001),
570            CompressionLevel::Binary { .. }
571        ));
572    }
573
574    #[test]
575    fn test_empty_embedding() {
576        let compressor = TensorCompress::new();
577        let result = compressor.compress(&[], 0.5);
578        assert!(result.is_err());
579    }
580
581    #[test]
582    fn test_pq8_compression() {
583        let compressor = TensorCompress::new();
584        let embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
585
586        let compressed = compressor.compress_pq8(&embedding, 8, 16).unwrap();
587        let decompressed = compressor.decompress(&compressed).unwrap();
588
589        assert_eq!(decompressed.len(), embedding.len());
590    }
591
592    #[test]
593    fn test_round_trip_all_levels() {
594        let compressor = TensorCompress::new();
595        let embedding: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
596
597        let access_frequencies = vec![0.9, 0.5, 0.2, 0.05, 0.001];
598
599        for freq in access_frequencies {
600            let compressed = compressor.compress(&embedding, freq).unwrap();
601            let decompressed = compressor.decompress(&compressed).unwrap();
602            assert_eq!(decompressed.len(), embedding.len());
603        }
604    }
605
606    #[test]
607    fn test_half_precision_roundtrip() {
608        let compressor = TensorCompress::new();
609        // Use values within the supported range (-32.768 to 32.767)
610        let values = vec![-30.0, -1.0, 0.0, 1.0, 30.0];
611
612        for val in values {
613            let embedding = vec![val; 4];
614            let compressed = compressor
615                .compress_with_level(&embedding, &CompressionLevel::Half { scale: 1.0 })
616                .unwrap();
617            let decompressed = compressor.decompress(&compressed).unwrap();
618
619            for (a, b) in embedding.iter().zip(decompressed.iter()) {
620                let diff = (a - b).abs();
621                assert!(
622                    diff < 0.1,
623                    "Value {} decompressed to {}, diff: {}",
624                    a,
625                    b,
626                    diff
627                );
628            }
629        }
630    }
631
632    #[test]
633    fn test_binary_threshold() {
634        let compressor = TensorCompress::new();
635        let embedding = vec![0.5, -0.5, 1.5, -1.5];
636
637        let compressed = compressor
638            .compress_with_level(&embedding, &CompressionLevel::Binary { threshold: 0.0 })
639            .unwrap();
640        let decompressed = compressor.decompress(&compressed).unwrap();
641
642        // Values > 0 should be 1.0, values <= 0 should be -1.0
643        assert_eq!(decompressed, vec![1.0, -1.0, 1.0, -1.0]);
644    }
645
646    #[test]
647    fn test_pq4_with_outliers() {
648        let compressor = TensorCompress::new();
649        // Create embedding with some outliers
650        let mut embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
651        embedding[10] = 100.0; // Outlier
652        embedding[30] = -100.0; // Outlier
653
654        let compressed = compressor
655            .compress_with_level(
656                &embedding,
657                &CompressionLevel::PQ4 {
658                    subvectors: 8,
659                    outlier_threshold: 2.0,
660                },
661            )
662            .unwrap();
663        let decompressed = compressor.decompress(&compressed).unwrap();
664
665        assert_eq!(decompressed.len(), embedding.len());
666        // Outliers should be preserved
667        assert_eq!(decompressed[10], 100.0);
668        assert_eq!(decompressed[30], -100.0);
669    }
670
671    #[test]
672    fn test_dimension_validation() {
673        let compressor = TensorCompress::new();
674        let embedding = vec![1.0; 10]; // Not divisible by 8
675
676        let result = compressor.compress_pq8(&embedding, 8, 16);
677        assert!(result.is_err());
678    }
679}