1use crate::error::{GnnError, Result};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub enum CompressionLevel {
16 None,
18
19 Half { scale: f32 },
21
22 PQ8 { subvectors: u8, centroids: u8 },
24
25 PQ4 {
27 subvectors: u8,
28 outlier_threshold: f32,
29 },
30
31 Binary { threshold: f32 },
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum CompressedTensor {
38 Full { data: Vec<f32> },
40
41 Half {
43 data: Vec<u16>,
44 scale: f32,
45 dim: usize,
46 },
47
48 PQ8 {
50 codes: Vec<u8>,
51 codebooks: Vec<Vec<f32>>,
52 subvector_dim: usize,
53 dim: usize,
54 },
55
56 PQ4 {
58 codes: Vec<u8>, codebooks: Vec<Vec<f32>>,
60 outliers: Vec<(usize, f32)>, subvector_dim: usize,
62 dim: usize,
63 },
64
65 Binary {
67 bits: Vec<u8>,
68 threshold: f32,
69 dim: usize,
70 },
71}
72
73#[derive(Debug, Clone)]
75pub struct TensorCompress {
76 default_level: CompressionLevel,
78}
79
80impl Default for TensorCompress {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl TensorCompress {
87 pub fn new() -> Self {
89 Self {
90 default_level: CompressionLevel::None,
91 }
92 }
93
94 pub fn compress(&self, embedding: &[f32], access_freq: f32) -> Result<CompressedTensor> {
103 if embedding.is_empty() {
104 return Err(GnnError::InvalidInput("Empty embedding vector".to_string()));
105 }
106
107 let level = self.select_level(access_freq);
108 self.compress_with_level(embedding, &level)
109 }
110
111 pub fn compress_with_level(
113 &self,
114 embedding: &[f32],
115 level: &CompressionLevel,
116 ) -> Result<CompressedTensor> {
117 match level {
118 CompressionLevel::None => self.compress_none(embedding),
119 CompressionLevel::Half { scale } => self.compress_half(embedding, *scale),
120 CompressionLevel::PQ8 {
121 subvectors,
122 centroids,
123 } => self.compress_pq8(embedding, *subvectors, *centroids),
124 CompressionLevel::PQ4 {
125 subvectors,
126 outlier_threshold,
127 } => self.compress_pq4(embedding, *subvectors, *outlier_threshold),
128 CompressionLevel::Binary { threshold } => self.compress_binary(embedding, *threshold),
129 }
130 }
131
132 pub fn decompress(&self, compressed: &CompressedTensor) -> Result<Vec<f32>> {
134 match compressed {
135 CompressedTensor::Full { data } => Ok(data.clone()),
136 CompressedTensor::Half { data, scale, dim } => self.decompress_half(data, *scale, *dim),
137 CompressedTensor::PQ8 {
138 codes,
139 codebooks,
140 subvector_dim,
141 dim,
142 } => self.decompress_pq8(codes, codebooks, *subvector_dim, *dim),
143 CompressedTensor::PQ4 {
144 codes,
145 codebooks,
146 outliers,
147 subvector_dim,
148 dim,
149 } => self.decompress_pq4(codes, codebooks, outliers, *subvector_dim, *dim),
150 CompressedTensor::Binary {
151 bits,
152 threshold,
153 dim,
154 } => self.decompress_binary(bits, *threshold, *dim),
155 }
156 }
157
158 fn select_level(&self, access_freq: f32) -> CompressionLevel {
167 if access_freq > 0.8 {
168 CompressionLevel::None
169 } else if access_freq > 0.4 {
170 CompressionLevel::Half { scale: 1.0 }
171 } else if access_freq > 0.1 {
172 CompressionLevel::PQ8 {
173 subvectors: 8,
174 centroids: 16,
175 }
176 } else if access_freq > 0.01 {
177 CompressionLevel::PQ4 {
178 subvectors: 8,
179 outlier_threshold: 3.0,
180 }
181 } else {
182 CompressionLevel::Binary { threshold: 0.0 }
183 }
184 }
185
186 fn compress_none(&self, embedding: &[f32]) -> Result<CompressedTensor> {
189 Ok(CompressedTensor::Full {
190 data: embedding.to_vec(),
191 })
192 }
193
194 fn compress_half(&self, embedding: &[f32], scale: f32) -> Result<CompressedTensor> {
195 let data: Vec<u16> = embedding
197 .iter()
198 .map(|&x| {
199 let scaled = x * scale;
200 let clamped = scaled.clamp(-65504.0, 65504.0);
201 f32_to_f16_bits(clamped)
203 })
204 .collect();
205
206 Ok(CompressedTensor::Half {
207 data,
208 scale,
209 dim: embedding.len(),
210 })
211 }
212
213 fn compress_pq8(
214 &self,
215 embedding: &[f32],
216 subvectors: u8,
217 centroids: u8,
218 ) -> Result<CompressedTensor> {
219 let dim = embedding.len();
220 let subvectors = subvectors as usize;
221
222 if dim % subvectors != 0 {
223 return Err(GnnError::InvalidInput(format!(
224 "Dimension {} not divisible by subvectors {}",
225 dim, subvectors
226 )));
227 }
228
229 let subvector_dim = dim / subvectors;
230 let mut codes = Vec::with_capacity(subvectors);
231 let mut codebooks = Vec::with_capacity(subvectors);
232
233 for i in 0..subvectors {
235 let start = i * subvector_dim;
236 let end = start + subvector_dim;
237 let subvector = &embedding[start..end];
238
239 let (codebook, code) = self.quantize_subvector(subvector, centroids as usize);
241 codes.push(code);
242 codebooks.push(codebook);
243 }
244
245 Ok(CompressedTensor::PQ8 {
246 codes,
247 codebooks,
248 subvector_dim,
249 dim,
250 })
251 }
252
253 fn compress_pq4(
254 &self,
255 embedding: &[f32],
256 subvectors: u8,
257 outlier_threshold: f32,
258 ) -> Result<CompressedTensor> {
259 let dim = embedding.len();
260 let subvectors = subvectors as usize;
261
262 if dim % subvectors != 0 {
263 return Err(GnnError::InvalidInput(format!(
264 "Dimension {} not divisible by subvectors {}",
265 dim, subvectors
266 )));
267 }
268
269 let subvector_dim = dim / subvectors;
270 let mut codes = Vec::with_capacity(subvectors);
271 let mut codebooks = Vec::with_capacity(subvectors);
272 let mut outliers = Vec::new();
273
274 let mean = embedding.iter().sum::<f32>() / dim as f32;
276 let std_dev =
277 (embedding.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / dim as f32).sqrt();
278
279 for i in 0..subvectors {
281 let start = i * subvector_dim;
282 let end = start + subvector_dim;
283 let subvector = &embedding[start..end];
284
285 let mut cleaned_subvector = subvector.to_vec();
287 for (j, &val) in subvector.iter().enumerate() {
288 if (val - mean).abs() > outlier_threshold * std_dev {
289 outliers.push((start + j, val));
290 cleaned_subvector[j] = mean; }
292 }
293
294 let (codebook, code) = self.quantize_subvector(&cleaned_subvector, 16);
296 codes.push(code);
297 codebooks.push(codebook);
298 }
299
300 Ok(CompressedTensor::PQ4 {
301 codes,
302 codebooks,
303 outliers,
304 subvector_dim,
305 dim,
306 })
307 }
308
309 fn compress_binary(&self, embedding: &[f32], threshold: f32) -> Result<CompressedTensor> {
310 let dim = embedding.len();
311 let num_bytes = (dim + 7) / 8;
312 let mut bits = vec![0u8; num_bytes];
313
314 for (i, &val) in embedding.iter().enumerate() {
315 if val > threshold {
316 let byte_idx = i / 8;
317 let bit_idx = i % 8;
318 bits[byte_idx] |= 1 << bit_idx;
319 }
320 }
321
322 Ok(CompressedTensor::Binary {
323 bits,
324 threshold,
325 dim,
326 })
327 }
328
329 fn decompress_half(&self, data: &[u16], scale: f32, dim: usize) -> Result<Vec<f32>> {
332 if data.len() != dim {
333 return Err(GnnError::InvalidInput(format!(
334 "Dimension mismatch: expected {}, got {}",
335 dim,
336 data.len()
337 )));
338 }
339
340 Ok(data
341 .iter()
342 .map(|&bits| f16_bits_to_f32(bits) / scale)
343 .collect())
344 }
345
346 fn decompress_pq8(
347 &self,
348 codes: &[u8],
349 codebooks: &[Vec<f32>],
350 subvector_dim: usize,
351 dim: usize,
352 ) -> Result<Vec<f32>> {
353 let subvectors = codes.len();
354 let expected_dim = subvectors * subvector_dim;
355
356 if expected_dim != dim {
357 return Err(GnnError::InvalidInput(format!(
358 "Dimension mismatch: expected {}, got {}",
359 dim, expected_dim
360 )));
361 }
362
363 let mut result = Vec::with_capacity(dim);
364
365 for (code, codebook) in codes.iter().zip(codebooks.iter()) {
366 let centroid_idx = *code as usize;
367 if centroid_idx >= codebook.len() / subvector_dim {
368 return Err(GnnError::InvalidInput(format!(
369 "Invalid centroid index: {}",
370 centroid_idx
371 )));
372 }
373
374 let start = centroid_idx * subvector_dim;
375 let end = start + subvector_dim;
376 result.extend_from_slice(&codebook[start..end]);
377 }
378
379 Ok(result)
380 }
381
382 fn decompress_pq4(
383 &self,
384 codes: &[u8],
385 codebooks: &[Vec<f32>],
386 outliers: &[(usize, f32)],
387 subvector_dim: usize,
388 dim: usize,
389 ) -> Result<Vec<f32>> {
390 let mut result = self.decompress_pq8(codes, codebooks, subvector_dim, dim)?;
392
393 for &(idx, val) in outliers {
395 if idx < result.len() {
396 result[idx] = val;
397 }
398 }
399
400 Ok(result)
401 }
402
403 fn decompress_binary(&self, bits: &[u8], _threshold: f32, dim: usize) -> Result<Vec<f32>> {
404 let expected_bytes = (dim + 7) / 8;
405 if bits.len() != expected_bytes {
406 return Err(GnnError::InvalidInput(format!(
407 "Dimension mismatch: expected {} bytes, got {}",
408 expected_bytes,
409 bits.len()
410 )));
411 }
412
413 let mut result = Vec::with_capacity(dim);
414
415 for i in 0..dim {
416 let byte_idx = i / 8;
417 let bit_idx = i % 8;
418 let is_set = (bits[byte_idx] & (1 << bit_idx)) != 0;
419 result.push(if is_set { 1.0 } else { -1.0 });
420 }
421
422 Ok(result)
423 }
424
425 fn quantize_subvector(&self, subvector: &[f32], k: usize) -> (Vec<f32>, u8) {
429 let dim = subvector.len();
430
431 let min_val = subvector.iter().cloned().fold(f32::INFINITY, f32::min);
433 let max_val = subvector.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
434 let range = max_val - min_val;
435
436 if range < 1e-6 {
437 let codebook = vec![min_val; dim * k];
439 return (codebook, 0);
440 }
441
442 let centroids: Vec<Vec<f32>> = (0..k)
444 .map(|i| {
445 let offset = min_val + (i as f32 / k as f32) * range;
446 vec![offset; dim]
447 })
448 .collect();
449
450 let code = self.nearest_centroid(subvector, ¢roids);
452
453 let codebook: Vec<f32> = centroids.into_iter().flatten().collect();
455
456 (codebook, code as u8)
457 }
458
459 fn nearest_centroid(&self, subvector: &[f32], centroids: &[Vec<f32>]) -> usize {
460 centroids
461 .iter()
462 .enumerate()
463 .map(|(i, centroid)| {
464 let dist: f32 = subvector
465 .iter()
466 .zip(centroid.iter())
467 .map(|(a, b)| (a - b).powi(2))
468 .sum();
469 (i, dist)
470 })
471 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
472 .map(|(i, _)| i)
473 .unwrap_or(0)
474 }
475}
476
477fn f32_to_f16_bits(value: f32) -> u16 {
481 let scaled = (value * 1000.0).clamp(-32768.0, 32767.0);
484 ((scaled as i32) + 32768) as u16
485}
486
487fn f16_bits_to_f32(bits: u16) -> f32 {
489 let value = bits as i32 - 32768;
491 value as f32 / 1000.0
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_compress_none() {
500 let compressor = TensorCompress::new();
501 let embedding = vec![1.0, 2.0, 3.0, 4.0];
502
503 let compressed = compressor.compress(&embedding, 1.0).unwrap();
504 let decompressed = compressor.decompress(&compressed).unwrap();
505
506 assert_eq!(embedding, decompressed);
507 }
508
509 #[test]
510 fn test_compress_half() {
511 let compressor = TensorCompress::new();
512 let embedding = vec![1.0, 2.0, 3.0, 4.0];
513
514 let compressed = compressor.compress(&embedding, 0.5).unwrap();
515 let decompressed = compressor.decompress(&compressed).unwrap();
516
517 for (a, b) in embedding.iter().zip(decompressed.iter()) {
519 assert!((a - b).abs() < 0.01, "Expected {}, got {}", a, b);
520 }
521 }
522
523 #[test]
524 fn test_compress_binary() {
525 let compressor = TensorCompress::new();
526 let embedding = vec![1.0, -1.0, 0.5, -0.5];
527
528 let compressed = compressor.compress(&embedding, 0.005).unwrap();
529 let decompressed = compressor.decompress(&compressed).unwrap();
530
531 assert_eq!(decompressed.len(), embedding.len());
533 for val in &decompressed {
534 assert!(*val == 1.0 || *val == -1.0);
535 }
536 }
537
538 #[test]
539 fn test_select_level() {
540 let compressor = TensorCompress::new();
541
542 assert!(matches!(
544 compressor.select_level(0.9),
545 CompressionLevel::None
546 ));
547
548 assert!(matches!(
550 compressor.select_level(0.5),
551 CompressionLevel::Half { .. }
552 ));
553
554 assert!(matches!(
556 compressor.select_level(0.2),
557 CompressionLevel::PQ8 { .. }
558 ));
559
560 assert!(matches!(
562 compressor.select_level(0.05),
563 CompressionLevel::PQ4 { .. }
564 ));
565
566 assert!(matches!(
568 compressor.select_level(0.001),
569 CompressionLevel::Binary { .. }
570 ));
571 }
572
573 #[test]
574 fn test_empty_embedding() {
575 let compressor = TensorCompress::new();
576 let result = compressor.compress(&[], 0.5);
577 assert!(result.is_err());
578 }
579
580 #[test]
581 fn test_pq8_compression() {
582 let compressor = TensorCompress::new();
583 let embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
584
585 let compressed = compressor.compress_pq8(&embedding, 8, 16).unwrap();
586 let decompressed = compressor.decompress(&compressed).unwrap();
587
588 assert_eq!(decompressed.len(), embedding.len());
589 }
590
591 #[test]
592 fn test_round_trip_all_levels() {
593 let compressor = TensorCompress::new();
594 let embedding: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
595
596 let access_frequencies = vec![0.9, 0.5, 0.2, 0.05, 0.001];
597
598 for freq in access_frequencies {
599 let compressed = compressor.compress(&embedding, freq).unwrap();
600 let decompressed = compressor.decompress(&compressed).unwrap();
601 assert_eq!(decompressed.len(), embedding.len());
602 }
603 }
604
605 #[test]
606 fn test_half_precision_roundtrip() {
607 let compressor = TensorCompress::new();
608 let values = vec![-30.0, -1.0, 0.0, 1.0, 30.0];
610
611 for val in values {
612 let embedding = vec![val; 4];
613 let compressed = compressor
614 .compress_with_level(&embedding, &CompressionLevel::Half { scale: 1.0 })
615 .unwrap();
616 let decompressed = compressor.decompress(&compressed).unwrap();
617
618 for (a, b) in embedding.iter().zip(decompressed.iter()) {
619 let diff = (a - b).abs();
620 assert!(
621 diff < 0.1,
622 "Value {} decompressed to {}, diff: {}",
623 a,
624 b,
625 diff
626 );
627 }
628 }
629 }
630
631 #[test]
632 fn test_binary_threshold() {
633 let compressor = TensorCompress::new();
634 let embedding = vec![0.5, -0.5, 1.5, -1.5];
635
636 let compressed = compressor
637 .compress_with_level(&embedding, &CompressionLevel::Binary { threshold: 0.0 })
638 .unwrap();
639 let decompressed = compressor.decompress(&compressed).unwrap();
640
641 assert_eq!(decompressed, vec![1.0, -1.0, 1.0, -1.0]);
643 }
644
645 #[test]
646 fn test_pq4_with_outliers() {
647 let compressor = TensorCompress::new();
648 let mut embedding: Vec<f32> = (0..64).map(|i| i as f32 * 0.01).collect();
650 embedding[10] = 100.0; embedding[30] = -100.0; let compressed = compressor
654 .compress_with_level(
655 &embedding,
656 &CompressionLevel::PQ4 {
657 subvectors: 8,
658 outlier_threshold: 2.0,
659 },
660 )
661 .unwrap();
662 let decompressed = compressor.decompress(&compressed).unwrap();
663
664 assert_eq!(decompressed.len(), embedding.len());
665 assert_eq!(decompressed[10], 100.0);
667 assert_eq!(decompressed[30], -100.0);
668 }
669
670 #[test]
671 fn test_dimension_validation() {
672 let compressor = TensorCompress::new();
673 let embedding = vec![1.0; 10]; let result = compressor.compress_pq8(&embedding, 8, 16);
676 assert!(result.is_err());
677 }
678}