Skip to main content

ruvector_core/
quantization.rs

1//! Quantization techniques for memory compression
2//!
3//! This module provides tiered quantization strategies as specified in ADR-001:
4//!
5//! | Quantization | Compression | Use Case |
6//! |--------------|-------------|----------|
7//! | Scalar (u8)  | 4x          | Warm data (40-80% access) |
8//! | Int4         | 8x          | Cool data (10-40% access) |
9//! | Product      | 8-16x       | Cold data (1-10% access) |
10//! | Binary       | 32x         | Archive (<1% access) |
11//!
12//! ## Performance Optimizations v2
13//!
14//! - SIMD-accelerated distance calculations for scalar (int8) quantization
15//! - SIMD popcnt for binary hamming distance
16//! - 4x loop unrolling for better instruction-level parallelism
17//! - Separate accumulator strategy to reduce data dependencies
18
19use crate::error::Result;
20use serde::{Deserialize, Serialize};
21
22/// Trait for quantized vector representations
23pub trait QuantizedVector: Send + Sync {
24    /// Quantize a full-precision vector
25    fn quantize(vector: &[f32]) -> Self;
26
27    /// Calculate distance to another quantized vector
28    fn distance(&self, other: &Self) -> f32;
29
30    /// Reconstruct approximate full-precision vector
31    fn reconstruct(&self) -> Vec<f32>;
32}
33
34/// Scalar quantization to int8 (4x compression)
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ScalarQuantized {
37    /// Quantized values (int8)
38    pub data: Vec<u8>,
39    /// Minimum value for dequantization
40    pub min: f32,
41    /// Scale factor for dequantization
42    pub scale: f32,
43}
44
45impl QuantizedVector for ScalarQuantized {
46    fn quantize(vector: &[f32]) -> Self {
47        let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
48        let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
49
50        // Handle edge case where all values are the same (scale = 0)
51        let scale = if (max - min).abs() < f32::EPSILON {
52            1.0 // Arbitrary non-zero scale when all values are identical
53        } else {
54            (max - min) / 255.0
55        };
56
57        let data = vector
58            .iter()
59            .map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
60            .collect();
61
62        Self { data, min, scale }
63    }
64
65    fn distance(&self, other: &Self) -> f32 {
66        // Fast int8 distance calculation with SIMD optimization
67        // Use i32 to avoid overflow: max diff is 255, and 255*255=65025 fits in i32
68
69        // Scale handling: We use the average of both scales for balanced comparison.
70        // Using max(scale) would bias toward the vector with larger range,
71        // while average provides a more symmetric distance metric.
72        // This ensures distance(a, b) ≈ distance(b, a) in the reconstructed space.
73        let avg_scale = (self.scale + other.scale) / 2.0;
74
75        // Use SIMD-optimized version for larger vectors
76        #[cfg(target_arch = "aarch64")]
77        {
78            if self.data.len() >= 16 {
79                return unsafe { scalar_distance_neon(&self.data, &other.data) }.sqrt() * avg_scale;
80            }
81        }
82
83        #[cfg(target_arch = "x86_64")]
84        {
85            if self.data.len() >= 32 && is_x86_feature_detected!("avx2") {
86                return unsafe { scalar_distance_avx2(&self.data, &other.data) }.sqrt() * avg_scale;
87            }
88        }
89
90        // Scalar fallback with 4x loop unrolling for better ILP
91        scalar_distance_scalar(&self.data, &other.data).sqrt() * avg_scale
92    }
93
94    fn reconstruct(&self) -> Vec<f32> {
95        self.data
96            .iter()
97            .map(|&v| self.min + (v as f32) * self.scale)
98            .collect()
99    }
100}
101
102/// Product quantization (8-16x compression)
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ProductQuantized {
105    /// Quantized codes (one per subspace)
106    pub codes: Vec<u8>,
107    /// Codebooks for each subspace
108    pub codebooks: Vec<Vec<Vec<f32>>>,
109}
110
111impl ProductQuantized {
112    /// Train product quantization on a set of vectors
113    pub fn train(
114        vectors: &[Vec<f32>],
115        num_subspaces: usize,
116        codebook_size: usize,
117        iterations: usize,
118    ) -> Result<Self> {
119        if vectors.is_empty() {
120            return Err(crate::error::RuvectorError::InvalidInput(
121                "Cannot train on empty vector set".into(),
122            ));
123        }
124        if vectors[0].is_empty() {
125            return Err(crate::error::RuvectorError::InvalidInput(
126                "Cannot train on vectors with zero dimensions".into(),
127            ));
128        }
129        if codebook_size > 256 {
130            return Err(crate::error::RuvectorError::InvalidParameter(format!(
131                "Codebook size {} exceeds u8 maximum of 256",
132                codebook_size
133            )));
134        }
135        let dimensions = vectors[0].len();
136        let subspace_dim = dimensions / num_subspaces;
137
138        let mut codebooks = Vec::with_capacity(num_subspaces);
139
140        // Train codebook for each subspace using k-means
141        for subspace_idx in 0..num_subspaces {
142            let start = subspace_idx * subspace_dim;
143            let end = start + subspace_dim;
144
145            // Extract subspace vectors
146            let subspace_vectors: Vec<Vec<f32>> =
147                vectors.iter().map(|v| v[start..end].to_vec()).collect();
148
149            // Run k-means
150            let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
151            codebooks.push(codebook);
152        }
153
154        Ok(Self {
155            codes: vec![],
156            codebooks,
157        })
158    }
159
160    /// Quantize a vector using trained codebooks
161    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
162        let num_subspaces = self.codebooks.len();
163        let subspace_dim = vector.len() / num_subspaces;
164
165        let mut codes = Vec::with_capacity(num_subspaces);
166
167        for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
168            let start = subspace_idx * subspace_dim;
169            let end = start + subspace_dim;
170            let subvector = &vector[start..end];
171
172            // Find nearest centroid
173            let code = codebook
174                .iter()
175                .enumerate()
176                .min_by(|(_, a), (_, b)| {
177                    let dist_a = euclidean_squared(subvector, a);
178                    let dist_b = euclidean_squared(subvector, b);
179                    dist_a.partial_cmp(&dist_b).unwrap()
180                })
181                .map(|(idx, _)| idx as u8)
182                .unwrap_or(0);
183
184            codes.push(code);
185        }
186
187        codes
188    }
189}
190
191/// Int4 quantization (8x compression)
192///
193/// Quantizes f32 to 4-bit integers (0-15), packing 2 values per byte.
194/// Provides 8x compression with better precision than binary.
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct Int4Quantized {
197    /// Packed 4-bit values (2 per byte)
198    pub data: Vec<u8>,
199    /// Minimum value for dequantization
200    pub min: f32,
201    /// Scale factor for dequantization
202    pub scale: f32,
203    /// Number of dimensions
204    pub dimensions: usize,
205}
206
207impl Int4Quantized {
208    /// Quantize a vector to 4-bit representation
209    pub fn quantize(vector: &[f32]) -> Self {
210        let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
211        let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
212
213        // Handle edge case where all values are the same
214        let scale = if (max - min).abs() < f32::EPSILON {
215            1.0
216        } else {
217            (max - min) / 15.0 // 4-bit gives 0-15 range
218        };
219
220        let dimensions = vector.len();
221        let num_bytes = (dimensions + 1) / 2;
222        let mut data = vec![0u8; num_bytes];
223
224        for (i, &v) in vector.iter().enumerate() {
225            let quantized = ((v - min) / scale).round().clamp(0.0, 15.0) as u8;
226            let byte_idx = i / 2;
227            if i % 2 == 0 {
228                // Low nibble
229                data[byte_idx] |= quantized;
230            } else {
231                // High nibble
232                data[byte_idx] |= quantized << 4;
233            }
234        }
235
236        Self {
237            data,
238            min,
239            scale,
240            dimensions,
241        }
242    }
243
244    /// Calculate distance to another Int4 quantized vector
245    pub fn distance(&self, other: &Self) -> f32 {
246        assert_eq!(self.dimensions, other.dimensions);
247
248        // Use average scale for balanced comparison
249        let avg_scale = (self.scale + other.scale) / 2.0;
250        let avg_min = (self.min + other.min) / 2.0;
251
252        let mut sum_sq = 0i32;
253
254        for i in 0..self.dimensions {
255            let byte_idx = i / 2;
256            let shift = if i % 2 == 0 { 0 } else { 4 };
257
258            let a = ((self.data[byte_idx] >> shift) & 0x0F) as i32;
259            let b = ((other.data[byte_idx] >> shift) & 0x0F) as i32;
260            let diff = a - b;
261            sum_sq += diff * diff;
262        }
263
264        (sum_sq as f32).sqrt() * avg_scale
265    }
266
267    /// Reconstruct approximate full-precision vector
268    pub fn reconstruct(&self) -> Vec<f32> {
269        let mut result = Vec::with_capacity(self.dimensions);
270
271        for i in 0..self.dimensions {
272            let byte_idx = i / 2;
273            let shift = if i % 2 == 0 { 0 } else { 4 };
274            let quantized = (self.data[byte_idx] >> shift) & 0x0F;
275            result.push(self.min + (quantized as f32) * self.scale);
276        }
277
278        result
279    }
280
281    /// Get compression ratio (8x for Int4)
282    pub fn compression_ratio() -> f32 {
283        8.0 // f32 (4 bytes) -> 4 bits (0.5 bytes)
284    }
285}
286
287/// Binary quantization (32x compression)
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct BinaryQuantized {
290    /// Binary representation (1 bit per dimension, packed into bytes)
291    pub bits: Vec<u8>,
292    /// Number of dimensions
293    pub dimensions: usize,
294}
295
296impl QuantizedVector for BinaryQuantized {
297    fn quantize(vector: &[f32]) -> Self {
298        let dimensions = vector.len();
299        let num_bytes = (dimensions + 7) / 8;
300        let mut bits = vec![0u8; num_bytes];
301
302        for (i, &v) in vector.iter().enumerate() {
303            if v > 0.0 {
304                let byte_idx = i / 8;
305                let bit_idx = i % 8;
306                bits[byte_idx] |= 1 << bit_idx;
307            }
308        }
309
310        Self { bits, dimensions }
311    }
312
313    fn distance(&self, other: &Self) -> f32 {
314        // Hamming distance using SIMD-friendly operations
315        Self::hamming_distance_fast(&self.bits, &other.bits) as f32
316    }
317
318    fn reconstruct(&self) -> Vec<f32> {
319        let mut result = Vec::with_capacity(self.dimensions);
320
321        for i in 0..self.dimensions {
322            let byte_idx = i / 8;
323            let bit_idx = i % 8;
324            let bit = (self.bits[byte_idx] >> bit_idx) & 1;
325            result.push(if bit == 1 { 1.0 } else { -1.0 });
326        }
327
328        result
329    }
330}
331
332impl BinaryQuantized {
333    /// Fast hamming distance using SIMD-optimized operations
334    ///
335    /// Uses hardware POPCNT on x86_64 or NEON vcnt on ARM64 for optimal performance.
336    /// Processes 16 bytes at a time on ARM64, 8 bytes at a time on x86_64.
337    /// Falls back to 64-bit operations for remainders.
338    pub fn hamming_distance_fast(a: &[u8], b: &[u8]) -> u32 {
339        // Use SIMD-optimized version based on architecture
340        #[cfg(target_arch = "aarch64")]
341        {
342            if a.len() >= 16 {
343                return unsafe { hamming_distance_neon(a, b) };
344            }
345        }
346
347        #[cfg(target_arch = "x86_64")]
348        {
349            if a.len() >= 8 && is_x86_feature_detected!("popcnt") {
350                return unsafe { hamming_distance_simd_x86(a, b) };
351            }
352        }
353
354        // Scalar fallback using 64-bit operations
355        let mut distance = 0u32;
356
357        // Process 8 bytes at a time using u64
358        let chunks_a = a.chunks_exact(8);
359        let chunks_b = b.chunks_exact(8);
360        let remainder_a = chunks_a.remainder();
361        let remainder_b = chunks_b.remainder();
362
363        for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
364            let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
365            let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
366            distance += (a_u64 ^ b_u64).count_ones();
367        }
368
369        // Handle remainder bytes
370        for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
371            distance += (a_byte ^ b_byte).count_ones();
372        }
373
374        distance
375    }
376
377    /// Compute normalized hamming similarity (0.0 to 1.0)
378    pub fn similarity(&self, other: &Self) -> f32 {
379        let distance = self.distance(other);
380        1.0 - (distance / self.dimensions as f32)
381    }
382
383    /// Get compression ratio (32x for binary)
384    pub fn compression_ratio() -> f32 {
385        32.0 // f32 (4 bytes = 32 bits) -> 1 bit
386    }
387
388    /// Convert to bytes for storage
389    pub fn to_bytes(&self) -> &[u8] {
390        &self.bits
391    }
392
393    /// Create from bytes
394    pub fn from_bytes(bits: Vec<u8>, dimensions: usize) -> Self {
395        Self { bits, dimensions }
396    }
397}
398
399// ============================================================================
400// Helper functions for scalar quantization distance
401// ============================================================================
402
403/// Scalar fallback for scalar quantization distance (sum of squared differences)
404fn scalar_distance_scalar(a: &[u8], b: &[u8]) -> f32 {
405    let mut sum_sq = 0i32;
406
407    // 4x loop unrolling for better ILP
408    let chunks = a.len() / 4;
409    for i in 0..chunks {
410        let idx = i * 4;
411        let d0 = (a[idx] as i32) - (b[idx] as i32);
412        let d1 = (a[idx + 1] as i32) - (b[idx + 1] as i32);
413        let d2 = (a[idx + 2] as i32) - (b[idx + 2] as i32);
414        let d3 = (a[idx + 3] as i32) - (b[idx + 3] as i32);
415        sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
416    }
417
418    // Handle remainder
419    for i in (chunks * 4)..a.len() {
420        let diff = (a[i] as i32) - (b[i] as i32);
421        sum_sq += diff * diff;
422    }
423
424    sum_sq as f32
425}
426
427/// NEON SIMD distance for scalar quantization
428///
429/// # Safety
430/// Caller must ensure a.len() == b.len()
431#[cfg(target_arch = "aarch64")]
432#[inline(always)]
433unsafe fn scalar_distance_neon(a: &[u8], b: &[u8]) -> f32 {
434    use std::arch::aarch64::*;
435
436    let len = a.len();
437    let a_ptr = a.as_ptr();
438    let b_ptr = b.as_ptr();
439
440    let mut sum = vdupq_n_s32(0);
441
442    // Process 8 bytes at a time
443    let chunks = len / 8;
444    let mut idx = 0usize;
445
446    for _ in 0..chunks {
447        // Load 8 u8 values
448        let va = vld1_u8(a_ptr.add(idx));
449        let vb = vld1_u8(b_ptr.add(idx));
450
451        // Zero-extend u8 to u16
452        let va_u16 = vmovl_u8(va);
453        let vb_u16 = vmovl_u8(vb);
454
455        // Convert to signed for subtraction
456        let va_s16 = vreinterpretq_s16_u16(va_u16);
457        let vb_s16 = vreinterpretq_s16_u16(vb_u16);
458
459        // Compute difference
460        let diff = vsubq_s16(va_s16, vb_s16);
461
462        // Square and accumulate
463        let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
464        let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
465
466        sum = vaddq_s32(sum, prod_lo);
467        sum = vaddq_s32(sum, prod_hi);
468
469        idx += 8;
470    }
471
472    let mut total = vaddvq_s32(sum);
473
474    // Handle remainder with bounds-check elimination
475    for i in (chunks * 8)..len {
476        let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
477        total += diff * diff;
478    }
479
480    total as f32
481}
482
483/// AVX2 SIMD distance for scalar quantization
484#[cfg(target_arch = "x86_64")]
485#[target_feature(enable = "avx2")]
486#[inline]
487unsafe fn scalar_distance_avx2(a: &[u8], b: &[u8]) -> f32 {
488    use std::arch::x86_64::*;
489
490    let len = a.len();
491    let mut sum = _mm256_setzero_si256();
492
493    // Process 16 bytes at a time
494    let chunks = len / 16;
495    for i in 0..chunks {
496        let idx = i * 16;
497
498        // Load 16 u8 values
499        let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
500        let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
501
502        // Zero-extend u8 to i16 (low and high halves)
503        let va_lo = _mm256_cvtepu8_epi16(va);
504        let vb_lo = _mm256_cvtepu8_epi16(vb);
505
506        // Compute difference
507        let diff = _mm256_sub_epi16(va_lo, vb_lo);
508
509        // Square (multiply i16 * i16 -> i32)
510        let prod = _mm256_madd_epi16(diff, diff);
511
512        // Accumulate
513        sum = _mm256_add_epi32(sum, prod);
514    }
515
516    // Horizontal sum
517    let sum_lo = _mm256_castsi256_si128(sum);
518    let sum_hi = _mm256_extracti128_si256(sum, 1);
519    let sum_128 = _mm_add_epi32(sum_lo, sum_hi);
520
521    let shuffle = _mm_shuffle_epi32(sum_128, 0b10_11_00_01);
522    let sum_64 = _mm_add_epi32(sum_128, shuffle);
523
524    let shuffle2 = _mm_shuffle_epi32(sum_64, 0b00_00_10_10);
525    let final_sum = _mm_add_epi32(sum_64, shuffle2);
526
527    let mut total = _mm_cvtsi128_si32(final_sum);
528
529    // Handle remainder
530    for i in (chunks * 16)..len {
531        let diff = (a[i] as i32) - (b[i] as i32);
532        total += diff * diff;
533    }
534
535    total as f32
536}
537
538// Helper functions
539
540fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
541    a.iter()
542        .zip(b)
543        .map(|(&x, &y)| {
544            let diff = x - y;
545            diff * diff
546        })
547        .sum()
548}
549
550fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
551    use rand::seq::SliceRandom;
552    use rand::thread_rng;
553
554    let mut rng = thread_rng();
555
556    // Initialize centroids randomly
557    let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
558
559    for _ in 0..iterations {
560        // Assign vectors to nearest centroid
561        let mut assignments = vec![Vec::new(); k];
562
563        for vector in vectors {
564            let nearest = centroids
565                .iter()
566                .enumerate()
567                .min_by(|(_, a), (_, b)| {
568                    let dist_a = euclidean_squared(vector, a);
569                    let dist_b = euclidean_squared(vector, b);
570                    dist_a.partial_cmp(&dist_b).unwrap()
571                })
572                .map(|(idx, _)| idx)
573                .unwrap_or(0);
574
575            assignments[nearest].push(vector.clone());
576        }
577
578        // Update centroids
579        for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
580            if !assigned.is_empty() {
581                let dim = centroid.len();
582                *centroid = vec![0.0; dim];
583
584                for vector in assigned {
585                    for (i, &v) in vector.iter().enumerate() {
586                        centroid[i] += v;
587                    }
588                }
589
590                let count = assigned.len() as f32;
591                for v in centroid.iter_mut() {
592                    *v /= count;
593                }
594            }
595        }
596    }
597
598    centroids
599}
600
601// =============================================================================
602// SIMD-Optimized Distance Calculations for Quantized Vectors
603// =============================================================================
604
605// NOTE: scalar_distance_scalar is already defined above (lines 404-425)
606// NOTE: scalar_distance_neon is already defined above (lines 430-473)
607// NOTE: scalar_distance_avx2 is already defined above (lines 479-540)
608// This section uses the existing implementations for consistency
609
610/// SIMD-optimized hamming distance using popcnt
611#[cfg(target_arch = "x86_64")]
612#[target_feature(enable = "popcnt")]
613#[inline]
614unsafe fn hamming_distance_simd_x86(a: &[u8], b: &[u8]) -> u32 {
615    use std::arch::x86_64::*;
616
617    let mut distance = 0u64;
618
619    // Process 8 bytes at a time using u64 with hardware popcnt
620    let chunks_a = a.chunks_exact(8);
621    let chunks_b = b.chunks_exact(8);
622    let remainder_a = chunks_a.remainder();
623    let remainder_b = chunks_b.remainder();
624
625    for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
626        let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
627        let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
628        distance += _popcnt64((a_u64 ^ b_u64) as i64) as u64;
629    }
630
631    // Handle remainder
632    for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
633        distance += (a_byte ^ b_byte).count_ones() as u64;
634    }
635
636    distance as u32
637}
638
639/// NEON-optimized hamming distance for ARM64
640///
641/// # Safety
642/// Caller must ensure a.len() == b.len()
643#[cfg(target_arch = "aarch64")]
644#[inline(always)]
645unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
646    use std::arch::aarch64::*;
647
648    let len = a.len();
649    let a_ptr = a.as_ptr();
650    let b_ptr = b.as_ptr();
651
652    let chunks = len / 16;
653    let mut idx = 0usize;
654
655    let mut sum = vdupq_n_u8(0);
656
657    for _ in 0..chunks {
658        // Load 16 bytes
659        let a_vec = vld1q_u8(a_ptr.add(idx));
660        let b_vec = vld1q_u8(b_ptr.add(idx));
661
662        // XOR and count bits using vcntq_u8 (population count)
663        let xor_result = veorq_u8(a_vec, b_vec);
664        let bits = vcntq_u8(xor_result);
665
666        // Accumulate
667        sum = vaddq_u8(sum, bits);
668
669        idx += 16;
670    }
671
672    // Horizontal sum
673    let sum_val = vaddvq_u8(sum) as u32;
674
675    // Handle remainder with bounds-check elimination
676    let mut remainder_sum = 0u32;
677    let start = chunks * 16;
678    for i in start..len {
679        remainder_sum += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
680    }
681
682    sum_val + remainder_sum
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688
689    #[test]
690    fn test_scalar_quantization() {
691        let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
692        let quantized = ScalarQuantized::quantize(&vector);
693        let reconstructed = quantized.reconstruct();
694
695        // Check approximate reconstruction
696        for (orig, recon) in vector.iter().zip(&reconstructed) {
697            assert!((orig - recon).abs() < 0.1);
698        }
699    }
700
701    #[test]
702    fn test_binary_quantization() {
703        let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
704        let quantized = BinaryQuantized::quantize(&vector);
705
706        assert_eq!(quantized.dimensions, 5);
707        assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
708    }
709
710    #[test]
711    fn test_binary_distance() {
712        let v1 = vec![1.0, 1.0, 1.0, 1.0];
713        let v2 = vec![1.0, 1.0, -1.0, -1.0];
714
715        let q1 = BinaryQuantized::quantize(&v1);
716        let q2 = BinaryQuantized::quantize(&v2);
717
718        let dist = q1.distance(&q2);
719        assert_eq!(dist, 2.0); // 2 bits differ
720    }
721
722    #[test]
723    fn test_scalar_quantization_roundtrip() {
724        // Test that quantize -> reconstruct produces values close to original
725        let test_vectors = vec![
726            vec![1.0, 2.0, 3.0, 4.0, 5.0],
727            vec![-10.0, -5.0, 0.0, 5.0, 10.0],
728            vec![0.1, 0.2, 0.3, 0.4, 0.5],
729            vec![100.0, 200.0, 300.0, 400.0, 500.0],
730        ];
731
732        for vector in test_vectors {
733            let quantized = ScalarQuantized::quantize(&vector);
734            let reconstructed = quantized.reconstruct();
735
736            assert_eq!(vector.len(), reconstructed.len());
737
738            for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
739                // With 8-bit quantization, max error is roughly (max-min)/255
740                let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
741                let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
742                let max_error = (max - min) / 255.0 * 2.0; // Allow 2x for rounding
743
744                assert!(
745                    (orig - recon).abs() < max_error,
746                    "Roundtrip error too large: orig={}, recon={}, error={}",
747                    orig,
748                    recon,
749                    (orig - recon).abs()
750                );
751            }
752        }
753    }
754
755    #[test]
756    fn test_scalar_distance_symmetry() {
757        // Test that distance(a, b) == distance(b, a)
758        let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
759        let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
760
761        let q1 = ScalarQuantized::quantize(&v1);
762        let q2 = ScalarQuantized::quantize(&v2);
763
764        let dist_ab = q1.distance(&q2);
765        let dist_ba = q2.distance(&q1);
766
767        // Distance should be symmetric (within floating point precision)
768        assert!(
769            (dist_ab - dist_ba).abs() < 0.01,
770            "Distance is not symmetric: d(a,b)={}, d(b,a)={}",
771            dist_ab,
772            dist_ba
773        );
774    }
775
776    #[test]
777    fn test_scalar_distance_different_scales() {
778        // Test distance calculation with vectors that have different scales
779        let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // range: 4.0
780        let v2 = vec![10.0, 20.0, 30.0, 40.0, 50.0]; // range: 40.0
781
782        let q1 = ScalarQuantized::quantize(&v1);
783        let q2 = ScalarQuantized::quantize(&v2);
784
785        let dist_ab = q1.distance(&q2);
786        let dist_ba = q2.distance(&q1);
787
788        // With average scaling, symmetry should be maintained
789        assert!(
790            (dist_ab - dist_ba).abs() < 0.01,
791            "Distance with different scales not symmetric: d(a,b)={}, d(b,a)={}",
792            dist_ab,
793            dist_ba
794        );
795    }
796
797    #[test]
798    fn test_scalar_quantization_edge_cases() {
799        // Test with all same values
800        let same_values = vec![5.0, 5.0, 5.0, 5.0];
801        let quantized = ScalarQuantized::quantize(&same_values);
802        let reconstructed = quantized.reconstruct();
803
804        for (orig, recon) in same_values.iter().zip(reconstructed.iter()) {
805            assert!((orig - recon).abs() < 0.01);
806        }
807
808        // Test with extreme ranges
809        let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
810        let quantized = ScalarQuantized::quantize(&extreme);
811        let reconstructed = quantized.reconstruct();
812
813        assert_eq!(extreme.len(), reconstructed.len());
814    }
815
816    #[test]
817    fn test_binary_distance_symmetry() {
818        // Test that binary distance is symmetric
819        let v1 = vec![1.0, -1.0, 1.0, -1.0];
820        let v2 = vec![1.0, 1.0, -1.0, -1.0];
821
822        let q1 = BinaryQuantized::quantize(&v1);
823        let q2 = BinaryQuantized::quantize(&v2);
824
825        let dist_ab = q1.distance(&q2);
826        let dist_ba = q2.distance(&q1);
827
828        assert_eq!(
829            dist_ab, dist_ba,
830            "Binary distance not symmetric: d(a,b)={}, d(b,a)={}",
831            dist_ab, dist_ba
832        );
833    }
834
835    #[test]
836    fn test_int4_quantization() {
837        let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
838        let quantized = Int4Quantized::quantize(&vector);
839        let reconstructed = quantized.reconstruct();
840
841        assert_eq!(quantized.dimensions, 5);
842        // 5 dimensions = 3 bytes (2 per byte, last byte has 1)
843        assert_eq!(quantized.data.len(), 3);
844
845        // Check approximate reconstruction
846        for (orig, recon) in vector.iter().zip(&reconstructed) {
847            // With 4-bit quantization, max error is roughly (max-min)/15
848            let max_error = (5.0 - 1.0) / 15.0 * 2.0;
849            assert!(
850                (orig - recon).abs() < max_error,
851                "Int4 roundtrip error too large: orig={}, recon={}",
852                orig,
853                recon
854            );
855        }
856    }
857
858    #[test]
859    fn test_int4_distance() {
860        // Use vectors with different quantized patterns
861        // v1 spans [0.0, 15.0] -> quantizes to [0, 1, 2, ..., 15] (linear mapping)
862        // v2 spans [0.0, 15.0] but with different distribution
863        let v1 = vec![0.0, 5.0, 10.0, 15.0];
864        let v2 = vec![0.0, 3.0, 12.0, 15.0]; // Different middle values
865
866        let q1 = Int4Quantized::quantize(&v1);
867        let q2 = Int4Quantized::quantize(&v2);
868
869        let dist = q1.distance(&q2);
870        // The quantized values differ in the middle, so distance should be positive
871        assert!(
872            dist > 0.0,
873            "Distance should be positive, got {}. q1.data={:?}, q2.data={:?}",
874            dist,
875            q1.data,
876            q2.data
877        );
878    }
879
880    #[test]
881    fn test_int4_distance_symmetry() {
882        let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
883        let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
884
885        let q1 = Int4Quantized::quantize(&v1);
886        let q2 = Int4Quantized::quantize(&v2);
887
888        let dist_ab = q1.distance(&q2);
889        let dist_ba = q2.distance(&q1);
890
891        assert!(
892            (dist_ab - dist_ba).abs() < 0.01,
893            "Int4 distance not symmetric: d(a,b)={}, d(b,a)={}",
894            dist_ab,
895            dist_ba
896        );
897    }
898
899    #[test]
900    fn test_int4_compression_ratio() {
901        assert_eq!(Int4Quantized::compression_ratio(), 8.0);
902    }
903
904    #[test]
905    fn test_binary_fast_hamming() {
906        // Test fast hamming distance with various sizes
907        let a = vec![0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xAA];
908        let b = vec![0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x55];
909
910        let distance = BinaryQuantized::hamming_distance_fast(&a, &b);
911        // All bits differ: 9 bytes * 8 bits = 72 bits
912        assert_eq!(distance, 72);
913    }
914
915    #[test]
916    fn test_binary_similarity() {
917        let v1 = vec![1.0; 8]; // All positive
918        let v2 = vec![1.0; 8]; // Same
919
920        let q1 = BinaryQuantized::quantize(&v1);
921        let q2 = BinaryQuantized::quantize(&v2);
922
923        let sim = q1.similarity(&q2);
924        assert!(
925            (sim - 1.0).abs() < 0.001,
926            "Same vectors should have similarity 1.0"
927        );
928    }
929
930    #[test]
931    fn test_binary_compression_ratio() {
932        assert_eq!(BinaryQuantized::compression_ratio(), 32.0);
933    }
934}