Skip to main content

selene_core/vector/
turbo_quant.rs

1//! Safe TurboQuant codec primitives for compressed vector storage.
2//!
3//! This module owns the deterministic, dependency-free pieces that graph vector
4//! indexes can reuse before they add index maintenance, persistence, or SIMD
5//! search caches.
6
7use std::error::Error;
8use std::fmt;
9use std::mem::size_of;
10
11use crate::MAX_VECTOR_DIMENSION;
12
13#[path = "turbo_quant/blocked.rs"]
14mod blocked;
15
16pub use blocked::{TURBO_QUANT_BLOCK_ROWS, TurboQuantBlockedCodes};
17
18/// Result type for TurboQuant codec operations.
19pub type TurboQuantCodecResult<T> = Result<T, TurboQuantCodecError>;
20
21/// Errors returned by safe TurboQuant codec primitives.
22#[derive(Clone, Debug, PartialEq)]
23pub enum TurboQuantCodecError {
24    /// Bit widths must be in the inclusive `2..=4` range.
25    InvalidBitWidth {
26        /// Rejected bit width.
27        bits: u8,
28    },
29    /// Vector dimensions must be non-zero and no larger than
30    /// [`MAX_VECTOR_DIMENSION`].
31    InvalidDimension {
32        /// Rejected dimension.
33        dimension: usize,
34        /// Maximum accepted dimension.
35        max: usize,
36    },
37    /// Caller-supplied packed bytes did not match the codec shape.
38    ByteLengthMismatch {
39        /// Expected byte length for the requested packed-code operation.
40        expected: usize,
41        /// Actual byte length supplied by the caller.
42        actual: usize,
43    },
44    /// Packed-code storage size overflowed `usize`.
45    SizeOverflow,
46    /// Requested row is outside the packed-code matrix.
47    RowOutOfBounds {
48        /// Requested row.
49        row: usize,
50        /// Number of rows stored.
51        rows: usize,
52    },
53    /// Requested dimension is outside the packed-code matrix.
54    DimensionOutOfBounds {
55        /// Requested dimension.
56        dimension: usize,
57        /// Number of dimensions stored per row.
58        dimensions: usize,
59    },
60    /// A code exceeded the active bit width's maximum representable code.
61    InvalidCode {
62        /// Rejected code.
63        code: u8,
64        /// Maximum accepted code for this bit width.
65        max: u8,
66    },
67    /// A value that must be finite was not finite.
68    NonFiniteValue {
69        /// Non-finite value.
70        value: f32,
71    },
72}
73
74impl fmt::Display for TurboQuantCodecError {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::InvalidBitWidth { bits } => {
78                write!(f, "invalid TurboQuant bit width {bits}; expected 2..=4")
79            }
80            Self::InvalidDimension { dimension, max } => write!(
81                f,
82                "invalid TurboQuant dimension {dimension}; expected 1..={max}"
83            ),
84            Self::ByteLengthMismatch { expected, actual } => write!(
85                f,
86                "invalid TurboQuant packed byte length {actual}; expected {expected}"
87            ),
88            Self::SizeOverflow => write!(f, "TurboQuant packed-code size overflowed usize"),
89            Self::RowOutOfBounds { row, rows } => {
90                write!(f, "TurboQuant row {row} is out of bounds for {rows} rows")
91            }
92            Self::DimensionOutOfBounds {
93                dimension,
94                dimensions,
95            } => write!(
96                f,
97                "TurboQuant dimension {dimension} is out of bounds for {dimensions} dimensions"
98            ),
99            Self::InvalidCode { code, max } => {
100                write!(f, "TurboQuant code {code} exceeds maximum code {max}")
101            }
102            Self::NonFiniteValue { value } => {
103                write!(f, "TurboQuant value must be finite, got {value}")
104            }
105        }
106    }
107}
108
109impl Error for TurboQuantCodecError {}
110
111/// Validated TurboQuant bit width.
112#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
113pub struct TurboQuantBitWidth(u8);
114
115impl TurboQuantBitWidth {
116    /// Construct a validated TurboQuant bit width.
117    ///
118    /// # Errors
119    ///
120    /// Returns [`TurboQuantCodecError::InvalidBitWidth`] unless `bits` is in
121    /// the inclusive `2..=4` range.
122    pub const fn new(bits: u8) -> TurboQuantCodecResult<Self> {
123        if bits >= 2 && bits <= 4 {
124            Ok(Self(bits))
125        } else {
126            Err(TurboQuantCodecError::InvalidBitWidth { bits })
127        }
128    }
129
130    /// Return the number of bits stored for each encoded coordinate.
131    #[must_use]
132    pub const fn bits(self) -> u8 {
133        self.0
134    }
135
136    /// Return the number of quantization levels represented by this bit width.
137    #[must_use]
138    pub const fn levels(self) -> usize {
139        1_usize << self.0
140    }
141
142    /// Return the maximum representable code for this bit width.
143    #[must_use]
144    pub const fn max_code(self) -> u8 {
145        (1_u8 << self.0) - 1
146    }
147}
148
149impl TryFrom<u8> for TurboQuantBitWidth {
150    type Error = TurboQuantCodecError;
151
152    fn try_from(value: u8) -> Result<Self, Self::Error> {
153        Self::new(value)
154    }
155}
156
157impl From<TurboQuantBitWidth> for u8 {
158    fn from(value: TurboQuantBitWidth) -> Self {
159        value.bits()
160    }
161}
162
163/// Deterministic TurboQuant scalar codebook family.
164#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
165pub enum TurboQuantCodebookKind {
166    /// Uniform centroids clipped to three standard deviations of
167    /// `N(0, 1 / dimension)`.
168    ClippedUniform,
169    /// Lloyd-Max centroids for `N(0, 1 / dimension)`.
170    NormalLloydMax,
171}
172
173/// Deterministic scalar codebook for one TurboQuant vector dimension.
174#[derive(Clone, Debug, PartialEq)]
175pub struct TurboQuantCodebook {
176    kind: TurboQuantCodebookKind,
177    bit_width: TurboQuantBitWidth,
178    dimension: usize,
179    centroids: Vec<f32>,
180    boundaries: Vec<f32>,
181}
182
183impl TurboQuantCodebook {
184    /// Build a deterministic codebook for `kind`, `bit_width`, and
185    /// `dimension`.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error when `dimension` is zero or exceeds
190    /// [`MAX_VECTOR_DIMENSION`].
191    pub fn new(
192        kind: TurboQuantCodebookKind,
193        bit_width: TurboQuantBitWidth,
194        dimension: usize,
195    ) -> TurboQuantCodecResult<Self> {
196        validate_dimension(dimension)?;
197        let centroids = match kind {
198            TurboQuantCodebookKind::ClippedUniform => {
199                clipped_uniform_centroids(bit_width, dimension)
200            }
201            TurboQuantCodebookKind::NormalLloydMax => {
202                normal_lloyd_max_centroids(bit_width, dimension)
203            }
204        };
205        let boundaries = centroid_boundaries(&centroids);
206        Ok(Self {
207            kind,
208            bit_width,
209            dimension,
210            centroids,
211            boundaries,
212        })
213    }
214
215    /// Build a clipped-uniform codebook.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error when `dimension` is zero or exceeds
220    /// [`MAX_VECTOR_DIMENSION`].
221    pub fn clipped_uniform(
222        bit_width: TurboQuantBitWidth,
223        dimension: usize,
224    ) -> TurboQuantCodecResult<Self> {
225        Self::new(TurboQuantCodebookKind::ClippedUniform, bit_width, dimension)
226    }
227
228    /// Build a normal Lloyd-Max codebook.
229    ///
230    /// # Errors
231    ///
232    /// Returns an error when `dimension` is zero or exceeds
233    /// [`MAX_VECTOR_DIMENSION`].
234    pub fn normal_lloyd_max(
235        bit_width: TurboQuantBitWidth,
236        dimension: usize,
237    ) -> TurboQuantCodecResult<Self> {
238        Self::new(TurboQuantCodebookKind::NormalLloydMax, bit_width, dimension)
239    }
240
241    /// Return the codebook family.
242    #[must_use]
243    pub const fn kind(&self) -> TurboQuantCodebookKind {
244        self.kind
245    }
246
247    /// Return the bit width used by this codebook.
248    #[must_use]
249    pub const fn bit_width(&self) -> TurboQuantBitWidth {
250        self.bit_width
251    }
252
253    /// Return the vector dimension this codebook was calibrated for.
254    #[must_use]
255    pub const fn dimension(&self) -> usize {
256        self.dimension
257    }
258
259    /// Return the codebook centroids in ascending code order.
260    #[must_use]
261    pub fn centroids(&self) -> &[f32] {
262        &self.centroids
263    }
264
265    /// Return midpoint boundaries between adjacent centroids.
266    #[must_use]
267    pub fn boundaries(&self) -> &[f32] {
268        &self.boundaries
269    }
270
271    /// Return the centroid for `code`.
272    ///
273    /// # Errors
274    ///
275    /// Returns [`TurboQuantCodecError::InvalidCode`] when `code` exceeds this
276    /// codebook's bit width.
277    pub fn centroid(&self, code: u8) -> TurboQuantCodecResult<f32> {
278        self.validate_code(code)?;
279        Ok(self.centroids[usize::from(code)])
280    }
281
282    /// Quantize a finite scalar into a code by scanning codebook boundaries.
283    ///
284    /// Values equal to a boundary choose the lower code, matching the
285    /// lower-code tie break used by exact nearest-centroid scans.
286    ///
287    /// # Errors
288    ///
289    /// Returns [`TurboQuantCodecError::NonFiniteValue`] for NaN or infinity.
290    pub fn encode_scalar(&self, value: f32) -> TurboQuantCodecResult<u8> {
291        if !value.is_finite() {
292            return Err(TurboQuantCodecError::NonFiniteValue { value });
293        }
294        Ok(self
295            .boundaries
296            .partition_point(|boundary| value > *boundary) as u8)
297    }
298
299    /// Return an approximate heap allocation footprint for this codebook.
300    #[must_use]
301    pub fn estimated_bytes(&self) -> usize {
302        self.centroids
303            .len()
304            .saturating_add(self.boundaries.len())
305            .saturating_mul(size_of::<f32>())
306    }
307
308    fn validate_code(&self, code: u8) -> TurboQuantCodecResult<()> {
309        let max = self.bit_width.max_code();
310        if code <= max {
311            Ok(())
312        } else {
313            Err(TurboQuantCodecError::InvalidCode { code, max })
314        }
315    }
316}
317
318/// Row-major packed TurboQuant coordinate codes.
319#[derive(Clone, Debug, Eq, PartialEq)]
320pub struct TurboQuantPackedCodes {
321    bit_width: TurboQuantBitWidth,
322    dimensions: usize,
323    rows: usize,
324    bytes_per_row: usize,
325    bytes: Vec<u8>,
326}
327
328impl TurboQuantPackedCodes {
329    /// Allocate zero-filled packed-code storage.
330    ///
331    /// # Errors
332    ///
333    /// Returns an error when dimensions are invalid or the computed byte size
334    /// overflows `usize`.
335    pub fn new(
336        bit_width: TurboQuantBitWidth,
337        dimensions: usize,
338        rows: usize,
339    ) -> TurboQuantCodecResult<Self> {
340        let bytes_per_row = bytes_per_row(bit_width, dimensions)?;
341        let byte_len = bytes_per_row
342            .checked_mul(rows)
343            .ok_or(TurboQuantCodecError::SizeOverflow)?;
344        Ok(Self {
345            bit_width,
346            dimensions,
347            rows,
348            bytes_per_row,
349            bytes: vec![0; byte_len],
350        })
351    }
352
353    /// Build packed-code storage from existing bytes.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error when dimensions are invalid, the computed byte size
358    /// overflows `usize`, or `bytes.len()` does not match the expected shape.
359    pub fn from_bytes(
360        bit_width: TurboQuantBitWidth,
361        dimensions: usize,
362        rows: usize,
363        bytes: Vec<u8>,
364    ) -> TurboQuantCodecResult<Self> {
365        let bytes_per_row = bytes_per_row(bit_width, dimensions)?;
366        let expected = bytes_per_row
367            .checked_mul(rows)
368            .ok_or(TurboQuantCodecError::SizeOverflow)?;
369        if bytes.len() != expected {
370            return Err(TurboQuantCodecError::ByteLengthMismatch {
371                expected,
372                actual: bytes.len(),
373            });
374        }
375        Ok(Self {
376            bit_width,
377            dimensions,
378            rows,
379            bytes_per_row,
380            bytes,
381        })
382    }
383
384    /// Return the bit width used by the packed codes.
385    #[must_use]
386    pub const fn bit_width(&self) -> TurboQuantBitWidth {
387        self.bit_width
388    }
389
390    /// Return the number of dimensions encoded in each row.
391    #[must_use]
392    pub const fn dimensions(&self) -> usize {
393        self.dimensions
394    }
395
396    /// Return the number of encoded rows.
397    #[must_use]
398    pub const fn rows(&self) -> usize {
399        self.rows
400    }
401
402    /// Return the byte stride for one row, including any trailing padding bits.
403    #[must_use]
404    pub const fn bytes_per_row(&self) -> usize {
405        self.bytes_per_row
406    }
407
408    /// Return the packed backing bytes.
409    #[must_use]
410    pub fn as_bytes(&self) -> &[u8] {
411        &self.bytes
412    }
413
414    /// Consume this storage and return the packed backing bytes.
415    #[must_use]
416    pub fn into_bytes(self) -> Vec<u8> {
417        self.bytes
418    }
419
420    /// Return the packed-code byte footprint.
421    #[must_use]
422    pub fn estimated_bytes(&self) -> usize {
423        self.bytes.len()
424    }
425
426    /// Resize the row count while preserving existing packed rows.
427    ///
428    /// Newly added rows are zero-filled. Shrinking drops trailing rows.
429    ///
430    /// # Errors
431    ///
432    /// Returns an error when the computed byte size overflows `usize`.
433    pub fn resize_rows(&mut self, rows: usize) -> TurboQuantCodecResult<()> {
434        let byte_len = self
435            .bytes_per_row
436            .checked_mul(rows)
437            .ok_or(TurboQuantCodecError::SizeOverflow)?;
438        self.bytes.resize(byte_len, 0);
439        self.rows = rows;
440        Ok(())
441    }
442
443    /// Read one packed coordinate code.
444    ///
445    /// # Errors
446    ///
447    /// Returns bounds errors when `row` or `dimension` is outside the packed
448    /// matrix.
449    pub fn read(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<u8> {
450        let bit_offset = self.bit_offset(row, dimension)?;
451        let byte = bit_offset / u8::BITS as usize;
452        let shift = bit_offset % u8::BITS as usize;
453        let mut word = u16::from(self.bytes[byte]);
454        if byte + 1 < self.bytes.len() {
455            word |= u16::from(self.bytes[byte + 1]) << u8::BITS;
456        }
457        let mask = (1_u16 << self.bit_width.bits()) - 1;
458        Ok(((word >> shift) & mask) as u8)
459    }
460
461    /// Write one packed coordinate code.
462    ///
463    /// # Errors
464    ///
465    /// Returns bounds errors when `row` or `dimension` is outside the packed
466    /// matrix, and [`TurboQuantCodecError::InvalidCode`] when `code` exceeds
467    /// this storage's bit width.
468    pub fn write(&mut self, row: usize, dimension: usize, code: u8) -> TurboQuantCodecResult<()> {
469        self.validate_code(code)?;
470        let bit_offset = self.bit_offset(row, dimension)?;
471        let byte = bit_offset / u8::BITS as usize;
472        let shift = bit_offset % u8::BITS as usize;
473        let mask = ((1_u16 << self.bit_width.bits()) - 1) << shift;
474        let mut word = u16::from(self.bytes[byte]);
475        if byte + 1 < self.bytes.len() {
476            word |= u16::from(self.bytes[byte + 1]) << u8::BITS;
477        }
478        word = (word & !mask) | (u16::from(code) << shift);
479        self.bytes[byte] = (word & 0xff) as u8;
480        if shift + usize::from(self.bit_width.bits()) > u8::BITS as usize {
481            self.bytes[byte + 1] = (word >> u8::BITS) as u8;
482        }
483        Ok(())
484    }
485
486    fn validate_code(&self, code: u8) -> TurboQuantCodecResult<()> {
487        let max = self.bit_width.max_code();
488        if code <= max {
489            Ok(())
490        } else {
491            Err(TurboQuantCodecError::InvalidCode { code, max })
492        }
493    }
494
495    fn bit_offset(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<usize> {
496        if row >= self.rows {
497            return Err(TurboQuantCodecError::RowOutOfBounds {
498                row,
499                rows: self.rows,
500            });
501        }
502        if dimension >= self.dimensions {
503            return Err(TurboQuantCodecError::DimensionOutOfBounds {
504                dimension,
505                dimensions: self.dimensions,
506            });
507        }
508        let row_bits = row
509            .checked_mul(self.bytes_per_row)
510            .and_then(|offset| offset.checked_mul(u8::BITS as usize))
511            .ok_or(TurboQuantCodecError::SizeOverflow)?;
512        let dimension_bits = dimension
513            .checked_mul(usize::from(self.bit_width.bits()))
514            .ok_or(TurboQuantCodecError::SizeOverflow)?;
515        row_bits
516            .checked_add(dimension_bits)
517            .ok_or(TurboQuantCodecError::SizeOverflow)
518    }
519}
520
521fn validate_dimension(dimension: usize) -> TurboQuantCodecResult<()> {
522    if dimension == 0 || dimension > MAX_VECTOR_DIMENSION {
523        Err(TurboQuantCodecError::InvalidDimension {
524            dimension,
525            max: MAX_VECTOR_DIMENSION,
526        })
527    } else {
528        Ok(())
529    }
530}
531
532fn bytes_per_row(bit_width: TurboQuantBitWidth, dimensions: usize) -> TurboQuantCodecResult<usize> {
533    validate_dimension(dimensions)?;
534    let bits_per_row = dimensions
535        .checked_mul(usize::from(bit_width.bits()))
536        .ok_or(TurboQuantCodecError::SizeOverflow)?;
537    bits_per_row
538        .checked_add(u8::BITS as usize - 1)
539        .map(|bits| bits / u8::BITS as usize)
540        .ok_or(TurboQuantCodecError::SizeOverflow)
541}
542
543fn clipped_uniform_centroids(bit_width: TurboQuantBitWidth, dimension: usize) -> Vec<f32> {
544    let levels = bit_width.levels();
545    let sigma = (dimension as f32).sqrt().recip();
546    let clip = 3.0 * sigma;
547    (0..levels)
548        .map(|code| {
549            let midpoint = (code as f32 + 0.5) / levels as f32;
550            midpoint.mul_add(2.0 * clip, -clip)
551        })
552        .collect()
553}
554
555fn normal_lloyd_max_centroids(bit_width: TurboQuantBitWidth, dimension: usize) -> Vec<f32> {
556    let levels = bit_width.levels();
557    let sigma = (dimension as f64).sqrt().recip();
558    let spread = 3.0 * sigma;
559    let mut centroids = (0..levels)
560        .map(|code| -spread + 2.0 * spread * code as f64 / (levels - 1) as f64)
561        .collect::<Vec<_>>();
562
563    for _ in 0..64 {
564        let boundaries = f64_centroid_boundaries(&centroids);
565        let mut max_change = 0.0f64;
566        for code in 0..levels {
567            let low = if code == 0 {
568                f64::NEG_INFINITY
569            } else {
570                boundaries[code - 1]
571            };
572            let high = if code + 1 == levels {
573                f64::INFINITY
574            } else {
575                boundaries[code]
576            };
577            let next = normal_interval_mean(low, high, sigma);
578            max_change = max_change.max((centroids[code] - next).abs());
579            centroids[code] = next;
580        }
581        if max_change < 1e-12 {
582            break;
583        }
584    }
585
586    centroids
587        .into_iter()
588        .map(|centroid| centroid as f32)
589        .collect()
590}
591
592fn centroid_boundaries(centroids: &[f32]) -> Vec<f32> {
593    centroids
594        .windows(2)
595        .map(|pair| (pair[0] + pair[1]) * 0.5)
596        .collect()
597}
598
599fn f64_centroid_boundaries(centroids: &[f64]) -> Vec<f64> {
600    centroids
601        .windows(2)
602        .map(|pair| (pair[0] + pair[1]) * 0.5)
603        .collect()
604}
605
606fn normal_interval_mean(low: f64, high: f64, sigma: f64) -> f64 {
607    let low_z = low / sigma;
608    let high_z = high / sigma;
609    let probability = standard_normal_cdf(high_z) - standard_normal_cdf(low_z);
610    if probability <= 1e-15 {
611        return (low + high) * 0.5;
612    }
613    sigma * (standard_normal_pdf(low_z) - standard_normal_pdf(high_z)) / probability
614}
615
616fn standard_normal_pdf(value: f64) -> f64 {
617    const INV_SQRT_2_PI: f64 = 0.398_942_280_401_432_7;
618    if value.is_infinite() {
619        0.0
620    } else {
621        INV_SQRT_2_PI * (-0.5 * value * value).exp()
622    }
623}
624
625fn standard_normal_cdf(value: f64) -> f64 {
626    if value == f64::NEG_INFINITY {
627        0.0
628    } else if value == f64::INFINITY {
629        1.0
630    } else {
631        0.5 * (1.0 + erf_approx(value / f64::sqrt(2.0)))
632    }
633}
634
635fn erf_approx(value: f64) -> f64 {
636    let sign = if value < 0.0 { -1.0 } else { 1.0 };
637    let x = value.abs();
638    let t = 1.0 / (1.0 + 0.327_591_1 * x);
639    let polynomial =
640        (((((1.061_405_429 * t - 1.453_152_027) * t + 1.421_413_741) * t - 0.284_496_736) * t
641            + 0.254_829_592)
642            * t)
643            * (-x * x).exp();
644    sign * (1.0 - polynomial)
645}
646
647#[cfg(test)]
648mod tests;