1use crate::{CodecId, CompressionError, DecompressError, ExactFallbackAdapter};
31use quant_governor::{evaluate, GovernancePolicy, GovernanceRequest};
32
33#[cfg(feature = "fib")]
34use fib_quant::{FibCodeV1, FibQuantProfileV1, FibQuantizer};
35#[cfg(feature = "turbo")]
36use turbo_quant::TurboQuantizer;
37
38#[derive(Debug, Clone)]
40pub enum CodecDispatch<'a> {
41 Governed {
43 policy: &'a GovernancePolicy,
45 request: GovernanceRequest,
47 },
48 Force(CodecId),
50}
51
52#[allow(unused_variables)]
61type FallbackDecoder<T> = Box<dyn Fn(CodecId, &[u8]) -> Result<T, DecompressError> + Send + Sync>;
63pub fn build_adapter<T>(_dispatch: CodecDispatch) -> ExactFallbackAdapter<T>
64where
65 T: From<Vec<u8>> + Send + Sync + 'static,
66{
67 let fallback_decoder: FallbackDecoder<T> = Box::new(move |codec_id, data| {
68 match codec_id {
69 CodecId::Uncompressed => Ok(T::from(data.to_vec())),
70 #[cfg(feature = "turbo")]
71 CodecId::TurboQuant => turbo_quant_decode(data).map(T::from),
72 #[cfg(feature = "fib")]
73 CodecId::FibQuant => fib_quant_decode(data).map(T::from),
74 #[cfg(feature = "polar")]
76 CodecId::Polar => Ok(T::from(data.to_vec())),
77 #[cfg(feature = "qjl")]
78 CodecId::Qjl => Ok(T::from(data.to_vec())),
79 #[cfg(not(any(feature = "turbo", feature = "fib", feature = "polar", feature = "qjl")))]
80 _ => Err(DecompressError::DecodeFailed(
81 "No codec features enabled".to_string(),
82 )),
83 }
84 });
85
86 ExactFallbackAdapter::new(fallback_decoder)
87}
88
89pub fn select_codec(
93 policy: &GovernancePolicy,
94 request: GovernanceRequest,
95) -> Result<CodecId, quant_governor::error::GovernorError> {
96 let decision = evaluate(request, policy)?;
97 Ok(match decision.codec {
98 quant_governor::CodecProfile::Raw => CodecId::Uncompressed,
99 quant_governor::CodecProfile::Q8 => CodecId::Uncompressed, quant_governor::CodecProfile::Q4 => CodecId::Uncompressed, quant_governor::CodecProfile::Turbo => CodecId::TurboQuant,
102 quant_governor::CodecProfile::Fib => CodecId::FibQuant,
103 quant_governor::CodecProfile::Polar => CodecId::Polar,
104 quant_governor::CodecProfile::Qjl => CodecId::Qjl,
105 })
106}
107
108#[cfg(feature = "fib")]
116pub fn fib_quant_profile(dim: usize, seed: u64) -> std::result::Result<FibQuantProfileV1, fib_quant::FibQuantError> {
117 let k = 4usize;
122 let n = 32usize;
123 FibQuantProfileV1::paper_default(dim, k, n, seed)
124}
125
126#[cfg(feature = "turbo")]
128pub fn turbo_quant_quantizer(
129 dim: usize,
130 seed: u64,
131) -> std::result::Result<TurboQuantizer, turbo_quant::TurboQuantError> {
132 TurboQuantizer::new(dim, 8, 32, seed)
135}
136
137pub fn encode(codec_id: CodecId, vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
152 match codec_id {
153 CodecId::Uncompressed => Ok(bytemuck::cast_slice::<f32, u8>(vector).to_vec()),
154 #[cfg(feature = "fib")]
155 CodecId::FibQuant => fib_quant_encode(vector, seed),
156 #[cfg(feature = "turbo")]
157 CodecId::TurboQuant => turbo_quant_encode(vector, seed),
158 #[cfg(feature = "polar")]
159 CodecId::Polar => polar_quant_encode(vector, seed),
160 #[cfg(feature = "qjl")]
161 CodecId::Qjl => qjl_sketch_encode(vector, seed),
162 #[cfg(not(any(feature = "turbo", feature = "fib", feature = "polar", feature = "qjl")))]
163 _ => Err(CompressionError::EncodeFailed(
164 "no codec features enabled".to_string(),
165 )),
166 }
167}
168
169pub fn decode(codec_id: CodecId, compressed: &[u8]) -> Result<Vec<u8>, DecompressError> {
184 match codec_id {
185 CodecId::Uncompressed => Ok(compressed.to_vec()),
186 #[cfg(feature = "fib")]
187 CodecId::FibQuant => fib_quant_decode(compressed),
188 #[cfg(feature = "turbo")]
189 CodecId::TurboQuant => turbo_quant_decode(compressed),
190 #[cfg(feature = "polar")]
191 CodecId::Polar => Ok(compressed.to_vec()),
192 #[cfg(feature = "qjl")]
193 CodecId::Qjl => Ok(compressed.to_vec()),
194 #[cfg(not(any(feature = "turbo", feature = "fib", feature = "polar", feature = "qjl")))]
195 _ => Err(DecompressError::DecodeFailed(
196 "no codec features enabled".to_string(),
197 )),
198 }
199}
200
201#[cfg(feature = "fib")]
204fn fib_quant_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
205 let dim = vector.len();
206 let profile = fib_quant_profile(dim, seed).map_err(|e| {
207 CompressionError::EncodeFailed(format!("fib_quant profile build: {e}"))
208 })?;
209 let quantizer = FibQuantizer::new(profile).map_err(|e| {
210 CompressionError::EncodeFailed(format!("fib_quant quantizer build: {e}"))
211 })?;
212 let code = quantizer.encode(vector).map_err(|e| {
213 CompressionError::EncodeFailed(format!("fib_quant encode: {e}"))
214 })?;
215 serde_json::to_vec(&code).map_err(|e| {
216 CompressionError::EncodeFailed(format!("fib_quant serialize: {e}"))
217 })
218}
219
220#[cfg(feature = "fib")]
221fn fib_quant_decode(compressed: &[u8]) -> Result<Vec<u8>, DecompressError> {
222 let code: FibCodeV1 = serde_json::from_slice(compressed).map_err(|e| {
223 DecompressError::DecodeFailed(format!("fib_quant deserialize: {e}"))
224 })?;
225 let seed = 42u64;
231 let profile = fib_quant_profile(code.ambient_dim as usize, seed).map_err(|e| {
232 DecompressError::DecodeFailed(format!("fib_quant profile build: {e}"))
233 })?;
234 let quantizer = FibQuantizer::new(profile).map_err(|e| {
235 DecompressError::DecodeFailed(format!("fib_quant quantizer build: {e}"))
236 })?;
237 let decoded = quantizer.decode(&code).map_err(|e| {
238 DecompressError::DecodeFailed(format!("fib_quant decode: {e}"))
239 })?;
240 Ok(bytemuck::cast_slice::<f32, u8>(&decoded).to_vec())
241}
242
243#[cfg(feature = "turbo")]
246fn turbo_quant_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
247 let dim = vector.len();
248 let quantizer = turbo_quant_quantizer(dim, seed).map_err(|e| {
249 CompressionError::EncodeFailed(format!("turbo_quant quantizer build: {e}"))
250 })?;
251 quantizer.encode_to_bytes(vector).map_err(|e| {
252 CompressionError::EncodeFailed(format!("turbo_quant encode: {e}"))
253 })
254}
255
256#[cfg(feature = "turbo")]
257#[cfg(feature = "turbo")]
264fn turbo_quant_decode(compressed: &[u8]) -> Result<Vec<u8>, DecompressError> {
265 use turbo_quant::{TurboCodeWireV1, TurboMode, TurboQuantizer};
266
267 let header = TurboCodeWireV1::parse_header(compressed).map_err(|e| {
271 DecompressError::DecodeFailed(format!("turbo_quant header parse: {e}"))
272 })?;
273
274 let mode = if header.qjl_sign_count > 0 {
277 TurboMode::PolarWithQjl
278 } else {
279 TurboMode::PolarOnly
280 };
281 let quantizer = TurboQuantizer::new_with_mode(
282 header.dim,
283 match mode {
287 TurboMode::PolarWithQjl => header.polar_bits + 1,
288 TurboMode::PolarOnly => header.polar_bits,
289 },
290 header.qjl_projections,
291 header.seed,
292 mode,
293 )
294 .map_err(|e| {
295 DecompressError::DecodeFailed(format!("turbo_quant quantizer rebuild: {e}"))
296 })?;
297
298 let code = TurboCodeWireV1::decode(compressed, &quantizer)
300 .map_err(|e| DecompressError::DecodeFailed(format!("turbo_quant wire decode: {e}")))?;
301
302 let decoded = quantizer
304 .decode_approximate(&code)
305 .map_err(|e| DecompressError::DecodeFailed(format!("turbo_quant decode: {e}")))?;
306
307 Ok(bytemuck::cast_slice::<f32, u8>(&decoded).to_vec())
308}
309
310#[cfg(feature = "polar")]
318fn polar_quant_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
319 use turbo_quant::PolarQuantizer;
320 let dim = vector.len();
321 let bits = 8u8;
323 let quantizer = PolarQuantizer::new_with_stored_rotation(dim, bits, seed).map_err(|e| {
324 CompressionError::EncodeFailed(format!("polar_quant build: {e}"))
325 })?;
326 let code = quantizer.encode(vector).map_err(|e| {
327 CompressionError::EncodeFailed(format!("polar_quant encode: {e}"))
328 })?;
329 serde_json::to_vec(&code).map_err(|e| {
330 CompressionError::EncodeFailed(format!("polar_quant serialize: {e}"))
331 })
332}
333
334#[cfg(feature = "qjl")]
342fn qjl_sketch_encode(vector: &[f32], seed: u64) -> Result<Vec<u8>, CompressionError> {
343 use turbo_quant::QjlQuantizer;
344 let dim = vector.len();
345 let projections = 32usize;
348 let quantizer = QjlQuantizer::new(dim, projections, seed).map_err(|e| {
349 CompressionError::EncodeFailed(format!("qjl_quant build: {e}"))
350 })?;
351 let sketch = quantizer.sketch(vector).map_err(|e| {
352 CompressionError::EncodeFailed(format!("qjl_quant sketch: {e}"))
353 })?;
354 serde_json::to_vec(&sketch).map_err(|e| {
355 CompressionError::EncodeFailed(format!("qjl_quant serialize: {e}"))
356 })
357}
358
359#[cfg(test)]
360#[allow(clippy::expect_used)] mod tests {
362 use super::*;
363 use crate::CompressionError;
364
365 fn make_vector(dim: usize, seed: u64) -> Vec<f32> {
366 let mut s = seed;
368 (0..dim)
369 .map(|_| {
370 s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
371 ((s >> 32) as f32 / u32::MAX as f32) - 0.5
372 })
373 .collect()
374 }
375
376 #[test]
377 fn uncompressed_round_trip_is_exact() {
378 let v = make_vector(128, 42);
379 let encoded = encode(CodecId::Uncompressed, &v, 0).unwrap();
380 let decoded_bytes = decode(CodecId::Uncompressed, &encoded).unwrap();
381 let decoded: &[f32] = bytemuck::cast_slice(&decoded_bytes);
382 assert_eq!(v, decoded);
383 }
384
385 #[test]
386 #[cfg(feature = "fib")]
387 fn fib_quant_round_trip_digest_stable() {
388 let v = make_vector(128, 42);
393 let encoded_a = encode(CodecId::FibQuant, &v, 42).unwrap();
394 let encoded_b = encode(CodecId::FibQuant, &v, 42).unwrap();
395 assert_eq!(
396 encoded_a, encoded_b,
397 "fib_quant encode must be deterministic at the same seed"
398 );
399 let decoded = decode(CodecId::FibQuant, &encoded_a).unwrap();
401 let decoded_vec: Vec<f32> = bytemuck::cast_slice(&decoded).to_vec();
402 assert_eq!(decoded_vec.len(), v.len());
403 assert!(decoded_vec.iter().all(|x| x.is_finite()));
405 }
406
407 #[test]
408 #[cfg(feature = "turbo")]
409 fn turbo_quant_round_trip_reconstructs_approximate_vector() {
410 let v = make_vector(128, 7);
415 let encoded = encode(CodecId::TurboQuant, &v, 7).expect("turbo encode failed");
416 let decoded_bytes = decode(CodecId::TurboQuant, &encoded).expect("turbo decode failed");
417 let decoded_vec: Vec<f32> = bytemuck::cast_slice(&decoded_bytes).to_vec();
418 assert_eq!(decoded_vec.len(), v.len());
419 assert!(decoded_vec.iter().all(|x| x.is_finite()));
421 }
425
426 #[test]
427 #[cfg(feature = "turbo")]
428 fn turbo_quant_round_trip_uses_wire_embedded_profile() {
429 let v = make_vector(64, 1);
436 let encoded_seed1 = encode(CodecId::TurboQuant, &v, 1).expect("encode seed=1");
437 let _decoded = decode(CodecId::TurboQuant, &encoded_seed1)
438 .expect("decode with wire-embedded seed must succeed");
439
440 let v_seed99 = make_vector(64, 99);
441 let encoded_seed99 = encode(CodecId::TurboQuant, &v_seed99, 99)
442 .expect("encode seed=99");
443 let _decoded_99 = decode(CodecId::TurboQuant, &encoded_seed99)
444 .expect("decode with wire-embedded seed must succeed");
445 }
446
447 #[test]
448 #[cfg(feature = "fib")]
449 fn fib_quant_different_seeds_produce_different_codes() {
450 let v = make_vector(128, 42);
451 let a = encode(CodecId::FibQuant, &v, 1).unwrap();
452 let b = encode(CodecId::FibQuant, &v, 2).unwrap();
453 assert_ne!(a, b, "different seeds must produce different codes");
454 }
455
456 #[test]
457 #[cfg(feature = "fib")]
458 fn fib_quant_profile_digest_mismatch_is_an_error() {
459 let v = make_vector(128, 1);
463 let encoded = encode(CodecId::FibQuant, &v, 1).unwrap();
464 let result = decode(CodecId::FibQuant, &encoded);
465 match result {
470 Ok(_) => {}
471 Err(DecompressError::DecodeFailed(msg)) => {
472 assert!(
473 msg.contains("profile digest") || msg.contains("decode"),
474 "unexpected error: {msg}"
475 );
476 }
477 Err(e) => panic!("unexpected error variant: {e:?}"),
478 }
479 }
480
481 #[test]
482 fn encode_uncompressed_forces_identity() {
483 let v = make_vector(64, 7);
484 let encoded = encode(CodecId::Uncompressed, &v, 99).unwrap();
485 let expected: Vec<u8> = bytemuck::cast_slice(&v).to_vec();
486 assert_eq!(encoded, expected);
487 }
488
489 #[test]
490 fn encode_unsupported_codec_errors() {
491 let v = make_vector(64, 0);
496 let _result: Result<Vec<u8>, CompressionError> = encode(CodecId::Uncompressed, &v, 0);
497 }
498
499 #[test]
500 #[cfg(feature = "polar")]
501 fn polar_quant_encode_is_deterministic() {
502 let v = make_vector(128, 42);
503 let a = encode(CodecId::Polar, &v, 42).unwrap();
504 let b = encode(CodecId::Polar, &v, 42).unwrap();
505 assert_eq!(a, b, "polar encode must be deterministic at the same seed");
506 }
512
513 #[test]
514 #[cfg(feature = "polar")]
515 fn polar_quant_different_seeds_produce_different_codes() {
516 let v = make_vector(128, 42);
517 let a = encode(CodecId::Polar, &v, 1).unwrap();
518 let b = encode(CodecId::Polar, &v, 2).unwrap();
519 assert_ne!(a, b, "different seeds must produce different polar codes");
520 }
521
522 #[test]
523 #[cfg(feature = "polar")]
524 fn polar_quant_decode_is_passthrough() {
525 let v = make_vector(64, 7);
529 let encoded = encode(CodecId::Polar, &v, 7).unwrap();
530 let decoded = decode(CodecId::Polar, &encoded).unwrap();
531 assert_eq!(encoded, decoded, "polar decode must be identity");
532 }
533
534 #[test]
535 #[cfg(feature = "qjl")]
536 fn qjl_sketch_encode_is_deterministic() {
537 let v = make_vector(128, 42);
538 let a = encode(CodecId::Qjl, &v, 42).unwrap();
539 let b = encode(CodecId::Qjl, &v, 42).unwrap();
540 assert_eq!(a, b, "qjl sketch must be deterministic at the same seed");
541 assert!(
544 a.len() < 512,
545 "qjl sketch ({} bytes) should be smaller than raw (512 bytes)",
546 a.len()
547 );
548 }
549
550 #[test]
551 #[cfg(feature = "qjl")]
552 fn qjl_sketch_different_seeds_produce_different_codes() {
553 let v = make_vector(128, 42);
554 let a = encode(CodecId::Qjl, &v, 1).unwrap();
555 let b = encode(CodecId::Qjl, &v, 2).unwrap();
556 assert_ne!(a, b, "different seeds must produce different qjl sketches");
557 }
558
559 #[test]
560 #[cfg(feature = "qjl")]
561 fn qjl_sketch_decode_is_passthrough() {
562 let v = make_vector(64, 7);
563 let encoded = encode(CodecId::Qjl, &v, 7).unwrap();
564 let decoded = decode(CodecId::Qjl, &encoded).unwrap();
565 assert_eq!(encoded, decoded, "qjl decode must be identity");
566 }
567}