1use std::f32::consts::PI;
22
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 bitpack,
28 error::{Result, TurboQuantError},
29 rotation::{Rotation, RotationBackend, RotationKind},
30};
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct PolarCode {
38 pub dim: usize,
40 pub bits: u8,
42 pub radii: Vec<f32>,
44 pub angle_indices: Vec<u16>,
46}
47
48impl PolarCode {
49 pub fn pair_count(&self) -> usize {
51 self.dim / 2
52 }
53
54 pub fn from_parts(
56 dim: usize,
57 bits: u8,
58 radii: Vec<f32>,
59 angle_indices: &[u16],
60 ) -> Result<Self> {
61 let code = Self {
62 dim,
63 bits,
64 radii,
65 angle_indices: angle_indices.to_vec(),
66 };
67 code.validate_for(dim, bits)?;
68 Ok(code)
69 }
70
71 pub fn angle_index(&self, i: usize) -> Result<u16> {
73 if i >= self.pair_count() {
74 return Err(TurboQuantError::MalformedCode {
75 reason: format!(
76 "angle index {i} is outside pair count {}",
77 self.pair_count()
78 ),
79 });
80 }
81 Ok(self.angle_indices[i])
82 }
83
84 pub fn angle_indices(&self) -> Result<Vec<u16>> {
86 self.validate_for(self.dim, self.bits)?;
87 Ok(self.angle_indices.clone())
88 }
89
90 pub fn dequantize_angle(&self, i: usize) -> Result<f32> {
92 let levels = 1u32 << self.bits;
93 let idx = self.angle_index(i)? as f32;
94 Ok((idx / levels as f32) * (2.0 * PI) - PI)
95 }
96
97 pub fn encoded_bytes(&self) -> usize {
99 self.radii.len() * std::mem::size_of::<f32>()
100 + bitpack::packed_len(self.angle_indices.len(), self.bits).unwrap_or(usize::MAX)
101 }
102
103 pub fn validate_for(&self, dim: usize, bits: u8) -> Result<()> {
105 if self.dim != dim {
106 return Err(TurboQuantError::DimensionMismatch {
107 expected: dim,
108 got: self.dim,
109 });
110 }
111 if self.bits != bits {
112 return Err(TurboQuantError::MalformedCode {
113 reason: format!("code has bits={}, expected {bits}", self.bits),
114 });
115 }
116 if dim == 0 || dim % 2 != 0 {
117 return Err(TurboQuantError::MalformedCode {
118 reason: format!("code dimension must be positive and even, got {dim}"),
119 });
120 }
121 let pairs = dim / 2;
122 if self.radii.len() != pairs {
123 return Err(TurboQuantError::MalformedCode {
124 reason: format!("code has {} radii, expected {pairs}", self.radii.len()),
125 });
126 }
127 for (index, radius) in self.radii.iter().enumerate() {
128 if !radius.is_finite() || *radius < 0.0 {
129 return Err(TurboQuantError::MalformedCode {
130 reason: format!("radius {index} is not finite and non-negative"),
131 });
132 }
133 }
134 if self.angle_indices.len() != pairs {
135 return Err(TurboQuantError::MalformedCode {
136 reason: format!(
137 "code has {} angle indices, expected {pairs}",
138 self.angle_indices.len()
139 ),
140 });
141 }
142 let levels = 1u32 << bits;
143 for (index, angle_index) in self.angle_indices.iter().enumerate() {
144 if u32::from(*angle_index) >= levels {
145 return Err(TurboQuantError::MalformedCode {
146 reason: format!(
147 "angle index {index} value {angle_index} is outside [0, {levels})"
148 ),
149 });
150 }
151 }
152 Ok(())
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct PolarQuantizer {
163 dim: usize,
164 bits: u8,
165 rotation: RotationBackend,
166}
167
168#[derive(Debug, Clone, PartialEq)]
170pub struct PolarProjectedQuery {
171 rotated_query: Vec<f32>,
172}
173
174impl PolarQuantizer {
175 pub fn new(dim: usize, bits: u8, seed: u64) -> Result<Self> {
182 if dim == 0 {
183 return Err(TurboQuantError::ZeroDimension);
184 }
185 if dim % 2 != 0 {
186 return Err(TurboQuantError::OddDimension { got: dim });
187 }
188 if bits == 0 || bits > 16 {
189 return Err(TurboQuantError::InvalidBitWidth { got: bits });
190 }
191 Self::new_with_rotation(dim, bits, seed, RotationKind::Auto)
192 }
193
194 pub fn new_with_rotation(
196 dim: usize,
197 bits: u8,
198 seed: u64,
199 rotation_kind: RotationKind,
200 ) -> Result<Self> {
201 if dim == 0 {
202 return Err(TurboQuantError::ZeroDimension);
203 }
204 if dim % 2 != 0 {
205 return Err(TurboQuantError::OddDimension { got: dim });
206 }
207 if bits == 0 || bits > 16 {
208 return Err(TurboQuantError::InvalidBitWidth { got: bits });
209 }
210 let rotation = RotationBackend::new(dim, seed, rotation_kind)?;
211 Ok(Self {
212 dim,
213 bits,
214 rotation,
215 })
216 }
217
218 pub fn new_with_stored_rotation(dim: usize, bits: u8, seed: u64) -> Result<Self> {
220 Self::new_with_rotation(dim, bits, seed, RotationKind::StoredQr)
221 }
222
223 pub fn dim(&self) -> usize {
225 self.dim
226 }
227
228 pub fn bits(&self) -> u8 {
230 self.bits
231 }
232
233 pub fn rotation_kind(&self) -> RotationKind {
235 self.rotation.kind()
236 }
237
238 pub fn rotation_kind_label(&self) -> &'static str {
240 self.rotation.kind_label()
241 }
242
243 pub fn encode(&self, vector: &[f32]) -> Result<PolarCode> {
247 self.check_input_dim(vector.len())?;
248 check_finite_vector(vector)?;
249
250 let mut rotated = vec![0.0f32; self.dim];
251 self.rotation.apply(vector, &mut rotated)?;
252
253 let pairs = self.dim / 2;
254 let mut radii = Vec::with_capacity(pairs);
255 let mut angle_indices = Vec::with_capacity(pairs);
256
257 for i in 0..pairs {
258 let a = rotated[2 * i];
259 let b = rotated[2 * i + 1];
260 let (r, idx) = encode_pair(a, b, self.bits);
261 radii.push(r);
262 angle_indices.push(idx);
263 }
264
265 PolarCode::from_parts(self.dim, self.bits, radii, &angle_indices)
266 }
267
268 pub fn decode(&self, code: &PolarCode) -> Result<Vec<f32>> {
273 self.validate_code(code)?;
274 let rotated = self.decode_to_rotated(code)?;
275 let mut output = vec![0.0f32; self.dim];
276 self.rotation.apply_inverse(&rotated, &mut output)?;
277 Ok(output)
278 }
279
280 pub fn decode_batch(&self, codes: &[PolarCode]) -> Result<Vec<Vec<f32>>> {
287 if codes.is_empty() {
288 return Ok(Vec::new());
289 }
290 let mut rotated: Vec<Vec<f32>> = Vec::with_capacity(codes.len());
293 for code in codes {
294 self.validate_code(code)?;
295 rotated.push(self.decode_to_rotated(code)?);
296 }
297 let rotated_refs: Vec<&[f32]> = rotated.iter().map(|v| v.as_slice()).collect();
299 self.rotation.apply_inverse_batch(&rotated_refs)
300 }
301
302 fn decode_to_rotated(&self, code: &PolarCode) -> Result<Vec<f32>> {
307 let mut rotated = vec![0.0f32; self.dim];
308 let pairs = self.dim / 2;
309 for i in 0..pairs {
310 let theta = code.dequantize_angle(i)?;
311 let r = code.radii[i];
312 rotated[2 * i] = r * theta.cos();
313 rotated[2 * i + 1] = r * theta.sin();
314 }
315 Ok(rotated)
316 }
317
318 pub fn inner_product_estimate(&self, code: &PolarCode, query: &[f32]) -> Result<f32> {
331 let projected = self.project_query(query)?;
332 self.inner_product_estimate_with_projected_query(code, &projected)
333 }
334
335 pub fn project_query(&self, query: &[f32]) -> Result<PolarProjectedQuery> {
337 self.check_input_dim(query.len())?;
338 check_finite_vector(query)?;
339 let mut rotated_query = vec![0.0f32; self.dim];
340 self.rotation.apply(query, &mut rotated_query)?;
341 check_finite_vector(&rotated_query)?;
342 Ok(PolarProjectedQuery { rotated_query })
343 }
344
345 pub fn inner_product_estimate_with_projected_query(
347 &self,
348 code: &PolarCode,
349 query: &PolarProjectedQuery,
350 ) -> Result<f32> {
351 self.validate_code(code)?;
352
353 let pairs = self.dim / 2;
354 let mut estimate = 0.0f32;
355
356 for i in 0..pairs {
357 let theta = code.dequantize_angle(i)?;
358 let r = code.radii[i];
359 let q_a = query.rotated_query[2 * i];
360 let q_b = query.rotated_query[2 * i + 1];
361 estimate += r * (q_a * theta.cos() + q_b * theta.sin());
362 }
363
364 if !estimate.is_finite() {
365 return Err(TurboQuantError::MalformedCode {
366 reason: "polar score is not finite".into(),
367 });
368 }
369 Ok(estimate)
370 }
371
372 pub fn l2_distance_estimate(&self, code: &PolarCode, query: &[f32]) -> Result<f32> {
378 let ip = self.inner_product_estimate(code, query)?;
379
380 let query_norm_sq: f32 = query.iter().map(|x| x * x).sum();
381 let code_norm_sq: f32 = code.radii.iter().map(|r| r * r).sum();
382
383 Ok((query_norm_sq + code_norm_sq - 2.0 * ip).max(0.0))
384 }
385
386 fn check_input_dim(&self, got: usize) -> Result<()> {
387 if got != self.dim {
388 return Err(TurboQuantError::DimensionMismatch {
389 expected: self.dim,
390 got,
391 });
392 }
393 Ok(())
394 }
395
396 fn validate_code(&self, code: &PolarCode) -> Result<()> {
397 code.validate_for(self.dim, self.bits)
398 }
399}
400
401fn check_finite_vector(vector: &[f32]) -> Result<()> {
402 if let Some((index, _)) = vector
403 .iter()
404 .enumerate()
405 .find(|(_, value)| !value.is_finite())
406 {
407 return Err(TurboQuantError::NonFiniteInput { index });
408 }
409 Ok(())
410}
411
412fn encode_pair(a: f32, b: f32, bits: u8) -> (f32, u16) {
414 let r = (a * a + b * b).sqrt();
415 let theta = b.atan2(a); let levels = 1u32 << bits;
417 let normalized = (theta + PI) / (2.0 * PI);
419 let idx = (normalized * levels as f32).floor() as u32 % levels;
420 (r, idx as u16)
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 fn unit_vector(dim: usize, i: usize) -> Vec<f32> {
428 let mut v = vec![0.0f32; dim];
429 v[i] = 1.0;
430 v
431 }
432
433 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
434 use rand::SeedableRng;
435 use rand_chacha::ChaCha8Rng;
436 use rand_distr::{Distribution, StandardNormal};
437 let mut rng = ChaCha8Rng::seed_from_u64(seed);
438 (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect()
439 }
440
441 #[test]
442 fn encode_decode_roundtrip_high_bits() {
443 let q = PolarQuantizer::new(8, 16, 42).unwrap();
444 let x = vec![1.0f32, 2.0, -1.5, 0.5, 3.0, -2.0, 0.1, -0.8];
445
446 let code = q.encode(&x).unwrap();
447 let decoded = q.decode(&code).unwrap();
448
449 for (orig, dec) in x.iter().zip(decoded.iter()) {
450 assert!(
451 (orig - dec).abs() < 0.01,
452 "orig={orig:.4}, decoded={dec:.4}"
453 );
454 }
455 }
456
457 #[test]
458 fn decode_batch_is_bit_exact_with_per_vec() {
459 for bits in [4u8, 8, 12] {
465 for seed in [0u64, 1, 42, 1337] {
466 let q = PolarQuantizer::new(64, bits, seed).unwrap();
467 let mut vecs: Vec<Vec<f32>> = Vec::new();
468 for i in 0..32 {
469 let v: Vec<f32> = (0..64)
470 .map(|j| ((i * 64 + j) as f32 * 0.137 + seed as f32 * 0.011).sin())
471 .collect();
472 vecs.push(v);
473 }
474 let codes: Vec<PolarCode> =
475 vecs.iter().map(|v| q.encode(v).unwrap()).collect();
476 let mut per_vec: Vec<Vec<f32>> = Vec::new();
478 for c in &codes {
479 per_vec.push(q.decode(c).unwrap());
480 }
481 let batched = q.decode_batch(&codes).unwrap();
483 assert_eq!(batched.len(), per_vec.len());
484 for (i, (a, b)) in per_vec.iter().zip(batched.iter()).enumerate() {
485 assert_eq!(a.len(), b.len(), "vec {i} length mismatch");
486 for (j, (x, y)) in a.iter().zip(b.iter()).enumerate() {
487 assert_eq!(
488 x.to_bits(),
489 y.to_bits(),
490 "vec {i} coord {j}: per_vec={x} batch={y} (bits={bits}, seed={seed})"
491 );
492 }
493 }
494 }
495 }
496 }
497
498 #[test]
499 fn inner_product_estimate_is_close_at_high_bits() {
500 let q = PolarQuantizer::new(16, 16, 7).unwrap();
501 let x = random_vector(16, 1);
502 let y = random_vector(16, 2);
503
504 let code = q.encode(&x).unwrap();
505 let estimated = q.inner_product_estimate(&code, &y).unwrap();
506 let exact: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
507
508 let relative_error = (estimated - exact).abs() / (exact.abs() + 1e-6);
509 assert!(
510 relative_error < 0.02,
511 "relative error {relative_error:.4} too large: estimated={estimated:.4}, exact={exact:.4}"
512 );
513 }
514
515 #[test]
516 fn encoding_is_deterministic() {
517 let q = PolarQuantizer::new(8, 8, 0).unwrap();
518 let x = vec![1.0f32; 8];
519
520 let c1 = q.encode(&x).unwrap();
521 let c2 = q.encode(&x).unwrap();
522 assert_eq!(c1.angle_indices, c2.angle_indices);
523 assert_eq!(c1.radii, c2.radii);
524 }
525
526 #[test]
527 fn zero_vector_has_zero_radius() {
528 let q = PolarQuantizer::new(8, 8, 1).unwrap();
529 let x = vec![0.0f32; 8];
530 let code = q.encode(&x).unwrap();
531 for r in &code.radii {
532 assert!(*r < 1e-7, "expected zero radius, got {r}");
533 }
534 }
535
536 #[test]
537 fn unit_vectors_preserve_norm() {
538 let q = PolarQuantizer::new(8, 16, 3).unwrap();
539 for i in 0..8 {
540 let x = unit_vector(8, i);
541 let code = q.encode(&x).unwrap();
542 let norm_sq: f32 = code.radii.iter().map(|r| r * r).sum();
543 assert!((norm_sq - 1.0).abs() < 1e-5, "norm_sq={norm_sq}");
544 }
545 }
546
547 #[test]
548 fn nearest_neighbor_ordering_preserved_at_8bits() {
549 let q = PolarQuantizer::new(16, 8, 42).unwrap();
550 let query = random_vector(16, 99);
551
552 let close = {
554 let mut v = query.clone();
555 v.iter_mut().for_each(|x| *x += 0.01);
556 v
557 };
558 let far1 = random_vector(16, 200);
559 let far2 = random_vector(16, 201);
560
561 let code_close = q.encode(&close).unwrap();
562 let code_far1 = q.encode(&far1).unwrap();
563 let code_far2 = q.encode(&far2).unwrap();
564
565 let ip_close = q.inner_product_estimate(&code_close, &query).unwrap();
566 let ip_far1 = q.inner_product_estimate(&code_far1, &query).unwrap();
567 let ip_far2 = q.inner_product_estimate(&code_far2, &query).unwrap();
568
569 assert!(
570 ip_close > ip_far1 && ip_close > ip_far2,
571 "close={ip_close:.3}, far1={ip_far1:.3}, far2={ip_far2:.3}"
572 );
573 }
574
575 #[test]
576 fn dimension_mismatch_is_rejected() {
577 let q = PolarQuantizer::new(8, 8, 0).unwrap();
578 let result = q.encode(&[1.0f32; 10]);
579 assert!(result.is_err());
580 }
581
582 #[test]
583 fn odd_dimension_is_rejected() {
584 assert!(PolarQuantizer::new(7, 8, 0).is_err());
585 }
586
587 #[test]
588 fn zero_bits_rejected() {
589 assert!(PolarQuantizer::new(8, 0, 0).is_err());
590 }
591
592 #[test]
593 fn code_serialization_roundtrip() {
594 let q = PolarQuantizer::new(8, 8, 42).unwrap();
595 let x = vec![1.0f32, -2.0, 0.5, 1.5, -0.3, 0.8, -1.0, 2.0];
596 let code = q.encode(&x).unwrap();
597 let json = serde_json::to_string(&code).unwrap();
598 let restored: PolarCode = serde_json::from_str(&json).unwrap();
599 assert_eq!(code, restored);
600 }
601}