1use crate::error::{GnnError, Result};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub enum CompressionLevel {
16 None,
18
19 Half { scale: f32 },
21
22 PQ8 { subvectors: u8, centroids: u8 },
24
25 PQ4 {
27 subvectors: u8,
28 outlier_threshold: f32,
29 },
30
31 Binary { threshold: f32 },
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum CompressedTensor {
38 Full { data: Vec<f32> },
40
41 Half {
43 data: Vec<u16>,
44 scale: f32,
45 dim: usize,
46 },
47
48 PQ8 {
50 codes: Vec<u8>,
51 codebooks: Vec<Vec<f32>>,
52 subvector_dim: usize,
53 dim: usize,
54 },
55
56 PQ4 {
58 codes: Vec<u8>, codebooks: Vec<Vec<f32>>,
60 outliers: Vec<(usize, f32)>, subvector_dim: usize,
62 dim: usize,
63 },
64
65 Binary {
67 bits: Vec<u8>,
68 threshold: f32,
69 dim: usize,
70 },
71}
72
73#[derive(Debug, Clone)]
75pub struct TensorCompress {
76 #[allow(dead_code)]
78 default_level: CompressionLevel,
79}
80
81impl Default for TensorCompress {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl TensorCompress {
88 pub fn new() -> Self {
90 Self {
91 default_level: CompressionLevel::None,
92 }
93 }
94
95 pub fn compress(&self, embedding: &[f32], access_freq: f32) -> Result<CompressedTensor> {
104 if embedding.is_empty() {
105 return Err(GnnError::InvalidInput("Empty embedding vector".to_string()));
106 }
107
108 let level = self.select_level(access_freq);
109 self.compress_with_level(embedding, &level)
110 }
111
112 pub fn compress_with_level(
114 &self,
115 embedding: &[f32],
116 level: &CompressionLevel,
117 ) -> Result<CompressedTensor> {
118 match level {
119 CompressionLevel::None => self.compress_none(embedding),
120 CompressionLevel::Half { scale } => self.compress_half(embedding, *scale),
121 CompressionLevel::PQ8 {
122 subvectors,
123 centroids,
124 } => self.compress_pq8(embedding, *subvectors, *centroids),
125 CompressionLevel::PQ4 {
126 subvectors,
127 outlier_threshold,
128 } => self.compress_pq4(embedding, *subvectors, *outlier_threshold),
129 CompressionLevel::Binary { threshold } => self.compress_binary(embedding, *threshold),
130 }
131 }
132
133 pub fn decompress(&self, compressed: &CompressedTensor) -> Result<Vec<f32>> {
135 match compressed {
136 CompressedTensor::Full { data } => Ok(data.clone()),
137 CompressedTensor::Half { data, scale, dim } => self.decompress_half(data, *scale, *dim),
138 CompressedTensor::PQ8 {
139 codes,
140 codebooks,
141 subvector_dim,
142 dim,
143 } => self.decompress_pq8(codes, codebooks, *subvector_dim, *dim),
144 CompressedTensor::PQ4 {
145 codes,
146 codebooks,
147 outliers,
148 subvector_dim,
149 dim,
150 } => self.decompress_pq4(codes, codebooks, outliers, *subvector_dim, *dim),
151 CompressedTensor::Binary {
152 bits,
153 threshold,
154 dim,
155 } => self.decompress_binary(bits, *threshold, *dim),
156 }
157 }
158
159 fn select_level(&self, access_freq: f32) -> CompressionLevel {
168 if access_freq > 0.8 {
169 CompressionLevel::None
170 } else if access_freq > 0.4 {
171 CompressionLevel::Half { scale: 1.0 }
172 } else if access_freq > 0.1 {
173 CompressionLevel::PQ8 {
174 subvectors: 8,
175 centroids: 16,
176 }
177 } else if access_freq > 0.01 {
178 CompressionLevel::PQ4 {
179 subvectors: 8,
180 outlier_threshold: 3.0,
181 }
182 } else {
183 CompressionLevel::Binary { threshold: 0.0 }
184 }
185 }
186
187 fn compress_none(&self, embedding: &[f32]) -> Result<CompressedTensor> {
190 Ok(CompressedTensor::Full {
191 data: embedding.to_vec(),
192 })
193 }
194
195 fn compress_half(&self, embedding: &[f32], scale: f32) -> Result<CompressedTensor> {
196 let data: Vec<u16> = embedding
198 .iter()
199 .map(|&x| {
200 let scaled = x * scale;
201 let clamped = scaled.clamp(-65504.0, 65504.0);
202 f32_to_f16_bits(clamped)
204 })
205 .collect();
206
207 Ok(CompressedTensor::Half {
208 data,
209 scale,
210 dim: embedding.len(),
211 })
212 }
213
214 fn compress_pq8(
215 &self,
216 embedding: &[f32],
217 subvectors: u8,
218 centroids: u8,
219 ) -> Result<CompressedTensor> {
220 let dim = embedding.len();
221 let subvectors = subvectors as usize;
222
223 if dim % subvectors != 0 {
224 return Err(GnnError::InvalidInput(format!(
225 "Dimension {} not divisible by subvectors {}",
226 dim, subvectors
227 )));
228 }
229
230 let subvector_dim = dim / subvectors;
231 let mut codes = Vec::with_capacity(subvectors);
232 let mut codebooks = Vec::with_capacity(subvectors);
233
234 for i in 0..subvectors {
236 let start = i * subvector_dim;
237 let end = start + subvector_dim;
238 let subvector = &embedding[start..end];
239
240 let (codebook, code) = self.quantize_subvector(subvector, centroids as usize);
242 codes.push(code);
243 codebooks.push(codebook);
244 }
245
246 Ok(CompressedTensor::PQ8 {
247 codes,
248 codebooks,
249 subvector_dim,
250 dim,
251 })
252 }
253
254 fn compress_pq4(
255 &self,
256 embedding: &[f32],
257 subvectors: u8,
258 outlier_threshold: f32,
259 ) -> Result<CompressedTensor> {
260 let dim = embedding.len();
261 let subvectors = subvectors as usize;
262
263 if dim % subvectors != 0 {
264 return Err(GnnError::InvalidInput(format!(
265 "Dimension {} not divisible by subvectors {}",
266 dim, subvectors
267 )));
268 }
269
270 let subvector_dim = dim / subvectors;
271 let mut codes = Vec::with_capacity(subvectors);
272 let mut codebooks = Vec::with_capacity(subvectors);
273 let mut outliers = Vec::new();
274
275 let mean = embedding.iter().sum::<f32>() / dim as f32;
277 let std_dev =
278 (embedding.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / dim as f32).sqrt();
279
280 for i in 0..subvectors {
282 let start = i * subvector_dim;
283 let end = start + subvector_dim;
284 let subvector = &embedding[start..end];
285
286 let mut cleaned_subvector = subvector.to_vec();
288 for (j, &val) in subvector.iter().enumerate() {
289 if (val - mean).abs() > outlier_threshold * std_dev {
290 outliers.push((start + j, val));
291 cleaned_subvector[j] = mean; }
293 }
294
295 let (codebook, code) = self.quantize_subvector(&cleaned_subvector, 16);
297 codes.push(code);
298 codebooks.push(codebook);
299 }
300
301 Ok(CompressedTensor::PQ4 {
302 codes,
303 codebooks,
304 outliers,
305 subvector_dim,
306 dim,
307 })
308 }
309
310 fn compress_binary(&self, embedding: &[f32], threshold: f32) -> Result<CompressedTensor> {
311 let dim = embedding.len();
312 let num_bytes = dim.div_ceil(8);
313 let mut bits = vec![0u8; num_bytes];
314
315 for (i, &val) in embedding.iter().enumerate() {
316 if val > threshold {
317 let byte_idx = i / 8;
318 let bit_idx = i % 8;
319 bits[byte_idx] |= 1 << bit_idx;
320 }
321 }
322
323 Ok(CompressedTensor::Binary {
324 bits,
325 threshold,
326 dim,
327 })
328 }
329
330 fn decompress_half(&self, data: &[u16], scale: f32, dim: usize) -> Result<Vec<f32>> {
333 if data.len() != dim {
334 return Err(GnnError::InvalidInput(format!(
335 "Dimension mismatch: expected {}, got {}",
336 dim,
337 data.len()
338 )));
339 }
340
341 Ok(data
342 .iter()
343 .map(|&bits| f16_bits_to_f32(bits) / scale)
344 .collect())
345 }
346
347 fn decompress_pq8(
348 &self,
349 codes: &[u8],
350 codebooks: &[Vec<f32>],
351 subvector_dim: usize,
352 dim: usize,
353 ) -> Result<Vec<f32>> {
354 let subvectors = codes.len();
355 let expected_dim = subvectors * subvector_dim;
356
357 if expected_dim != dim {
358 return Err(GnnError::InvalidInput(format!(
359 "Dimension mismatch: expected {}, got {}",
360 dim, expected_dim
361 )));
362 }
363
364 let mut result = Vec::with_capacity(dim);
365
366 for (code, codebook) in codes.iter().zip(codebooks.iter()) {
367 let centroid_idx = *code as usize;
368 if centroid_idx >= codebook.len() / subvector_dim {
369 return Err(GnnError::InvalidInput(format!(
370 "Invalid centroid index: {}",
371 centroid_idx
372 )));
373 }
374
375 let start = centroid_idx * subvector_dim;
376 let end = start + subvector_dim;
377 result.extend_from_slice(&codebook[start..end]);
378 }
379
380 Ok(result)
381 }
382
383 fn decompress_pq4(
384 &self,
385 codes: &[u8],
386 codebooks: &[Vec<f32>],
387 outliers: &[(usize, f32)],
388 subvector_dim: usize,
389 dim: usize,
390 ) -> Result<Vec<f32>> {
391 let mut result = self.decompress_pq8(codes, codebooks, subvector_dim, dim)?;
393
394 for &(idx, val) in outliers {
396 if idx < result.len() {
397 result[idx] = val;
398 }
399 }
400
401 Ok(result)
402 }
403
404 fn decompress_binary(&self, bits: &[u8], _threshold: f32, dim: usize) -> Result<Vec<f32>> {
405 let expected_bytes = dim.div_ceil(8);
406 if bits.len() != expected_bytes {
407 return Err(GnnError::InvalidInput(format!(
408 "Dimension mismatch: expected {} bytes, got {}",
409 expected_bytes,
410 bits.len()
411 )));
412 }
413
414 let mut result = Vec::with_capacity(dim);
415
416 for i in 0..dim {
417 let byte_idx = i / 8;
418 let bit_idx = i % 8;
419 let is_set = (bits[byte_idx] & (1 << bit_idx)) != 0;
420 result.push(if is_set { 1.0 } else { -1.0 });
421 }
422
423 Ok(result)
424 }
425
426 fn quantize_subvector(&self, subvector: &[f32], k: usize) -> (Vec<f32>, u8) {
430 let dim = subvector.len();
431
432 let min_val = subvector.iter().cloned().fold(f32::INFINITY, f32::min);
434 let max_val = subvector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
435 let range = max_val - min_val;
436
437 if range < 1e-6 {
438 let codebook = vec![min_val; dim * k];
440 return (codebook, 0);
441 }
442
443 let centroids: Vec<Vec<f32>> = (0..k)
445 .map(|i| {
446 let offset = min_val + (i as f32 / k as f32) * range;
447 vec![offset; dim]
448 })
449 .collect();
450
451 let code = self.nearest_centroid(subvector, ¢roids);
453
454 let codebook: Vec<f32> = centroids.into_iter().flatten().collect();
456
457 (codebook, code as u8)
458 }
459
460 fn nearest_centroid(&self, subvector: &[f32], centroids: &[Vec<f32>]) -> usize {
461 centroids
462 .iter()
463 .enumerate()
464 .map(|(i, centroid)| {
465 let dist: f32 = subvector
466 .iter()
467 .zip(centroid.iter())
468 .map(|(a, b)| (a - b).powi(2))
469 .sum();
470 (i, dist)
471 })
472 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
473 .map(|(i, _)| i)
474 .unwrap_or(0)
475 }
476}
477
478fn f32_to_f16_bits(value: f32) -> u16 {
482 let scaled = (value * 1000.0).clamp(-32768.0, 32767.0);
485 ((scaled as i32) + 32768) as u16
486}
487
488fn f16_bits_to_f32(bits: u16) -> f32 {
490 let value = bits as i32 - 32768;
492 value as f32 / 1000.0
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_compress_none() {
501 let compressor = TensorCompress::new();
502 let embedding = vec![1.0, 2.0, 3.0, 4.0];
503
504 let compressed = compressor.compress(&embedding, 1.0).unwrap();
505 let decompressed = compressor.decompress(&compressed).unwrap();
506
507 assert_eq!(embedding, decompressed);
508 }
509
510 #[test]
511 fn test_compress_half() {
512 let compressor = TensorCompress::new();
513 let embedding = vec![1.0, 2.0, 3.0, 4.0];
514
515 let compressed = compressor.compress(&embedding, 0.5).unwrap();
516 let decompressed = compressor.decompress(&compressed).unwrap();
517
518 for (a, b) in embedding.iter().zip(decompressed.iter()) {
520 assert!((a - b).abs() < 0.01, "Expected {}, got {}", a, b);
521 }
522 }
523
524 #[test]
525 fn test_compress_binary() {
526 let compressor = TensorCompress::new();
527 let embedding = vec![1.0, -1.0, 0.5, -0.5];
528
529 let compressed = compressor.compress(&embedding, 0.005).unwrap();
530 let decompressed = compressor.decompress(&compressed).unwrap();
531
532 assert_eq!(decompressed.len(), embedding.len());
534 for val in &decompressed {
535 assert!(*val == 1.0 || *val == -1.0);
536 }
537 }
538
539 #[test]
540 fn test_select_level() {
541 let compressor = TensorCompress::new();
542
543 assert!(matches!(
545 compressor.select_level(0.9),
546 CompressionLevel::None
547 ));
548
549 assert!(matches!(
551 compressor.select_level(0.5),
552 CompressionLevel::Half { .. }
553 ));
554
555 assert!(matches!(
557 compressor.select_level(0.2),
558 CompressionLevel::PQ8 { .. }
559 ));
560
561 assert!(matches!(
563 compressor.select_level(0.05),
564 CompressionLevel::PQ4 { .. }
565 ));
566
567 assert!(matches!(
569 compressor.select_level(0.001),
570 CompressionLevel::Binary { .. }
571 ));
572 }
573
574 #[test]
575 fn test_empty_embedding() {
576 let compressor = TensorCompress::new();
577 let result = compressor.compress(&[], 0.5);
578 assert!(result.is_err());
579 }
580
581 #[test]
582 fn test_pq8_compression() {
583 let compressor = TensorCompress::new();
584 let embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
585
586 let compressed = compressor.compress_pq8(&embedding, 8, 16).unwrap();
587 let decompressed = compressor.decompress(&compressed).unwrap();
588
589 assert_eq!(decompressed.len(), embedding.len());
590 }
591
592 #[test]
593 fn test_round_trip_all_levels() {
594 let compressor = TensorCompress::new();
595 let embedding: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
596
597 let access_frequencies = vec![0.9, 0.5, 0.2, 0.05, 0.001];
598
599 for freq in access_frequencies {
600 let compressed = compressor.compress(&embedding, freq).unwrap();
601 let decompressed = compressor.decompress(&compressed).unwrap();
602 assert_eq!(decompressed.len(), embedding.len());
603 }
604 }
605
606 #[test]
607 fn test_half_precision_roundtrip() {
608 let compressor = TensorCompress::new();
609 let values = vec![-30.0, -1.0, 0.0, 1.0, 30.0];
611
612 for val in values {
613 let embedding = vec![val; 4];
614 let compressed = compressor
615 .compress_with_level(&embedding, &CompressionLevel::Half { scale: 1.0 })
616 .unwrap();
617 let decompressed = compressor.decompress(&compressed).unwrap();
618
619 for (a, b) in embedding.iter().zip(decompressed.iter()) {
620 let diff = (a - b).abs();
621 assert!(
622 diff < 0.1,
623 "Value {} decompressed to {}, diff: {}",
624 a,
625 b,
626 diff
627 );
628 }
629 }
630 }
631
632 #[test]
633 fn test_binary_threshold() {
634 let compressor = TensorCompress::new();
635 let embedding = vec![0.5, -0.5, 1.5, -1.5];
636
637 let compressed = compressor
638 .compress_with_level(&embedding, &CompressionLevel::Binary { threshold: 0.0 })
639 .unwrap();
640 let decompressed = compressor.decompress(&compressed).unwrap();
641
642 assert_eq!(decompressed, vec![1.0, -1.0, 1.0, -1.0]);
644 }
645
646 #[test]
647 fn test_pq4_with_outliers() {
648 let compressor = TensorCompress::new();
649 let mut embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
651 embedding[10] = 100.0; embedding[30] = -100.0; let compressed = compressor
655 .compress_with_level(
656 &embedding,
657 &CompressionLevel::PQ4 {
658 subvectors: 8,
659 outlier_threshold: 2.0,
660 },
661 )
662 .unwrap();
663 let decompressed = compressor.decompress(&compressed).unwrap();
664
665 assert_eq!(decompressed.len(), embedding.len());
666 assert_eq!(decompressed[10], 100.0);
668 assert_eq!(decompressed[30], -100.0);
669 }
670
671 #[test]
672 fn test_dimension_validation() {
673 let compressor = TensorCompress::new();
674 let embedding = vec![1.0; 10]; let result = compressor.compress_pq8(&embedding, 8, 16);
677 assert!(result.is_err());
678 }
679}