Skip to main content

turbo_quant/
polar.rs

1//! PolarQuant: high-efficiency vector compression via polar coordinate encoding.
2//!
3//! PolarQuant converts Cartesian pairs into polar form (radius, angle), then
4//! uniformly quantizes angles. Because the rotation stage whitens the data,
5//! angles are uniformly distributed on [−π, π], making uniform quantization
6//! a profile-defined alternative to trained-codebook calibration.
7//!
8//! # Algorithm
9//!
10//! Given a d-dimensional (even d) vector x after rotation y = R·x:
11//!
12//! 1. Group into d/2 pairs: (y₀, y₁), (y₂, y₃), …
13//! 2. Convert each pair to polar: rᵢ = √(y₂ᵢ² + y₂ᵢ₊₁²), θᵢ = atan2(y₂ᵢ₊₁, y₂ᵢ)
14//! 3. Quantize each θᵢ to `bits` levels uniformly on [−π, π]
15//! 4. Store radii as f32 and bitpack quantized angle indices
16//!
17//! For approximate nearest-neighbor search, exact reconstruction is not required.
18//! The inner product estimator operates directly on polar codes, avoiding
19//! the decode round-trip entirely.
20
21use std::f32::consts::PI;
22
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25
26use crate::{
27    bitpack,
28    error::{Result, TurboQuantError},
29    rotation::{Rotation, RotationBackend, RotationKind},
30};
31
32/// A compressed representation of a single vector in polar form.
33///
34/// The `radii` array has length d/2, and `angle_indices` stores d/2 logical
35/// angle indices. Angle indices are in [0, 2^bits).
36#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct PolarCode {
38    /// Original vector dimension (must be even).
39    pub dim: usize,
40    /// Number of bits used to quantize each angle.
41    pub bits: u8,
42    /// Per-pair radii (f32, lossless).
43    pub radii: Vec<f32>,
44    /// Quantized angle indices in [0, 2^bits).
45    pub angle_indices: Vec<u16>,
46}
47
48impl PolarCode {
49    /// Number of pairs (= dim / 2).
50    pub fn pair_count(&self) -> usize {
51        self.dim / 2
52    }
53
54    /// Build a packed code from logical angle indices.
55    pub fn from_parts(
56        dim: usize,
57        bits: u8,
58        radii: Vec<f32>,
59        angle_indices: &[u16],
60    ) -> Result<Self> {
61        let code = Self {
62            dim,
63            bits,
64            radii,
65            angle_indices: angle_indices.to_vec(),
66        };
67        code.validate_for(dim, bits)?;
68        Ok(code)
69    }
70
71    /// Return the logical angle index for pair `i`.
72    pub fn angle_index(&self, i: usize) -> Result<u16> {
73        if i >= self.pair_count() {
74            return Err(TurboQuantError::MalformedCode {
75                reason: format!(
76                    "angle index {i} is outside pair count {}",
77                    self.pair_count()
78                ),
79            });
80        }
81        Ok(self.angle_indices[i])
82    }
83
84    /// Unpack all logical angle indices.
85    pub fn angle_indices(&self) -> Result<Vec<u16>> {
86        self.validate_for(self.dim, self.bits)?;
87        Ok(self.angle_indices.clone())
88    }
89
90    /// Reconstruct the dequantized angle for pair `i` in radians ∈ [−π, π).
91    pub fn dequantize_angle(&self, i: usize) -> Result<f32> {
92        let levels = 1u32 << self.bits;
93        let idx = self.angle_index(i)? as f32;
94        Ok((idx / levels as f32) * (2.0 * PI) - PI)
95    }
96
97    /// Serialized payload bytes used by this code.
98    pub fn encoded_bytes(&self) -> usize {
99        self.radii.len() * std::mem::size_of::<f32>()
100            + bitpack::packed_len(self.angle_indices.len(), self.bits).unwrap_or(usize::MAX)
101    }
102
103    /// Validate this code against an expected profile.
104    pub fn validate_for(&self, dim: usize, bits: u8) -> Result<()> {
105        if self.dim != dim {
106            return Err(TurboQuantError::DimensionMismatch {
107                expected: dim,
108                got: self.dim,
109            });
110        }
111        if self.bits != bits {
112            return Err(TurboQuantError::MalformedCode {
113                reason: format!("code has bits={}, expected {bits}", self.bits),
114            });
115        }
116        if dim == 0 || dim % 2 != 0 {
117            return Err(TurboQuantError::MalformedCode {
118                reason: format!("code dimension must be positive and even, got {dim}"),
119            });
120        }
121        let pairs = dim / 2;
122        if self.radii.len() != pairs {
123            return Err(TurboQuantError::MalformedCode {
124                reason: format!("code has {} radii, expected {pairs}", self.radii.len()),
125            });
126        }
127        for (index, radius) in self.radii.iter().enumerate() {
128            if !radius.is_finite() || *radius < 0.0 {
129                return Err(TurboQuantError::MalformedCode {
130                    reason: format!("radius {index} is not finite and non-negative"),
131                });
132            }
133        }
134        if self.angle_indices.len() != pairs {
135            return Err(TurboQuantError::MalformedCode {
136                reason: format!(
137                    "code has {} angle indices, expected {pairs}",
138                    self.angle_indices.len()
139                ),
140            });
141        }
142        let levels = 1u32 << bits;
143        for (index, angle_index) in self.angle_indices.iter().enumerate() {
144            if u32::from(*angle_index) >= levels {
145                return Err(TurboQuantError::MalformedCode {
146                    reason: format!(
147                        "angle index {index} value {angle_index} is outside [0, {levels})"
148                    ),
149                });
150            }
151        }
152        Ok(())
153    }
154}
155
156/// Encodes and decodes vectors using PolarQuant.
157///
158/// The quantizer owns a selected rotation backend that is applied before
159/// encoding. The rotation is seeded deterministically, so the profile can
160/// record `(dim, seed, bits, rotation_kind)`.
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct PolarQuantizer {
163    dim: usize,
164    bits: u8,
165    rotation: RotationBackend,
166}
167
168/// Query state prepared once for scoring multiple PolarQuant codes.
169#[derive(Debug, Clone, PartialEq)]
170pub struct PolarProjectedQuery {
171    rotated_query: Vec<f32>,
172}
173
174impl PolarQuantizer {
175    /// Create a new quantizer for vectors of dimension `dim`.
176    ///
177    /// - `bits`: angle quantization levels (1–16). Higher values generally
178    ///   reduce angle quantization error and increase storage.
179    /// - `seed`: controls the random rotation. Identical `(dim, bits, seed)`
180    ///   always produces an identical quantizer.
181    pub fn new(dim: usize, bits: u8, seed: u64) -> Result<Self> {
182        if dim == 0 {
183            return Err(TurboQuantError::ZeroDimension);
184        }
185        if dim % 2 != 0 {
186            return Err(TurboQuantError::OddDimension { got: dim });
187        }
188        if bits == 0 || bits > 16 {
189            return Err(TurboQuantError::InvalidBitWidth { got: bits });
190        }
191        Self::new_with_rotation(dim, bits, seed, RotationKind::Auto)
192    }
193
194    /// Create a quantizer with an explicit rotation policy.
195    pub fn new_with_rotation(
196        dim: usize,
197        bits: u8,
198        seed: u64,
199        rotation_kind: RotationKind,
200    ) -> Result<Self> {
201        if dim == 0 {
202            return Err(TurboQuantError::ZeroDimension);
203        }
204        if dim % 2 != 0 {
205            return Err(TurboQuantError::OddDimension { got: dim });
206        }
207        if bits == 0 || bits > 16 {
208            return Err(TurboQuantError::InvalidBitWidth { got: bits });
209        }
210        let rotation = RotationBackend::new(dim, seed, rotation_kind)?;
211        Ok(Self {
212            dim,
213            bits,
214            rotation,
215        })
216    }
217
218    /// Create a quantizer using dense QR reference rotation.
219    pub fn new_with_stored_rotation(dim: usize, bits: u8, seed: u64) -> Result<Self> {
220        Self::new_with_rotation(dim, bits, seed, RotationKind::StoredQr)
221    }
222
223    /// The vector dimension this quantizer operates on.
224    pub fn dim(&self) -> usize {
225        self.dim
226    }
227
228    /// Angle quantization bit width.
229    pub fn bits(&self) -> u8 {
230        self.bits
231    }
232
233    /// Resolved rotation backend.
234    pub fn rotation_kind(&self) -> RotationKind {
235        self.rotation.kind()
236    }
237
238    /// Resolved rotation backend label for profiles and receipts.
239    pub fn rotation_kind_label(&self) -> &'static str {
240        self.rotation.kind_label()
241    }
242
243    /// Encode a vector into a [`PolarCode`].
244    ///
245    /// `vector` must have length `dim`.
246    pub fn encode(&self, vector: &[f32]) -> Result<PolarCode> {
247        self.check_input_dim(vector.len())?;
248        check_finite_vector(vector)?;
249
250        let mut rotated = vec![0.0f32; self.dim];
251        self.rotation.apply(vector, &mut rotated)?;
252
253        let pairs = self.dim / 2;
254        let mut radii = Vec::with_capacity(pairs);
255        let mut angle_indices = Vec::with_capacity(pairs);
256
257        for i in 0..pairs {
258            let a = rotated[2 * i];
259            let b = rotated[2 * i + 1];
260            let (r, idx) = encode_pair(a, b, self.bits);
261            radii.push(r);
262            angle_indices.push(idx);
263        }
264
265        PolarCode::from_parts(self.dim, self.bits, radii, &angle_indices)
266    }
267
268    /// Decode a [`PolarCode`] back to an approximate vector.
269    ///
270    /// The result is the nearest-neighbor reconstruction in the rotated space,
271    /// then inverse-rotated back. Reconstruction error depends on `bits`.
272    pub fn decode(&self, code: &PolarCode) -> Result<Vec<f32>> {
273        self.validate_code(code)?;
274        let rotated = self.decode_to_rotated(code)?;
275        let mut output = vec![0.0f32; self.dim];
276        self.rotation.apply_inverse(&rotated, &mut output)?;
277        Ok(output)
278    }
279
280    /// Decode a batch of [`PolarCode`]s back to vectors in one call.
281    ///
282    /// Bit-exact identical to `decode` for each code in turn; the win is
283    /// amortizing the per-call branch / lookup overhead and keeping the
284    /// sign table (or matrix) hot in cache across the whole batch.
285    /// Returns one `Vec<f32>` per input code, in the same order.
286    pub fn decode_batch(&self, codes: &[PolarCode]) -> Result<Vec<Vec<f32>>> {
287        if codes.is_empty() {
288            return Ok(Vec::new());
289        }
290        // Phase 1: validate and dequantize every code's polar pairs into
291        // a flat (cos, sin)-rotated buffer. Allocations are pre-sized.
292        let mut rotated: Vec<Vec<f32>> = Vec::with_capacity(codes.len());
293        for code in codes {
294            self.validate_code(code)?;
295            rotated.push(self.decode_to_rotated(code)?);
296        }
297        // Phase 2: apply inverse rotation to the whole batch at once.
298        let rotated_refs: Vec<&[f32]> = rotated.iter().map(|v| v.as_slice()).collect();
299        self.rotation.apply_inverse_batch(&rotated_refs)
300    }
301
302    /// Decode a single [`PolarCode`] into its rotated-space representation
303    /// (the (r·cosθ, r·sinθ) pairs) without applying the inverse rotation.
304    /// Shared by `decode` and `decode_batch` to keep the angle math in
305    /// one place.
306    fn decode_to_rotated(&self, code: &PolarCode) -> Result<Vec<f32>> {
307        let mut rotated = vec![0.0f32; self.dim];
308        let pairs = self.dim / 2;
309        for i in 0..pairs {
310            let theta = code.dequantize_angle(i)?;
311            let r = code.radii[i];
312            rotated[2 * i] = r * theta.cos();
313            rotated[2 * i + 1] = r * theta.sin();
314        }
315        Ok(rotated)
316    }
317
318    /// Estimate the inner product ⟨query, encoded_vector⟩ without decoding.
319    ///
320    /// This is the core operation for approximate nearest-neighbor search.
321    /// The query is rotated, then each rotated pair is compared to the stored
322    /// polar code using the identity:
323    ///
324    /// ```text
325    /// ⟨q_pair, k_pair⟩ ≈ r_k · (q_pair · [cos θ_k, sin θ_k])
326    ///                   = r_k · (q_a cos θ_k + q_b sin θ_k)
327    /// ```
328    ///
329    /// Summing over all pairs gives the full inner product estimate.
330    pub fn inner_product_estimate(&self, code: &PolarCode, query: &[f32]) -> Result<f32> {
331        let projected = self.project_query(query)?;
332        self.inner_product_estimate_with_projected_query(code, &projected)
333    }
334
335    /// Rotate a query once so it can score multiple codes without repeated allocation.
336    pub fn project_query(&self, query: &[f32]) -> Result<PolarProjectedQuery> {
337        self.check_input_dim(query.len())?;
338        check_finite_vector(query)?;
339        let mut rotated_query = vec![0.0f32; self.dim];
340        self.rotation.apply(query, &mut rotated_query)?;
341        check_finite_vector(&rotated_query)?;
342        Ok(PolarProjectedQuery { rotated_query })
343    }
344
345    /// Estimate inner product using a pre-rotated query.
346    pub fn inner_product_estimate_with_projected_query(
347        &self,
348        code: &PolarCode,
349        query: &PolarProjectedQuery,
350    ) -> Result<f32> {
351        self.validate_code(code)?;
352
353        let pairs = self.dim / 2;
354        let mut estimate = 0.0f32;
355
356        for i in 0..pairs {
357            let theta = code.dequantize_angle(i)?;
358            let r = code.radii[i];
359            let q_a = query.rotated_query[2 * i];
360            let q_b = query.rotated_query[2 * i + 1];
361            estimate += r * (q_a * theta.cos() + q_b * theta.sin());
362        }
363
364        if !estimate.is_finite() {
365            return Err(TurboQuantError::MalformedCode {
366                reason: "polar score is not finite".into(),
367            });
368        }
369        Ok(estimate)
370    }
371
372    /// Compute the squared L2 distance estimate between query and encoded vector.
373    ///
374    /// Uses the identity: ||x - y||² = ||x||² + ||y||² - 2⟨x, y⟩.
375    /// The query's squared norm is computed exactly; the encoded vector's norm
376    /// is derived from the stored radii (lossless since radii are stored as f32).
377    pub fn l2_distance_estimate(&self, code: &PolarCode, query: &[f32]) -> Result<f32> {
378        let ip = self.inner_product_estimate(code, query)?;
379
380        let query_norm_sq: f32 = query.iter().map(|x| x * x).sum();
381        let code_norm_sq: f32 = code.radii.iter().map(|r| r * r).sum();
382
383        Ok((query_norm_sq + code_norm_sq - 2.0 * ip).max(0.0))
384    }
385
386    fn check_input_dim(&self, got: usize) -> Result<()> {
387        if got != self.dim {
388            return Err(TurboQuantError::DimensionMismatch {
389                expected: self.dim,
390                got,
391            });
392        }
393        Ok(())
394    }
395
396    fn validate_code(&self, code: &PolarCode) -> Result<()> {
397        code.validate_for(self.dim, self.bits)
398    }
399}
400
401fn check_finite_vector(vector: &[f32]) -> Result<()> {
402    if let Some((index, _)) = vector
403        .iter()
404        .enumerate()
405        .find(|(_, value)| !value.is_finite())
406    {
407        return Err(TurboQuantError::NonFiniteInput { index });
408    }
409    Ok(())
410}
411
412/// Encode a Cartesian pair (a, b) into (radius, quantized_angle_index).
413fn encode_pair(a: f32, b: f32, bits: u8) -> (f32, u16) {
414    let r = (a * a + b * b).sqrt();
415    let theta = b.atan2(a); // ∈ [−π, π]
416    let levels = 1u32 << bits;
417    // Map [−π, π] → [0, 1) → [0, levels)
418    let normalized = (theta + PI) / (2.0 * PI);
419    let idx = (normalized * levels as f32).floor() as u32 % levels;
420    (r, idx as u16)
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    fn unit_vector(dim: usize, i: usize) -> Vec<f32> {
428        let mut v = vec![0.0f32; dim];
429        v[i] = 1.0;
430        v
431    }
432
433    fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
434        use rand::SeedableRng;
435        use rand_chacha::ChaCha8Rng;
436        use rand_distr::{Distribution, StandardNormal};
437        let mut rng = ChaCha8Rng::seed_from_u64(seed);
438        (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect()
439    }
440
441    #[test]
442    fn encode_decode_roundtrip_high_bits() {
443        let q = PolarQuantizer::new(8, 16, 42).unwrap();
444        let x = vec![1.0f32, 2.0, -1.5, 0.5, 3.0, -2.0, 0.1, -0.8];
445
446        let code = q.encode(&x).unwrap();
447        let decoded = q.decode(&code).unwrap();
448
449        for (orig, dec) in x.iter().zip(decoded.iter()) {
450            assert!(
451                (orig - dec).abs() < 0.01,
452                "orig={orig:.4}, decoded={dec:.4}"
453            );
454        }
455    }
456
457    #[test]
458    fn decode_batch_is_bit_exact_with_per_vec() {
459        // Bit-exactness guard for the batch-decode fast path. The batch
460        // and per-vec paths must produce the same float output for the
461        // same input, because the whole point of the batch path is to
462        // be a drop-in replacement that only differs in constant factor
463        // (branch / lookup amortization).
464        for bits in [4u8, 8, 12] {
465            for seed in [0u64, 1, 42, 1337] {
466                let q = PolarQuantizer::new(64, bits, seed).unwrap();
467                let mut vecs: Vec<Vec<f32>> = Vec::new();
468                for i in 0..32 {
469                    let v: Vec<f32> = (0..64)
470                        .map(|j| ((i * 64 + j) as f32 * 0.137 + seed as f32 * 0.011).sin())
471                        .collect();
472                    vecs.push(v);
473                }
474                let codes: Vec<PolarCode> =
475                    vecs.iter().map(|v| q.encode(v).unwrap()).collect();
476                // Per-vec baseline.
477                let mut per_vec: Vec<Vec<f32>> = Vec::new();
478                for c in &codes {
479                    per_vec.push(q.decode(c).unwrap());
480                }
481                // Batch path.
482                let batched = q.decode_batch(&codes).unwrap();
483                assert_eq!(batched.len(), per_vec.len());
484                for (i, (a, b)) in per_vec.iter().zip(batched.iter()).enumerate() {
485                    assert_eq!(a.len(), b.len(), "vec {i} length mismatch");
486                    for (j, (x, y)) in a.iter().zip(b.iter()).enumerate() {
487                        assert_eq!(
488                            x.to_bits(),
489                            y.to_bits(),
490                            "vec {i} coord {j}: per_vec={x} batch={y} (bits={bits}, seed={seed})"
491                        );
492                    }
493                }
494            }
495        }
496    }
497
498    #[test]
499    fn inner_product_estimate_is_close_at_high_bits() {
500        let q = PolarQuantizer::new(16, 16, 7).unwrap();
501        let x = random_vector(16, 1);
502        let y = random_vector(16, 2);
503
504        let code = q.encode(&x).unwrap();
505        let estimated = q.inner_product_estimate(&code, &y).unwrap();
506        let exact: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
507
508        let relative_error = (estimated - exact).abs() / (exact.abs() + 1e-6);
509        assert!(
510            relative_error < 0.02,
511            "relative error {relative_error:.4} too large: estimated={estimated:.4}, exact={exact:.4}"
512        );
513    }
514
515    #[test]
516    fn encoding_is_deterministic() {
517        let q = PolarQuantizer::new(8, 8, 0).unwrap();
518        let x = vec![1.0f32; 8];
519
520        let c1 = q.encode(&x).unwrap();
521        let c2 = q.encode(&x).unwrap();
522        assert_eq!(c1.angle_indices, c2.angle_indices);
523        assert_eq!(c1.radii, c2.radii);
524    }
525
526    #[test]
527    fn zero_vector_has_zero_radius() {
528        let q = PolarQuantizer::new(8, 8, 1).unwrap();
529        let x = vec![0.0f32; 8];
530        let code = q.encode(&x).unwrap();
531        for r in &code.radii {
532            assert!(*r < 1e-7, "expected zero radius, got {r}");
533        }
534    }
535
536    #[test]
537    fn unit_vectors_preserve_norm() {
538        let q = PolarQuantizer::new(8, 16, 3).unwrap();
539        for i in 0..8 {
540            let x = unit_vector(8, i);
541            let code = q.encode(&x).unwrap();
542            let norm_sq: f32 = code.radii.iter().map(|r| r * r).sum();
543            assert!((norm_sq - 1.0).abs() < 1e-5, "norm_sq={norm_sq}");
544        }
545    }
546
547    #[test]
548    fn nearest_neighbor_ordering_preserved_at_8bits() {
549        let q = PolarQuantizer::new(16, 8, 42).unwrap();
550        let query = random_vector(16, 99);
551
552        // Create three database vectors: one close, two far.
553        let close = {
554            let mut v = query.clone();
555            v.iter_mut().for_each(|x| *x += 0.01);
556            v
557        };
558        let far1 = random_vector(16, 200);
559        let far2 = random_vector(16, 201);
560
561        let code_close = q.encode(&close).unwrap();
562        let code_far1 = q.encode(&far1).unwrap();
563        let code_far2 = q.encode(&far2).unwrap();
564
565        let ip_close = q.inner_product_estimate(&code_close, &query).unwrap();
566        let ip_far1 = q.inner_product_estimate(&code_far1, &query).unwrap();
567        let ip_far2 = q.inner_product_estimate(&code_far2, &query).unwrap();
568
569        assert!(
570            ip_close > ip_far1 && ip_close > ip_far2,
571            "close={ip_close:.3}, far1={ip_far1:.3}, far2={ip_far2:.3}"
572        );
573    }
574
575    #[test]
576    fn dimension_mismatch_is_rejected() {
577        let q = PolarQuantizer::new(8, 8, 0).unwrap();
578        let result = q.encode(&[1.0f32; 10]);
579        assert!(result.is_err());
580    }
581
582    #[test]
583    fn odd_dimension_is_rejected() {
584        assert!(PolarQuantizer::new(7, 8, 0).is_err());
585    }
586
587    #[test]
588    fn zero_bits_rejected() {
589        assert!(PolarQuantizer::new(8, 0, 0).is_err());
590    }
591
592    #[test]
593    fn code_serialization_roundtrip() {
594        let q = PolarQuantizer::new(8, 8, 42).unwrap();
595        let x = vec![1.0f32, -2.0, 0.5, 1.5, -0.3, 0.8, -1.0, 2.0];
596        let code = q.encode(&x).unwrap();
597        let json = serde_json::to_string(&code).unwrap();
598        let restored: PolarCode = serde_json::from_str(&json).unwrap();
599        assert_eq!(code, restored);
600    }
601}