1use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31
32use crate::{
33 error::{Result, TurboQuantError},
34 polar::{PolarCode, PolarQuantizer},
35 profile::{CodecProfileV1, CompressionReceiptV1, ValidationState},
36 qjl::{QjlProjectedQuery, QjlQuantizer, QjlSketch},
37 rotation::RotationKind,
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
42pub enum TurboMode {
43 PolarOnly,
45 PolarWithQjl,
47}
48
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
51pub struct TurboCode {
52 pub polar_code: PolarCode,
54 pub residual_sketch: QjlSketch,
56}
57
58impl TurboCode {
59 pub fn encoded_bytes(&self) -> usize {
61 self.polar_code.encoded_bytes() + self.residual_sketch.encoded_bytes()
62 }
63
64 pub fn compression_ratio(&self) -> f32 {
66 let original = self.polar_code.dim * std::mem::size_of::<f32>();
67 original as f32 / self.encoded_bytes() as f32
68 }
69
70 pub fn validate_for(
72 &self,
73 dim: usize,
74 bits: u8,
75 projections: usize,
76 mode: TurboMode,
77 ) -> Result<()> {
78 let polar_bits = match mode {
79 TurboMode::PolarOnly => bits,
80 TurboMode::PolarWithQjl => bits.saturating_sub(1),
81 };
82 self.polar_code.validate_for(dim, polar_bits)?;
83 match mode {
84 TurboMode::PolarOnly => self.residual_sketch.validate_for(dim, 0),
85 TurboMode::PolarWithQjl => self.residual_sketch.validate_for(dim, projections),
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct TurboQuantizer {
97 dim: usize,
98 bits: u8,
100 projections: usize,
102 seed: u64,
103 mode: TurboMode,
104 polar: PolarQuantizer,
105 qjl: Option<QjlQuantizer>,
106}
107
108#[derive(Debug, Clone, PartialEq)]
110pub struct TurboProjectedQuery {
111 polar: crate::polar::PolarProjectedQuery,
112 qjl: Option<QjlProjectedQuery>,
113}
114
115impl TurboQuantizer {
116 pub fn new(dim: usize, bits: u8, projections: usize, seed: u64) -> Result<Self> {
125 Self::new_with_mode(dim, bits, projections, seed, TurboMode::PolarWithQjl)
126 }
127
128 pub fn new_with_mode(
130 dim: usize,
131 bits: u8,
132 projections: usize,
133 seed: u64,
134 mode: TurboMode,
135 ) -> Result<Self> {
136 Self::new_with_mode_and_rotation(dim, bits, projections, seed, mode, RotationKind::Auto)
137 }
138
139 pub fn new_with_mode_and_rotation(
141 dim: usize,
142 bits: u8,
143 projections: usize,
144 seed: u64,
145 mode: TurboMode,
146 rotation_kind: RotationKind,
147 ) -> Result<Self> {
148 if dim == 0 {
149 return Err(TurboQuantError::ZeroDimension);
150 }
151 if dim % 2 != 0 {
152 return Err(TurboQuantError::OddDimension { got: dim });
153 }
154 let valid_bits = match mode {
155 TurboMode::PolarOnly => (1..=16).contains(&bits),
156 TurboMode::PolarWithQjl => (2..=16).contains(&bits),
157 };
158 if !valid_bits {
159 return Err(TurboQuantError::InvalidBitWidth { got: bits });
160 }
161 if mode == TurboMode::PolarWithQjl && projections == 0 {
162 return Err(TurboQuantError::ZeroProjectionCount);
163 }
164
165 let polar_seed = seed;
167 let qjl_seed = seed.wrapping_add(0xCAFE_BABE_0000_0001);
168
169 let polar_bits = match mode {
170 TurboMode::PolarOnly => bits,
171 TurboMode::PolarWithQjl => bits - 1,
172 };
173 let polar = PolarQuantizer::new_with_rotation(dim, polar_bits, polar_seed, rotation_kind)?;
174 let qjl = match mode {
175 TurboMode::PolarOnly => None,
176 TurboMode::PolarWithQjl => Some(QjlQuantizer::new(dim, projections, qjl_seed)?),
177 };
178
179 Ok(Self {
180 dim,
181 bits,
182 projections,
183 seed,
184 mode,
185 polar,
186 qjl,
187 })
188 }
189
190 pub fn new_with_stored_rotation(
192 dim: usize,
193 bits: u8,
194 projections: usize,
195 seed: u64,
196 ) -> Result<Self> {
197 Self::new_with_mode_and_rotation(
198 dim,
199 bits,
200 projections,
201 seed,
202 TurboMode::PolarWithQjl,
203 RotationKind::StoredQr,
204 )
205 }
206
207 pub fn dim(&self) -> usize {
209 self.dim
210 }
211
212 pub fn bits(&self) -> u8 {
214 self.bits
215 }
216
217 pub fn projections(&self) -> usize {
219 self.projections
220 }
221
222 pub fn seed(&self) -> u64 {
224 self.seed
225 }
226
227 pub fn mode(&self) -> TurboMode {
229 self.mode
230 }
231
232 pub fn rotation_kind(&self) -> RotationKind {
234 self.polar.rotation_kind()
235 }
236
237 pub fn profile(&self) -> CodecProfileV1 {
239 CodecProfileV1::turbo(
240 self.dim,
241 self.bits,
242 self.projections,
243 self.seed,
244 self.mode == TurboMode::PolarWithQjl,
245 self.polar.rotation_kind_label(),
246 )
247 }
248
249 pub fn encode(&self, vector: &[f32]) -> Result<TurboCode> {
257 if vector.len() != self.dim {
258 return Err(TurboQuantError::DimensionMismatch {
259 expected: self.dim,
260 got: vector.len(),
261 });
262 }
263 check_finite_vector(vector)?;
264
265 let polar_code = self.polar.encode(vector)?;
266
267 let reconstruction = self.polar.decode(&polar_code)?;
269 let residual: Vec<f32> = vector
270 .iter()
271 .zip(reconstruction.iter())
272 .map(|(orig, rec)| orig - rec)
273 .collect();
274
275 let residual_sketch = match &self.qjl {
276 Some(qjl) => qjl.sketch(&residual)?,
277 None => QjlSketch {
278 dim: self.dim,
279 projections: 0,
280 signs: Vec::new(),
281 },
282 };
283
284 Ok(TurboCode {
285 polar_code,
286 residual_sketch,
287 })
288 }
289
290 pub fn encode_with_receipt(
292 &self,
293 vector: &[f32],
294 source_digest: Option<String>,
295 ) -> Result<(TurboCode, CompressionReceiptV1)> {
296 let code = self.encode(vector)?;
297 let receipt = CompressionReceiptV1::new(
298 self.profile(),
299 source_digest,
300 vector.len(),
301 code.encoded_bytes(),
302 ValidationState::Validated,
303 );
304 Ok((code, receipt))
305 }
306
307 pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<TurboCode>> {
309 vectors.iter().map(|vector| self.encode(vector)).collect()
310 }
311
312 pub fn inner_product_estimate(&self, code: &TurboCode, query: &[f32]) -> Result<f32> {
317 let projected = self.prepare_query(query)?;
318 self.inner_product_estimate_prepared(code, &projected)
319 }
320
321 pub fn prepare_query(&self, query: &[f32]) -> Result<TurboProjectedQuery> {
323 if query.len() != self.dim {
324 return Err(TurboQuantError::DimensionMismatch {
325 expected: self.dim,
326 got: query.len(),
327 });
328 }
329 check_finite_vector(query)?;
330 Ok(TurboProjectedQuery {
331 polar: self.polar.project_query(query)?,
332 qjl: match &self.qjl {
333 Some(qjl) => Some(qjl.project_query(query)?),
334 None => None,
335 },
336 })
337 }
338
339 pub fn inner_product_estimate_prepared(
341 &self,
342 code: &TurboCode,
343 query: &TurboProjectedQuery,
344 ) -> Result<f32> {
345 code.validate_for(self.dim, self.bits, self.projections, self.mode)?;
346
347 let polar_estimate = self
348 .polar
349 .inner_product_estimate_with_projected_query(&code.polar_code, &query.polar)?;
350 let qjl_correction = match (&self.qjl, &query.qjl) {
351 (Some(qjl), Some(qjl_query)) => {
352 qjl.inner_product_estimate_with_projected_query(&code.residual_sketch, qjl_query)?
353 }
354 (None, None) => 0.0,
355 _ => {
356 return Err(TurboQuantError::MalformedCode {
357 reason: "TurboQuant QJL mode/query/code mismatch".into(),
358 });
359 }
360 };
361
362 let score = polar_estimate + qjl_correction;
363 if !score.is_finite() {
364 return Err(TurboQuantError::MalformedCode {
365 reason: "turbo score is not finite".into(),
366 });
367 }
368 Ok(score)
369 }
370
371 pub fn score_batch_prepared(
373 &self,
374 query: &TurboProjectedQuery,
375 codes: &[TurboCode],
376 ) -> Result<Vec<f32>> {
377 codes
378 .iter()
379 .map(|code| self.inner_product_estimate_prepared(code, query))
380 .collect()
381 }
382
383 pub fn l2_distance_estimate(&self, code: &TurboCode, query: &[f32]) -> Result<f32> {
388 let ip = self.inner_product_estimate(code, query)?;
389 let query_norm_sq: f32 = query.iter().map(|x| x * x).sum();
390 let code_norm_sq: f32 = code.polar_code.radii.iter().map(|r| r * r).sum();
391 let distance = (query_norm_sq + code_norm_sq - 2.0 * ip).max(0.0);
392 if !distance.is_finite() {
393 return Err(TurboQuantError::MalformedCode {
394 reason: "turbo l2 distance is not finite".into(),
395 });
396 }
397 Ok(distance)
398 }
399
400 pub fn decode_approximate(&self, code: &TurboCode) -> Result<Vec<f32>> {
405 code.validate_for(self.dim, self.bits, self.projections, self.mode)?;
406 self.polar.decode(&code.polar_code)
407 }
408
409 pub fn decode_approximate_batch(&self, codes: &[TurboCode]) -> Result<Vec<Vec<f32>>> {
415 for code in codes {
416 code.validate_for(self.dim, self.bits, self.projections, self.mode)?;
417 }
418 let polar_refs: Vec<PolarCode> = codes.iter().map(|c| c.polar_code.clone()).collect();
419 self.polar.decode_batch(&polar_refs)
420 }
421
422 pub fn encode_to_bytes(&self, vector: &[f32]) -> Result<Vec<u8>> {
424 let code = self.encode(vector)?;
425 crate::wire::TurboCodeWireV1::encode(&code, self)
426 }
427
428 pub fn decode_code_from_bytes(&self, bytes: &[u8]) -> Result<TurboCode> {
430 crate::wire::TurboCodeWireV1::decode(bytes, self)
431 }
432
433 pub fn score_inner_product_from_bytes(&self, bytes: &[u8], query: &[f32]) -> Result<f32> {
435 let code = self.decode_code_from_bytes(bytes)?;
436 let prepared = self.prepare_query(query)?;
437 self.inner_product_estimate_prepared(&code, &prepared)
438 }
439
440 pub fn batch_stats(&self, codes: &[TurboCode]) -> BatchStats {
442 let total_bytes: usize = codes.iter().map(|c| c.encoded_bytes()).sum();
443 let original_bytes = codes.len() * self.dim * std::mem::size_of::<f32>();
444 BatchStats {
445 count: codes.len(),
446 total_encoded_bytes: total_bytes,
447 total_original_bytes: original_bytes,
448 compression_ratio: if total_bytes > 0 {
449 original_bytes as f32 / total_bytes as f32
450 } else {
451 0.0
452 },
453 }
454 }
455}
456
457fn check_finite_vector(vector: &[f32]) -> Result<()> {
458 if let Some((index, _)) = vector
459 .iter()
460 .enumerate()
461 .find(|(_, value)| !value.is_finite())
462 {
463 return Err(TurboQuantError::NonFiniteInput { index });
464 }
465 Ok(())
466}
467
468#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
470pub struct BatchStats {
471 pub count: usize,
472 pub total_encoded_bytes: usize,
473 pub total_original_bytes: usize,
474 pub compression_ratio: f32,
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
482 use rand::SeedableRng;
483 use rand_chacha::ChaCha8Rng;
484 use rand_distr::{Distribution, StandardNormal};
485 let mut rng = ChaCha8Rng::seed_from_u64(seed);
486 (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect()
487 }
488
489 #[test]
490 fn encode_is_deterministic() {
491 let q = TurboQuantizer::new(16, 8, 16, 42).unwrap();
492 let x = random_vector(16, 1);
493 let c1 = q.encode(&x).unwrap();
494 let c2 = q.encode(&x).unwrap();
495 assert_eq!(c1.polar_code, c2.polar_code);
496 assert_eq!(c1.residual_sketch.signs, c2.residual_sketch.signs);
497 }
498
499 #[test]
500 fn inner_product_estimate_outperforms_polar_alone_at_low_bits() {
501 let dim = 64;
504 let bits = 4u8; let projections = 64;
506
507 let polar_only = PolarQuantizer::new(dim, bits, 0).unwrap();
508 let turbo = TurboQuantizer::new(dim, bits + 1, projections, 0).unwrap();
509
510 let mut polar_errors = Vec::new();
511 let mut turbo_errors = Vec::new();
512
513 for seed in 0..20u64 {
514 let x = random_vector(dim, seed * 2);
515 let y = random_vector(dim, seed * 2 + 1);
516
517 let exact: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
518
519 let polar_code = polar_only.encode(&x).unwrap();
520 let polar_est = polar_only.inner_product_estimate(&polar_code, &y).unwrap();
521
522 let turbo_code = turbo.encode(&x).unwrap();
523 let turbo_est = turbo.inner_product_estimate(&turbo_code, &y).unwrap();
524
525 polar_errors.push((polar_est - exact).abs());
526 turbo_errors.push((turbo_est - exact).abs());
527 }
528
529 let avg_polar: f32 = polar_errors.iter().sum::<f32>() / polar_errors.len() as f32;
530 let avg_turbo: f32 = turbo_errors.iter().sum::<f32>() / turbo_errors.len() as f32;
531
532 assert!(
533 avg_turbo <= avg_polar * 1.5,
534 "TurboQuant should be competitive with PolarQuant: turbo_avg={avg_turbo:.3}, polar_avg={avg_polar:.3}"
535 );
536 }
537
538 #[test]
539 fn nearest_neighbor_ordering_is_preserved() {
540 let q = TurboQuantizer::new(16, 8, 16, 7).unwrap();
541 let query = random_vector(16, 99);
542
543 let close = {
544 let mut v = query.clone();
545 v.iter_mut().for_each(|x| *x += 0.05);
546 v
547 };
548 let far1 = random_vector(16, 200);
549 let far2 = random_vector(16, 201);
550
551 let code_close = q.encode(&close).unwrap();
552 let code_far1 = q.encode(&far1).unwrap();
553 let code_far2 = q.encode(&far2).unwrap();
554
555 let ip_close = q.inner_product_estimate(&code_close, &query).unwrap();
556 let ip_far1 = q.inner_product_estimate(&code_far1, &query).unwrap();
557 let ip_far2 = q.inner_product_estimate(&code_far2, &query).unwrap();
558
559 assert!(
560 ip_close > ip_far1 && ip_close > ip_far2,
561 "close={ip_close:.3}, far1={ip_far1:.3}, far2={ip_far2:.3}"
562 );
563 }
564
565 #[test]
566 fn compression_ratio_is_positive() {
567 let q = TurboQuantizer::new(64, 8, 32, 0).unwrap();
568 let x = random_vector(64, 1);
569 let code = q.encode(&x).unwrap();
570 assert!(code.compression_ratio() > 1.0);
571 }
572
573 #[test]
574 fn batch_stats_sums_correctly() {
575 let dim = 64;
576 let q = TurboQuantizer::new(dim, 8, 16, 0).unwrap();
577 let codes: Vec<_> = (0..10)
578 .map(|i| q.encode(&random_vector(dim, i)).unwrap())
579 .collect();
580 let stats = q.batch_stats(&codes);
581 assert_eq!(stats.count, 10);
582 assert!(stats.compression_ratio > 1.0);
583 assert_eq!(
584 stats.total_original_bytes,
585 10 * dim * std::mem::size_of::<f32>()
586 );
587 }
588
589 #[test]
590 fn turbo_code_serialization_roundtrip() {
591 let q = TurboQuantizer::new(16, 8, 16, 42).unwrap();
592 let x = random_vector(16, 1);
593 let code = q.encode(&x).unwrap();
594 let json = serde_json::to_string(&code).unwrap();
595 let restored: TurboCode = serde_json::from_str(&json).unwrap();
596 assert_eq!(code, restored);
597 }
598
599 #[test]
600 fn invalid_config_rejected() {
601 assert!(TurboQuantizer::new(0, 8, 16, 0).is_err()); assert!(TurboQuantizer::new(7, 8, 16, 0).is_err()); assert!(TurboQuantizer::new(8, 1, 16, 0).is_err()); assert!(TurboQuantizer::new(8, 8, 0, 0).is_err()); }
606}