1use std::error::Error;
8use std::fmt;
9use std::mem::size_of;
10
11use crate::MAX_VECTOR_DIMENSION;
12
13#[path = "turbo_quant/blocked.rs"]
14mod blocked;
15
16pub use blocked::{TURBO_QUANT_BLOCK_ROWS, TurboQuantBlockedCodes};
17
18pub type TurboQuantCodecResult<T> = Result<T, TurboQuantCodecError>;
20
21#[derive(Clone, Debug, PartialEq)]
23pub enum TurboQuantCodecError {
24 InvalidBitWidth {
26 bits: u8,
28 },
29 InvalidDimension {
32 dimension: usize,
34 max: usize,
36 },
37 ByteLengthMismatch {
39 expected: usize,
41 actual: usize,
43 },
44 SizeOverflow,
46 RowOutOfBounds {
48 row: usize,
50 rows: usize,
52 },
53 DimensionOutOfBounds {
55 dimension: usize,
57 dimensions: usize,
59 },
60 InvalidCode {
62 code: u8,
64 max: u8,
66 },
67 NonFiniteValue {
69 value: f32,
71 },
72}
73
74impl fmt::Display for TurboQuantCodecError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::InvalidBitWidth { bits } => {
78 write!(f, "invalid TurboQuant bit width {bits}; expected 2..=4")
79 }
80 Self::InvalidDimension { dimension, max } => write!(
81 f,
82 "invalid TurboQuant dimension {dimension}; expected 1..={max}"
83 ),
84 Self::ByteLengthMismatch { expected, actual } => write!(
85 f,
86 "invalid TurboQuant packed byte length {actual}; expected {expected}"
87 ),
88 Self::SizeOverflow => write!(f, "TurboQuant packed-code size overflowed usize"),
89 Self::RowOutOfBounds { row, rows } => {
90 write!(f, "TurboQuant row {row} is out of bounds for {rows} rows")
91 }
92 Self::DimensionOutOfBounds {
93 dimension,
94 dimensions,
95 } => write!(
96 f,
97 "TurboQuant dimension {dimension} is out of bounds for {dimensions} dimensions"
98 ),
99 Self::InvalidCode { code, max } => {
100 write!(f, "TurboQuant code {code} exceeds maximum code {max}")
101 }
102 Self::NonFiniteValue { value } => {
103 write!(f, "TurboQuant value must be finite, got {value}")
104 }
105 }
106 }
107}
108
109impl Error for TurboQuantCodecError {}
110
111#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
113pub struct TurboQuantBitWidth(u8);
114
115impl TurboQuantBitWidth {
116 pub const fn new(bits: u8) -> TurboQuantCodecResult<Self> {
123 if bits >= 2 && bits <= 4 {
124 Ok(Self(bits))
125 } else {
126 Err(TurboQuantCodecError::InvalidBitWidth { bits })
127 }
128 }
129
130 #[must_use]
132 pub const fn bits(self) -> u8 {
133 self.0
134 }
135
136 #[must_use]
138 pub const fn levels(self) -> usize {
139 1_usize << self.0
140 }
141
142 #[must_use]
144 pub const fn max_code(self) -> u8 {
145 (1_u8 << self.0) - 1
146 }
147}
148
149impl TryFrom<u8> for TurboQuantBitWidth {
150 type Error = TurboQuantCodecError;
151
152 fn try_from(value: u8) -> Result<Self, Self::Error> {
153 Self::new(value)
154 }
155}
156
157impl From<TurboQuantBitWidth> for u8 {
158 fn from(value: TurboQuantBitWidth) -> Self {
159 value.bits()
160 }
161}
162
163#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
165pub enum TurboQuantCodebookKind {
166 ClippedUniform,
169 NormalLloydMax,
171}
172
173#[derive(Clone, Debug, PartialEq)]
175pub struct TurboQuantCodebook {
176 kind: TurboQuantCodebookKind,
177 bit_width: TurboQuantBitWidth,
178 dimension: usize,
179 centroids: Vec<f32>,
180 boundaries: Vec<f32>,
181}
182
183impl TurboQuantCodebook {
184 pub fn new(
192 kind: TurboQuantCodebookKind,
193 bit_width: TurboQuantBitWidth,
194 dimension: usize,
195 ) -> TurboQuantCodecResult<Self> {
196 validate_dimension(dimension)?;
197 let centroids = match kind {
198 TurboQuantCodebookKind::ClippedUniform => {
199 clipped_uniform_centroids(bit_width, dimension)
200 }
201 TurboQuantCodebookKind::NormalLloydMax => {
202 normal_lloyd_max_centroids(bit_width, dimension)
203 }
204 };
205 let boundaries = centroid_boundaries(¢roids);
206 Ok(Self {
207 kind,
208 bit_width,
209 dimension,
210 centroids,
211 boundaries,
212 })
213 }
214
215 pub fn clipped_uniform(
222 bit_width: TurboQuantBitWidth,
223 dimension: usize,
224 ) -> TurboQuantCodecResult<Self> {
225 Self::new(TurboQuantCodebookKind::ClippedUniform, bit_width, dimension)
226 }
227
228 pub fn normal_lloyd_max(
235 bit_width: TurboQuantBitWidth,
236 dimension: usize,
237 ) -> TurboQuantCodecResult<Self> {
238 Self::new(TurboQuantCodebookKind::NormalLloydMax, bit_width, dimension)
239 }
240
241 #[must_use]
243 pub const fn kind(&self) -> TurboQuantCodebookKind {
244 self.kind
245 }
246
247 #[must_use]
249 pub const fn bit_width(&self) -> TurboQuantBitWidth {
250 self.bit_width
251 }
252
253 #[must_use]
255 pub const fn dimension(&self) -> usize {
256 self.dimension
257 }
258
259 #[must_use]
261 pub fn centroids(&self) -> &[f32] {
262 &self.centroids
263 }
264
265 #[must_use]
267 pub fn boundaries(&self) -> &[f32] {
268 &self.boundaries
269 }
270
271 pub fn centroid(&self, code: u8) -> TurboQuantCodecResult<f32> {
278 self.validate_code(code)?;
279 Ok(self.centroids[usize::from(code)])
280 }
281
282 pub fn encode_scalar(&self, value: f32) -> TurboQuantCodecResult<u8> {
291 if !value.is_finite() {
292 return Err(TurboQuantCodecError::NonFiniteValue { value });
293 }
294 Ok(self
295 .boundaries
296 .partition_point(|boundary| value > *boundary) as u8)
297 }
298
299 #[must_use]
301 pub fn estimated_bytes(&self) -> usize {
302 self.centroids
303 .len()
304 .saturating_add(self.boundaries.len())
305 .saturating_mul(size_of::<f32>())
306 }
307
308 fn validate_code(&self, code: u8) -> TurboQuantCodecResult<()> {
309 let max = self.bit_width.max_code();
310 if code <= max {
311 Ok(())
312 } else {
313 Err(TurboQuantCodecError::InvalidCode { code, max })
314 }
315 }
316}
317
318#[derive(Clone, Debug, Eq, PartialEq)]
320pub struct TurboQuantPackedCodes {
321 bit_width: TurboQuantBitWidth,
322 dimensions: usize,
323 rows: usize,
324 bytes_per_row: usize,
325 bytes: Vec<u8>,
326}
327
328impl TurboQuantPackedCodes {
329 pub fn new(
336 bit_width: TurboQuantBitWidth,
337 dimensions: usize,
338 rows: usize,
339 ) -> TurboQuantCodecResult<Self> {
340 let bytes_per_row = bytes_per_row(bit_width, dimensions)?;
341 let byte_len = bytes_per_row
342 .checked_mul(rows)
343 .ok_or(TurboQuantCodecError::SizeOverflow)?;
344 Ok(Self {
345 bit_width,
346 dimensions,
347 rows,
348 bytes_per_row,
349 bytes: vec![0; byte_len],
350 })
351 }
352
353 pub fn from_bytes(
360 bit_width: TurboQuantBitWidth,
361 dimensions: usize,
362 rows: usize,
363 bytes: Vec<u8>,
364 ) -> TurboQuantCodecResult<Self> {
365 let bytes_per_row = bytes_per_row(bit_width, dimensions)?;
366 let expected = bytes_per_row
367 .checked_mul(rows)
368 .ok_or(TurboQuantCodecError::SizeOverflow)?;
369 if bytes.len() != expected {
370 return Err(TurboQuantCodecError::ByteLengthMismatch {
371 expected,
372 actual: bytes.len(),
373 });
374 }
375 Ok(Self {
376 bit_width,
377 dimensions,
378 rows,
379 bytes_per_row,
380 bytes,
381 })
382 }
383
384 #[must_use]
386 pub const fn bit_width(&self) -> TurboQuantBitWidth {
387 self.bit_width
388 }
389
390 #[must_use]
392 pub const fn dimensions(&self) -> usize {
393 self.dimensions
394 }
395
396 #[must_use]
398 pub const fn rows(&self) -> usize {
399 self.rows
400 }
401
402 #[must_use]
404 pub const fn bytes_per_row(&self) -> usize {
405 self.bytes_per_row
406 }
407
408 #[must_use]
410 pub fn as_bytes(&self) -> &[u8] {
411 &self.bytes
412 }
413
414 #[must_use]
416 pub fn into_bytes(self) -> Vec<u8> {
417 self.bytes
418 }
419
420 #[must_use]
422 pub fn estimated_bytes(&self) -> usize {
423 self.bytes.len()
424 }
425
426 pub fn resize_rows(&mut self, rows: usize) -> TurboQuantCodecResult<()> {
434 let byte_len = self
435 .bytes_per_row
436 .checked_mul(rows)
437 .ok_or(TurboQuantCodecError::SizeOverflow)?;
438 self.bytes.resize(byte_len, 0);
439 self.rows = rows;
440 Ok(())
441 }
442
443 pub fn read(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<u8> {
450 let bit_offset = self.bit_offset(row, dimension)?;
451 let byte = bit_offset / u8::BITS as usize;
452 let shift = bit_offset % u8::BITS as usize;
453 let mut word = u16::from(self.bytes[byte]);
454 if byte + 1 < self.bytes.len() {
455 word |= u16::from(self.bytes[byte + 1]) << u8::BITS;
456 }
457 let mask = (1_u16 << self.bit_width.bits()) - 1;
458 Ok(((word >> shift) & mask) as u8)
459 }
460
461 pub fn write(&mut self, row: usize, dimension: usize, code: u8) -> TurboQuantCodecResult<()> {
469 self.validate_code(code)?;
470 let bit_offset = self.bit_offset(row, dimension)?;
471 let byte = bit_offset / u8::BITS as usize;
472 let shift = bit_offset % u8::BITS as usize;
473 let mask = ((1_u16 << self.bit_width.bits()) - 1) << shift;
474 let mut word = u16::from(self.bytes[byte]);
475 if byte + 1 < self.bytes.len() {
476 word |= u16::from(self.bytes[byte + 1]) << u8::BITS;
477 }
478 word = (word & !mask) | (u16::from(code) << shift);
479 self.bytes[byte] = (word & 0xff) as u8;
480 if shift + usize::from(self.bit_width.bits()) > u8::BITS as usize {
481 self.bytes[byte + 1] = (word >> u8::BITS) as u8;
482 }
483 Ok(())
484 }
485
486 fn validate_code(&self, code: u8) -> TurboQuantCodecResult<()> {
487 let max = self.bit_width.max_code();
488 if code <= max {
489 Ok(())
490 } else {
491 Err(TurboQuantCodecError::InvalidCode { code, max })
492 }
493 }
494
495 fn bit_offset(&self, row: usize, dimension: usize) -> TurboQuantCodecResult<usize> {
496 if row >= self.rows {
497 return Err(TurboQuantCodecError::RowOutOfBounds {
498 row,
499 rows: self.rows,
500 });
501 }
502 if dimension >= self.dimensions {
503 return Err(TurboQuantCodecError::DimensionOutOfBounds {
504 dimension,
505 dimensions: self.dimensions,
506 });
507 }
508 let row_bits = row
509 .checked_mul(self.bytes_per_row)
510 .and_then(|offset| offset.checked_mul(u8::BITS as usize))
511 .ok_or(TurboQuantCodecError::SizeOverflow)?;
512 let dimension_bits = dimension
513 .checked_mul(usize::from(self.bit_width.bits()))
514 .ok_or(TurboQuantCodecError::SizeOverflow)?;
515 row_bits
516 .checked_add(dimension_bits)
517 .ok_or(TurboQuantCodecError::SizeOverflow)
518 }
519}
520
521fn validate_dimension(dimension: usize) -> TurboQuantCodecResult<()> {
522 if dimension == 0 || dimension > MAX_VECTOR_DIMENSION {
523 Err(TurboQuantCodecError::InvalidDimension {
524 dimension,
525 max: MAX_VECTOR_DIMENSION,
526 })
527 } else {
528 Ok(())
529 }
530}
531
532fn bytes_per_row(bit_width: TurboQuantBitWidth, dimensions: usize) -> TurboQuantCodecResult<usize> {
533 validate_dimension(dimensions)?;
534 let bits_per_row = dimensions
535 .checked_mul(usize::from(bit_width.bits()))
536 .ok_or(TurboQuantCodecError::SizeOverflow)?;
537 bits_per_row
538 .checked_add(u8::BITS as usize - 1)
539 .map(|bits| bits / u8::BITS as usize)
540 .ok_or(TurboQuantCodecError::SizeOverflow)
541}
542
543fn clipped_uniform_centroids(bit_width: TurboQuantBitWidth, dimension: usize) -> Vec<f32> {
544 let levels = bit_width.levels();
545 let sigma = (dimension as f32).sqrt().recip();
546 let clip = 3.0 * sigma;
547 (0..levels)
548 .map(|code| {
549 let midpoint = (code as f32 + 0.5) / levels as f32;
550 midpoint.mul_add(2.0 * clip, -clip)
551 })
552 .collect()
553}
554
555fn normal_lloyd_max_centroids(bit_width: TurboQuantBitWidth, dimension: usize) -> Vec<f32> {
556 let levels = bit_width.levels();
557 let sigma = (dimension as f64).sqrt().recip();
558 let spread = 3.0 * sigma;
559 let mut centroids = (0..levels)
560 .map(|code| -spread + 2.0 * spread * code as f64 / (levels - 1) as f64)
561 .collect::<Vec<_>>();
562
563 for _ in 0..64 {
564 let boundaries = f64_centroid_boundaries(¢roids);
565 let mut max_change = 0.0f64;
566 for code in 0..levels {
567 let low = if code == 0 {
568 f64::NEG_INFINITY
569 } else {
570 boundaries[code - 1]
571 };
572 let high = if code + 1 == levels {
573 f64::INFINITY
574 } else {
575 boundaries[code]
576 };
577 let next = normal_interval_mean(low, high, sigma);
578 max_change = max_change.max((centroids[code] - next).abs());
579 centroids[code] = next;
580 }
581 if max_change < 1e-12 {
582 break;
583 }
584 }
585
586 centroids
587 .into_iter()
588 .map(|centroid| centroid as f32)
589 .collect()
590}
591
592fn centroid_boundaries(centroids: &[f32]) -> Vec<f32> {
593 centroids
594 .windows(2)
595 .map(|pair| (pair[0] + pair[1]) * 0.5)
596 .collect()
597}
598
599fn f64_centroid_boundaries(centroids: &[f64]) -> Vec<f64> {
600 centroids
601 .windows(2)
602 .map(|pair| (pair[0] + pair[1]) * 0.5)
603 .collect()
604}
605
606fn normal_interval_mean(low: f64, high: f64, sigma: f64) -> f64 {
607 let low_z = low / sigma;
608 let high_z = high / sigma;
609 let probability = standard_normal_cdf(high_z) - standard_normal_cdf(low_z);
610 if probability <= 1e-15 {
611 return (low + high) * 0.5;
612 }
613 sigma * (standard_normal_pdf(low_z) - standard_normal_pdf(high_z)) / probability
614}
615
616fn standard_normal_pdf(value: f64) -> f64 {
617 const INV_SQRT_2_PI: f64 = 0.398_942_280_401_432_7;
618 if value.is_infinite() {
619 0.0
620 } else {
621 INV_SQRT_2_PI * (-0.5 * value * value).exp()
622 }
623}
624
625fn standard_normal_cdf(value: f64) -> f64 {
626 if value == f64::NEG_INFINITY {
627 0.0
628 } else if value == f64::INFINITY {
629 1.0
630 } else {
631 0.5 * (1.0 + erf_approx(value / f64::sqrt(2.0)))
632 }
633}
634
635fn erf_approx(value: f64) -> f64 {
636 let sign = if value < 0.0 { -1.0 } else { 1.0 };
637 let x = value.abs();
638 let t = 1.0 / (1.0 + 0.327_591_1 * x);
639 let polynomial =
640 (((((1.061_405_429 * t - 1.453_152_027) * t + 1.421_413_741) * t - 0.284_496_736) * t
641 + 0.254_829_592)
642 * t)
643 * (-x * x).exp();
644 sign * (1.0 - polynomial)
645}
646
647#[cfg(test)]
648mod tests;