ruvector_attention/curvature/
component_quantizer.rs

1//! Component Quantization for Mixed-Curvature Attention
2//!
3//! Different precision for each geometric component:
4//! - Euclidean: 7-8 bit (needs precision)
5//! - Hyperbolic tangent: 5 bit (tolerates noise)
6//! - Spherical: 5 bit (only direction matters)
7
8use serde::{Deserialize, Serialize};
9
10/// Quantization configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct QuantizationConfig {
13    /// Bits for Euclidean component
14    pub euclidean_bits: u8,
15    /// Bits for Hyperbolic component
16    pub hyperbolic_bits: u8,
17    /// Bits for Spherical component
18    pub spherical_bits: u8,
19}
20
21impl Default for QuantizationConfig {
22    fn default() -> Self {
23        Self {
24            euclidean_bits: 8,
25            hyperbolic_bits: 5,
26            spherical_bits: 5,
27        }
28    }
29}
30
31/// Quantized vector representation
32#[derive(Debug, Clone)]
33pub struct QuantizedVector {
34    /// Quantized Euclidean component
35    pub euclidean: Vec<i8>,
36    /// Euclidean scale factor
37    pub euclidean_scale: f32,
38    /// Quantized Hyperbolic component
39    pub hyperbolic: Vec<i8>,
40    /// Hyperbolic scale factor
41    pub hyperbolic_scale: f32,
42    /// Quantized Spherical component
43    pub spherical: Vec<i8>,
44    /// Spherical scale factor
45    pub spherical_scale: f32,
46}
47
48/// Component quantizer for efficient storage and compute
49#[derive(Debug, Clone)]
50pub struct ComponentQuantizer {
51    config: QuantizationConfig,
52    euclidean_levels: i32,
53    hyperbolic_levels: i32,
54    spherical_levels: i32,
55}
56
57impl ComponentQuantizer {
58    /// Create new quantizer
59    pub fn new(config: QuantizationConfig) -> Self {
60        Self {
61            euclidean_levels: (1 << (config.euclidean_bits - 1)) - 1,
62            hyperbolic_levels: (1 << (config.hyperbolic_bits - 1)) - 1,
63            spherical_levels: (1 << (config.spherical_bits - 1)) - 1,
64            config,
65        }
66    }
67
68    /// Quantize a component vector
69    fn quantize_component(&self, values: &[f32], levels: i32) -> (Vec<i8>, f32) {
70        if values.is_empty() {
71            return (vec![], 1.0);
72        }
73
74        // Find absmax for scale
75        let absmax = values
76            .iter()
77            .map(|v| v.abs())
78            .fold(0.0f32, f32::max)
79            .max(1e-8);
80
81        let scale = absmax / levels as f32;
82        let inv_scale = levels as f32 / absmax;
83
84        let quantized: Vec<i8> = values
85            .iter()
86            .map(|v| (v * inv_scale).round().clamp(-127.0, 127.0) as i8)
87            .collect();
88
89        (quantized, scale)
90    }
91
92    /// Dequantize a component
93    fn dequantize_component(&self, quantized: &[i8], scale: f32) -> Vec<f32> {
94        quantized.iter().map(|&q| q as f32 * scale).collect()
95    }
96
97    /// Quantize full vector with component ranges
98    pub fn quantize(
99        &self,
100        vector: &[f32],
101        e_range: std::ops::Range<usize>,
102        h_range: std::ops::Range<usize>,
103        s_range: std::ops::Range<usize>,
104    ) -> QuantizedVector {
105        let (euclidean, euclidean_scale) = self.quantize_component(
106            &vector[e_range],
107            self.euclidean_levels,
108        );
109
110        let (hyperbolic, hyperbolic_scale) = self.quantize_component(
111            &vector[h_range],
112            self.hyperbolic_levels,
113        );
114
115        let (spherical, spherical_scale) = self.quantize_component(
116            &vector[s_range],
117            self.spherical_levels,
118        );
119
120        QuantizedVector {
121            euclidean,
122            euclidean_scale,
123            hyperbolic,
124            hyperbolic_scale,
125            spherical,
126            spherical_scale,
127        }
128    }
129
130    /// Compute dot product between quantized vectors (integer arithmetic)
131    #[inline]
132    pub fn quantized_dot_product(
133        &self,
134        a: &QuantizedVector,
135        b: &QuantizedVector,
136        weights: &[f32; 3],
137    ) -> f32 {
138        // Integer dot products
139        let dot_e = Self::int_dot(&a.euclidean, &b.euclidean);
140        let dot_h = Self::int_dot(&a.hyperbolic, &b.hyperbolic);
141        let dot_s = Self::int_dot(&a.spherical, &b.spherical);
142
143        // Scale and weight
144        let sim_e = dot_e as f32 * a.euclidean_scale * b.euclidean_scale;
145        let sim_h = dot_h as f32 * a.hyperbolic_scale * b.hyperbolic_scale;
146        let sim_s = dot_s as f32 * a.spherical_scale * b.spherical_scale;
147
148        weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s
149    }
150
151    /// Integer dot product (SIMD-friendly)
152    #[inline(always)]
153    fn int_dot(a: &[i8], b: &[i8]) -> i32 {
154        let len = a.len().min(b.len());
155        let chunks = len / 4;
156        let remainder = len % 4;
157
158        let mut sum0 = 0i32;
159        let mut sum1 = 0i32;
160        let mut sum2 = 0i32;
161        let mut sum3 = 0i32;
162
163        for i in 0..chunks {
164            let base = i * 4;
165            sum0 += a[base] as i32 * b[base] as i32;
166            sum1 += a[base + 1] as i32 * b[base + 1] as i32;
167            sum2 += a[base + 2] as i32 * b[base + 2] as i32;
168            sum3 += a[base + 3] as i32 * b[base + 3] as i32;
169        }
170
171        let base = chunks * 4;
172        for i in 0..remainder {
173            sum0 += a[base + i] as i32 * b[base + i] as i32;
174        }
175
176        sum0 + sum1 + sum2 + sum3
177    }
178
179    /// Dequantize to full vector
180    pub fn dequantize(
181        &self,
182        quant: &QuantizedVector,
183        total_dim: usize,
184    ) -> Vec<f32> {
185        let mut result = vec![0.0f32; total_dim];
186
187        let e_vec = self.dequantize_component(&quant.euclidean, quant.euclidean_scale);
188        let h_vec = self.dequantize_component(&quant.hyperbolic, quant.hyperbolic_scale);
189        let s_vec = self.dequantize_component(&quant.spherical, quant.spherical_scale);
190
191        let e_end = e_vec.len();
192        let h_end = e_end + h_vec.len();
193
194        result[0..e_end].copy_from_slice(&e_vec);
195        result[e_end..h_end].copy_from_slice(&h_vec);
196        result[h_end..h_end + s_vec.len()].copy_from_slice(&s_vec);
197
198        result
199    }
200
201    /// Get memory savings ratio
202    pub fn compression_ratio(&self, dim: usize, e_dim: usize, h_dim: usize, s_dim: usize) -> f32 {
203        let original_bits = dim as f32 * 32.0;
204        let quantized_bits = e_dim as f32 * self.config.euclidean_bits as f32
205            + h_dim as f32 * self.config.hyperbolic_bits as f32
206            + s_dim as f32 * self.config.spherical_bits as f32
207            + 3.0 * 32.0; // 3 scale factors
208
209        original_bits / quantized_bits
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_quantize_dequantize() {
219        let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
220
221        let vector = vec![0.5f32; 64];
222        let e_range = 0..32;
223        let h_range = 32..48;
224        let s_range = 48..64;
225
226        let quantized = quantizer.quantize(&vector, e_range.clone(), h_range.clone(), s_range.clone());
227
228        assert_eq!(quantized.euclidean.len(), 32);
229        assert_eq!(quantized.hyperbolic.len(), 16);
230        assert_eq!(quantized.spherical.len(), 16);
231
232        // Dequantize and check approximate equality
233        let dequantized = quantizer.dequantize(&quantized, 64);
234        for (&orig, &deq) in vector.iter().zip(dequantized.iter()) {
235            assert!((orig - deq).abs() < 0.1);
236        }
237    }
238
239    #[test]
240    fn test_quantized_dot_product() {
241        let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
242
243        let a = vec![1.0f32; 64];
244        let b = vec![1.0f32; 64];
245        let e_range = 0..32;
246        let h_range = 32..48;
247        let s_range = 48..64;
248
249        let qa = quantizer.quantize(&a, e_range.clone(), h_range.clone(), s_range.clone());
250        let qb = quantizer.quantize(&b, e_range, h_range, s_range);
251
252        let weights = [0.5, 0.3, 0.2];
253        let dot = quantizer.quantized_dot_product(&qa, &qb, &weights);
254
255        // Should be positive for same vectors
256        assert!(dot > 0.0);
257    }
258
259    #[test]
260    fn test_compression_ratio() {
261        let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
262
263        let ratio = quantizer.compression_ratio(512, 256, 192, 64);
264
265        // With 8/5/5 bits vs 32 bits, expect ~4-5x compression
266        assert!(ratio > 3.0);
267        assert!(ratio < 7.0);
268    }
269}