ruvector_attention/curvature/
component_quantizer.rs1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct QuantizationConfig {
13 pub euclidean_bits: u8,
15 pub hyperbolic_bits: u8,
17 pub spherical_bits: u8,
19}
20
21impl Default for QuantizationConfig {
22 fn default() -> Self {
23 Self {
24 euclidean_bits: 8,
25 hyperbolic_bits: 5,
26 spherical_bits: 5,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct QuantizedVector {
34 pub euclidean: Vec<i8>,
36 pub euclidean_scale: f32,
38 pub hyperbolic: Vec<i8>,
40 pub hyperbolic_scale: f32,
42 pub spherical: Vec<i8>,
44 pub spherical_scale: f32,
46}
47
48#[derive(Debug, Clone)]
50pub struct ComponentQuantizer {
51 config: QuantizationConfig,
52 euclidean_levels: i32,
53 hyperbolic_levels: i32,
54 spherical_levels: i32,
55}
56
57impl ComponentQuantizer {
58 pub fn new(config: QuantizationConfig) -> Self {
60 Self {
61 euclidean_levels: (1 << (config.euclidean_bits - 1)) - 1,
62 hyperbolic_levels: (1 << (config.hyperbolic_bits - 1)) - 1,
63 spherical_levels: (1 << (config.spherical_bits - 1)) - 1,
64 config,
65 }
66 }
67
68 fn quantize_component(&self, values: &[f32], levels: i32) -> (Vec<i8>, f32) {
70 if values.is_empty() {
71 return (vec![], 1.0);
72 }
73
74 let absmax = values
76 .iter()
77 .map(|v| v.abs())
78 .fold(0.0f32, f32::max)
79 .max(1e-8);
80
81 let scale = absmax / levels as f32;
82 let inv_scale = levels as f32 / absmax;
83
84 let quantized: Vec<i8> = values
85 .iter()
86 .map(|v| (v * inv_scale).round().clamp(-127.0, 127.0) as i8)
87 .collect();
88
89 (quantized, scale)
90 }
91
92 fn dequantize_component(&self, quantized: &[i8], scale: f32) -> Vec<f32> {
94 quantized.iter().map(|&q| q as f32 * scale).collect()
95 }
96
97 pub fn quantize(
99 &self,
100 vector: &[f32],
101 e_range: std::ops::Range<usize>,
102 h_range: std::ops::Range<usize>,
103 s_range: std::ops::Range<usize>,
104 ) -> QuantizedVector {
105 let (euclidean, euclidean_scale) = self.quantize_component(
106 &vector[e_range],
107 self.euclidean_levels,
108 );
109
110 let (hyperbolic, hyperbolic_scale) = self.quantize_component(
111 &vector[h_range],
112 self.hyperbolic_levels,
113 );
114
115 let (spherical, spherical_scale) = self.quantize_component(
116 &vector[s_range],
117 self.spherical_levels,
118 );
119
120 QuantizedVector {
121 euclidean,
122 euclidean_scale,
123 hyperbolic,
124 hyperbolic_scale,
125 spherical,
126 spherical_scale,
127 }
128 }
129
130 #[inline]
132 pub fn quantized_dot_product(
133 &self,
134 a: &QuantizedVector,
135 b: &QuantizedVector,
136 weights: &[f32; 3],
137 ) -> f32 {
138 let dot_e = Self::int_dot(&a.euclidean, &b.euclidean);
140 let dot_h = Self::int_dot(&a.hyperbolic, &b.hyperbolic);
141 let dot_s = Self::int_dot(&a.spherical, &b.spherical);
142
143 let sim_e = dot_e as f32 * a.euclidean_scale * b.euclidean_scale;
145 let sim_h = dot_h as f32 * a.hyperbolic_scale * b.hyperbolic_scale;
146 let sim_s = dot_s as f32 * a.spherical_scale * b.spherical_scale;
147
148 weights[0] * sim_e + weights[1] * sim_h + weights[2] * sim_s
149 }
150
151 #[inline(always)]
153 fn int_dot(a: &[i8], b: &[i8]) -> i32 {
154 let len = a.len().min(b.len());
155 let chunks = len / 4;
156 let remainder = len % 4;
157
158 let mut sum0 = 0i32;
159 let mut sum1 = 0i32;
160 let mut sum2 = 0i32;
161 let mut sum3 = 0i32;
162
163 for i in 0..chunks {
164 let base = i * 4;
165 sum0 += a[base] as i32 * b[base] as i32;
166 sum1 += a[base + 1] as i32 * b[base + 1] as i32;
167 sum2 += a[base + 2] as i32 * b[base + 2] as i32;
168 sum3 += a[base + 3] as i32 * b[base + 3] as i32;
169 }
170
171 let base = chunks * 4;
172 for i in 0..remainder {
173 sum0 += a[base + i] as i32 * b[base + i] as i32;
174 }
175
176 sum0 + sum1 + sum2 + sum3
177 }
178
179 pub fn dequantize(
181 &self,
182 quant: &QuantizedVector,
183 total_dim: usize,
184 ) -> Vec<f32> {
185 let mut result = vec![0.0f32; total_dim];
186
187 let e_vec = self.dequantize_component(&quant.euclidean, quant.euclidean_scale);
188 let h_vec = self.dequantize_component(&quant.hyperbolic, quant.hyperbolic_scale);
189 let s_vec = self.dequantize_component(&quant.spherical, quant.spherical_scale);
190
191 let e_end = e_vec.len();
192 let h_end = e_end + h_vec.len();
193
194 result[0..e_end].copy_from_slice(&e_vec);
195 result[e_end..h_end].copy_from_slice(&h_vec);
196 result[h_end..h_end + s_vec.len()].copy_from_slice(&s_vec);
197
198 result
199 }
200
201 pub fn compression_ratio(&self, dim: usize, e_dim: usize, h_dim: usize, s_dim: usize) -> f32 {
203 let original_bits = dim as f32 * 32.0;
204 let quantized_bits = e_dim as f32 * self.config.euclidean_bits as f32
205 + h_dim as f32 * self.config.hyperbolic_bits as f32
206 + s_dim as f32 * self.config.spherical_bits as f32
207 + 3.0 * 32.0; original_bits / quantized_bits
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_quantize_dequantize() {
219 let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
220
221 let vector = vec![0.5f32; 64];
222 let e_range = 0..32;
223 let h_range = 32..48;
224 let s_range = 48..64;
225
226 let quantized = quantizer.quantize(&vector, e_range.clone(), h_range.clone(), s_range.clone());
227
228 assert_eq!(quantized.euclidean.len(), 32);
229 assert_eq!(quantized.hyperbolic.len(), 16);
230 assert_eq!(quantized.spherical.len(), 16);
231
232 let dequantized = quantizer.dequantize(&quantized, 64);
234 for (&orig, &deq) in vector.iter().zip(dequantized.iter()) {
235 assert!((orig - deq).abs() < 0.1);
236 }
237 }
238
239 #[test]
240 fn test_quantized_dot_product() {
241 let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
242
243 let a = vec![1.0f32; 64];
244 let b = vec![1.0f32; 64];
245 let e_range = 0..32;
246 let h_range = 32..48;
247 let s_range = 48..64;
248
249 let qa = quantizer.quantize(&a, e_range.clone(), h_range.clone(), s_range.clone());
250 let qb = quantizer.quantize(&b, e_range, h_range, s_range);
251
252 let weights = [0.5, 0.3, 0.2];
253 let dot = quantizer.quantized_dot_product(&qa, &qb, &weights);
254
255 assert!(dot > 0.0);
257 }
258
259 #[test]
260 fn test_compression_ratio() {
261 let quantizer = ComponentQuantizer::new(QuantizationConfig::default());
262
263 let ratio = quantizer.compression_ratio(512, 256, 192, 64);
264
265 assert!(ratio > 3.0);
267 assert!(ratio < 7.0);
268 }
269}